libeblearn
/home/rex/ebltrunk/core/libeblearn/include/detector.h
00001 /***************************************************************************
00002  *   Copyright (C) 2010 by Pierre Sermanet *
00003  *   pierre.sermanet@gmail.com *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef DETECTOR_H_
00034 #define DETECTOR_H_
00035 
00036 #include "libidx.h"
00037 #include "ebl_states.h"
00038 #include "ebl_arch.h"
00039 #include "ebl_answer.h"
00040 #include "ebl_merge.h"
00041 #include "ebl_preprocessing.h"
00042 #include "bbox.h"
00043 #include "nms.h"
00044 
00045 using namespace std;
00046 
00047 namespace ebl {
00048 
00050   // detector
00051 
00061   enum t_scaling { MANUAL = 0, SCALES = 1, NSCALES = 2, SCALES_STEP = 3,
00062                    ORIGINAL = 4, NETWORK = 5, SCALES_STEP_UP = 6 };
00063 
00064   template <typename T, class Tstate = fstate_idx<T> >
00065     class detector {
00066   public:
00067 
00069     // constructors
00070 
00084     detector(module_1_1<T,Tstate> &thenet, vector<string> &labels,
00085              answer_module<T,T,T,Tstate> *answer = NULL,
00086              resizepp_module<T,Tstate> *resize = NULL,
00087              const char *background = NULL,
00088              std::ostream &out = std::cout, std::ostream &err = std::cerr,
00089              bool adapt_scales = false);
00090 
00092     virtual ~detector();
00093 
00095     // configuration
00096 
00098     void set_scaling_original();
00100     void set_scaling_type(t_scaling type);
00102     void set_resolutions(const midxdim &scales);
00104     void set_resolutions(const vector<double> &factors);
00106     void set_resolution(double factor);
00110     void set_resolutions(int resolutions);
00114     void set_resolutions(idx<uint> &resolutions);
00117     void set_zpads(float hzpad, float wzpad);
00125     void set_resolutions(double scales_steps, double max_scale = 1.0,
00126                          double min_scale = 1.0);
00128     int get_class_id(const string &name);
00130     void set_bgclass(const char *bg = NULL);
00135     bool set_mask_class(const char *mask);
00137     void set_silent();
00142     void set_min_resolution(uint min_size);
00147     void set_max_resolution(uint max_size);
00149     void set_raw_thresholds(vector<float> &t);
00150 
00153     void set_nms(t_nms type = nms_overlap, float pre_threshold = 0.0,
00154                  float post_threshold = 0.0,
00155                  float pre_hfact = 1.0, float pre_wfact = 1.0,
00156                  float post_hfact = 1.0, float post_wfact = 1.0,
00157                  float woverh = 1.0, float max_overlap = 1.0,
00158                  float max_hcenter_dist = 0.0, float max_wcenter_dist = 0.0,
00159                  float vote_max_overlap = 1.0,float vote_max_hcenter_dist = 0.0,
00160                  float vote_max_wcenter_dist = 0.0);
00161 
00163     void set_scaler_mode(bool set);
00165     void set_smoothing(uint type);
00166 
00173     void set_mem_optimization(Tstate &in, Tstate &out,
00174                               bool keep_inputs = false);
00176     void set_netdim(idxdim &d);
00182     void set_outputs_dumping(const char *name);
00184     void set_bboxes_off();
00186     vector<string>& get_labels();
00188     void set_ignore_outsiders();
00192     void set_corners_inference(uint type);
00196     void set_bbox_decision(uint type);
00197     void set_bbox_scalings(mfidxdim &scalings);
00198 
00200     // execution
00201 
00207     template <class Tin> bboxes& fprop(idx<Tin> &img, const char *fname = NULL);
00209     void fprop_nms(bboxes &in, bboxes &out);
00212     vector<idx<T> >& get_originals();
00215     midx<T> get_preprocessed(const bbox &b);
00223     svector<midx<T> >& get_preprocessed(bboxes &out, uint n = 0,
00224                                         bool diverse = false,
00225                                         uint pre_diverse_max = 100);
00234     svector<midx<T> >& get_preprocessed(bboxes &in, bboxes &out, uint n = 0,
00235                                         bool diverse = false,
00236                                         uint pre_diverse_max = 100);
00239     idx<T> get_mask(string &classname);
00241     uint get_total_saved();
00248     string& set_save(const string &directory, uint nmax = 0,
00249                      bool diversity = false);
00251     void init(idxdim &dinput, const char *frame_name = NULL);
00252 
00253   protected:
00254     // scales methods //////////////////////////////////////////////////////////
00255 
00262     void compute_scales(midxdim &scales, idxdim &netdim, idxdim &mindim,
00263                         idxdim &maxdim, idxdim &indim, t_scaling type,
00264                         uint nscales, double scales_step,
00265                         const char *frame_name = NULL);
00268     void compute_resolutions(midxdim &scales,
00269                              idxdim &mindim, idxdim &maxdim, uint nscales);
00272     void compute_resolutions(midxdim &scales,
00273                              idxdim &indim, vector<double> &scale_factors);
00278     void compute_resolutions(midxdim &scales, idxdim &mindim,
00279                              idxdim &maxdim, double scales_step);
00284     void compute_resolutions_up(midxdim &scales, idxdim &indim,
00285                                 idxdim &mindim, idxdim &maxdim,
00286                                 double scales_step);
00290     void validate_resolutions();
00291 
00292     // bboxes operations ///////////////////////////////////////////////////////
00293 
00295     void smooth_outputs();
00298     void update_merge_alignment();
00301     void get_corners(mstate<Tstate> &outputs);
00304     void extract_bboxes(T threshold, bboxes &bboxes);
00307     void save_bboxes(bboxes &bboxes, const string &dir,
00308                      const char *frame_name = NULL);
00312     void add_class(const char *name);
00313 
00314     // processing methods //////////////////////////////////////////////////////
00315 
00320     template <class Tin> void prepare(idx<Tin> &img, const char *fname = NULL);
00326     void prepare_scale(uint i);
00328     void multi_res_fprop();
00329 
00330     // member variables ////////////////////////////////////////////////////////
00331   protected:
00332     module_1_1<T,Tstate>        &thenet; 
00333     resizepp_module<T,Tstate>   *resizepp; 
00334     bool                 resizepp_delete; 
00335     idx<T>               image;
00336     double               contrast;
00337     double               brightness;
00338     idx<float>           sizes;
00339     fstate_idx<T>        finput; 
00340     Tstate              *input;        
00341     mstate<Tstate>       output;       
00342     Tstate              *tmp;           
00343     Tstate              *minput; 
00344     svector<mstate<Tstate> > ppinputs; 
00345     svector<mstate<Tstate> > outputs; 
00346     vector<string>       labels; 
00347   protected:
00348     // dimensions //////////////////////////////////////////////////////////////
00349     idxdim               indim; 
00350     idxdim               netdim; 
00351     bool                 netdim_fixed; 
00352     // bboxes //////////////////////////////////////////////////////////////////
00353     vector<rect<int> >   original_bboxes; 
00354     int                  bgclass;
00355     int                  mask_class;
00356     idx<T>               mask;
00357     nms                  *pnms; 
00358     // scales //////////////////////////////////////////////////////////////
00359     midxdim              scales; 
00360     midxdim              actual_scales; 
00361     midxdim              manual_scales; 
00362     vector<double>       scale_factors; 
00363     uint                 nscales; 
00364     double               scales_step;
00365     double               min_scale;
00366     double               max_scale;
00367     t_scaling            restype; 
00368     // saving //////////////////////////////////////////////////////////////
00369     bool                 silent; 
00370     bool                 save_mode; 
00371     string               save_dir; 
00372     vector<uint>         save_counts; 
00373     bboxes               raw_bboxes; 
00374     bboxes               pruned_bboxes; 
00375     uint                 min_size; 
00376     uint                 max_size; 
00377     vector<idx<T> >      odetections; 
00378     svector<midx<T> >    ppdetections; 
00379     bool                 bodetections; 
00380     bool                 bppdetections; 
00381     uint                 save_max_per_frame; 
00382     bool                 diverse_ordering; 
00383     bool                 mem_optimization; 
00384     bool                 optimization_swap; 
00385     bool                 keep_inputs; 
00386     uint                 hzpad; 
00387     uint                 wzpad; 
00388     // printing ////////////////////////////////////////////////////////////////
00389     std::ostream         &mout; 
00390     std::ostream         &merr; 
00391     // smoothing //////////////////////////////////////////////////////////////
00392     uint                 smoothing_type;
00393     idx<T>               smoothing_kernel;
00394     bool                 initialized;
00395     string               outputs_dump; 
00396     bool                 bboxes_off; 
00397     bool                 adapt_scales; 
00398     bool                 scaler_mode;
00399     answer_module<T,T,T,Tstate> *answer;
00400     mstate<Tstate>       answers; 
00401     bool                 ignore_outsiders; 
00402     uint corners_inference; 
00403     bool corners_infered; 
00404     mfidxdim itl, itr, ibl, ibr; 
00405     mfidxdim pptl, pptr, ppbl, ppbr; 
00406     float    pre_threshold; 
00407     vector<float> raw_thresholds; 
00408     vector<uint>  scale_indices; 
00409     uint bbox_decision; 
00410     mfidxdim bbox_scalings;
00411 
00412     // friends /////////////////////////////////////////////////////////////////
00413     template <typename T2, class Tstate2> friend class detector_gui;
00414     template <typename T2> friend class detection_thread;
00415     template <typename T2, class Tstate2> friend class bootstrapping;
00416   };
00417 
00418 } // end namespace ebl
00419 
00420 #include "detector.hpp"
00421 
00422 #endif /* DETECTOR_H_ */