libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_march.h
00001 /***************************************************************************
00002  *   Copyright (C) 2011 by Pierre Sermanet *
00003  *   pierre.sermanet@gmail.com *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef EBL_MARCH_H_
00034 #define EBL_MARCH_H_
00035 
00036 #include "ebl_arch.h"
00037 
00038 namespace ebl {
00039 
00042   template<typename T, class Tin = bbstate_idx<T>, class Tout = Tin>
00043     class s2m_module : virtual public module_1_1<T,Tin,Tout> {
00044   public:
00047     s2m_module(uint nstates, const char *name = "s2m_module");
00048     virtual ~s2m_module();
00049     virtual void fprop(Tin &in, mstate<Tout> &out);
00050     virtual void bprop(Tin &in, mstate<Tout> &out);
00051     virtual void bbprop(Tin &in, mstate<Tout> &out);
00053     virtual mfidxdim bprop_size(mfidxdim &osize);
00055     virtual uint nstates();
00056   protected:
00061     virtual void resize_output(Tin &in, mstate<Tout> &out, idxdim *d = NULL);
00062     // members /////////////////////////////////////////////////////////////////
00063   protected:
00064     uint _nstates;
00065   };
00066 
00069   template<typename T, class Tin = bbstate_idx<T>, class Tout = Tin>
00070     class m2s_module : virtual public module_1_1<T,Tin,Tout> {
00071   public:
00074     m2s_module(uint nstates, const char *name = "m2s_module");
00075     virtual ~m2s_module();
00076     virtual void fprop(mstate<Tin> &in, Tout &out);
00077     virtual void bprop(mstate<Tin> &in, Tout &out);
00078     virtual void bbprop(mstate<Tin> &in, Tout &out);
00080     virtual mfidxdim bprop_size(mfidxdim &osize);
00082     virtual uint nstates();
00083     // members /////////////////////////////////////////////////////////////////
00084   protected:
00085     uint _nstates;
00086   };
00087 
00094   template<typename T, class Tstate = bbstate_idx<T> >
00095     class ms_module : public module_1_1<T,Tstate,Tstate> {
00096   public:
00097     // constructors ////////////////////////////////////////////////////////////
00099     ms_module(bool replicate_inputs = false, const char *name = "ms_module");
00101     ms_module(module_1_1<T,Tstate> *pipe, uint n = 1,
00102               bool replicate_inputs = false, const char *name = "ms_module");
00105     ms_module(std::vector<module_1_1<T,Tstate>*> &pipes,
00106               bool replicate_inputs = false, const char *name = "ms_module");
00108     virtual ~ms_module();
00109 
00110     // multi-state inputs and outputs //////////////////////////////////////////
00111     virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out);
00112     virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out);
00113     virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out);
00116     virtual void dump_fprop(mstate<Tstate> &in, mstate<Tstate> &out);
00118     virtual void forget(forget_param_linear& fp);
00119 
00120     // sizes propagations //////////////////////////////////////////////////////
00126     virtual fidxdim fprop_size(fidxdim &i_size);
00131     virtual mfidxdim fprop_size(mfidxdim &isize);
00135     virtual fidxdim bprop_size(const fidxdim &o_size);
00139     virtual mfidxdim bprop_size(mfidxdim &o_size);
00140 
00141     // printing ////////////////////////////////////////////////////////////////
00145     virtual std::string pretty(idxdim &isize);
00149     virtual std::string pretty(mfidxdim &isize);
00151     virtual std::string describe();
00152 
00153     // accessors ///////////////////////////////////////////////////////////////
00155     virtual void set_switch(midxdim &sizes);
00157     virtual void set_switch(intg id);
00159     virtual uint npipes();
00161     virtual module_1_1<T,Tstate>* get_pipe(uint i);
00164     virtual module_1_1<T,Tstate>* last_module();
00165 
00166     // internal methods ////////////////////////////////////////////////////////
00167   protected:
00169     virtual void init();
00171     virtual void init_fprop(mstate<Tstate> &in, mstate<Tstate> &out);
00173     virtual void switch_pipes(mstate<Tstate> &in);
00174 
00175     // variable members ////////////////////////////////////////////////////////
00176   protected:
00177     std::vector<module_1_1<T,Tstate>*> pipes; 
00178     std::vector<module_1_1<T,Tstate>*> used_pipes; 
00179     std::vector<uint> pipes_noutputs; 
00180     svector<mstate<Tstate> > ins; 
00181     svector<mstate<Tstate> > mbuffers; 
00182     bool replicate_inputs; 
00183     midxdim switches; 
00184     bool bindex; 
00185     intg switch_id; 
00186 
00187     // friends /////////////////////////////////////////////////////////////////
00188     template <typename T1, class Ts, class Tc>
00189       friend EXPORT Tc* arch_find(ms_module<T1,Ts> *m, Tc *c);
00190     template <typename T1, class Ts, class Tc>
00191     friend EXPORT std::vector<Tc*> arch_find_all(ms_module<T1,Ts> *m, Tc *c,
00192                                                  std::vector<Tc*> *);
00193     template <typename T1, class Ts, class Tc>
00194     friend EXPORT ms_module<T1,Ts>*
00195       arch_narrow(ms_module<T1,Ts> *m, Tc *c, bool i, bool *f);
00196     friend class ms_module_gui;
00197   };
00198 
00202   template<typename T, class Tstate = bbstate_idx<T> >
00203     class msc_module : public ms_module<T,Tstate> {
00204   public:
00211     msc_module(std::vector<module_1_1<T,Tstate>*> &pipes, uint nsize = 1,
00212                uint stride = 1, uint nsize2 = 0,
00213                const char *name = "msc_module");
00215     virtual ~msc_module();
00222     virtual fidxdim fprop_size(fidxdim &i_size);
00226     virtual fidxdim bprop_size(const fidxdim &o_size);
00230     virtual mfidxdim bprop_size(mfidxdim &o_size);
00232     virtual std::string describe();
00233 
00234     // internal methods ////////////////////////////////////////////////////////
00235   protected:
00237     virtual void init_fprop(mstate<Tstate> &in, mstate<Tstate> &out);
00238 
00239   // variable members //////////////////////////////////////////////////////////
00240   protected:
00241     using ms_module<T,Tstate>::pipes;
00242     using ms_module<T,Tstate>::used_pipes;
00243     using ms_module<T,Tstate>::ins;
00244     using ms_module<T,Tstate>::pipes_noutputs;
00245     uint nsize; 
00246     uint stride; 
00247     uint nsize2; 
00248   };
00249 
00250   // arch find methods /////////////////////////////////////////////////////////
00251 
00254   template <typename T, class Tstate, class Tcast>
00255     EXPORT Tcast* arch_find(module_1_1<T,Tstate> *m, Tcast *c);
00258   template <typename T, class Tstate, class Tcast>
00259     EXPORT Tcast* arch_find(layers<T,Tstate> *m, Tcast *c);
00262   template <typename T, class Tstate, class Tcast>
00263     EXPORT Tcast* arch_find(ms_module<T,Tstate> *m, Tcast *c);
00264 
00265   // arch find_all methods /////////////////////////////////////////////////////
00266 
00269   template <typename T, class Tstate, class Tcast>
00270     EXPORT std::vector<Tcast*> arch_find_all(module_1_1<T,Tstate> *m, Tcast *c,
00271                                              std::vector<Tcast*> *v = NULL);
00274   template <typename T, class Tstate, class Tcast>
00275     EXPORT std::vector<Tcast*> arch_find_all(layers<T,Tstate> *m, Tcast *c,
00276                                              std::vector<Tcast*> *v = NULL);
00279   template <typename T, class Tstate, class Tcast>
00280     EXPORT std::vector<Tcast*> arch_find_all(ms_module<T,Tstate> *m, Tcast *c,
00281                                              std::vector<Tcast*> *v = NULL);
00282 
00283   // arch narrow methods ///////////////////////////////////////////////////////
00284 
00289   template <typename T, class Tstate, class Tcast>
00290     EXPORT module_1_1<T,Tstate>*
00291     arch_narrow(module_1_1<T,Tstate> *m, Tcast *c, bool included = true,
00292                 bool *found = NULL);
00297   template <typename T, class Tstate, class Tcast>
00298     EXPORT layers<T,Tstate>*
00299     arch_narrow(layers<T,Tstate> *m, Tcast *c, bool included = true,
00300                 bool *found = NULL);
00305   template <typename T, class Tstate, class Tcast>
00306     EXPORT ms_module<T,Tstate>*
00307     arch_narrow(ms_module<T,Tstate> *m, Tcast *c, bool included = true,
00308                 bool *found = NULL);
00309 
00310 } // namespace ebl {
00311 
00312 #include "ebl_march.hpp"
00313 
00314 #endif /* EBL_MARCH_H_ */