libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2011 by Pierre Sermanet * 00003 * pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef EBL_ANSWER_H_ 00034 #define EBL_ANSWER_H_ 00035 00036 #include "ebl_defines.h" 00037 #include "libidx.h" 00038 #include "ebl_states.h" 00039 #include "ebl_arch.h" 00040 #include "datasource.h" 00041 #include "ebl_nonlinearity.h" 00042 #include "ebl_energy.h" 00043 #include "ebl_logger.h" 00044 00045 using namespace std; 00046 00047 namespace ebl { 00048 00061 enum t_confidence { confidence_sqrdist = 0, 00062 confidence_single = 1, 00063 confidence_max = 2 }; 00064 00066 // answer modules 00067 00072 template <typename T, typename Tds1 = T, typename Tds2 = T, 00073 class Tstate = bbstate_idx<T> > 00074 class answer_module : public module_1_1<T,Tstate> { 00075 public: 00076 answer_module(uint nfeatures, const char *name = "answer_module"); 00077 virtual ~answer_module(); 00078 00079 // single-state propagation //////////////////////////////////////////////// 00082 virtual void fprop(Tstate &in, Tstate &out); 00084 virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out); 00087 virtual void bprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out); 00090 virtual void bbprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out); 00091 00092 // multi-state propagation ///////////////////////////////////////////////// 00094 virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, 00095 mstate<Tstate> &out); 00098 virtual void bprop(labeled_datasource<T,Tds1,Tds2> &ds, 00099 mstate<Tstate> &out); 00102 virtual void bbprop(labeled_datasource<T,Tds1,Tds2> &ds, 00103 mstate<Tstate> &out); 00104 00107 virtual bool correct(Tstate &answer, Tstate &label); 00109 virtual void update_log(classifier_meter &log, intg age, idx<T> &energy, 00110 idx<T> &answer, idx<T> &label, idx<T> &target, 00111 idx<T> &rawout); 00113 virtual void forget(forget_param_linear &fp); 00115 virtual std::string describe(); 00117 virtual uint get_nfeatures(); 00118 protected: 00119 uint nfeatures; 00120 }; 00121 00125 template <typename T, typename Tds1 = T, typename Tds2 = T, 00126 class Tstate = bbstate_idx<T> > 00127 class class_answer : public answer_module<T,Tds1,Tds2,Tstate> { 00128 public: 00136 class_answer(uint nclasses, double target_factor = 1.0, 00137 bool binary_target = false, 00138 t_confidence conf = confidence_max, 00139 bool apply_tanh = false, const char *name = "class_answer", 00140 int force_class = -1); 00141 virtual ~class_answer(); 00144 virtual void fprop(Tstate &in, Tstate &out); 00145 virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out); 00147 virtual bool correct(Tstate &answer, Tstate &label); 00149 virtual void update_log(classifier_meter &log, intg age, idx<T> &energy, 00150 idx<T> &answer, idx<T> &label, idx<T> &target, 00151 idx<T> &rawout); 00153 virtual std::string describe(); 00154 // members 00155 protected: 00156 idx<T> targets; 00157 idx<T> target; 00158 t_confidence conf_type; 00159 T conf_ratio; 00160 T conf_shift; 00161 bool binary_target; 00162 bool resize_output; 00163 bool apply_tanh; 00164 Tstate tmp; 00165 bbstate_idx<Tds2> last_label; 00166 tanh_module<T,Tstate> mtanh; 00167 T target_min; 00168 T target_max; 00169 int force_class; 00170 }; 00171 00176 template <typename T, typename Tds1 = T, typename Tds2 = T, 00177 class Tstate = bbstate_idx<T> > 00178 class scalerclass_answer : public class_answer<T,Tds1,Tds2,Tstate> { 00179 public: 00191 scalerclass_answer(uint nclasses, double target_factor = 1.0, 00192 bool binary_target = false, 00193 t_confidence conf = confidence_max, 00194 bool apply_tanh = false, 00195 uint jsize = 3, uint joffset = 0, 00196 float mgauss = 1.5, bool predict_conf = false, 00197 bool predict_bconf = false, 00198 idx<T> *biases = NULL, idx<T> *coeffs = NULL, 00199 const char *name = "scalerclass_answer"); 00200 virtual ~scalerclass_answer(); 00203 virtual void fprop(Tstate &in, Tstate &out); 00204 virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out); 00206 virtual void update_log(classifier_meter &log, intg age, idx<T> &energy, 00207 idx<T> &answer, idx<T> &label, idx<T> &target, 00208 idx<T> &rawout); 00210 virtual std::string describe(); 00211 // members 00212 protected: 00213 bbstate_idx<T> jitter; 00214 using class_answer<T,Tds1,Tds2,Tstate>::resize_output; 00215 using class_answer<T,Tds1,Tds2,Tstate>::targets; 00216 using class_answer<T,Tds1,Tds2,Tstate>::conf_type; 00217 using class_answer<T,Tds1,Tds2,Tstate>::apply_tanh; 00218 Tstate out_class; 00219 Tstate tmp1; 00220 Tstate tmp2; 00221 uint jsize; 00222 uint joffset; 00223 float scale_mgauss; 00224 bool predict_conf; 00225 bool predict_bconf; 00226 uint pconf_offset; 00227 std::vector<std::string> log_fields; 00228 idx<T> *biases; 00229 idx<T> *coeffs; 00230 }; 00231 00235 template <typename T, typename Tds1 = T, typename Tds2 = T, 00236 class Tstate = bbstate_idx<T> > 00237 class scaler_answer : public answer_module<T,Tds1,Tds2,Tstate> { 00238 public: 00244 scaler_answer(uint negative_id_, uint positive_id_, 00245 bool raw_confidence = false, float threshold = 0.0, 00246 bool spatial = false, 00247 const char *name = "scaler_answer"); 00248 virtual ~scaler_answer(); 00251 virtual void fprop(Tstate &in, Tstate &out); 00254 virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out); 00256 virtual std::string describe(); 00257 // members 00258 protected: 00259 uint negative_id; 00260 uint positive_id; 00261 bool raw_confidence; 00262 bbstate_idx<T> jitter; 00263 T threshold; 00264 bool spatial; 00265 uint jsize; 00266 }; 00267 00276 template <typename T, typename Tds1 = T, typename Tds2 = T, 00277 class Tstate = bbstate_idx<T> > 00278 class regression_answer : public answer_module<T,Tds1,Tds2,Tstate> { 00279 public: 00282 regression_answer(uint nfeatures, float64 threshold = 0.0, 00283 const char *name = "regression_answer"); 00284 virtual ~regression_answer(); 00286 virtual void fprop(Tstate &in, Tstate &out); 00289 virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, Tstate &out); 00292 virtual bool correct(Tstate &answer, Tstate &label); 00294 virtual void update_log(classifier_meter &log, intg age, idx<T> &energy, 00295 idx<T> &answer, idx<T> &label, idx<T> &target, 00296 idx<T> &rawout); 00298 virtual std::string describe(); 00299 // members 00300 protected: 00301 float64 threshold; 00302 }; 00303 00306 template <typename T, typename Tds1 = T, typename Tds2 = T, 00307 class Tstate = bbstate_idx<T> > 00308 class vote_answer : public class_answer<T,Tds1,Tds2,Tstate> { 00309 public: 00316 vote_answer(uint nclasses, double target_factor = 1.0, 00317 bool binary_target = false, 00318 t_confidence conf = confidence_max, 00319 bool apply_tanh = false, 00320 const char *name = "vote_answer"); 00321 virtual ~vote_answer(); 00324 virtual void fprop(Tstate &in, Tstate &out); 00325 }; 00326 00328 // trainable_module 00329 00334 template <typename T, typename Tds1 = T, typename Tds2 = T, 00335 class Tin1 = bbstate_idx<T>, class Tin2 = Tin1, class Ten = Tin1> 00336 class trainable_module { 00337 public: 00348 trainable_module(ebm_2<Tin1,Tin2,Ten> &energy, 00349 module_1_1<T,Tin1> &mod1, 00350 module_1_1<T,Tin2> *mod2 = NULL, 00351 answer_module<T,Tds1,Tds2,Tin1> *dsmod1 = NULL, 00352 answer_module<T,Tds1,Tds2,Tin2> *dsmod2 = NULL, 00353 const char *name = "trainable_module", 00354 const char *switcher = ""); 00355 virtual ~trainable_module(); 00356 00357 virtual void fprop(labeled_datasource<T,Tds1,Tds2> &ds, Ten &energy); 00358 virtual void bprop(labeled_datasource<T,Tds1,Tds2> &ds, Ten &energy); 00359 virtual void bbprop(labeled_datasource<T,Tds1,Tds2> &ds, Ten &energy); 00360 virtual int infer2(labeled_datasource<T,Tds1,Tds2> &ds, Ten &energy); 00362 virtual void forget(forget_param_linear &fp); 00366 virtual const Tin1& compute_answers(); 00370 virtual void compute_answers(Tin1 &ans); 00372 virtual bool correct(Tin1 &answer, Tin1 &label); 00374 virtual void update_log(classifier_meter &log, intg age, idx<T> &energy, 00375 idx<T> &answer, idx<T> &label, idx<T> &target, 00376 idx<T> &rawout); 00380 virtual idx<T> compute_targets(labeled_datasource<T,Tds1,Tds2> &ds); 00382 virtual const char* name(); 00384 virtual std::string describe(); 00385 00386 // friends ///////////////////////////////////////////////////////////////// 00387 friend class trainable_module_gui; 00388 template<class T1, class T2, class T3> friend class supervised_trainer; 00389 template<class T1, class T2, class T3> friend class supervised_trainer_gui; 00390 00391 // internal methods //////////////////////////////////////////////////////// 00392 protected: 00394 void update_scale(labeled_datasource<T,Tds1,Tds2> &ds); 00395 00396 // members ///////////////////////////////////////////////////////////////// 00397 protected: 00398 ebm_2<Tin1,Tin2,Ten> &energy_mod; 00399 module_1_1<T,Tin1,Tin1> &mod1; 00400 module_1_1<T,Tin1,Tin1> *mod2; 00401 answer_module<T,Tds1,Tds2,Tin1> *dsmod1; 00402 answer_module<T,Tds1,Tds2,Tin2> *dsmod2; 00403 // intermediate buffers 00404 mstate<Tin1> msin1; 00405 Tin1 in1; 00406 Tin1 out1; 00407 Tin2 in2; 00408 Tin2 out2; 00409 Tin1 answers; 00410 Tin1 targets; 00411 string mod_name; 00412 Ten tmp_energy; 00413 ms_module<T,Tin1> *ms_switch; 00414 }; 00415 00416 // utility functions ///////////////////////////////////////////////////////// 00417 00418 template <typename T> 00419 void print_targets(idx<T> &targets); 00420 00421 } // end namespace ebl 00422 00423 #include "ebl_answer.hpp" 00424 00425 #endif /* EBL_ANSWER_H_ */