libeblearntools
|
00001 /*************************************************************************** 00002 * Copyright (C) 2009 by Pierre Sermanet * 00003 * pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef DATASET_H_ 00034 #define DATASET_H_ 00035 00036 #define MKDIR_RIGHTS 0755 00037 00038 #define DATASET_SAVE "dataset" 00039 #define DYNSET_SAVE "dynset" 00040 00041 #include "libidx.h" 00042 #include "libeblearn.h" 00043 00044 typedef int t_label; 00045 00046 // jitter 00047 typedef float t_jitter; 00048 #define JITTERS 4 // number of jitter variables 00049 00050 namespace ebl { 00051 00053 class EXPORT object : public rect<int> { 00054 public: 00055 // constructors/destructors //////////////////////////////////////////////// 00056 00058 object(uint id); 00060 virtual ~object(); 00061 00062 // accessors /////////////////////////////////////////////////////////////// 00063 00065 virtual void set_rect(int xmin, int ymin, int xmax, int ymax); 00067 virtual void set_visible(int xmin, int ymin, int xmax, int ymax); 00069 virtual void set_centroid(int x, int y); 00070 00071 // members ///////////////////////////////////////////////////////////////// 00072 uint id; 00073 rect<int> *visible; 00074 pair<int,int> *centroid; 00075 string name; 00076 bool difficult; 00077 bool truncated; 00078 bool occluded; 00079 string pose; 00080 std::vector<object*> parts; 00081 bool ignored; 00082 }; 00083 00085 class EXPORT jitter { 00086 public: 00089 jitter(float h_, float w_, float s_, float r_, int spatial_norm = 1); 00094 jitter(rect<float> &context, rect<float> &jit, int spatial_norm = 1); 00096 jitter(); 00098 virtual ~jitter(); 00101 template <typename T> 00102 rect<T> get_rect(const rect<T> &r, float ratio = 1.0); 00104 const idx<t_jitter>& get_jitter_vector() const; 00106 void set(const idx<t_jitter> &j); 00107 00108 // members ///////////////////////////////////////////////////////////////// 00109 public: 00110 float h; 00111 float w; 00112 float s; 00113 float r; 00114 private: 00115 idx<t_jitter> jitts; 00116 }; 00117 00120 template <class Tdata> class dataset { 00121 public: 00122 00124 // constructors/allocation 00125 00130 dataset(const char *name, const char *inroot = NULL); 00131 00133 virtual ~dataset(); 00134 00139 bool alloc(intg max = 0); 00140 00142 // data manipulation 00143 00145 virtual bool extract(); 00147 virtual void extract_statistics(); 00151 bool split_max_and_save(const char *name1, const char *name2, 00152 intg max, const string &outroot); 00155 void split_max(dataset<Tdata> &ds1, dataset<Tdata> &ds2, intg max); 00158 void merge_and_save(const char *name1, const char *name2, 00159 const string &outroot); 00161 void shuffle(); 00163 virtual void set_unique_label(const string &class_name); 00164 00166 // data preprocessing 00167 00169 void set_planar_loading(); 00170 00172 // accessors 00173 00175 const idxdim &get_sample_outdim(); 00177 intg size(); 00179 t_label get_label_from_class(const string &class_name); 00181 void set_display(bool display); 00183 void set_sleepdisplay(uint delay); 00185 void set_preprocessing(vector<resizepp_module<fs(Tdata)>*> &p); 00188 virtual void set_outdims(const idxdim &d); 00190 virtual void set_outdir(const char *s, const char *tmp = NULL); 00193 void set_mindims(const idxdim &d); 00195 void set_maxdims(const idxdim &d); 00198 void set_scales(const vector<double> &sc, const string &od); 00201 void set_fovea(const vector<double> &scales); 00203 void set_max_per_class(intg max); 00205 void set_max_data(intg max); 00207 void set_image_pattern(const string &p); 00209 void set_exclude(const vector<string> &ex); 00211 void set_include(const vector<string> &inc); 00214 void set_save(const string &save); 00216 void set_individual_save(bool b); 00218 void set_separate_layers_save(bool b); 00222 void set_name(const string &name); 00225 void set_label(const string &label); 00229 void set_bbox_woverh(float factor); 00232 void set_nopadded(bool nopadded); 00235 void set_videobox(uint nframes, uint stride); 00238 void set_jitter(uint tjitter_step, uint tjitter_hmin, uint tjitter_hmax, 00239 uint tjitter_wmin, uint tjitter_wmax, 00240 uint scale_steps, float scale_min, float scale_max, 00241 uint rotation_steps, float rotation_range, 00242 uint njitter); 00249 virtual void set_minvisibility(float minvis); 00251 void set_wmirror(); 00252 00255 void save_display(const string &dir, uint h = 0, uint w = 0); 00259 void use_pose(); 00263 void use_parts(); 00267 void use_parts_only(); 00273 bool full(t_label label = -1); 00277 virtual intg count_total(); 00278 00280 // I/O 00281 00284 bool load(const string &root); 00287 bool save(const string &root, bool save_data = true); 00288 00290 // print methods 00291 00293 void print_classes(); 00295 void print_stats(); 00296 00298 // Helper functions 00299 00301 static idx<ubyte> build_classes_idx(vector<string> &classes); 00302 00303 protected: 00304 00306 // allocation 00307 00313 bool allocate(intg n, idxdim &d); 00314 00316 // data manipulation 00317 00330 virtual bool add_data(midx<Tdata> &d, const t_label label, 00331 const string *class_name, 00332 const char *filename = NULL, 00333 const rect<int> *r = NULL, 00334 pair<int,int> *center = NULL, 00335 const rect<int> *visr = NULL, 00336 const rect<int> *cropr = NULL, 00337 const vector<object*> *objs = NULL, 00338 const jitter *jittforce = NULL); 00340 void add_data2(midx<Tdata> &sample, t_label label, const string *class_name, 00341 const char *filename, const jitter *jitt, 00342 idx<t_jitter> *js); 00344 void add_label(t_label label, const string *class_name, 00345 const char *filename, const jitter *jitt, 00346 idx<t_jitter> *js); 00348 virtual void clear_classes(); 00350 virtual bool add_class(const string &class_name); 00352 virtual void set_classes(idx<ubyte> &classidx); 00353 00355 virtual intg count_samples(); 00356 00360 void split(dataset<Tdata> &ds1, dataset<Tdata> &ds2); 00361 00362 template <class Toriginal> 00363 bool save_scales(idx<Toriginal> &d, const string &filename); 00364 00367 virtual bool included(t_label &lab); 00370 virtual bool included(const string &class_name); 00371 00373 // data preprocessing 00374 00386 midx<Tdata> preprocess_data(midx<Tdata> &d, const string *class_name, 00387 const char *filename = NULL, 00388 const rect<int> *r = NULL, double scale = 0, 00389 rect<int> *outr = NULL, 00390 pair<int,int> *center = NULL, 00391 jitter *jitt = NULL, 00392 const rect<int> *visr = NULL, 00393 const rect<int> *cropr = NULL, 00394 rect<int> *inr_out = NULL); 00395 00405 void display_added(midx<Tdata> &added, idx<Tdata> &original, 00406 const string *class_name, 00407 const char *filename = NULL, 00408 const rect<int> *inr = NULL, 00409 const rect<int> *origr = NULL, 00410 bool active_sleepd = true, 00411 pair<int,int> *center = NULL, 00412 const rect<int> *visr = NULL, 00413 const rect<int> *cropr = NULL, 00414 const vector<object*> *objs = NULL, 00415 const jitter *jitt = NULL, 00416 idx<t_jitter> *js = NULL, 00417 uint *woriginal = NULL); 00418 00420 // Helper functions 00421 00423 string& get_class_string(t_label id); 00424 00426 t_label get_class_id(const string &name); 00427 00429 void compute_stats(); 00430 00432 uint count_matches(const string &dir, const string &pattern); 00433 00436 void process_dir(const string &dir, const string &ext, 00437 const string &class_name); 00438 00440 virtual void load_data(const string &fname); 00441 00444 virtual void compute_random_jitter(); 00445 00446 protected: 00447 // data //////////////////////////////////////////////////////// 00448 midx<Tdata> data; 00449 idx<t_label> labels; 00450 idx<intg> ids; 00451 midx<t_jitter> jitters; 00452 vector<string> classes; 00453 idx<t_label> classpairs; 00454 idx<t_label> deformpairs; 00455 // data helpers //////////////////////////////////////////////// 00456 uint height; 00457 uint width; 00458 bool allocated; 00459 bool no_outdims; 00460 idxdim outdims; 00461 idxdim mindims; 00462 idxdim maxdims; 00463 bool maxdims_set; 00464 idxdim datadims; 00465 uint nlayers; 00466 intg data_cnt; 00467 intg processed_cnt; 00468 intg max_data; 00469 bool max_data_set; 00470 intg total_samples; 00471 idx<intg> max_per_class; 00472 intg mpc; 00473 bool max_per_class_set; 00474 midx<Tdata> load_img; 00475 bool scale_mode; 00476 vector<double> scales; 00477 bool interleaved_input; 00478 bool load_planar; 00479 vector<string> exclude; 00480 vector<string> include; 00481 bool usepose; 00482 bool useparts; 00483 bool usepartsonly; 00484 string save_mode; 00485 bool individual_save; 00486 bool separate_layers_save; 00487 list<string> images_list; 00488 bool wmirror; 00489 // bbox transformations ///////////////////////////////////////////////// 00490 float bbox_woverh; 00491 string force_label; 00492 bool nopadded; 00493 float minvisibility; 00494 // jitter /////////////////////////////////////////////////////// 00495 int tjitter_step; 00496 int tjitter_hmin; 00497 int tjitter_hmax; 00498 int tjitter_wmin; 00499 int tjitter_wmax; 00500 int sjitter_steps; 00501 float sjitter_min; 00502 float sjitter_max; 00503 int rjitter_steps; 00504 float rjitter; 00505 uint njitter; 00506 bool bjitter; 00507 vector<jitter> random_jitter; 00508 // names /////////////////////////////////////////////////////// 00509 string name; 00510 string data_fname; 00511 string labels_fname; 00512 string jitters_fname; 00513 string ids_fname; 00514 string classes_fname; 00515 string classpairs_fname; 00516 string deformpairs_fname; 00517 // directories ///////////////////////////////////////////////// 00518 string inroot; 00519 string outdir; 00520 string outtmp; 00521 string extension; 00522 // display ///////////////////////////////////////////////////// 00523 bool display_extraction; 00524 bool display_result; 00525 bool sleep_display; 00526 uint sleep_delay; 00527 bool bsave_display; 00528 string save_display_dir; 00529 // stats /////////////////////////////////////////////////////// 00530 uint nclasses; 00531 idx<intg> class_tally; 00532 idx<intg> add_tally; 00533 uint add_errors; 00534 timer xtimer; 00535 // preprocessing /////////////////////////////////////////////// 00536 bool do_preprocessing; 00537 vector<resizepp_module<fs(Tdata)>*> ppmods; 00538 string pp_names; 00539 rect<int> original_bbox; 00540 vector<double> fovea; 00541 // videobox /////////////////////////////////////////////// 00542 bool do_videobox; 00543 uint videobox_nframes; 00544 uint videobox_stride; 00545 }; 00546 00548 // Helper functions 00549 00551 EXPORT void build_fname(string &ds_name, const char *fname, string &fullname); 00552 00555 EXPORT uint count_matches(const string &dir, const string &pattern); 00556 00558 // loading errors 00559 00562 template <typename T> 00563 bool loading_error(idx<T> &mat, string &fname); 00566 template <typename T> 00567 bool loading_error(midx<T> &mat, string &fname); 00570 template <typename T> 00571 bool loading_warning(idx<T> &mat, string &fname); 00574 template <typename T> 00575 bool loading_warning(midx<T> &mat, string &fname); 00578 template <typename T> 00579 bool loading_nowarning(idx<T> &mat, string &fname); 00580 00581 } // end namespace ebl 00582 00583 #include "dataset.hpp" 00584 00585 #endif /* DATASET_H_ */