libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_trainer.h
00001 /***************************************************************************
00002  *   Copyright (C) 2008 by Yann LeCun and Pierre Sermanet *
00003  *   yann@cs.nyu.edu, pierre.sermanet@gmail.com *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef EBL_TRAINER_H_
00034 #define EBL_TRAINER_H_
00035 
00036 #include "libidx.h"
00037 #include "ebl_arch.h"
00038 #include "ebl_machines.h"
00039 #include "ebl_logger.h"
00040 #include "datasource.h"
00041 #include "ebl_answer.h"
00042 
00043 namespace ebl {
00044 
00052   template<typename Tnet, typename Tdata, typename Tlabel>
00053     class supervised_trainer {
00054   public:
00056     supervised_trainer(trainable_module<Tnet,Tdata,Tlabel> &m,
00057                        parameter<Tnet,bbstate_idx<Tnet> > &p);
00059     virtual ~supervised_trainer();
00060 
00061     // per-sample methods //////////////////////////////////////////////////////
00062 
00065     bool test_sample(labeled_datasource<Tnet,Tdata,Tlabel> &ds,
00066                      bbstate_idx<Tnet> &label, bbstate_idx<Tnet> &answers,
00067                      infer_param &infp);
00071     Tnet train_sample(labeled_datasource<Tnet,Tdata,Tlabel> &ds,
00072                       gd_param &arg);
00073 
00074     // epoch methods ///////////////////////////////////////////////////////////
00075 
00079     void test(labeled_datasource<Tnet, Tdata, Tlabel> &ds,
00080               classifier_meter &log, infer_param &infp,
00081               uint max_test = 0);
00091     void train(labeled_datasource<Tnet, Tdata, Tlabel> &ds,
00092                classifier_meter &log, gd_param &args, int niter,
00093                infer_param &infp, 
00094                intg hessian_period = 0, intg nhessian = 0, double mu = .02);
00096     void compute_diaghessian(labeled_datasource<Tnet, Tdata, Tlabel> &ds,
00097                              intg niter, double mu);
00098 
00099     // accessors ///////////////////////////////////////////////////////////////
00100     
00103     void set_iteration(int i);
00105     void pretty(labeled_datasource<Tnet, Tdata, Tlabel> &ds);
00109     void set_progress_file(const std::string &s);
00112     void update_progress();
00113 
00114     // friends /////////////////////////////////////////////////////////////////
00115     
00116     // template <class Tdata, class Tlabel> friend class supervised_trainer_gui;
00117     template <class T1, class T2, class T3> friend class supervised_trainer_gui;
00118 
00119     // internal methods ////////////////////////////////////////////////////////
00120   protected:
00121     
00125     void init(labeled_datasource<Tnet, Tdata, Tlabel> &ds,
00126               classifier_meter *log = NULL, bool new_iteration = false);
00127 
00128     // members /////////////////////////////////////////////////////////////////
00129   protected:
00130     trainable_module<Tnet,Tdata,Tlabel> &machine;
00131     parameter<Tnet, bbstate_idx<Tnet> >   &param;       
00132     bbstate_idx<Tnet>    energy;        
00133     bbstate_idx<Tnet>   *answers;       
00134     bbstate_idx<Tnet>   *label;         
00135     intg                 age;
00136     int                  iteration;
00137     void                *iteration_ptr;
00138     bool                 prettied;      
00139     std::string          progress_file; 
00140     intg                 progress_cnt;  
00141     bool                 test_running;
00142   };
00143 
00144 } // namespace ebl {
00145 
00146 #include "ebl_trainer.hpp"
00147 
00148 #endif /* EBL_TRAINER_H_ */