libeblearntools
/home/rex/ebltrunk/tools/libeblearntools/include/train_utils.hpp
00001 /***************************************************************************
00002  *   Copyright (C) 2011 by Pierre Sermanet   *
00003  *   pierre.sermanet@gmail.com   *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source 
00014  *       Initiative (http://www.opensource.org) must display the 
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef TRAIN_UTILS_HPP_
00034 #define TRAIN_UTILS_HPP_
00035 
00036 namespace ebl {
00037 
00039   // testing and saving
00040 
00041   template <typename Tnet, typename Tdata, typename Tlabel>
00042   void test_and_save(uint iter, configuration &conf, string &conffname,
00043                      parameter<Tnet> &theparam,
00044                      supervised_trainer<Tnet,Tdata,Tlabel> &thetrainer,
00045                      labeled_datasource<Tnet,Tdata,Tlabel> &train_ds,
00046                      labeled_datasource<Tnet,Tdata,Tlabel> &test_ds,
00047                      classifier_meter &trainmeter,
00048                      classifier_meter &testmeter,
00049                      infer_param &infp, gd_param &gdp, string &shortname,
00050                      long iteration_seconds) {
00051     ostringstream wname, wfname;
00052     // save samples picking statistics
00053     if (conf.exists_true("save_pickings")) {
00054       string fname; fname << "pickings_" << iter;
00055       train_ds.save_pickings(fname.c_str());
00056     }
00057     // save weights and confusion matrix for test set
00058     wname.str("");
00059     if (conf.exists("job_name"))
00060       wname << conf.get_string("job_name");
00061     wname << "_net" << setfill('0') << setw(5) << iter;
00062     wfname.str(""); wfname << wname.str() << ".mat";
00063     if (conf.exists_false("save_weights"))
00064       cout << "Not saving weights (save_weights set to 0)." << endl;
00065     else {
00066       cout << "saving net to " << wfname.str() << endl;
00067       theparam.save_x(wfname.str().c_str()); // save trained network
00068       cout << "saved=" << wfname.str() << endl;
00069     }
00070     // test
00071     test(iter, conf, conffname, theparam, thetrainer, train_ds, test_ds,
00072          trainmeter, testmeter, infp, gdp, shortname);
00073     // set retrain to next iteration with current saved weights
00074     ostringstream progress;
00075     progress << "retrain_iteration = " << iter + 1 << endl
00076              << "retrain_weights = " << wfname.str() << endl;
00077     if (iteration_seconds > 0)
00078       progress << "meta_timeout = " << iteration_seconds * 1.2 << endl;
00079     // save progress
00080     job::write_progress(iter + 1, conf.get_uint("iterations"),
00081                         progress.str().c_str());
00082     // save confusion
00083     if (conf.exists_true("save_confusion")) {
00084       string fname; fname << wname.str() << "_confusion_test.mat";
00085       cout << "saving confusion to " << fname << endl;
00086       save_matrix(testmeter.get_confusion(), fname.c_str());
00087     }
00088   }
00089 
00090   template <typename Tnet, typename Tdata, typename Tlabel>
00091   void test(uint iter, configuration &conf, string &conffname,
00092             parameter<Tnet> &theparam,
00093             supervised_trainer<Tnet,Tdata,Tlabel> &thetrainer,
00094             labeled_datasource<Tnet,Tdata,Tlabel> &train_ds,
00095             labeled_datasource<Tnet,Tdata,Tlabel> &test_ds,
00096             classifier_meter &trainmeter,
00097             classifier_meter &testmeter,
00098             infer_param &infp, gd_param &gdp, string &shortname) {
00099     timer ttest;
00100     ostringstream wname, wfname;
00101 
00102     //   // some code to average several random solutions
00103     //     cout << "Testing...";
00104     //     if (original_tests > 1) cout << " (" << original_tests << " times)";
00105     //     cout << endl;
00106     //     ttest.restart();
00107     //     for (uint i = 0; i < original_tests; ++i) {
00108     //       if (test_only && original_tests > 1) {
00109     //  // we obviously wanna test several random solutions
00110     //  cout << "Initializing weights from random." << endl;
00111     //  thenet.forget(fgp);
00112     //       }
00113     //       if (!no_training_test)
00114     //  thetrainer.test(train_ds, trainmeter, infp);
00115     //       thetrainer.test(test_ds, testmeter, infp);
00116     //       cout << "testing_time="; ttest.pretty_elapsed(); cout << endl;
00117     //     }
00118     //     if (test_only && original_tests > 1) {
00119     //       // display averages over all tests
00120     //       testmeter.display_average(test_ds.name(), test_ds.lblstr, 
00121     //                          test_ds.is_test());
00122     //       trainmeter.display_average(train_ds.name(), train_ds.lblstr, 
00123     //                           train_ds.is_test());
00124     //     }
00125     cout << "Testing on " << test_ds.size() << " samples..." << endl;
00126     uint maxtest = conf.exists("max_testing") ? conf.get_uint("max_testing") :0;
00127     ttest.start();
00128     if (!conf.exists_true("no_training_test"))
00129       thetrainer.test(train_ds, trainmeter, infp, maxtest);     // test
00130     if (!conf.exists_true("no_testing_test"))
00131       thetrainer.test(test_ds, testmeter, infp, maxtest);       // test
00132     cout << "testing_time="; ttest.pretty_elapsed(); cout << endl;
00133     // detection test
00134     if (conf.exists_true("detection_test")) {
00135       uint dt_nthreads = 1;
00136       if (conf.exists("detection_test_nthreads"))
00137         dt_nthreads = conf.get_uint("detection_test_nthreads");
00138       timer dtest;
00139       dtest.start();
00140       // copy config file and augment it and detect it
00141       string cmd, params;
00142       if (conf.exists("detection_params")) {
00143         params = conf.get_string("detection_params");
00144         params = string_replaceall(params, "\\n", "\n");
00145       }
00146       cmd << "cp " << conffname << " tmp.conf && echo \"silent=1\n"
00147           << "nthreads=" << dt_nthreads << "\nevaluate=1\nweights_file=" 
00148           << wfname.str() << "\n" << params
00149           << "\" >> tmp.conf && detect tmp.conf";
00150       if (std::system(cmd.c_str()))
00151         cerr << "warning: failed to execute: " << cmd << endl;
00152       cout << "detection_test_time="; dtest.pretty_elapsed(); cout << endl;
00153     }
00154 #ifdef __GUI__ // display
00155     static supervised_trainer_gui<Tnet,Tdata,Tlabel> stgui(shortname.c_str());
00156     static supervised_trainer_gui<Tnet,Tdata,Tlabel> stgui2(shortname.c_str());
00157     bool display = conf.exists_true("show_train"); // enable/disable display
00158     uint ninternals = conf.exists("show_train_ninternals") ? 
00159       conf.get_uint("show_train_ninternals") : 1; // # examples' to display
00160     bool show_train_errors = conf.exists_true("show_train_errors");
00161     bool show_train_correct = conf.exists_true("show_train_correct");
00162     bool show_val_errors = conf.exists_true("show_val_errors");
00163     bool show_val_correct = conf.exists_true("show_val_correct");
00164     bool show_raw_outputs = conf.exists_true("show_raw_outputs");
00165     bool show_all_jitter = conf.exists_true("show_all_jitter");
00166     uint hsample = conf.exists("show_hsample") ?conf.get_uint("show_hsample"):5;
00167     uint wsample = conf.exists("show_wsample") ?conf.get_uint("show_wsample"):5;
00168     if (display) {
00169       cout << "Displaying training..." << endl;
00170       if (show_train_errors) {
00171         stgui2.display_correctness(true, true, thetrainer, train_ds, infp,
00172                                    hsample, wsample, show_raw_outputs,
00173                                    show_all_jitter);
00174         stgui2.display_correctness(true, false, thetrainer, train_ds, infp,
00175                                    hsample, wsample, show_raw_outputs,
00176                                    show_all_jitter);
00177       }
00178       if (show_train_correct) {
00179         stgui2.display_correctness(false, true, thetrainer, train_ds, infp,
00180                                    hsample, wsample, show_raw_outputs,
00181                                    show_all_jitter);
00182         stgui2.display_correctness(false, false, thetrainer, train_ds, infp,
00183                                    hsample, wsample, show_raw_outputs,
00184                                    show_all_jitter);
00185       }
00186       if (show_val_errors) {
00187         stgui.display_correctness(true, true, thetrainer, test_ds, infp,
00188                                   hsample, wsample, show_raw_outputs,
00189                                   show_all_jitter);
00190         stgui.display_correctness(true, false, thetrainer, test_ds, infp,
00191                                   hsample, wsample, show_raw_outputs,
00192                                   show_all_jitter);
00193       }
00194       if (show_val_correct) {
00195         stgui.display_correctness(false, true, thetrainer, test_ds, infp,
00196                                   hsample, wsample, show_raw_outputs,
00197                                   show_all_jitter);
00198         stgui.display_correctness(false, false, thetrainer, test_ds, infp,
00199                                   hsample, wsample, show_raw_outputs,
00200                                   show_all_jitter);
00201       }
00202       stgui.display_internals(thetrainer, test_ds, infp, gdp, ninternals);
00203     }
00204 #endif
00205   }
00206 
00207   template <typename Tnet, typename Tdata, typename Tlabel>
00208   labeled_datasource<Tnet,Tdata,Tlabel>* 
00209   create_validation_set(configuration &conf, uint &noutputs, string &valdata) {
00210     bool classification = conf.exists_true("classification");
00211     valdata = conf.get_string("val");
00212     string vallabels, valclasses, valjitters, valscales;
00213     vallabels = conf.try_get_string("val_labels");
00214     valclasses = conf.try_get_string("val_classes");
00215     valjitters = conf.try_get_string("val_jitters");
00216     valscales = conf.try_get_string("val_scales");
00217     uint maxval = 0;
00218     if (conf.exists("val_size")) maxval = conf.get_uint("val_size");
00219     labeled_datasource<Tnet,Tdata,Tlabel> *val_ds = NULL;
00220     if (classification) { // classification task
00221       class_datasource<Tnet,Tdata,Tlabel> *ds =
00222         new class_datasource<Tnet,Tdata,Tlabel>;
00223       ds->init(valdata.c_str(), vallabels.c_str(), valjitters.c_str(),
00224                valscales.c_str(), valclasses.c_str(), "val", maxval);
00225       if (conf.exists("limit_classes"))
00226         ds->limit_classes(conf.get_int("limit_classes"), 0, 
00227                           conf.exists_true("limit_classes_random"));
00228       noutputs = ds->get_nclasses();
00229       val_ds = ds;
00230     } else { // regression task
00231       val_ds = new labeled_datasource<Tnet,Tdata,Tlabel>;
00232       val_ds->init(valdata.c_str(), vallabels.c_str(), valjitters.c_str(),
00233                    valscales.c_str(), "val", maxval);
00234       idxdim d = val_ds->label_dims();
00235       noutputs = d.nelements();
00236     }
00237     val_ds->set_test(); // test is the test set, used for reporting
00238     val_ds->pretty();
00239     if (conf.exists("data_bias"))
00240       val_ds->set_data_bias((Tnet)conf.get_double("data_bias"));
00241     if (conf.exists("data_coeff"))
00242       val_ds->set_data_coeff((Tnet)conf.get_double("data_coeff"));
00243     if (conf.exists("label_bias"))
00244       val_ds->set_label_bias((Tnet)conf.get_double("label_bias"));
00245     if (conf.exists("label_coeff"))
00246       val_ds->set_label_coeff((Tnet)conf.get_double("label_coeff"));
00247     if (conf.exists("epoch_show_modulo"))
00248       val_ds->set_epoch_show(conf.get_uint("epoch_show_modulo"));
00249     val_ds->keep_outputs(conf.exists_true("keep_outputs"));
00250     return val_ds;    
00251   }
00252 
00253   template <typename Tnet, typename Tdata, typename Tlabel>
00254   labeled_datasource<Tnet,Tdata,Tlabel>* 
00255   create_training_set(configuration &conf, uint &noutputs, string &traindata) {
00256     bool classification = conf.exists_true("classification");
00257     traindata = conf.get_string("train");
00258     string trainlabels, trainclasses, trainjitters, trainscales;
00259     trainlabels = conf.try_get_string("train_labels");
00260     trainclasses = conf.try_get_string("train_classes");
00261     trainjitters = conf.try_get_string("train_jitters");
00262     trainscales = conf.try_get_string("train_scales");
00263     uint maxtrain = 0;
00264     if (conf.exists("train_size")) maxtrain = conf.get_uint("train_size");
00265     labeled_datasource<Tnet,Tdata,Tlabel> *train_ds = NULL;
00266     if (classification) { // classification task
00267       class_datasource<Tnet,Tdata,Tlabel> *ds =
00268         new class_datasource<Tnet,Tdata,Tlabel>;
00269       ds->init(traindata.c_str(), trainlabels.c_str(),
00270                trainjitters.c_str(), trainscales.c_str(), 
00271                trainclasses.c_str(), "train", maxtrain);
00272       if (conf.exists("balanced_training"))
00273         ds->set_balanced(conf.get_bool("balanced_training"));
00274       if (conf.exists("random_class_order"))
00275         ds->set_random_class_order(conf.get_bool("random_class_order"));
00276       if (conf.exists("limit_classes"))
00277         ds->limit_classes(conf.get_int("limit_classes"), 0, 
00278                           conf.exists_true("limit_classes_random"));
00279       noutputs = ds->get_nclasses();
00280       train_ds = ds;
00281     } else { // regression task
00282       train_ds = new labeled_datasource<Tnet,Tdata,Tlabel>;
00283       train_ds->init(traindata.c_str(), trainlabels.c_str(),
00284                      trainjitters.c_str(), trainscales.c_str(), 
00285                      "train", maxtrain);
00286       idxdim d = train_ds->label_dims();
00287       noutputs = d.nelements();
00288     }
00289     train_ds->ignore_correct(conf.exists_true("ignore_correct"));
00290     train_ds->set_weigh_samples(conf.exists_true("sample_probabilities"),
00291                                 conf.exists_true("hardest_focus"),
00292                                 conf.exists_true("per_class_norm"),
00293                                 conf.exists("min_sample_weight") ?
00294                                 conf.get_double("min_sample_weight") : 0.0);
00295     train_ds->set_shuffle_passes(conf.exists_bool("shuffle_passes"));
00296     if (conf.exists("epoch_size"))
00297       train_ds->set_epoch_size(conf.get_int("epoch_size"));
00298     if (conf.exists("epoch_mode"))
00299       train_ds->set_epoch_mode(conf.get_uint("epoch_mode"));
00300     if (conf.exists("epoch_show_modulo"))
00301       train_ds->set_epoch_show(conf.get_uint("epoch_show_modulo"));
00302     train_ds->pretty();
00303     if (conf.exists("data_bias"))
00304       train_ds->set_data_bias((Tnet)conf.get_double("data_bias"));
00305     if (conf.exists("data_coeff"))
00306       train_ds->set_data_coeff((Tnet)conf.get_double("data_coeff"));
00307     if (conf.exists("label_bias"))
00308       train_ds->set_label_bias((Tnet)conf.get_double("label_bias"));
00309     if (conf.exists("label_coeff"))
00310       train_ds->set_label_coeff((Tnet)conf.get_double("label_coeff"));
00311     train_ds->keep_outputs(conf.exists_true("keep_outputs"));
00312     return train_ds;
00313   }
00314   
00315 } // end namespace ebl
00316 
00317 #endif /* TRAIN_UTILS_HPP_ */
00318