libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2011 by Yann LeCun, Pierre Sermanet and Soumith Chintala* 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com, soumith@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef EBL_POOLING_H_ 00034 #define EBL_POOLING_H_ 00035 00036 #include "ebl_defines.h" 00037 #include "libidx.h" 00038 #include "ebl_arch.h" 00039 #include "ebl_states.h" 00040 #include "ebl_utils.h" 00041 #include "ebl_preprocessing.h" 00042 00043 namespace ebl { 00044 00046 // subsampling_module 00052 template <typename T, class Tstate = bbstate_idx<T> > 00053 class subsampling_module: public module_1_1<T, Tstate> { 00054 public: 00065 subsampling_module(parameter<T,Tstate> *p, uint thickness, idxdim &kernel, 00066 idxdim &stride, const char *name = "subsampling", 00067 bool crop = true, bool pad = true); 00069 virtual ~subsampling_module(); 00071 virtual void fprop(Tstate &in, Tstate &out); 00073 virtual void bprop(Tstate &in, Tstate &out); 00075 virtual void bbprop(Tstate &in, Tstate &out); 00077 virtual void forget(forget_param_linear &fp); 00079 virtual int replicable_order() { return 3; } 00081 virtual bool resize_output(Tstate &in, Tstate &out); 00084 virtual fidxdim fprop_size(fidxdim &i_size); 00087 virtual fidxdim bprop_size(const fidxdim &o_size); 00089 virtual subsampling_module<T,Tstate>* copy(); 00091 virtual std::string describe(); 00094 virtual void dump_fprop(Tstate &in, Tstate &out); 00095 00096 // members //////////////////////////////////////////////////////// 00097 public: 00098 Tstate coeff; 00099 Tstate sub; 00100 uint thickness; 00101 idxdim kernel; 00102 idxdim stride; 00103 protected: 00104 bool crop; 00105 bool pad; 00106 }; 00107 00116 DECLARE_REPLICABLE_MODULE_1_1(subsampling_module_replicable, 00117 subsampling_module, T, Tstate, 00118 (parameter<T,Tstate> *p, uint thickness, 00119 idxdim &kernel, idxdim &strides, 00120 const char *name = "subsampling_replicable"), 00121 (p, thickness, kernel, strides, name)); 00122 00124 // lppooling_module 00125 00128 template <typename T, class Tstate = bbstate_idx<T> > 00129 class lppooling_module: public module_1_1<T, Tstate> { 00130 public: 00139 lppooling_module(uint thickness, idxdim &kernel, idxdim &stride, 00140 uint lppower = 2, 00141 const char *name = "lppooling", 00142 bool crop = true); 00144 virtual ~lppooling_module(); 00146 virtual void fprop(Tstate &in, Tstate &out); 00148 virtual void bprop(Tstate &in, Tstate &out); 00150 virtual void bbprop(Tstate &in, Tstate &out); 00153 virtual fidxdim fprop_size(fidxdim &i_size); 00156 virtual fidxdim bprop_size(const fidxdim &o_size); 00158 virtual lppooling_module<T,Tstate>* copy(); 00160 virtual std::string describe(); 00163 virtual void dump_fprop(Tstate &in, Tstate &out); 00164 00165 // members //////////////////////////////////////////////////////// 00166 protected: 00167 uint thickness; 00168 idxdim kernel; 00169 idxdim stride; 00170 bool crop; 00171 uint lp_pow; 00172 convolution_module<T,Tstate> *conv; 00173 power_module<T,Tstate> sqmod; 00174 power_module<T,Tstate> sqrtmod; 00175 Tstate squared, convolved; 00176 parameter<T,Tstate> param; 00177 }; 00178 00180 // wavg_pooling_module 00181 00184 template <typename T, class Tstate = bbstate_idx<T> > 00185 class wavg_pooling_module: public module_1_1<T, Tstate> { 00186 public: 00195 wavg_pooling_module(uint thickness, idxdim &kernel, idxdim &stride, 00196 const char *name = "wavg_pooling", bool crop = true); 00198 virtual ~wavg_pooling_module(); 00200 virtual void fprop(Tstate &in, Tstate &out); 00202 virtual void bprop(Tstate &in, Tstate &out); 00204 virtual void bbprop(Tstate &in, Tstate &out); 00207 virtual fidxdim fprop_size(fidxdim &i_size); 00210 virtual fidxdim bprop_size(const fidxdim &o_size); 00212 virtual wavg_pooling_module<T,Tstate>* copy(); 00214 virtual std::string describe(); 00217 virtual void dump_fprop(Tstate &in, Tstate &out); 00218 00219 // members //////////////////////////////////////////////////////// 00220 protected: 00221 uint thickness; 00222 idxdim kernel; 00223 idxdim stride; 00224 bool crop; 00225 convolution_module<T,Tstate> *conv; 00226 parameter<T,Tstate> param; 00227 }; 00228 00230 // pyramid_module 00232 template <typename T, class Tstate = bbstate_idx<T> > 00233 class pyramid_module 00234 : public resizepp_module<T,Tstate>, public s2m_module<T,Tstate> { 00235 public: 00249 pyramid_module(uint nscales, float scaling_ratio, idxdim &dsize, 00250 uint mode = MEAN_RESIZE, module_1_1<T,Tstate> *pp = NULL, 00251 bool own_pp = false, idxdim *dzpad = NULL, 00252 const char *name = "pyramid_module"); 00266 pyramid_module(uint nscales, float scaling_ratio, uint mode = MEAN_RESIZE, 00267 module_1_1<T,Tstate> *pp = NULL, bool own_pp = false, 00268 idxdim *dzpad = NULL, const char *name = "pyramid_module"); 00269 virtual ~pyramid_module(); 00271 virtual void fprop(Tstate &in, Tstate &out); 00274 virtual void fprop(Tstate &in, mstate<Tstate> &out); 00277 virtual void fprop(Tstate &in, midx<T> &out); 00279 virtual void bprop(Tstate &in, mstate<Tstate> &out); 00281 virtual void bbprop(Tstate &in, mstate<Tstate> &out); 00283 virtual mfidxdim bprop_size(mfidxdim &osize); 00285 virtual std::string describe(); 00286 /* //! Returns bounding boxes of each scale in the input space. */ 00287 /* const vector<rect<int> >& get_input_bboxes(); */ 00288 /* //! Returns bounding boxes of each scale in the output space. */ 00289 /* const vector<rect<int> >& get_original_bboxes(); */ 00290 00291 // members //////////////////////////////////////////////////////// 00292 protected: 00293 uint nscales; 00294 float scaling_ratio; 00295 }; 00296 00298 // average_pyramid_module 00300 template <typename T, class Tstate = bbstate_idx<T> > 00301 class average_pyramid_module : public s2m_module<T,Tstate> { 00302 public: 00308 average_pyramid_module(parameter<T,Tstate> *p, uint thickness, 00309 midxdim &strides, 00310 const char *name = "average_pyramid_module"); 00311 virtual ~average_pyramid_module(); 00313 virtual void fprop(Tstate &in, Tstate &out); 00316 virtual void fprop(Tstate &in, mstate<Tstate> &out); 00318 virtual void bprop(Tstate &in, mstate<Tstate> &out); 00320 virtual void bbprop(Tstate &in, mstate<Tstate> &out); 00322 virtual std::string describe(); 00324 virtual mfidxdim bprop_size(mfidxdim &osize); 00325 00326 // members //////////////////////////////////////////////////////// 00327 protected: 00328 midxdim strides; 00329 std::vector<subsampling_module<T,Tstate>*> mods; 00330 bool well_behaved; 00331 }; 00332 00334 // maxss_module 00336 template <typename T, class Tstate = bbstate_idx<T> > 00337 class maxss_module : public module_1_1<T,Tstate> { 00338 public: 00343 maxss_module(uint thickness, idxdim &kernel, idxdim &stride, 00344 const char *name = "maxss"); 00346 virtual ~maxss_module(); 00348 virtual void fprop(Tstate &in, Tstate &out); 00350 virtual void bprop(Tstate &in, Tstate &out); 00352 virtual void bbprop(Tstate &in, Tstate &out); 00355 virtual bool resize_output(Tstate &in, Tstate &out); 00357 virtual int replicable_order() { return 3; } 00360 virtual fidxdim fprop_size(fidxdim &i_size); 00363 virtual fidxdim bprop_size(const fidxdim &o_size); 00367 virtual maxss_module<T,Tstate>* copy(parameter<T,Tstate> *p = NULL); 00369 virtual std::string describe(); 00370 // members //////////////////////////////////////////////////////// 00371 protected: 00372 uint thickness; 00373 idxdim kernel; 00374 idxdim stride; 00375 idx<int> switches; 00376 bool float_precision; 00377 bool double_precision; 00378 //#ifdef __TH__ 00379 idx<T> indices; 00380 //#endif 00381 }; 00382 00383 } // namespace ebl { 00384 00385 #include "ebl_pooling.hpp" 00386 00387 #endif /* EBL_POOLING_H_ */