libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_energy.h
00001 /***************************************************************************
00002  *   Copyright (C) 2011 by Pierre Sermanet *
00003  *   pierre.sermanet@gmail.com *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef EBL_ENERGY_H_
00034 #define EBL_ENERGY_H_
00035 
00036 #include "libidx.h"
00037 #include "ebl_arch.h"
00038 #include "ebl_nonlinearity.h"
00039 
00040 namespace ebl {
00041 
00042   // l2_energy /////////////////////////////////////////////////////////////////
00045   template<typename T, class Tstate = bbstate_idx<T> >
00046     class l2_energy : public ebm_2<Tstate> {
00047   public:
00048     l2_energy(const char *name = "l2_energy");
00049     virtual ~l2_energy();
00050     virtual void fprop(Tstate &in1, Tstate &in2, Tstate &energy);
00051     virtual void bprop(Tstate &in1, Tstate &in2, Tstate &energy);
00052     virtual void bbprop(Tstate &in1, Tstate &in2, Tstate &energy);
00053     virtual void infer2_copy(Tstate &i1, Tstate &i2, Tstate &energy);
00055     virtual std::string describe();
00056   protected:
00057     idx<T> tmp; 
00058   };
00059 
00060   // l1_penalty ////////////////////////////////////////////////////////////////
00062   template<typename T, class Tstate = bbstate_idx<T> >
00063     class l1_penalty : public ebm_1<T,Tstate> {
00064   public:
00068     l1_penalty(T threshold = 0, T coeff = 1);
00069     virtual ~l1_penalty();
00070     virtual void fprop(Tstate &in, Tstate &energy);
00071     virtual void bprop(Tstate &in, Tstate &energy);
00072     virtual void bbprop(Tstate &in, Tstate &energy);
00074     virtual std::string describe();
00075     // member variables 
00076   protected:
00077     T threshold;
00078     T coeff;
00079   };
00080 
00081   // cross_entropy_energy //////////////////////////////////////////////////////
00085   template<typename T, class Tstate = bbstate_idx<T> >
00086     class cross_entropy_energy : public ebm_2<Tstate> {
00087   public:
00088     cross_entropy_energy(const char *name = "cross_entropy_energy");
00089     virtual ~cross_entropy_energy();
00090     virtual void fprop(Tstate &in1, Tstate &in2, Tstate &energy);
00091     virtual void bprop(Tstate &in1, Tstate &in2, Tstate &energy);
00092     virtual void bbprop(Tstate &in1, Tstate &in2, Tstate &energy);
00093     virtual void infer2_copy(Tstate &i1, Tstate &i2, Tstate &energy);
00095     virtual std::string describe();
00096   protected:
00097     idx<T> tmp; 
00098   };
00099 
00101   // scalerclass_energy
00102   
00103   template<typename T, class Tstate = bbstate_idx<T> >
00104     class scalerclass_energy : public l2_energy<T,Tstate> {
00105   public:
00109     scalerclass_energy(bool apply_tanh = false, uint jsize = 1,
00110                        uint jitter_selection = 0, float dist_coeff = 1.0,
00111                        float scale_coeff = 1.0, bool predict_conf = false,
00112                        bool predict_bconf = false,
00113                        idx<T> *biases = NULL, idx<T> *coeffs = NULL,
00114                        const char *name = "scalerclass_energy");
00116     virtual ~scalerclass_energy();
00117     virtual void fprop(Tstate &in, Tstate &scale, Tstate &energy);
00118     virtual void bprop(Tstate &in, Tstate &scale, Tstate &energy);
00119     virtual void bbprop(Tstate &in, Tstate &scale, Tstate &energy);
00121     virtual void infer2(Tstate &i1, Tstate &scale, infer_param &ip, 
00122                         Tstate *energy = NULL);
00124     virtual std::string describe();
00125     // members
00126   public:
00127     idx<T>              last_target_raw; 
00128     Tstate              last_target; 
00129   protected:
00130     uint jsize; 
00131     bool                apply_tanh; 
00132     uint                jitter_selection; 
00133     float               dist_coeff; 
00134     float               scale_coeff; 
00135     Tstate              tmp;        
00136     Tstate              tmp2;       
00137     Tstate              last_class_target; 
00138     Tstate              last_jitt_target; 
00139     Tstate              last_conf_target; 
00140     idx<T>              best_target; 
00141     tanh_module<T,Tstate> mtanh;    
00142     bool predict_conf; 
00143     bool predict_bconf; 
00144     idx<T>              *biases; 
00145     idx<T>              *coeffs; 
00146   };
00147 
00149   // scaler_energy
00150   
00151   template<typename T, class Tstate = bbstate_idx<T> >
00152     class scaler_energy : public ebm_2<Tstate> {
00153   public:
00156     scaler_energy(const char *name = "scaler_energy");
00158     virtual ~scaler_energy();
00159     virtual void fprop(Tstate &in, Tstate &in2, Tstate &energy);
00160     virtual void bprop(Tstate &in, Tstate &in2, Tstate &energy);
00161     virtual void bbprop(Tstate &in, Tstate &in2, Tstate &energy);
00163     virtual void infer2(Tstate &i1, Tstate &in2, infer_param &ip, 
00164                         Tstate *energy = NULL);
00166     virtual std::string describe();
00167   };
00168 
00169 } // namespace ebl {
00170 
00171 #include "ebl_energy.hpp"
00172 
00173 #endif /* EBL_ENERGY_H_ */