libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2011 by Yann LeCun, Pierre Sermanet and Soumith Chintala* 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com, soumith@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef EBL_BASIC_H_ 00034 #define EBL_BASIC_H_ 00035 00036 #include "ebl_defines.h" 00037 #include "libidx.h" 00038 #include "ebl_arch.h" 00039 #include "ebl_states.h" 00040 #include "ebl_utils.h" 00041 #include "bbox.h" 00042 00043 namespace ebl { 00044 00046 // linear_module 00051 template <typename T, class Tstate = bbstate_idx<T> > 00052 class linear_module: public module_1_1<T, Tstate> { 00053 public: 00059 linear_module(parameter<T,Tstate> *p, intg in, intg out, 00060 const char *name = "linear"); 00062 virtual ~linear_module(); 00064 virtual void fprop(Tstate &in, Tstate &out); 00066 virtual void bprop(Tstate &in, Tstate &out); 00068 virtual void bbprop(Tstate &in, Tstate &out); 00070 virtual int replicable_order() { return 1; } 00072 virtual void forget(forget_param_linear &fp); 00074 virtual void normalize(); 00077 virtual fidxdim fprop_size(fidxdim &i_size); 00080 virtual fidxdim bprop_size(const fidxdim &o_size); 00084 virtual linear_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL); 00086 virtual void load_x(idx<T> &weights); 00088 virtual std::string describe(); 00091 virtual void dump_fprop(Tstate &in, Tstate &out); 00092 00093 /* bool resize_output(Tstate &in, Tstate &out); */ 00094 00095 00096 // members //////////////////////////////////////////////////////// 00097 public: 00098 Tstate w; 00099 }; 00100 00109 DECLARE_REPLICABLE_MODULE_1_1(linear_module_replicable, 00110 linear_module, T, Tstate, 00111 (parameter<T,Tstate> *p, intg in, intg out, 00112 const char *name = "linear_replicable"), 00113 (p, in, out, name)); 00114 00116 // convolution_module 00122 template <typename T, class Tstate = bbstate_idx<T> > 00123 class convolution_module : public module_1_1<T, Tstate> { 00124 public: 00133 convolution_module(parameter<T,Tstate> *p, idxdim &ker, idxdim &stride, 00134 idx<intg> &table, const char *name = "convolution", 00135 bool crop = true); 00137 virtual ~convolution_module(); 00139 virtual void fprop(Tstate &in, Tstate &out); 00141 virtual void bprop(Tstate &in, Tstate &out); 00143 virtual void bbprop(Tstate &in, Tstate &out); 00145 virtual void forget(forget_param_linear &fp); 00147 virtual int replicable_order() { return 3; } 00150 virtual bool resize_output(Tstate &in, Tstate &out); 00153 virtual fidxdim fprop_size(fidxdim &i_size); 00156 virtual fidxdim bprop_size(const fidxdim &o_size); 00160 virtual convolution_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL); 00162 virtual void load_x(idx<T> &weights); 00164 virtual std::string describe(); 00167 virtual void dump_fprop(Tstate &in, Tstate &out); 00168 00169 // members //////////////////////////////////////////////////////// 00170 public: 00171 intg tablemax; 00172 Tstate kernel; 00173 intg thickness; 00174 idxdim ker; 00175 idxdim stride; 00176 idx<intg> table; 00177 protected: 00178 bool warnings_shown; 00179 bool fulltable; 00180 bool float_precision; 00181 bool double_precision; 00182 bool crop; 00183 // IPP members //////////////////////////////////////////////////////// 00184 idx<T> revkernel; 00185 idx<T> outtmp; 00186 bool ipp_err_printed; 00187 bool use_ipp; 00188 }; 00189 00198 DECLARE_REPLICABLE_MODULE_1_1(convolution_module_replicable, 00199 convolution_module, T, Tstate, 00200 (parameter<T,Tstate> *p, 00201 idxdim &ker, idxdim &stride, idx<intg> &table, 00202 const char *name = "convolution_replicable"), 00203 (p, ker, stride, table, name)); 00204 00206 // addc_module 00211 template <typename T, class Tstate = bbstate_idx<T> > 00212 class addc_module: public module_1_1<T, Tstate> { 00213 public: 00219 addc_module(parameter<T,Tstate> *p, intg size, const char *name = "addc"); 00221 virtual ~addc_module(); 00223 virtual void fprop(Tstate &in, Tstate &out); 00225 virtual void bprop(Tstate &in, Tstate &out); 00227 virtual void bbprop(Tstate &in, Tstate &out); 00229 virtual void forget(forget_param_linear &fp); 00233 virtual addc_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL); 00235 virtual void load_x(idx<T> &weights); 00237 virtual std::string describe(); 00240 virtual void dump_fprop(Tstate &in, Tstate &out); 00241 00242 // members //////////////////////////////////////////////////////// 00243 public: 00244 Tstate bias; 00245 }; 00246 00248 // power_module 00255 // TODO: write specialized modules square and sqrt to run faster 00256 template <typename T, class Tstate = bbstate_idx<T> > 00257 class power_module : public module_1_1<T,Tstate> { 00258 public: 00261 power_module(T p); 00263 virtual ~power_module(); 00265 virtual void fprop(Tstate &in, Tstate &out); 00267 virtual void bprop(Tstate &in, Tstate &out); 00269 virtual void bbprop(Tstate &in, Tstate &out); 00270 00271 // members //////////////////////////////////////////////////////// 00272 private: 00273 T p; 00274 idx<T> tt; 00275 }; 00276 00278 // diff_module 00281 template <typename T, class Tstate = bbstate_idx<T> > 00282 class diff_module : public module_2_1<T, Tstate> { 00283 public: 00285 diff_module(); 00287 virtual ~diff_module(); 00289 virtual void fprop(Tstate &in1, Tstate &in2, Tstate &out); 00291 virtual void bprop(Tstate &in1, Tstate &in2, Tstate &out); 00293 virtual void bbprop(Tstate &in1, Tstate &in2, Tstate &out); 00294 }; 00295 00297 // mul_module 00300 template <typename T, class Tstate = bbstate_idx<T> > 00301 class mul_module : public module_2_1<T, Tstate> { 00302 private: 00303 idx<T> tmp; 00304 00305 public: 00307 mul_module(); 00309 virtual ~mul_module(); 00311 virtual void fprop(Tstate &in1, Tstate &in2, Tstate &out); 00313 virtual void bprop(Tstate &in1, Tstate &in2, Tstate &out); 00315 virtual void bbprop(Tstate &in1, Tstate &in2, Tstate &out); 00316 }; 00317 00319 // thres_module 00322 template <typename T, class Tstate = bbstate_idx<T> > 00323 class thres_module : public module_1_1<T,Tstate> { 00324 public: 00325 T thres; 00326 T val; 00327 00328 public: 00333 thres_module(T thres, T val); 00335 virtual ~thres_module(); 00337 virtual void fprop(Tstate &in, Tstate &out); 00339 virtual void bprop(Tstate &in, Tstate &out); 00341 virtual void bbprop(Tstate &in, Tstate &out); 00342 }; 00343 00344 00346 // cutborder_module 00350 template <typename T, class Tstate = bbstate_idx<T> > 00351 class cutborder_module : module_1_1<T,Tstate> { 00352 private: 00353 int nrow, ncol; 00354 00355 public: 00360 cutborder_module(int nr, int nc); 00362 virtual ~cutborder_module(); 00364 virtual void fprop(Tstate &in, Tstate &out); 00366 virtual void bprop(Tstate &in, Tstate &out); 00368 virtual void bbprop(Tstate &in, Tstate &out); 00369 }; 00370 00372 // zpad_module 00375 template <typename T, class Tstate = bbstate_idx<T> > 00376 class zpad_module : public module_1_1<T,Tstate> { 00377 public: 00380 zpad_module(const char *name = "zpad"); 00386 zpad_module(int nrows, int ncolumns); 00393 zpad_module(int top, int left, int bottom, int right); 00398 zpad_module(idxdim &kernel_size, const char *name = "zpad"); 00403 zpad_module(midxdim &kernels, const char *name = "zpad"); 00405 virtual ~zpad_module(); 00406 virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out); 00408 virtual void fprop(Tstate &in, Tstate &out); 00410 virtual void fprop(Tstate &in, idx<T> &out); 00412 virtual void fprop(idx<T> &in, idx<T> &out); 00414 virtual void bprop(Tstate &in, Tstate &out); 00416 virtual void bbprop(Tstate &in, Tstate &out); 00418 virtual idxdim get_paddings(); 00421 virtual idxdim get_paddings(idxdim &kernel); 00424 virtual midxdim get_paddings(midxdim &kernels); 00426 virtual void set_paddings(int top, int left, int bottom, int right); 00428 virtual void set_paddings(idxdim &pads); 00430 virtual void set_kernel(idxdim &kernel); 00432 virtual void set_kernels(midxdim &kernels); 00435 virtual fidxdim fprop_size(fidxdim &i_size); 00438 virtual fidxdim bprop_size(const fidxdim &o_size); 00440 virtual mfidxdim fprop_size(mfidxdim &isize); 00442 virtual mfidxdim bprop_size(mfidxdim &osize); 00444 virtual std::string describe(); 00448 virtual zpad_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL); 00449 00450 protected: 00451 idxdim pad; 00452 midxdim pads; 00453 }; 00454 00456 // mirrorpad_module 00459 template <typename T, class Tstate = bbstate_idx<T> > 00460 class mirrorpad_module : public zpad_module<T,Tstate> { 00461 public: 00466 mirrorpad_module(int nr, int nc); 00471 mirrorpad_module(idxdim &kernel_size); 00473 virtual ~mirrorpad_module(); 00475 virtual void fprop(Tstate &in, Tstate &out); 00477 virtual void fprop(Tstate &in, idx<T> &out); 00481 virtual mirrorpad_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL); 00482 protected: 00483 using zpad_module<T,Tstate>::pad; 00484 }; 00485 00487 // fsum_module 00490 template <typename T, class Tstate = bbstate_idx<T> > 00491 class fsum_module : public module_1_1<T,Tstate> { 00492 public: 00497 fsum_module(bool div = false, float split = 1.0); 00499 virtual ~fsum_module(); 00501 virtual void fprop(Tstate &in, Tstate &out); 00503 virtual void bprop(Tstate &in, Tstate &out); 00505 virtual void bbprop(Tstate &in, Tstate &out); 00506 protected: 00507 bool div; 00508 float split; 00509 }; 00510 00512 // range_lut_module 00515 template <typename T, class Tstate = bbstate_idx<T> > 00516 class range_lut_module : public module_1_1<T,Tstate> { 00517 public: 00526 range_lut_module(idx<T> *value_range); 00528 virtual ~range_lut_module(); 00530 virtual void fprop(Tstate &in, Tstate &out); 00531 /* //! backward propagation from out to in */ 00532 /* virtual void bprop(Tstate &in, Tstate &out); */ 00533 /* //! second-derivative backward propagation from out to in */ 00534 /* virtual void bbprop(Tstate &in, Tstate &out); */ 00535 protected: 00536 idx<T> value_range; 00537 }; 00538 00540 // binarize_module 00543 template <typename T, class Tstate = bbstate_idx<T> > 00544 class binarize_module : public module_1_1<T,Tstate> { 00545 public: 00547 binarize_module(T threshold, T false_value, T true_value); 00549 virtual ~binarize_module(); 00551 virtual void fprop(Tstate &in, Tstate &out); 00552 /* //! backward propagation from out to in */ 00553 /* virtual void bprop(Tstate &in, Tstate &out); */ 00554 /* //! second-derivative backward propagation from out to in */ 00555 /* virtual void bbprop(Tstate &in, Tstate &out); */ 00556 protected: 00557 T threshold; 00558 T false_value; 00559 T true_value; 00560 }; 00561 00563 // diag_module 00565 template <typename T, class Tstate = bbstate_idx<T> > 00566 class diag_module : public module_1_1<T,Tstate> { 00567 public: 00572 diag_module(parameter<T,Tstate> *p, intg thickness, 00573 const char *name = "diag"); 00575 virtual ~diag_module(); 00577 virtual void fprop(Tstate &in, Tstate &out); 00579 virtual void bprop(Tstate &in, Tstate &out); 00581 virtual void bbprop(Tstate &in, Tstate &out); 00584 virtual bool resize_output(Tstate &in, Tstate &out); 00586 virtual void load_x(idx<T> &weights); 00588 virtual std::string describe(); 00592 virtual diag_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL); 00593 protected: 00594 Tstate coeff; 00595 }; 00596 00598 // copy_module 00601 template <typename T, class Tstate = bbstate_idx<T> > 00602 class copy_module : public module_1_1<T,Tstate> { 00603 public: 00605 copy_module(const char *name = "copy"); 00607 virtual ~copy_module(); 00609 virtual void fprop(Tstate &in, Tstate &out); 00611 virtual void bprop(Tstate &in, Tstate &out); 00613 virtual void bbprop(Tstate &in, Tstate &out); 00615 virtual std::string describe(); 00616 }; 00617 00619 // back_module 00620 template <typename T, class Tstate = bbstate_idx<T> > 00621 class back_module : public module_1_1<T,Tstate> { 00622 public: 00624 back_module(const char *name = "back"); 00626 virtual ~back_module(); 00628 virtual void fprop(Tstate &in, Tstate &out); 00631 virtual bool resize_output(Tstate &in, Tstate &out); 00633 virtual std::string describe(); 00638 virtual fidxdim bprop_size(const fidxdim &o_size); 00640 void bb(std::vector<bbox*> &boxes); 00641 00642 protected: 00643 idx<T> *s0; 00644 idx<T> *s1; 00645 idx<T> *s2; 00646 idxdim pixel_size; 00647 }; 00648 00650 // printer_module 00653 template <typename T, class Tstate = bbstate_idx<T> > 00654 class printer_module : module_1_1<T,Tstate> { 00655 00656 public: 00657 printer_module(const char *name = "printer"); 00659 virtual ~printer_module(); 00661 virtual void fprop(Tstate &in, Tstate &out); 00663 virtual void bprop(Tstate &in, Tstate &out); 00665 virtual void bbprop(Tstate &in, Tstate &out); 00666 }; 00667 00668 00669 00670 } // namespace ebl { 00671 00672 #include "ebl_basic.hpp" 00673 00674 #endif /* EBL_BASIC_H_ */