libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2011 by Pierre Sermanet * 00003 * pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef DATASOURCE_H_ 00034 #define DATASOURCE_H_ 00035 00036 #include "ebl_defines.h" 00037 #include "libidx.h" 00038 #include "ebl_states.h" 00039 #include "ebl_arch.h" 00040 #include <map> 00041 00042 using namespace std; 00043 00044 namespace ebl { 00045 00050 template<typename Tnet, typename Tdata> class datasource { 00051 public: 00052 typedef map<uint,idx<Tdata> > t_pick_map; 00053 00055 datasource(); 00063 datasource(midx<Tdata> &data, const char *name = NULL); 00068 datasource(idx<Tdata> &data, const char *name = NULL); 00074 datasource(const char *data_fname, const char *name = NULL); 00076 virtual ~datasource(); 00077 00078 // intialization /////////////////////////////////////////////////////////// 00079 00081 void init(midx<Tdata> &data, const char *name); 00083 void init(idx<Tdata> &data, const char *name); 00084 00085 // data access methods ///////////////////////////////////////////////////// 00086 00088 template <class Tstate> 00089 void fprop_data(mstate<Tstate> &s); 00091 virtual void fprop_data(fstate_idx<Tnet> &s); 00093 virtual void fprop_data(bbstate_idx<Tnet> &s); 00095 virtual void fprop(bbstate_idx<Tnet> &s); 00097 virtual idx<Tdata> get_sample(intg index); 00103 virtual idx<Tnet> get_raw_output(intg index = -1); 00104 00105 // iterating methods /////////////////////////////////////////////////////// 00106 00109 virtual void select_sample(intg index); 00112 virtual void shuffle(); 00119 virtual bool next(); 00133 virtual bool next_train(); 00134 00135 // accessors /////////////////////////////////////////////////////////////// 00136 00138 virtual void set_data_bias(Tnet bias); 00140 virtual void set_data_coeff(Tnet coeff); 00142 virtual unsigned int size(); 00145 virtual idxdim sample_dims(); 00148 virtual mfidxdim sample_mfdims(); 00149 00171 virtual void set_sample_energy(double e, bool correct, 00172 idx<Tnet> &raw_outputs, 00173 idx<Tnet> &answers, 00174 idx<Tnet> &target); 00178 virtual void keep_outputs(bool keep = true); 00180 virtual void normalize_all_probas(); 00182 virtual void normalize_probas(vector<intg> *cindinces = NULL); 00185 virtual void seek_begin(); 00188 virtual void seek_begin_train(); 00194 virtual void set_shuffle_passes(bool activate); 00195 00215 virtual void set_weigh_samples(bool activate, bool hardest_focus = false, 00216 bool perclass_norm = true, 00217 double min_proba = 0.0); 00221 virtual void set_test(); 00223 virtual bool is_test(); 00228 virtual intg get_epoch_size(); 00230 virtual intg get_epoch_count(); 00234 virtual void set_epoch_size(intg sz); 00239 virtual void set_epoch_mode(uint mode); 00242 virtual bool epoch_done(); 00245 virtual void init_epoch(); 00248 virtual void save_pickings(const char *name = NULL); 00250 virtual bool get_count_pickings(); 00252 virtual void set_count_pickings(bool count = true); 00254 virtual string& name(); 00256 virtual void set_epoch_show(uint modulo); 00258 virtual void ignore_correct(bool ignore = true); 00260 virtual bool mstate_samples(); 00261 00262 // state saving //////////////////////////////////////////////////////////// 00263 00266 virtual void save_state(); 00268 virtual void restore_state(); 00269 00270 // pretty methods ////////////////////////////////////////////////////////// 00271 00273 virtual void pretty(); 00276 virtual void pretty_progress(bool newline = true); 00277 00279 protected: 00280 00281 // intialization /////////////////////////////////////////////////////////// 00282 00284 void init2(const char *name); 00285 00286 // picking methods ///////////////////////////////////////////////////////// 00287 00290 virtual bool pick_current(); 00292 virtual map<uint,intg>& get_pickings(); 00293 00294 // members ///////////////////////////////////////////////////////////////// 00295 public: 00296 Tnet bias; 00297 Tnet coeff; 00298 // data 00299 idx<Tdata> data; // samples 00300 midx<Tdata> datas; // samples (multi-matrix). 00301 idx<double> probas; 00302 // predictions 00303 idx<double> energies; 00304 idx<ubyte> correct; 00305 idx<Tnet> raw_outputs; 00306 idx<Tnet> answers; 00307 idx<Tnet> targets; 00308 // picking ///////////////////////////////////////////////////////////////// 00309 idx<uint> pick_count; 00310 bool count_pickings; 00311 bool count_pickings_save; 00312 unsigned int height; 00313 unsigned int width; 00314 string _name; 00315 protected: 00316 vector<intg> counts; // # of samples / class 00317 map<uint,intg> picksmap; 00318 bool multimat; 00319 bool bkeep_outputs; 00320 00321 // (unbalanced) iterating indices 00322 intg it; 00323 intg it_test; 00324 intg it_train; 00325 idx<intg> indices; 00326 00327 // state saving 00328 bool state_saved; 00329 intg it_saved; 00330 intg it_test_saved; 00331 intg it_train_saved; 00332 idx<intg> indices_saved; 00333 intg epoch_cnt_saved; 00334 intg epoch_pick_cnt_saved; 00335 vector<intg> epoch_done_counters_saved; 00337 // features switches 00338 bool shuffle_passes; 00339 bool test_set; 00340 00341 // epoch variables 00342 vector<intg> epoch_done_counters; 00343 intg epoch_sz; 00344 intg epoch_cnt; 00345 intg epoch_pick_cnt; 00346 uint epoch_show; // show modulo 00347 intg epoch_show_printed; 00348 uint epoch_mode; 00349 timer epoch_timer; 00350 timer test_timer; 00351 uint not_picked; 00352 bool hardest_focus; 00353 bool _ignore_correct; 00354 00355 // sample picking with probabilities 00356 bool weigh_samples; 00357 bool perclass_norm; 00358 double sample_min_proba; 00359 idxdim sampledims; 00360 mfidxdim samplemfdims; 00361 }; 00362 00365 00370 template <typename Tnet, typename Tdata, typename Tlabel> 00371 class labeled_datasource : public datasource<Tnet, Tdata> { 00372 public: 00374 labeled_datasource(); 00379 labeled_datasource(midx<Tdata> &data, idx<Tlabel> &labels, 00380 const char *name = NULL); 00383 labeled_datasource(idx<Tdata> &data, idx<Tlabel> &labels, 00384 const char *name = NULL); 00389 labeled_datasource(const char *root_ds, const char *name = NULL); 00394 labeled_datasource(const char *root, const char *data_name, 00395 const char *labels_name, const char *jitters_name = NULL, 00396 const char *scales_name = NULL, const char *name = NULL); 00398 virtual ~labeled_datasource(); 00399 00400 // init methods //////////////////////////////////////////////////////////// 00401 00403 void init(midx<Tdata> &data, idx<Tlabel> &labels, const char *name); 00405 void init(idx<Tdata> &data, idx<Tlabel> &labels, const char *name); 00408 void init(const char *data_fname, const char *labels_fname, 00409 const char *jitters_fname = NULL, const char *name = NULL, 00410 const char *scales_fname = NULL, uint max_size = 0); 00412 void init_root(const char *root, const char *data_fname, 00413 const char *labels_fname, const char *jitters_fname = NULL, 00414 const char *scales_fname = NULL, const char *name = NULL); 00418 void init_root(const char *root_dsname, const char *name = NULL); 00419 00420 // data access ///////////////////////////////////////////////////////////// 00421 00425 virtual void fprop(bbstate_idx<Tnet> &out, bbstate_idx<Tlabel> &label); 00427 virtual void fprop_label(fstate_idx<Tlabel> &s); 00430 virtual void fprop_label_net(fstate_idx<Tnet> &s); 00433 virtual void fprop_label_net(bbstate_idx<Tnet> &s); 00435 virtual void fprop_jitter(bbstate_idx<Tnet> &s); 00437 virtual intg fprop_scale(); 00438 00439 // accessors /////////////////////////////////////////////////////////////// 00440 00442 virtual bool included_sample(intg index); 00444 virtual intg count_included_samples(); 00446 virtual void pretty(); 00448 virtual void pretty_scales(); 00450 virtual idxdim label_dims(); 00452 virtual void set_label_bias(Tnet bias); 00454 virtual void set_label_coeff(Tnet coeff); 00456 virtual bool has_scales(); 00457 00458 // friends ////////////////////////////////////////////////////////////// 00459 template <typename T1, typename T2, typename T3> 00460 friend class labeled_datasource_gui; 00461 template <typename T1, typename T2, typename T3> 00462 friend class supervised_trainer; 00463 template <typename T1, typename T2, typename T3> 00464 friend class supervised_trainer_gui; 00465 00466 // protected methods /////////////////////////////////////////////////////// 00467 protected: 00468 00470 void init_labels(idx<Tlabel> &labels, const char *name); 00471 00472 // members ///////////////////////////////////////////////////////////////// 00473 protected: 00474 using datasource<Tnet,Tdata>::_name; 00475 // data 00476 Tnet label_bias; 00477 Tnet label_coeff; 00478 using datasource<Tnet,Tdata>::data; 00479 idx<Tlabel> labels; // labels 00480 idx<intg> scales; // scales 00481 midx<float> jitters; 00482 bool scales_loaded; 00483 // iterating 00484 using datasource<Tnet,Tdata>::it; 00485 using datasource<Tnet,Tdata>::epoch_sz; 00486 // dimensions 00487 using datasource<Tnet,Tdata>::sampledims; 00488 idxdim jitters_maxdim; 00489 idxdim labeldims; 00490 using datasource<Tnet,Tdata>::multimat; 00491 }; 00492 00495 00499 template <typename Tnet, typename Tdata, typename Tlabel> 00500 class class_datasource : public labeled_datasource<Tnet, Tdata, Tlabel> { 00501 public: 00503 class_datasource(); 00504 00510 class_datasource(midx<Tdata> &data, idx<Tlabel> &labels, 00511 vector<string*> *lblstr = NULL, const char *name = NULL); 00515 class_datasource(idx<Tdata> &data, idx<Tlabel> &labels, 00516 vector<string*> *lblstr = NULL, const char *name = NULL); 00522 class_datasource(midx<Tdata> &data, idx<Tlabel> &labels, 00523 idx<ubyte> &classes, const char *name = NULL); 00527 class_datasource(idx<Tdata> &data, idx<Tlabel> &labels, 00528 idx<ubyte> &classes, const char *name = NULL); 00532 class_datasource(const char *data_name, 00533 const char *labels_name, const char *jitters_name = NULL, 00534 const char *scales_name = NULL, 00535 const char *classes_name = NULL, const char *name = NULL); 00537 class_datasource(const class_datasource<Tnet, Tdata, Tlabel> &ds); 00539 virtual ~class_datasource(); 00540 00541 // init methods //////////////////////////////////////////////////////////// 00542 00544 void defaults(); 00546 virtual void init_strings(idx<ubyte> &classes); 00549 void init_local(vector<string*> *lblstr); 00551 void init(midx<Tdata> &data, idx<Tlabel> &labels, 00552 vector<string*> *lblstr, const char *name); 00554 void init(idx<Tdata> &data, idx<Tlabel> &labels, 00555 vector<string*> *lblstr, const char *name); 00558 void init(const char *data_fname, const char *labels_fname, 00559 const char *jitters_fname = NULL, const char *scales_fname = NULL, 00560 const char *classes_fname = NULL, const char *name = NULL, 00561 uint max_size = 0); 00563 void init_root(const char *root, const char *data_fname, 00564 const char *labels_fname, const char *jitters_fname = NULL, 00565 const char *scales_fname = NULL, 00566 const char *classes_fname = NULL, const char *name = NULL); 00570 void init_root(const char *root_dsname, const char *name = NULL); 00575 virtual void init_class_labels(); 00576 00577 // data access ///////////////////////////////////////////////////////////// 00578 00582 virtual Tlabel get_label(); 00583 00584 // iterating /////////////////////////////////////////////////////////////// 00585 00587 virtual bool included_sample(intg index); 00589 virtual intg count_included_samples(); 00592 virtual void seek_begin(); 00595 virtual void seek_begin_train(); 00602 virtual bool next(); 00617 virtual bool next_train(); 00621 virtual void next_balanced_class(); 00625 virtual void reset_class_order(); 00628 virtual void set_random_class_order(bool ran); 00631 virtual void limit_classes(intg n, intg offset = 0, bool random = false); 00637 virtual void set_balanced(bool bal = true); 00640 virtual bool epoch_done(); 00643 virtual void init_epoch(); 00646 virtual void normalize_all_probas(); 00649 virtual void normalize_probas(int classid = -1); 00650 00651 // accessors /////////////////////////////////////////////////////////////// 00652 00654 virtual intg get_nclasses(); 00656 virtual int get_class_id(const char *name); 00658 virtual string& get_class_name(int id); 00660 virtual vector<string*>& get_label_strings(); 00668 virtual intg get_lowest_common_size(); 00669 00670 // picking methods ///////////////////////////////////////////////////////// 00671 00674 virtual void save_pickings(const char *name = NULL); 00676 template <typename T> 00677 void write_classed_pickings(idx<T> &m, idx<ubyte> &correct, 00678 string &name_, const char *name2_ = NULL, 00679 bool plot_correct = true, 00680 const char *ylabel = ""); 00681 00682 // state saving //////////////////////////////////////////////////////////// 00683 00686 virtual void save_state(); 00688 virtual void restore_state(); 00689 00690 // pretty methods ////////////////////////////////////////////////////////// 00691 00693 virtual void pretty(); 00695 virtual void pretty_scales(); 00698 virtual void pretty_progress(bool newline = true); 00699 00700 // friends ////////////////////////////////////////////////////////////// 00701 template <typename T1, typename T2, typename T3> 00702 friend class class_datasource_gui; 00703 template <typename T1, typename T2, typename T3> 00704 friend class supervised_trainer; 00705 00706 protected: 00707 00710 virtual bool pick_current(); 00711 00712 // members ///////////////////////////////////////////////////////////////// 00713 protected: 00714 using datasource<Tnet,Tdata>::_name; 00715 // classes 00716 intg nclasses; 00717 vector<string*> *lblstr; 00718 vector<string*> clblstr; 00719 bool bexclusion; 00720 vector<bool> excluded; 00721 intg included; 00722 // data 00723 using datasource<Tnet,Tdata>::data; 00724 using labeled_datasource<Tnet,Tdata,Tlabel>::labels; 00725 idx<Tlabel> olabels; 00726 using datasource<Tnet,Tdata>::correct; 00727 using datasource<Tnet,Tdata>::energies; 00728 using datasource<Tnet,Tdata>::probas; 00729 // iterating 00730 using datasource<Tnet,Tdata>::it; 00731 using datasource<Tnet,Tdata>::epoch_mode; 00732 using datasource<Tnet,Tdata>::epoch_show; 00733 using datasource<Tnet,Tdata>::epoch_sz; 00734 using datasource<Tnet,Tdata>::epoch_timer; 00735 using datasource<Tnet,Tdata>::epoch_show_printed; 00736 // class-balanced iterating indices 00737 using datasource<Tnet,Tdata>::epoch_done_counters; 00738 bool balance; 00739 vector<vector<intg> > bal_indices; 00740 vector<uint> bal_it; 00741 vector<uint> class_order; 00742 bool random_class_order; 00743 uint class_it; 00744 uint class_it_it; 00745 // sample picking with probabilities 00746 bool perclass_norm; 00747 using datasource<Tnet,Tdata>::epoch_pick_cnt; 00748 using datasource<Tnet,Tdata>::epoch_cnt; 00749 using datasource<Tnet,Tdata>::count_pickings; 00750 using datasource<Tnet,Tdata>::pick_count; 00751 using datasource<Tnet,Tdata>::counts; 00752 using datasource<Tnet,Tdata>::weigh_samples; 00753 // state saving 00754 vector<vector<intg> > bal_indices_saved; 00755 vector<uint> bal_it_saved; 00756 uint class_it_saved; 00757 uint class_it_it_saved; 00758 using datasource<Tnet,Tdata>::state_saved; 00759 using datasource<Tnet,Tdata>::count_pickings_save; 00760 using datasource<Tnet,Tdata>::it_saved; 00761 using datasource<Tnet,Tdata>::it_test; 00762 using datasource<Tnet,Tdata>::it_test_saved; 00763 using datasource<Tnet,Tdata>::it_train; 00764 using datasource<Tnet,Tdata>::it_train_saved; 00765 // misc 00766 using datasource<Tnet,Tdata>::sampledims; 00767 using datasource<Tnet,Tdata>::test_set; 00768 using datasource<Tnet,Tdata>::shuffle_passes; 00769 using datasource<Tnet,Tdata>::not_picked; 00770 }; 00771 00774 00777 template <typename Tlabel> class class_node { 00778 public: 00781 class_node(Tlabel id, string &name); 00783 virtual ~class_node(); 00786 virtual bool empty(); 00788 virtual bool internally_empty(); 00793 virtual intg next(); 00796 virtual void add_child(class_node *child); 00798 virtual void add_sample(intg index); 00800 inline virtual intg nsamples(); 00802 inline virtual Tlabel label(); 00805 inline virtual Tlabel label(uint depth); 00807 inline virtual uint depth(); 00810 virtual uint set_depth(uint d); 00812 virtual string& name(); 00814 virtual class_node<Tlabel>* get_parent(); 00816 virtual bool is_parent(Tlabel lab); 00817 00818 protected: 00823 virtual void set_non_empty(); 00825 virtual void set_parent(class_node *parent); 00826 00827 // members ///////////////////////////////////////////////////////////////// 00828 protected: 00829 Tlabel _label; 00830 string &_name; 00831 // hierarchy variables ///////////////////////////////////////////////////// 00832 class_node *parent; 00833 vector<class_node<Tlabel>*> children; 00834 typename vector<class_node<Tlabel>*>::iterator it_children; 00835 bool bempty; 00836 bool iempty; 00837 uint _depth; 00838 // samples corresponding to this node ////////////////////////////////////// 00839 vector<intg> samples; 00840 vector<intg>::iterator it_samples; 00841 00842 // friends ///////////////////////////////////////////////////////////////// 00843 template <typename Tnet1, typename Tdata1, typename Tlabel1> 00844 friend class hierarchy_datasource; 00845 }; 00846 00849 00852 template <typename Tnet, typename Tdata, typename Tlabel> 00853 class hierarchy_datasource : public class_datasource<Tnet, Tdata, Tlabel> { 00854 public: 00856 hierarchy_datasource(); 00864 hierarchy_datasource(midx<Tdata> &data, idx<Tlabel> &labels, 00865 idx<Tlabel> *parents = NULL, 00866 vector<string*> *lblstr = NULL, 00867 const char *name = NULL); 00873 hierarchy_datasource(idx<Tdata> &data, idx<Tlabel> &labels, 00874 idx<Tlabel> *parents = NULL, 00875 vector<string*> *lblstr = NULL, 00876 const char *name = NULL); 00882 hierarchy_datasource(idx<Tdata> &data, idx<Tlabel> &labels, 00883 idx<Tlabel> *parents = NULL, 00884 idx<ubyte> *classes = NULL, const char *name = NULL); 00891 hierarchy_datasource(const char *data_name, const char *labels_name, 00892 const char *parents_name = NULL, 00893 const char *jitters_name = NULL, 00894 const char *scales_name = NULL, 00895 const char *classes_name = NULL, 00896 const char *name = NULL, 00897 uint max_size = 0); 00899 virtual ~hierarchy_datasource(); 00900 00901 // init methods //////////////////////////////////////////////////////////// 00902 00904 void init_parents(idx<Tlabel> *parents = NULL); 00909 virtual void init_class_labels(); 00910 00911 // data access ///////////////////////////////////////////////////////////// 00912 00916 virtual Tlabel get_parent(); 00918 virtual bool is_parent_of(Tlabel lab1, Tlabel lab2); 00920 virtual vector<class_node<Tlabel>*>& get_nodes(); 00923 virtual vector<class_node<Tlabel>*>& get_nodes_by_depth(); 00924 00926 virtual void fprop_label(fstate_idx<Tlabel> &s); 00929 virtual void fprop_label_net(fstate_idx<Tnet> &s); 00932 virtual void fprop_label_net(bbstate_idx<Tnet> &s); 00936 virtual Tlabel get_label(); 00942 virtual Tlabel get_label(uint depth, intg index = -1); 00945 idx<Tlabel>& get_depth_labels(); 00947 uint get_nbrothers(class_node<Tlabel> &n); 00948 00949 // iterating /////////////////////////////////////////////////////////////// 00950 00952 void set_depth_balanced(bool bal); 00958 virtual void set_current_depth(uint depth); 00960 virtual uint get_current_depth(); 00962 virtual void incr_current_depth(); 00980 virtual bool next_train(); 00981 00982 /* //! If 'bal' is true, make the next_train() method call sequentially one */ 00983 /* //! sample of each class instead of following the dataset's distribution. */ 00984 /* //! This is important to use when the dataset is unbalanced. */ 00985 /* //! This is set to true by default. */ 00986 /* //! Balance is used only by next_train(), not by next(). */ 00987 /* virtual void set_balanced(bool bal = true); */ 00988 /* //! Return true if current epoch is finished. Call init_epoch() to */ 00989 /* //! restart a new epoch. */ 00990 /* virtual bool epoch_done(); */ 00991 /* //! Restarts a new epoch, i.e. resets counters but do not reset iterators */ 00992 /* //! positions. */ 00993 /* virtual void init_epoch(); */ 00994 /* //! Normalize picking probabilities by maximum probability for all classes */ 00995 /* //! if perclass_norm is true, or globally otherwise. */ 00996 /* virtual void normalize_all_probas(); */ 00997 /* //! Normalize picking probabilities by maximum probability of classid if */ 00998 /* //! perclass_norm is true, or globally otherwise. */ 00999 /* virtual void normalize_probas(int classid = -1); */ 01000 01001 /* // accessors /////////////////////////////////////////////////////////////// */ 01002 01003 /* //! Return the number of classes. */ 01004 /* virtual intg get_nclasses(); */ 01005 /* //! Return the label id corresponding to name, or -1 if not found. */ 01006 /* virtual int get_class_id(const char *name); */ 01007 /* //! Return the label string for index id. */ 01008 /* virtual string& get_class_name(int id); */ 01009 /* //! Returns a reference to a vector of each label string. */ 01010 /* virtual vector<string*>& get_label_strings(); */ 01011 /* //! Return the lowest (non-zero) size per class, multiplied by the number */ 01012 /* //! of classes. */ 01013 /* //! e.g. if a dataset has 10 classes with 100 examples and 5 classes with */ 01014 /* //! 50 examples, it will return 50 * (10 + 5) = 750, whereas size() */ 01015 /* //! will return 1250. */ 01016 /* //! This is useful to keep iterations to a meaningful size when a class */ 01017 /* //! has many more examples than another. */ 01018 /* virtual intg get_lowest_common_size(); */ 01019 01020 /* // picking methods ///////////////////////////////////////////////////////// */ 01021 01022 /* //! Output statistics of samples picking, i.e. the number of times each */ 01023 /* //! sample has been picked for training. */ 01024 /* virtual void save_pickings(const char *name = NULL); */ 01025 /* //! Write plot of m organized by class and correctness */ 01026 /* template <typename T> */ 01027 /* void write_classed_pickings(idx<T> &m, idx<ubyte> &correct, */ 01028 /* string &name_, const char *name2_ = NULL, */ 01029 /* bool plot_correct = true, */ 01030 /* const char *ylabel = ""); */ 01031 01032 /* // state saving //////////////////////////////////////////////////////////// */ 01033 01034 /* //! Save internal iterators. Calling restore_state() will return to the */ 01035 /* //! current sample. */ 01036 /* virtual void save_state(); */ 01037 /* //! Restore previously saved internal iterators. */ 01038 /* virtual void restore_state(); */ 01039 01040 // pretty methods ////////////////////////////////////////////////////////// 01041 01043 virtual void pretty(); 01044 /* //! Pretty the progress of current epoch. */ 01045 /* //! \param newline If true, end pretty with a new line. */ 01046 /* virtual void pretty_progress(bool newline = true); */ 01047 01049 virtual void print_path(Tlabel l); 01050 01051 // friends ////////////////////////////////////////////////////////////// 01052 template <typename T1, typename T2, typename T3> 01053 friend class class_datasource_gui; 01054 01055 // members ///////////////////////////////////////////////////////////////// 01056 protected: 01057 // hierarchy 01058 vector<class_node<Tlabel>*> all_nodes; 01059 vector<class_node<Tlabel>*> all_nodes_by_depth; 01060 vector<vector<class_node<Tlabel>*>*> all_depths; 01061 01062 01063 01064 vector<vector<class_node<Tlabel>*>*> complete_depths; 01065 // data 01066 idx<Tlabel> *parents; 01067 idx<Tlabel> depth_labels; 01068 using labeled_datasource<Tnet,Tdata,Tlabel>::labels; 01069 using class_datasource<Tnet,Tdata,Tlabel>::olabels; 01070 using datasource<Tnet,Tdata>::_name; 01071 // classes 01072 using class_datasource<Tnet,Tdata,Tlabel>::nclasses; 01073 using class_datasource<Tnet,Tdata,Tlabel>::lblstr; 01074 using class_datasource<Tnet,Tdata,Tlabel>::clblstr; 01075 using datasource<Tnet,Tdata>::data; 01076 using datasource<Tnet,Tdata>::correct; 01077 using datasource<Tnet,Tdata>::energies; 01078 using datasource<Tnet,Tdata>::probas; 01079 // iterating 01080 uint current_depth; 01081 vector<uint> it_depths; 01082 using datasource<Tnet,Tdata>::it; 01083 /* using datasource<Tnet,Tdata>::epoch_mode; */ 01084 /* using datasource<Tnet,Tdata>::epoch_show; */ 01085 using datasource<Tnet,Tdata>::epoch_sz; 01086 /* using datasource<Tnet,Tdata>::epoch_timer; */ 01087 /* using datasource<Tnet,Tdata>::epoch_show_printed; */ 01088 // class-balanced iterating indices 01089 /* using datasource<Tnet,Tdata>::epoch_done_counters; */ 01090 bool depth_balance; 01091 using class_datasource<Tnet,Tdata,Tlabel>::balance; 01092 /* vector<vector<intg> > bal_indices; //!< Balanced iterating indices. */ 01093 /* vector<uint> bal_it; //!< Sample iterators for each class. */ 01094 /* uint class_it; //!< Iterator on classes. */ 01095 /* // sample picking with probabilities */ 01096 /* bool perclass_norm; //!< Normalize probas per class. */ 01097 using datasource<Tnet,Tdata>::epoch_pick_cnt; 01098 using datasource<Tnet,Tdata>::epoch_cnt; 01099 /* using datasource<Tnet,Tdata>::count_pickings; */ 01100 using datasource<Tnet,Tdata>::pick_count; 01101 /* using datasource<Tnet,Tdata>::counts; */ 01102 using datasource<Tnet,Tdata>::weigh_samples; 01103 /* // state saving */ 01104 /* vector<vector<intg> > bal_indices_saved; */ 01105 /* vector<uint> bal_it_saved; */ 01106 /* uint class_it_saved; */ 01107 /* using datasource<Tnet,Tdata>::state_saved; */ 01108 /* using datasource<Tnet,Tdata>::count_pickings_save; */ 01109 /* using datasource<Tnet,Tdata>::it_saved; */ 01110 /* using datasource<Tnet,Tdata>::it_test; */ 01111 /* using datasource<Tnet,Tdata>::it_test_saved; */ 01112 /* using datasource<Tnet,Tdata>::it_train; */ 01113 /* using datasource<Tnet,Tdata>::it_train_saved; */ 01114 // misc 01115 using datasource<Tnet,Tdata>::sampledims; 01116 using datasource<Tnet,Tdata>::test_set; 01117 /* using datasource<Tnet,Tdata>::shuffle_passes; */ 01118 /* using datasource<Tnet,Tdata>::not_picked; */ 01119 }; 01120 01123 template <typename Tnet, typename Tdata, typename Tlabel> 01124 class labeled_pair_datasource 01125 : public labeled_datasource<Tnet, Tdata, Tlabel> { 01126 public: 01127 idx<intg> pairs; 01128 /* typename idx<intg>::dimension_iterator pairsIter; */ 01129 01131 labeled_pair_datasource(const char *data_fname, const char *labels_fname, 01132 const char *classes_fname, const char *pairs_fname, 01133 const char *name_ = NULL, 01134 Tdata bias = 0, float coeff = 1.0); 01136 labeled_pair_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_, 01137 idx<ubyte> &classes_, idx<intg> &pairs_, 01138 const char *name_ = NULL, 01139 Tdata bias = 0, float coeff = 1.0); 01141 virtual ~labeled_pair_datasource(); 01142 01144 virtual void fprop(bbstate_idx<Tnet> &d1, bbstate_idx<Tnet> &d2, 01145 bbstate_idx<Tlabel> &label); 01147 virtual bool next(); 01150 virtual void seek_begin(); 01152 virtual unsigned int size(); 01153 }; 01154 01161 template <typename Tnet, typename Tdata, typename Tlabel> 01162 class mnist_datasource : public class_datasource<Tnet, Tdata, Tlabel> { 01163 public: 01169 mnist_datasource(const char *root, bool train_data, uint size); 01176 mnist_datasource(const char *root, const char *name, uint size); 01178 virtual ~mnist_datasource(); 01179 01181 virtual void fprop_data(bbstate_idx<Tnet> &s); 01182 01183 protected: 01185 virtual void init(idx<Tdata> &data, idx<Tlabel> &labels, const char *name); 01186 01187 // members ///////////////////////////////////////////////////////////////// 01188 public: 01189 using datasource<Tnet,Tdata>::bias; 01190 using datasource<Tnet,Tdata>::coeff; 01191 using datasource<Tnet,Tdata>::sampledims; 01192 using datasource<Tnet,Tdata>::height; 01193 using datasource<Tnet,Tdata>::width; 01194 using datasource<Tnet,Tdata>::data; 01195 using datasource<Tnet,Tdata>::datas; 01196 using datasource<Tnet,Tdata>::multimat; 01197 using labeled_datasource<Tnet,Tdata,Tlabel>::labels; 01198 using datasource<Tnet,Tdata>::it; 01199 }; 01200 01202 // Helper functions 01203 01207 template <typename Tdata> 01208 idx<Tdata> create_target_matrix(intg nclasses, Tdata target); 01209 01210 } // end namespace ebl 01211 01212 #include "datasource.hpp" 01213 01214 #endif /* DATASOURCE_H_ */