libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2008 by Yann LeCun and Pierre Sermanet * 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef EBL_ARCH_H_ 00034 #define EBL_ARCH_H_ 00035 00036 #include "libidx.h" 00037 #include "ebl_defines.h" 00038 #include "ebl_states.h" 00039 00040 #ifndef __NOSTL__ 00041 #include <vector> 00042 #endif 00043 00044 using namespace std; 00045 00046 namespace ebl { 00047 00048 #define check_bstate(ptr) \ 00049 if (!ptr) eblerror("module buffers should be bstate_idx") 00050 #define check_bbstate(ptr) \ 00051 if (!ptr) eblerror("module buffers should be bbstate_idx") 00052 00053 00055 // module 00056 00058 class EXPORT module { 00059 public: 00060 module(const char *name = "module"); 00061 virtual ~module(); 00063 virtual const char* name(); 00065 virtual void set_name(const char *name); 00067 virtual void set_output_streams(std::ostream &out, std::ostream &err); 00069 virtual std::string describe(); 00071 virtual std::string describe(uint indent); 00073 virtual void enable(); 00075 virtual void disable(); 00076 // members ///////////////////////////////////////////////////////////////// 00077 protected: 00078 std::string _name; 00079 std::ostream *mout, *merr; 00080 bool silent; 00081 bool _enabled; 00082 }; 00083 00085 // module_1_1 00086 00088 template<typename T, class Tin = bbstate_idx<T>, class Tout = Tin> 00089 class EXPORT module_1_1 : public module { 00090 public: 00092 module_1_1(const char *name = "module_1_1", bool bresize = true); 00093 virtual ~module_1_1(); 00094 // generic states methods ////////////////////////////////////////////////// 00095 virtual void fprop(Tin &in, Tout &out); 00096 virtual void bprop(Tin &in, Tout &out); 00097 virtual void bbprop(Tin &in, Tout &out); 00100 virtual void dump_fprop(Tin &in, Tout &out); 00101 // multi-states methods //////////////////////////////////////////////////// 00102 virtual void fprop(mstate<Tin> &in, mstate<Tout> &out); 00103 virtual void bprop(mstate<Tin> &in, mstate<Tout> &out); 00104 virtual void bbprop(mstate<Tin> &in, mstate<Tout> &out); 00107 virtual void dump_fprop(mstate<Tin> &in, mstate<Tout> &out); 00108 // multi-states to single-state methods //////////////////////////////////// 00109 virtual void fprop(mstate<Tin> &in, Tout &out); 00110 virtual void bprop(mstate<Tin> &in, Tout &out); 00111 virtual void bbprop(mstate<Tin> &in, Tout &out); 00114 virtual void dump_fprop(mstate<Tin> &in, Tout &out); 00115 // single-states to multi-state methods //////////////////////////////////// 00116 virtual void fprop(Tin &in, mstate<Tout> &out); 00117 virtual void bprop(Tin &in, mstate<Tout> &out); 00118 virtual void bbprop(Tin &in, mstate<Tout> &out); 00121 virtual void dump_fprop(Tin &in, mstate<Tout> &out); 00123 virtual void forget(forget_param_linear& fp); 00124 virtual void normalize(); 00126 virtual int replicable_order(); 00127 // resizing //////////////////////////////////////////////////////////////// 00130 virtual bool ignored(Tin &in, Tout &out); 00136 virtual bool resize_output(Tin &in, Tout &out, idxdim *d = NULL); 00142 virtual bool resize_output(Tin &in, idx<T> &out, idxdim *d = NULL); 00147 virtual fidxdim fprop_size(fidxdim &isize); 00151 virtual fidxdim bprop_size(const fidxdim &osize); 00156 virtual mfidxdim fprop_size(mfidxdim &isize); 00157 /* //! Returns input dimensions corresponding to multiple output dimensions */ 00158 /* //! 'osize'. Implementation of this method helps automatic scaling of input */ 00159 /* //! data but is optional. */ 00160 /* virtual mfidxdim bprop_size(const mfidxdim &osize); */ 00164 virtual mfidxdim bprop_size(mfidxdim &osize); 00168 virtual std::string pretty(idxdim &isize); 00172 virtual std::string pretty(mfidxdim &isize); 00174 virtual module_1_1<T, Tin, Tout>* copy(); 00177 virtual module_1_1<T, Tin, Tout>* copy(parameter<T,Tin> *p); 00182 virtual bool optimize_fprop(Tin &in, Tout &out); 00184 virtual bool optimize_fprop(mstate<Tin> &in, mstate<Tout> &out); 00187 virtual void load_x(idx<T> &weights); 00190 virtual module_1_1<T,Tin,Tout>* last_module(); 00192 virtual bool mstate_input(); 00194 virtual bool mstate_output(); 00196 virtual uint get_ninputs(); 00198 virtual uint get_noutputs(); 00199 00200 // variable members ////////////////////////////////////////////////////////// 00201 public: 00202 // these variables describe internal buffers declared to be displayed 00203 // by external display objects. 00204 std::vector<idx<T> > internals; 00205 std::vector<std::string> internals_str; 00206 protected: 00207 bool bresize; 00208 bool memoptimized; 00209 bool bmstate_input; 00210 bool bmstate_output; 00211 uint ninputs, noutputs; 00212 }; 00213 00215 // module_2_1 00216 00218 template<typename T, class Tin1 = bbstate_idx<T>, class Tin2 = Tin1, 00219 class Tout = Tin1> 00220 class EXPORT module_2_1 : public module { 00221 public: 00222 module_2_1(const char *name = "module_2_1"); 00223 virtual ~module_2_1(); 00225 // generic states methods 00226 virtual void fprop(Tin1 &in1, Tin2 &in2, Tout &out); 00227 virtual void bprop(Tin1 &in1, Tin2 &in2, Tout &out); 00228 virtual void bbprop(Tin1 &in1, Tin2 &in2, Tout &out); 00230 // multi-states methods 00231 virtual void fprop(mstate<Tin1> &in1, mstate<Tin2> &in2, 00232 mstate<Tout> &out); 00233 virtual void bprop(mstate<Tin1> &in1, mstate<Tin2> &in2, 00234 mstate<Tout> &out); 00235 virtual void bbprop(mstate<Tin1> &in1, mstate<Tin2> &in2, 00236 mstate<Tout> &out); 00238 virtual void forget(forget_param &fp); 00239 virtual void normalize(); 00245 virtual bool resize_output(Tin1 &in1, Tin2 &in2, Tout &out, 00246 idxdim *d = NULL); 00247 00248 protected: 00249 bool bresize; 00250 }; 00251 00253 template<typename T, class Tin = bbstate_idx<T>, class Ten = Tin> 00254 class ebm_1 : public module { 00255 public: 00256 ebm_1(const char *name = "ebm_1"); 00257 virtual ~ebm_1(); 00258 virtual void fprop(Tin &in, Ten &energy); 00259 virtual void bprop(Tin &in, Ten &energy); 00260 virtual void bbprop(Tin &in, Ten &energy); 00261 virtual void forget(forget_param &fp); 00262 virtual void normalize(); 00263 }; 00264 00265 // ebm_module_1_1 //////////////////////////////////////////////////////////// 00266 00270 template<typename T, class Tin = bbstate_idx<T>, class Tout = Tin, 00271 class Ten = Tin> 00272 class ebm_module_1_1 : public module_1_1<T,Tin,Tout> { 00273 public: 00276 ebm_module_1_1(module_1_1<T,Tin,Tout> *m, ebm_1<T,Ten> *e, 00277 const char *name = "ebm_module_1_1"); 00278 virtual ~ebm_module_1_1(); 00279 virtual void fprop(Tin &in, Tout &out); 00280 virtual void bprop(Tin &in, Tout &out); 00281 virtual void bbprop(Tin &in, Tout &out); 00282 virtual void forget(forget_param_linear &fp); 00284 virtual Ten& get_energy(); 00289 virtual fidxdim fprop_size(fidxdim &isize); 00293 virtual fidxdim bprop_size(const fidxdim &osize); 00295 virtual std::string describe(); 00296 protected: 00297 module_1_1<T,Tin,Tout> *module; 00298 ebm_1<T,Ten> *ebm; 00299 Ten energy; 00300 }; 00301 00303 // ebm_2 00304 00306 template<class Tin1, class Tin2 = Tin1, class Ten = Tin1> 00307 class ebm_2 : public module { 00308 public: 00309 ebm_2(const char *name = "ebm_2"); 00310 virtual ~ebm_2(); 00312 virtual void fprop(Tin1 &i1, Tin2 &i2, Ten &energy); 00314 virtual void bprop(Tin1 &i1, Tin2 &i2, Ten &energy); 00316 virtual void bbprop(Tin1 &i1, Tin2 &i2, Ten &energy); 00317 00318 virtual void bprop1_copy(Tin1 &i1, Tin2 &i2, Ten &energy); 00319 virtual void bprop2_copy(Tin1 &i1, Tin2 &i2, Ten &energy); 00320 virtual void bbprop1_copy(Tin1 &i1, Tin2 &i2, Ten &energy); 00321 virtual void bbprop2_copy(Tin1 &i1, Tin2 &i2, Ten &energy); 00322 virtual void forget(forget_param_linear &fp); 00323 virtual void normalize(); 00325 virtual double infer1(Tin1 &i1, Tin2 &i2, Ten &energy, infer_param &ip); 00328 virtual double infer2(Tin1 &i1, Tin2 &i2, infer_param &ip, 00329 Tin2 *label = NULL, Ten *energy = NULL); 00330 virtual void infer2_copy(Tin1 &i1, Tin2 &i2, Ten &energy); 00331 }; 00332 00334 // layers 00335 00337 template<typename T, class Tstate = bbstate_idx<T> > 00338 class layers : public module_1_1<T, Tstate, Tstate> { 00339 public: 00345 layers(bool oc = true, const char *name = "layers", 00346 bool is_branch = false, bool narrow = false, 00347 intg dim = 0, intg sz = 0, intg offset = 0); 00348 virtual ~layers(); 00351 virtual void add_module(module_1_1<T, Tstate, Tstate>* module); 00356 // TODO: fix optimize fprop 00357 //virtual bool optimize_fprop(Mstate &in, Mstate &out); 00358 // single states methods /////////////////////////////////////////////////// 00359 virtual void fprop(Tstate &in, Tstate &out); 00360 virtual void bprop(Tstate &in, Tstate &out); 00361 virtual void bbprop(Tstate &in, Tstate &out); 00364 virtual void dump_fprop(Tstate &in, Tstate &out); 00365 // multi to single states methods ////////////////////////////////////////// 00366 virtual void fprop(mstate<Tstate> &in, Tstate &out); 00367 virtual void bprop(mstate<Tstate> &in, Tstate &out); 00368 virtual void bbprop(mstate<Tstate> &in, Tstate &out); 00371 virtual void dump_fprop(mstate<Tstate> &in, Tstate &out); 00372 // single to multi states methods ////////////////////////////////////////// 00373 virtual void fprop(Tstate &in, mstate<Tstate> &out); 00374 virtual void bprop(Tstate &in, mstate<Tstate> &out); 00375 virtual void bbprop(Tstate &in, mstate<Tstate> &out); 00378 virtual void dump_fprop(Tstate &in, mstate<Tstate> &out); 00379 // multi to multi methods ////////////////////////////////////////////////// 00380 virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out); 00381 virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out); 00382 virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out); 00385 virtual void dump_fprop(mstate<Tstate> &in, mstate<Tstate> &out); 00387 virtual void forget(forget_param_linear &fp); 00388 virtual void normalize(); 00394 virtual fidxdim fprop_size(fidxdim &i_size); 00398 virtual fidxdim bprop_size(const fidxdim &o_size); 00403 virtual mfidxdim fprop_size(mfidxdim &isize); 00407 virtual mfidxdim bprop_size(mfidxdim &o_size); 00411 virtual std::string pretty(idxdim &isize); 00415 virtual std::string pretty(mfidxdim &isize); 00417 virtual layers<T,Tstate>* copy(); 00419 virtual void swap_buffers(); 00421 virtual uint size(); 00424 virtual void clear_dx(); 00427 virtual void clear_ddx(); 00429 bool is_branch(); 00431 module_1_1<T, Tstate, Tstate>* find(const char *name); 00434 virtual module_1_1<T,Tstate>* last_module(); 00436 virtual std::string describe(uint indent = 0); 00438 virtual bool mstate_input(); 00440 virtual bool mstate_output(); 00442 virtual void set_output_streams(std::ostream &out, std::ostream &err); 00443 00444 // friends ///////////////////////////////////////////////////////////////// 00445 friend class layers_gui; 00446 00447 // member variables //////////////////////////////////////////////////////// 00448 public: 00449 std::vector<module_1_1<T, Tstate, Tstate>*> modules; 00450 std::vector<mstate<Tstate>*> hiddens; 00451 mstate<Tstate>* intern_out; 00452 protected: 00453 bool own_contents; 00454 mstate<Tstate>* hi; 00455 mstate<Tstate>* ho; 00456 mstate<Tstate>* htmp; 00457 // used for parallelism /////////////////////////////////////////////////// 00458 bool branch; 00459 mstate<Tstate>* intern_h0; 00460 mstate<Tstate>* intern_h1; 00461 bool branch_narrow; 00462 intg narrow_dim; 00463 intg narrow_size; 00464 intg narrow_offset; 00465 mstate<Tstate> msin, msout; 00466 }; 00467 00469 // layers_2 00470 00471 template<typename T, class Tin = bbstate_idx<T>, class Thid = Tin, 00472 class Tout = Tin> 00473 class EXPORT layers_2 : public module_1_1<T, Tin, Tout> { 00474 public: 00475 module_1_1<T, Tin, Thid> &layer1; 00476 Thid &hidden; 00477 module_1_1<T, Thid, Tout> &layer2; 00478 00479 layers_2(module_1_1<T, Tin, Thid> &l1, Thid &h, 00480 module_1_1<T, Thid, Tout> &l2); 00481 virtual ~layers_2(); 00482 virtual void fprop(Tin &in, Tout &out); 00483 virtual void bprop(Tin &in, Tout &out); 00484 virtual void bbprop(Tin &in, Tout &out); 00485 virtual void forget(forget_param_linear &fp); 00486 virtual void normalize(); 00492 virtual fidxdim fprop_size(fidxdim &i_size); 00496 virtual fidxdim bprop_size(const fidxdim &o_size); 00500 virtual std::string pretty(idxdim &isize); 00501 }; 00502 00506 template<typename T, class Tin = bbstate_idx<T>, class Thid = Tin, 00507 class Ten = Tin> 00508 class EXPORT fc_ebm1 : public ebm_1<T, Tin, Ten> { 00509 public: 00510 module_1_1<T,Tin,Thid> &fmod; 00511 Thid &fout; 00512 ebm_1<T,Thid,Ten> &fcost; 00513 00514 fc_ebm1(module_1_1<T,Tin,Thid> &fm, Thid &fo, ebm_1<T,Thid,Ten> &fc); 00515 virtual ~fc_ebm1(); 00516 00517 virtual void fprop(Tin &in, Ten &energy); 00518 virtual void bprop(Tin &in, Ten &energy); 00519 virtual void bbprop(Tin &in, Ten &energy); 00520 virtual void forget(forget_param &fp); 00521 }; 00522 00526 template<typename T, class Tin1 = bbstate_idx<T>, class Tin2 = Tin1, 00527 class Ten = Tin1> 00528 class EXPORT fc_ebm2 : public ebm_2<Tin1, Tin2, Ten> { 00529 public: 00530 module_1_1<T, Tin1, Tin1> &fmod; 00531 Tin1 &fout; 00532 ebm_2<Tin1, Tin2, Ten> &fcost; 00533 00534 fc_ebm2(module_1_1<T, Tin1> &fm, Tin1 &fo, ebm_2<Tin1, Tin2, Ten> &fc); 00535 virtual ~fc_ebm2(); 00536 00537 virtual void fprop(Tin1 &in1, Tin2 &in2, Ten &energy); 00538 virtual void bprop(Tin1 &in1, Tin2 &in2, Ten &energy); 00539 virtual void bbprop(Tin1 &in1, Tin2 &in2, Ten &energy); 00540 virtual void forget(forget_param_linear &fp); 00541 virtual double infer2(Tin1 &i1, Tin2 &i2, infer_param &ip, 00542 Tin2 *label = NULL, Ten *energy = NULL); 00543 }; 00544 00546 // helper functions 00547 00549 template<typename T, class Tstate> 00550 void check_replicable_orders(module_1_1<T,Tstate> &m, Tstate& in); 00551 00553 // generic replicable module classes 00554 00559 template<class Tmodule, typename T, class Tstate = bbstate_idx<T> > 00560 class EXPORT module_1_1_replicable { 00561 public: 00562 Tmodule &module; 00563 module_1_1_replicable(Tmodule &m); 00564 virtual ~module_1_1_replicable(); 00565 virtual void fprop(Tstate &in, Tstate &out); 00566 virtual void bprop(Tstate &in, Tstate &out); 00567 virtual void bbprop(Tstate &in, Tstate &out); 00568 }; 00569 00583 #define DECLARE_REPLICABLE_MODULE_1_1(replicable_module, base_module, \ 00584 T, Tstate, \ 00585 types_arguments, arguments) \ 00586 template <typename T, class Tstate = bbstate_idx<T> > \ 00587 class EXPORT replicable_module : public base_module<T,Tstate> { \ 00588 public: \ 00589 module_1_1_replicable<base_module<T,Tstate>,T,Tstate> *rep; \ 00590 replicable_module types_arguments : base_module<T,Tstate> arguments { \ 00591 rep = new module_1_1_replicable<base_module<T,Tstate>,T,Tstate>(*this); \ 00592 this->bresize = false; \ 00593 if (this->replicable_order() <= 0) \ 00594 eblerror("this module is not replicable"); } \ 00595 virtual ~replicable_module() { delete rep; } \ 00596 virtual void fprop(Tstate &in, Tstate &out) \ 00597 { rep->fprop(in, out); } \ 00598 virtual void bprop(Tstate &in, Tstate &out) \ 00599 { rep->bprop(in, out); } \ 00600 virtual void bbprop(Tstate &in, Tstate &out) \ 00601 { rep->bbprop(in, out); } \ 00602 } 00603 00607 template<typename T, class Tstate = bbstate_idx<T> > 00608 class EXPORT narrow_module : public module_1_1<T,Tstate,Tstate> { 00609 public: 00614 narrow_module(int d, intg size, intg offset, bool narrow_states = false); 00619 narrow_module(int d, intg size, vector<intg> &offsets, 00620 bool narrow_states = false, const char *name = "narrow"); 00622 virtual ~narrow_module(); 00623 // multi-state inputs and outputs ////////////////////////////////////////// 00624 virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out); 00625 virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out); 00626 virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out); 00627 // single-state inputs and outputs ///////////////////////////////////////// 00628 virtual void fprop(Tstate &in, Tstate &out); 00630 virtual std::string describe(); 00632 virtual narrow_module<T,Tstate>* copy(); 00637 virtual mfidxdim fprop_size(mfidxdim &isize); 00641 virtual mfidxdim bprop_size(mfidxdim &osize); 00642 protected: 00643 int dim; 00644 intg size; 00645 vector<intg> offsets; 00646 bool narrow_states; 00647 }; 00648 00650 template<typename T, class Tstate = bbstate_idx<T> > 00651 class EXPORT table_module : public module_1_1<T,Tstate,Tstate> { 00652 public: 00655 table_module(vector<intg> &inputs, intg total, 00656 const char *name = "table_module"); 00658 virtual ~table_module(); 00659 // multi-state inputs and outputs ////////////////////////////////////////// 00660 virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out); 00661 virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out); 00662 virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out); 00664 virtual std::string describe(); 00666 virtual table_module<T,Tstate>* copy(); 00671 virtual mfidxdim fprop_size(mfidxdim &isize); 00675 virtual mfidxdim bprop_size(mfidxdim &osize); 00676 protected: 00677 vector<intg> table; 00678 intg total; 00679 }; 00680 00681 // network sizes methods ///////////////////////////////////////////////////// 00682 00685 template <typename T, class Tstate> 00686 EXPORT idxdim network_mindims(module_1_1<T,Tstate> &m, uint order); 00687 00688 } // namespace ebl { 00689 00690 #include "ebl_arch.hpp" 00691 00692 #endif /* EBL_ARCH_H_ */