libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_arch.h
00001 /***************************************************************************
00002  *   Copyright (C) 2008 by Yann LeCun and Pierre Sermanet *
00003  *   yann@cs.nyu.edu, pierre.sermanet@gmail.com *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef EBL_ARCH_H_
00034 #define EBL_ARCH_H_
00035 
00036 #include "libidx.h"
00037 #include "ebl_defines.h"
00038 #include "ebl_states.h"
00039 
00040 #ifndef __NOSTL__
00041 #include <vector>
00042 #endif
00043 
00044 using namespace std;
00045 
00046 namespace ebl {
00047 
00048 #define check_bstate(ptr) \
00049   if (!ptr) eblerror("module buffers should be bstate_idx")
00050 #define check_bbstate(ptr) \
00051   if (!ptr) eblerror("module buffers should be bbstate_idx")
00052 
00053 
00055   // module
00056 
00058   class EXPORT module {
00059   public:
00060     module(const char *name = "module");
00061     virtual ~module();
00063     virtual const char* name();
00065     virtual void set_name(const char *name);
00067     virtual void set_output_streams(std::ostream &out, std::ostream &err);
00069     virtual std::string describe();
00071     virtual std::string describe(uint indent);
00073     virtual void enable();
00075     virtual void disable();
00076     // members /////////////////////////////////////////////////////////////////
00077   protected:
00078     std::string _name; 
00079     std::ostream *mout, *merr; 
00080     bool silent;
00081     bool _enabled;
00082   };
00083 
00085   // module_1_1
00086 
00088   template<typename T, class Tin = bbstate_idx<T>, class Tout = Tin>
00089     class EXPORT module_1_1 : public module {
00090   public:
00092     module_1_1(const char *name = "module_1_1", bool bresize = true);
00093     virtual ~module_1_1();
00094     // generic states methods //////////////////////////////////////////////////
00095     virtual void fprop(Tin &in, Tout &out);
00096     virtual void bprop(Tin &in, Tout &out);
00097     virtual void bbprop(Tin &in, Tout &out);
00100     virtual void dump_fprop(Tin &in, Tout &out);
00101     // multi-states methods ////////////////////////////////////////////////////
00102     virtual void fprop(mstate<Tin> &in, mstate<Tout> &out);
00103     virtual void bprop(mstate<Tin> &in, mstate<Tout> &out);
00104     virtual void bbprop(mstate<Tin> &in, mstate<Tout> &out);
00107     virtual void dump_fprop(mstate<Tin> &in, mstate<Tout> &out);
00108     // multi-states to single-state methods ////////////////////////////////////
00109     virtual void fprop(mstate<Tin> &in, Tout &out);
00110     virtual void bprop(mstate<Tin> &in, Tout &out);
00111     virtual void bbprop(mstate<Tin> &in, Tout &out);
00114     virtual void dump_fprop(mstate<Tin> &in, Tout &out);
00115     // single-states to multi-state methods ////////////////////////////////////
00116     virtual void fprop(Tin &in, mstate<Tout> &out);
00117     virtual void bprop(Tin &in, mstate<Tout> &out);
00118     virtual void bbprop(Tin &in, mstate<Tout> &out);
00121     virtual void dump_fprop(Tin &in, mstate<Tout> &out);
00123     virtual void forget(forget_param_linear& fp);
00124     virtual void normalize();
00126     virtual int  replicable_order();
00127     // resizing ////////////////////////////////////////////////////////////////
00130     virtual bool ignored(Tin &in, Tout &out);
00136     virtual bool resize_output(Tin &in, Tout &out, idxdim *d = NULL);
00142     virtual bool resize_output(Tin &in, idx<T> &out, idxdim *d = NULL);
00147     virtual fidxdim fprop_size(fidxdim &isize);
00151     virtual fidxdim bprop_size(const fidxdim &osize);
00156     virtual mfidxdim fprop_size(mfidxdim &isize);
00157     /* //! Returns input dimensions corresponding to multiple output dimensions */
00158     /* //! 'osize'. Implementation of this method helps automatic scaling of input */
00159     /* //! data but is optional. */
00160     /* virtual mfidxdim bprop_size(const mfidxdim &osize); */
00164     virtual mfidxdim bprop_size(mfidxdim &osize);
00168     virtual std::string pretty(idxdim &isize);
00172     virtual std::string pretty(mfidxdim &isize);
00174     virtual module_1_1<T, Tin, Tout>* copy();
00177     virtual module_1_1<T, Tin, Tout>* copy(parameter<T,Tin> *p);
00182     virtual bool optimize_fprop(Tin &in, Tout &out);
00184     virtual bool optimize_fprop(mstate<Tin> &in, mstate<Tout> &out);
00187     virtual void load_x(idx<T> &weights);
00190     virtual module_1_1<T,Tin,Tout>* last_module();
00192     virtual bool mstate_input();
00194     virtual bool mstate_output();
00196     virtual uint get_ninputs();
00198     virtual uint get_noutputs();
00199 
00200   // variable members //////////////////////////////////////////////////////////
00201   public:
00202     // these variables describe internal buffers declared to be displayed
00203     // by external display objects.
00204     std::vector<idx<T> >     internals; 
00205     std::vector<std::string>    internals_str;  
00206   protected:
00207     bool                        bresize; 
00208     bool                        memoptimized; 
00209     bool                        bmstate_input;  
00210     bool                        bmstate_output; 
00211     uint                        ninputs, noutputs; 
00212   };
00213 
00215   // module_2_1
00216 
00218   template<typename T, class Tin1 = bbstate_idx<T>, class Tin2 = Tin1,
00219     class Tout = Tin1>
00220     class EXPORT module_2_1 : public module {
00221   public:
00222     module_2_1(const char *name = "module_2_1");
00223     virtual ~module_2_1();
00225     // generic states methods
00226     virtual void fprop(Tin1 &in1, Tin2 &in2, Tout &out);
00227     virtual void bprop(Tin1 &in1, Tin2 &in2, Tout &out);
00228     virtual void bbprop(Tin1 &in1, Tin2 &in2, Tout &out);
00230     // multi-states methods
00231     virtual void fprop(mstate<Tin1> &in1, mstate<Tin2> &in2,
00232                        mstate<Tout> &out);
00233     virtual void bprop(mstate<Tin1> &in1, mstate<Tin2> &in2,
00234                        mstate<Tout> &out);
00235     virtual void bbprop(mstate<Tin1> &in1, mstate<Tin2> &in2,
00236                         mstate<Tout> &out);
00238     virtual void forget(forget_param &fp);
00239     virtual void normalize();
00245     virtual bool resize_output(Tin1 &in1, Tin2 &in2, Tout &out,
00246                                idxdim *d = NULL);
00247 
00248   protected:
00249     bool         bresize;       
00250   };
00251 
00253   template<typename T, class Tin = bbstate_idx<T>, class Ten = Tin>
00254     class ebm_1 : public module {
00255   public:
00256     ebm_1(const char *name = "ebm_1");
00257     virtual ~ebm_1();
00258     virtual void fprop(Tin &in, Ten &energy);
00259     virtual void bprop(Tin &in, Ten &energy);
00260     virtual void bbprop(Tin &in, Ten &energy);
00261     virtual void forget(forget_param &fp);
00262     virtual void normalize();
00263   };
00264 
00265   // ebm_module_1_1 ////////////////////////////////////////////////////////////
00266 
00270   template<typename T, class Tin = bbstate_idx<T>, class Tout = Tin,
00271     class Ten = Tin>
00272     class ebm_module_1_1 : public module_1_1<T,Tin,Tout> {
00273   public:
00276     ebm_module_1_1(module_1_1<T,Tin,Tout> *m, ebm_1<T,Ten> *e,
00277                    const char *name = "ebm_module_1_1");
00278     virtual ~ebm_module_1_1();
00279     virtual void fprop(Tin &in, Tout &out);
00280     virtual void bprop(Tin &in, Tout &out);
00281     virtual void bbprop(Tin &in, Tout &out);
00282     virtual void forget(forget_param_linear &fp);
00284     virtual Ten& get_energy();
00289     virtual fidxdim fprop_size(fidxdim &isize);
00293     virtual fidxdim bprop_size(const fidxdim &osize);
00295     virtual std::string describe();
00296   protected:
00297     module_1_1<T,Tin,Tout>      *module;
00298     ebm_1<T,Ten>                *ebm;
00299     Ten                          energy;
00300   };
00301 
00303   // ebm_2
00304 
00306   template<class Tin1, class Tin2 = Tin1, class Ten = Tin1>
00307     class ebm_2 : public module {
00308   public:
00309     ebm_2(const char *name = "ebm_2");
00310     virtual ~ebm_2();
00312     virtual void fprop(Tin1 &i1, Tin2 &i2, Ten &energy);
00314     virtual void bprop(Tin1 &i1, Tin2 &i2, Ten &energy);
00316     virtual void bbprop(Tin1 &i1, Tin2 &i2, Ten &energy);
00317 
00318     virtual void bprop1_copy(Tin1 &i1, Tin2 &i2, Ten &energy);
00319     virtual void bprop2_copy(Tin1 &i1, Tin2 &i2, Ten &energy);
00320     virtual void bbprop1_copy(Tin1 &i1, Tin2 &i2, Ten &energy);
00321     virtual void bbprop2_copy(Tin1 &i1, Tin2 &i2, Ten &energy);
00322     virtual void forget(forget_param_linear &fp);
00323     virtual void normalize();
00325     virtual double infer1(Tin1 &i1, Tin2 &i2, Ten &energy, infer_param &ip);
00328     virtual double infer2(Tin1 &i1, Tin2 &i2, infer_param &ip,
00329                           Tin2 *label = NULL, Ten *energy = NULL);
00330     virtual void infer2_copy(Tin1 &i1, Tin2 &i2, Ten &energy);
00331   };
00332 
00334   // layers
00335 
00337   template<typename T, class Tstate = bbstate_idx<T> >
00338     class layers : public module_1_1<T, Tstate, Tstate> {
00339   public:
00345     layers(bool oc = true, const char *name = "layers",
00346            bool is_branch = false, bool narrow = false,
00347            intg dim = 0, intg sz = 0, intg offset = 0);
00348     virtual ~layers();
00351     virtual void add_module(module_1_1<T, Tstate, Tstate>* module);
00356   // TODO: fix optimize fprop
00357   //virtual bool optimize_fprop(Mstate &in, Mstate &out);
00358     // single states methods ///////////////////////////////////////////////////
00359     virtual void fprop(Tstate &in, Tstate &out);
00360     virtual void bprop(Tstate &in, Tstate &out);
00361     virtual void bbprop(Tstate &in, Tstate &out);
00364     virtual void dump_fprop(Tstate &in, Tstate &out);
00365     // multi to single states methods //////////////////////////////////////////
00366     virtual void fprop(mstate<Tstate> &in, Tstate &out);
00367     virtual void bprop(mstate<Tstate> &in, Tstate &out);
00368     virtual void bbprop(mstate<Tstate> &in, Tstate &out);
00371     virtual void dump_fprop(mstate<Tstate> &in, Tstate &out);
00372     // single to multi states methods //////////////////////////////////////////
00373     virtual void fprop(Tstate &in, mstate<Tstate> &out);
00374     virtual void bprop(Tstate &in, mstate<Tstate> &out);
00375     virtual void bbprop(Tstate &in, mstate<Tstate> &out);
00378     virtual void dump_fprop(Tstate &in, mstate<Tstate> &out);
00379     // multi to multi methods //////////////////////////////////////////////////
00380     virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out);
00381     virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out);
00382     virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out);
00385     virtual void dump_fprop(mstate<Tstate> &in, mstate<Tstate> &out);
00387     virtual void forget(forget_param_linear &fp);
00388     virtual void normalize();
00394     virtual fidxdim fprop_size(fidxdim &i_size);
00398     virtual fidxdim bprop_size(const fidxdim &o_size);
00403     virtual mfidxdim fprop_size(mfidxdim &isize);
00407     virtual mfidxdim bprop_size(mfidxdim &o_size);
00411     virtual std::string pretty(idxdim &isize);
00415     virtual std::string pretty(mfidxdim &isize);
00417     virtual layers<T,Tstate>* copy();
00419     virtual void swap_buffers();
00421     virtual uint size();
00424     virtual void clear_dx();
00427     virtual void clear_ddx();
00429     bool is_branch();
00431     module_1_1<T, Tstate, Tstate>* find(const char *name);
00434     virtual module_1_1<T,Tstate>* last_module();
00436     virtual std::string describe(uint indent = 0);
00438     virtual bool mstate_input();
00440     virtual bool mstate_output();
00442     virtual void set_output_streams(std::ostream &out, std::ostream &err);
00443 
00444     // friends /////////////////////////////////////////////////////////////////
00445     friend class layers_gui;
00446 
00447     // member variables ////////////////////////////////////////////////////////
00448   public:
00449     std::vector<module_1_1<T, Tstate, Tstate>*> modules;
00450     std::vector<mstate<Tstate>*>                        hiddens;
00451     mstate<Tstate>* intern_out; 
00452   protected:
00453     bool own_contents;
00454     mstate<Tstate>* hi; 
00455     mstate<Tstate>* ho; 
00456     mstate<Tstate>* htmp; 
00457     // used for parallelism ///////////////////////////////////////////////////
00458     bool branch; 
00459     mstate<Tstate>* intern_h0; 
00460     mstate<Tstate>* intern_h1; 
00461     bool branch_narrow; 
00462     intg narrow_dim; 
00463     intg narrow_size; 
00464     intg narrow_offset; 
00465     mstate<Tstate> msin, msout; 
00466   };
00467 
00469   // layers_2
00470 
00471   template<typename T, class Tin = bbstate_idx<T>, class Thid = Tin,
00472     class Tout = Tin>
00473     class EXPORT layers_2 : public module_1_1<T, Tin, Tout> {
00474   public:
00475     module_1_1<T, Tin, Thid>    &layer1;
00476     Thid                        &hidden;
00477     module_1_1<T, Thid, Tout>   &layer2;
00478 
00479     layers_2(module_1_1<T, Tin, Thid> &l1, Thid &h,
00480              module_1_1<T, Thid, Tout> &l2);
00481     virtual ~layers_2();
00482     virtual void fprop(Tin &in, Tout &out);
00483     virtual void bprop(Tin &in, Tout &out);
00484     virtual void bbprop(Tin &in, Tout &out);
00485     virtual void forget(forget_param_linear &fp);
00486     virtual void normalize();
00492     virtual fidxdim fprop_size(fidxdim &i_size);
00496     virtual fidxdim bprop_size(const fidxdim &o_size);
00500     virtual std::string pretty(idxdim &isize);
00501   };
00502 
00506   template<typename T, class Tin = bbstate_idx<T>, class Thid = Tin,
00507     class Ten = Tin>
00508     class EXPORT fc_ebm1 : public ebm_1<T, Tin, Ten> {
00509   public:
00510     module_1_1<T,Tin,Thid>      &fmod;
00511     Thid                        &fout;
00512     ebm_1<T,Thid,Ten>           &fcost;
00513 
00514     fc_ebm1(module_1_1<T,Tin,Thid> &fm, Thid &fo, ebm_1<T,Thid,Ten> &fc);
00515     virtual ~fc_ebm1();
00516 
00517     virtual void fprop(Tin &in, Ten &energy);
00518     virtual void bprop(Tin &in, Ten &energy);
00519     virtual void bbprop(Tin &in, Ten &energy);
00520     virtual void forget(forget_param &fp);
00521   };
00522 
00526   template<typename T, class Tin1 = bbstate_idx<T>, class Tin2 = Tin1,
00527     class Ten = Tin1>
00528     class EXPORT fc_ebm2 : public ebm_2<Tin1, Tin2, Ten> {
00529   public:
00530     module_1_1<T, Tin1, Tin1>   &fmod;
00531     Tin1                        &fout;
00532     ebm_2<Tin1, Tin2, Ten>      &fcost;
00533 
00534     fc_ebm2(module_1_1<T, Tin1> &fm, Tin1 &fo, ebm_2<Tin1, Tin2, Ten> &fc);
00535     virtual ~fc_ebm2();
00536 
00537     virtual void fprop(Tin1 &in1, Tin2 &in2, Ten &energy);
00538     virtual void bprop(Tin1 &in1, Tin2 &in2, Ten &energy);
00539     virtual void bbprop(Tin1 &in1, Tin2 &in2, Ten &energy);
00540     virtual void forget(forget_param_linear &fp);
00541     virtual double infer2(Tin1 &i1, Tin2 &i2, infer_param &ip,
00542                           Tin2 *label = NULL, Ten *energy = NULL);
00543   };
00544 
00546   // helper functions
00547 
00549   template<typename T, class Tstate>
00550     void check_replicable_orders(module_1_1<T,Tstate> &m, Tstate& in);
00551 
00553   // generic replicable module classes
00554 
00559   template<class Tmodule, typename T, class Tstate = bbstate_idx<T> >
00560     class EXPORT module_1_1_replicable {
00561   public:
00562     Tmodule &module;
00563     module_1_1_replicable(Tmodule &m);
00564     virtual ~module_1_1_replicable();
00565     virtual void fprop(Tstate &in, Tstate &out);
00566     virtual void bprop(Tstate &in, Tstate &out);
00567     virtual void bbprop(Tstate &in, Tstate &out);
00568   };
00569 
00583 #define DECLARE_REPLICABLE_MODULE_1_1(replicable_module, base_module,   \
00584                                       T, Tstate,                        \
00585                                       types_arguments, arguments)       \
00586   template <typename T, class Tstate = bbstate_idx<T> >                 \
00587     class EXPORT replicable_module : public base_module<T,Tstate> {             \
00588   public:                                                               \
00589     module_1_1_replicable<base_module<T,Tstate>,T,Tstate> *rep;         \
00590     replicable_module types_arguments : base_module<T,Tstate> arguments { \
00591       rep = new module_1_1_replicable<base_module<T,Tstate>,T,Tstate>(*this); \
00592       this->bresize = false;                                            \
00593       if (this->replicable_order() <= 0)                                \
00594         eblerror("this module is not replicable"); }                    \
00595     virtual ~replicable_module() { delete rep; }                        \
00596     virtual void fprop(Tstate &in, Tstate &out)                         \
00597     { rep->fprop(in, out); }                                            \
00598     virtual void bprop(Tstate &in, Tstate &out)                         \
00599     { rep->bprop(in, out); }                                            \
00600     virtual void bbprop(Tstate &in, Tstate &out)                        \
00601     { rep->bbprop(in, out); }                                           \
00602   }
00603 
00607   template<typename T, class Tstate = bbstate_idx<T> >
00608     class EXPORT narrow_module : public module_1_1<T,Tstate,Tstate> {
00609   public:
00614     narrow_module(int d, intg size, intg offset, bool narrow_states = false);
00619     narrow_module(int d, intg size, vector<intg> &offsets,
00620                   bool narrow_states = false, const char *name = "narrow");
00622     virtual ~narrow_module();
00623     // multi-state inputs and outputs //////////////////////////////////////////
00624     virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out);
00625     virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out);
00626     virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out);
00627     // single-state inputs and outputs /////////////////////////////////////////
00628     virtual void fprop(Tstate &in, Tstate &out);
00630     virtual std::string describe();
00632     virtual narrow_module<T,Tstate>* copy();
00637     virtual mfidxdim fprop_size(mfidxdim &isize);
00641     virtual mfidxdim bprop_size(mfidxdim &osize);
00642   protected:
00643     int dim; 
00644     intg size; 
00645     vector<intg> offsets; 
00646     bool narrow_states; 
00647   };
00648 
00650   template<typename T, class Tstate = bbstate_idx<T> >
00651     class EXPORT table_module : public module_1_1<T,Tstate,Tstate> {
00652   public:
00655     table_module(vector<intg> &inputs, intg total,
00656                  const char *name = "table_module");
00658     virtual ~table_module();
00659     // multi-state inputs and outputs //////////////////////////////////////////
00660     virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out);
00661     virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out);
00662     virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out);
00664     virtual std::string describe();
00666     virtual table_module<T,Tstate>* copy();
00671     virtual mfidxdim fprop_size(mfidxdim &isize);
00675     virtual mfidxdim bprop_size(mfidxdim &osize);
00676   protected:
00677     vector<intg> table; 
00678     intg total; 
00679   };
00680 
00681   // network sizes methods /////////////////////////////////////////////////////
00682 
00685   template <typename T, class Tstate>
00686     EXPORT idxdim network_mindims(module_1_1<T,Tstate> &m, uint order);
00687 
00688 } // namespace ebl {
00689 
00690 #include "ebl_arch.hpp"
00691 
00692 #endif /* EBL_ARCH_H_ */