libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_basic.h
00001 /***************************************************************************
00002  *   Copyright (C) 2011 by Yann LeCun, Pierre Sermanet and Soumith Chintala*
00003  *   yann@cs.nyu.edu, pierre.sermanet@gmail.com, soumith@gmail.com  *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef EBL_BASIC_H_
00034 #define EBL_BASIC_H_
00035 
00036 #include "ebl_defines.h"
00037 #include "libidx.h"
00038 #include "ebl_arch.h"
00039 #include "ebl_states.h"
00040 #include "ebl_utils.h"
00041 #include "bbox.h"
00042 
00043 namespace ebl {
00044 
00046   // linear_module
00051   template <typename T, class Tstate = bbstate_idx<T> >
00052     class linear_module: public module_1_1<T, Tstate> {
00053   public:
00059     linear_module(parameter<T,Tstate> *p, intg in, intg out,
00060                   const char *name = "linear");
00062     virtual ~linear_module();
00064     virtual void fprop(Tstate &in, Tstate &out);
00066     virtual void bprop(Tstate &in, Tstate &out);
00068     virtual void bbprop(Tstate &in, Tstate &out);
00070     virtual int replicable_order() { return 1; }
00072     virtual void forget(forget_param_linear &fp);
00074     virtual void normalize();
00077     virtual fidxdim fprop_size(fidxdim &i_size);
00080     virtual fidxdim bprop_size(const fidxdim &o_size);
00084     virtual linear_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL);
00086     virtual void load_x(idx<T> &weights);
00088     virtual std::string describe();
00091     virtual void dump_fprop(Tstate &in, Tstate &out);
00092 
00093   /* bool resize_output(Tstate &in, Tstate &out); */
00094 
00095 
00096     // members ////////////////////////////////////////////////////////
00097   public:
00098     Tstate w;
00099   };
00100 
00109   DECLARE_REPLICABLE_MODULE_1_1(linear_module_replicable,
00110                                 linear_module, T, Tstate,
00111                                 (parameter<T,Tstate> *p, intg in, intg out,
00112                                  const char *name = "linear_replicable"),
00113                                 (p, in, out, name));
00114 
00116   // convolution_module
00122   template <typename T, class Tstate = bbstate_idx<T> >
00123     class convolution_module : public module_1_1<T, Tstate> {
00124   public:
00133     convolution_module(parameter<T,Tstate> *p, idxdim &ker, idxdim &stride,
00134                        idx<intg> &table, const char *name = "convolution",
00135                        bool crop = true);
00137     virtual ~convolution_module();
00139     virtual void fprop(Tstate &in, Tstate &out);
00141     virtual void bprop(Tstate &in, Tstate &out);
00143     virtual void bbprop(Tstate &in, Tstate &out);
00145     virtual void forget(forget_param_linear &fp);
00147     virtual int replicable_order() { return 3; }
00150     virtual bool resize_output(Tstate &in, Tstate &out);
00153     virtual fidxdim fprop_size(fidxdim &i_size);
00156     virtual fidxdim bprop_size(const fidxdim &o_size);
00160     virtual convolution_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL);
00162     virtual void load_x(idx<T> &weights);
00164     virtual std::string describe();
00167     virtual void dump_fprop(Tstate &in, Tstate &out);
00168 
00169     // members ////////////////////////////////////////////////////////
00170   public:
00171     intg                tablemax;
00172     Tstate              kernel;
00173     intg                thickness;
00174     idxdim              ker;
00175     idxdim              stride;
00176     idx<intg>           table;  
00177   protected:
00178     bool                warnings_shown;
00179     bool                fulltable; 
00180     bool                float_precision; 
00181     bool                double_precision; 
00182     bool                crop; 
00183   // IPP members ////////////////////////////////////////////////////////
00184     idx<T>              revkernel; 
00185     idx<T>              outtmp; 
00186     bool                ipp_err_printed; 
00187     bool                use_ipp; 
00188   };
00189 
00198   DECLARE_REPLICABLE_MODULE_1_1(convolution_module_replicable,
00199                                 convolution_module, T, Tstate,
00200                                 (parameter<T,Tstate> *p,
00201                                  idxdim &ker, idxdim &stride, idx<intg> &table,
00202                                  const char *name = "convolution_replicable"),
00203                                 (p, ker, stride, table, name));
00204 
00206   // addc_module
00211   template <typename T, class Tstate = bbstate_idx<T> >
00212     class addc_module: public module_1_1<T, Tstate> {
00213   public:
00219     addc_module(parameter<T,Tstate> *p, intg size, const char *name = "addc");
00221     virtual ~addc_module();
00223     virtual void fprop(Tstate &in, Tstate &out);
00225     virtual void bprop(Tstate &in, Tstate &out);
00227     virtual void bbprop(Tstate &in, Tstate &out);
00229     virtual void forget(forget_param_linear &fp);
00233     virtual addc_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL);
00235     virtual void load_x(idx<T> &weights);
00237     virtual std::string describe();
00240     virtual void dump_fprop(Tstate &in, Tstate &out);
00241 
00242     // members ////////////////////////////////////////////////////////
00243   public:
00244     Tstate  bias; 
00245   };
00246 
00248   // power_module
00255   // TODO: write specialized modules square and sqrt to run faster
00256   template <typename T, class Tstate = bbstate_idx<T> >
00257     class power_module : public module_1_1<T,Tstate> {
00258   public:
00261     power_module(T p);
00263     virtual ~power_module();
00265     virtual void fprop(Tstate &in, Tstate &out);
00267     virtual void bprop(Tstate &in, Tstate &out);
00269     virtual void bbprop(Tstate &in, Tstate &out);
00270 
00271     // members ////////////////////////////////////////////////////////
00272   private:
00273     T p;
00274     idx<T> tt; 
00275   };
00276 
00278   // diff_module
00281   template <typename T, class Tstate = bbstate_idx<T> >
00282     class diff_module : public module_2_1<T, Tstate> {
00283   public:
00285     diff_module();
00287     virtual ~diff_module();
00289     virtual void fprop(Tstate &in1, Tstate &in2, Tstate &out);
00291     virtual void bprop(Tstate &in1, Tstate &in2, Tstate &out);
00293     virtual void bbprop(Tstate &in1, Tstate &in2, Tstate &out);
00294   };
00295 
00297   // mul_module
00300   template <typename T, class Tstate = bbstate_idx<T> >
00301     class mul_module : public module_2_1<T, Tstate> {
00302   private:
00303     idx<T> tmp; 
00304 
00305   public:
00307     mul_module();
00309     virtual ~mul_module();
00311     virtual void fprop(Tstate &in1, Tstate &in2, Tstate &out);
00313     virtual void bprop(Tstate &in1, Tstate &in2, Tstate &out);
00315     virtual void bbprop(Tstate &in1, Tstate &in2, Tstate &out);
00316   };
00317 
00319   // thres_module
00322   template <typename T, class Tstate = bbstate_idx<T> >
00323     class thres_module : public module_1_1<T,Tstate> {
00324   public:
00325     T thres;
00326     T val;
00327 
00328   public:
00333     thres_module(T thres, T val);
00335     virtual ~thres_module();
00337     virtual void fprop(Tstate &in, Tstate &out);
00339     virtual void bprop(Tstate &in, Tstate &out);
00341     virtual void bbprop(Tstate &in, Tstate &out);
00342   };
00343 
00344 
00346   // cutborder_module
00350   template <typename T, class Tstate = bbstate_idx<T> >
00351     class cutborder_module : module_1_1<T,Tstate> {
00352   private:
00353     int nrow, ncol;
00354 
00355   public:
00360     cutborder_module(int nr, int nc);
00362     virtual ~cutborder_module();
00364     virtual void fprop(Tstate &in, Tstate &out);
00366     virtual void bprop(Tstate &in, Tstate &out);
00368     virtual void bbprop(Tstate &in, Tstate &out);
00369   };
00370 
00372   // zpad_module
00375   template <typename T, class Tstate = bbstate_idx<T> >
00376     class zpad_module : public module_1_1<T,Tstate> {
00377   public:
00380     zpad_module(const char *name = "zpad");
00386     zpad_module(int nrows, int ncolumns);
00393     zpad_module(int top, int left, int bottom, int right);
00398     zpad_module(idxdim &kernel_size, const char *name = "zpad");
00403     zpad_module(midxdim &kernels, const char *name = "zpad");
00405     virtual ~zpad_module();
00406     virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out);
00408     virtual void fprop(Tstate &in, Tstate &out);
00410     virtual void fprop(Tstate &in, idx<T> &out);
00412     virtual void fprop(idx<T> &in, idx<T> &out);
00414     virtual void bprop(Tstate &in, Tstate &out);
00416     virtual void bbprop(Tstate &in, Tstate &out);
00418     virtual idxdim get_paddings();
00421     virtual idxdim get_paddings(idxdim &kernel);
00424     virtual midxdim get_paddings(midxdim &kernels);
00426     virtual void set_paddings(int top, int left, int bottom, int right);
00428     virtual void set_paddings(idxdim &pads);
00430     virtual void set_kernel(idxdim &kernel);
00432     virtual void set_kernels(midxdim &kernels);
00435     virtual fidxdim fprop_size(fidxdim &i_size);
00438     virtual fidxdim bprop_size(const fidxdim &o_size);
00440     virtual mfidxdim fprop_size(mfidxdim &isize);
00442     virtual mfidxdim bprop_size(mfidxdim &osize);
00444     virtual std::string describe();
00448     virtual zpad_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL);
00449 
00450   protected:
00451     idxdim pad; 
00452     midxdim pads; 
00453   };
00454 
00456   // mirrorpad_module
00459   template <typename T, class Tstate = bbstate_idx<T> >
00460     class mirrorpad_module : public zpad_module<T,Tstate> {
00461   public:
00466     mirrorpad_module(int nr, int nc);
00471     mirrorpad_module(idxdim &kernel_size);
00473     virtual ~mirrorpad_module();
00475     virtual void fprop(Tstate &in, Tstate &out);
00477     virtual void fprop(Tstate &in, idx<T> &out);
00481     virtual mirrorpad_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL);
00482   protected:
00483     using zpad_module<T,Tstate>::pad;
00484   };
00485 
00487   // fsum_module
00490   template <typename T, class Tstate = bbstate_idx<T> >
00491     class fsum_module : public module_1_1<T,Tstate> {
00492   public:
00497     fsum_module(bool div = false, float split = 1.0);
00499     virtual ~fsum_module();
00501     virtual void fprop(Tstate &in, Tstate &out);
00503     virtual void bprop(Tstate &in, Tstate &out);
00505     virtual void bbprop(Tstate &in, Tstate &out);
00506   protected:
00507     bool div; 
00508     float split; 
00509   };
00510 
00512   // range_lut_module
00515   template <typename T, class Tstate = bbstate_idx<T> >
00516     class range_lut_module : public module_1_1<T,Tstate> {
00517   public:
00526     range_lut_module(idx<T> *value_range);
00528     virtual ~range_lut_module();
00530     virtual void fprop(Tstate &in, Tstate &out);
00531     /* //! backward propagation from out to in */
00532     /* virtual void bprop(Tstate &in, Tstate &out); */
00533     /* //! second-derivative backward propagation from out to in */
00534     /* virtual void bbprop(Tstate &in, Tstate &out); */
00535   protected:
00536     idx<T>      value_range;
00537   };
00538 
00540   // binarize_module
00543   template <typename T, class Tstate = bbstate_idx<T> >
00544     class binarize_module : public module_1_1<T,Tstate> {
00545   public:
00547     binarize_module(T threshold, T false_value, T true_value);
00549     virtual ~binarize_module();
00551     virtual void fprop(Tstate &in, Tstate &out);
00552     /* //! backward propagation from out to in */
00553     /* virtual void bprop(Tstate &in, Tstate &out); */
00554     /* //! second-derivative backward propagation from out to in */
00555     /* virtual void bbprop(Tstate &in, Tstate &out); */
00556   protected:
00557     T   threshold;
00558     T   false_value;
00559     T   true_value;
00560   };
00561 
00563   // diag_module
00565   template <typename T, class Tstate = bbstate_idx<T> >
00566     class diag_module : public module_1_1<T,Tstate> {
00567   public:
00572     diag_module(parameter<T,Tstate> *p, intg thickness,
00573                 const char *name = "diag");
00575     virtual ~diag_module();
00577     virtual void fprop(Tstate &in, Tstate &out);
00579     virtual void bprop(Tstate &in, Tstate &out);
00581     virtual void bbprop(Tstate &in, Tstate &out);
00584     virtual bool resize_output(Tstate &in, Tstate &out);
00586     virtual void load_x(idx<T> &weights);
00588     virtual std::string describe();
00592     virtual diag_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL);
00593   protected:
00594     Tstate      coeff;
00595   };
00596 
00598   // copy_module
00601   template <typename T, class Tstate = bbstate_idx<T> >
00602     class copy_module : public module_1_1<T,Tstate> {
00603   public:
00605     copy_module(const char *name = "copy");
00607     virtual ~copy_module();
00609     virtual void fprop(Tstate &in, Tstate &out);
00611     virtual void bprop(Tstate &in, Tstate &out);
00613     virtual void bbprop(Tstate &in, Tstate &out);
00615     virtual std::string describe();
00616   };
00617 
00619   // back_module
00620   template <typename T, class Tstate = bbstate_idx<T> >
00621     class back_module : public module_1_1<T,Tstate> {
00622   public:
00624     back_module(const char *name = "back");
00626     virtual ~back_module();
00628     virtual void fprop(Tstate &in, Tstate &out);
00631     virtual bool resize_output(Tstate &in, Tstate &out);
00633     virtual std::string describe();
00638     virtual fidxdim bprop_size(const fidxdim &o_size);
00640     void bb(std::vector<bbox*> &boxes);
00641 
00642   protected:
00643     idx<T>      *s0;
00644     idx<T>      *s1;
00645     idx<T>      *s2;
00646     idxdim       pixel_size;
00647   };
00648 
00650   // printer_module
00653   template <typename T, class Tstate = bbstate_idx<T> >
00654     class printer_module : module_1_1<T,Tstate> {
00655 
00656   public:
00657     printer_module(const char *name = "printer");
00659     virtual ~printer_module();
00661     virtual void fprop(Tstate &in, Tstate &out);
00663     virtual void bprop(Tstate &in, Tstate &out);
00665     virtual void bbprop(Tstate &in, Tstate &out);
00666   };
00667 
00668 
00669 
00670 } // namespace ebl {
00671 
00672 #include "ebl_basic.hpp"
00673 
00674 #endif /* EBL_BASIC_H_ */