libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2011 by Pierre Sermanet * 00003 * pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef EBL_MARCH_H_ 00034 #define EBL_MARCH_H_ 00035 00036 #include "ebl_arch.h" 00037 00038 namespace ebl { 00039 00042 template<typename T, class Tin = bbstate_idx<T>, class Tout = Tin> 00043 class s2m_module : virtual public module_1_1<T,Tin,Tout> { 00044 public: 00047 s2m_module(uint nstates, const char *name = "s2m_module"); 00048 virtual ~s2m_module(); 00049 virtual void fprop(Tin &in, mstate<Tout> &out); 00050 virtual void bprop(Tin &in, mstate<Tout> &out); 00051 virtual void bbprop(Tin &in, mstate<Tout> &out); 00053 virtual mfidxdim bprop_size(mfidxdim &osize); 00055 virtual uint nstates(); 00056 protected: 00061 virtual void resize_output(Tin &in, mstate<Tout> &out, idxdim *d = NULL); 00062 // members ///////////////////////////////////////////////////////////////// 00063 protected: 00064 uint _nstates; 00065 }; 00066 00069 template<typename T, class Tin = bbstate_idx<T>, class Tout = Tin> 00070 class m2s_module : virtual public module_1_1<T,Tin,Tout> { 00071 public: 00074 m2s_module(uint nstates, const char *name = "m2s_module"); 00075 virtual ~m2s_module(); 00076 virtual void fprop(mstate<Tin> &in, Tout &out); 00077 virtual void bprop(mstate<Tin> &in, Tout &out); 00078 virtual void bbprop(mstate<Tin> &in, Tout &out); 00080 virtual mfidxdim bprop_size(mfidxdim &osize); 00082 virtual uint nstates(); 00083 // members ///////////////////////////////////////////////////////////////// 00084 protected: 00085 uint _nstates; 00086 }; 00087 00094 template<typename T, class Tstate = bbstate_idx<T> > 00095 class ms_module : public module_1_1<T,Tstate,Tstate> { 00096 public: 00097 // constructors //////////////////////////////////////////////////////////// 00099 ms_module(bool replicate_inputs = false, const char *name = "ms_module"); 00101 ms_module(module_1_1<T,Tstate> *pipe, uint n = 1, 00102 bool replicate_inputs = false, const char *name = "ms_module"); 00105 ms_module(std::vector<module_1_1<T,Tstate>*> &pipes, 00106 bool replicate_inputs = false, const char *name = "ms_module"); 00108 virtual ~ms_module(); 00109 00110 // multi-state inputs and outputs ////////////////////////////////////////// 00111 virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out); 00112 virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out); 00113 virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out); 00116 virtual void dump_fprop(mstate<Tstate> &in, mstate<Tstate> &out); 00118 virtual void forget(forget_param_linear& fp); 00119 00120 // sizes propagations ////////////////////////////////////////////////////// 00126 virtual fidxdim fprop_size(fidxdim &i_size); 00131 virtual mfidxdim fprop_size(mfidxdim &isize); 00135 virtual fidxdim bprop_size(const fidxdim &o_size); 00139 virtual mfidxdim bprop_size(mfidxdim &o_size); 00140 00141 // printing //////////////////////////////////////////////////////////////// 00145 virtual std::string pretty(idxdim &isize); 00149 virtual std::string pretty(mfidxdim &isize); 00151 virtual std::string describe(); 00152 00153 // accessors /////////////////////////////////////////////////////////////// 00155 virtual void set_switch(midxdim &sizes); 00157 virtual void set_switch(intg id); 00159 virtual uint npipes(); 00161 virtual module_1_1<T,Tstate>* get_pipe(uint i); 00164 virtual module_1_1<T,Tstate>* last_module(); 00165 00166 // internal methods //////////////////////////////////////////////////////// 00167 protected: 00169 virtual void init(); 00171 virtual void init_fprop(mstate<Tstate> &in, mstate<Tstate> &out); 00173 virtual void switch_pipes(mstate<Tstate> &in); 00174 00175 // variable members //////////////////////////////////////////////////////// 00176 protected: 00177 std::vector<module_1_1<T,Tstate>*> pipes; 00178 std::vector<module_1_1<T,Tstate>*> used_pipes; 00179 std::vector<uint> pipes_noutputs; 00180 svector<mstate<Tstate> > ins; 00181 svector<mstate<Tstate> > mbuffers; 00182 bool replicate_inputs; 00183 midxdim switches; 00184 bool bindex; 00185 intg switch_id; 00186 00187 // friends ///////////////////////////////////////////////////////////////// 00188 template <typename T1, class Ts, class Tc> 00189 friend EXPORT Tc* arch_find(ms_module<T1,Ts> *m, Tc *c); 00190 template <typename T1, class Ts, class Tc> 00191 friend EXPORT std::vector<Tc*> arch_find_all(ms_module<T1,Ts> *m, Tc *c, 00192 std::vector<Tc*> *); 00193 template <typename T1, class Ts, class Tc> 00194 friend EXPORT ms_module<T1,Ts>* 00195 arch_narrow(ms_module<T1,Ts> *m, Tc *c, bool i, bool *f); 00196 friend class ms_module_gui; 00197 }; 00198 00202 template<typename T, class Tstate = bbstate_idx<T> > 00203 class msc_module : public ms_module<T,Tstate> { 00204 public: 00211 msc_module(std::vector<module_1_1<T,Tstate>*> &pipes, uint nsize = 1, 00212 uint stride = 1, uint nsize2 = 0, 00213 const char *name = "msc_module"); 00215 virtual ~msc_module(); 00222 virtual fidxdim fprop_size(fidxdim &i_size); 00226 virtual fidxdim bprop_size(const fidxdim &o_size); 00230 virtual mfidxdim bprop_size(mfidxdim &o_size); 00232 virtual std::string describe(); 00233 00234 // internal methods //////////////////////////////////////////////////////// 00235 protected: 00237 virtual void init_fprop(mstate<Tstate> &in, mstate<Tstate> &out); 00238 00239 // variable members ////////////////////////////////////////////////////////// 00240 protected: 00241 using ms_module<T,Tstate>::pipes; 00242 using ms_module<T,Tstate>::used_pipes; 00243 using ms_module<T,Tstate>::ins; 00244 using ms_module<T,Tstate>::pipes_noutputs; 00245 uint nsize; 00246 uint stride; 00247 uint nsize2; 00248 }; 00249 00250 // arch find methods ///////////////////////////////////////////////////////// 00251 00254 template <typename T, class Tstate, class Tcast> 00255 EXPORT Tcast* arch_find(module_1_1<T,Tstate> *m, Tcast *c); 00258 template <typename T, class Tstate, class Tcast> 00259 EXPORT Tcast* arch_find(layers<T,Tstate> *m, Tcast *c); 00262 template <typename T, class Tstate, class Tcast> 00263 EXPORT Tcast* arch_find(ms_module<T,Tstate> *m, Tcast *c); 00264 00265 // arch find_all methods ///////////////////////////////////////////////////// 00266 00269 template <typename T, class Tstate, class Tcast> 00270 EXPORT std::vector<Tcast*> arch_find_all(module_1_1<T,Tstate> *m, Tcast *c, 00271 std::vector<Tcast*> *v = NULL); 00274 template <typename T, class Tstate, class Tcast> 00275 EXPORT std::vector<Tcast*> arch_find_all(layers<T,Tstate> *m, Tcast *c, 00276 std::vector<Tcast*> *v = NULL); 00279 template <typename T, class Tstate, class Tcast> 00280 EXPORT std::vector<Tcast*> arch_find_all(ms_module<T,Tstate> *m, Tcast *c, 00281 std::vector<Tcast*> *v = NULL); 00282 00283 // arch narrow methods /////////////////////////////////////////////////////// 00284 00289 template <typename T, class Tstate, class Tcast> 00290 EXPORT module_1_1<T,Tstate>* 00291 arch_narrow(module_1_1<T,Tstate> *m, Tcast *c, bool included = true, 00292 bool *found = NULL); 00297 template <typename T, class Tstate, class Tcast> 00298 EXPORT layers<T,Tstate>* 00299 arch_narrow(layers<T,Tstate> *m, Tcast *c, bool included = true, 00300 bool *found = NULL); 00305 template <typename T, class Tstate, class Tcast> 00306 EXPORT ms_module<T,Tstate>* 00307 arch_narrow(ms_module<T,Tstate> *m, Tcast *c, bool included = true, 00308 bool *found = NULL); 00309 00310 } // namespace ebl { 00311 00312 #include "ebl_march.hpp" 00313 00314 #endif /* EBL_MARCH_H_ */