libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2008 by Yann LeCun and Pierre Sermanet * 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com * 00004 * 00005 * Redistribution and use in source and binary forms, with or without 00006 * modification, are permitted provided that the following conditions are met: 00007 * * Redistributions of source code must retain the above copyright 00008 * notice, this list of conditions and the following disclaimer. 00009 * * Redistributions in binary form must reproduce the above copyright 00010 * notice, this list of conditions and the following disclaimer in the 00011 * documentation and/or other materials provided with the distribution. 00012 * * Redistribution under a license not approved by the Open Source 00013 * Initiative (http://www.opensource.org) must display the 00014 * following acknowledgement in all advertising material: 00015 * This product includes software developed at the Courant 00016 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00017 * * The names of the authors may not be used to endorse or promote products 00018 * derived from this software without specific prior written permission. 00019 * 00020 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00021 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00022 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00023 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00024 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00025 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00026 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00027 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00028 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00029 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00030 ***************************************************************************/ 00031 00032 #ifndef EBL_NONLINEARITY_H_ 00033 #define EBL_NONLINEARITY_H_ 00034 00035 #include "ebl_defines.h" 00036 #include "libidx.h" 00037 #include "ebl_states.h" 00038 #include "ebl_basic.h" 00039 #include "ebl_arch.h" 00040 00041 namespace ebl { 00042 00045 template <typename T, class Tstate = bbstate_idx<T> > 00046 class stdsigmoid_module: public module_1_1<T,Tstate> { 00047 public: 00049 stdsigmoid_module(); 00050 virtual ~stdsigmoid_module(); 00052 virtual void fprop(Tstate &in, Tstate &out); 00054 virtual void bprop(Tstate &in, Tstate &out); 00056 virtual void bbprop(Tstate &in, Tstate &out); 00058 virtual stdsigmoid_module<T,Tstate>* copy(); 00059 protected: 00060 idx<T> tmp; 00061 }; 00062 00065 template <typename T, class Tstate = bbstate_idx<T> > 00066 class tanh_module: public module_1_1<T,Tstate> { 00067 public: 00069 tanh_module(); 00070 virtual ~tanh_module(); 00072 void fprop(Tstate &in, Tstate &out); 00074 void bprop(Tstate &in, Tstate &out); 00076 void bbprop(Tstate &in, Tstate &out); 00078 virtual tanh_module<T,Tstate>* copy(); 00079 protected: 00080 idx<T> tmp; 00081 }; 00082 00092 template <typename T, class Tstate = bbstate_idx<T> > 00093 class softmax: public module_1_1<T,Tstate> { 00094 public: 00095 double beta; 00096 00097 // <b> is the parameter beta in the softmax 00098 // large <b> turns the softmax into a max 00099 // <b> equal to 0 turns the softmax into 1/N 00100 00101 private: 00102 void resize_nsame(Tstate &in, Tstate &out, int n); 00103 00104 public: 00105 softmax(double b); 00106 ~softmax() {}; 00107 void fprop(Tstate &in, Tstate &out); 00108 void bprop(Tstate &in, Tstate &out); 00109 void bbprop(Tstate &in, Tstate &out); 00110 }; 00111 00113 // abs_module 00118 template <typename T, class Tstate = bbstate_idx<T> > 00119 class abs_module: public module_1_1<T, Tstate> { 00120 public: 00123 abs_module(double thresh = 0.0); 00125 virtual ~abs_module(); 00127 virtual void fprop(Tstate &in, Tstate &out); 00129 virtual void bprop(Tstate &in, Tstate &out); 00131 virtual void bbprop(Tstate &in, Tstate &out); 00133 virtual abs_module<T,Tstate>* copy(); 00134 private: 00135 double threshold; 00136 }; 00137 00139 // linear_shrink_module 00144 template <typename T, class Tstate = bbstate_idx<T> > 00145 class linear_shrink_module: public module_1_1<T, Tstate> { 00146 public: 00149 linear_shrink_module(parameter<T,Tstate> *p, intg nf, T bias = 0); 00151 virtual ~linear_shrink_module(); 00153 virtual void fprop(Tstate &in, Tstate &out); 00155 virtual void bprop(Tstate &in, Tstate &out); 00157 virtual void bbprop(Tstate &in, Tstate &out); 00159 virtual linear_shrink_module<T,Tstate>* copy(); 00161 virtual std::string describe(); 00162 protected: 00163 Tstate bias; 00164 T default_bias; 00165 }; 00166 00168 // smooth_shrink_module 00173 template <typename T, class Tstate = bbstate_idx<T> > 00174 class smooth_shrink_module: public module_1_1<T, Tstate> { 00175 public: 00178 smooth_shrink_module(parameter<T,Tstate> *p, intg nf, T beta = 10, 00179 T bias = .3); 00181 virtual ~smooth_shrink_module(); 00183 virtual void fprop(Tstate &in, Tstate &out); 00185 virtual void bprop(Tstate &in, Tstate &out); 00187 virtual void bbprop(Tstate &in, Tstate &out); 00189 virtual smooth_shrink_module<T,Tstate>* copy(); 00190 00191 public: 00192 Tstate beta, bias; 00193 private: 00194 Tstate ebb, ebx, tin; 00195 abs_module<T,Tstate> absmod; 00196 T default_beta, default_bias; 00197 }; 00198 00200 // tanh_shrink_module 00205 template <typename T, class Tstate = bbstate_idx<T> > 00206 class tanh_shrink_module: public module_1_1<T, Tstate> { 00207 public: 00212 tanh_shrink_module(parameter<T,Tstate> *p, intg nf, bool diags = false); 00214 virtual ~tanh_shrink_module(); 00216 virtual void fprop(Tstate &in, Tstate &out); 00218 virtual void bprop(Tstate &in, Tstate &out); 00220 virtual void bbprop(Tstate &in, Tstate &out); 00222 virtual tanh_shrink_module<T,Tstate>* copy(); 00224 virtual std::string describe(); 00225 protected: 00226 intg nfeatures; 00227 Tstate abuf, tbuf, bbuf; 00228 diag_module<T,Tstate> *alpha, *beta; 00229 tanh_module<T,Tstate> mtanh; 00230 diff_module<T,Tstate> difmod; 00231 bool diags; 00232 }; 00233 00234 } // namespace ebl { 00235 00236 #include "ebl_nonlinearity.hpp" 00237 00238 #endif /* EBL_NONLINEARITY_H_ */