libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_pooling.h
00001 /***************************************************************************
00002  *   Copyright (C) 2011 by Yann LeCun, Pierre Sermanet and Soumith Chintala*
00003  *   yann@cs.nyu.edu, pierre.sermanet@gmail.com, soumith@gmail.com  *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef EBL_POOLING_H_
00034 #define EBL_POOLING_H_
00035 
00036 #include "ebl_defines.h"
00037 #include "libidx.h"
00038 #include "ebl_arch.h"
00039 #include "ebl_states.h"
00040 #include "ebl_utils.h"
00041 #include "ebl_preprocessing.h"
00042 
00043 namespace ebl {
00044 
00046   // subsampling_module
00052   template <typename T, class Tstate = bbstate_idx<T> >
00053     class subsampling_module: public module_1_1<T, Tstate> {
00054   public:
00065     subsampling_module(parameter<T,Tstate> *p, uint thickness, idxdim &kernel,
00066                        idxdim &stride, const char *name = "subsampling",
00067                        bool crop = true, bool pad = true);
00069     virtual ~subsampling_module();
00071     virtual void fprop(Tstate &in, Tstate &out);
00073     virtual void bprop(Tstate &in, Tstate &out);
00075     virtual void bbprop(Tstate &in, Tstate &out);
00077     virtual void forget(forget_param_linear &fp);
00079     virtual int replicable_order() { return 3; }
00081     virtual bool resize_output(Tstate &in, Tstate &out);
00084     virtual fidxdim fprop_size(fidxdim &i_size);
00087     virtual fidxdim bprop_size(const fidxdim &o_size);
00089     virtual subsampling_module<T,Tstate>* copy();
00091     virtual std::string describe();
00094     virtual void dump_fprop(Tstate &in, Tstate &out);
00095 
00096     // members ////////////////////////////////////////////////////////
00097   public:
00098     Tstate              coeff; 
00099     Tstate              sub; 
00100     uint                thickness; 
00101     idxdim              kernel; 
00102     idxdim              stride; 
00103   protected:
00104     bool                crop; 
00105     bool                pad; 
00106   };
00107 
00116   DECLARE_REPLICABLE_MODULE_1_1(subsampling_module_replicable, 
00117                                 subsampling_module, T, Tstate,
00118                                 (parameter<T,Tstate> *p, uint thickness,
00119                                  idxdim &kernel, idxdim &strides,
00120                                  const char *name = "subsampling_replicable"),
00121                                 (p, thickness, kernel, strides, name));
00122   
00124   // lppooling_module
00125   
00128   template <typename T, class Tstate = bbstate_idx<T> >
00129     class lppooling_module: public module_1_1<T, Tstate> {
00130   public:
00139   lppooling_module(uint thickness, idxdim &kernel, idxdim &stride,
00140                      uint lppower = 2, 
00141                      const char *name = "lppooling", 
00142                      bool crop = true);
00144     virtual ~lppooling_module();
00146     virtual void fprop(Tstate &in, Tstate &out);
00148     virtual void bprop(Tstate &in, Tstate &out);
00150     virtual void bbprop(Tstate &in, Tstate &out);
00153     virtual fidxdim fprop_size(fidxdim &i_size);
00156     virtual fidxdim bprop_size(const fidxdim &o_size);
00158     virtual lppooling_module<T,Tstate>* copy();
00160     virtual std::string describe();
00163     virtual void dump_fprop(Tstate &in, Tstate &out);
00164 
00165     // members ////////////////////////////////////////////////////////
00166   protected:
00167     uint                thickness; 
00168     idxdim              kernel; 
00169     idxdim              stride; 
00170     bool                crop; 
00171     uint                lp_pow; 
00172     convolution_module<T,Tstate> *conv;
00173     power_module<T,Tstate> sqmod;
00174     power_module<T,Tstate> sqrtmod;
00175     Tstate squared, convolved; 
00176     parameter<T,Tstate> param;
00177   };
00178 
00180   // wavg_pooling_module
00181   
00184   template <typename T, class Tstate = bbstate_idx<T> >
00185     class wavg_pooling_module: public module_1_1<T, Tstate> {
00186   public:
00195     wavg_pooling_module(uint thickness, idxdim &kernel, idxdim &stride,
00196                      const char *name = "wavg_pooling", bool crop = true);
00198     virtual ~wavg_pooling_module();
00200     virtual void fprop(Tstate &in, Tstate &out);
00202     virtual void bprop(Tstate &in, Tstate &out);
00204     virtual void bbprop(Tstate &in, Tstate &out);
00207     virtual fidxdim fprop_size(fidxdim &i_size);
00210     virtual fidxdim bprop_size(const fidxdim &o_size);
00212     virtual wavg_pooling_module<T,Tstate>* copy();
00214     virtual std::string describe();
00217     virtual void dump_fprop(Tstate &in, Tstate &out);
00218 
00219     // members ////////////////////////////////////////////////////////
00220   protected:
00221     uint                thickness; 
00222     idxdim              kernel; 
00223     idxdim              stride; 
00224     bool                crop; 
00225     convolution_module<T,Tstate> *conv;
00226     parameter<T,Tstate> param;
00227   };
00228 
00230   // pyramid_module
00232   template <typename T, class Tstate = bbstate_idx<T> >
00233     class pyramid_module
00234     : public resizepp_module<T,Tstate>, public s2m_module<T,Tstate> {
00235   public:
00249     pyramid_module(uint nscales, float scaling_ratio, idxdim &dsize,
00250                    uint mode = MEAN_RESIZE, module_1_1<T,Tstate> *pp = NULL,
00251                    bool own_pp = false, idxdim *dzpad = NULL, 
00252                    const char *name = "pyramid_module");
00266     pyramid_module(uint nscales, float scaling_ratio, uint mode = MEAN_RESIZE,
00267                    module_1_1<T,Tstate> *pp = NULL, bool own_pp = false,
00268                    idxdim *dzpad = NULL, const char *name = "pyramid_module");
00269     virtual ~pyramid_module();
00271     virtual void fprop(Tstate &in, Tstate &out);
00274     virtual void fprop(Tstate &in, mstate<Tstate> &out);
00277     virtual void fprop(Tstate &in, midx<T> &out);
00279     virtual void bprop(Tstate &in, mstate<Tstate> &out);
00281     virtual void bbprop(Tstate &in, mstate<Tstate> &out);
00283     virtual mfidxdim bprop_size(mfidxdim &osize);
00285     virtual std::string describe();
00286     /* //! Returns bounding boxes of each scale in the input space. */
00287     /* const vector<rect<int> >& get_input_bboxes(); */
00288     /* //! Returns bounding boxes of each scale in the output space. */
00289     /* const vector<rect<int> >& get_original_bboxes(); */
00290 
00291     // members ////////////////////////////////////////////////////////
00292   protected:
00293     uint nscales; 
00294     float scaling_ratio; 
00295   };
00296   
00298   // average_pyramid_module
00300   template <typename T, class Tstate = bbstate_idx<T> >
00301     class average_pyramid_module : public s2m_module<T,Tstate> {
00302   public:
00308     average_pyramid_module(parameter<T,Tstate> *p, uint thickness,
00309                            midxdim &strides, 
00310                            const char *name = "average_pyramid_module");
00311     virtual ~average_pyramid_module();
00313     virtual void fprop(Tstate &in, Tstate &out);
00316     virtual void fprop(Tstate &in, mstate<Tstate> &out);
00318     virtual void bprop(Tstate &in, mstate<Tstate> &out);
00320     virtual void bbprop(Tstate &in, mstate<Tstate> &out);
00322     virtual std::string describe();
00324     virtual mfidxdim bprop_size(mfidxdim &osize);
00325 
00326     // members ////////////////////////////////////////////////////////
00327   protected:
00328     midxdim strides; 
00329     std::vector<subsampling_module<T,Tstate>*> mods; 
00330     bool well_behaved; 
00331   };
00332 
00334   // maxss_module
00336   template <typename T, class Tstate = bbstate_idx<T> >
00337     class maxss_module : public module_1_1<T,Tstate> {
00338   public:
00343     maxss_module(uint thickness, idxdim &kernel, idxdim &stride,
00344                  const char *name = "maxss");
00346     virtual ~maxss_module();
00348     virtual void fprop(Tstate &in, Tstate &out);
00350     virtual void bprop(Tstate &in, Tstate &out);
00352     virtual void bbprop(Tstate &in, Tstate &out);
00355     virtual bool resize_output(Tstate &in, Tstate &out);
00357     virtual int replicable_order() { return 3; }
00360     virtual fidxdim fprop_size(fidxdim &i_size);
00363     virtual fidxdim bprop_size(const fidxdim &o_size);
00367     virtual maxss_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL);
00369     virtual std::string describe();
00370     // members ////////////////////////////////////////////////////////
00371   protected:
00372     uint        thickness;    
00373     idxdim      kernel;       
00374     idxdim      stride;       
00375     idx<int>    switches;     
00376     bool float_precision;     
00377     bool double_precision;    
00378   //#ifdef __TH__
00379     idx<T>       indices;     
00380   //#endif
00381   };
00382   
00383 } // namespace ebl {
00384 
00385 #include "ebl_pooling.hpp"
00386 
00387 #endif /* EBL_POOLING_H_ */