libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2008 by Yann LeCun and Pierre Sermanet * 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 namespace ebl { 00034 00036 // fstate_idx 00037 00038 template<typename T> 00039 fstate_idx<T>::~fstate_idx() { 00040 } 00041 00043 // constructors from specific dimensions using a fparameter 00044 00045 template<typename T> 00046 fstate_idx<T>::fstate_idx() { 00047 clear(); 00048 } 00049 00050 template<typename T> 00051 fstate_idx<T>::fstate_idx(intg s0) { 00052 x = idx<T>(s0); 00053 clear(); 00054 } 00055 00056 template<typename T> 00057 fstate_idx<T>::fstate_idx(intg s0, intg s1) { 00058 x = idx<T>(s0, s1); 00059 clear(); 00060 } 00061 00062 template<typename T> 00063 fstate_idx<T>::fstate_idx(intg s0, intg s1, intg s2) { 00064 x = idx<T>(s0, s1, s2); 00065 clear(); 00066 } 00067 00068 template<typename T> 00069 fstate_idx<T>::fstate_idx(intg s0, intg s1, intg s2, intg s3, intg s4, 00070 intg s5, intg s6, intg s7) { 00071 x = idx<T>(s0, s1, s2, s3, s4, s5, s6, s7); 00072 clear(); 00073 } 00074 00075 template<typename T> 00076 fstate_idx<T>::fstate_idx(const idxdim &d) { 00077 x = idx<T>(d); 00078 clear(); 00079 } 00080 00081 template <typename T> 00082 fstate_idx<T>::fstate_idx(intg n, fstate_idx<T> &s) { 00083 idxdim d = s.x.get_idxdim(); 00084 d.setdims(n); 00085 x = idx<T>(d); 00086 clear(); 00087 } 00088 00090 // constructors from specific dimensions using a fparameter 00091 00092 template<typename T> 00093 fstate_idx<T>::fstate_idx(parameter<T,fstate_idx<T> > *st) { 00094 x = idx<T>(st ? st->x.getstorage() : NULL, st ? st->x.footprint() : 0); 00095 if (st) 00096 st->resize(st->footprint() + nelements()); 00097 clear(); 00098 } 00099 00100 template<typename T> 00101 fstate_idx<T>::fstate_idx(parameter<T,fstate_idx<T> > *st, intg s0) 00102 { 00103 x = idx<T>(st ? st->x.getstorage() : NULL, st ? st->x.footprint() : 0, s0); 00104 if (st) 00105 st->resize(st->footprint() + nelements()); 00106 clear(); 00107 } 00108 00109 template<typename T> 00110 fstate_idx<T>::fstate_idx(parameter<T,fstate_idx<T> > *st, intg s0, intg s1) 00111 { 00112 x = idx<T>(st ? st->x.getstorage() : NULL, 00113 st ? st->x.footprint() : 0, s0, s1); 00114 if (st) 00115 st->resize(st->footprint() + nelements()); 00116 clear(); 00117 } 00118 00119 template<typename T> 00120 fstate_idx<T>::fstate_idx(parameter<T,fstate_idx<T> > *st, intg s0, intg s1, 00121 intg s2) { 00122 x = idx<T>(st ? st->x.getstorage() : NULL, 00123 st ? st->x.footprint() : 0, s0, s1, s2); 00124 if (st) 00125 st->resize(st->footprint() + nelements()); 00126 clear(); 00127 } 00128 00129 template<typename T> 00130 fstate_idx<T>::fstate_idx(parameter<T,fstate_idx<T> > *st, intg s0, intg s1, 00131 intg s2, intg s3, intg s4, intg s5, intg s6, 00132 intg s7) { 00133 x = idx<T>(st ? st->x.getstorage() : NULL, 00134 st ? st->x.footprint() : 0, s0, s1, s2, s3, s4, s5, s6, s7); 00135 if (st) 00136 st->resize(st->footprint() + nelements()); 00137 clear(); 00138 } 00139 00140 template<typename T> 00141 fstate_idx<T>::fstate_idx(parameter<T,fstate_idx<T> > *st, const idxdim &d) 00142 { 00143 x = idx<T>(st ? st->x.getstorage() : NULL, st ? st->x.footprint() : 0, d); 00144 if (st) 00145 st->resize(st->footprint() + nelements()); 00146 clear(); 00147 } 00148 00150 // constructors from other fstate_idx's dimensions 00151 00152 template<typename T> 00153 fstate_idx<T>::fstate_idx(const idx<T> &_x) { 00154 x = idx<T>(_x); 00155 } 00156 00158 // clear methods 00159 00160 template <typename T> void fstate_idx<T>::clear() { 00161 idx_clear(x); 00162 } 00163 00164 template <typename T> void fstate_idx<T>::clear_x() { 00165 idx_clear(x); 00166 } 00167 00169 // information methods 00170 00171 template <typename T> intg fstate_idx<T>::nelements() { 00172 return x.nelements(); 00173 } 00174 00175 template <typename T> intg fstate_idx<T>::footprint() { 00176 return x.footprint(); 00177 } 00178 00179 template <typename T> intg fstate_idx<T>::size() { 00180 return x.footprint(); 00181 } 00182 00184 // resize methods 00185 00186 template <typename T> 00187 void fstate_idx<T>::resize(intg s0, intg s1, intg s2, 00188 intg s3, intg s4, intg s5, 00189 intg s6, intg s7) { 00190 if (!x.same_dim(s0, s1, s2, s3, s4, s5, s6, s7)) { // save some time 00191 x.resize(s0, s1, s2, s3, s4, s5, s6, s7); 00192 } 00193 } 00194 00195 template <typename T> 00196 void fstate_idx<T>::resize(const idxdim &d) { 00197 if (!x.same_dim(d)) { // save some time if dimensions are the same 00198 x.resize(d); 00199 } 00200 } 00201 00202 template <typename T> 00203 void fstate_idx<T>::resize1(intg dimn, intg size) { 00204 if (x.dim(dimn) != size) { // save some time if size is already set. 00205 x.resize1(dimn, size); 00206 } 00207 } 00208 00209 template <typename T> 00210 void fstate_idx<T>::resize_as(fstate_idx<T>& s) { 00211 idxdim d(s.x.spec); // use same dimensions as s 00212 resize(d); 00213 } 00214 00215 template <typename T> 00216 void fstate_idx<T>::resize_as_but1(fstate_idx<T>& s, intg fixed_dim) { 00217 idxdim d(s.x.spec); // use same dimensions as s 00218 d.setdim(fixed_dim, x.dim(fixed_dim)); 00219 resize(d); 00220 } 00221 00224 00225 template <typename T> 00226 fstate_idx<T> fstate_idx<T>::select(int dimension, intg slice_index) { 00227 fstate_idx<T> s = *this; 00228 s.x = s.x.select(dimension, slice_index); 00229 return s; 00230 } 00231 00232 template <typename T> 00233 fstate_idx<T> fstate_idx<T>::narrow(int d, intg sz, intg o) { 00234 fstate_idx<T> s = *this; 00235 s.x = s.x.narrow(d, sz, o); 00236 return s; 00237 } 00238 00241 00242 template <typename T> 00243 fstate_idx<T> fstate_idx<T>::make_copy() { 00244 fstate_idx<T> other(x.get_idxdim()); 00245 idx_copy(x, other.x); 00246 return other; 00247 } 00248 00249 template <typename T> 00250 fstate_idx<T>& fstate_idx<T>::operator=(const fstate_idx<T>& other) { 00251 this->x = other.x; 00252 return *this; 00253 } 00254 00255 template <typename T> 00256 void fstate_idx<T>::copy(fstate_idx<T> &s) { 00257 idx_copy(s.x, x); 00258 } 00259 00262 00263 template <typename T> 00264 void fstate_idx<T>::pretty() { 00265 cout << "x: "; this->x.pretty(); 00266 } 00267 00268 template <typename T> 00269 void fstate_idx<T>::print() { 00270 cout << "x: "; this->x.print(); 00271 } 00272 00274 // bstate_idx 00275 00276 template<typename T> 00277 bstate_idx<T>::~bstate_idx() { 00278 } 00279 00281 // constructors from specific dimensions using a bparameter 00282 00283 template<typename T> 00284 bstate_idx<T>::bstate_idx() { 00285 clear(); 00286 } 00287 00288 template<typename T> 00289 bstate_idx<T>::bstate_idx(intg s0) { 00290 x = idx<T>(s0); 00291 dx = idx<T>(s0); 00292 clear(); 00293 } 00294 00295 template<typename T> 00296 bstate_idx<T>::bstate_idx(intg s0, intg s1) { 00297 x = idx<T>(s0, s1); 00298 dx = idx<T>(s0, s1); 00299 clear(); 00300 } 00301 00302 template<typename T> 00303 bstate_idx<T>::bstate_idx(intg s0, intg s1, intg s2) { 00304 x = idx<T>(s0, s1, s2); 00305 dx = idx<T>(s0, s1, s2); 00306 clear(); 00307 } 00308 00309 template<typename T> 00310 bstate_idx<T>::bstate_idx(intg s0, intg s1, intg s2, intg s3, intg s4, 00311 intg s5, intg s6, intg s7) { 00312 x = idx<T>(s0, s1, s2, s3, s4, s5, s6, s7); 00313 dx = idx<T>(s0, s1, s2, s3, s4, s5, s6, s7); 00314 clear(); 00315 } 00316 00317 template<typename T> 00318 bstate_idx<T>::bstate_idx(const idxdim &d) { 00319 x = idx<T>(d); 00320 dx = idx<T>(d); 00321 clear(); 00322 } 00323 00324 template <typename T> 00325 bstate_idx<T>::bstate_idx(intg n, bstate_idx<T> &s) { 00326 idxdim d = s.x.get_idxdim(); 00327 d.setdims(n); 00328 x = idx<T>(d); 00329 dx = idx<T>(d); 00330 clear(); 00331 } 00332 00334 // constructors from specific dimensions using a bparameter 00335 00336 template<typename T> 00337 bstate_idx<T>::bstate_idx(parameter<T,bstate_idx<T> > *st) { 00338 x = idx<T>(st ? st->x.getstorage() : NULL, st ? st->x.footprint() : 0); 00339 dx = idx<T>(st ? st->dx.getstorage() : NULL, st ? st->dx.footprint() : 0); 00340 if (st) 00341 st->resize(st->footprint() + nelements()); 00342 clear(); 00343 } 00344 00345 template<typename T> 00346 bstate_idx<T>::bstate_idx(parameter<T,bstate_idx<T> > *st, intg s0) { 00347 x = idx<T>(st ? st->x.getstorage() : NULL, st ? st->x.footprint() : 0, s0); 00348 dx = idx<T>(st ? st->dx.getstorage() : NULL, 00349 st ? st->dx.footprint() : 0, s0); 00350 if (st) 00351 st->resize(st->footprint() + nelements()); 00352 clear(); 00353 } 00354 00355 template<typename T> 00356 bstate_idx<T>::bstate_idx(parameter<T,bstate_idx<T> > *st, intg s0, intg s1) { 00357 x = idx<T>(st ? st->x.getstorage() : NULL, 00358 st ? st->x.footprint() : 0, s0, s1); 00359 dx = idx<T>(st ? st->dx.getstorage() : NULL, 00360 st ? st->dx.footprint() : 0, s0, s1); 00361 if (st) 00362 st->resize(st->footprint() + nelements()); 00363 clear(); 00364 } 00365 00366 template<typename T> 00367 bstate_idx<T>::bstate_idx(parameter<T,bstate_idx<T> > *st, intg s0, intg s1, 00368 intg s2) { 00369 x = idx<T>(st ? st->x.getstorage() : NULL, 00370 st ? st->x.footprint() : 0, s0, s1, s2); 00371 dx = idx<T>(st ? st->dx.getstorage() : NULL, 00372 st ? st->dx.footprint() : 0, s0, s1, s2); 00373 if (st) 00374 st->resize(st->footprint() + nelements()); 00375 clear(); 00376 } 00377 00378 template<typename T> 00379 bstate_idx<T>::bstate_idx(parameter<T,bstate_idx<T> > *st, intg s0, intg s1, 00380 intg s2, intg s3, intg s4, intg s5, 00381 intg s6, intg s7) { 00382 x = idx<T>(st ? st->x.getstorage() : NULL, 00383 st ? st->x.footprint() : 0, s0, s1, s2, s3, s4, s5, s6, s7); 00384 dx = idx<T>(st ? st->dx.getstorage() : NULL, 00385 st ? st->dx.footprint : 0, s0, s1, s2, s3, s4, s5, s6,s7); 00386 if (st) 00387 st->resize(st->footprint() + nelements()); 00388 clear(); 00389 } 00390 00391 template<typename T> 00392 bstate_idx<T>::bstate_idx(parameter<T,bstate_idx<T> > *st, const idxdim &d) { 00393 x = idx<T>(st ? st->x.getstorage() : NULL, st ? st->x.footprint() : 0, d); 00394 dx = idx<T>(st ? st->x.getstorage() : NULL, st ? st->x.footprint() : 0, d); 00395 if (st) 00396 st->resize(st->footprint() + nelements()); 00397 clear(); 00398 } 00399 00401 // constructors from other bstate_idx's dimensions 00402 00403 template<typename T> 00404 bstate_idx<T>::bstate_idx(const idx<T> &_x, const idx<T> &_dx) { 00405 x = idx<T>(_x); 00406 dx = idx<T>(_dx); 00407 } 00408 00410 // clear methods 00411 00412 template <typename T> void bstate_idx<T>::clear() { 00413 idx_clear(x); 00414 idx_clear(dx); 00415 } 00416 00417 template <typename T> void bstate_idx<T>::clear_dx() { 00418 idx_clear(dx); 00419 } 00420 00422 // resize methods 00423 00424 template <typename T> 00425 void bstate_idx<T>::resize(intg s0, intg s1, intg s2, 00426 intg s3, intg s4, intg s5, 00427 intg s6, intg s7) { 00428 if (!x.same_dim(s0, s1, s2, s3, s4, s5, s6, s7)) { // save some time 00429 x.resize(s0, s1, s2, s3, s4, s5, s6, s7); 00430 dx.resize(s0, s1, s2, s3, s4, s5, s6, s7); 00431 } 00432 } 00433 00434 template <typename T> 00435 void bstate_idx<T>::resize(const idxdim &d) { 00436 if (!x.same_dim(d)) { // save some time if dimensions are the same 00437 x.resize(d); 00438 dx.resize(d); 00439 } 00440 } 00441 00442 template <typename T> 00443 void bstate_idx<T>::resize1(intg dimn, intg size) { 00444 if (x.dim(dimn) != size) { // save some time if size is already set. 00445 x.resize1(dimn, size); 00446 dx.resize1(dimn, size); 00447 } 00448 } 00449 00450 template <typename T> 00451 void bstate_idx<T>::resize_as(bstate_idx<T>& s) { 00452 idxdim d(s.x.spec); // use same dimensions as s 00453 resize(d); 00454 } 00455 00456 template <typename T> 00457 void bstate_idx<T>::resize_as_but1(bstate_idx<T>& s, intg fixed_dim) { 00458 idxdim d(s.x.spec); // use same dimensions as s 00459 d.setdim(fixed_dim, x.dim(fixed_dim)); 00460 resize(d); 00461 } 00462 00463 // template <typename T> 00464 // void bstate_idx<T>::resize(const intg* dimsBegin, const intg* dimsEnd) { 00465 // x.resize(dimsBegin, dimsEnd); 00466 // dx.resize(dimsBegin, dimsEnd); 00467 // ddx.resize(dimsBegin, dimsEnd); 00468 // } 00469 00470 template <typename T> 00471 void bstate_idx<T>::update_gd(gd_param &arg) { 00472 if (arg.decay_l2 > 0) { 00473 idx_dotcacc(x, arg.decay_l2, dx); 00474 } 00475 if (arg.decay_l1 > 0) { 00476 idx_signdotcacc(x, (T) arg.decay_l1, dx); 00477 } 00478 idx_dotcacc(dx, -arg.eta, x); 00479 } 00480 00483 00484 template <typename T> 00485 bstate_idx<T> bstate_idx<T>::select(int dimension, intg slice_index) { 00486 bstate_idx<T> s = *this; 00487 s.x = s.x.select(dimension, slice_index); 00488 s.dx = s.dx.select(dimension, slice_index); 00489 return s; 00490 } 00491 00492 template <typename T> 00493 bstate_idx<T> bstate_idx<T>::narrow(int d, intg sz, intg o) { 00494 bstate_idx<T> s = *this; 00495 s.x = s.x.narrow(d, sz, o); 00496 s.dx = s.dx.narrow(d, sz, o); 00497 return s; 00498 } 00499 00502 00503 template <typename T> 00504 bstate_idx<T> bstate_idx<T>::make_copy() { 00505 bstate_idx<T> other(x.get_idxdim()); 00506 idx_copy(x, other.x); 00507 idx_copy(dx, other.dx); 00508 return other; 00509 } 00510 00511 template <typename T> 00512 bstate_idx<T>& bstate_idx<T>::operator=(const bstate_idx<T>& other) { 00513 this->x = other.x; 00514 this->dx = other.dx; 00515 return *this; 00516 } 00517 00518 template <typename T> 00519 void bstate_idx<T>::copy(bstate_idx<T> &s) { 00520 idx_copy(s.x, x); 00521 idx_copy(s.dx, dx); 00522 } 00523 00526 00527 template <typename T> 00528 void bstate_idx<T>::pretty() { 00529 cout << "x: "; this->x.pretty(); 00530 cout << "dx: "; this->dx.pretty(); 00531 } 00532 00533 template <typename T> 00534 void bstate_idx<T>::print() { 00535 cout << "x: "; this->x.print(); 00536 cout << " dx: "; this->dx.print(); 00537 } 00538 00540 // bbstate_idx 00541 00542 template<typename T> 00543 bbstate_idx<T>::~bbstate_idx() { 00544 } 00545 00547 // constructors from specific dimensions using a bbparameter 00548 00549 template<typename T> 00550 bbstate_idx<T>::bbstate_idx() { 00551 clear(); 00552 } 00553 00554 template<typename T> 00555 bbstate_idx<T>::bbstate_idx(intg s0) { 00556 x = idx<T>(s0); 00557 dx = idx<T>(s0); 00558 ddx = idx<T>(s0); 00559 clear(); 00560 } 00561 00562 template<typename T> 00563 bbstate_idx<T>::bbstate_idx(intg s0, intg s1) { 00564 x = idx<T>(s0, s1); 00565 dx = idx<T>(s0, s1); 00566 ddx = idx<T>(s0, s1); 00567 clear(); 00568 } 00569 00570 template<typename T> 00571 bbstate_idx<T>::bbstate_idx(intg s0, intg s1, intg s2) { 00572 x = idx<T>(s0, s1, s2); 00573 dx = idx<T>(s0, s1, s2); 00574 ddx = idx<T>(s0, s1, s2); 00575 clear(); 00576 } 00577 00578 template<typename T> 00579 bbstate_idx<T>::bbstate_idx(intg s0, intg s1, intg s2, intg s3, intg s4, 00580 intg s5, intg s6, intg s7) { 00581 x = idx<T>(s0, s1, s2, s3, s4, s5, s6, s7); 00582 dx = idx<T>(s0, s1, s2, s3, s4, s5, s6, s7); 00583 ddx = idx<T>(s0, s1, s2, s3, s4, s5, s6, s7); 00584 clear(); 00585 } 00586 00587 template<typename T> 00588 bbstate_idx<T>::bbstate_idx(const idxdim &d) { 00589 x = idx<T>(d); 00590 dx = idx<T>(d); 00591 ddx = idx<T>(d); 00592 clear(); 00593 } 00594 00595 template <typename T> 00596 bbstate_idx<T>::bbstate_idx(intg n, bbstate_idx<T> &s) { 00597 idxdim d = s.x.get_idxdim(); 00598 d.setdims(n); 00599 x = idx<T>(d); 00600 dx = idx<T>(d); 00601 ddx = idx<T>(d); 00602 clear(); 00603 } 00604 00606 // constructors from specific dimensions using a bbparameter 00607 00608 template<typename T> 00609 bbstate_idx<T>::bbstate_idx(parameter<T,bbstate_idx<T> > *st) { 00610 x = idx<T>(st ? st->x.getstorage() : NULL, st ? st->x.footprint() : 0); 00611 dx = idx<T>(st ? st->dx.getstorage() : NULL, st ? st->dx.footprint() : 0); 00612 ddx = idx<T>(st ? st->ddx.getstorage() : NULL, st ? st->ddx.footprint() :0); 00613 if (st) 00614 st->resize(st->footprint() + nelements()); 00615 clear(); 00616 } 00617 00618 template<typename T> 00619 bbstate_idx<T>::bbstate_idx(parameter<T,bbstate_idx<T> > *st, intg s0) { 00620 x = idx<T>(st ? st->x.getstorage() : NULL, st ? st->x.footprint() : 0, s0); 00621 dx = idx<T>(st ? st->dx.getstorage() : NULL, 00622 st ? st->dx.footprint() : 0, s0); 00623 ddx = idx<T>(st ? st->ddx.getstorage() : NULL, 00624 st ? st->ddx.footprint() : 0, s0); 00625 if (st) 00626 st->resize(st->footprint() + nelements()); 00627 clear(); 00628 } 00629 00630 template<typename T> 00631 bbstate_idx<T>::bbstate_idx(parameter<T,bbstate_idx<T> > *st, intg s0, 00632 intg s1) { 00633 x = idx<T>(st ? st->x.getstorage() : NULL, 00634 st ? st->x.footprint() : 0, s0, s1); 00635 dx = idx<T>(st ? st->dx.getstorage() : NULL, 00636 st ? st->dx.footprint() : 0, s0, s1); 00637 ddx = idx<T>(st ? st->ddx.getstorage() : NULL, 00638 st ? st->ddx.footprint() : 0, s0, s1); 00639 if (st) 00640 st->resize(st->footprint() + nelements()); 00641 clear(); 00642 } 00643 00644 template<typename T> 00645 bbstate_idx<T>::bbstate_idx(parameter<T,bbstate_idx<T> > *st, intg s0, 00646 intg s1, intg s2) { 00647 x = idx<T>(st ? st->x.getstorage() : NULL, 00648 st ? st->x.footprint() : 0, s0, s1, s2); 00649 dx = idx<T>(st ? st->dx.getstorage() : NULL, 00650 st ? st->dx.footprint() : 0, s0, s1, s2); 00651 ddx = idx<T>(st ? st->ddx.getstorage() : NULL, 00652 st ? st->ddx.footprint() : 0, s0, s1, s2); 00653 if (st) 00654 st->resize(st->footprint() + nelements()); 00655 clear(); 00656 } 00657 00658 template<typename T> 00659 bbstate_idx<T>::bbstate_idx(parameter<T,bbstate_idx<T> > *st, intg s0, 00660 intg s1, intg s2, 00661 intg s3, intg s4, intg s5, intg s6, intg s7) { 00662 x = idx<T>(st ? st->x.getstorage() : NULL, 00663 st ? st->x.footprint() : 0, s0, s1, s2, s3, s4, s5, s6, s7); 00664 dx = idx<T>(st ? st->dx.getstorage() : NULL, 00665 st ? st->dx.footprint() : 0, s0, s1, s2, s3, s4, s5, s6,s7); 00666 ddx = idx<T>(st ? st->ddx.getstorage() : NULL, 00667 st ? st->ddx.footprint() : 0, s0, s1, s2, s3, s4, s5, s6, s7); 00668 if (st) 00669 st->resize(st->footprint() + nelements()); 00670 clear(); 00671 } 00672 00673 template<typename T> 00674 bbstate_idx<T>::bbstate_idx(parameter<T,bbstate_idx<T> > *st, 00675 const idxdim &d) { 00676 x = idx<T>(st ? st->x.getstorage() : NULL, 00677 st ? st->x.footprint() : 0, d); 00678 dx = idx<T>(st ? st->dx.getstorage() : NULL, 00679 st ? st->dx.footprint() : 0, d); 00680 ddx = idx<T>(st ? st->ddx.getstorage() : NULL, 00681 st ? st->ddx.footprint() : 0, d); 00682 if (st) 00683 st->resize(st->footprint() + nelements()); 00684 clear(); 00685 } 00686 00688 // constructors from existing idx 00689 00690 // template<typename T> 00691 // bbstate_idx<T>::bbstate_idx(const idx<T> &_x) { 00692 // idxdim d(_x); 00693 // x = idx<T>(_x); 00694 // dx = idx<T>(d); 00695 // ddx = idx<T>(d); 00696 // } 00697 00698 template<typename T> 00699 bbstate_idx<T>::bbstate_idx(const idx<T> &_x, const idx<T> &_dx, 00700 const idx<T> &_ddx) { 00701 x = idx<T>(_x); 00702 dx = idx<T>(_dx); 00703 ddx = idx<T>(_ddx); 00704 } 00705 00707 // clear methods 00708 00709 template <typename T> void bbstate_idx<T>::clear() { 00710 idx_clear(x); 00711 idx_clear(dx); 00712 idx_clear(ddx); 00713 } 00714 00715 template <typename T> void bbstate_idx<T>::clear_ddx() { 00716 idx_clear(ddx); 00717 } 00718 00720 // resize methods 00721 00722 template <typename T> 00723 void bbstate_idx<T>::resize(intg s0, intg s1, intg s2, 00724 intg s3, intg s4, intg s5, 00725 intg s6, intg s7) { 00726 if (!x.same_dim(s0, s1, s2, s3, s4, s5, s6, s7)) { // save some time 00727 x.resize(s0, s1, s2, s3, s4, s5, s6, s7); 00728 dx.resize(s0, s1, s2, s3, s4, s5, s6, s7); 00729 ddx.resize(s0, s1, s2, s3, s4, s5, s6, s7); 00730 } 00731 } 00732 00733 template <typename T> 00734 void bbstate_idx<T>::resize(const idxdim &d) { 00735 if (!x.same_dim(d)) { // save some time if dimensions are the same 00736 x.resize(d); 00737 dx.resize(d); 00738 ddx.resize(d); 00739 } 00740 } 00741 00742 template <typename T> 00743 void bbstate_idx<T>::resize1(intg dimn, intg size) { 00744 if (x.dim(dimn) != size) { // save some time if size is already set. 00745 x.resize1(dimn, size); 00746 dx.resize1(dimn, size); 00747 ddx.resize1(dimn, size); 00748 } 00749 } 00750 00751 template <typename T> 00752 void bbstate_idx<T>::resize_as(bbstate_idx<T>& s) { 00753 idxdim d(s.x.spec); // use same dimensions as s 00754 resize(d); 00755 } 00756 00757 template <typename T> 00758 void bbstate_idx<T>::resize_as_but1(bbstate_idx<T>& s, intg fixed_dim) { 00759 idxdim d(s.x.spec); // use same dimensions as s 00760 d.setdim(fixed_dim, x.dim(fixed_dim)); 00761 resize(d); 00762 } 00763 00764 // template <typename T> 00765 // void bbstate_idx<T>::resize(const intg* dimsBegin, const intg* dimsEnd) { 00766 // x.resize(dimsBegin, dimsEnd); 00767 // dx.resize(dimsBegin, dimsEnd); 00768 // ddx.resize(dimsBegin, dimsEnd); 00769 // } 00770 00773 00774 template <typename T> 00775 bbstate_idx<T> bbstate_idx<T>::select(int dimension, intg slice_index) { 00776 bbstate_idx<T> s = *this; 00777 s.x = s.x.select(dimension, slice_index); 00778 s.dx = s.dx.select(dimension, slice_index); 00779 s.ddx = s.ddx.select(dimension, slice_index); 00780 return s; 00781 } 00782 00783 template <typename T> 00784 bbstate_idx<T> bbstate_idx<T>::narrow(int d, intg sz, intg o) { 00785 bbstate_idx<T> s = *this; 00786 s.x = s.x.narrow(d, sz, o); 00787 s.dx = s.dx.narrow(d, sz, o); 00788 s.ddx = s.ddx.narrow(d, sz, o); 00789 return s; 00790 } 00791 00794 00795 template <typename T> 00796 bbstate_idx<T> bbstate_idx<T>::make_copy() { 00797 bbstate_idx<T> other(x.get_idxdim()); 00798 idx_copy(x, other.x); 00799 idx_copy(dx, other.dx); 00800 idx_copy(ddx, other.ddx); 00801 return other; 00802 } 00803 00804 template <typename T> 00805 bbstate_idx<T>& bbstate_idx<T>::operator=(const bbstate_idx<T>& other) { 00806 this->x = other.x; 00807 this->dx = other.dx; 00808 this->ddx = other.ddx; 00809 return *this; 00810 } 00811 00812 template <typename T> 00813 void bbstate_idx<T>::copy(bbstate_idx<T> &s) { 00814 idx_copy(s.x, x); 00815 idx_copy(s.dx, dx); 00816 idx_copy(s.ddx, ddx); 00817 } 00818 00821 00822 template <typename T> 00823 void bbstate_idx<T>::pretty() { 00824 cout << "x: "; this->x.pretty(); 00825 cout << "dx: "; this->dx.pretty(); 00826 cout << "ddx: "; this->ddx.pretty(); 00827 } 00828 00829 template <typename T> 00830 void bbstate_idx<T>::print() { 00831 cout << "x: "; this->x.print(); 00832 cout << " dx: "; this->dx.print(); 00833 cout << " ddx: "; this->ddx.print(); 00834 } 00835 00837 // parameter 00838 00839 template <typename T> 00840 parameter<T,fstate_idx<T> >::parameter(intg initial_size) 00841 : fstate_idx<T>(initial_size) { 00842 resize(0); 00843 } 00844 00845 template <typename T> 00846 parameter<T,fstate_idx<T> >::parameter(const char *param_filename) 00847 : fstate_idx<T>(1) { 00848 if (!load_x(param_filename)) { 00849 cerr << "failed to open " << param_filename << endl; 00850 eblerror("failed to load parameter file in parameter constructor"); 00851 } 00852 } 00853 00854 template <typename T> 00855 parameter<T,fstate_idx<T> >::~parameter() { 00856 } 00857 00858 // TODO-0: BUG: a parameter object casted in state_idx* and called 00859 // with resize(n) calls state_idx::resize instead of parameter<T>::resize 00860 // a temporary unclean solution is to use the same parameters as 00861 // in state_idx::resize in parameter<T>::resize: 00862 // resize(intg s0, intg s1, intg s2, intg s3, intg s4, intg s5, 00863 // intg s6, intg s7); 00864 template <typename T> 00865 void parameter<T,fstate_idx<T> >::resize(intg s0) { 00866 x.resize(s0); 00867 } 00868 00869 template <typename T> 00870 bool parameter<T,fstate_idx<T> >::save_x(const char *s) { 00871 if (!save_matrix(x, s)) 00872 return false; 00873 return true; 00874 } 00875 00876 template <typename T> 00877 bool parameter<T,fstate_idx<T> >::load_x(std::vector<string> &files) { 00878 if (files.size() == 0) eblerror("expected at least 1 file to load"); 00879 idx<T> w = load_matrix<T>(files[0]); 00880 for (uint i = 1; i < files.size(); ++i) { 00881 idx<T> tmp = load_matrix<T>(files[i]); 00882 w = idx_concat(w, tmp); 00883 } 00884 cout << "Concatenated " << files.size() << " matrices into 1: " 00885 << w << " from " << files << endl; 00886 return load_x(w); 00887 } 00888 00889 template <typename T> 00890 bool parameter<T,fstate_idx<T> >::load_x(const char *s) { 00891 #ifndef __NOSTL__ 00892 try { 00893 #endif 00894 idx<T> m = load_matrix<T>(s); 00895 if ((x.dim(0) != 1) // param has been enlarged by network construction 00896 && (x.dim(0) != m.dim(0))) { // trying to load incompatible network 00897 eblerror("Trying to load a network with " << m.dim(0) 00898 << " parameters into a network with " << x.dim(0) 00899 << " parameters"); 00900 } 00901 this->resize(m.dim(0)); 00902 idx_copy(m, x); 00903 cout << "Loaded weights from " << s << ": " << x << endl; 00904 return true; 00905 #ifndef __NOSTL__ 00906 } catch(string &err) { 00907 cout << err << endl; 00908 eblerror("failed to load weights"); 00909 } 00910 #endif 00911 return false; 00912 } 00913 00914 template <typename T> 00915 bool parameter<T,fstate_idx<T> >::load_x(idx<T> &m) { 00916 if ((x.dim(0) != 1) // param has been enlarged by network construction 00917 && (x.dim(0) != m.dim(0))) { // trying to load incompatible network 00918 eblerror("Trying to load a network with " << m.dim(0) 00919 << " parameters into a network with " << x.dim(0) 00920 << " parameters"); 00921 } 00922 this->resize(m.dim(0)); 00923 idx_copy(m, x); 00924 cout << "Loaded weights from " << m << ": " << x << endl; 00925 return true; 00926 } 00927 00929 // parameter<T,bstate_idx<T> > 00930 00931 template <typename T> 00932 parameter<T,bstate_idx<T> >::parameter(intg initial_size) 00933 : bstate_idx<T>(initial_size), //gradient(initial_size), 00934 deltax(initial_size), epsilons(initial_size) { 00935 //idx_clear(gradient); 00936 idx_clear(deltax); 00937 idx_clear(epsilons); 00938 resize(0); 00939 } 00940 00941 template <typename T> 00942 parameter<T,bstate_idx<T> >::parameter(const char *param_filename) 00943 : bstate_idx<T>(1), //gradient(1), 00944 deltax(1), epsilons(1) { 00945 if (!load_x(param_filename)) { 00946 cerr << "failed to open " << param_filename << endl; 00947 eblerror("failed to load bparameter file in bparameter constructor"); 00948 } 00949 } 00950 00951 template <typename T> 00952 parameter<T,bstate_idx<T> >::~parameter() { 00953 } 00954 00955 // TODO-0: BUG: a bparameter object casted in state_idx* and called 00956 // with resize(n) calls state_idx::resize instead of parameter<T,bstate_idx<T> >resize 00957 // a temporary unclean solution is to use the same bparameters as 00958 // in state_idx::resize in parameter<T,bstate_idx<T> >resize: 00959 // resize(intg s0, intg s1, intg s2, intg s3, intg s4, intg s5, 00960 // intg s6, intg s7); 00961 template <typename T> 00962 void parameter<T,bstate_idx<T> >::resize(intg s0) { 00963 x.resize(s0); 00964 dx.resize(s0); 00965 //gradient.resize(s0); 00966 deltax.resize(s0); 00967 epsilons.resize(s0); 00968 } 00969 00970 template <typename T> 00971 bool parameter<T,bstate_idx<T> >::save_x(const char *s) { 00972 if (!save_matrix(x, s)) 00973 return false; 00974 return true; 00975 } 00976 00977 template <typename T> 00978 bool parameter<T,bstate_idx<T> >::load_x(const char *s) { 00979 try { 00980 idx<T> m = load_matrix<T>(s); 00981 if ((x.dim(0) != 1) // param has been enlarged by network construction 00982 && (x.dim(0) != m.dim(0))) // trying to load incompatible network 00983 eblerror("Trying to load a network with " << m.dim(0) 00984 << " parameters into a network with " << x.dim(0) 00985 << " parameters"); 00986 this->resize(m.dim(0)); 00987 idx_copy(m, x); 00988 cout << "Loaded weights from " << s << ": " << x << endl; 00989 return true; 00990 } catch(string &err) { 00991 cout << err << endl; 00992 eblerror("failed to load weights"); 00993 } 00994 return false; 00995 } 00996 00997 template <typename T> 00998 void parameter<T,bstate_idx<T> >::update(gd_param &arg) { 00999 update_gd(arg); 01000 } 01001 01002 template <typename T> 01003 void parameter<T,bstate_idx<T> >::clear_deltax() { 01004 idx_clear(deltax); 01005 } 01006 01008 // protected methods 01009 01010 template <typename T> 01011 void parameter<T,bstate_idx<T> >::set_epsilon(T m) { 01012 idx_fill(epsilons, m); 01013 } 01014 01015 template <typename T> 01016 void parameter<T,bstate_idx<T> >::update_gd(gd_param &arg) { 01017 if (arg.decay_l2 > 0) 01018 idx_dotcacc(x, arg.decay_l2, dx); 01019 if (arg.decay_l1 > 0) 01020 idx_signdotcacc(x, (T) arg.decay_l1, dx); 01021 if (arg.inertia == 0) { 01022 idx_mul(dx, epsilons, dx); 01023 idx_dotcacc(dx, -arg.eta, x); 01024 } else { 01025 update_deltax((T) (1 - arg.inertia), (T) arg.inertia); 01026 idx_mul(deltax, epsilons, deltax); 01027 idx_dotcacc(deltax, -arg.eta, x); 01028 } 01029 } 01030 01031 template <typename T> 01032 void parameter<T,bstate_idx<T> >::update_deltax(T knew, T kold) { 01033 idx_lincomb(dx, knew, deltax, kold, deltax); 01034 } 01035 01037 // parameter<T,bbstate_idx<T> > 01038 01039 template <typename T> 01040 parameter<T,bbstate_idx<T> >::parameter(intg initial_size) 01041 : bbstate_idx<T>(initial_size), //gradient(initial_size), 01042 deltax(initial_size), epsilons(initial_size), ddeltax(initial_size) { 01043 //idx_clear(gradient); 01044 idx_clear(deltax); 01045 idx_clear(epsilons); 01046 idx_clear(ddeltax); 01047 resize(0); 01048 } 01049 01050 template <typename T> 01051 parameter<T,bbstate_idx<T> >::parameter(const char *param_filename) 01052 : bbstate_idx<T>(1), //gradient(1), 01053 deltax(1), epsilons(1), ddeltax(1) { 01054 if (!load_x(param_filename)) { 01055 cerr << "failed to open " << param_filename << endl; 01056 eblerror("failed to load bbparameter file in bbparameter constructor"); 01057 } 01058 } 01059 01060 template <typename T> 01061 parameter<T,bbstate_idx<T> >::~parameter() { 01062 } 01063 01064 // TODO-0: BUG: a bbparameter object casted in state_idx* and called 01065 // with resize(n) calls state_idx::resize instead of 01066 // parameter<T,bbstate_idx<T> >::resize 01067 // a temporary unclean solution is to use the same bbparameters as 01068 // in state_idx::resize in parameter<T,bbstate_idx<T> >::resize: 01069 // resize(intg s0, intg s1, intg s2, intg s3, intg s4, intg s5, 01070 // intg s6, intg s7); 01071 template <typename T> 01072 void parameter<T,bbstate_idx<T> >::resize(intg s0) { 01073 this->x.resize(s0); 01074 this->dx.resize(s0); 01075 this->ddx.resize(s0); 01076 //gradient.resize(s0); 01077 deltax.resize(s0); 01078 epsilons.resize(s0); 01079 ddeltax.resize(s0); 01080 } 01081 01082 template <typename T> 01083 bool parameter<T,bbstate_idx<T> >::save_x(const char *s) { 01084 if (!save_matrix(this->x, s)) 01085 return false; 01086 return true; 01087 } 01088 01089 template <typename T> 01090 bool parameter<T,bbstate_idx<T> >::load_x(std::vector<string> &files) { 01091 if (files.size() == 0) eblerror("expected at least 1 file to load"); 01092 idx<T> w = load_matrix<T>(files[0]); 01093 for (uint i = 1; i < files.size(); ++i) { 01094 idx<T> tmp = load_matrix<T>(files[i]); 01095 w = idx_concat(w, tmp); 01096 } 01097 cout << "Concatenated " << files.size() << " matrices into 1: " 01098 << w << " from " << files << endl; 01099 load_x(w); 01100 return true; 01101 } 01102 01103 template <typename T> 01104 bool parameter<T,bbstate_idx<T> >::load_x(const char *s) { 01105 try { 01106 idx<T> m = load_matrix<T>(s); 01107 if ((x.dim(0) != 1) // param has been enlarged by network construction 01108 && (x.dim(0) != m.dim(0))) // trying to load incompatible network 01109 eblerror("Trying to load a network with " << m.dim(0) 01110 << " parameters into a network with " << x.dim(0) 01111 << " parameters"); 01112 this->resize(m.dim(0)); 01113 idx_copy(m, this->x); 01114 cout << "Loaded weights from " << s << ": " << this->x << endl; 01115 return true; 01116 } catch(string &err) { 01117 cout << err << endl; 01118 eblerror("failed to load weights"); 01119 } 01120 return false; 01121 } 01122 01123 template <typename T> 01124 bool parameter<T,bbstate_idx<T> >::load_x(idx<T> &m) { 01125 if ((x.dim(0) != 1) // param has been enlarged by network construction 01126 && (x.dim(0) != m.dim(0))) { // trying to load incompatible network 01127 eblerror("Trying to load a network with " << m.dim(0) 01128 << " parameters into a network with " << x.dim(0) 01129 << " parameters"); 01130 } 01131 this->resize(m.dim(0)); 01132 idx_copy(m, x); 01133 cout << "Loaded weights from " << m << ": " << x << endl; 01134 return true; 01135 } 01136 01137 template <typename T> 01138 void parameter<T,bbstate_idx<T> >::update(gd_param &arg) { 01139 update_gd(arg); 01140 } 01141 01142 template <typename T> 01143 void parameter<T,bbstate_idx<T> >::clear_deltax() { 01144 idx_clear(deltax); 01145 } 01146 01147 template <typename T> 01148 void parameter<T,bbstate_idx<T> >::clear_ddeltax() { 01149 idx_clear(ddeltax); 01150 } 01151 01152 template <typename T> 01153 void parameter<T,bbstate_idx<T> >::set_epsilon(T m) { 01154 idx_fill(epsilons, m); 01155 } 01156 01157 template <typename T> 01158 void parameter<T,bbstate_idx<T> >::compute_epsilons(T mu) { 01159 idx_addc(ddeltax, mu, epsilons); 01160 idx_inv(epsilons, epsilons); 01161 } 01162 01163 template <typename T> 01164 void parameter<T,bbstate_idx<T> >::update_ddeltax(T knew, T kold) { 01165 idx_lincomb(this->ddx, knew, ddeltax, kold, ddeltax); 01166 } 01167 01169 // protected methods 01170 01171 template <typename T> 01172 void parameter<T,bbstate_idx<T> >::update_gd(gd_param &arg) { 01173 if (arg.decay_l2 > 0) 01174 idx_dotcacc(this->x, arg.decay_l2, this->dx); 01175 if (arg.decay_l1 > 0) 01176 idx_signdotcacc(this->x, (T) arg.decay_l1, this->dx); 01177 if (arg.inertia == 0) { 01178 idx_mul(this->dx, epsilons, this->dx); 01179 idx_dotcacc(this->dx, -arg.eta, this->x); 01180 } else { 01181 update_deltax((T) (1 - arg.inertia), (T) arg.inertia); 01182 idx_mul(deltax, epsilons, deltax); 01183 idx_dotcacc(deltax, -arg.eta, this->x); 01184 } 01185 } 01186 01187 template <typename T> 01188 void parameter<T,bbstate_idx<T> >::update_deltax(T knew, T kold) { 01189 idx_lincomb(this->dx, knew, deltax, kold, deltax); 01190 } 01191 01193 // fstate_idxlooper 01194 01195 template <typename T> 01196 state_idxlooper<fstate_idx<T> >::state_idxlooper(fstate_idx<T> &s, int ld) 01197 : fstate_idx<T>(s.x.select(ld, 0)), lx(s.x, ld) { 01198 } 01199 01200 template <typename T> 01201 state_idxlooper<fstate_idx<T> >::~state_idxlooper() { 01202 } 01203 01204 template <typename T> 01205 void state_idxlooper<fstate_idx<T> >::next() { 01206 lx.next(); 01207 x = lx; 01208 } 01209 01210 // return true when done. 01211 template <typename T> 01212 bool state_idxlooper<fstate_idx<T> >::notdone() { 01213 return lx.notdone(); 01214 } 01215 01217 // bstate_idxlooper 01218 01219 template <typename T> 01220 state_idxlooper<bstate_idx<T> >::state_idxlooper(bstate_idx<T> &s, int ld) 01221 : bstate_idx<T>(s.x.select(ld, 0), s.dx.select(ld, 0)), 01222 lx(s.x, ld), ldx(s.dx, ld) { 01223 } 01224 01225 template <typename T> 01226 state_idxlooper<bstate_idx<T> >::~state_idxlooper() { 01227 } 01228 01229 template <typename T> 01230 void state_idxlooper<bstate_idx<T> >::next() { 01231 lx.next(); 01232 ldx.next(); 01233 x = lx; 01234 dx = ldx; 01235 } 01236 01237 // return true when done. 01238 template <typename T> 01239 bool state_idxlooper<bstate_idx<T> >::notdone() { 01240 return lx.notdone(); 01241 } 01242 01244 // bbstate_idxlooper 01245 01246 template <typename T> 01247 state_idxlooper<bbstate_idx<T> >::state_idxlooper(bbstate_idx<T> &s, int ld) 01248 : bbstate_idx<T>(s.x.select(ld, 0), 01249 s.dx.select(ld, 0), 01250 s.ddx.select(ld, 0)), 01251 lx(s.x, ld), ldx(s.dx, ld), lddx(s.ddx, ld) { 01252 } 01253 01254 template <typename T> 01255 state_idxlooper<bbstate_idx<T> >::~state_idxlooper() { 01256 } 01257 01258 template <typename T> 01259 void state_idxlooper<bbstate_idx<T> >::next() { 01260 lx.next(); 01261 ldx.next(); 01262 lddx.next(); 01263 x = lx; 01264 dx = ldx; 01265 ddx = lddx; 01266 } 01267 01268 // return true when done. 01269 template <typename T> 01270 bool state_idxlooper<bbstate_idx<T> >::notdone() { 01271 return lx.notdone(); 01272 } 01273 01275 // mstate_idx 01276 01277 template <class Tstate> mstate<Tstate>::mstate() {} 01278 01279 template <class Tstate> 01280 mstate<Tstate>::mstate(const mstate<Tstate> &ms, intg dims, intg nstates) { 01281 //EDEBUG("constructing new mstate from " << ms); 01282 nstates = (nstates == -1 ? ms.size() : nstates); 01283 for (uint i = 0; i < nstates; ++i) { 01284 idxdim d = ms.at_const(i).x.get_idxdim(); 01285 d.setdims(dims); 01286 Tstate *nt = new Tstate(d); 01287 this->push_back(nt); 01288 } 01289 } 01290 01291 template <class Tstate> 01292 mstate<Tstate>::mstate(const mstate<Tstate> &other) { 01293 svector<Tstate>::copy(other); 01294 } 01295 01296 template <class Tstate> mstate<Tstate>::~mstate() { 01297 } 01298 01299 template <class Tstate> 01300 void mstate<Tstate>::clear_x() { 01301 for (it = this->begin(); it != this->end(); ++it) 01302 it->clear_x(); 01303 } 01304 01305 template <class Tstate> 01306 void mstate<Tstate>::clear_dx() { 01307 for (it = this->begin(); it != this->end(); ++it) 01308 it->clear_dx(); 01309 } 01310 01311 template <class Tstate> 01312 void mstate<Tstate>::clear_ddx() { 01313 for (it = this->begin(); it != this->end(); ++it) 01314 it->clear_ddx(); 01315 } 01316 01317 template <class Tstate> 01318 void mstate<Tstate>::copy(mstate<Tstate> &s) { 01319 for (uint i = 0; i < s.size(); ++i) { 01320 Tstate& cpy = s[i]; 01321 // add more states if necessary 01322 if (i >= this->size()) 01323 this->push_back(new Tstate(cpy.x.get_idxdim())); 01324 // copy 01325 Tstate& local = (*this)[i]; 01326 local.copy(cpy); 01327 } 01328 } 01329 01330 template <class Tstate> template <typename T> 01331 void mstate<Tstate>::copy(midx<T> &s) { 01332 for (int i = 0; i < s.dim(0); ++i) { 01333 idx<T> d = s.get(i); 01334 // add more states if necessary 01335 if (i >= (int) this->size()) 01336 this->push_back(new Tstate(d.get_idxdim())); 01337 Tstate& local = (*this)[i]; 01338 local.x = d; 01339 } 01340 } 01341 01342 template <class Tstate> template <typename T> 01343 midx<T> mstate<Tstate>::copy() { 01344 midx<T> s(this->size()); 01345 for (uint i = 0; i < this->size(); ++i) { 01346 Tstate& local = (*this)[i]; 01347 s.set(local.x, i); 01348 } 01349 return s; 01350 } 01351 01352 template <class Tstate> 01353 mstate<Tstate> mstate<Tstate>::narrow(intg size, intg offset) { 01354 if ((uint) (size + offset) > this->size()) 01355 eblerror("cannot narrow this vector of size " << this->size() 01356 << " to size " << size << " starting at offset " << offset); 01357 mstate<Tstate> ms; 01358 for (uint i = 0; i < size; ++i) 01359 ms.push_back(this->at(i + offset)); 01360 EDEBUG("narrowed " << *this << " into " << ms); 01361 return ms; 01362 } 01363 01364 template <class Tstate> 01365 mstate<Tstate> mstate<Tstate>::narrow(int dimension, intg size, intg offset) { 01366 eblerror("not implemented"); 01367 mstate<Tstate> ms; 01368 return ms; 01369 } 01370 01371 template <class Tstate> 01372 mstate<Tstate> mstate<Tstate>::narrow(midxdim &dims) { 01373 if (dims.size() != this->size()) 01374 eblerror("expected same size input regions and states but got: " 01375 << dims << " and " << *this); 01376 mstate<Tstate> all; 01377 for (uint i = 0; i < dims.size(); ++i) { 01378 Tstate in = (*this)[i]; 01379 idxdim d = dims[i]; 01380 // narrow input, ignoring 1st dim 01381 for (uint j = 1; j < d.order(); ++j) 01382 in = in.narrow(j, d.dim(j), d.offset(j)); 01383 all.push_back(new Tstate(in)); 01384 } 01385 return all; 01386 } 01387 01388 template <class Tstate> 01389 mstate<Tstate> mstate<Tstate>::narrow_max(mfidxdim &dims) { 01390 if (dims.size() != this->size()) 01391 eblerror("expected same size input regions and states but got: " 01392 << dims << " and " << *this); 01393 mstate<Tstate> all; 01394 for (uint i = 0; i < dims.size(); ++i) { 01395 Tstate in = (*this)[i]; 01396 fidxdim d = dims[i]; 01397 // narrow input, ignoring 1st dim 01398 for (uint j = 1; j < d.order(); ++j) 01399 in = in.narrow(j, (intg) std::min(in.x.dim(j) - d.offset(j), d.dim(j)), 01400 (intg) d.offset(j)); 01401 all.push_back(new Tstate(in)); 01402 } 01403 return all; 01404 } 01405 01406 template <class Tstate> template <class T> 01407 void mstate<Tstate>::get_midx(mfidxdim &dims, midx<T> &all) { 01408 // first narrow state 01409 mstate<Tstate> n = this->narrow(dims); 01410 // now set all x into an midx 01411 all = midx<T>(n.size()); 01412 for (uint i = 0; i < n.size(); ++i) 01413 all.set(n[i].x, i); 01414 } 01415 01416 template <class Tstate> template <class T> 01417 void mstate<Tstate>::get_max_midx(mfidxdim &dims, midx<T> &all) { 01418 // first narrow state 01419 mstate<Tstate> n = this->narrow_max(dims); 01420 // now set all x into an midx 01421 all = midx<T>(n.size()); 01422 for (uint i = 0; i < n.size(); ++i) 01423 all.set(n[i].x, i); 01424 } 01425 01426 template <class Tstate> template <class T> 01427 void mstate<Tstate>::get_padded_midx(mfidxdim &dims, midx<T> &all) { 01428 if (dims.size() != this->size()) 01429 eblerror("expected same size input regions and states but got: " 01430 << dims << " and " << *this); 01431 all.clear(); 01432 all.resize(dims.size_existing()); 01433 uint ooff, ioff, osize, n = 0; 01434 bool bcopy; 01435 for (uint i = 0; i < dims.size(); ++i) { 01436 if (dims.exists(i)) { 01437 idx<T> in = (*this)[i].x; 01438 idxdim d(dims[i]); 01439 d.setdim(0, in.dim(0)); 01440 idx<T> out(d); 01441 idx_clear(out); 01442 all.set(out, n); 01443 bcopy = true; 01444 // narrow input, ignoring 1st dim 01445 for (uint j = 1; j < d.order(); ++j) { 01446 // if no overlap, skip this state 01447 if (d.offset(j) >= in.dim(j) || d.offset(j) + d.dim(j) <= 0) { 01448 bcopy = false; 01449 break ; 01450 } 01451 // determine narrow params 01452 ooff = (uint) std::max(0, (int) - d.offset(j)); 01453 ioff = (uint) std::max(0, (int) d.offset(j)); 01454 osize = (uint) std::min(d.dim(j) - ooff, in.dim(j) - ioff); 01455 // narrow 01456 out = out.narrow(j, osize, ooff); 01457 in = in.narrow(j, osize, ioff); 01458 } 01459 // copy 01460 if (bcopy) idx_copy(in, out); 01461 n++; 01462 } 01463 } 01464 } 01465 01466 template <class Tstate> 01467 void mstate<Tstate>::resize(mstate<Tstate> &s2, uint nmax) { 01468 uint sz = s2.size(); 01469 if (nmax > 0) 01470 sz = std::min((uint) s2.size(), nmax); 01471 if (this->size() != sz) 01472 mstate<Tstate>::clear(); 01473 for (uint i = 0; i < sz; ++i) { 01474 Tstate &t2 = s2[i]; 01475 idxdim d = t2.x; 01476 if (this->size() < sz) { 01477 Tstate *nt = new Tstate(d); 01478 this->push_back(nt); 01479 } 01480 else { // state already exists 01481 Tstate &t = (*this)[i]; 01482 if (t.x.order() != d.order()) // wrong order, reassign state 01483 t = Tstate(d); 01484 else if (t.x.get_idxdim() != d) // correct order but wrong dimensions 01485 t.resize(d); 01486 } 01487 } 01488 } 01489 01490 template <class Tstate> template <class Tstate2> 01491 void mstate<Tstate>::resize(mstate<Tstate2> &other) { 01492 bool reset = false; 01493 // check we have the right number of states in out 01494 if (other.nstates() != this->nstates()) 01495 reset = true; 01496 // check that all states have the right orders 01497 for (uint i = 0; i < other.nstates() && !reset; ++i) { 01498 Tstate2 &sother = other[i]; 01499 Tstate &sthis = (*this)[i]; 01500 if (sthis.x.order() != sother.x.order()) 01501 reset = true; 01502 } 01503 // allocate 01504 if (reset) { 01505 mstate<Tstate>::clear(); 01506 for (uint i = 0; i < other.nstates(); ++i) { 01507 Tstate2 &sother = other[i]; 01508 Tstate *nt = new Tstate(sother.x.get_idxdim()); 01509 this->push_back(nt); 01510 } 01511 } 01512 } 01513 01514 template <class Tstate> 01515 idxdim& mstate<Tstate>::get_idxdim0() { 01516 Tstate& s0 = (*this)[0]; 01517 return s0.x.get_idxdim(); 01518 } 01519 01521 // stream operators 01522 01523 template <typename T> 01524 EXPORT std::ostream& operator<<(std::ostream &out, const fstate_idx<T> &m) { 01525 out << "(x:" << m.x << ")"; 01526 return out; 01527 } 01528 01529 template <typename T> 01530 EXPORT std::ostream& operator<<(std::ostream &out, const bstate_idx<T> &m) { 01531 out << "(x:" << m.x << "," << m.dx << ")"; 01532 return out; 01533 } 01534 01535 template <typename T> 01536 EXPORT std::ostream& operator<<(std::ostream &out, const bbstate_idx<T> &m) { 01537 out << "(x:" << m.x << "," << m.dx << "," << m.ddx << ")"; 01538 return out; 01539 } 01540 01541 template <class Tstate> 01542 EXPORT std::ostream& operator<<(std::ostream &out, const mstate<Tstate> &m) { 01543 out << "["; 01544 if (m.size() == 0) 01545 out << "empty"; 01546 else { 01547 // const Tstate &s = m[0]; 01548 const Tstate &s = m.at_const(0); 01549 out << s; 01550 for (uint i = 1; i < m.size(); ++i) { 01551 // const Tstate &st = m[i]; 01552 const Tstate &st = m.at_const(i); 01553 out << "," << st; 01554 } 01555 } 01556 out << "]"; 01557 return out; 01558 } 01559 01560 } // end namespace ebl