libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_merge.h
00001 /***************************************************************************
00002  *   Copyright (C) 2011 by Pierre Sermanet *
00003  *   pierre.sermanet@gmail.com *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef EBL_MERGE_H_
00034 #define EBL_MERGE_H_
00035 
00036 #include "ebl_arch.h"
00037 #include "ebl_march.h"
00038 
00039 namespace ebl {
00040 
00041   // flat_merge ////////////////////////////////////////////////////////////////
00042 
00044   template <typename T, class Tstate> class zpad_module;
00045 
00050   template <typename T, class Tstate = bbstate_idx<T> >
00051     class flat_merge_module : public m2s_module<T, Tstate> {
00052   public:
00056     flat_merge_module(std::vector<Tstate**> &inputs,
00057                       idxdim &in, midxdim &ins,
00058                       fidxdim &stride, mfidxdim &strides,
00059                       const char *name_ = "flatmerge", const char *list = NULL);
00063     flat_merge_module(std::vector<mstate<Tstate>**> &inputs,
00064                       idxdim &in, midxdim &ins,
00065                       fidxdim &stride, mfidxdim &strides,
00066                       const char *name_ = "flatmerge", const char *list = NULL);
00073     flat_merge_module(midxdim &ins, mfidxdim &strides, bool pad = false,
00074                       const char *name_ = "flatmerge", mfidxdim *scales = NULL,
00075                       intg hextra = 0, intg wextra = 0, float ss = 1,
00076                       float edge = 0);
00077     virtual ~flat_merge_module();
00079     // generic states methods
00081     virtual void fprop(Tstate &in, Tstate &out);
00083     virtual void bprop(Tstate &in, Tstate &out);
00085     virtual void bbprop(Tstate &in, Tstate &out);
00087     // multi-states methods
00089     virtual void fprop(mstate<Tstate> &in, Tstate &out);
00091     virtual void bprop(mstate<Tstate> &in, Tstate &out);
00093     virtual void bbprop(mstate<Tstate> &in, Tstate &out);
00097     virtual idxdim fprop_size(idxdim &i_size);
00100     virtual fidxdim bprop_size(const fidxdim &o_size);
00102     virtual mfidxdim bprop_size(mfidxdim &osize);
00104     virtual std::string describe();
00106     virtual uint get_ninputs();
00108     virtual mfidxdim get_strides();
00110     virtual mfidxdim get_scales();
00112     virtual flat_merge_module<T,Tstate>* copy();
00113     /* //! Set paddings to be applied to each input scale. */
00114     /* virtual void set_paddings(mfidxdim &pads); */
00116     virtual void set_offsets(vector<vector<int> > &off);
00118     virtual void set_strides(mfidxdim &s);
00119 
00120   protected:
00128     idxdim compute_pad(idxdim &window, float subsampling, float edge,
00129                        float scale, fidxdim &stride);
00130 
00131   private:
00132     std::vector<Tstate**> inputs;
00133     idxdim din;
00134     midxdim dins;
00135     fidxdim stride;
00136     mfidxdim strides;
00137     std::string merge_list; 
00138     std::vector<zpad_module<T,Tstate>*> zpads;
00139     Tstate *in0;
00140     std::vector<Tstate*> pinputs;
00141     bool use_pinputs;
00142     mstate<Tstate> padded;
00143     zpad_module<T,Tstate> padder;
00144     bool bpad; 
00145     mfidxdim scales; 
00146     mfidxdim paddings; 
00147     vector<vector<int> > offsets; 
00148 
00149   // TEMP
00150   intg hextra, wextra;
00151   float subsampling, edge;
00152   };
00153 
00154   // mstate_merge //////////////////////////////////////////////////////////////
00155 
00157   template <typename T, class Tstate = bbstate_idx<T> >
00158     class mstate_merge_module : public module_1_1<T, Tstate> {
00159   public:
00163     mstate_merge_module(midxdim &ins, mfidxdim &strides,
00164                         const char *name_ = "mstate_merge");
00165     virtual ~mstate_merge_module();
00166     // multi-states methods ////////////////////////////////////////////////////
00167     virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out);
00168     virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out);
00169     virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out);
00173     virtual idxdim fprop_size(idxdim &i_size);
00176     virtual fidxdim bprop_size(const fidxdim &o_size);
00178     virtual std::string describe();
00179 
00180   private:
00181     midxdim dins;
00182     mfidxdim dstrides;
00183   };
00184 
00185   // merge /////////////////////////////////////////////////////////////////////
00186 
00191   template <typename T, class Tstate = bbstate_idx<T> >
00192     class merge_module : public module_1_1<T, Tstate> {
00193   public:
00199     merge_module(std::vector<Tstate**> &inputs, intg concat_dim,
00200                  const char *name_ = "merge", const char *list = NULL);
00202     merge_module(std::vector<mstate<Tstate>**> &inputs, intg concat_dim,
00203                  const char *name_ = "merge", const char *list = NULL);
00207     merge_module(std::vector<std::vector<uint> > &states, intg concat_dim,
00208                  const char *name_ = "merge");
00209     virtual ~merge_module();
00211     virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out);
00213     virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out);
00215     virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out);
00217     virtual void fprop(Tstate &in, Tstate &out);
00219     virtual void bprop(Tstate &in, Tstate &out);
00221     virtual void bbprop(Tstate &in, Tstate &out);
00223     virtual std::string describe();
00224 
00225     // internal members ////////////////////////////////////////////////////////
00226   protected:
00228     virtual void merge(mstate<Tstate> &in, Tstate &out);
00229 
00230   private:
00231     std::vector<Tstate**> inputs;
00232     std::vector<mstate<Tstate>**> msinputs;
00233     std::string merge_list;
00234     std::vector<std::vector<uint> > states_list;
00235     intg concat_dim; 
00236   };
00237 
00238   // interlace /////////////////////////////////////////////////////////////////
00239 
00242   template <typename T, class Tstate = bbstate_idx<T> >
00243     class interlace_module : public module_1_1<T, Tstate> {
00244   public:
00246     interlace_module(uint stride, const char *name = "interlace_module");
00248     virtual ~interlace_module();
00250     virtual void fprop(mstate<Tstate> &in, mstate<Tstate> &out);
00252     virtual void bprop(mstate<Tstate> &in, mstate<Tstate> &out);
00254     virtual void bbprop(mstate<Tstate> &in, mstate<Tstate> &out);
00256     virtual mfidxdim bprop_size(mfidxdim &osize);
00258     virtual std::string describe();
00260     virtual interlace_module<T,Tstate>* copy();
00261   private:
00262     uint stride; 
00263   };
00264 
00265 } // namespace ebl {
00266 
00267 #include "ebl_merge.hpp"
00268 
00269 #endif /* EBL_MERGE_H_ */