libeblearn
/home/rex/ebltrunk/core/libeblearn/include/ebl_trainer.hpp
00001 /***************************************************************************
00002  *   Copyright (C) 2008 by Yann LeCun and Pierre Sermanet *
00003  *   yann@cs.nyu.edu, pierre.sermanet@gmail.com *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #include <typeinfo>
00034 #include "utils.h"
00035 
00036 namespace ebl {
00037 
00038   // supervised_trainer ////////////////////////////////////////////////////////
00039 
00040   template <typename Tnet, typename Tdata, typename Tlabel>  
00041   supervised_trainer<Tnet, Tdata, Tlabel>::
00042   supervised_trainer(trainable_module<Tnet,Tdata,Tlabel> &m,
00043                      parameter<Tnet, bbstate_idx<Tnet> > &p)
00044     : machine(m), param(p), energy(), answers(NULL), label(NULL), age(0),
00045       iteration(-1), iteration_ptr(NULL), prettied(false), progress_cnt(0),
00046       test_running(false) {
00047     energy.dx.set(1.0); // d(E)/dE is always 1
00048     energy.ddx.set(0.0); // dd(E)/dE is always 0
00049     cout << "Training with: " << m.describe() << endl;
00050   }
00051 
00052   template <typename Tnet, typename Tdata, typename Tlabel>  
00053   supervised_trainer<Tnet, Tdata, Tlabel>::~supervised_trainer() {
00054   }
00055                      
00056   // per-sample methods ////////////////////////////////////////////////////////
00057   
00058   template <typename Tnet, typename Tdata, typename Tlabel>  
00059   bool supervised_trainer<Tnet, Tdata, Tlabel>::
00060   test_sample(labeled_datasource<Tnet,Tdata,Tlabel> &ds,
00061               bbstate_idx<Tnet> &label, bbstate_idx<Tnet> &answers,
00062               infer_param &infp) {
00063     machine.compute_answers(answers);
00064     return machine.correct(answers, label);
00065   }
00066 
00067   template <typename Tnet, typename Tdata, typename Tlabel>  
00068   Tnet supervised_trainer<Tnet, Tdata, Tlabel>::
00069   train_sample(labeled_datasource<Tnet,Tdata,Tlabel> &ds, gd_param &args) {
00070     TIMING2("until train_sample");
00071     machine.fprop(ds, energy);
00072     param.clear_dx();
00073     machine.bprop(ds, energy);
00074     param.update(args);
00075     TIMING2("entire train_sample");
00076     return energy.x.get();
00077   }
00078 
00079   // epoch methods /////////////////////////////////////////////////////////////
00080 
00081   template <typename Tnet, typename Tdata, typename Tlabel> 
00082   void supervised_trainer<Tnet, Tdata, Tlabel>::
00083   test(labeled_datasource<Tnet, Tdata, Tlabel> &ds, classifier_meter &log,
00084        infer_param &infp, uint max_test) {
00085     init(ds, &log, true);
00086     idx<Tnet> target;
00087     uint ntest = ds.size();
00088     if (max_test > 0) { // limit the number of tests
00089       ntest = std::min(ntest, max_test);
00090       cout << "Limiting the number of tested samples to " << ntest << endl;
00091     }
00092     // loop
00093     uint i = 0;
00094     do {
00095       TIMING2("until beginning of sample test");
00096       ds.fprop_label_net(*label);
00097       machine.fprop(ds, energy);
00098       bool correct = test_sample(ds, *label, *answers, infp);
00099       target = machine.compute_targets(ds);
00100       machine.update_log(log, age, energy.x, answers->x, label->x, target,
00101                          machine.out1.x);
00102       // use energy as distance for samples probabilities to be used
00103       ds.set_sample_energy((double) energy.x.get(), correct, machine.out1.x,
00104                            answers->x, target);
00105       ds.pretty_progress();
00106       update_progress(); // tell the outside world we're still running
00107       TIMING2("sample test (" << machine.msin1 << ")");
00108     } while (ds.next() && i++ < ntest);
00109     ds.normalize_all_probas();
00110     // TODO: simplify this
00111     vector<string*> lblstr;
00112     class_datasource<Tnet,Tdata,Tlabel> *cds =
00113       dynamic_cast<class_datasource<Tnet,Tdata,Tlabel>*>(&ds);
00114     if (cds)
00115       lblstr = cds->get_label_strings();    
00116     log.display(iteration, ds.name(), &lblstr, ds.is_test());
00117     cout << endl;
00118   }
00119 
00120   template <typename Tnet, typename Tdata, typename Tlabel> 
00121   void supervised_trainer<Tnet, Tdata, Tlabel>::
00122   train(labeled_datasource<Tnet, Tdata, Tlabel> &ds, classifier_meter &log, 
00123         gd_param &gdp, int niter, infer_param &infp, 
00124         intg hessian_period, intg nhessian, double mu) {
00125     cout << "training on " << niter * ds.get_epoch_size();
00126     if (nhessian == 0) {
00127       cout << " samples and disabling 2nd order derivative calculation" << endl;
00128       param.set_epsilon(1.0);
00129     } else {
00130       cout << " samples and recomputing 2nd order "
00131            << "derivatives on " << nhessian << " samples after every "
00132            << hessian_period << " trained samples..." << endl;
00133     }
00134     timer t;
00135     init(ds, &log);
00136     bool selected = true, correct;
00137     idx<Tnet> target;
00138     for (int i = 0; i < niter; ++i) { // niter iterations
00139       t.start();
00140       ds.init_epoch();
00141       // training on lowest size common to all classes (times # classes)
00142       while (!ds.epoch_done()) {
00143         // recompute 2nd order derivatives
00144         if (hessian_period > 0 && nhessian > 0 &&
00145             ds.get_epoch_count() % hessian_period == 0)
00146           compute_diaghessian(ds, nhessian, mu);
00147         // get label
00148         ds.fprop_label_net(*label);
00149         if (selected) // selected for training
00150           train_sample(ds, gdp);
00151         // test if answer is correct
00152         correct = test_sample(ds, *label, *answers, infp);
00153         machine.update_log(log, age, energy.x, answers->x, label->x, target,
00154                            machine.out1.x);
00155         // use energy and answer as distance for samples probabilities
00156         target = machine.compute_targets(ds);
00157         ds.set_sample_energy((double) energy.x.get(), correct, machine.out1.x,
00158                              answers->x, target);
00159         //      log.update(age, output, label.get(), energy);
00160         age++;
00161         // select next sample
00162         selected = ds.next_train();
00163         ds.pretty_progress();
00164         // decrease learning rate if specified
00165         if (gdp.anneal_period > 0 && ((age - 1) % gdp.anneal_period) == 0) {
00166           gdp.eta = gdp.eta /
00167             (1 + ((age / gdp.anneal_period) * gdp.anneal_value));
00168           cout << "age: " << age << " updated eta=" << gdp.eta << endl;
00169         }
00170         update_progress(); // tell the outside world we're still running
00171       }
00172       ds.normalize_all_probas();
00173       cout << "epoch_count=" << ds.get_epoch_count() << endl;
00174       cout << "training_time="; t.pretty_elapsed();
00175       cout << endl;
00176       // report accuracy on trained sample
00177       if (test_running) {
00178         cout << "Training running test:" << endl;
00179         // TODO: simplify this      
00180         class_datasource<Tnet,Tdata,Tlabel> *cds =
00181           dynamic_cast<class_datasource<Tnet,Tdata,Tlabel>*>(&ds);
00182         log.display(iteration, ds.name(), cds ? cds->lblstr : NULL,
00183                     ds.is_test());
00184         cout << endl;
00185       }
00186     }
00187   }
00188 
00189   template <typename Tnet, typename Tdata, typename Tlabel> 
00190   void supervised_trainer<Tnet,Tdata,Tlabel>::
00191   compute_diaghessian(labeled_datasource<Tnet,Tdata,Tlabel> &ds, intg niter, 
00192                       double mu) {
00193     cout << "computing 2nd order derivatives on " << niter
00194          << " samples..." << endl;
00195     timer t;
00196     t.start();
00197 //     init(ds, NULL);
00198 //     ds.init_epoch();
00199     ds.save_state(); // save current ds state
00200     ds.set_count_pickings(false); // do not counts those samples in training
00201     param.clear_ddeltax();
00202     // loop
00203     for (int i = 0; i < niter; ++i) {
00204       machine.fprop(ds, energy);
00205       param.clear_dx();
00206       machine.bprop(ds, energy);
00207       param.clear_ddx();
00208       machine.bbprop(ds, energy);
00209       param.update_ddeltax((1 / (double) niter), 1.0);
00210       while (!ds.next_train()) ; // skipping all non selected samples
00211       ds.pretty_progress();
00212       update_progress(); // tell the outside world we're still running
00213     }
00214     ds.restore_state(); // set ds state back
00215     param.compute_epsilons(mu);
00216     cout << "diaghessian inf: " << idx_min(param.epsilons);
00217     cout << " sup: " << idx_max(param.epsilons);
00218     cout << " diaghessian_minutes=" << t.elapsed_minutes() << endl;
00219   }
00220 
00221   // accessors /////////////////////////////////////////////////////////////////
00222 
00223   template <typename Tnet, typename Tdata, typename Tlabel>  
00224   void supervised_trainer<Tnet, Tdata, Tlabel>::set_iteration(int i) {
00225     cout << "Setting iteration id to " << i << endl;
00226     iteration = i;
00227   }
00228                      
00229   template <typename Tnet, typename Tdata, typename Tlabel>
00230   void supervised_trainer<Tnet, Tdata, Tlabel>::
00231   pretty(labeled_datasource<Tnet, Tdata, Tlabel> &ds) {
00232     if (!prettied) {
00233       // pretty sizes of input/output for each module the first time
00234       mfidxdim d(ds.sample_mfdims());
00235       cout << "machine sizes: " << d << machine.mod1.pretty(d) << endl
00236            << "trainable parameters: " << param.x << endl;
00237       prettied = true;
00238     }
00239   }
00240 
00241   template <typename Tnet, typename Tdata, typename Tlabel>  
00242   void supervised_trainer<Tnet, Tdata, Tlabel>::
00243   set_progress_file(const std::string &f) {
00244     progress_file = f;
00245     cout << "Setting progress file to \"" << f << "\"" << endl;
00246   }
00247                      
00248   template <typename Tnet, typename Tdata, typename Tlabel>  
00249   void supervised_trainer<Tnet, Tdata, Tlabel>::update_progress() {
00250     progress_cnt++;
00251 #if __WINDOWS__ !=1
00252     // tell the outside world we are still running every 20 samples
00253     if (!progress_file.empty() && progress_cnt % 20 == 0)
00254       touch_file(progress_file);
00255 #endif
00256   }
00257                      
00258   // internal methods //////////////////////////////////////////////////////////
00259   
00260   template <typename Tnet, typename Tdata, typename Tlabel>
00261   void supervised_trainer<Tnet, Tdata, Tlabel>::
00262   init(labeled_datasource<Tnet, Tdata, Tlabel> &ds,
00263        classifier_meter *log, bool new_iteration) {
00264     pretty(ds); // pretty info
00265     // if not allocated, allocate answers. answers are allocated dynamically
00266     // based on ds dimensions because fstate_idx cannot change orders.
00267     idxdim d = ds.sample_dims();
00268     d.setdims(1);
00269     if (answers)
00270       delete answers;
00271     answers = new bbstate_idx<Tnet>(d);
00272     //
00273     idxdim dl = ds.label_dims();
00274     if (label)
00275       delete label;
00276     label = new bbstate_idx<Tnet>(dl);    
00277     // reinit ds
00278     ds.seek_begin();
00279     if (log) { // reinit logger
00280       class_datasource<Tnet,Tdata,Tlabel> *cds =
00281         dynamic_cast<class_datasource<Tnet,Tdata,Tlabel>* >(&ds);
00282       log->init(cds ? cds->get_nclasses() : 0);
00283     }
00284     // new iteration
00285     if (new_iteration) {
00286       if (!iteration_ptr) 
00287         iteration_ptr = (void *) &ds;
00288       if (iteration_ptr == (void *) &ds)
00289         ++iteration;
00290     }
00291   }
00292 
00293 } // end namespace ebl