00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef EBLARCH_H_
00033 #define EBLARCH_H_
00034
00035 #include "Defines.h"
00036 #include "Idx.h"
00037 #include "Blas.h"
00038 #include "EblStates.h"
00039
00040 #ifdef __GUI__
00041 #include "libidxgui.h"
00042 #endif
00043
00044 namespace ebl {
00045
00047
00048
00050 template<class Tin, class Tout> class module_1_1 {
00051 public:
00052 bool bResize;
00053 module_1_1() { bResize = true; }
00054 virtual ~module_1_1() {}
00055 virtual void fprop(Tin &in, Tout &out);
00056 virtual void bprop(Tin &in, Tout &out);
00057 virtual void bbprop(Tin &in, Tout &out);
00058 virtual void forget(forget_param_linear& fp);
00059 virtual void normalize();
00061 virtual int replicable_order();
00062 virtual void resize_output(Tin &in, Tout &out);
00064 virtual void display_fprop(Tin &in, Tout &out, unsigned int &h0,
00065 unsigned int &w0, double zoom,
00066 bool show_out = false);
00067 };
00068
00070 template<class Tin1, class Tin2, class Tout> class module_2_1 {
00071 public:
00072 virtual ~module_2_1() {};
00073 virtual void fprop(Tin1 &in1, Tin2 &in2, Tout &out);
00074 virtual void bprop(Tin1 &in1, Tin2 &in2, Tout &out);
00075 virtual void bbprop(Tin1 &in1, Tin2 &in2, Tout &out);
00076 virtual void forget(forget_param &fp);
00077 virtual void normalize();
00078 };
00079
00082 template<class Tin> class ebm_1 {
00083 public:
00084 virtual ~ebm_1() {};
00085 virtual void fprop(Tin &in, state_idx &energy);
00086 virtual void bprop(Tin &in, state_idx &energy);
00087 virtual void bbprop(Tin &in, state_idx &energy);
00088 virtual void forget(forget_param &fp);
00089 virtual void normalize();
00090 };
00091
00094 template<class Tin1, class Tin2> class ebm_2 {
00095 public:
00096 virtual ~ebm_2() {};
00098 virtual void fprop(Tin1 &i1, Tin2 &i2, state_idx &energy);
00100 virtual void bprop(Tin1 &i1, Tin2 &i2, state_idx &energy);
00102 virtual void bbprop(Tin1 &i1, Tin2 &i2, state_idx &energy);
00103
00104 virtual void bprop1_copy(Tin1 &i1, Tin2 &i2, state_idx &energy);
00105 virtual void bprop2_copy(Tin1 &i1, Tin2 &i2, state_idx &energy);
00106 virtual void bbprop1_copy(Tin1 &i1, Tin2 &i2, state_idx &energy);
00107 virtual void bbprop2_copy(Tin1 &i1, Tin2 &i2, state_idx &energy);
00108 virtual void forget(forget_param_linear &fp);
00109 virtual void normalize();
00110
00112 virtual double infer1(Tin1 &i1, Tin2 &i2, state_idx &energy,
00113 infer_param &ip);
00115 virtual double infer2(Tin1 &i1, Tin2 &i2, state_idx &energy,
00116 infer_param &ip);
00117 };
00118
00120
00121
00122 template<class Tin, class Thid, class Tout>
00123 class layers_2: public module_1_1<Tin, Tout> {
00124 public:
00125 module_1_1<Tin, Thid> &layer1;
00126 Thid &hidden;
00127 module_1_1<Thid, Tout> &layer2;
00128
00129 layers_2(module_1_1<Tin, Thid> &l1, Thid &h, module_1_1<Thid, Tout> &l2);
00130 virtual ~layers_2();
00131 virtual void fprop(Tin &in, Tout &out);
00132 virtual void bprop(Tin &in, Tout &out);
00133 virtual void bbprop(Tin &in, Tout &out);
00134 virtual void forget(forget_param &fp);
00135 virtual void normalize();
00136 };
00137
00138 template<class T> class layers_n: public module_1_1<T, T> {
00139 public:
00140 std::vector<module_1_1<T, T>*> *modules;
00141 std::vector<T*> *hiddens;
00142
00143 layers_n();
00144 layers_n(bool oc);
00145 virtual ~layers_n();
00146 void addModule(module_1_1 <T, T>* module, T* hidden);
00147 void addLastModule(module_1_1 <T, T>* module);
00148 virtual void fprop(T &in, T &out);
00149 virtual void bprop(T &in, T &out);
00150 virtual void bbprop(T &in, T &out);
00151 virtual void forget(forget_param_linear &fp);
00152 virtual void normalize();
00154 void display_fprop(T &in, T &out, unsigned int &h0, unsigned int &w0,
00155 double zoom, bool show_out = true);
00156 private:
00157 bool own_contents;
00158 };
00159
00163 template<class Tin, class Thid> class fc_ebm1: public ebm_1<Tin> {
00164 public:
00165 module_1_1<Tin, Thid> &fmod;
00166 Thid &fout;
00167 ebm_1<Thid> &fcost;
00168
00169 fc_ebm1(module_1_1<Tin, Thid> &fm, Thid &fo, ebm_1<Thid> &fc);
00170 virtual ~fc_ebm1();
00171
00172 virtual void fprop(Tin &in, state_idx &energy);
00173 virtual void bprop(Tin &in, state_idx &energy);
00174 virtual void bbprop(Tin &in, state_idx &energy);
00175 virtual void forget(forget_param &fp);
00176 };
00177
00181 template<class Tin1, class Tin2, class Thid>
00182 class fc_ebm2: public ebm_2<Tin1, Tin2> {
00183 public:
00184 module_1_1<Tin1, Thid> &fmod;
00185 Thid &fout;
00186 ebm_2<Thid, Tin2> &fcost;
00187
00188 fc_ebm2(module_1_1<Tin1, Thid> &fm, Thid &fo, ebm_2<Thid, Tin2> &fc);
00189 virtual ~fc_ebm2();
00190
00191 virtual void fprop(Tin1 &in1, Tin2 &in2, state_idx &energy);
00192 virtual void bprop(Tin1 &in1, Tin2 &in2, state_idx &energy);
00193 virtual void bbprop(Tin1 &in1, Tin2 &in2, state_idx &energy);
00194 virtual void forget(forget_param_linear &fp);
00195 virtual double infer2(Tin1 &i1, Tin2 &i2, state_idx &energy,
00196 infer_param &ip);
00197 virtual void display_fprop(Tin1 &i1, Tin2 &i2, state_idx &energy,
00198 unsigned int &h0, unsigned int &w0,
00199 double zoom, bool show_out = false);
00200 };
00201
00203
00204
00206 void check_replicable_orders(module_1_1<state_idx, state_idx> &m,
00207 state_idx& in);
00208
00210
00211
00216 template<class T> class module_1_1_replicable {
00217 public:
00218 T &module;
00219 module_1_1_replicable(T &m);
00220 virtual ~module_1_1_replicable();
00221 virtual void fprop(state_idx &in, state_idx &out);
00222 virtual void bprop(state_idx &in, state_idx &out);
00223 virtual void bbprop(state_idx &in, state_idx &out);
00224 };
00225
00239 #define DECLARE_REPLICABLE_MODULE_1_1(replicable_module, base_module, \
00240 types_arguments, arguments) \
00241 class replicable_module : public base_module { \
00242 public: \
00243 module_1_1_replicable<base_module> rep; \
00244 replicable_module types_arguments : base_module arguments, rep(*this) { \
00245 bResize = false; \
00246 if (replicable_order() <= 0) ylerror("this module is not replicable"); } \
00247 virtual ~replicable_module() {} \
00248 virtual void fprop(state_idx &in, state_idx &out) { rep.fprop(in, out); } \
00249 virtual void bprop(state_idx &in, state_idx &out) { rep.bprop(in, out); } \
00250 virtual void bbprop(state_idx &in, state_idx &out){ rep.bbprop(in, out); }\
00251 }
00252
00253 }
00254
00255 #include "EblArch.hpp"
00256
00257 #endif