libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2008 by Yann LeCun and Pierre Sermanet * 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef EBL_COST_H_ 00034 #define EBL_COST_H_ 00035 00036 #include "libidx.h" 00037 #include "ebl_arch.h" 00038 00039 namespace ebl { 00040 00043 template<typename T1, typename T2, class Tstate1 = bbstate_idx<T1>, 00044 class Tstate2 = bbstate_idx<T2> > 00045 class cost_module : public ebm_2<Tstate1, Tstate2, Tstate1> { 00046 public: 00048 idx<T1> &targets; 00050 Tstate1 in2; 00052 idx<T1> energies; 00053 00056 cost_module(idx<T1> &targets_); 00057 virtual ~cost_module(); 00058 }; 00059 00065 template<typename T1, typename T2, class Tstate1 = bbstate_idx<T1>, 00066 class Tstate2 = bbstate_idx<T2> > 00067 class euclidean_module : public cost_module<T1, T2, Tstate1, Tstate2> { 00068 public: 00070 euclidean_module(idx<T1> &targets_); 00071 00073 virtual ~euclidean_module(); 00074 00079 virtual void fprop(Tstate1 &in1, Tstate2 &in2, Tstate1 &energy); 00080 00087 virtual void bprop(Tstate1 &in1, Tstate2 &in2, Tstate1 &energy); 00088 00092 virtual void bbprop(Tstate1 &in1, Tstate2 &in2, Tstate1 &energy); 00093 00095 virtual void forget(forget_param_linear &fp) {} 00096 00098 virtual double infer2(Tstate1 &i1, Tstate2 &i2, infer_param &ip, 00099 Tstate2 *label = NULL, Tstate1 *energy = NULL); 00100 00101 protected: 00102 using cost_module<T1, T2, Tstate1, Tstate2>::targets; 00103 using cost_module<T1, T2, Tstate1, Tstate2>::in2; 00104 using cost_module<T1, T2, Tstate1, Tstate2>::energies; 00105 }; 00106 00110 template<class T> 00111 class logadd_layer { //: public module_1_1<fstate_idx, fstate_idx> { // TODO 00112 public: 00113 idx<T> expdist; 00114 idx<T> sumexp; 00115 00116 logadd_layer(intg thick, intg si, intg sj); 00117 virtual ~logadd_layer() { 00118 } 00119 void fprop(fstate_idx<T> *in, fstate_idx<T> *out); 00120 void bprop(fstate_idx<T> *in, fstate_idx<T> *out); 00121 00124 void bbprop(fstate_idx<T> *in, fstate_idx<T> *out); 00125 }; 00126 00129 template<class T> class distance_l2 : public ebm_2<T> { 00130 private: 00131 idx<T> tmp; 00132 00133 public: 00134 distance_l2(); 00135 virtual ~distance_l2(); 00136 00137 virtual void fprop(fstate_idx<T> &in1, fstate_idx<T> &in2, 00138 fstate_idx<T> &energy); 00139 virtual void bprop(fstate_idx<T> &in1, fstate_idx<T> &in2, 00140 fstate_idx<T> &energy); 00141 virtual void bbprop(fstate_idx<T> &in1, fstate_idx<T> &in2, 00142 fstate_idx<T> &energy); 00143 virtual void forget(forget_param_linear &fp); 00144 virtual void infer2_copy(fstate_idx<T> &i1, fstate_idx<T> &i2, 00145 fstate_idx<T> &energy); 00146 }; 00147 00150 template<class T> class penalty_l1 : public ebm_1<T> { 00151 private: 00152 double threshold; 00153 00154 public: 00155 penalty_l1(T threshold_); 00156 virtual ~penalty_l1(); 00157 00158 virtual void fprop(fstate_idx<T> &in, fstate_idx<T> &energy); 00159 virtual void bprop(fstate_idx<T> &in, fstate_idx<T> &energy); 00160 virtual void bbprop(fstate_idx<T> &in, fstate_idx<T> &energy); 00161 virtual void forget(forget_param_linear &fp); 00162 }; 00163 00164 } // namespace ebl { 00165 00166 #include "ebl_cost.hpp" 00167 00168 #endif /* EBLCOST_H_ */