libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_nonlinearity.h
00001 /***************************************************************************
00002  *   Copyright (C) 2008 by Yann LeCun and Pierre Sermanet *
00003  *   yann@cs.nyu.edu, pierre.sermanet@gmail.com *
00004  *
00005  * Redistribution and use in source and binary forms, with or without
00006  * modification, are permitted provided that the following conditions are met:
00007  *     * Redistributions of source code must retain the above copyright
00008  *       notice, this list of conditions and the following disclaimer.
00009  *     * Redistributions in binary form must reproduce the above copyright
00010  *       notice, this list of conditions and the following disclaimer in the
00011  *       documentation and/or other materials provided with the distribution.
00012  *     * Redistribution under a license not approved by the Open Source
00013  *       Initiative (http://www.opensource.org) must display the
00014  *       following acknowledgement in all advertising material:
00015  *        This product includes software developed at the Courant
00016  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00017  *     * The names of the authors may not be used to endorse or promote products
00018  *       derived from this software without specific prior written permission.
00019  *
00020  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00021  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00022  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00023  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00024  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00025  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00026  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00027  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00028  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00029  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00030  ***************************************************************************/
00031 
00032 #ifndef EBL_NONLINEARITY_H_
00033 #define EBL_NONLINEARITY_H_
00034 
00035 #include "ebl_defines.h"
00036 #include "libidx.h"
00037 #include "ebl_states.h"
00038 #include "ebl_basic.h"
00039 #include "ebl_arch.h"
00040 
00041 namespace ebl {
00042 
00045   template <typename T, class Tstate = bbstate_idx<T> >
00046     class stdsigmoid_module: public module_1_1<T,Tstate> {
00047   public:
00049     stdsigmoid_module();
00050     virtual ~stdsigmoid_module();
00052     virtual void fprop(Tstate &in, Tstate &out);
00054     virtual void bprop(Tstate &in, Tstate &out);
00056     virtual void bbprop(Tstate &in, Tstate &out);
00058     virtual stdsigmoid_module<T,Tstate>* copy();
00059   protected:
00060     idx<T> tmp; 
00061   };
00062 
00065   template <typename T, class Tstate = bbstate_idx<T> >
00066     class tanh_module: public module_1_1<T,Tstate> {
00067   public:
00069     tanh_module();
00070     virtual ~tanh_module();
00072     void fprop(Tstate &in, Tstate &out);
00074     void bprop(Tstate &in, Tstate &out);
00076     void bbprop(Tstate &in, Tstate &out);
00078     virtual tanh_module<T,Tstate>* copy();
00079   protected:
00080     idx<T> tmp; 
00081   };
00082 
00092   template <typename T, class Tstate = bbstate_idx<T> >
00093     class softmax: public module_1_1<T,Tstate> {
00094   public:
00095     double beta;
00096 
00097     // <b> is the parameter beta in the softmax
00098     // large <b> turns the softmax into a max
00099     // <b> equal to 0 turns the softmax into 1/N
00100 
00101   private:
00102     void resize_nsame(Tstate &in, Tstate &out, int n);
00103 
00104   public:
00105     softmax(double b);
00106     ~softmax() {};
00107     void fprop(Tstate &in, Tstate &out);
00108     void bprop(Tstate &in, Tstate &out);
00109     void bbprop(Tstate &in, Tstate &out);
00110   };
00111 
00113   // abs_module
00118   template <typename T, class Tstate = bbstate_idx<T> >
00119     class abs_module: public module_1_1<T, Tstate> {    
00120   public:
00123     abs_module(double thresh = 0.0);
00125     virtual ~abs_module();
00127     virtual void fprop(Tstate &in, Tstate &out);
00129     virtual void bprop(Tstate &in, Tstate &out);
00131     virtual void bbprop(Tstate &in, Tstate &out);
00133     virtual abs_module<T,Tstate>* copy();
00134   private:
00135     double threshold;
00136   };
00137 
00139   // linear_shrink_module
00144   template <typename T, class Tstate = bbstate_idx<T> >
00145     class linear_shrink_module: public module_1_1<T, Tstate> {
00146   public:
00149     linear_shrink_module(parameter<T,Tstate> *p, intg nf, T bias = 0);
00151     virtual ~linear_shrink_module();
00153     virtual void fprop(Tstate &in, Tstate &out);
00155     virtual void bprop(Tstate &in, Tstate &out);
00157     virtual void bbprop(Tstate &in, Tstate &out);
00159     virtual linear_shrink_module<T,Tstate>* copy();
00161     virtual std::string describe();
00162   protected:
00163     Tstate bias;
00164     T default_bias;
00165   };
00166 
00168   // smooth_shrink_module
00173   template <typename T, class Tstate = bbstate_idx<T> >
00174     class smooth_shrink_module: public module_1_1<T, Tstate> {
00175   public:
00178     smooth_shrink_module(parameter<T,Tstate> *p, intg nf, T beta = 10,
00179                          T bias = .3);
00181     virtual ~smooth_shrink_module();
00183     virtual void fprop(Tstate &in, Tstate &out);
00185     virtual void bprop(Tstate &in, Tstate &out);
00187     virtual void bbprop(Tstate &in, Tstate &out);
00189     virtual smooth_shrink_module<T,Tstate>* copy();
00190 
00191   public:
00192     Tstate beta, bias;
00193   private:
00194     Tstate ebb, ebx, tin;
00195     abs_module<T,Tstate> absmod;
00196     T default_beta, default_bias;
00197   };
00198 
00200   // tanh_shrink_module
00205   template <typename T, class Tstate = bbstate_idx<T> >
00206     class tanh_shrink_module: public module_1_1<T, Tstate> {
00207   public:
00212     tanh_shrink_module(parameter<T,Tstate> *p, intg nf, bool diags = false);
00214     virtual ~tanh_shrink_module();
00216     virtual void fprop(Tstate &in, Tstate &out);
00218     virtual void bprop(Tstate &in, Tstate &out);
00220     virtual void bbprop(Tstate &in, Tstate &out);
00222     virtual tanh_shrink_module<T,Tstate>* copy();
00224     virtual std::string describe();
00225   protected:
00226     intg                         nfeatures;
00227     Tstate                       abuf, tbuf, bbuf;
00228     diag_module<T,Tstate>       *alpha, *beta;
00229     tanh_module<T,Tstate>        mtanh;
00230     diff_module<T,Tstate>        difmod;        
00231     bool                         diags; 
00232   };
00233 
00234 } // namespace ebl {
00235 
00236 #include "ebl_nonlinearity.hpp"
00237 
00238 #endif /* EBL_NONLINEARITY_H_ */