libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2008 by Yann LeCun and Pierre Sermanet * 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com * 00004 * 00005 * Redistribution and use in source and binary forms, with or without 00006 * modification, are permitted provided that the following conditions are met: 00007 * * Redistributions of source code must retain the above copyright 00008 * notice, this list of conditions and the following disclaimer. 00009 * * Redistributions in binary form must reproduce the above copyright 00010 * notice, this list of conditions and the following disclaimer in the 00011 * documentation and/or other materials provided with the distribution. 00012 * * Redistribution under a license not approved by the Open Source 00013 * Initiative (http://www.opensource.org) must display the 00014 * following acknowledgement in all advertising material: 00015 * This product includes software developed at the Courant 00016 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00017 * * The names of the authors may not be used to endorse or promote products 00018 * derived from this software without specific prior written permission. 00019 * 00020 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00021 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00022 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00023 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00024 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00025 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00026 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00027 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00028 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00029 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00030 ***************************************************************************/ 00031 00032 #ifndef EBL_MACHINES_H_ 00033 #define EBL_MACHINES_H_ 00034 00035 #include "ebl_defines.h" 00036 #include "libidx.h" 00037 #include "ebl_states.h" 00038 #include "ebl_normalization.h" 00039 #include "ebl_basic.h" 00040 #include "ebl_cost.h" 00041 #include "ebl_arch.h" 00042 #include "ebl_nonlinearity.h" 00043 #include "ebl_layers.h" 00044 00045 namespace ebl { 00046 00049 template <typename T, class Tstate = bbstate_idx<T> > 00050 class net_cscscf : public layers<T,Tstate> { 00051 public: 00054 net_cscscf(); 00057 net_cscscf(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 00058 idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 00059 idx<intg> &tbl1, intg si1, intg sj1, intg ki2, intg kj2, 00060 idx<intg> &tbl2, intg outthick, bool norm = false, 00061 bool mirror = false, bool tanh = false, 00062 bool shrink = false, bool diag = false); 00063 virtual ~net_cscscf(); 00064 00079 void init(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 00080 idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 00081 idx<intg> &tbl1, intg si1, intg sj1, intg ki2, intg kj2, 00082 idx<intg> &tbl2, intg outthick, bool norm = false, 00083 bool mirror = false, bool tanh = false, bool shrink = false, 00084 bool diag = false); 00085 }; 00086 00089 template <typename T, class Tstate = bbstate_idx<T> > 00090 class net_cscf : public layers<T,Tstate> { 00091 public: 00094 net_cscf(); 00097 net_cscf(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 00098 idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 00099 idx<intg> &tbl1, intg outthick, bool norm = false, 00100 bool mirror = false, bool tanh = false, 00101 bool shrink = false, bool diag = false, bool lut_features = false, 00102 idx<T> *lut = NULL); 00103 virtual ~net_cscf(); 00104 00119 void init(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 00120 idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 00121 idx<intg> &tbl1, intg outthick, bool norm = false, 00122 bool mirror = false, bool tanh = false, bool shrink = false, 00123 bool diag = false, bool lut_features = false, idx<T> *lut = NULL); 00124 }; 00125 00128 template <typename T, class Tstate = bbstate_idx<T> > 00129 class net_cscc : public layers<T,Tstate> { 00130 public: 00133 net_cscc(); 00136 net_cscc(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 00137 idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 00138 idx<intg> &tbl1, idx<intg> &tbl2, intg outthick, bool norm = false, 00139 bool mirror = false, bool tanh = false, 00140 bool shrink = false, bool diag = false); 00141 virtual ~net_cscc(); 00142 00157 void init(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 00158 idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 00159 idx<intg> &tbl1, idx<intg> &tbl2, intg outthick, 00160 bool norm = false, 00161 bool mirror = false, bool tanh = false, bool shrink = false, 00162 bool diag = false); 00163 }; 00164 00167 template <typename T, class Tstate = bbstate_idx<T> > 00168 class net_cscsc : public layers<T,Tstate> { 00169 public: 00172 net_cscsc(); 00177 net_cscsc(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 00178 idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 00179 idx<intg> &tbl1, intg si1, intg sj1, intg ki2, intg kj2, 00180 idx<intg> &tbl2, bool norm = false, 00181 bool mirror = false, bool tanh = false, 00182 bool shrink = false, bool diag = false, bool norm_pos = false); 00183 virtual ~net_cscsc(); 00184 00201 void init(parameter<T,Tstate> &prm, intg ini, intg inj, intg ki0, intg kj0, 00202 idx<intg> &tbl0, intg si0, intg sj0, intg ki1, intg kj1, 00203 idx<intg> &tbl1, intg si1, intg sj1, intg ki2, intg kj2, 00204 idx<intg> &tbl2, bool norm = false, bool mirror = false, 00205 bool tanh = false, bool shrink = false, bool diag = false, 00206 bool norm_pos = false); 00207 }; 00208 00215 template <typename T, class Tstate = bbstate_idx<T> > 00216 class lenet_cscsc : public net_cscsc<T,Tstate> { 00217 public: 00218 lenet_cscsc(parameter<T,Tstate> &prm, intg image_height, intg image_width, 00219 intg ki0, intg kj0, intg si0, intg sj0, intg ki1, intg kj1, 00220 intg si1, intg sj1, intg output_size, 00221 bool norm = false, bool color = false, bool mirror = false, 00222 bool tanh = false, bool shrink = false, bool diag = false, 00223 bool norm_pos = false, 00224 idx<intg> *table0_ = NULL, idx<intg> *table1_ = NULL, 00225 idx<intg> *table2_ = NULL); 00226 virtual ~lenet_cscsc() {} 00227 }; 00228 00235 template <typename T, class Tstate = bbstate_idx<T> > 00236 class lenet : public net_cscscf<T,Tstate> { 00237 public: 00238 lenet(parameter<T,Tstate> &prm, intg image_height, intg image_width, 00239 intg ki0, intg kj0, intg si0, intg sj0, intg ki1, intg kj1, 00240 intg si1, intg sj1, intg hid, intg output_size, 00241 bool norm = false, bool color = false, bool mirror = false, 00242 bool tanh = false, bool shrink = false, bool diag = false, 00243 idx<intg> *table0_ = NULL, idx<intg> *table1_ = NULL, 00244 idx<intg> *table2_ = NULL); 00245 virtual ~lenet() {} 00246 }; 00247 00254 template <typename T, class Tstate = bbstate_idx<T> > 00255 class lenet_cscf : public net_cscf<T,Tstate> { 00256 public: 00257 lenet_cscf(parameter<T,Tstate> &prm, intg image_height, intg image_width, 00258 intg ki0, intg kj0, intg si0, intg sj0, intg ki1, intg kj1, 00259 intg output_size, bool norm = false, bool color = false, 00260 bool mirror = false, bool tanh = false, bool shrink = false, 00261 bool diag = false, 00262 idx<intg> *table0_ = NULL, idx<intg> *table1_ = NULL); 00263 virtual ~lenet_cscf() {} 00264 }; 00265 00291 template <typename T, class Tstate = bbstate_idx<T> > 00292 class lenet5 : public net_cscscf<T,Tstate> { 00293 public: 00294 lenet5(parameter<T,Tstate> &prm, intg image_height, intg image_width, 00295 intg ki0, intg kj0, intg si0, intg sj0, 00296 intg ki1, intg kj1, intg si1, intg sj1, 00297 intg hid, intg output_size, bool norm = false, bool mirror = false, 00298 bool tanh = false, bool shrink = false, bool diag = false); 00299 virtual ~lenet5() {} 00300 }; 00301 00305 template <typename T, class Tstate = bbstate_idx<T> > 00306 class lenet7 : public net_cscscf<T,Tstate> { 00307 public: 00310 lenet7(parameter<T,Tstate> &prm, intg image_height, intg image_width, 00311 intg output_size, bool norm = false, bool mirror = false, 00312 bool tanh = false, bool shrink = false, bool diag = false); 00313 virtual ~lenet7() {} 00314 }; 00315 00320 template <typename T, class Tstate = bbstate_idx<T> > 00321 class lenet7_binocular : public net_cscscf<T,Tstate> { 00322 public: 00325 lenet7_binocular(parameter<T,Tstate> &prm, intg image_height, 00326 intg image_width, 00327 intg output_size, bool norm = false, bool mirror = false, 00328 bool tanh = false, bool shrink = false, bool diag = false); 00329 virtual ~lenet7_binocular() {} 00330 }; 00331 00334 template <typename Tdata, class Tlabel, class Tstate = bbstate_idx<Tdata> > 00335 class supervised_euclidean_machine 00336 : public fc_ebm2<Tdata, Tstate, bbstate_idx<Tlabel>, Tstate> { 00337 public: 00338 euclidean_module<Tdata, Tlabel> fcost; // euclidean cost function 00339 Tstate fout; // hidden state in between 00340 00341 supervised_euclidean_machine(module_1_1<Tdata,Tstate> &net_, 00342 idx<Tdata> &targets, idxdim &dims); 00343 virtual ~supervised_euclidean_machine(); 00344 }; 00345 00347 // some connection tables 00348 00350 static intg connection_table_6_16[60][2] = 00351 {{0, 0}, {1, 0}, {2, 0}, 00352 {1, 1}, {2, 1}, {3, 1}, 00353 {2, 2}, {3, 2}, {4, 2}, 00354 {3, 3}, {4, 3}, {5, 3}, 00355 {4, 4}, {5, 4}, {0, 4}, 00356 {5, 5}, {0, 5}, {1, 5}, 00357 00358 {0, 6}, {1, 6}, {2, 6}, {3, 6}, 00359 {1, 7}, {2, 7}, {3, 7}, {4, 7}, 00360 {2, 8}, {3, 8}, {4, 8}, {5, 8}, 00361 {3, 9}, {4, 9}, {5, 9}, {0, 9}, 00362 {4, 10}, {5, 10}, {0, 10}, {1, 10}, 00363 {5, 11}, {0, 11}, {1, 11}, {2, 11}, 00364 00365 {0, 12}, {1, 12}, {3, 12}, {4, 12}, 00366 {1, 13}, {2, 13}, {4, 13}, {5, 13}, 00367 {2, 14}, {3, 14}, {5, 14}, {0, 14}, 00368 00369 {0, 15}, {1, 15}, {2, 15}, {3, 15}, {4, 15}, {5, 15}}; 00370 00372 static intg connection_table_8_24[96][2] = 00373 {{0, 0}, {2, 0}, {4, 0}, {5, 0}, 00374 {0, 1}, {2, 1}, {4, 1}, {6, 1}, 00375 {0, 2}, {2, 2}, {4, 2}, {7, 2}, 00376 {0, 3}, {2, 3}, {5, 3}, {6, 3}, 00377 {0, 4}, {2, 4}, {5, 4}, {7, 4}, 00378 {0, 5}, {2, 5}, {6, 5}, {7, 5}, 00379 {1, 6}, {3, 6}, {4, 6}, {5, 6}, 00380 {1, 7}, {3, 7}, {4, 7}, {6, 7}, 00381 {1, 8}, {3, 8}, {4, 8}, {7, 8}, 00382 {1, 9}, {3, 9}, {5, 9}, {6, 9}, 00383 {1, 10}, {3, 10}, {5, 10}, {7, 10}, 00384 {1, 11}, {3, 11}, {6, 11}, {7, 11}, 00385 {1, 12}, {2, 12}, {4, 12}, {5, 12}, 00386 {1, 13}, {2, 13}, {4, 13}, {6, 13}, 00387 {1, 14}, {2, 14}, {4, 14}, {7, 14}, 00388 {1, 15}, {2, 15}, {5, 15}, {6, 15}, 00389 {1, 16}, {2, 16}, {5, 16}, {7, 16}, 00390 {1, 17}, {2, 17}, {6, 17}, {7, 17}, 00391 {0, 18}, {3, 18}, {4, 18}, {5, 18}, 00392 {0, 19}, {3, 19}, {4, 19}, {6, 19}, 00393 {0, 20}, {3, 20}, {4, 20}, {7, 20}, 00394 {0, 21}, {3, 21}, {5, 21}, {6, 21}, 00395 {0, 22}, {3, 22}, {5, 22}, {7, 22}, 00396 {0, 23}, {3, 23}, {6, 23}, {7, 23}}; 00397 00398 } // namespace ebl { 00399 00400 #include "ebl_machines.hpp" 00401 00402 #endif /* EBL_MACHINES_H_ */