libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_machines.h
00001 /***************************************************************************
00002  *   Copyright (C) 2008 by Yann LeCun and Pierre Sermanet *
00003  *   yann@cs.nyu.edu, pierre.sermanet@gmail.com *
00004  *
00005  * Redistribution and use in source and binary forms, with or without
00006  * modification, are permitted provided that the following conditions are met:
00007  *     * Redistributions of source code must retain the above copyright
00008  *       notice, this list of conditions and the following disclaimer.
00009  *     * Redistributions in binary form must reproduce the above copyright
00010  *       notice, this list of conditions and the following disclaimer in the
00011  *       documentation and/or other materials provided with the distribution.
00012  *     * Redistribution under a license not approved by the Open Source
00013  *       Initiative (http://www.opensource.org) must display the
00014  *       following acknowledgement in all advertising material:
00015  *        This product includes software developed at the Courant
00016  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00017  *     * The names of the authors may not be used to endorse or promote products
00018  *       derived from this software without specific prior written permission.
00019  *
00020  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00021  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00022  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00023  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00024  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00025  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00026  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00027  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00028  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00029  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00030  ***************************************************************************/
00031 
00032 #ifndef EBL_MACHINES_H_
00033 #define EBL_MACHINES_H_
00034 
00035 #include "ebl_defines.h"
00036 #include "libidx.h"
00037 #include "ebl_states.h"
00038 #include "ebl_normalization.h"
00039 #include "ebl_basic.h"
00040 #include "ebl_cost.h"
00041 #include "ebl_arch.h"
00042 #include "ebl_nonlinearity.h"
00043 #include "ebl_layers.h"
00044 
00045 namespace ebl {
00046 
00049   template <typename T, class Tstate = bbstate_idx<T> >
00050     class net_cscscf : public layers<T,Tstate> {
00051   public:
00054     net_cscscf();
00057     net_cscscf(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0,
00058                idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 
00059                idx<intg> &tbl1, intg si1, intg sj1, intg ki2, intg kj2, 
00060                idx<intg> &tbl2, intg outthick, bool norm = false,
00061                bool mirror = false, bool tanh = false,
00062                bool shrink = false, bool diag = false);
00063     virtual ~net_cscscf();
00064 
00079     void init(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 
00080               idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 
00081               idx<intg> &tbl1, intg si1, intg sj1, intg ki2, intg kj2, 
00082               idx<intg> &tbl2, intg outthick, bool norm = false,
00083               bool mirror = false, bool tanh = false, bool shrink = false,
00084               bool diag = false);
00085   };
00086 
00089   template <typename T, class Tstate = bbstate_idx<T> >
00090     class net_cscf : public layers<T,Tstate> {
00091   public:
00094     net_cscf();
00097     net_cscf(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0,
00098              idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 
00099              idx<intg> &tbl1, intg outthick, bool norm = false,
00100              bool mirror = false, bool tanh = false,
00101              bool shrink = false, bool diag = false, bool lut_features = false,
00102              idx<T> *lut = NULL);
00103     virtual ~net_cscf();
00104 
00119     void init(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 
00120               idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 
00121               idx<intg> &tbl1, intg outthick, bool norm = false,
00122               bool mirror = false, bool tanh = false, bool shrink = false,
00123               bool diag = false, bool lut_features = false, idx<T> *lut = NULL);
00124   };
00125 
00128   template <typename T, class Tstate = bbstate_idx<T> >
00129     class net_cscc : public layers<T,Tstate> {
00130   public:
00133     net_cscc();
00136     net_cscc(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0,
00137              idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 
00138              idx<intg> &tbl1, idx<intg> &tbl2, intg outthick, bool norm = false,
00139              bool mirror = false, bool tanh = false,
00140              bool shrink = false, bool diag = false);
00141     virtual ~net_cscc();
00142 
00157     void init(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 
00158               idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 
00159               idx<intg> &tbl1, idx<intg> &tbl2, intg outthick, 
00160               bool norm = false,
00161               bool mirror = false, bool tanh = false, bool shrink = false,
00162               bool diag = false);
00163   };
00164 
00167   template <typename T, class Tstate = bbstate_idx<T> >
00168     class net_cscsc : public layers<T,Tstate> {
00169   public:
00172     net_cscsc();
00177     net_cscsc(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 
00178               idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 
00179               idx<intg> &tbl1, intg si1, intg sj1, intg ki2, intg kj2, 
00180               idx<intg> &tbl2, bool norm = false,
00181               bool mirror = false, bool tanh = false,
00182               bool shrink = false, bool diag = false, bool norm_pos = false);
00183     virtual ~net_cscsc();
00184 
00201     void init(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 
00202               idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 
00203               idx<intg> &tbl1, intg si1, intg sj1, intg ki2, intg kj2, 
00204               idx<intg> &tbl2, bool norm = false, bool mirror = false,
00205               bool tanh = false, bool shrink = false, bool diag = false,
00206               bool norm_pos = false);
00207   };
00208 
00215   template <typename T, class Tstate = bbstate_idx<T> >
00216     class lenet_cscsc : public net_cscsc<T,Tstate> {
00217   public:
00218     lenet_cscsc(parameter<T,Tstate> &prm, intg image_height, intg image_width,
00219                 intg ki0, intg kj0, intg si0, intg sj0, intg ki1, intg kj1,
00220                 intg si1, intg sj1, intg output_size,
00221                 bool norm = false, bool color = false, bool mirror = false,
00222                 bool tanh = false, bool shrink = false, bool diag = false, 
00223                 bool norm_pos = false,
00224                 idx<intg> *table0_ = NULL, idx<intg> *table1_ = NULL,
00225                 idx<intg> *table2_ = NULL);
00226     virtual ~lenet_cscsc() {}
00227   };
00228 
00235   template <typename T, class Tstate = bbstate_idx<T> >
00236     class lenet : public net_cscscf<T,Tstate> {
00237   public:
00238     lenet(parameter<T,Tstate> &prm, intg image_height, intg image_width,
00239           intg ki0, intg kj0, intg si0, intg sj0, intg ki1, intg kj1,
00240           intg si1, intg sj1, intg hid, intg output_size,
00241           bool norm = false, bool color = false, bool mirror = false,
00242           bool tanh = false, bool shrink = false, bool diag = false,
00243           idx<intg> *table0_ = NULL, idx<intg> *table1_ = NULL,
00244           idx<intg> *table2_ = NULL);
00245     virtual ~lenet() {}
00246   };
00247 
00254   template <typename T, class Tstate = bbstate_idx<T> >
00255     class lenet_cscf : public net_cscf<T,Tstate> {
00256   public:
00257     lenet_cscf(parameter<T,Tstate> &prm, intg image_height, intg image_width,
00258                intg ki0, intg kj0, intg si0, intg sj0, intg ki1, intg kj1,
00259                intg output_size, bool norm = false, bool color = false,
00260                bool mirror = false, bool tanh = false, bool shrink = false, 
00261                bool diag = false,
00262                idx<intg> *table0_ = NULL, idx<intg> *table1_ = NULL);
00263     virtual ~lenet_cscf() {}
00264   };
00265 
00291   template <typename T, class Tstate = bbstate_idx<T> >
00292     class lenet5 : public net_cscscf<T,Tstate> {
00293   public:
00294     lenet5(parameter<T,Tstate> &prm, intg image_height, intg image_width,
00295            intg ki0, intg kj0, intg si0, intg sj0,
00296            intg ki1, intg kj1, intg si1, intg sj1,
00297            intg hid, intg output_size, bool norm = false, bool mirror = false,
00298            bool tanh = false, bool shrink = false, bool diag = false);
00299     virtual ~lenet5() {}
00300   };
00301 
00305   template <typename T, class Tstate = bbstate_idx<T> >
00306     class lenet7 : public net_cscscf<T,Tstate> {
00307   public:
00310     lenet7(parameter<T,Tstate> &prm, intg image_height, intg image_width,
00311            intg output_size, bool norm = false, bool mirror = false,
00312            bool tanh = false, bool shrink = false, bool diag = false);
00313     virtual ~lenet7() {}
00314   };
00315   
00320   template <typename T, class Tstate = bbstate_idx<T> >
00321     class lenet7_binocular : public net_cscscf<T,Tstate> {
00322   public:
00325     lenet7_binocular(parameter<T,Tstate> &prm, intg image_height,
00326                      intg image_width,
00327                      intg output_size, bool norm = false, bool mirror = false,
00328                      bool tanh = false, bool shrink = false, bool diag = false);
00329     virtual ~lenet7_binocular() {}
00330   };
00331   
00334   template <typename Tdata, class Tlabel, class Tstate = bbstate_idx<Tdata> >
00335     class supervised_euclidean_machine
00336     : public fc_ebm2<Tdata, Tstate, bbstate_idx<Tlabel>, Tstate> {
00337   public:
00338     euclidean_module<Tdata, Tlabel>     fcost;  // euclidean cost function
00339     Tstate                      fout;   // hidden state in between
00340 
00341     supervised_euclidean_machine(module_1_1<Tdata,Tstate> &net_,
00342                                  idx<Tdata> &targets, idxdim &dims);
00343     virtual ~supervised_euclidean_machine();
00344   };
00345   
00347   // some connection tables
00348 
00350   static intg connection_table_6_16[60][2] =
00351     {{0, 0},  {1, 0},  {2, 0},
00352      {1, 1},  {2, 1},  {3, 1},
00353      {2, 2},  {3, 2},  {4, 2},
00354      {3, 3},  {4, 3},  {5, 3},
00355      {4, 4},  {5, 4},  {0, 4},
00356      {5, 5},  {0, 5},  {1, 5},
00357      
00358      {0, 6},  {1, 6},  {2, 6},  {3, 6},
00359      {1, 7},  {2, 7},  {3, 7},  {4, 7},
00360      {2, 8},  {3, 8},  {4, 8},  {5, 8},
00361      {3, 9},  {4, 9},  {5, 9},  {0, 9},
00362      {4, 10}, {5, 10}, {0, 10}, {1, 10},
00363      {5, 11}, {0, 11}, {1, 11}, {2, 11},
00364      
00365      {0, 12}, {1, 12}, {3, 12}, {4, 12},
00366      {1, 13}, {2, 13}, {4, 13}, {5, 13},
00367      {2, 14}, {3, 14}, {5, 14}, {0, 14},
00368      
00369      {0, 15}, {1, 15}, {2, 15}, {3, 15}, {4, 15}, {5, 15}};
00370   
00372   static intg connection_table_8_24[96][2] =
00373     {{0,  0}, {2,  0}, {4,  0}, {5,  0},
00374      {0,  1}, {2,  1}, {4,  1}, {6,  1},
00375      {0,  2}, {2,  2}, {4,  2}, {7,  2},
00376      {0,  3}, {2,  3}, {5,  3}, {6,  3},
00377      {0,  4}, {2,  4}, {5,  4}, {7,  4},
00378      {0,  5}, {2,  5}, {6,  5}, {7,  5},
00379      {1,  6}, {3,  6}, {4,  6}, {5,  6},
00380      {1,  7}, {3,  7}, {4,  7}, {6,  7},
00381      {1,  8}, {3,  8}, {4,  8}, {7,  8},
00382      {1,  9}, {3,  9}, {5,  9}, {6,  9},
00383      {1, 10}, {3, 10}, {5, 10}, {7, 10},
00384      {1, 11}, {3, 11}, {6, 11}, {7, 11},
00385      {1, 12}, {2, 12}, {4, 12}, {5, 12},
00386      {1, 13}, {2, 13}, {4, 13}, {6, 13},
00387      {1, 14}, {2, 14}, {4, 14}, {7, 14},
00388      {1, 15}, {2, 15}, {5, 15}, {6, 15},
00389      {1, 16}, {2, 16}, {5, 16}, {7, 16},
00390      {1, 17}, {2, 17}, {6, 17}, {7, 17},
00391      {0, 18}, {3, 18}, {4, 18}, {5, 18},
00392      {0, 19}, {3, 19}, {4, 19}, {6, 19},
00393      {0, 20}, {3, 20}, {4, 20}, {7, 20},
00394      {0, 21}, {3, 21}, {5, 21}, {6, 21},
00395      {0, 22}, {3, 22}, {5, 22}, {7, 22},
00396      {0, 23}, {3, 23}, {6, 23}, {7, 23}};
00397 
00398 } // namespace ebl {
00399 
00400 #include "ebl_machines.hpp"
00401 
00402 #endif /* EBL_MACHINES_H_ */