libeblearn
/home/rex/ebltrunk/core/libeblearn/include/datasource.h
00001 /***************************************************************************
00002  *   Copyright (C) 2011 by Pierre Sermanet   *
00003  *   pierre.sermanet@gmail.com   *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef DATASOURCE_H_
00034 #define DATASOURCE_H_
00035 
00036 #include "ebl_defines.h"
00037 #include "libidx.h"
00038 #include "ebl_states.h"
00039 #include "ebl_arch.h"
00040 #include <map>
00041 
00042 using namespace std;
00043 
00044 namespace ebl {
00045   
00050   template<typename Tnet, typename Tdata> class datasource {
00051   public:
00052     typedef map<uint,idx<Tdata> > t_pick_map;
00053     
00055     datasource(); 
00063     datasource(midx<Tdata> &data, const char *name = NULL);
00068     datasource(idx<Tdata> &data, const char *name = NULL);
00074     datasource(const char *data_fname, const char *name = NULL);
00076     virtual ~datasource();
00077 
00078     // intialization ///////////////////////////////////////////////////////////
00079     
00081     void init(midx<Tdata> &data, const char *name);
00083     void init(idx<Tdata> &data, const char *name);
00084     
00085     // data access methods /////////////////////////////////////////////////////
00086 
00088     template <class Tstate>
00089       void fprop_data(mstate<Tstate> &s);
00091     virtual void fprop_data(fstate_idx<Tnet> &s);
00093     virtual void fprop_data(bbstate_idx<Tnet> &s);
00095     virtual void fprop(bbstate_idx<Tnet> &s);
00097     virtual idx<Tdata> get_sample(intg index);
00103     virtual idx<Tnet> get_raw_output(intg index = -1);
00104 
00105     // iterating methods ///////////////////////////////////////////////////////
00106     
00109     virtual void select_sample(intg index);
00112     virtual void shuffle();
00119     virtual bool next();
00133     virtual bool next_train();
00134 
00135     // accessors ///////////////////////////////////////////////////////////////
00136 
00138     virtual void set_data_bias(Tnet bias);
00140     virtual void set_data_coeff(Tnet coeff);
00142     virtual unsigned int size();
00145     virtual idxdim sample_dims();
00148     virtual mfidxdim sample_mfdims();
00149     
00171     virtual void set_sample_energy(double e, bool correct,
00172                                    idx<Tnet> &raw_outputs,
00173                                    idx<Tnet> &answers,
00174                                    idx<Tnet> &target);
00178     virtual void keep_outputs(bool keep = true);
00180     virtual void normalize_all_probas();
00182     virtual void normalize_probas(vector<intg> *cindinces = NULL);
00185     virtual void seek_begin();
00188     virtual void seek_begin_train();
00194     virtual void set_shuffle_passes(bool activate);
00195 
00215     virtual void set_weigh_samples(bool activate, bool hardest_focus = false, 
00216                                    bool perclass_norm = true, 
00217                                    double min_proba = 0.0);    
00221     virtual void set_test();
00223     virtual bool is_test();
00228     virtual intg get_epoch_size();
00230     virtual intg get_epoch_count();
00234     virtual void set_epoch_size(intg sz);
00239     virtual void set_epoch_mode(uint mode);
00242     virtual bool epoch_done();
00245     virtual void init_epoch();
00248     virtual void save_pickings(const char *name = NULL);    
00250     virtual bool get_count_pickings();    
00252     virtual void set_count_pickings(bool count = true);    
00254     virtual string& name();
00256     virtual void set_epoch_show(uint modulo);
00258     virtual void ignore_correct(bool ignore = true);
00260     virtual bool mstate_samples();
00261 
00262     // state saving ////////////////////////////////////////////////////////////
00263     
00266     virtual void save_state();    
00268     virtual void restore_state();
00269 
00270     // pretty methods //////////////////////////////////////////////////////////
00271     
00273     virtual void pretty();
00276     virtual void pretty_progress(bool newline = true);
00277     
00279   protected:    
00280 
00281     // intialization ///////////////////////////////////////////////////////////
00282 
00284     void init2(const char *name);
00285 
00286     // picking methods /////////////////////////////////////////////////////////
00287     
00290     virtual bool pick_current();
00292     virtual map<uint,intg>& get_pickings();
00293 
00294     // members /////////////////////////////////////////////////////////////////
00295   public:
00296     Tnet                bias;
00297     Tnet                coeff;
00298     // data
00299     idx<Tdata>          data;   // samples
00300     midx<Tdata>         datas;  // samples (multi-matrix).
00301     idx<double>         probas; 
00302     // predictions
00303     idx<double>         energies;       
00304     idx<ubyte>          correct; 
00305     idx<Tnet>           raw_outputs; 
00306     idx<Tnet>           answers; 
00307     idx<Tnet>           targets;        
00308     // picking /////////////////////////////////////////////////////////////////
00309     idx<uint>           pick_count;     
00310     bool                count_pickings; 
00311     bool                count_pickings_save;
00312     unsigned int        height;
00313     unsigned int        width;
00314     string              _name;
00315   protected:
00316     vector<intg>        counts; // # of samples / class
00317     map<uint,intg>      picksmap;
00318     bool                multimat; 
00319     bool                bkeep_outputs; 
00320 
00321     // (unbalanced) iterating indices
00322     intg                it; 
00323     intg                it_test; 
00324     intg                it_train; 
00325     idx<intg>           indices; 
00326 
00327     // state saving
00328     bool                state_saved; 
00329     intg                it_saved; 
00330     intg                it_test_saved; 
00331     intg                it_train_saved; 
00332     idx<intg>           indices_saved; 
00333     intg                epoch_cnt_saved;
00334     intg                epoch_pick_cnt_saved;   
00335     vector<intg>        epoch_done_counters_saved;
00337     // features switches
00338     bool                shuffle_passes; 
00339     bool                test_set; 
00340 
00341     // epoch variables
00342     vector<intg>        epoch_done_counters;
00343     intg                epoch_sz;
00344     intg                epoch_cnt;
00345     intg                epoch_pick_cnt; 
00346     uint                epoch_show;     // show modulo
00347     intg                epoch_show_printed;
00348     uint                epoch_mode; 
00349     timer               epoch_timer;
00350     timer               test_timer;
00351     uint                not_picked;
00352     bool                hardest_focus; 
00353     bool                _ignore_correct; 
00354 
00355     // sample picking with probabilities
00356     bool                weigh_samples; 
00357     bool                perclass_norm; 
00358     double              sample_min_proba; 
00359     idxdim              sampledims; 
00360     mfidxdim            samplemfdims; 
00361   };
00362 
00365 
00370   template <typename Tnet, typename Tdata, typename Tlabel>
00371     class labeled_datasource : public datasource<Tnet, Tdata> {
00372   public:
00374     labeled_datasource();
00379     labeled_datasource(midx<Tdata> &data, idx<Tlabel> &labels, 
00380                        const char *name = NULL);
00383     labeled_datasource(idx<Tdata> &data, idx<Tlabel> &labels, 
00384                        const char *name = NULL);
00389     labeled_datasource(const char *root_ds, const char *name = NULL);
00394     labeled_datasource(const char *root, const char *data_name,
00395                        const char *labels_name, const char *jitters_name = NULL,
00396                        const char *scales_name = NULL, const char *name = NULL);
00398     virtual ~labeled_datasource();
00399 
00400     // init methods ////////////////////////////////////////////////////////////
00401     
00403     void init(midx<Tdata> &data, idx<Tlabel> &labels, const char *name);
00405     void init(idx<Tdata> &data, idx<Tlabel> &labels, const char *name);
00408     void init(const char *data_fname, const char *labels_fname,
00409               const char *jitters_fname = NULL, const char *name = NULL,
00410               const char *scales_fname = NULL, uint max_size = 0);
00412     void init_root(const char *root, const char *data_fname,
00413                    const char *labels_fname, const char *jitters_fname = NULL,
00414                    const char *scales_fname = NULL, const char *name = NULL);
00418     void init_root(const char *root_dsname, const char *name = NULL);
00419 
00420     // data access /////////////////////////////////////////////////////////////
00421 
00425     virtual void fprop(bbstate_idx<Tnet> &out, bbstate_idx<Tlabel> &label);
00427     virtual void fprop_label(fstate_idx<Tlabel> &s);
00430     virtual void fprop_label_net(fstate_idx<Tnet> &s);
00433     virtual void fprop_label_net(bbstate_idx<Tnet> &s);
00435     virtual void fprop_jitter(bbstate_idx<Tnet> &s);
00437     virtual intg fprop_scale();
00438 
00439     // accessors ///////////////////////////////////////////////////////////////
00440     
00442     virtual bool included_sample(intg index);
00444     virtual intg count_included_samples();
00446     virtual void pretty();
00448     virtual void pretty_scales();
00450     virtual idxdim label_dims();
00452     virtual void set_label_bias(Tnet bias);
00454     virtual void set_label_coeff(Tnet coeff);
00456     virtual bool has_scales();
00457 
00458     // friends //////////////////////////////////////////////////////////////
00459     template <typename T1, typename T2, typename T3>
00460       friend class labeled_datasource_gui;
00461     template <typename T1, typename T2, typename T3>
00462       friend class supervised_trainer;
00463     template <typename T1, typename T2, typename T3>
00464       friend class supervised_trainer_gui;
00465 
00466     // protected methods ///////////////////////////////////////////////////////
00467   protected:
00468 
00470     void init_labels(idx<Tlabel> &labels, const char *name);
00471 
00472     // members /////////////////////////////////////////////////////////////////
00473   protected:
00474     using datasource<Tnet,Tdata>::_name;
00475     // data
00476     Tnet                label_bias;
00477     Tnet                label_coeff;
00478     using datasource<Tnet,Tdata>::data;
00479     idx<Tlabel> labels;         // labels
00480     idx<intg>   scales;         // scales
00481     midx<float> jitters;        
00482     bool        scales_loaded; 
00483     // iterating
00484     using datasource<Tnet,Tdata>::it;
00485     using datasource<Tnet,Tdata>::epoch_sz;
00486     // dimensions
00487     using datasource<Tnet,Tdata>::sampledims;
00488     idxdim      jitters_maxdim; 
00489     idxdim      labeldims;      
00490     using datasource<Tnet,Tdata>::multimat;
00491   };
00492 
00495 
00499   template <typename Tnet, typename Tdata, typename Tlabel>
00500     class class_datasource : public labeled_datasource<Tnet, Tdata, Tlabel> {
00501   public:
00503     class_datasource();
00504 
00510     class_datasource(midx<Tdata> &data, idx<Tlabel> &labels, 
00511                      vector<string*> *lblstr = NULL, const char *name = NULL);
00515     class_datasource(idx<Tdata> &data, idx<Tlabel> &labels, 
00516                      vector<string*> *lblstr = NULL, const char *name = NULL);
00522     class_datasource(midx<Tdata> &data, idx<Tlabel> &labels,
00523                      idx<ubyte> &classes, const char *name = NULL);
00527     class_datasource(idx<Tdata> &data, idx<Tlabel> &labels,
00528                      idx<ubyte> &classes, const char *name = NULL);
00532     class_datasource(const char *data_name,
00533                      const char *labels_name, const char *jitters_name = NULL,
00534                      const char *scales_name = NULL, 
00535                      const char *classes_name = NULL, const char *name = NULL);
00537     class_datasource(const class_datasource<Tnet, Tdata, Tlabel> &ds);
00539     virtual ~class_datasource();
00540 
00541     // init methods ////////////////////////////////////////////////////////////
00542 
00544     void defaults();
00546     virtual void init_strings(idx<ubyte> &classes);
00549     void init_local(vector<string*> *lblstr);
00551     void init(midx<Tdata> &data, idx<Tlabel> &labels, 
00552               vector<string*> *lblstr, const char *name);
00554     void init(idx<Tdata> &data, idx<Tlabel> &labels, 
00555               vector<string*> *lblstr, const char *name);
00558     void init(const char *data_fname, const char *labels_fname,
00559               const char *jitters_fname = NULL, const char *scales_fname = NULL,
00560               const char *classes_fname = NULL, const char *name = NULL,
00561               uint max_size = 0);
00563     void init_root(const char *root, const char *data_fname,
00564                    const char *labels_fname, const char *jitters_fname = NULL,
00565                    const char *scales_fname = NULL,
00566                    const char *classes_fname = NULL, const char *name = NULL);
00570     void init_root(const char *root_dsname, const char *name = NULL);
00575     virtual void init_class_labels();
00576     
00577     // data access /////////////////////////////////////////////////////////////
00578     
00582     virtual Tlabel get_label();
00583 
00584     // iterating ///////////////////////////////////////////////////////////////
00585     
00587     virtual bool included_sample(intg index);
00589     virtual intg count_included_samples();
00592     virtual void seek_begin();
00595     virtual void seek_begin_train();
00602     virtual bool next();
00617     virtual bool next_train();
00621     virtual void next_balanced_class();
00625     virtual void reset_class_order();
00628     virtual void set_random_class_order(bool ran);
00631     virtual void limit_classes(intg n, intg offset = 0, bool random = false);
00637     virtual void set_balanced(bool bal = true);
00640     virtual bool epoch_done();
00643     virtual void init_epoch();
00646     virtual void normalize_all_probas();
00649     virtual void normalize_probas(int classid = -1);
00650     
00651     // accessors ///////////////////////////////////////////////////////////////
00652     
00654     virtual intg get_nclasses();
00656     virtual int get_class_id(const char *name);
00658     virtual string& get_class_name(int id);
00660     virtual vector<string*>& get_label_strings();
00668     virtual intg get_lowest_common_size();
00669 
00670     // picking methods /////////////////////////////////////////////////////////
00671 
00674     virtual void save_pickings(const char *name = NULL);    
00676     template <typename T>
00677       void write_classed_pickings(idx<T> &m, idx<ubyte> &correct,
00678                                   string &name_, const char *name2_ = NULL,
00679                                   bool plot_correct = true,
00680                                   const char *ylabel = "");    
00681 
00682     // state saving ////////////////////////////////////////////////////////////
00683     
00686     virtual void save_state();    
00688     virtual void restore_state();
00689 
00690     // pretty methods //////////////////////////////////////////////////////////
00691 
00693     virtual void pretty();
00695     virtual void pretty_scales();
00698     virtual void pretty_progress(bool newline = true);
00699     
00700     // friends //////////////////////////////////////////////////////////////
00701     template <typename T1, typename T2, typename T3>
00702       friend class class_datasource_gui;
00703     template <typename T1, typename T2, typename T3>
00704       friend class supervised_trainer;
00705 
00706   protected:
00707     
00710     virtual bool pick_current();
00711     
00712     // members /////////////////////////////////////////////////////////////////
00713   protected:
00714     using datasource<Tnet,Tdata>::_name;
00715     // classes
00716     intg                 nclasses;      
00717     vector<string*>     *lblstr;        
00718     vector<string*>      clblstr;       
00719     bool                 bexclusion;    
00720     vector<bool>         excluded;      
00721     intg                 included;      
00722     // data
00723     using datasource<Tnet,Tdata>::data;
00724     using labeled_datasource<Tnet,Tdata,Tlabel>::labels;
00725     idx<Tlabel> olabels; 
00726     using datasource<Tnet,Tdata>::correct;
00727     using datasource<Tnet,Tdata>::energies;
00728     using datasource<Tnet,Tdata>::probas;
00729     // iterating
00730     using datasource<Tnet,Tdata>::it;
00731     using datasource<Tnet,Tdata>::epoch_mode;
00732     using datasource<Tnet,Tdata>::epoch_show;
00733     using datasource<Tnet,Tdata>::epoch_sz;
00734     using datasource<Tnet,Tdata>::epoch_timer;
00735     using datasource<Tnet,Tdata>::epoch_show_printed;
00736     // class-balanced iterating indices
00737     using datasource<Tnet,Tdata>::epoch_done_counters;
00738     bool                 balance;       
00739     vector<vector<intg> > bal_indices;  
00740     vector<uint>         bal_it;        
00741     vector<uint>         class_order;   
00742     bool                 random_class_order;
00743     uint                 class_it;      
00744     uint                 class_it_it;   
00745     // sample picking with probabilities
00746     bool                 perclass_norm; 
00747     using datasource<Tnet,Tdata>::epoch_pick_cnt;       
00748     using datasource<Tnet,Tdata>::epoch_cnt;    
00749     using datasource<Tnet,Tdata>::count_pickings;
00750     using datasource<Tnet,Tdata>::pick_count;
00751     using datasource<Tnet,Tdata>::counts;
00752     using datasource<Tnet,Tdata>::weigh_samples;
00753     // state saving
00754     vector<vector<intg> > bal_indices_saved;
00755     vector<uint>        bal_it_saved;
00756     uint                class_it_saved;
00757     uint                class_it_it_saved;
00758     using datasource<Tnet,Tdata>::state_saved;
00759     using datasource<Tnet,Tdata>::count_pickings_save;
00760     using datasource<Tnet,Tdata>::it_saved;
00761     using datasource<Tnet,Tdata>::it_test;
00762     using datasource<Tnet,Tdata>::it_test_saved;
00763     using datasource<Tnet,Tdata>::it_train;
00764     using datasource<Tnet,Tdata>::it_train_saved;
00765     // misc 
00766     using datasource<Tnet,Tdata>::sampledims;
00767     using datasource<Tnet,Tdata>::test_set;
00768     using datasource<Tnet,Tdata>::shuffle_passes;
00769     using datasource<Tnet,Tdata>::not_picked;
00770   };
00771 
00774 
00777   template <typename Tlabel> class class_node {
00778   public:
00781     class_node(Tlabel id, string &name);
00783     virtual ~class_node();
00786     virtual bool empty();
00788     virtual bool internally_empty();
00793     virtual intg next();
00796     virtual void add_child(class_node *child);
00798     virtual void add_sample(intg index);
00800     inline virtual intg nsamples();
00802     inline virtual Tlabel label();
00805     inline virtual Tlabel label(uint depth);
00807     inline virtual uint depth();
00810     virtual uint set_depth(uint d);
00812     virtual string& name();
00814     virtual class_node<Tlabel>* get_parent();
00816     virtual bool is_parent(Tlabel lab);
00817       
00818   protected:
00823     virtual void set_non_empty();
00825     virtual void set_parent(class_node *parent);
00826     
00827     // members /////////////////////////////////////////////////////////////////
00828   protected:
00829     Tlabel               _label; 
00830     string               &_name;
00831     // hierarchy variables /////////////////////////////////////////////////////
00832     class_node          *parent;        
00833     vector<class_node<Tlabel>*>  children;      
00834     typename vector<class_node<Tlabel>*>::iterator it_children;
00835     bool bempty; 
00836     bool iempty; 
00837     uint _depth; 
00838     // samples corresponding to this node //////////////////////////////////////
00839     vector<intg>         samples;       
00840     vector<intg>::iterator it_samples;  
00841 
00842     // friends /////////////////////////////////////////////////////////////////
00843     template <typename Tnet1, typename Tdata1, typename Tlabel1>
00844       friend class hierarchy_datasource;
00845   };
00846   
00849 
00852   template <typename Tnet, typename Tdata, typename Tlabel>
00853     class hierarchy_datasource : public class_datasource<Tnet, Tdata, Tlabel> {
00854   public:
00856     hierarchy_datasource();
00864     hierarchy_datasource(midx<Tdata> &data, idx<Tlabel> &labels,
00865                          idx<Tlabel> *parents = NULL,
00866                          vector<string*> *lblstr = NULL,
00867                          const char *name = NULL);
00873     hierarchy_datasource(idx<Tdata> &data, idx<Tlabel> &labels,
00874                          idx<Tlabel> *parents = NULL,
00875                          vector<string*> *lblstr = NULL,
00876                          const char *name = NULL);
00882     hierarchy_datasource(idx<Tdata> &data, idx<Tlabel> &labels,
00883                          idx<Tlabel> *parents = NULL,
00884                          idx<ubyte> *classes = NULL, const char *name = NULL);
00891     hierarchy_datasource(const char *data_name, const char *labels_name,
00892                          const char *parents_name = NULL,
00893                          const char *jitters_name = NULL,
00894                          const char *scales_name = NULL,
00895                          const char *classes_name = NULL,
00896                          const char *name = NULL,
00897                          uint max_size = 0);
00899     virtual ~hierarchy_datasource();
00900 
00901     // init methods ////////////////////////////////////////////////////////////
00902 
00904     void init_parents(idx<Tlabel> *parents = NULL);
00909     virtual void init_class_labels();
00910     
00911     // data access /////////////////////////////////////////////////////////////
00912     
00916     virtual Tlabel get_parent();
00918     virtual bool is_parent_of(Tlabel lab1, Tlabel lab2);
00920     virtual vector<class_node<Tlabel>*>& get_nodes();
00923     virtual vector<class_node<Tlabel>*>& get_nodes_by_depth();
00924 
00926     virtual void fprop_label(fstate_idx<Tlabel> &s);
00929     virtual void fprop_label_net(fstate_idx<Tnet> &s);
00932     virtual void fprop_label_net(bbstate_idx<Tnet> &s);
00936     virtual Tlabel get_label();
00942     virtual Tlabel get_label(uint depth, intg index = -1);
00945     idx<Tlabel>& get_depth_labels();
00947     uint get_nbrothers(class_node<Tlabel> &n);
00948     
00949     // iterating ///////////////////////////////////////////////////////////////
00950 
00952     void set_depth_balanced(bool bal);
00958     virtual void set_current_depth(uint depth);
00960     virtual uint get_current_depth();
00962     virtual void incr_current_depth();
00980     virtual bool next_train();
00981 
00982     /* //! If 'bal' is true, make the next_train() method call sequentially one */
00983     /* //! sample of each class instead of following the dataset's distribution. */
00984     /* //! This is important to use when the dataset is unbalanced. */
00985     /* //! This is set to true by default. */
00986     /* //! Balance is used only by next_train(), not by next(). */
00987     /* virtual void set_balanced(bool bal = true); */
00988     /* //! Return true if current epoch is finished. Call init_epoch() to */
00989     /* //! restart a new epoch. */
00990     /* virtual bool epoch_done(); */
00991     /* //! Restarts a new epoch, i.e. resets counters but do not reset iterators */
00992     /* //! positions. */
00993     /* virtual void init_epoch(); */
00994     /* //! Normalize picking probabilities by maximum probability for all classes */
00995     /* //! if perclass_norm is true, or globally otherwise. */
00996     /* virtual void normalize_all_probas(); */
00997     /* //! Normalize picking probabilities by maximum probability of classid if */
00998     /* //! perclass_norm is true, or globally otherwise. */
00999     /* virtual void normalize_probas(int classid = -1); */
01000     
01001     /* // accessors /////////////////////////////////////////////////////////////// */
01002     
01003     /* //! Return the number of classes. */
01004     /* virtual intg get_nclasses(); */
01005     /* //! Return the label id corresponding to name, or -1 if not found. */
01006     /* virtual int get_class_id(const char *name); */
01007     /* //! Return the label string for index id. */
01008     /* virtual string& get_class_name(int id); */
01009     /* //! Returns a reference to a vector of each label string. */
01010     /* virtual vector<string*>& get_label_strings(); */
01011     /* //! Return the lowest (non-zero) size per class, multiplied by the number */
01012     /* //! of classes. */
01013     /* //! e.g. if a dataset has 10 classes with 100 examples and 5 classes with */
01014     /* //! 50 examples, it will return 50 * (10 + 5) = 750, whereas size() */
01015     /* //! will return 1250. */
01016     /* //! This is useful to keep iterations to a meaningful size when a class */
01017     /* //! has many more examples than another. */
01018     /* virtual intg get_lowest_common_size(); */
01019 
01020     /* // picking methods ///////////////////////////////////////////////////////// */
01021 
01022     /* //! Output statistics of samples picking, i.e. the number of times each */
01023     /* //! sample has been picked for training. */
01024     /* virtual void save_pickings(const char *name = NULL);     */
01025     /* //! Write plot of m organized by class and correctness */
01026     /* template <typename T> */
01027     /*   void write_classed_pickings(idx<T> &m, idx<ubyte> &correct, */
01028     /*                            string &name_, const char *name2_ = NULL, */
01029     /*                            bool plot_correct = true, */
01030     /*                            const char *ylabel = "");     */
01031 
01032     /* // state saving //////////////////////////////////////////////////////////// */
01033     
01034     /* //! Save internal iterators. Calling restore_state() will return to the */
01035     /* //! current sample. */
01036     /* virtual void save_state();     */
01037     /* //! Restore previously saved internal iterators. */
01038     /* virtual void restore_state(); */
01039 
01040     // pretty methods //////////////////////////////////////////////////////////
01041 
01043     virtual void pretty();
01044     /* //! Pretty the progress of current epoch. */
01045     /* //! \param newline If true, end pretty with a new line. */
01046     /* virtual void pretty_progress(bool newline = true); */
01047 
01049     virtual void print_path(Tlabel l);
01050     
01051     // friends //////////////////////////////////////////////////////////////
01052     template <typename T1, typename T2, typename T3>
01053       friend class class_datasource_gui;
01054 
01055     // members /////////////////////////////////////////////////////////////////
01056   protected:
01057     // hierarchy
01058     vector<class_node<Tlabel>*> all_nodes; 
01059     vector<class_node<Tlabel>*> all_nodes_by_depth; 
01060     vector<vector<class_node<Tlabel>*>*> all_depths; 
01061 
01062 
01063 
01064     vector<vector<class_node<Tlabel>*>*> complete_depths;
01065     // data
01066     idx<Tlabel> *parents; 
01067     idx<Tlabel> depth_labels; 
01068     using labeled_datasource<Tnet,Tdata,Tlabel>::labels;
01069     using class_datasource<Tnet,Tdata,Tlabel>::olabels;
01070     using datasource<Tnet,Tdata>::_name;
01071     // classes
01072     using class_datasource<Tnet,Tdata,Tlabel>::nclasses;
01073     using class_datasource<Tnet,Tdata,Tlabel>::lblstr;  
01074     using class_datasource<Tnet,Tdata,Tlabel>::clblstr; 
01075     using datasource<Tnet,Tdata>::data;
01076     using datasource<Tnet,Tdata>::correct;
01077     using datasource<Tnet,Tdata>::energies;
01078     using datasource<Tnet,Tdata>::probas;
01079     // iterating
01080     uint current_depth; 
01081     vector<uint> it_depths; 
01082     using datasource<Tnet,Tdata>::it;
01083     /* using datasource<Tnet,Tdata>::epoch_mode; */
01084     /* using datasource<Tnet,Tdata>::epoch_show; */
01085     using datasource<Tnet,Tdata>::epoch_sz;
01086     /* using datasource<Tnet,Tdata>::epoch_timer; */
01087     /* using datasource<Tnet,Tdata>::epoch_show_printed; */
01088     // class-balanced iterating indices
01089     /* using datasource<Tnet,Tdata>::epoch_done_counters; */
01090     bool depth_balance; 
01091     using class_datasource<Tnet,Tdata,Tlabel>::balance; 
01092     /* vector<vector<intg> > bal_indices;       //!< Balanced iterating indices. */
01093     /* vector<uint>      bal_it;        //!< Sample iterators for each class. */
01094     /* uint              class_it;      //!< Iterator on classes. */
01095     /* // sample picking with probabilities */
01096     /* bool              perclass_norm; //!< Normalize probas per class. */
01097     using datasource<Tnet,Tdata>::epoch_pick_cnt;       
01098     using datasource<Tnet,Tdata>::epoch_cnt;    
01099     /* using datasource<Tnet,Tdata>::count_pickings; */
01100     using datasource<Tnet,Tdata>::pick_count;
01101     /* using datasource<Tnet,Tdata>::counts; */
01102     using datasource<Tnet,Tdata>::weigh_samples;
01103     /* // state saving */
01104     /* vector<vector<intg> > bal_indices_saved; */
01105     /* vector<uint>        bal_it_saved; */
01106     /* uint                class_it_saved; */
01107     /* using datasource<Tnet,Tdata>::state_saved; */
01108     /* using datasource<Tnet,Tdata>::count_pickings_save; */
01109     /* using datasource<Tnet,Tdata>::it_saved; */
01110     /* using datasource<Tnet,Tdata>::it_test; */
01111     /* using datasource<Tnet,Tdata>::it_test_saved; */
01112     /* using datasource<Tnet,Tdata>::it_train; */
01113     /* using datasource<Tnet,Tdata>::it_train_saved; */
01114     // misc
01115     using datasource<Tnet,Tdata>::sampledims;
01116     using datasource<Tnet,Tdata>::test_set;
01117     /* using datasource<Tnet,Tdata>::shuffle_passes; */
01118     /* using datasource<Tnet,Tdata>::not_picked; */
01119   };
01120 
01123   template <typename Tnet, typename Tdata, typename Tlabel>
01124     class labeled_pair_datasource
01125     : public labeled_datasource<Tnet, Tdata, Tlabel> {
01126   public:
01127     idx<intg>                                   pairs;
01128     /* typename idx<intg>::dimension_iterator   pairsIter; */
01129 
01131     labeled_pair_datasource(const char *data_fname, const char *labels_fname,
01132                             const char *classes_fname, const char *pairs_fname,
01133                             const char *name_ = NULL,
01134                             Tdata bias = 0, float coeff = 1.0);
01136     labeled_pair_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_,
01137                             idx<ubyte> &classes_, idx<intg> &pairs_,
01138                             const char *name_ = NULL,
01139                             Tdata bias = 0, float coeff = 1.0);    
01141     virtual ~labeled_pair_datasource();
01142 
01144     virtual void fprop(bbstate_idx<Tnet> &d1, bbstate_idx<Tnet> &d2,
01145                        bbstate_idx<Tlabel> &label);
01147     virtual bool next();
01150     virtual void seek_begin();
01152     virtual unsigned int size();
01153   };
01154 
01161   template <typename Tnet, typename Tdata, typename Tlabel>
01162     class mnist_datasource : public class_datasource<Tnet, Tdata, Tlabel> {
01163   public:
01169     mnist_datasource(const char *root, bool train_data, uint size);
01176     mnist_datasource(const char *root, const char *name, uint size);
01178     virtual ~mnist_datasource();
01179 
01181     virtual void fprop_data(bbstate_idx<Tnet> &s);
01182     
01183   protected:
01185     virtual void init(idx<Tdata> &data, idx<Tlabel> &labels, const char *name);
01186 
01187     // members /////////////////////////////////////////////////////////////////
01188   public:
01189     using datasource<Tnet,Tdata>::bias;
01190     using datasource<Tnet,Tdata>::coeff;
01191     using datasource<Tnet,Tdata>::sampledims;
01192     using datasource<Tnet,Tdata>::height;
01193     using datasource<Tnet,Tdata>::width;
01194     using datasource<Tnet,Tdata>::data;
01195     using datasource<Tnet,Tdata>::datas;
01196     using datasource<Tnet,Tdata>::multimat;
01197     using labeled_datasource<Tnet,Tdata,Tlabel>::labels;
01198     using datasource<Tnet,Tdata>::it;    
01199   };
01200 
01202   // Helper functions
01203   
01207   template <typename Tdata>
01208   idx<Tdata> create_target_matrix(intg nclasses, Tdata target);
01209 
01210 } // end namespace ebl
01211 
01212 #include "datasource.hpp"
01213 
01214 #endif /* DATASOURCE_H_ */