libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2008 by Yann LeCun, Pierre Sermanet * 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 namespace ebl { 00034 00035 // module_1_1 //////////////////////////////////////////////////////////////// 00036 00037 template <typename T, class Tin, class Tout> 00038 module_1_1<T,Tin,Tout>::module_1_1(const char *name, bool bresize_) 00039 : module(name), bresize(bresize_), memoptimized(false), 00040 bmstate_input(false), bmstate_output(false), ninputs(1), noutputs(1) { 00041 } 00042 00043 template <typename T, class Tin, class Tout> 00044 module_1_1<T,Tin,Tout>::~module_1_1() { 00045 EDEBUG("deleting module_1_1: " << _name); 00046 } 00047 00048 // single-state methods ////////////////////////////////////////////////////// 00049 00050 template <typename T, class Tin, class Tout> 00051 void module_1_1<T,Tin,Tout>::fprop(Tin &in, Tout &out) { 00052 err_not_implemented(); } 00053 00054 template <typename T, class Tin, class Tout> 00055 void module_1_1<T,Tin,Tout>::bprop(Tin &in, Tout &out) { 00056 err_not_implemented(); } 00057 00058 template <typename T, class Tin, class Tout> 00059 void module_1_1<T,Tin,Tout>::bbprop(Tin &in, Tout &out) { 00060 err_not_implemented(); } 00061 00062 template <typename T, class Tin, class Tout> 00063 void module_1_1<T,Tin,Tout>::dump_fprop(Tin &in, Tout &out) { 00064 fprop(in, out); // no dumping by default, just fproping. 00065 } 00066 00067 // multi-state methods /////////////////////////////////////////////////////// 00068 00069 template <typename T, class Tin, class Tout> 00070 void module_1_1<T,Tin,Tout>::fprop(mstate<Tin> &in, mstate<Tout> &out) { 00071 // check that in/out have at least 1 state and the same number of them. 00072 if (in.size() == 0) eblerror("input should have at least 1"); 00073 // if (in.size() != out.size()) out.resize(in); 00074 out.resize(in); 00075 // run regular fprop on each states 00076 for (uint i = 0; i < in.size(); ++i) { 00077 EDEBUG(this->name() << ": fprop at index " << i << " in: " 00078 << in << " and out: " << out); 00079 Tin &fin = in[i]; 00080 Tout &fout = out[i]; 00081 EDEBUG(this->name() << ": in.x " << fin.x << ", min " << idx_min(fin.x) 00082 << " max " << idx_max(fin.x)); 00083 fprop(fin, fout); 00084 } 00085 // remember number of input/outputs 00086 ninputs = in.size(); 00087 noutputs = out.size(); 00088 } 00089 00090 template <typename T, class Tin, class Tout> 00091 void module_1_1<T,Tin,Tout>::bprop(mstate<Tin> &in, mstate<Tout> &out) { 00092 // run regular bbprop on each states 00093 for (int i = (int) in.size() - 1; i >= 0; --i) { 00094 EDEBUG(this->name() << ": bprop at index " << i << " in: " 00095 << in << " and out: " << out); 00096 Tin &bin = in[i]; 00097 Tout &bout = out[i]; 00098 EDEBUG(this->name() << ": bprop in.x " << bin.x 00099 << " min " << idx_min(bin.x) << " max " << idx_max(bin.x) 00100 << " out.x " << bout.x 00101 << " min " << idx_min(bout.x) << " max " << idx_max(bout.x)); 00102 bprop(bin, bout); 00103 } 00104 } 00105 00106 template <typename T, class Tin, class Tout> 00107 void module_1_1<T,Tin,Tout>::bbprop(mstate<Tin> &in, mstate<Tout> &out) { 00108 // run regular bbprop on each states 00109 for (int i = (int) in.size() - 1; i >= 0; --i) { 00110 Tin &bbin = in[i]; 00111 Tout &bbout = out[i]; 00112 bbprop(bbin, bbout); 00113 } 00114 } 00115 00116 template <typename T, class Tin, class Tout> 00117 void module_1_1<T,Tin,Tout>::dump_fprop(mstate<Tin> &in, mstate<Tout> &out) { 00118 // check that in/out have at least 1 state and the same number of them. 00119 if (in.size() == 0) 00120 eblerror("input should have at least 1"); 00121 if (in.size() != out.size()) 00122 out.resize(in); 00123 // run regular bbprop on each states 00124 for (uint i = 0; i < in.size(); ++i) { 00125 Tin &fin = in[i]; 00126 Tout &fout = out[i]; 00127 EDEBUG(this->name() << ": in.x " << fin.x << ", min " << idx_min(fin.x) 00128 << " max " << idx_max(fin.x)); 00129 this->dump_fprop(fin, fout); 00130 } 00131 // remember number of input/outputs 00132 ninputs = in.size(); 00133 noutputs = out.size(); 00134 } 00135 00136 // multi to single state methods ///////////////////////////////////////////// 00137 00138 template <typename T, class Tin, class Tout> 00139 void module_1_1<T,Tin,Tout>::fprop(mstate<Tin> &in, Tout &out) { 00140 err_not_implemented(); } 00141 00142 template <typename T, class Tin, class Tout> 00143 void module_1_1<T,Tin,Tout>::bprop(mstate<Tin> &in, Tout &out) { 00144 err_not_implemented(); } 00145 00146 template <typename T, class Tin, class Tout> 00147 void module_1_1<T,Tin,Tout>::bbprop(mstate<Tin> &in, Tout &out) { 00148 err_not_implemented(); } 00149 00150 template <typename T, class Tin, class Tout> 00151 void module_1_1<T,Tin,Tout>::dump_fprop(mstate<Tin> &in, Tout &out) { 00152 } // eblerror("not implemented for " << this->name()); } 00153 00154 // single to multi state methods ///////////////////////////////////////////// 00155 00156 template <typename T, class Tin, class Tout> 00157 void module_1_1<T,Tin,Tout>::fprop(Tin &in, mstate<Tout> &out) { 00158 err_not_implemented(); } 00159 00160 template <typename T, class Tin, class Tout> 00161 void module_1_1<T,Tin,Tout>::bprop(Tin &in, mstate<Tout> &out) { 00162 err_not_implemented(); } 00163 00164 template <typename T, class Tin, class Tout> 00165 void module_1_1<T,Tin,Tout>::bbprop(Tin &in, mstate<Tout> &out) { 00166 err_not_implemented(); } 00167 00168 template <typename T, class Tin, class Tout> 00169 void module_1_1<T,Tin,Tout>::dump_fprop(Tin &in, mstate<Tout> &out) { 00170 err_not_implemented(); } 00171 00173 00174 template <typename T, class Tin, class Tout> 00175 void module_1_1<T,Tin,Tout>::forget(forget_param_linear& fp) { 00176 } 00177 00178 template <typename T, class Tin, class Tout> 00179 void module_1_1<T,Tin,Tout>::normalize() { 00180 } 00181 00182 template <typename T, class Tin, class Tout> 00183 int module_1_1<T,Tin,Tout>::replicable_order() { return -1; } 00184 00185 template <typename T, class Tin, class Tout> 00186 bool module_1_1<T,Tin,Tout>::ignored(Tin &in, Tout &out) { 00187 if (this->_enabled) return false; 00188 idx_copy(in.x, out.x); 00189 return true; 00190 } 00191 00192 // resizing ////////////////////////////////////////////////////////////////// 00193 00194 template <typename T, class Tin, class Tout> 00195 bool module_1_1<T,Tin,Tout>::resize_output(Tin &in, Tout &out, idxdim *d) { 00196 //if (!this->bresize) return false; // no resizing 00197 if (&in == &out) return false; // resize only when in and out are different 00198 TIMING_RESIZING_ACCSTART(); // start accumulating resizing time 00199 if (d) { // use d as target dims 00200 if (d->order() != out.x.order()) { // re-allocate buffer 00201 EDEBUG(this->name() << ": reallocating output from " << out.x 00202 << " to " << d); 00203 out = Tout(*d); 00204 } else if (*d != out.x.get_idxdim()) { // resize buffer 00205 EDEBUG(this->name() << ": resizing output from " << out.x << " to " 00206 << *d); 00207 out.resize(*d); 00208 } else { 00209 TIMING_RESIZING_ACCSTOP(); // stop accumulating resizing time 00210 return false; 00211 } 00212 } else { // use in.x as target dims 00213 if (in.x.order() != out.x.order()) { // re-allocate buffer 00214 EDEBUG(this->name() << ": reallocating output from " << out.x 00215 << " to " << in.x.get_idxdim()); 00216 out = Tout(in.x.get_idxdim()); 00217 } else if (in.x.get_idxdim() != out.x.get_idxdim()) { // resize buffer 00218 EDEBUG(this->name() << ": resizing output from " << out.x << " to " 00219 << in.x.get_idxdim()); 00220 out.resize(in.x.get_idxdim()); 00221 } else { 00222 TIMING_RESIZING_ACCSTOP(); // stop accumulating resizing time 00223 return false; 00224 } 00225 } 00226 TIMING_RESIZING_ACCSTOP(); // stop accumulating resizing time 00227 return true; 00228 } 00229 00230 template <typename T, class Tin, class Tout> 00231 bool module_1_1<T,Tin,Tout>::resize_output(Tin &in, idx<T> &out, idxdim *d) { 00232 //if (!this->bresize) return false; // no resizing 00233 if (&in.x == &out) return false; // resize only when different 00234 TIMING_RESIZING_ACCSTART(); // start accumulating resizing time 00235 if (d) { // use d as target dims 00236 if (d->order() != out.order()) { // re-allocate buffer 00237 EDEBUG(this->name() << ": reallocating output from " << out 00238 << " to " << d); 00239 out = idx<T>(*d); 00240 } else if (*d != out.get_idxdim()) { // resize buffer 00241 EDEBUG(this->name() << ": resizing output from " << out << " to " << *d); 00242 out.resize(*d); 00243 } else { 00244 TIMING_RESIZING_ACCSTOP(); // stop accumulating resizing time 00245 return false; 00246 } 00247 } else { // use in.x as target dims 00248 if (in.x.order() != out.order()) { // re-allocate buffer 00249 EDEBUG(this->name() << ": reallocating output from " << out 00250 << " to " << in.x.get_idxdim()); 00251 out = idx<T>(in.x.get_idxdim()); 00252 } else if (in.x.get_idxdim() != out.get_idxdim()) { // resize buffer 00253 EDEBUG(this->name() << ": resizing output from " << out << " to " 00254 << in.x.get_idxdim()); 00255 out.resize(in.x.get_idxdim()); 00256 } else { 00257 TIMING_RESIZING_ACCSTOP(); // stop accumulating resizing time 00258 return false; 00259 } 00260 } 00261 TIMING_RESIZING_ACCSTOP(); // stop accumulating resizing time 00262 return true; 00263 } 00264 00265 template <typename T, class Tin, class Tout> 00266 fidxdim module_1_1<T,Tin,Tout>::fprop_size(fidxdim &isize) { 00267 return isize; 00268 } 00269 00270 template <typename T, class Tin, class Tout> 00271 fidxdim module_1_1<T,Tin,Tout>::bprop_size(const fidxdim &osize) { 00272 //EDEBUG(this->name() << ": " << osize << " -> same"); 00273 return osize; 00274 } 00275 00276 template <typename T, class Tin, class Tout> 00277 mfidxdim module_1_1<T,Tin,Tout>::fprop_size(mfidxdim &isize) { 00278 mfidxdim osize; 00279 for (uint i = 0; i < isize.size(); ++i) 00280 if (!isize.exists(i)) osize.push_back_empty(); 00281 else osize.push_back(this->fprop_size(isize[i])); 00282 return osize; 00283 } 00284 00285 // template <typename T, class Tin, class Tout> 00286 // mfidxdim module_1_1<T,Tin,Tout>::bprop_size(const mfidxdim &osize) { 00287 // EDEBUG(this->name() << ": " << osize << " -> ..."); 00288 // mfidxdim isize; 00289 // for (mfidxdim::const_iterator i = osize.begin(); i != osize.end(); ++i) { 00290 // if (i.exists()) 00291 // isize.push_back(this->bprop_size(*i)); 00292 // else 00293 // isize.push_back_empty(); 00294 // } 00295 // EDEBUG(this->name() << ": " << osize << " -> " << isize); 00296 // return isize; 00297 // } 00298 00299 template <typename T, class Tin, class Tout> 00300 mfidxdim module_1_1<T,Tin,Tout>::bprop_size(mfidxdim &osize) { 00301 EDEBUG(this->name() << ": " << osize << " b-> ..."); 00302 mfidxdim isize; 00303 for (mfidxdim::iterator i = osize.begin(); i != osize.end(); ++i) { 00304 if (i.exists()) 00305 isize.push_back(this->bprop_size(*i)); 00306 else 00307 isize.push_back_empty(); 00308 } 00309 EDEBUG(this->name() << ": " << osize << " b-> " << isize); 00310 return isize; 00311 } 00312 00313 template <typename T, class Tin, class Tout> 00314 std::string module_1_1<T,Tin,Tout>::pretty(idxdim &isize) { 00315 std::string s; 00316 fidxdim d = isize; 00317 s << " -> " << this->_name.c_str() << " -> " << fprop_size(d); 00318 return s; 00319 } 00320 00321 template <typename T, class Tin, class Tout> 00322 std::string module_1_1<T,Tin,Tout>::pretty(mfidxdim &isize) { 00323 std::string s; 00324 midxdim d = fprop_size(isize); 00325 s << " -> " << this->_name.c_str() << " -> " << d; 00326 return s; 00327 } 00328 00329 template <typename T, class Tin, class Tout> 00330 module_1_1<T,Tin,Tout>* module_1_1<T,Tin,Tout>::copy() { 00331 return this->copy(NULL); 00332 } 00333 00334 template <typename T, class Tin, class Tout> 00335 module_1_1<T,Tin,Tout>* module_1_1<T,Tin,Tout>::copy(parameter<T,Tin> *p) { 00336 eblerror("deep copy not implemented in " << this->name()); 00337 return NULL; 00338 } 00339 00340 template <typename T, class Tin, class Tout> 00341 bool module_1_1<T,Tin,Tout>::optimize_fprop(Tin& in, Tout& out){ 00342 return true; 00343 } 00344 00345 template <typename T, class Tin, class Tout> 00346 bool module_1_1<T,Tin,Tout>::optimize_fprop(mstate<Tin>& in, 00347 mstate<Tout>& out){ 00348 eblerror("memory optimization not implemented for mstates"); 00349 return false; 00350 } 00351 00352 template <typename T, class Tin, class Tout> 00353 void module_1_1<T,Tin,Tout>::load_x(idx<T> &weights) { 00354 err_not_implemented(); } 00355 00356 template <typename T, class Tin, class Tout> 00357 module_1_1<T,Tin,Tout>* module_1_1<T,Tin,Tout>::last_module() { 00358 return this; 00359 } 00360 00361 template <typename T, class Tin, class Tout> 00362 bool module_1_1<T,Tin,Tout>::mstate_input() { 00363 return bmstate_input; 00364 } 00365 00366 template <typename T, class Tin, class Tout> 00367 bool module_1_1<T,Tin,Tout>::mstate_output() { 00368 return bmstate_output; 00369 } 00370 00371 template <typename T, class Tin, class Tout> 00372 uint module_1_1<T,Tin,Tout>::get_ninputs() { 00373 return ninputs; 00374 } 00375 00376 template <typename T, class Tin, class Tout> 00377 uint module_1_1<T,Tin,Tout>::get_noutputs() { 00378 return noutputs; 00379 } 00380 00382 // module_2_1 00383 00384 template <typename T, class Tin1, class Tin2, class Tout> 00385 module_2_1<T,Tin1,Tin2,Tout>::module_2_1(const char *name_) 00386 : module(name_), bresize(true) { 00387 } 00388 00389 template <typename T, class Tin1, class Tin2, class Tout> 00390 module_2_1<T,Tin1,Tin2,Tout>::~module_2_1() { 00391 #ifdef __DEBUG__ 00392 cout << "deleting module_2_1: " << _name << endl; 00393 #endif 00394 } 00395 00397 // generic state methods 00398 00399 template <typename T, class Tin1, class Tin2, class Tout> 00400 void module_2_1<T,Tin1,Tin2,Tout>::fprop(Tin1 &in1, Tin2 &in2, Tout &out) { 00401 err_not_implemented(); } 00402 00403 template <typename T, class Tin1, class Tin2, class Tout> 00404 void module_2_1<T,Tin1,Tin2,Tout>::bprop(Tin1 &in1, Tin2 &in2, Tout &out) { 00405 err_not_implemented(); } 00406 00407 template <typename T, class Tin1, class Tin2, class Tout> 00408 void module_2_1<T,Tin1,Tin2,Tout>::bbprop(Tin1 &in1, Tin2 &in2, Tout &out){ 00409 err_not_implemented(); } 00410 00412 // multi-state methods 00413 00414 template <typename T, class Tin1, class Tin2, class Tout> 00415 void module_2_1<T,Tin1,Tin2,Tout>::fprop(mstate<Tin1> &in1, mstate<Tin2> &in2, 00416 mstate<Tout> &out) { 00417 // check that in/out have at least 1 state and the same number of them. 00418 if (in1.size() == 0 || in2.size() == 0 || out.size() == 0 00419 || in1.size() != out.size() || in2.size() != out.size()) 00420 eblerror("in1, in2 and out don't have at least 1 state or don't have the " 00421 << "same number of states: in1: " << in2.size() 00422 << " in2: " << in2.size() << " out: " << out.size()); 00423 // run regular bbprop on each states 00424 for (uint i = 0; i < in1.size(); ++i) { 00425 Tin1 &fin1 = in1[i]; 00426 Tin2 &fin2 = in2[i]; 00427 Tout &fout = out[i]; 00428 fprop(fin1, fin2, fout); 00429 } 00430 } 00431 00432 template <typename T, class Tin1, class Tin2, class Tout> 00433 void module_2_1<T,Tin1,Tin2,Tout>::bprop(mstate<Tin1> &in1, mstate<Tin2> &in2, 00434 mstate<Tout> &out) { 00435 // check that in/out have at least 1 state and the same number of them. 00436 if (in1.size() == 0 || in2.size() == 0 || out.size() == 0 00437 || in1.size() != out.size() || in2.size() != out.size()) 00438 eblerror("in1, in2 and out don't have at least 1 state or don't have the " 00439 << "same number of states: in1: " << in2.size() 00440 << " in2: " << in2.size() << " out: " << out.size()); 00441 // run regular bbprop on each states 00442 for (uint i = 0; i < in1.size(); ++i) { 00443 Tin1 &bin1 = in1[i]; 00444 Tin2 &bin2 = in2[i]; 00445 Tout &bout = out[i]; 00446 bprop(bin1, bin2, bout); 00447 } 00448 } 00449 00450 template <typename T, class Tin1, class Tin2, class Tout> 00451 void module_2_1<T,Tin1,Tin2,Tout>::bbprop(mstate<Tin1> &in1, 00452 mstate<Tin2> &in2, 00453 mstate<Tout> &out) { 00454 // check that in/out have at least 1 state and the same number of them. 00455 if (in1.size() == 0 || in2.size() == 0 || out.size() == 0 00456 || in1.size() != out.size() || in2.size() != out.size()) 00457 eblerror("in1, in2 and out don't have at least 1 state or don't have the " 00458 << "same number of states: in1: " << in2.size() 00459 << " in2: " << in2.size() << " out: " << out.size()); 00460 // run regular bbprop on each states 00461 for (uint i = 0; i < in1.size(); ++i) { 00462 Tin1 &bin1 = in1[i]; 00463 Tin2 &bin2 = in2[i]; 00464 Tout &bout = out[i]; 00465 bbprop(bin1, bin2, bout); 00466 } 00467 } 00468 00470 00471 template <typename T, class Tin1, class Tin2, class Tout> 00472 void module_2_1<T,Tin1,Tin2,Tout>::forget(forget_param &fp) { 00473 err_not_implemented(); } 00474 00475 template <typename T, class Tin1, class Tin2, class Tout> 00476 void module_2_1<T,Tin1,Tin2,Tout>::normalize() { err_not_implemented(); } 00477 00478 template <typename T, class Tin1, class Tin2, class Tout> 00479 bool module_2_1<T,Tin1,Tin2,Tout>::resize_output(Tin1 &in1, Tin2 &in2, 00480 Tout &out, idxdim *d) { 00481 if (!bresize) return false; // no resizing 00482 if (&in1 == &out) return false; // resize only when in and out are different 00483 if (d) { // use d as target dims 00484 if (d->order() != out.x.order()) { // re-allocate buffer 00485 EDEBUG(this->name() << ": reallocating output from " << out.x 00486 << " to " << d); 00487 out = Tout(*d); 00488 } else if (*d != out.x.get_idxdim()) { // resize buffer 00489 EDEBUG(this->name() << ": resizing output from " << out.x << " to " 00490 << *d); 00491 out.resize(*d); 00492 } else return false; 00493 } else { // use in.x as target dims 00494 if (in1.x.order() != out.x.order()) { // re-allocate buffer 00495 EDEBUG(this->name() << ": reallocating output from " << out.x 00496 << " to " << in1.x.get_idxdim()); 00497 out = Tout(in1.x.get_idxdim()); 00498 } else if (in1.x.get_idxdim() != out.x.get_idxdim()) { // resize buffer 00499 EDEBUG(this->name() << ": resizing output from " << out.x << " to " 00500 << in1.x.get_idxdim()); 00501 out.resize(in1.x.get_idxdim()); 00502 } else return false; 00503 } 00504 if (in2.x.get_idxdim() != in1.x.get_idxdim()) 00505 eblerror("expected same size inputs " << in1.x << " and " << in2.x); 00506 return true; 00507 } 00508 00510 // ebm_1 00511 00512 template <typename T, class Tin, class Ten> 00513 ebm_1<T,Tin,Ten>::ebm_1(const char *n) : module(n) { 00514 } 00515 00516 template <typename T, class Tin, class Ten> 00517 ebm_1<T,Tin,Ten>::~ebm_1() { 00518 } 00519 00520 template <typename T, class Tin, class Ten> 00521 void ebm_1<T,Tin,Ten>::fprop(Tin &in, Ten &energy) { 00522 err_not_implemented(); } 00523 00524 template <typename T, class Tin, class Ten> 00525 void ebm_1<T,Tin,Ten>::bprop(Tin &in, Ten &energy) { 00526 err_not_implemented(); } 00527 00528 template <typename T, class Tin, class Ten> 00529 void ebm_1<T,Tin,Ten>::bbprop(Tin &in, Ten &energy) { 00530 err_not_implemented(); } 00531 00532 template <typename T, class Tin, class Ten> 00533 void ebm_1<T,Tin,Ten>::forget(forget_param &fp) { 00534 err_not_implemented(); } 00535 00536 template <typename T, class Tin, class Ten> 00537 void ebm_1<T,Tin,Ten>::normalize() { err_not_implemented(); } 00538 00540 // ebm_module_1_1 00541 00542 template <typename T, class Tin, class Tout, class Ten> 00543 ebm_module_1_1<T,Tin,Tout,Ten>:: 00544 ebm_module_1_1(module_1_1<T,Tin,Tout> *m, ebm_1<T,Ten> *e, const char *name_) 00545 : module_1_1<T,Tin,Tout>(name_), module(m), ebm(e) { 00546 if (!m) eblerror("expected non-null module"); 00547 if (!e) eblerror("expected non-null ebm"); 00548 energy.dx.set((T)1.0); // d(E)/dE is always 1 00549 energy.ddx.set((T)0.0); // dd(E)/dE is always 0 00550 } 00551 00552 template <typename T, class Tin, class Tout, class Ten> 00553 ebm_module_1_1<T,Tin,Tout,Ten>::~ebm_module_1_1() { 00554 delete module; 00555 delete ebm; 00556 } 00557 00559 // single-state methods 00560 00561 template <typename T, class Tin, class Tout, class Ten> 00562 void ebm_module_1_1<T,Tin,Tout,Ten>::fprop(Tin &in, Tout &out) { 00563 EDEBUG(this->name() << ": " << module->name() << ": in " << in); 00564 module->fprop(in, out); 00565 ebm->fprop(out, energy); 00566 } 00567 00568 template <typename T, class Tin, class Tout, class Ten> 00569 void ebm_module_1_1<T,Tin,Tout,Ten>::bprop(Tin &in, Tout &out) { 00570 EDEBUG(this->name() << ": " << module->name() << ": bprop in " << in); 00571 ebm->bprop(out, energy); 00572 module->bprop(in, out); 00573 } 00574 00575 template <typename T, class Tin, class Tout, class Ten> 00576 void ebm_module_1_1<T,Tin,Tout,Ten>::bbprop(Tin &in, Tout &out) { 00577 ebm->bbprop(out, energy); 00578 module->bbprop(in, out); 00579 } 00580 00581 template <typename T, class Tin, class Tout, class Ten> 00582 void ebm_module_1_1<T,Tin,Tout,Ten>::forget(forget_param_linear &fp) { 00583 module->forget(fp); 00584 } 00585 00586 template <typename T, class Tin, class Tout, class Ten> 00587 Ten& ebm_module_1_1<T,Tin,Tout,Ten>::get_energy() { 00588 return energy; 00589 } 00590 00591 template <typename T, class Tin, class Tout, class Ten> 00592 fidxdim ebm_module_1_1<T,Tin,Tout,Ten>::fprop_size(fidxdim &isize) { 00593 return module->fprop_size(isize); 00594 } 00595 00596 template <typename T, class Tin, class Tout, class Ten> 00597 fidxdim ebm_module_1_1<T,Tin,Tout,Ten>::bprop_size(const fidxdim &osize) { 00598 return module->bprop_size(osize); 00599 } 00600 00601 template <typename T, class Tin, class Tout, class Ten> 00602 std::string ebm_module_1_1<T,Tin,Tout,Ten>::describe() { 00603 std::string desc; 00604 desc << "ebm_module_1_1 " << this->name() << " contains a module_1_1: " 00605 << module->describe() << ", and an ebm1: " << ebm->describe(); 00606 return desc; 00607 } 00608 00610 // ebm_2 00611 00612 template <class Tin1, class Tin2, class Ten> 00613 ebm_2<Tin1,Tin2,Ten>::ebm_2(const char *name_) : module(name_) { 00614 } 00615 00616 template <class Tin1, class Tin2, class Ten> 00617 ebm_2<Tin1,Tin2,Ten>::~ebm_2() { 00618 } 00619 00620 template <class Tin1, class Tin2, class Ten> 00621 void ebm_2<Tin1,Tin2,Ten>::fprop(Tin1 &i1, Tin2 &i2,Ten &energy){ 00622 err_not_implemented(); } 00623 00624 template <class Tin1, class Tin2, class Ten> 00625 void ebm_2<Tin1,Tin2,Ten>::bprop(Tin1 &i1, Tin2 &i2,Ten &energy){ 00626 err_not_implemented(); } 00627 00628 template <class Tin1, class Tin2, class Ten> 00629 void ebm_2<Tin1,Tin2,Ten>::bbprop(Tin1 &i1, Tin2 &i2,Ten &energy) 00630 { err_not_implemented(); } 00631 00632 template <class Tin1, class Tin2, class Ten> 00633 void ebm_2<Tin1,Tin2,Ten>::bprop1_copy(Tin1 &i1, Tin2 &i2, Ten &energy) { 00634 err_not_implemented(); } 00635 00636 template <class Tin1, class Tin2, class Ten> 00637 void ebm_2<Tin1,Tin2,Ten>::bprop2_copy(Tin1 &i1, Tin2 &i2, Ten &energy) { 00638 err_not_implemented(); } 00639 00640 template <class Tin1, class Tin2, class Ten> 00641 void ebm_2<Tin1,Tin2,Ten>::bbprop1_copy(Tin1 &i1, Tin2 &i2, Ten &energy) { 00642 err_not_implemented(); } 00643 00644 template <class Tin1, class Tin2, class Ten> 00645 void ebm_2<Tin1,Tin2,Ten>::bbprop2_copy(Tin1 &i1, Tin2 &i2, Ten &energy) { 00646 err_not_implemented(); } 00647 00648 template <class Tin1, class Tin2, class Ten> 00649 void ebm_2<Tin1,Tin2,Ten>::forget(forget_param_linear &fp) { 00650 err_not_implemented(); } 00651 00652 template <class Tin1, class Tin2, class Ten> 00653 void ebm_2<Tin1,Tin2,Ten>::normalize() { 00654 err_not_implemented(); } 00655 00656 template <class Tin1, class Tin2, class Ten> 00657 double ebm_2<Tin1,Tin2,Ten>::infer1(Tin1 &i1, Tin2 &i2, Ten &energy, 00658 infer_param &ip) { 00659 err_not_implemented(); return 0; } 00660 00661 template <class Tin1, class Tin2, class Ten> 00662 double ebm_2<Tin1,Tin2,Ten>::infer2(Tin1 &i1, Tin2 &i2, infer_param &ip, 00663 Tin2 *label, Ten *energy) { 00664 err_not_implemented(); return 0; } 00665 00666 template <class Tin1, class Tin2, class Ten> 00667 void ebm_2<Tin1,Tin2,Ten>::infer2_copy(Tin1 &i1, Tin2 &i2, Ten &energy) { 00668 err_not_implemented(); } 00669 00671 // layers 00672 00673 template <typename T, class Tstate> 00674 layers<T,Tstate>::layers(bool oc, const char *name_, 00675 bool is_branch, bool narrow, intg dim, 00676 intg sz, intg offset) 00677 : module_1_1<T,Tstate>(name_), intern_out(NULL), 00678 hi(NULL), ho(NULL), htmp(NULL), 00679 /* parallelism */ 00680 branch(is_branch), intern_h0(NULL), intern_h1(NULL), 00681 // narrowing 00682 branch_narrow(narrow), narrow_dim(dim), narrow_size(sz), 00683 narrow_offset(offset) { 00684 this->own_contents = oc; 00685 msin.push_back(new Tstate(1)); 00686 msout.push_back(new Tstate(1)); 00687 } 00688 00689 // Clean vectors. Module doesn't have ownership of sub-modules 00690 template <typename T, class Tstate> 00691 layers<T,Tstate>::~layers() { 00692 if (this->own_contents) { 00693 for (unsigned int i=0; i < modules.size(); i++) 00694 delete modules[i]; 00695 if (!this->memoptimized) { 00696 for(unsigned int i=0;i < hiddens.size(); i++) 00697 delete hiddens[i]; 00698 } 00699 } 00700 } 00701 00702 template <typename T, class Tstate> 00703 void layers<T,Tstate>:: 00704 add_module(module_1_1<T, Tstate, Tstate>* module) { 00705 // regular addition 00706 modules.push_back(module); 00707 hiddens.push_back(NULL); 00708 // update what type of input/output are expected 00709 this->bmstate_input = modules[0]->mstate_input(); 00710 this->bmstate_output = modules[modules.size() - 1]->mstate_output(); 00711 } 00712 00713 // TODO: fix optimize fprop 00714 00715 // template <typename T, class Tstate> 00716 // bool layers<T,Tstate>::optimize_fprop(Mstate& in, Mstate& out){ 00717 // this->memoptimized = true; 00718 // if (modules.empty()) 00719 // eblerror("trying to fprop through empty layers"); 00720 // // initialize buffers 00721 // hi = ∈ 00722 // ho = &out; 00723 // // parallelism: do not modify input nor output 00724 // if (branch) { 00725 // // create our internal buffers with all dimensions set to 1 00726 // intern_h0 = new Mstate(in); 00727 // intern_h1 = new Mstate(in); 00728 // ho = intern_h0; 00729 // } 00730 // // loop over modules 00731 // for (uint i = 0; i < modules.size(); i++) { 00732 // hiddens[i] = ho; 00733 // // parallelism: for first module, do not allow optim with in buffer 00734 // if (branch && i == 0) { 00735 // hi = intern_h1; // now we use only internal buffers 00736 // swap_buffers(); // swap hi and ho 00737 // } else { 00738 // // call optimization on submodules, and remember if they put 00739 // // the output in ho (swap == true) or not (swap == false). 00740 // bool swap = modules[i]->optimize_fprop(*hi,*ho); 00741 // // if output is truly in ho, swap buffers, otherwise do nothing. 00742 // // if module was a branch, it di 00743 // if (swap) 00744 // swap_buffers(); 00745 // } 00746 // } 00747 // // parallelism: remember which buffer contains the output 00748 // if (branch) { 00749 // intern_out = hiddens[modules.size() - 1]; 00750 // // a branch does not output to current track, so the output for the 00751 // // mother branch is actually the branch's input, which is left in in 00752 // return false; // output is in in 00753 // } 00754 // // tell the outside if the output is in in or out 00755 // if (hiddens[modules.size() - 1] == &out) 00756 // return true; // output is in out 00757 // return false; // output is in in 00758 // } 00759 00760 // fprop ///////////////////////////////////////////////////////////////////// 00761 00762 template <typename T, class Tstate> 00763 void layers<T,Tstate>::fprop(Tstate& in, Tstate& out) { 00764 msin[0] = in; 00765 msout[0] = out; 00766 fprop(msin, msout); 00767 out = msout[0]; 00768 } 00769 00770 template <typename T, class Tstate> 00771 void layers<T,Tstate>::fprop(mstate<Tstate>& in, Tstate& out) { 00772 msout[0] = out; 00773 fprop(in, msout); 00774 out = msout[0]; 00775 } 00776 00777 template <typename T, class Tstate> 00778 void layers<T,Tstate>::fprop(Tstate& in, mstate<Tstate>& out) { 00779 msin[0] = in; 00780 fprop(msin, out); 00781 } 00782 00783 template <typename T, class Tstate> 00784 void layers<T,Tstate>::fprop(mstate<Tstate>& in, mstate<Tstate>& out) { 00785 if (modules.empty() && !branch) 00786 eblerror("trying to fprop through empty layers"); 00787 // initialize buffers 00788 hi = ∈ 00789 ho = &out; 00790 // narrow input data if required by branch 00791 mstate<Tstate> narrowed; 00792 if (branch && branch_narrow) { 00793 eblerror("not implemented"); 00794 // narrowed = hi->narrow(narrow_dim, narrow_size, narrow_offset); 00795 // //EDEBUG("branch narrowing input " << hi->x << " to " << narrowed.x); 00796 // hi = &narrowed; 00797 } 00798 // loop over modules 00799 for(int i = 0; i < (int) modules.size(); i++){ 00800 LOCAL_TIMING_START(); // timing debugging 00801 // if last module, output into out 00802 if (i == (int) modules.size() - 1 && !branch) ho = &out; 00803 else { // not last module, use hidden buffers 00804 ho = (mstate<Tstate>*) hiddens[i]; 00805 // allocate hidden buffer if necessary 00806 if (ho == NULL) { 00807 // allocate mstates with only 1 state. 00808 hiddens[i] = new mstate<Tstate>(in, 1, 1); 00809 ho = (mstate<Tstate>*) hiddens[i]; 00810 } 00811 } 00812 // run module 00813 // modules[i]->fprop(*hi, *ho); 00814 module_1_1<T,Tstate> *mod = modules[i]; 00815 DEBUGMEM_PRETTY("before " << mod->name() << " fprop: "); 00816 if (mod->mstate_input() == mod->mstate_output()) // s-s or ms-ms 00817 mod->fprop(*hi, *ho); 00818 else { // s-ms or ms-s 00819 if (mod->mstate_output()) { // s-ms 00820 Tstate &sin = (*hi)[0]; 00821 mod->fprop(sin, *ho); 00822 } else { // ms-s 00823 if (ho->size() == 0) ho->push_back(new Tstate(1)); 00824 Tstate &sout = (*ho)[0]; 00825 mod->fprop(*hi, sout); 00826 } 00827 } 00828 00829 // keep same input if current module is a branch, otherwise take out as in 00830 bool isbranch = false; 00831 if (dynamic_cast<layers<T,Tstate>*>(modules[i]) && 00832 ((layers<T,Tstate>*)modules[i])->branch) 00833 isbranch = true; 00834 if (!isbranch) 00835 hi = ho; 00836 if (isbranch && i + 1 == (int) modules.size()) 00837 ho = hi; // if last module is branch, set the input to be the branch out 00838 LOCAL_TIMING_REPORT(mod->name()); // timing debugging 00839 } 00840 if (branch) // remember output buffer (did not output to out) 00841 intern_out = ho; 00842 // remember number of input/outputs 00843 this->ninputs = in.size(); 00844 this->noutputs = out.size(); 00845 } 00846 00847 // bprop ///////////////////////////////////////////////////////////////////// 00848 00849 template <typename T, class Tstate> 00850 void layers<T,Tstate>::bprop(Tstate& in, Tstate& out) { 00851 msin[0] = in; 00852 msout[0] = out; 00853 bprop(msin, msout); 00854 in = msin[0]; 00855 } 00856 00857 template <typename T, class Tstate> 00858 void layers<T,Tstate>::bprop(mstate<Tstate>& in, Tstate& out) { 00859 msout[0] = out; 00860 bprop(in, msout); 00861 } 00862 00863 template <typename T, class Tstate> 00864 void layers<T,Tstate>::bprop(Tstate& in, mstate<Tstate>& out) { 00865 msin[0] = in; 00866 bprop(msin, out); 00867 } 00868 00869 template <typename T, class Tstate> 00870 void layers<T,Tstate>::bprop(mstate<Tstate>& in, mstate<Tstate>& out) { 00871 if (this->memoptimized) 00872 eblerror("cannot bprop while using dual-buffer memory optimization"); 00873 if (modules.empty()) 00874 eblerror("trying to bprop through empty layers"); 00875 // clear hidden states 00876 clear_dx(); 00877 EDEBUG(this->name() << ": in " << in); 00878 // init buffers 00879 hi = &out; 00880 ho = &out; 00881 // last will be manual 00882 for (int i = (int) modules.size() - 1; i >= 0; i--){ 00883 LOCAL_TIMING_START(); // timing debugging 00884 // set input 00885 if (i == 0) hi = ∈ 00886 else hi = hiddens[i - 1]; 00887 // run module 00888 EDEBUG(this->name() << " layers bprop hi " << *hi << " ho " << *ho); 00889 // modules[i]->bprop(*hi, *ho); 00890 module_1_1<T,Tstate> *mod = modules[i]; 00891 if (mod->mstate_input() == mod->mstate_output()) // s-s or ms-ms 00892 mod->bprop(*hi, *ho); 00893 else { // s-ms or ms-s 00894 if (mod->mstate_output()) { // s-ms 00895 Tstate &sin = (*hi)[0]; 00896 mod->bprop(sin, *ho); 00897 } else { // ms-s 00898 Tstate &sout = (*ho)[0]; 00899 mod->bprop(*hi, sout); 00900 } 00901 } 00902 // shift output pointer to input 00903 ho = hi; 00904 LOCAL_TIMING_REPORT(mod->name() << " bprop"); // timing debugging 00905 } 00906 } 00907 00908 // bbprop //////////////////////////////////////////////////////////////////// 00909 00910 template <typename T, class Tstate> 00911 void layers<T,Tstate>::bbprop(Tstate& in, Tstate& out) { 00912 msin[0] = in; 00913 msout[0] = out; 00914 bbprop(msin, msout); 00915 in = msin[0]; 00916 } 00917 00918 template <typename T, class Tstate> 00919 void layers<T,Tstate>::bbprop(mstate<Tstate>& in, Tstate& out) { 00920 msout[0] = out; 00921 bbprop(in, msout); 00922 } 00923 00924 template <typename T, class Tstate> 00925 void layers<T,Tstate>::bbprop(Tstate& in, mstate<Tstate>& out) { 00926 msin[0] = in; 00927 bbprop(msin, out); 00928 } 00929 00930 template <typename T, class Tstate> 00931 void layers<T,Tstate>::bbprop(mstate<Tstate>& in, mstate<Tstate>& out) { 00932 if (this->memoptimized) 00933 eblerror("cannot bbprop while using dual-buffer memory optimization"); 00934 if (modules.empty()) 00935 eblerror("trying to bbprop through empty layers"); 00936 00937 // clear hidden states 00938 // do not clear if we are a branch, it must have been cleared already by 00939 // main branch 00940 if (!branch) 00941 clear_ddx(); 00942 00943 hi = &out; 00944 ho = &out; 00945 00946 if (branch) // we are a branch, use the internal output 00947 ho = intern_out; 00948 00949 // last will be manual 00950 for(int i = (int) modules.size() - 1; i >= 0; i--){ 00951 LOCAL_TIMING_START(); // timing debugging 00952 // set input 00953 if (i == 0) 00954 hi = ∈ 00955 else 00956 hi = hiddens[i-1]; 00957 // if previous module is a branch, take its input as input 00958 if (i > 0 && dynamic_cast<layers<T,Tstate>*>(modules[i - 1]) && 00959 ((layers<T,Tstate>*)modules[i - 1])->branch) { 00960 if (i >= 2) 00961 hi = hiddens[i - 2]; 00962 else // i == 1 00963 hi = ∈ 00964 } 00965 // run module 00966 // modules[i]->bbprop(*hi, *ho); 00967 module_1_1<T,Tstate> *mod = modules[i]; 00968 if (mod->mstate_input() == mod->mstate_output()) // s-s or ms-ms 00969 mod->bbprop(*hi, *ho); 00970 else { // s-ms or ms-s 00971 if (mod->mstate_output()) { // s-ms 00972 Tstate &sin = (*hi)[0]; 00973 mod->bbprop(sin, *ho); 00974 } else { // ms-s 00975 Tstate &sout = (*ho)[0]; 00976 mod->bbprop(*hi, sout); 00977 } 00978 } 00979 00980 00981 // shift output pointer to input 00982 ho = hi; 00983 LOCAL_TIMING_REPORT(mod->name() << " bbprop"); // timing debugging 00984 } 00985 } 00986 00987 // dump_fprop //////////////////////////////////////////////////////////////// 00988 00989 template <typename T, class Tstate> 00990 void layers<T,Tstate>::dump_fprop(Tstate& in, Tstate& out) { 00991 msin[0] = in; 00992 msout[0] = out; 00993 dump_fprop(msin, msout); 00994 out = msout[0]; 00995 } 00996 00997 template <typename T, class Tstate> 00998 void layers<T,Tstate>::dump_fprop(mstate<Tstate>& in, Tstate& out) { 00999 msout[0] = out; 01000 dump_fprop(in, msout); 01001 out = msout[0]; 01002 } 01003 01004 template <typename T, class Tstate> 01005 void layers<T,Tstate>::dump_fprop(Tstate& in, mstate<Tstate>& out) { 01006 msin[0] = in; 01007 dump_fprop(msin, out); 01008 } 01009 01010 template <typename T, class Tstate> 01011 void layers<T,Tstate>::dump_fprop(mstate<Tstate>& in, mstate<Tstate>& out) { 01012 if (modules.empty() && !branch) 01013 eblerror("trying to dump_fprop through empty layers"); 01014 // initialize buffers 01015 hi = ∈ 01016 ho = &out; 01017 // narrow input data if required by branch 01018 mstate<Tstate> narrowed; 01019 if (branch && branch_narrow) { 01020 eblerror("not implemented"); 01021 // narrowed = hi->narrow(narrow_dim, narrow_size, narrow_offset); 01022 // //EDEBUG("branch narrowing input " << hi->x << " to " << narrowed.x); 01023 // hi = &narrowed; 01024 } 01025 // loop over modules 01026 for(int i = 0; i < (int) modules.size(); i++){ 01027 // if last module, output into out 01028 if (i == (int) modules.size() - 1 && !branch) 01029 ho = &out; 01030 else { // not last module, use hidden buffers 01031 ho = (mstate<Tstate>*) hiddens[i]; 01032 // allocate hidden buffer if necessary 01033 if (ho == NULL) { 01034 // allocate mstates with only 1 state. 01035 hiddens[i] = new mstate<Tstate>(in, 1, 1); 01036 ho = (mstate<Tstate>*) hiddens[i]; 01037 } 01038 } 01039 // run module 01040 01041 // modules[i]->dump_fprop(*hi, *ho); 01042 module_1_1<T,Tstate> *mod = modules[i]; 01043 if (mod->mstate_input() == mod->mstate_output()) // s-s or ms-ms 01044 mod->dump_fprop(*hi, *ho); 01045 else { // s-ms or ms-s 01046 if (mod->mstate_output()) { // s-ms 01047 Tstate &sin = (*hi)[0]; 01048 mod->dump_fprop(sin, *ho); 01049 } else { // ms-s 01050 Tstate &sout = (*ho)[0]; 01051 mod->dump_fprop(*hi, sout); 01052 } 01053 } 01054 TIMING1(mod->name()); 01055 01056 // keep same input if current module is a branch, otherwise take out as in 01057 bool isbranch = false; 01058 if (dynamic_cast<layers<T,Tstate>*>(modules[i]) && 01059 ((layers<T,Tstate>*)modules[i])->branch) 01060 isbranch = true; 01061 if (!isbranch) 01062 hi = ho; 01063 if (isbranch && i + 1 == (int) modules.size()) 01064 ho = hi; // if last module is branch, set the input to be the branch out 01065 } 01066 if (branch) // remember output buffer (did not output to out) 01067 intern_out = ho; 01068 // remember number of input/outputs 01069 this->ninputs = in.size(); 01070 this->noutputs = out.size(); 01071 } 01072 01074 01075 template <typename T, class Tstate> 01076 void layers<T,Tstate>::forget(forget_param_linear& fp){ 01077 if (modules.empty() && !branch) 01078 eblerror("trying to forget through empty layers"); 01079 01080 for(unsigned int i=0; i<modules.size(); i++){ 01081 module_1_1<T,Tstate,Tstate> *tt = modules[i]; 01082 tt->forget(fp); 01083 } 01084 } 01085 01086 template <typename T, class Tstate> 01087 void layers<T,Tstate>::normalize(){ 01088 if (modules.empty()) 01089 eblerror("trying to normalize through empty layers"); 01090 01091 for(unsigned int i=0; i<modules.size(); i++){ 01092 modules[i]->normalize(); 01093 } 01094 } 01095 01096 template <typename T, class Tstate> 01097 fidxdim layers<T,Tstate>::fprop_size(fidxdim &isize) { 01098 fidxdim os(isize); 01101 for (unsigned int i = 0; i < modules.size(); i++) { 01102 module_1_1<T,Tstate,Tstate> *tt = modules[i]; 01103 // determine if module is a branch 01104 bool isbranch = false; 01105 if (dynamic_cast<layers<T,Tstate>*>(modules[i]) && 01106 ((layers<T,Tstate>*)modules[i])->branch) 01107 isbranch = true; 01108 // do not go to branches 01109 if (!isbranch) 01110 os = tt->fprop_size(os); 01111 } 01113 isize = bprop_size(os); 01114 return os; 01115 } 01116 01117 template <typename T, class Tstate> 01118 fidxdim layers<T,Tstate>::bprop_size(const fidxdim &osize) { 01119 fidxdim isize(osize); 01121 for (int i = (int) modules.size() - 1; i >= 0; i--) { 01122 module_1_1<T,Tstate,Tstate> *tt = modules[i]; 01123 // determine if module is a branch 01124 bool isbranch = false; 01125 if (dynamic_cast<layers<T,Tstate>*>(modules[i]) && 01126 ((layers<T,Tstate>*)modules[i])->branch) 01127 isbranch = true; 01128 // do not go to branches 01129 if (!isbranch) 01130 isize = tt->bprop_size(isize); 01131 } 01132 return isize; 01133 } 01134 01135 template <typename T, class Tstate> 01136 mfidxdim layers<T,Tstate>::fprop_size(mfidxdim &isize) { 01137 mfidxdim os(isize); 01140 for (unsigned int i = 0; i < modules.size(); i++) { 01141 module_1_1<T,Tstate,Tstate> *tt = modules[i]; 01142 // determine if module is a branch 01143 bool isbranch = false; 01144 if (dynamic_cast<layers<T,Tstate>*>(modules[i]) && 01145 ((layers<T,Tstate>*)modules[i])->branch) 01146 isbranch = true; 01147 // do not go to branches 01148 if (!isbranch) 01149 os = tt->fprop_size(os); 01150 } 01152 isize = bprop_size(os); 01153 this->ninputs = isize.size(); 01154 this->noutputs = os.size(); 01155 return os; 01156 } 01157 01158 template <typename T, class Tstate> 01159 mfidxdim layers<T,Tstate>::bprop_size(mfidxdim &osize) { 01160 mfidxdim isize(osize); 01162 for (int i = (int) modules.size() - 1; i >= 0; i--) { 01163 module_1_1<T,Tstate,Tstate> *tt = modules[i]; 01164 // determine if module is a branch 01165 bool isbranch = false; 01166 if (dynamic_cast<layers<T,Tstate>*>(modules[i]) && 01167 ((layers<T,Tstate>*)modules[i])->branch) 01168 isbranch = true; 01169 // do not go to branches 01170 if (!isbranch) { 01171 //EDEBUG(this->name() << ": layers bprop_size before: " << isize); 01172 isize = tt->bprop_size(isize); 01173 //EDEBUG(this->name() << ": layers bprop_size after: " << isize); 01174 } 01175 } 01176 //EDEBUG(this->name() << ": " << osize << " -> " << isize); 01177 return isize; 01178 } 01179 01180 template <typename T, class Tstate> 01181 layers<T,Tstate>* layers<T,Tstate>::copy() { 01182 layers<T,Tstate> *l2 = new layers<T,Tstate>(true); 01184 int niter = this->modules.size(); 01185 for(int i = 0; i < niter; i++) { 01186 l2->add_module((module_1_1<T,Tstate>*)this->modules[i]->copy()); 01187 if (this->hiddens[i] != NULL) { 01188 l2->hiddens[i] = new mstate<Tstate>(*(this->hiddens[i])); 01189 l2->hiddens[i]->copy(*(l2->hiddens[i])); 01190 } 01191 } 01192 return l2; 01193 } 01194 01195 template <typename T, class Tstate> 01196 void layers<T,Tstate>::swap_buffers() { 01197 htmp = hi; 01198 hi = ho; 01199 ho = htmp; 01200 } 01201 01202 template <typename T, class Tstate> 01203 uint layers<T,Tstate>::size() { 01204 return modules.size(); 01205 } 01206 01207 template <typename T, class Tstate> 01208 std::string layers<T,Tstate>::pretty(idxdim &isize) { 01209 mfidxdim is(isize); 01210 return this->pretty(is); 01211 } 01212 01213 template <typename T, class Tstate> 01214 std::string layers<T,Tstate>::pretty(mfidxdim &isize) { 01215 std::string s; 01216 mfidxdim is(isize); 01219 for (unsigned int i = 0; i < modules.size(); i++) { 01220 module_1_1<T,Tstate> *tt = modules[i]; 01221 // determine if module is a branch 01222 bool isbranch = false; 01223 if (dynamic_cast<layers<T,Tstate>*>(modules[i]) && 01224 ((layers<T,Tstate>*)modules[i])->branch) 01225 isbranch = true; 01226 // do not go to branches 01227 if (!isbranch) { 01228 s << tt->pretty(is); 01229 mfidxdim mis(is); 01230 mis = tt->fprop_size(mis); 01231 is = mis; 01232 } 01233 } 01234 return s; 01235 } 01236 01237 template <typename T, class Tstate> 01238 void layers<T,Tstate>::clear_dx() { 01239 // clear hidden states 01240 for (uint i = 0; i<hiddens.size(); i++){ 01241 if (hiddens[i]) 01242 hiddens[i]->clear_dx(); 01243 } 01244 // clear hidden states of branches 01245 for (uint i = 0; i < modules.size(); ++i) { 01246 // check if this module is a branch 01247 if (dynamic_cast<layers<T,Tstate>*>(modules[i]) && 01248 ((layers<T,Tstate>*)modules[i])->branch) { 01249 // if yes, clear its hidden states 01250 layers<T,Tstate> *branch = (layers<T,Tstate>*) modules[i]; 01251 branch->clear_dx(); 01252 } 01253 } 01254 } 01255 01256 template <typename T, class Tstate> 01257 void layers<T,Tstate>::clear_ddx() { 01258 // clear hidden states 01259 for (uint i = 0; i < hiddens.size(); i++) { 01260 if (hiddens[i]) 01261 hiddens[i]->clear_ddx(); 01262 } 01263 // clear hidden states of branches 01264 for (uint i = 0; i < modules.size(); ++i) { 01265 // check if this module is a branch 01266 if (dynamic_cast<layers<T,Tstate>*>(modules[i]) && 01267 ((layers<T,Tstate>*)modules[i])->branch) { 01268 // if yes, clear its hidden states 01269 layers<T,Tstate> *branch = (layers<T,Tstate>*) modules[i]; 01270 branch->clear_ddx(); 01271 } 01272 } 01273 } 01274 01275 template <typename T, class Tstate> 01276 bool layers<T,Tstate>::is_branch() { 01277 return branch; 01278 } 01279 01280 template <typename T, class Tstate> 01281 module_1_1<T, Tstate, Tstate>* 01282 layers<T,Tstate>::find(const char *name) { 01283 for (uint i = 0; i < modules.size(); ++i) { 01284 module_1_1<T, Tstate, Tstate>* m = modules[i]; 01285 if (!strcmp(name, m->name())) 01286 return m; 01287 } 01288 return NULL; // not found 01289 } 01290 01291 template <typename T, class Tstate> 01292 module_1_1<T, Tstate, Tstate>* 01293 layers<T,Tstate>::last_module() { 01294 if (modules.size() == 0) 01295 eblerror("requires at least 1 module"); 01296 return modules[modules.size() - 1]->last_module(); 01297 } 01298 01299 template <typename T, class Tstate> 01300 std::string layers<T,Tstate>::describe(uint indent) { 01301 std::string desc; 01302 desc << "Module " << this->name() << " contains " 01303 << (int) modules.size() << " modules:\n"; 01304 for(uint i = 0; i < modules.size(); ++i) { 01305 for (uint j = 0; j < indent; ++j) desc << "\t"; 01306 desc << i << ": " << modules[i]->describe(); 01307 if (i != modules.size() - 1) desc << "\n"; 01308 } 01309 return desc; 01310 } 01311 01312 template <typename T, class Tstate> 01313 bool layers<T,Tstate>::mstate_input() { 01314 if (modules.size()) 01315 return modules[0]->mstate_input(); 01316 return this->bmstate_input; 01317 } 01318 01319 template <typename T, class Tstate> 01320 bool layers<T,Tstate>::mstate_output() { 01321 if (modules.size()) 01322 return modules[modules.size() - 1]->mstate_output(); 01323 return this->bmstate_output; 01324 } 01325 01326 template <typename T, class Tstate> 01327 void layers<T,Tstate>::set_output_streams(std::ostream &out, 01328 std::ostream &err) { 01329 for(uint i = 0; i < modules.size(); ++i) 01330 modules[i]->set_output_streams(out, err); 01331 } 01332 01334 // layers_2 01335 01336 template <typename T, class Tin, class Thid, class Tout> 01337 layers_2<T,Tin,Thid,Tout>::layers_2(module_1_1<T,Tin,Thid> &l1, 01338 Thid &h, module_1_1<T,Thid,Tout> &l2) 01339 : layer1(l1), hidden(h), layer2(l2) { 01340 } 01341 01342 // Do nothing. Module doesn't have ownership of sub-modules 01343 template <typename T, class Tin, class Thid, class Tout> 01344 layers_2<T,Tin,Thid,Tout>::~layers_2() { 01345 } 01346 01347 template <typename T, class Tin, class Thid, class Tout> 01348 void layers_2<T,Tin,Thid,Tout>::fprop(Tin &in, Tout &out) { 01349 layer1.fprop(in, hidden); 01350 layer2.fprop(hidden, out); 01351 } 01352 01353 template <typename T, class Tin, class Thid, class Tout> 01354 void layers_2<T,Tin,Thid,Tout>::bprop(Tin &in, Tout &out) { 01355 hidden.clear_dx(); 01356 layer2.bprop(hidden, out); 01357 layer1.bprop(in, hidden); 01358 } 01359 01360 template <typename T, class Tin, class Thid, class Tout> 01361 void layers_2<T,Tin,Thid,Tout>::bbprop(Tin &in, Tout &out) { 01362 hidden.clear_ddx(); 01363 layer2.bbprop(hidden, out); 01364 layer1.bbprop(in, hidden); 01365 } 01366 01367 template <typename T, class Tin, class Thid, class Tout> 01368 void layers_2<T,Tin,Thid,Tout>::forget(forget_param_linear &fp) { 01369 layer1.forget(fp); 01370 layer2.forget(fp); 01371 } 01372 01373 template <typename T, class Tin, class Thid, class Tout> 01374 void layers_2<T,Tin,Thid,Tout>::normalize() { 01375 layer1.normalize(); 01376 layer2.normalize(); 01377 } 01378 01379 template <typename T, class Tin, class Thid, class Tout> 01380 fidxdim layers_2<T,Tin,Thid,Tout>::fprop_size(fidxdim &isize) { 01381 fidxdim os(isize); 01382 os = layer1.fprop_size(os); 01383 os = layer2.fprop_size(os); 01385 isize = bprop_size(os); 01386 return os; 01387 } 01388 01389 template <typename T, class Tin, class Thid, class Tout> 01390 fidxdim layers_2<T,Tin,Thid,Tout>::bprop_size(const fidxdim &osize) { 01391 fidxdim isize(osize); 01392 isize = layer2.bprop_size(isize); 01393 isize = layer1.bprop_size(isize); 01394 return isize; 01395 } 01396 01397 template <typename T, class Tin, class Thid, class Tout> 01398 std::string layers_2<T,Tin,Thid,Tout>::pretty(idxdim &isize) { 01399 std::string s; 01400 idxdim is(isize); 01401 s << layer1.pretty(is); 01402 s << " -> "; 01403 is = layer1.fprop_size(is); 01404 s << layer2.pretty(is); 01405 return s; 01406 } 01407 01409 01410 template <typename T, class Tin, class Thid, class Ten> 01411 fc_ebm1<T,Tin,Thid,Ten>::fc_ebm1(module_1_1<T,Tin,Thid> &fm, 01412 Thid &fo, ebm_1<T,Thid,Ten> &fc) 01413 : fmod(fm), fout(fo), fcost(fc) { 01414 } 01415 01416 template <typename T, class Tin, class Thid, class Ten> 01417 fc_ebm1<T,Tin,Thid,Ten>::~fc_ebm1() {} 01418 01419 template <typename T, class Tin, class Thid, class Ten> 01420 void fc_ebm1<T,Tin,Thid,Ten>::fprop(Tin &in, Ten &energy) { 01421 fmod.fprop(in, fout); 01422 fcost.fprop(fout, energy); 01423 } 01424 01425 template <typename T, class Tin, class Thid, class Ten> 01426 void fc_ebm1<T,Tin,Thid,Ten>::bprop(Tin &in, Ten &energy) { 01427 fout.clear_dx(); 01428 fcost.bprop(fout, energy); 01429 fmod.bprop(in, fout); 01430 } 01431 01432 template <typename T, class Tin, class Thid, class Ten> 01433 void fc_ebm1<T,Tin,Thid,Ten>::bbprop(Tin &in, Ten &energy) { 01434 fout.clear_ddx(); 01435 fcost.bbprop(fout, energy); 01436 fmod.bbprop(in, fout); 01437 } 01438 01439 template <typename T, class Tin, class Thid, class Ten> 01440 void fc_ebm1<T,Tin,Thid,Ten>::forget(forget_param &fp) { 01441 fmod.forget(fp); 01442 fcost.forget(fp); 01443 } 01444 01446 01447 template <typename T, class Tin1, class Tin2, class Ten> 01448 fc_ebm2<T,Tin1,Tin2,Ten>::fc_ebm2(module_1_1<T,Tin1,Tin1> &fm, 01449 Tin1 &fo, 01450 ebm_2<Tin1,Tin2,Ten> &fc) 01451 : fmod(fm), fout(fo), fcost(fc) { 01452 } 01453 01454 template <typename T, class Tin1, class Tin2, class Ten> 01455 fc_ebm2<T,Tin1,Tin2,Ten>::~fc_ebm2() {} 01456 01457 template <typename T, class Tin1, class Tin2, class Ten> 01458 void fc_ebm2<T,Tin1,Tin2,Ten>::fprop(Tin1 &in1, Tin2 &in2, Ten &energy) { 01459 fmod.fprop(in1, fout); 01460 fcost.fprop(fout, in2, energy); 01461 #ifdef __DUMP_STATES__ // used to debug 01462 save_matrix(energy.x, "dump_fc_ebm2_energy.x.mat"); 01463 save_matrix(in1.x, "dump_fc_ebm2_cost_in1.x.mat"); 01464 #endif 01465 } 01466 01467 template <typename T, class Tin1, class Tin2, class Ten> 01468 void fc_ebm2<T,Tin1,Tin2,Ten>::bprop(Tin1 &in1, Tin2 &in2, Ten &energy) { 01469 fout.clear_dx(); 01470 // in2.clear_dx(); // TODO this assumes Tin2 == fstate_idx 01471 fcost.bprop(fout, in2, energy); 01472 fmod.bprop(in1, fout); 01473 } 01474 01475 template <typename T, class Tin1, class Tin2, class Ten> 01476 void fc_ebm2<T,Tin1,Tin2,Ten>::bbprop(Tin1 &in1, Tin2 &in2, Ten &energy){ 01477 fout.clear_ddx(); 01478 // in2.clear_ddx(); // TODO this assumes Tin2 == fstate_idx 01479 fcost.bbprop(fout, in2, energy); 01480 fmod.bbprop(in1, fout); 01481 } 01482 01483 template <typename T, class Tin1, class Tin2, class Ten> 01484 void fc_ebm2<T,Tin1,Tin2,Ten>::forget(forget_param_linear &fp) { 01485 fmod.forget(fp); 01486 fcost.forget(fp); 01487 } 01488 01489 template <typename T, class Tin1, class Tin2, class Ten> 01490 double fc_ebm2<T,Tin1,Tin2,Ten>::infer2(Tin1 &i1, Tin2 &i2, 01491 infer_param &ip, Tin2 *label, 01492 Ten *energy) { 01493 fmod.fprop(i1, fout); // first propagate all the way up 01494 return fcost.infer2(fout, i2, ip, label, energy); //then infer from energies 01495 } 01496 01498 // generic replicable modules classes 01499 01500 // check that orders of input and module are compatible 01501 template <typename T, class Tstate> 01502 void check_replicable_orders(module_1_1<T,Tstate> &m, Tstate& in) { 01503 if (in.x.order() < 0) 01504 eblerror("module_1_1_replicable cannot replicate this module (order -1)"); 01505 if (in.x.order() < m.replicable_order()) 01506 eblerror("input order must be >= to module's operating order, input is " 01507 << in.x << " but module operates with order " 01508 << m.replicable_order()); 01509 if (in.x.order() > MAXDIMS) 01510 eblerror("cannot replicate using more dimensions than MAXDIMS"); 01511 } 01512 01515 template <class Tmodule, class Tstate> 01516 void module_eloop2_fprop(Tmodule &m, Tstate &in, Tstate &out) { 01517 if (m.replicable_order() == in.x.order()) { 01518 m.Tmodule::fprop(in, out); 01519 } else if (m.replicable_order() > in.x.order()) { 01520 eblerror("input order must be >= to module's operating order, input is " 01521 << in.x << " but module operates with order " 01522 << m.replicable_order()); 01523 } else { 01524 state_idx_eloop2(iin, in, Tstate, oout, out, Tstate) { 01525 module_eloop2_fprop<Tmodule,Tstate>(m, (Tstate&) iin, (Tstate&) oout); 01526 } 01527 } 01528 } 01529 01532 template <class Tmodule, class Tstate> 01533 void module_eloop2_bprop(Tmodule &m, Tstate &in, Tstate &out) { 01534 if (m.replicable_order() == in.x.order()) { 01535 m.Tmodule::bprop(in, out); 01536 } else if (m.replicable_order() > in.x.order()) { 01537 eblerror("the order of the input should be greater or equal to module's\ 01538 operating order"); 01539 } else { 01540 state_idx_eloop2(iin, in, Tstate, oout, out, Tstate) { 01541 module_eloop2_bprop<Tmodule,Tstate>(m, (Tstate&) iin, (Tstate&) oout); 01542 } 01543 } 01544 } 01545 01548 template <class Tmodule, class Tstate> 01549 void module_eloop2_bbprop(Tmodule &m, Tstate &in, Tstate &out) { 01550 if (m.replicable_order() == in.x.order()) { 01551 m.Tmodule::bbprop(in, out); 01552 } else if (m.replicable_order() > in.x.order()) { 01553 eblerror("the order of the input should be greater or equal to module's\ 01554 operating order"); 01555 } else { 01556 state_idx_eloop2(iin, in, Tstate, oout, out, Tstate) { 01557 module_eloop2_bbprop<Tmodule,Tstate>(m, (Tstate&) iin, (Tstate&) oout); 01558 } 01559 } 01560 } 01561 01562 template <class Tmodule, typename T, class Tstate> 01563 module_1_1_replicable<Tmodule,T,Tstate>::module_1_1_replicable(Tmodule &m) 01564 : module(m) { 01565 } 01566 01567 template <class Tmodule, typename T, class Tstate> 01568 module_1_1_replicable<Tmodule,T,Tstate>::~module_1_1_replicable() { 01569 } 01570 01571 template <class Tmodule, typename T, class Tstate> 01572 void module_1_1_replicable<Tmodule,T,Tstate>::fprop(Tstate &in, Tstate &out) { 01573 check_replicable_orders(module, in); // check for orders compatibility 01574 module.resize_output(in, out); // resize output 01575 module_eloop2_fprop<Tmodule,Tstate>(module, (Tstate&) in, (Tstate&) out); 01576 } 01577 01578 template <class Tmodule, typename T, class Tstate> 01579 void module_1_1_replicable<Tmodule,T,Tstate>::bprop(Tstate &in, Tstate &out) { 01580 check_replicable_orders(module, in); // check for orders compatibility 01581 module_eloop2_bprop<Tmodule,Tstate>(module, (Tstate&) in, (Tstate&) out); 01582 } 01583 01584 template <class Tmodule, typename T, class Tstate> 01585 void module_1_1_replicable<Tmodule,T,Tstate>::bbprop(Tstate &in, Tstate &out){ 01586 check_replicable_orders(module, in); // check for orders compatibility 01587 module_eloop2_bbprop<Tmodule,Tstate>(module, (Tstate&) in, (Tstate&) out); 01588 } 01589 01591 01592 template <typename T, class Tstate> 01593 narrow_module<T,Tstate>::narrow_module(int dim_, intg size_, intg offset_, 01594 bool narrow_states_) 01595 : module_1_1<T,Tstate>("narrow_module"), dim(dim_), size(size_), 01596 narrow_states(narrow_states_) { 01597 this->bmstate_input = true; // this module takes multi-state inputs 01598 this->bmstate_output = true; // this module takes multi-state outputs 01599 offsets.push_back(offset_); 01600 } 01601 01602 template <typename T, class Tstate> 01603 narrow_module<T,Tstate>:: 01604 narrow_module(int dim_, intg size_, vector<intg> &offsets_, bool states_, 01605 const char *name_) 01606 : module_1_1<T,Tstate>(name_), 01607 dim(dim_), size(size_), offsets(offsets_), narrow_states(states_) { 01608 this->bmstate_input = true; // this module takes multi-state inputs 01609 this->bmstate_output = true; // this module takes multi-state outputs 01610 } 01611 01612 template <typename T, class Tstate> 01613 narrow_module<T,Tstate>::~narrow_module() { 01614 } 01615 01616 template <typename T, class Tstate> 01617 void narrow_module<T,Tstate>::fprop(mstate<Tstate> &in, mstate<Tstate> &out) { 01618 // narrow each state of multi-state in 01619 if (narrow_states) { 01620 out.resize(in); 01621 for (uint i = 0; i < in.size(); ++i) 01622 fprop(in[i], out[i]); 01623 } else { // narrow multi-state itself 01624 if (dim == 0) { // narrow on states 01625 out.resize(in, offsets.size() * size); 01626 for (uint o = 0; o < offsets.size(); ++o) { 01627 intg offset = offsets[o]; 01628 if ((intg) in.size() < offset + size) 01629 eblerror("expected at least " << offset + size 01630 << " states in narrow of dimension " 01631 << dim << " at offset " << offset << " to size " << size 01632 << " but found only " << in.size() << " states"); 01633 for (intg i = offset; i < offset + size; ++i) 01634 out[i - offset + o * size] = in[i]; 01635 } 01636 } else eblerror("not implemented"); 01637 } 01638 this->ninputs = in.size(); 01639 this->noutputs = out.size(); 01640 EDEBUG("narrowed " << in << " to " << out); 01641 } 01642 01643 template <typename T, class Tstate> 01644 void narrow_module<T,Tstate>::bprop(mstate<Tstate> &in, mstate<Tstate> &out) { 01645 // TODO: assign states back to their input location? 01646 } 01647 01648 template <typename T, class Tstate> 01649 void narrow_module<T,Tstate>::bbprop(mstate<Tstate> &in, mstate<Tstate> &out){ 01650 // TODO: assign states back to their input location? 01651 } 01652 01653 template <typename T, class Tstate> 01654 void narrow_module<T,Tstate>::fprop(Tstate &in, Tstate &out) { 01655 // TODO: handle multiple offsets by copying narrows next to each other 01656 intg offset = offsets[0]; 01657 out = in.narrow(dim, size, offset); 01658 } 01659 01660 template <typename T, class Tstate> 01661 std::string narrow_module<T,Tstate>::describe() { 01662 std::string s; 01663 s << "narrow_module " << this->name() << " narrowing dimension " << dim 01664 << " to size " << size << " starting at offset(s) " << offsets; 01665 return s; 01666 } 01667 01668 template <typename T, class Tstate> 01669 narrow_module<T,Tstate>* narrow_module<T,Tstate>::copy() { 01670 narrow_module<T,Tstate> *l2 = 01671 new narrow_module<T,Tstate>(dim, size, offsets, narrow_states); 01672 return l2; 01673 } 01674 01675 template <typename T, class Tstate> 01676 mfidxdim narrow_module<T,Tstate>::fprop_size(mfidxdim &isize) { 01677 EDEBUG(this->name() << ": " << isize << " f-> ..."); 01678 mfidxdim osize; 01679 if (narrow_states) { eblerror("not implemented"); 01680 } else { 01681 if (dim == 0) { // narrow on states 01682 osize.resize_default(offsets.size() * size); 01683 for (uint o = 0; o < offsets.size(); ++o) { 01684 intg offset = offsets[o]; 01685 if ((intg) isize.size() < offset + size) 01686 eblerror("expected at least " << offset + size 01687 << " states in narrow of dimension " 01688 << dim << " at offset " << offset << " to size " << size 01689 << " but found only " << isize.size() << " states"); 01690 for (intg i = offset; i < offset + size; ++i) 01691 if (isize.exists(i)) osize.set(isize[i], i - offset + o * size); 01692 } 01693 } else eblerror("not implemented"); 01694 } 01695 this->ninputs = isize.size(); 01696 this->noutputs = osize.size(); 01697 EDEBUG(this->name() << ": " << isize << " f-> " << osize); 01698 return osize; 01699 } 01700 01701 template <typename T, class Tstate> 01702 mfidxdim narrow_module<T,Tstate>::bprop_size(mfidxdim &osize) { 01703 EDEBUG(this->name() << ": " << osize << " b-> ..."); 01704 // eblwarn("temporary no bpropsize in narrow"); 01705 // return osize; 01706 mfidxdim isize; 01707 uint offset = offsets[0]; 01708 for (uint i = 0; i < offset; ++i) 01709 isize.push_back_empty(); 01710 isize.push_back(osize); 01711 for (uint i = offset + size; i < this->ninputs; ++i) 01712 isize.push_back_empty(); 01713 EDEBUG(this->name() << ": " << osize << " b-> " << isize); 01714 return isize; 01715 } 01716 01718 01719 template <typename T, class Tstate> 01720 table_module<T,Tstate>::table_module(vector<intg> &tbl, intg tot, 01721 const char *name_) 01722 : module_1_1<T,Tstate>(name_), table(tbl), total(tot) { 01723 this->bmstate_input = true; // this module takes multi-state inputs 01724 this->bmstate_output = true; // this module takes multi-state outputs 01725 } 01726 01727 template <typename T, class Tstate> 01728 table_module<T,Tstate>::~table_module() { 01729 } 01730 01731 template <typename T, class Tstate> 01732 void table_module<T,Tstate>::fprop(mstate<Tstate> &in, mstate<Tstate> &out) { 01733 out.clear(); 01734 for (uint i = 0; i < table.size(); ++i) { 01735 intg k = table[i]; 01736 if (k < 0 || k >= (intg) in.size()) 01737 eblerror("trying to access index " << k << " in inputs " << in); 01738 out.push_back(in[k]); 01739 } 01740 this->ninputs = in.size(); 01741 this->noutputs = out.size(); 01742 EDEBUG(this->name() << ": mapped " << in << " to " << out); 01743 } 01744 01745 template <typename T, class Tstate> 01746 void table_module<T,Tstate>::bprop(mstate<Tstate> &in, mstate<Tstate> &out) { 01747 EDEBUG(this->name() << " bprop: in: " << in); 01748 EDEBUG(this->name() << " bprop: out: " << out); 01749 // TODO: assign states back to their input location? 01750 } 01751 01752 template <typename T, class Tstate> 01753 void table_module<T,Tstate>::bbprop(mstate<Tstate> &in, mstate<Tstate> &out){ 01754 // TODO: assign states back to their input location? 01755 } 01756 01757 template <typename T, class Tstate> 01758 std::string table_module<T,Tstate>::describe() { 01759 std::string s; 01760 s << "table_module " << this->name() << " with input list " << table; 01761 return s; 01762 } 01763 01764 template <typename T, class Tstate> 01765 table_module<T,Tstate>* table_module<T,Tstate>::copy() { 01766 table_module<T,Tstate> *l2 = 01767 new table_module<T,Tstate>(table, total, this->name()); 01768 return l2; 01769 } 01770 01771 template <typename T, class Tstate> 01772 mfidxdim table_module<T,Tstate>::fprop_size(mfidxdim &isize) { 01773 mfidxdim osize; 01774 for (uint i = 0; i < table.size(); ++i) { 01775 intg k = table[i]; 01776 if (k < 0 || k >= (intg) isize.size()) 01777 eblerror("trying to access index " << k << " in inputs " << isize); 01778 osize.push_back(isize[k]); 01779 } 01780 return osize; 01781 } 01782 01783 template <typename T, class Tstate> 01784 mfidxdim table_module<T,Tstate>::bprop_size(mfidxdim &osize) { 01785 mfidxdim isize; 01786 uint n = total; 01787 for (uint i = 0; i < table.size(); ++i) 01788 if (table[i] + 1 > n) n = table[i] + 1; 01789 for (uint i = 0; i < n; ++i) 01790 isize.push_back_empty(); 01791 for (uint i = 0; i < table.size(); ++i) { 01792 intg k = table[i]; 01793 if (osize.exists(i)) isize.set(osize[i], k); 01794 } 01795 EDEBUG(this->name() << ": " << osize << " b-> " << isize); 01796 return isize; 01797 } 01798 01800 // network sizes methods 01801 01802 template <typename T, class Tstate> 01803 idxdim network_mindims(module_1_1<T,Tstate> &m, uint order) { 01804 idxdim d; 01805 for (uint i = 0; i < order; ++i) 01806 d.insert_dim(0, 1); 01807 fidxdim fd = d; 01808 d = m.bprop_size(fd); 01809 return d; 01810 } 01811 01812 } // end namespace ebl