libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2008 by Yann LeCun and Pierre Sermanet * 00003 * yann@cs.nyu.edu, pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef EBL_TRAINER_H_ 00034 #define EBL_TRAINER_H_ 00035 00036 #include "libidx.h" 00037 #include "ebl_arch.h" 00038 #include "ebl_machines.h" 00039 #include "ebl_logger.h" 00040 #include "datasource.h" 00041 #include "ebl_answer.h" 00042 00043 namespace ebl { 00044 00052 template<typename Tnet, typename Tdata, typename Tlabel> 00053 class supervised_trainer { 00054 public: 00056 supervised_trainer(trainable_module<Tnet,Tdata,Tlabel> &m, 00057 parameter<Tnet,bbstate_idx<Tnet> > &p); 00059 virtual ~supervised_trainer(); 00060 00061 // per-sample methods ////////////////////////////////////////////////////// 00062 00065 bool test_sample(labeled_datasource<Tnet,Tdata,Tlabel> &ds, 00066 bbstate_idx<Tnet> &label, bbstate_idx<Tnet> &answers, 00067 infer_param &infp); 00071 Tnet train_sample(labeled_datasource<Tnet,Tdata,Tlabel> &ds, 00072 gd_param &arg); 00073 00074 // epoch methods /////////////////////////////////////////////////////////// 00075 00079 void test(labeled_datasource<Tnet, Tdata, Tlabel> &ds, 00080 classifier_meter &log, infer_param &infp, 00081 uint max_test = 0); 00091 void train(labeled_datasource<Tnet, Tdata, Tlabel> &ds, 00092 classifier_meter &log, gd_param &args, int niter, 00093 infer_param &infp, 00094 intg hessian_period = 0, intg nhessian = 0, double mu = .02); 00096 void compute_diaghessian(labeled_datasource<Tnet, Tdata, Tlabel> &ds, 00097 intg niter, double mu); 00098 00099 // accessors /////////////////////////////////////////////////////////////// 00100 00103 void set_iteration(int i); 00105 void pretty(labeled_datasource<Tnet, Tdata, Tlabel> &ds); 00109 void set_progress_file(const std::string &s); 00112 void update_progress(); 00113 00114 // friends ///////////////////////////////////////////////////////////////// 00115 00116 // template <class Tdata, class Tlabel> friend class supervised_trainer_gui; 00117 template <class T1, class T2, class T3> friend class supervised_trainer_gui; 00118 00119 // internal methods //////////////////////////////////////////////////////// 00120 protected: 00121 00125 void init(labeled_datasource<Tnet, Tdata, Tlabel> &ds, 00126 classifier_meter *log = NULL, bool new_iteration = false); 00127 00128 // members ///////////////////////////////////////////////////////////////// 00129 protected: 00130 trainable_module<Tnet,Tdata,Tlabel> &machine; 00131 parameter<Tnet, bbstate_idx<Tnet> > ¶m; 00132 bbstate_idx<Tnet> energy; 00133 bbstate_idx<Tnet> *answers; 00134 bbstate_idx<Tnet> *label; 00135 intg age; 00136 int iteration; 00137 void *iteration_ptr; 00138 bool prettied; 00139 std::string progress_file; 00140 intg progress_cnt; 00141 bool test_running; 00142 }; 00143 00144 } // namespace ebl { 00145 00146 #include "ebl_trainer.hpp" 00147 00148 #endif /* EBL_TRAINER_H_ */