libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2008 by Yann LeCun and Pierre Sermanet * 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #include <typeinfo> 00034 #include "utils.h" 00035 00036 namespace ebl { 00037 00038 // supervised_trainer //////////////////////////////////////////////////////// 00039 00040 template <typename Tnet, typename Tdata, typename Tlabel> 00041 supervised_trainer<Tnet, Tdata, Tlabel>:: 00042 supervised_trainer(trainable_module<Tnet,Tdata,Tlabel> &m, 00043 parameter<Tnet, bbstate_idx<Tnet> > &p) 00044 : machine(m), param(p), energy(), answers(NULL), label(NULL), age(0), 00045 iteration(-1), iteration_ptr(NULL), prettied(false), progress_cnt(0), 00046 test_running(false) { 00047 energy.dx.set(1.0); // d(E)/dE is always 1 00048 energy.ddx.set(0.0); // dd(E)/dE is always 0 00049 cout << "Training with: " << m.describe() << endl; 00050 } 00051 00052 template <typename Tnet, typename Tdata, typename Tlabel> 00053 supervised_trainer<Tnet, Tdata, Tlabel>::~supervised_trainer() { 00054 } 00055 00056 // per-sample methods //////////////////////////////////////////////////////// 00057 00058 template <typename Tnet, typename Tdata, typename Tlabel> 00059 bool supervised_trainer<Tnet, Tdata, Tlabel>:: 00060 test_sample(labeled_datasource<Tnet,Tdata,Tlabel> &ds, 00061 bbstate_idx<Tnet> &label, bbstate_idx<Tnet> &answers, 00062 infer_param &infp) { 00063 machine.compute_answers(answers); 00064 return machine.correct(answers, label); 00065 } 00066 00067 template <typename Tnet, typename Tdata, typename Tlabel> 00068 Tnet supervised_trainer<Tnet, Tdata, Tlabel>:: 00069 train_sample(labeled_datasource<Tnet,Tdata,Tlabel> &ds, gd_param &args) { 00070 TIMING2("until train_sample"); 00071 machine.fprop(ds, energy); 00072 param.clear_dx(); 00073 machine.bprop(ds, energy); 00074 param.update(args); 00075 TIMING2("entire train_sample"); 00076 return energy.x.get(); 00077 } 00078 00079 // epoch methods ///////////////////////////////////////////////////////////// 00080 00081 template <typename Tnet, typename Tdata, typename Tlabel> 00082 void supervised_trainer<Tnet, Tdata, Tlabel>:: 00083 test(labeled_datasource<Tnet, Tdata, Tlabel> &ds, classifier_meter &log, 00084 infer_param &infp, uint max_test) { 00085 init(ds, &log, true); 00086 idx<Tnet> target; 00087 uint ntest = ds.size(); 00088 if (max_test > 0) { // limit the number of tests 00089 ntest = std::min(ntest, max_test); 00090 cout << "Limiting the number of tested samples to " << ntest << endl; 00091 } 00092 // loop 00093 uint i = 0; 00094 do { 00095 TIMING2("until beginning of sample test"); 00096 ds.fprop_label_net(*label); 00097 machine.fprop(ds, energy); 00098 bool correct = test_sample(ds, *label, *answers, infp); 00099 target = machine.compute_targets(ds); 00100 machine.update_log(log, age, energy.x, answers->x, label->x, target, 00101 machine.out1.x); 00102 // use energy as distance for samples probabilities to be used 00103 ds.set_sample_energy((double) energy.x.get(), correct, machine.out1.x, 00104 answers->x, target); 00105 ds.pretty_progress(); 00106 update_progress(); // tell the outside world we're still running 00107 TIMING2("sample test (" << machine.msin1 << ")"); 00108 } while (ds.next() && i++ < ntest); 00109 ds.normalize_all_probas(); 00110 // TODO: simplify this 00111 vector<string*> lblstr; 00112 class_datasource<Tnet,Tdata,Tlabel> *cds = 00113 dynamic_cast<class_datasource<Tnet,Tdata,Tlabel>*>(&ds); 00114 if (cds) 00115 lblstr = cds->get_label_strings(); 00116 log.display(iteration, ds.name(), &lblstr, ds.is_test()); 00117 cout << endl; 00118 } 00119 00120 template <typename Tnet, typename Tdata, typename Tlabel> 00121 void supervised_trainer<Tnet, Tdata, Tlabel>:: 00122 train(labeled_datasource<Tnet, Tdata, Tlabel> &ds, classifier_meter &log, 00123 gd_param &gdp, int niter, infer_param &infp, 00124 intg hessian_period, intg nhessian, double mu) { 00125 cout << "training on " << niter * ds.get_epoch_size(); 00126 if (nhessian == 0) { 00127 cout << " samples and disabling 2nd order derivative calculation" << endl; 00128 param.set_epsilon(1.0); 00129 } else { 00130 cout << " samples and recomputing 2nd order " 00131 << "derivatives on " << nhessian << " samples after every " 00132 << hessian_period << " trained samples..." << endl; 00133 } 00134 timer t; 00135 init(ds, &log); 00136 bool selected = true, correct; 00137 idx<Tnet> target; 00138 for (int i = 0; i < niter; ++i) { // niter iterations 00139 t.start(); 00140 ds.init_epoch(); 00141 // training on lowest size common to all classes (times # classes) 00142 while (!ds.epoch_done()) { 00143 // recompute 2nd order derivatives 00144 if (hessian_period > 0 && nhessian > 0 && 00145 ds.get_epoch_count() % hessian_period == 0) 00146 compute_diaghessian(ds, nhessian, mu); 00147 // get label 00148 ds.fprop_label_net(*label); 00149 if (selected) // selected for training 00150 train_sample(ds, gdp); 00151 // test if answer is correct 00152 correct = test_sample(ds, *label, *answers, infp); 00153 machine.update_log(log, age, energy.x, answers->x, label->x, target, 00154 machine.out1.x); 00155 // use energy and answer as distance for samples probabilities 00156 target = machine.compute_targets(ds); 00157 ds.set_sample_energy((double) energy.x.get(), correct, machine.out1.x, 00158 answers->x, target); 00159 // log.update(age, output, label.get(), energy); 00160 age++; 00161 // select next sample 00162 selected = ds.next_train(); 00163 ds.pretty_progress(); 00164 // decrease learning rate if specified 00165 if (gdp.anneal_period > 0 && ((age - 1) % gdp.anneal_period) == 0) { 00166 gdp.eta = gdp.eta / 00167 (1 + ((age / gdp.anneal_period) * gdp.anneal_value)); 00168 cout << "age: " << age << " updated eta=" << gdp.eta << endl; 00169 } 00170 update_progress(); // tell the outside world we're still running 00171 } 00172 ds.normalize_all_probas(); 00173 cout << "epoch_count=" << ds.get_epoch_count() << endl; 00174 cout << "training_time="; t.pretty_elapsed(); 00175 cout << endl; 00176 // report accuracy on trained sample 00177 if (test_running) { 00178 cout << "Training running test:" << endl; 00179 // TODO: simplify this 00180 class_datasource<Tnet,Tdata,Tlabel> *cds = 00181 dynamic_cast<class_datasource<Tnet,Tdata,Tlabel>*>(&ds); 00182 log.display(iteration, ds.name(), cds ? cds->lblstr : NULL, 00183 ds.is_test()); 00184 cout << endl; 00185 } 00186 } 00187 } 00188 00189 template <typename Tnet, typename Tdata, typename Tlabel> 00190 void supervised_trainer<Tnet,Tdata,Tlabel>:: 00191 compute_diaghessian(labeled_datasource<Tnet,Tdata,Tlabel> &ds, intg niter, 00192 double mu) { 00193 cout << "computing 2nd order derivatives on " << niter 00194 << " samples..." << endl; 00195 timer t; 00196 t.start(); 00197 // init(ds, NULL); 00198 // ds.init_epoch(); 00199 ds.save_state(); // save current ds state 00200 ds.set_count_pickings(false); // do not counts those samples in training 00201 param.clear_ddeltax(); 00202 // loop 00203 for (int i = 0; i < niter; ++i) { 00204 machine.fprop(ds, energy); 00205 param.clear_dx(); 00206 machine.bprop(ds, energy); 00207 param.clear_ddx(); 00208 machine.bbprop(ds, energy); 00209 param.update_ddeltax((1 / (double) niter), 1.0); 00210 while (!ds.next_train()) ; // skipping all non selected samples 00211 ds.pretty_progress(); 00212 update_progress(); // tell the outside world we're still running 00213 } 00214 ds.restore_state(); // set ds state back 00215 param.compute_epsilons(mu); 00216 cout << "diaghessian inf: " << idx_min(param.epsilons); 00217 cout << " sup: " << idx_max(param.epsilons); 00218 cout << " diaghessian_minutes=" << t.elapsed_minutes() << endl; 00219 } 00220 00221 // accessors ///////////////////////////////////////////////////////////////// 00222 00223 template <typename Tnet, typename Tdata, typename Tlabel> 00224 void supervised_trainer<Tnet, Tdata, Tlabel>::set_iteration(int i) { 00225 cout << "Setting iteration id to " << i << endl; 00226 iteration = i; 00227 } 00228 00229 template <typename Tnet, typename Tdata, typename Tlabel> 00230 void supervised_trainer<Tnet, Tdata, Tlabel>:: 00231 pretty(labeled_datasource<Tnet, Tdata, Tlabel> &ds) { 00232 if (!prettied) { 00233 // pretty sizes of input/output for each module the first time 00234 mfidxdim d(ds.sample_mfdims()); 00235 cout << "machine sizes: " << d << machine.mod1.pretty(d) << endl 00236 << "trainable parameters: " << param.x << endl; 00237 prettied = true; 00238 } 00239 } 00240 00241 template <typename Tnet, typename Tdata, typename Tlabel> 00242 void supervised_trainer<Tnet, Tdata, Tlabel>:: 00243 set_progress_file(const std::string &f) { 00244 progress_file = f; 00245 cout << "Setting progress file to \"" << f << "\"" << endl; 00246 } 00247 00248 template <typename Tnet, typename Tdata, typename Tlabel> 00249 void supervised_trainer<Tnet, Tdata, Tlabel>::update_progress() { 00250 progress_cnt++; 00251 #if __WINDOWS__ !=1 00252 // tell the outside world we are still running every 20 samples 00253 if (!progress_file.empty() && progress_cnt % 20 == 0) 00254 touch_file(progress_file); 00255 #endif 00256 } 00257 00258 // internal methods ////////////////////////////////////////////////////////// 00259 00260 template <typename Tnet, typename Tdata, typename Tlabel> 00261 void supervised_trainer<Tnet, Tdata, Tlabel>:: 00262 init(labeled_datasource<Tnet, Tdata, Tlabel> &ds, 00263 classifier_meter *log, bool new_iteration) { 00264 pretty(ds); // pretty info 00265 // if not allocated, allocate answers. answers are allocated dynamically 00266 // based on ds dimensions because fstate_idx cannot change orders. 00267 idxdim d = ds.sample_dims(); 00268 d.setdims(1); 00269 if (answers) 00270 delete answers; 00271 answers = new bbstate_idx<Tnet>(d); 00272 // 00273 idxdim dl = ds.label_dims(); 00274 if (label) 00275 delete label; 00276 label = new bbstate_idx<Tnet>(dl); 00277 // reinit ds 00278 ds.seek_begin(); 00279 if (log) { // reinit logger 00280 class_datasource<Tnet,Tdata,Tlabel> *cds = 00281 dynamic_cast<class_datasource<Tnet,Tdata,Tlabel>* >(&ds); 00282 log->init(cds ? cds->get_nclasses() : 0); 00283 } 00284 // new iteration 00285 if (new_iteration) { 00286 if (!iteration_ptr) 00287 iteration_ptr = (void *) &ds; 00288 if (iteration_ptr == (void *) &ds) 00289 ++iteration; 00290 } 00291 } 00292 00293 } // end namespace ebl