libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2011 by Pierre Sermanet * 00003 * pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef EBL_ENERGY_H_ 00034 #define EBL_ENERGY_H_ 00035 00036 #include "libidx.h" 00037 #include "ebl_arch.h" 00038 #include "ebl_nonlinearity.h" 00039 00040 namespace ebl { 00041 00042 // l2_energy ///////////////////////////////////////////////////////////////// 00045 template<typename T, class Tstate = bbstate_idx<T> > 00046 class l2_energy : public ebm_2<Tstate> { 00047 public: 00048 l2_energy(const char *name = "l2_energy"); 00049 virtual ~l2_energy(); 00050 virtual void fprop(Tstate &in1, Tstate &in2, Tstate &energy); 00051 virtual void bprop(Tstate &in1, Tstate &in2, Tstate &energy); 00052 virtual void bbprop(Tstate &in1, Tstate &in2, Tstate &energy); 00053 virtual void infer2_copy(Tstate &i1, Tstate &i2, Tstate &energy); 00055 virtual std::string describe(); 00056 protected: 00057 idx<T> tmp; 00058 }; 00059 00060 // l1_penalty //////////////////////////////////////////////////////////////// 00062 template<typename T, class Tstate = bbstate_idx<T> > 00063 class l1_penalty : public ebm_1<T,Tstate> { 00064 public: 00068 l1_penalty(T threshold = 0, T coeff = 1); 00069 virtual ~l1_penalty(); 00070 virtual void fprop(Tstate &in, Tstate &energy); 00071 virtual void bprop(Tstate &in, Tstate &energy); 00072 virtual void bbprop(Tstate &in, Tstate &energy); 00074 virtual std::string describe(); 00075 // member variables 00076 protected: 00077 T threshold; 00078 T coeff; 00079 }; 00080 00081 // cross_entropy_energy ////////////////////////////////////////////////////// 00085 template<typename T, class Tstate = bbstate_idx<T> > 00086 class cross_entropy_energy : public ebm_2<Tstate> { 00087 public: 00088 cross_entropy_energy(const char *name = "cross_entropy_energy"); 00089 virtual ~cross_entropy_energy(); 00090 virtual void fprop(Tstate &in1, Tstate &in2, Tstate &energy); 00091 virtual void bprop(Tstate &in1, Tstate &in2, Tstate &energy); 00092 virtual void bbprop(Tstate &in1, Tstate &in2, Tstate &energy); 00093 virtual void infer2_copy(Tstate &i1, Tstate &i2, Tstate &energy); 00095 virtual std::string describe(); 00096 protected: 00097 idx<T> tmp; 00098 }; 00099 00101 // scalerclass_energy 00102 00103 template<typename T, class Tstate = bbstate_idx<T> > 00104 class scalerclass_energy : public l2_energy<T,Tstate> { 00105 public: 00109 scalerclass_energy(bool apply_tanh = false, uint jsize = 1, 00110 uint jitter_selection = 0, float dist_coeff = 1.0, 00111 float scale_coeff = 1.0, bool predict_conf = false, 00112 bool predict_bconf = false, 00113 idx<T> *biases = NULL, idx<T> *coeffs = NULL, 00114 const char *name = "scalerclass_energy"); 00116 virtual ~scalerclass_energy(); 00117 virtual void fprop(Tstate &in, Tstate &scale, Tstate &energy); 00118 virtual void bprop(Tstate &in, Tstate &scale, Tstate &energy); 00119 virtual void bbprop(Tstate &in, Tstate &scale, Tstate &energy); 00121 virtual void infer2(Tstate &i1, Tstate &scale, infer_param &ip, 00122 Tstate *energy = NULL); 00124 virtual std::string describe(); 00125 // members 00126 public: 00127 idx<T> last_target_raw; 00128 Tstate last_target; 00129 protected: 00130 uint jsize; 00131 bool apply_tanh; 00132 uint jitter_selection; 00133 float dist_coeff; 00134 float scale_coeff; 00135 Tstate tmp; 00136 Tstate tmp2; 00137 Tstate last_class_target; 00138 Tstate last_jitt_target; 00139 Tstate last_conf_target; 00140 idx<T> best_target; 00141 tanh_module<T,Tstate> mtanh; 00142 bool predict_conf; 00143 bool predict_bconf; 00144 idx<T> *biases; 00145 idx<T> *coeffs; 00146 }; 00147 00149 // scaler_energy 00150 00151 template<typename T, class Tstate = bbstate_idx<T> > 00152 class scaler_energy : public ebm_2<Tstate> { 00153 public: 00156 scaler_energy(const char *name = "scaler_energy"); 00158 virtual ~scaler_energy(); 00159 virtual void fprop(Tstate &in, Tstate &in2, Tstate &energy); 00160 virtual void bprop(Tstate &in, Tstate &in2, Tstate &energy); 00161 virtual void bbprop(Tstate &in, Tstate &in2, Tstate &energy); 00163 virtual void infer2(Tstate &i1, Tstate &in2, infer_param &ip, 00164 Tstate *energy = NULL); 00166 virtual std::string describe(); 00167 }; 00168 00169 } // namespace ebl { 00170 00171 #include "ebl_energy.hpp" 00172 00173 #endif /* EBL_ENERGY_H_ */