libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_answer.h
00001 /***************************************************************************
00002  *   Copyright (C) 2011 by Pierre Sermanet   *
00003  *   pierre.sermanet@gmail.com   *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef EBL_ANSWER_H_
00034 #define EBL_ANSWER_H_
00035 
00036 #include "ebl_defines.h"
00037 #include "libidx.h"
00038 #include "ebl_states.h"
00039 #include "ebl_arch.h"
00040 #include "datasource.h"
00041 #include "ebl_nonlinearity.h"
00042 #include "ebl_energy.h"
00043 #include "ebl_logger.h"
00044 
00045 using namespace std;
00046 
00047 namespace ebl {
00048 
00061   enum t_confidence { confidence_sqrdist = 0,
00062                       confidence_single  = 1,
00063                       confidence_max     = 2 };
00064   
00066   // answer modules
00067   
00072   template <typename T, typename Tds1 = T, typename Tds2 = T,
00073     class Tstate = bbstate_idx<T> >
00074     class answer_module : public module_1_1<T,Tstate> {
00075   public:
00076     answer_module(uint nfeatures, const char *name = "answer_module");
00077     virtual ~answer_module();
00078 
00079     // single-state propagation ////////////////////////////////////////////////
00082     virtual void fprop(Tstate &in, Tstate &out);
00084     virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out);
00087     virtual void bprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out);
00090     virtual void bbprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out);
00091 
00092     // multi-state propagation /////////////////////////////////////////////////
00094     virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds,
00095                        mstate<Tstate> &out);
00098     virtual void bprop(labeled_datasource<T,Tds1,Tds2> &ds,
00099                        mstate<Tstate> &out);
00102     virtual void bbprop(labeled_datasource<T,Tds1,Tds2> &ds,
00103                         mstate<Tstate> &out);
00104 
00107     virtual bool correct(Tstate &answer, Tstate &label);
00109     virtual void update_log(classifier_meter &log, intg age, idx<T> &energy,
00110                             idx<T> &answer, idx<T> &label, idx<T> &target,
00111                             idx<T> &rawout);
00113     virtual void forget(forget_param_linear &fp);
00115     virtual std::string describe();
00117     virtual uint get_nfeatures();
00118   protected:
00119     uint nfeatures; 
00120   };
00121 
00125   template <typename T, typename Tds1 = T, typename Tds2 = T,
00126     class Tstate = bbstate_idx<T> >
00127     class class_answer : public answer_module<T,Tds1,Tds2,Tstate> {
00128   public:
00136     class_answer(uint nclasses, double target_factor = 1.0,
00137                  bool binary_target = false,
00138                  t_confidence conf = confidence_max,
00139                  bool apply_tanh = false, const char *name = "class_answer",
00140                  int force_class = -1);
00141     virtual ~class_answer();
00144     virtual void fprop(Tstate &in, Tstate &out);
00145     virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out);
00147     virtual bool correct(Tstate &answer, Tstate &label);
00149     virtual void update_log(classifier_meter &log, intg age, idx<T> &energy,
00150                             idx<T> &answer, idx<T> &label, idx<T> &target,
00151                             idx<T> &rawout);
00153     virtual std::string describe();
00154     // members
00155   protected:
00156     idx<T>              targets;    
00157     idx<T>              target;     
00158     t_confidence        conf_type;  
00159     T                   conf_ratio; 
00160     T                   conf_shift; 
00161     bool                binary_target;
00162     bool                resize_output; 
00163     bool                apply_tanh; 
00164     Tstate              tmp;        
00165     bbstate_idx<Tds2>    last_label; 
00166     tanh_module<T,Tstate> mtanh;    
00167     T                   target_min;
00168     T                   target_max;
00169     int                 force_class;
00170   };
00171 
00176   template <typename T, typename Tds1 = T, typename Tds2 = T,
00177     class Tstate = bbstate_idx<T> >
00178     class scalerclass_answer : public class_answer<T,Tds1,Tds2,Tstate> {
00179   public:
00191     scalerclass_answer(uint nclasses, double target_factor = 1.0,
00192                        bool binary_target = false,
00193                        t_confidence conf = confidence_max,
00194                        bool apply_tanh = false,
00195                        uint jsize = 3, uint joffset = 0,
00196                        float mgauss = 1.5, bool predict_conf = false,
00197                        bool predict_bconf = false,
00198                        idx<T> *biases = NULL, idx<T> *coeffs = NULL,
00199                        const char *name = "scalerclass_answer");
00200     virtual ~scalerclass_answer();
00203     virtual void fprop(Tstate &in, Tstate &out);
00204     virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out);
00206     virtual void update_log(classifier_meter &log, intg age, idx<T> &energy,
00207                             idx<T> &answer, idx<T> &label, idx<T> &target,
00208                             idx<T> &rawout);
00210     virtual std::string describe();
00211     // members
00212   protected:
00213     bbstate_idx<T> jitter; 
00214     using class_answer<T,Tds1,Tds2,Tstate>::resize_output;
00215     using class_answer<T,Tds1,Tds2,Tstate>::targets;
00216     using class_answer<T,Tds1,Tds2,Tstate>::conf_type;
00217     using class_answer<T,Tds1,Tds2,Tstate>::apply_tanh;
00218     Tstate              out_class;       
00219     Tstate              tmp1;       
00220     Tstate              tmp2;       
00221     uint jsize; 
00222     uint joffset; 
00223     float scale_mgauss; 
00224     bool predict_conf; 
00225     bool predict_bconf; 
00226     uint pconf_offset; 
00227     std::vector<std::string> log_fields; 
00228     idx<T>              *biases; 
00229     idx<T>              *coeffs; 
00230   };
00231 
00235   template <typename T, typename Tds1 = T, typename Tds2 = T,
00236     class Tstate = bbstate_idx<T> >
00237     class scaler_answer : public answer_module<T,Tds1,Tds2,Tstate> {
00238   public:    
00244     scaler_answer(uint negative_id_, uint positive_id_,
00245                   bool raw_confidence = false, float threshold = 0.0,
00246                   bool spatial = false,
00247                   const char *name = "scaler_answer");
00248     virtual ~scaler_answer();
00251     virtual void fprop(Tstate &in, Tstate &out);
00254     virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out);
00256     virtual std::string describe();
00257     // members
00258   protected:
00259     uint negative_id; 
00260     uint positive_id; 
00261     bool raw_confidence; 
00262     bbstate_idx<T> jitter; 
00263     T threshold; 
00264     bool spatial; 
00265     uint jsize; 
00266   };
00267 
00276   template <typename T, typename Tds1 = T, typename Tds2 = T,
00277     class Tstate = bbstate_idx<T> >
00278     class regression_answer : public answer_module<T,Tds1,Tds2,Tstate> {
00279   public:
00282     regression_answer(uint nfeatures, float64 threshold = 0.0,
00283                       const char *name = "regression_answer");
00284     virtual ~regression_answer();
00286     virtual void fprop(Tstate &in, Tstate &out);
00289     virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out);
00292     virtual bool correct(Tstate &answer, Tstate &label);
00294     virtual void update_log(classifier_meter &log, intg age, idx<T> &energy,
00295                             idx<T> &answer, idx<T> &label, idx<T> &target,
00296                             idx<T> &rawout);
00298     virtual std::string describe();
00299     // members
00300   protected:
00301     float64          threshold; 
00302   };
00303   
00306   template <typename T, typename Tds1 = T, typename Tds2 = T,
00307     class Tstate = bbstate_idx<T> >
00308     class vote_answer : public class_answer<T,Tds1,Tds2,Tstate> {
00309   public:
00316     vote_answer(uint nclasses, double target_factor = 1.0,
00317                  bool binary_target = false,
00318                  t_confidence conf = confidence_max,
00319                  bool apply_tanh = false,
00320                  const char *name = "vote_answer");
00321     virtual ~vote_answer();
00324     virtual void fprop(Tstate &in, Tstate &out);
00325   };
00326 
00328   // trainable_module
00329 
00334   template <typename T, typename Tds1 = T, typename Tds2 = T,
00335     class Tin1 = bbstate_idx<T>, class Tin2 = Tin1, class Ten = Tin1>
00336     class trainable_module {
00337   public:
00348     trainable_module(ebm_2<Tin1,Tin2,Ten> &energy,
00349                      module_1_1<T,Tin1> &mod1,
00350                      module_1_1<T,Tin2> *mod2 = NULL,
00351                      answer_module<T,Tds1,Tds2,Tin1> *dsmod1 = NULL,
00352                      answer_module<T,Tds1,Tds2,Tin2> *dsmod2 = NULL,
00353                      const char *name = "trainable_module",
00354                      const char *switcher = "");
00355     virtual ~trainable_module();
00356 
00357     virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, Ten &energy);
00358     virtual void bprop(labeled_datasource<T,Tds1,Tds2> &ds, Ten &energy);
00359     virtual void bbprop(labeled_datasource<T,Tds1,Tds2> &ds, Ten &energy);
00360     virtual int infer2(labeled_datasource<T,Tds1,Tds2> &ds, Ten &energy);
00362     virtual void forget(forget_param_linear &fp);
00366     virtual const Tin1& compute_answers();
00370     virtual void compute_answers(Tin1 &ans);
00372     virtual bool correct(Tin1 &answer, Tin1 &label);
00374     virtual void update_log(classifier_meter &log, intg age, idx<T> &energy,
00375                             idx<T> &answer, idx<T> &label, idx<T> &target,
00376                             idx<T> &rawout);
00380     virtual idx<T> compute_targets(labeled_datasource<T,Tds1,Tds2> &ds);
00382     virtual const char* name();
00384     virtual std::string describe();
00385 
00386     // friends /////////////////////////////////////////////////////////////////
00387     friend class trainable_module_gui;
00388     template<class T1, class T2, class T3> friend class supervised_trainer;
00389     template<class T1, class T2, class T3> friend class supervised_trainer_gui;
00390 
00391     // internal methods ////////////////////////////////////////////////////////
00392   protected:
00394     void update_scale(labeled_datasource<T,Tds1,Tds2> &ds);
00395   
00396     // members /////////////////////////////////////////////////////////////////
00397   protected:
00398     ebm_2<Tin1,Tin2,Ten>                &energy_mod;
00399     module_1_1<T,Tin1,Tin1>             &mod1;
00400     module_1_1<T,Tin1,Tin1>             *mod2;
00401     answer_module<T,Tds1,Tds2,Tin1>     *dsmod1;
00402     answer_module<T,Tds1,Tds2,Tin2>     *dsmod2;
00403     // intermediate buffers
00404     mstate<Tin1>                         msin1;
00405     Tin1                                 in1;
00406     Tin1                                 out1;  
00407     Tin2                                 in2;
00408     Tin2                                 out2;
00409     Tin1                                 answers;
00410     Tin1                                 targets;
00411     string                               mod_name;
00412     Ten                                  tmp_energy;
00413     ms_module<T,Tin1>                    *ms_switch;
00414   };
00415 
00416   // utility functions /////////////////////////////////////////////////////////
00417   
00418   template <typename T>
00419     void print_targets(idx<T> &targets);
00420     
00421 } // end namespace ebl
00422 
00423 #include "ebl_answer.hpp"
00424 
00425 #endif /* EBL_ANSWER_H_ */