libeblearntools
/home/rex/ebltrunk/tools/libeblearntools/include/dataset.h
00001 /***************************************************************************
00002  *   Copyright (C) 2009 by Pierre Sermanet   *
00003  *   pierre.sermanet@gmail.com   *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source 
00014  *       Initiative (http://www.opensource.org) must display the 
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef DATASET_H_
00034 #define DATASET_H_
00035 
00036 #define MKDIR_RIGHTS 0755
00037 
00038 #define DATASET_SAVE "dataset"
00039 #define DYNSET_SAVE "dynset"
00040 
00041 #include "libidx.h"
00042 #include "libeblearn.h"
00043 
00044 typedef int t_label;
00045 
00046 // jitter
00047 typedef float t_jitter;
00048 #define JITTERS 4 // number of jitter variables
00049 
00050 namespace ebl {
00051 
00053   class EXPORT object : public rect<int> {
00054   public:
00055     // constructors/destructors ////////////////////////////////////////////////
00056     
00058     object(uint id);
00060     virtual ~object();
00061 
00062     // accessors ///////////////////////////////////////////////////////////////
00063 
00065     virtual void set_rect(int xmin, int ymin, int xmax, int ymax);
00067     virtual void set_visible(int xmin, int ymin, int xmax, int ymax);
00069     virtual void set_centroid(int x, int y);
00070 
00071     // members /////////////////////////////////////////////////////////////////
00072     uint id; 
00073     rect<int> *visible; 
00074     pair<int,int> *centroid; 
00075     string name; 
00076     bool difficult; 
00077     bool truncated; 
00078     bool occluded; 
00079     string pose; 
00080     std::vector<object*> parts; 
00081     bool ignored; 
00082   };
00083 
00085   class EXPORT jitter {
00086   public:
00089     jitter(float h_, float w_, float s_, float r_, int spatial_norm = 1);
00094     jitter(rect<float> &context, rect<float> &jit, int spatial_norm = 1);
00096     jitter();
00098     virtual ~jitter();
00101     template <typename T> 
00102       rect<T> get_rect(const rect<T> &r, float ratio = 1.0);
00104     const idx<t_jitter>& get_jitter_vector() const;
00106     void set(const idx<t_jitter> &j);
00107     
00108     // members /////////////////////////////////////////////////////////////////
00109   public:
00110     float h; 
00111     float w; 
00112     float s; 
00113     float r; 
00114   private:
00115     idx<t_jitter> jitts; 
00116   };
00117   
00120   template <class Tdata> class dataset {
00121   public:
00122 
00124     // constructors/allocation
00125 
00130     dataset(const char *name, const char *inroot = NULL);
00131 
00133     virtual ~dataset();
00134 
00139     bool alloc(intg max = 0);
00140 
00142     // data manipulation
00143 
00145     virtual bool extract();
00147     virtual void extract_statistics();
00151     bool split_max_and_save(const char *name1, const char *name2,
00152                             intg max, const string &outroot);
00155     void split_max(dataset<Tdata> &ds1, dataset<Tdata> &ds2, intg max);
00158     void merge_and_save(const char *name1, const char *name2,
00159                         const string &outroot);
00161     void shuffle();
00163     virtual void set_unique_label(const string &class_name);
00164 
00166     // data preprocessing
00167 
00169     void set_planar_loading();
00170 
00172     // accessors    
00173 
00175     const idxdim &get_sample_outdim();
00177     intg size();
00179     t_label get_label_from_class(const string &class_name);
00181     void set_display(bool display);
00183     void set_sleepdisplay(uint delay);
00185     void set_preprocessing(vector<resizepp_module<fs(Tdata)>*> &p);
00188     virtual void set_outdims(const idxdim &d);
00190     virtual void set_outdir(const char *s, const char *tmp = NULL);
00193     void set_mindims(const idxdim &d);
00195     void set_maxdims(const idxdim &d);
00198     void set_scales(const vector<double> &sc, const string &od);
00201     void set_fovea(const vector<double> &scales);
00203     void set_max_per_class(intg max);
00205     void set_max_data(intg max);
00207     void set_image_pattern(const string &p);
00209     void set_exclude(const vector<string> &ex);
00211     void set_include(const vector<string> &inc);
00214     void set_save(const string &save);
00216     void set_individual_save(bool b);
00218     void set_separate_layers_save(bool b);
00222     void set_name(const string &name);
00225     void set_label(const string &label);
00229     void set_bbox_woverh(float factor);
00232     void set_nopadded(bool nopadded);
00235     void set_videobox(uint nframes, uint stride);
00238     void set_jitter(uint tjitter_step, uint tjitter_hmin, uint tjitter_hmax,
00239                     uint tjitter_wmin, uint tjitter_wmax,
00240                     uint scale_steps, float scale_min, float scale_max,
00241                     uint rotation_steps, float rotation_range,
00242                     uint njitter);
00249     virtual void set_minvisibility(float minvis);
00251     void set_wmirror();
00252     
00255     void save_display(const string &dir, uint h = 0, uint w = 0);
00259     void use_pose();
00263     void use_parts();
00267     void use_parts_only();
00273     bool full(t_label label = -1);
00277     virtual intg count_total();
00278     
00280     // I/O
00281 
00284     bool load(const string &root);
00287     bool save(const string &root, bool save_data = true);
00288 
00290     // print methods
00291 
00293     void print_classes();
00295     void print_stats();
00296     
00298     // Helper functions
00299 
00301     static idx<ubyte> build_classes_idx(vector<string> &classes);
00302 
00303   protected:
00304     
00306     // allocation
00307     
00313     bool allocate(intg n, idxdim &d);
00314 
00316     // data manipulation
00317 
00330     virtual bool add_data(midx<Tdata> &d, const t_label label,
00331                           const string *class_name,
00332                           const char *filename = NULL,
00333                           const rect<int> *r = NULL,
00334                           pair<int,int> *center = NULL,
00335                           const rect<int> *visr = NULL,
00336                           const rect<int> *cropr = NULL,
00337                           const vector<object*> *objs = NULL,
00338                           const jitter *jittforce = NULL);
00340     void add_data2(midx<Tdata> &sample, t_label label, const string *class_name,
00341                    const char *filename, const jitter *jitt,
00342                    idx<t_jitter> *js);
00344     void add_label(t_label label, const string *class_name,
00345                    const char *filename, const jitter *jitt,
00346                    idx<t_jitter> *js);
00348     virtual void clear_classes();
00350     virtual bool add_class(const string &class_name);
00352     virtual void set_classes(idx<ubyte> &classidx);
00353 
00355     virtual intg count_samples();
00356 
00360     void split(dataset<Tdata> &ds1, dataset<Tdata> &ds2);
00361 
00362     template <class Toriginal>
00363       bool save_scales(idx<Toriginal> &d, const string &filename);
00364 
00367     virtual bool included(t_label &lab);
00370     virtual bool included(const string &class_name);
00371 
00373     // data preprocessing
00374 
00386     midx<Tdata> preprocess_data(midx<Tdata> &d, const string *class_name,
00387                                 const char *filename = NULL,
00388                                 const rect<int> *r = NULL, double scale = 0,
00389                                 rect<int> *outr = NULL,
00390                                 pair<int,int> *center = NULL,
00391                                 jitter *jitt = NULL,
00392                                 const rect<int> *visr = NULL,
00393                                 const rect<int> *cropr = NULL,
00394                                 rect<int> *inr_out = NULL);
00395 
00405     void display_added(midx<Tdata> &added, idx<Tdata> &original, 
00406                        const string *class_name, 
00407                        const char *filename = NULL,
00408                        const rect<int> *inr = NULL,  
00409                        const rect<int> *origr = NULL,  
00410                        bool active_sleepd = true,
00411                        pair<int,int> *center = NULL,
00412                        const rect<int> *visr = NULL,
00413                        const rect<int> *cropr = NULL,
00414                        const vector<object*> *objs = NULL,
00415                        const jitter *jitt = NULL,
00416                        idx<t_jitter> *js = NULL,
00417                        uint *woriginal = NULL);
00418 
00420     // Helper functions
00421 
00423     string& get_class_string(t_label id);
00424 
00426     t_label get_class_id(const string &name);
00427     
00429     void compute_stats();
00430 
00432     uint count_matches(const string &dir, const string &pattern);
00433 
00436     void process_dir(const string &dir, const string &ext,
00437                      const string &class_name);
00438     
00440     virtual void load_data(const string &fname);
00441 
00444     virtual void compute_random_jitter();
00445     
00446   protected:
00447     // data ////////////////////////////////////////////////////////
00448     midx<Tdata>         data;           
00449     idx<t_label>        labels;         
00450     idx<intg>           ids;            
00451     midx<t_jitter>      jitters;        
00452     vector<string>      classes;        
00453     idx<t_label>        classpairs;     
00454     idx<t_label>        deformpairs;    
00455     // data helpers ////////////////////////////////////////////////
00456     uint                height;         
00457     uint                width;          
00458     bool                allocated;      
00459     bool                no_outdims;     
00460     idxdim              outdims;        
00461     idxdim              mindims;        
00462     idxdim              maxdims;        
00463     bool                maxdims_set;    
00464     idxdim              datadims;       
00465     uint                nlayers;        
00466     intg                data_cnt;       
00467     intg                processed_cnt;  
00468     intg                max_data;       
00469     bool                max_data_set;   
00470     intg                total_samples;  
00471     idx<intg>           max_per_class;  
00472     intg                mpc;            
00473     bool                max_per_class_set;      
00474     midx<Tdata>         load_img;       
00475     bool                scale_mode;     
00476     vector<double>      scales;         
00477     bool                interleaved_input; 
00478     bool                load_planar;    
00479     vector<string>      exclude;        
00480     vector<string>      include;        
00481     bool                usepose;        
00482     bool                useparts;       
00483     bool                usepartsonly;   
00484     string              save_mode;      
00485     bool                individual_save; 
00486     bool                separate_layers_save; 
00487     list<string>        images_list;    
00488     bool                wmirror;        
00489     // bbox transformations /////////////////////////////////////////////////
00490     float               bbox_woverh;    
00491     string              force_label;    
00492     bool                nopadded;       
00493     float               minvisibility;
00494     // jitter ///////////////////////////////////////////////////////
00495     int                 tjitter_step;   
00496     int                 tjitter_hmin;   
00497     int                 tjitter_hmax;   
00498     int                 tjitter_wmin;   
00499     int                 tjitter_wmax;   
00500     int                 sjitter_steps;  
00501     float               sjitter_min;    
00502     float               sjitter_max;    
00503     int                 rjitter_steps;  
00504     float               rjitter;        
00505     uint                njitter;        
00506     bool                bjitter;        
00507     vector<jitter>      random_jitter;  
00508     // names ///////////////////////////////////////////////////////
00509     string              name;           
00510     string              data_fname;     
00511     string              labels_fname;   
00512     string              jitters_fname;  
00513     string              ids_fname;      
00514     string              classes_fname;  
00515     string              classpairs_fname;       
00516     string              deformpairs_fname;      
00517     // directories /////////////////////////////////////////////////
00518     string              inroot;         
00519     string              outdir;         
00520     string              outtmp;         
00521     string              extension;      
00522     // display /////////////////////////////////////////////////////
00523     bool                display_extraction;     
00524     bool                display_result; 
00525     bool                sleep_display;  
00526     uint                sleep_delay;    
00527     bool                bsave_display;
00528     string              save_display_dir;
00529     // stats ///////////////////////////////////////////////////////
00530     uint                nclasses;       
00531     idx<intg>           class_tally;    
00532     idx<intg>           add_tally;      
00533     uint                add_errors;     
00534     timer               xtimer;         
00535     // preprocessing ///////////////////////////////////////////////
00536     bool                do_preprocessing;       
00537     vector<resizepp_module<fs(Tdata)>*> ppmods;   
00538     string              pp_names;       
00539     rect<int>           original_bbox;  
00540     vector<double>      fovea; 
00541     // videobox ///////////////////////////////////////////////
00542     bool                do_videobox; 
00543     uint                videobox_nframes; 
00544     uint                videobox_stride;
00545   };
00546   
00548   // Helper functions
00549   
00551   EXPORT void build_fname(string &ds_name, const char *fname, string &fullname);
00552 
00555   EXPORT uint count_matches(const string &dir, const string &pattern);
00556 
00558   // loading errors
00559   
00562   template <typename T>
00563     bool loading_error(idx<T> &mat, string &fname);
00566   template <typename T>
00567     bool loading_error(midx<T> &mat, string &fname);
00570   template <typename T>
00571     bool loading_warning(idx<T> &mat, string &fname);
00574   template <typename T>
00575     bool loading_warning(midx<T> &mat, string &fname);
00578   template <typename T>
00579     bool loading_nowarning(idx<T> &mat, string &fname);
00580 
00581 } // end namespace ebl
00582 
00583 #include "dataset.hpp"
00584 
00585 #endif /* DATASET_H_ */