libeblearntools
|
00001 /*************************************************************************** 00002 * Copyright (C) 2011 by Pierre Sermanet * 00003 * pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef TRAIN_UTILS_HPP_ 00034 #define TRAIN_UTILS_HPP_ 00035 00036 namespace ebl { 00037 00039 // testing and saving 00040 00041 template <typename Tnet, typename Tdata, typename Tlabel> 00042 void test_and_save(uint iter, configuration &conf, string &conffname, 00043 parameter<Tnet> &theparam, 00044 supervised_trainer<Tnet,Tdata,Tlabel> &thetrainer, 00045 labeled_datasource<Tnet,Tdata,Tlabel> &train_ds, 00046 labeled_datasource<Tnet,Tdata,Tlabel> &test_ds, 00047 classifier_meter &trainmeter, 00048 classifier_meter &testmeter, 00049 infer_param &infp, gd_param &gdp, string &shortname, 00050 long iteration_seconds) { 00051 ostringstream wname, wfname; 00052 // save samples picking statistics 00053 if (conf.exists_true("save_pickings")) { 00054 string fname; fname << "pickings_" << iter; 00055 train_ds.save_pickings(fname.c_str()); 00056 } 00057 // save weights and confusion matrix for test set 00058 wname.str(""); 00059 if (conf.exists("job_name")) 00060 wname << conf.get_string("job_name"); 00061 wname << "_net" << setfill('0') << setw(5) << iter; 00062 wfname.str(""); wfname << wname.str() << ".mat"; 00063 if (conf.exists_false("save_weights")) 00064 cout << "Not saving weights (save_weights set to 0)." << endl; 00065 else { 00066 cout << "saving net to " << wfname.str() << endl; 00067 theparam.save_x(wfname.str().c_str()); // save trained network 00068 cout << "saved=" << wfname.str() << endl; 00069 } 00070 // test 00071 test(iter, conf, conffname, theparam, thetrainer, train_ds, test_ds, 00072 trainmeter, testmeter, infp, gdp, shortname); 00073 // set retrain to next iteration with current saved weights 00074 ostringstream progress; 00075 progress << "retrain_iteration = " << iter + 1 << endl 00076 << "retrain_weights = " << wfname.str() << endl; 00077 if (iteration_seconds > 0) 00078 progress << "meta_timeout = " << iteration_seconds * 1.2 << endl; 00079 // save progress 00080 job::write_progress(iter + 1, conf.get_uint("iterations"), 00081 progress.str().c_str()); 00082 // save confusion 00083 if (conf.exists_true("save_confusion")) { 00084 string fname; fname << wname.str() << "_confusion_test.mat"; 00085 cout << "saving confusion to " << fname << endl; 00086 save_matrix(testmeter.get_confusion(), fname.c_str()); 00087 } 00088 } 00089 00090 template <typename Tnet, typename Tdata, typename Tlabel> 00091 void test(uint iter, configuration &conf, string &conffname, 00092 parameter<Tnet> &theparam, 00093 supervised_trainer<Tnet,Tdata,Tlabel> &thetrainer, 00094 labeled_datasource<Tnet,Tdata,Tlabel> &train_ds, 00095 labeled_datasource<Tnet,Tdata,Tlabel> &test_ds, 00096 classifier_meter &trainmeter, 00097 classifier_meter &testmeter, 00098 infer_param &infp, gd_param &gdp, string &shortname) { 00099 timer ttest; 00100 ostringstream wname, wfname; 00101 00102 // // some code to average several random solutions 00103 // cout << "Testing..."; 00104 // if (original_tests > 1) cout << " (" << original_tests << " times)"; 00105 // cout << endl; 00106 // ttest.restart(); 00107 // for (uint i = 0; i < original_tests; ++i) { 00108 // if (test_only && original_tests > 1) { 00109 // // we obviously wanna test several random solutions 00110 // cout << "Initializing weights from random." << endl; 00111 // thenet.forget(fgp); 00112 // } 00113 // if (!no_training_test) 00114 // thetrainer.test(train_ds, trainmeter, infp); 00115 // thetrainer.test(test_ds, testmeter, infp); 00116 // cout << "testing_time="; ttest.pretty_elapsed(); cout << endl; 00117 // } 00118 // if (test_only && original_tests > 1) { 00119 // // display averages over all tests 00120 // testmeter.display_average(test_ds.name(), test_ds.lblstr, 00121 // test_ds.is_test()); 00122 // trainmeter.display_average(train_ds.name(), train_ds.lblstr, 00123 // train_ds.is_test()); 00124 // } 00125 cout << "Testing on " << test_ds.size() << " samples..." << endl; 00126 uint maxtest = conf.exists("max_testing") ? conf.get_uint("max_testing") :0; 00127 ttest.start(); 00128 if (!conf.exists_true("no_training_test")) 00129 thetrainer.test(train_ds, trainmeter, infp, maxtest); // test 00130 if (!conf.exists_true("no_testing_test")) 00131 thetrainer.test(test_ds, testmeter, infp, maxtest); // test 00132 cout << "testing_time="; ttest.pretty_elapsed(); cout << endl; 00133 // detection test 00134 if (conf.exists_true("detection_test")) { 00135 uint dt_nthreads = 1; 00136 if (conf.exists("detection_test_nthreads")) 00137 dt_nthreads = conf.get_uint("detection_test_nthreads"); 00138 timer dtest; 00139 dtest.start(); 00140 // copy config file and augment it and detect it 00141 string cmd, params; 00142 if (conf.exists("detection_params")) { 00143 params = conf.get_string("detection_params"); 00144 params = string_replaceall(params, "\\n", "\n"); 00145 } 00146 cmd << "cp " << conffname << " tmp.conf && echo \"silent=1\n" 00147 << "nthreads=" << dt_nthreads << "\nevaluate=1\nweights_file=" 00148 << wfname.str() << "\n" << params 00149 << "\" >> tmp.conf && detect tmp.conf"; 00150 if (std::system(cmd.c_str())) 00151 cerr << "warning: failed to execute: " << cmd << endl; 00152 cout << "detection_test_time="; dtest.pretty_elapsed(); cout << endl; 00153 } 00154 #ifdef __GUI__ // display 00155 static supervised_trainer_gui<Tnet,Tdata,Tlabel> stgui(shortname.c_str()); 00156 static supervised_trainer_gui<Tnet,Tdata,Tlabel> stgui2(shortname.c_str()); 00157 bool display = conf.exists_true("show_train"); // enable/disable display 00158 uint ninternals = conf.exists("show_train_ninternals") ? 00159 conf.get_uint("show_train_ninternals") : 1; // # examples' to display 00160 bool show_train_errors = conf.exists_true("show_train_errors"); 00161 bool show_train_correct = conf.exists_true("show_train_correct"); 00162 bool show_val_errors = conf.exists_true("show_val_errors"); 00163 bool show_val_correct = conf.exists_true("show_val_correct"); 00164 bool show_raw_outputs = conf.exists_true("show_raw_outputs"); 00165 bool show_all_jitter = conf.exists_true("show_all_jitter"); 00166 uint hsample = conf.exists("show_hsample") ?conf.get_uint("show_hsample"):5; 00167 uint wsample = conf.exists("show_wsample") ?conf.get_uint("show_wsample"):5; 00168 if (display) { 00169 cout << "Displaying training..." << endl; 00170 if (show_train_errors) { 00171 stgui2.display_correctness(true, true, thetrainer, train_ds, infp, 00172 hsample, wsample, show_raw_outputs, 00173 show_all_jitter); 00174 stgui2.display_correctness(true, false, thetrainer, train_ds, infp, 00175 hsample, wsample, show_raw_outputs, 00176 show_all_jitter); 00177 } 00178 if (show_train_correct) { 00179 stgui2.display_correctness(false, true, thetrainer, train_ds, infp, 00180 hsample, wsample, show_raw_outputs, 00181 show_all_jitter); 00182 stgui2.display_correctness(false, false, thetrainer, train_ds, infp, 00183 hsample, wsample, show_raw_outputs, 00184 show_all_jitter); 00185 } 00186 if (show_val_errors) { 00187 stgui.display_correctness(true, true, thetrainer, test_ds, infp, 00188 hsample, wsample, show_raw_outputs, 00189 show_all_jitter); 00190 stgui.display_correctness(true, false, thetrainer, test_ds, infp, 00191 hsample, wsample, show_raw_outputs, 00192 show_all_jitter); 00193 } 00194 if (show_val_correct) { 00195 stgui.display_correctness(false, true, thetrainer, test_ds, infp, 00196 hsample, wsample, show_raw_outputs, 00197 show_all_jitter); 00198 stgui.display_correctness(false, false, thetrainer, test_ds, infp, 00199 hsample, wsample, show_raw_outputs, 00200 show_all_jitter); 00201 } 00202 stgui.display_internals(thetrainer, test_ds, infp, gdp, ninternals); 00203 } 00204 #endif 00205 } 00206 00207 template <typename Tnet, typename Tdata, typename Tlabel> 00208 labeled_datasource<Tnet,Tdata,Tlabel>* 00209 create_validation_set(configuration &conf, uint &noutputs, string &valdata) { 00210 bool classification = conf.exists_true("classification"); 00211 valdata = conf.get_string("val"); 00212 string vallabels, valclasses, valjitters, valscales; 00213 vallabels = conf.try_get_string("val_labels"); 00214 valclasses = conf.try_get_string("val_classes"); 00215 valjitters = conf.try_get_string("val_jitters"); 00216 valscales = conf.try_get_string("val_scales"); 00217 uint maxval = 0; 00218 if (conf.exists("val_size")) maxval = conf.get_uint("val_size"); 00219 labeled_datasource<Tnet,Tdata,Tlabel> *val_ds = NULL; 00220 if (classification) { // classification task 00221 class_datasource<Tnet,Tdata,Tlabel> *ds = 00222 new class_datasource<Tnet,Tdata,Tlabel>; 00223 ds->init(valdata.c_str(), vallabels.c_str(), valjitters.c_str(), 00224 valscales.c_str(), valclasses.c_str(), "val", maxval); 00225 if (conf.exists("limit_classes")) 00226 ds->limit_classes(conf.get_int("limit_classes"), 0, 00227 conf.exists_true("limit_classes_random")); 00228 noutputs = ds->get_nclasses(); 00229 val_ds = ds; 00230 } else { // regression task 00231 val_ds = new labeled_datasource<Tnet,Tdata,Tlabel>; 00232 val_ds->init(valdata.c_str(), vallabels.c_str(), valjitters.c_str(), 00233 valscales.c_str(), "val", maxval); 00234 idxdim d = val_ds->label_dims(); 00235 noutputs = d.nelements(); 00236 } 00237 val_ds->set_test(); // test is the test set, used for reporting 00238 val_ds->pretty(); 00239 if (conf.exists("data_bias")) 00240 val_ds->set_data_bias((Tnet)conf.get_double("data_bias")); 00241 if (conf.exists("data_coeff")) 00242 val_ds->set_data_coeff((Tnet)conf.get_double("data_coeff")); 00243 if (conf.exists("label_bias")) 00244 val_ds->set_label_bias((Tnet)conf.get_double("label_bias")); 00245 if (conf.exists("label_coeff")) 00246 val_ds->set_label_coeff((Tnet)conf.get_double("label_coeff")); 00247 if (conf.exists("epoch_show_modulo")) 00248 val_ds->set_epoch_show(conf.get_uint("epoch_show_modulo")); 00249 val_ds->keep_outputs(conf.exists_true("keep_outputs")); 00250 return val_ds; 00251 } 00252 00253 template <typename Tnet, typename Tdata, typename Tlabel> 00254 labeled_datasource<Tnet,Tdata,Tlabel>* 00255 create_training_set(configuration &conf, uint &noutputs, string &traindata) { 00256 bool classification = conf.exists_true("classification"); 00257 traindata = conf.get_string("train"); 00258 string trainlabels, trainclasses, trainjitters, trainscales; 00259 trainlabels = conf.try_get_string("train_labels"); 00260 trainclasses = conf.try_get_string("train_classes"); 00261 trainjitters = conf.try_get_string("train_jitters"); 00262 trainscales = conf.try_get_string("train_scales"); 00263 uint maxtrain = 0; 00264 if (conf.exists("train_size")) maxtrain = conf.get_uint("train_size"); 00265 labeled_datasource<Tnet,Tdata,Tlabel> *train_ds = NULL; 00266 if (classification) { // classification task 00267 class_datasource<Tnet,Tdata,Tlabel> *ds = 00268 new class_datasource<Tnet,Tdata,Tlabel>; 00269 ds->init(traindata.c_str(), trainlabels.c_str(), 00270 trainjitters.c_str(), trainscales.c_str(), 00271 trainclasses.c_str(), "train", maxtrain); 00272 if (conf.exists("balanced_training")) 00273 ds->set_balanced(conf.get_bool("balanced_training")); 00274 if (conf.exists("random_class_order")) 00275 ds->set_random_class_order(conf.get_bool("random_class_order")); 00276 if (conf.exists("limit_classes")) 00277 ds->limit_classes(conf.get_int("limit_classes"), 0, 00278 conf.exists_true("limit_classes_random")); 00279 noutputs = ds->get_nclasses(); 00280 train_ds = ds; 00281 } else { // regression task 00282 train_ds = new labeled_datasource<Tnet,Tdata,Tlabel>; 00283 train_ds->init(traindata.c_str(), trainlabels.c_str(), 00284 trainjitters.c_str(), trainscales.c_str(), 00285 "train", maxtrain); 00286 idxdim d = train_ds->label_dims(); 00287 noutputs = d.nelements(); 00288 } 00289 train_ds->ignore_correct(conf.exists_true("ignore_correct")); 00290 train_ds->set_weigh_samples(conf.exists_true("sample_probabilities"), 00291 conf.exists_true("hardest_focus"), 00292 conf.exists_true("per_class_norm"), 00293 conf.exists("min_sample_weight") ? 00294 conf.get_double("min_sample_weight") : 0.0); 00295 train_ds->set_shuffle_passes(conf.exists_bool("shuffle_passes")); 00296 if (conf.exists("epoch_size")) 00297 train_ds->set_epoch_size(conf.get_int("epoch_size")); 00298 if (conf.exists("epoch_mode")) 00299 train_ds->set_epoch_mode(conf.get_uint("epoch_mode")); 00300 if (conf.exists("epoch_show_modulo")) 00301 train_ds->set_epoch_show(conf.get_uint("epoch_show_modulo")); 00302 train_ds->pretty(); 00303 if (conf.exists("data_bias")) 00304 train_ds->set_data_bias((Tnet)conf.get_double("data_bias")); 00305 if (conf.exists("data_coeff")) 00306 train_ds->set_data_coeff((Tnet)conf.get_double("data_coeff")); 00307 if (conf.exists("label_bias")) 00308 train_ds->set_label_bias((Tnet)conf.get_double("label_bias")); 00309 if (conf.exists("label_coeff")) 00310 train_ds->set_label_coeff((Tnet)conf.get_double("label_coeff")); 00311 train_ds->keep_outputs(conf.exists_true("keep_outputs")); 00312 return train_ds; 00313 } 00314 00315 } // end namespace ebl 00316 00317 #endif /* TRAIN_UTILS_HPP_ */ 00318