libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2011 by Pierre Sermanet * 00003 * pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef EBL_MERGE_H_ 00034 #define EBL_MERGE_H_ 00035 00036 #include "ebl_arch.h" 00037 #include "ebl_march.h" 00038 00039 namespace ebl { 00040 00041 // flat_merge //////////////////////////////////////////////////////////////// 00042 00044 template <typename T, class Tstate> class zpad_module; 00045 00050 template <typename T, class Tstate = bbstate_idx<T> > 00051 class flat_merge_module : public m2s_module<T, Tstate> { 00052 public: 00056 flat_merge_module(std::vector<Tstate**> &inputs, 00057 idxdim &in, midxdim &ins, 00058 fidxdim &stride, mfidxdim &strides, 00059 const char *name_ = "flatmerge", const char *list = NULL); 00063 flat_merge_module(std::vector<mstate<Tstate>**> &inputs, 00064 idxdim &in, midxdim &ins, 00065 fidxdim &stride, mfidxdim &strides, 00066 const char *name_ = "flatmerge", const char *list = NULL); 00073 flat_merge_module(midxdim &ins, mfidxdim &strides, bool pad = false, 00074 const char *name_ = "flatmerge", mfidxdim *scales = NULL, 00075 intg hextra = 0, intg wextra = 0, float ss = 1, 00076 float edge = 0); 00077 virtual ~flat_merge_module(); 00079 // generic states methods 00081 virtual void fprop(Tstate &in, Tstate &out); 00083 virtual void bprop(Tstate &in, Tstate &out); 00085 virtual void bbprop(Tstate &in, Tstate &out); 00087 // multi-states methods 00089 virtual void fprop(mstate<Tstate> &in, Tstate &out); 00091 virtual void bprop(mstate<Tstate> &in, Tstate &out); 00093 virtual void bbprop(mstate<Tstate> &in, Tstate &out); 00097 virtual idxdim fprop_size(idxdim &i_size); 00100 virtual fidxdim bprop_size(const fidxdim &o_size); 00102 virtual mfidxdim bprop_size(mfidxdim &osize); 00104 virtual std::string describe(); 00106 virtual uint get_ninputs(); 00108 virtual mfidxdim get_strides(); 00110 virtual mfidxdim get_scales(); 00112 virtual flat_merge_module<T,Tstate>* copy(); 00113 /* //! Set paddings to be applied to each input scale. */ 00114 /* virtual void set_paddings(mfidxdim &pads); */ 00116 virtual void set_offsets(vector<vector<int> > &off); 00118 virtual void set_strides(mfidxdim &s); 00119 00120 protected: 00128 idxdim compute_pad(idxdim &window, float subsampling, float edge, 00129 float scale, fidxdim &stride); 00130 00131 private: 00132 std::vector<Tstate**> inputs; 00133 idxdim din; 00134 midxdim dins; 00135 fidxdim stride; 00136 mfidxdim strides; 00137 std::string merge_list; 00138 std::vector<zpad_module<T,Tstate>*> zpads; 00139 Tstate *in0; 00140 std::vector<Tstate*> pinputs; 00141 bool use_pinputs; 00142 mstate<Tstate> padded; 00143 zpad_module<T,Tstate> padder; 00144 bool bpad; 00145 mfidxdim scales; 00146 mfidxdim paddings; 00147 vector<vector<int> > offsets; 00148 00149 // TEMP 00150 intg hextra, wextra; 00151 float subsampling, edge; 00152 }; 00153 00154 // mstate_merge ////////////////////////////////////////////////////////////// 00155 00157 template <typename T, class Tstate = bbstate_idx<T> > 00158 class mstate_merge_module : public module_1_1<T, Tstate> { 00159 public: 00163 mstate_merge_module(midxdim &ins, mfidxdim &strides, 00164 const char *name_ = "mstate_merge"); 00165 virtual ~mstate_merge_module(); 00166 // multi-states methods //////////////////////////////////////////////////// 00167 virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out); 00168 virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out); 00169 virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out); 00173 virtual idxdim fprop_size(idxdim &i_size); 00176 virtual fidxdim bprop_size(const fidxdim &o_size); 00178 virtual std::string describe(); 00179 00180 private: 00181 midxdim dins; 00182 mfidxdim dstrides; 00183 }; 00184 00185 // merge ///////////////////////////////////////////////////////////////////// 00186 00191 template <typename T, class Tstate = bbstate_idx<T> > 00192 class merge_module : public module_1_1<T, Tstate> { 00193 public: 00199 merge_module(std::vector<Tstate**> &inputs, intg concat_dim, 00200 const char *name_ = "merge", const char *list = NULL); 00202 merge_module(std::vector<mstate<Tstate>**> &inputs, intg concat_dim, 00203 const char *name_ = "merge", const char *list = NULL); 00207 merge_module(std::vector<std::vector<uint> > &states, intg concat_dim, 00208 const char *name_ = "merge"); 00209 virtual ~merge_module(); 00211 virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out); 00213 virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out); 00215 virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out); 00217 virtual void fprop(Tstate &in, Tstate &out); 00219 virtual void bprop(Tstate &in, Tstate &out); 00221 virtual void bbprop(Tstate &in, Tstate &out); 00223 virtual std::string describe(); 00224 00225 // internal members //////////////////////////////////////////////////////// 00226 protected: 00228 virtual void merge(mstate<Tstate> &in, Tstate &out); 00229 00230 private: 00231 std::vector<Tstate**> inputs; 00232 std::vector<mstate<Tstate>**> msinputs; 00233 std::string merge_list; 00234 std::vector<std::vector<uint> > states_list; 00235 intg concat_dim; 00236 }; 00237 00238 // interlace ///////////////////////////////////////////////////////////////// 00239 00242 template <typename T, class Tstate = bbstate_idx<T> > 00243 class interlace_module : public module_1_1<T, Tstate> { 00244 public: 00246 interlace_module(uint stride, const char *name = "interlace_module"); 00248 virtual ~interlace_module(); 00250 virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out); 00252 virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out); 00254 virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out); 00256 virtual mfidxdim bprop_size(mfidxdim &osize); 00258 virtual std::string describe(); 00260 virtual interlace_module<T,Tstate>* copy(); 00261 private: 00262 uint stride; 00263 }; 00264 00265 } // namespace ebl { 00266 00267 #include "ebl_merge.hpp" 00268 00269 #endif /* EBL_MERGE_H_ */