libeblearntools
/home/rex/ebltrunk/tools/libeblearntools/include/pascalbg_dataset.hpp
00001 /***************************************************************************
00002  *   Copyright (C) 2009 by Pierre Sermanet   *
00003  *   pierre.sermanet@gmail.com   *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source 
00014  *       Initiative (http://www.opensource.org) must display the 
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef PASCALBG_DATASET_HPP_
00034 #define PASCALBG_DATASET_HPP_
00035 
00036 #include <algorithm>
00037 
00038 #include "xml_utils.h"
00039 
00040 #ifdef __BOOST__
00041 #define BOOST_FILESYSTEM_VERSION 2
00042 #include "boost/filesystem.hpp"
00043 #include "boost/regex.hpp"
00044 using namespace boost::filesystem;
00045 using namespace boost;
00046 #endif
00047 
00048 using namespace std;
00049 
00050 namespace ebl {
00051 
00053   // constructors & initializations
00054 
00055   template <class Tdata>
00056   pascalbg_dataset<Tdata>::pascalbg_dataset(const char *name_,
00057                                             const char *inroot_,
00058                                             const char *outdir_,
00059                                             uint max_folders_,
00060                                             bool ignore_diff, bool ignore_trunc,
00061                                             bool ignore_occl,
00062                                             const char *annotations,
00063                                             const char *outtmp_)
00064     : pascal_dataset<Tdata>(name_, inroot_, ignore_diff, ignore_trunc,
00065                             ignore_occl, annotations) {
00066     this->set_outdir(outdir_, outtmp_);
00067     max_folders = max_folders_;
00068     data_cnt = 0;
00069     save_mode = "mat";
00070   }
00071 
00072   template <class Tdata>
00073   pascalbg_dataset<Tdata>::~pascalbg_dataset() {
00074   }
00075 
00077   // data extraction
00078 
00079   template <class Tdata>
00080   bool pascalbg_dataset<Tdata>::extract() {
00081 #ifdef __BOOST__    
00082 #ifdef __XML__    
00083     cout << "Extracting samples from files into dataset..." << endl;
00084     // adding data to dataset using all xml files in annroot
00085     path p(annroot);
00086     if (!exists(p))
00087       eblerror("Annotation path " << annroot << " does not exist.");
00088     xtimer.start();
00089     processed_cnt = 0;
00090     // find all xml files recursively (and randomize list)
00091     list<string> *files = find_fullfiles(annroot, XML_PATTERN, NULL, false, 
00092                                          true, true);
00093     if (!files || files->size() == 0)
00094       eblerror("no xml files found in " << annroot << " using file pattern "
00095                << XML_PATTERN);
00096     cout << "Found " << files->size() << " xml files." << endl;
00097     for (list<string>::iterator i = files->begin(); i != files->end(); ++i) {
00098       this->process_xml(*i);
00099       processed_cnt++;
00100       if (this->full())
00101         break;
00102     }
00103     cout << "Extracted and saved " << data_cnt;
00104     cout << " background patches from dataset." << endl;
00105     cout << "Extraction time: " << xtimer.elapsed() << endl;
00106     if (files) delete files;
00107 #endif /* __XML__ */
00108 #endif /* __BOOSt__ */
00109     return true;
00110   }
00111 
00112 #ifdef __BOOST__ // disable some derived methods if BOOST not available
00113 #ifdef __XML__ // disable some derived methods if XML not available
00114 
00116   // process xml
00117 
00118   // Note: the difficult flag is ignored, so that we don't take
00119   // background patches even in difficult bounding boxes.
00120   template <class Tdata>
00121   bool pascalbg_dataset<Tdata>::process_xml(const string &xmlfile) {
00122     string image_filename, image_fullname, folder;
00123     vector<rect<int> > bboxes;
00124     string obj_classname, pose;
00125     bool pose_found = false;
00126     Node::NodeList::iterator oiter;
00127       
00128     // parse xml file
00129     try {
00130       DomParser parser;
00131       //    parser.set_validate();
00132       parser.parse_file(xmlfile);
00133       if (parser) {
00134         // initialize root node and list
00135         const Node* pNode = parser.get_document()->get_root_node();
00136         Node::NodeList list = pNode->get_children();
00137         // get image filename
00138         for(Node::NodeList::iterator iter = list.begin();
00139             iter != list.end(); ++iter) {
00140           if (!strcmp((*iter)->get_name().c_str(), "filename")) {
00141             xml_get_string(*iter, image_filename);
00142           } else if (!strcmp((*iter)->get_name().c_str(), "folder")) {
00143             xml_get_string(*iter, folder);
00144           }
00145         }
00146         image_fullname = imgroot;
00147         if (!folder.empty())
00148           image_fullname << "/" << folder << "/";
00149         image_fullname += image_filename;
00150         // include folder into filename to avoid conflicts
00151         if (!folder.empty()) { 
00152           string tmp;
00153           tmp << folder << "_" << image_filename;
00154           tmp = string_replace(tmp, "/", "_");
00155           image_filename = tmp;
00156         }
00157         // parse all objects in image
00158         for(Node::NodeList::iterator iter = list.begin();
00159             iter != list.end(); ++iter) {
00160           if (!strcmp((*iter)->get_name().c_str(), "object")) {
00161             // get object's properties
00162             Node::NodeList olist = (*iter)->get_children();
00163             for(oiter = olist.begin(); oiter != olist.end(); ++oiter) {
00164               if (!strcmp((*oiter)->get_name().c_str(), "name"))
00165                 xml_get_string(*oiter, obj_classname);
00166               else if (!strcmp((*oiter)->get_name().c_str(), "pose")) {
00167                 xml_get_string(*oiter, pose);
00168                 pose_found = true;
00169               }
00170             }
00171             // add object's bbox
00172             if (!usepartsonly) {
00173               // add object's class to dataset
00174               if (usepose && pose_found) { // append pose to class name
00175                 obj_classname += "_";
00176                 obj_classname += pose;
00177               }
00178               if (dataset<Tdata>::included(obj_classname)) {
00179                 bboxes.push_back(get_object(*iter));
00180               }
00181             }
00183             // parts
00184             if (useparts || usepartsonly) {
00185               string part_classname;
00186       
00187               // add part's class to dataset
00188               for(oiter = olist.begin();oiter != olist.end(); ++oiter) {
00189                 if (!strcmp((*oiter)->get_name().c_str(), "part")) {
00190                   // get part's name
00191                   Node::NodeList plist = (*oiter)->get_children();
00192                   for(Node::NodeList::iterator piter = plist.begin();
00193                       piter != plist.end(); ++piter) {
00194                     if (!strcmp((*piter)->get_name().c_str(), "name")) {
00195                       xml_get_string(*piter, part_classname);
00196                       // found a part and its name, add it
00197                       if (usepose && pose_found) {
00198                         part_classname += "_";
00199                         part_classname += pose;
00200                       }
00201                       if (dataset<Tdata>::included(part_classname)) {
00202                         bboxes.push_back(get_object(*oiter));
00203                       }
00204                     }
00205                   }
00206                 }
00207               }
00208             }
00209           }
00210         }
00211       }
00212     } catch (const std::exception& ex) {
00213       cerr << "error: Xml exception caught: " << ex.what() << endl;
00214       return false;
00215     } catch (const char *err) {
00216       cerr << "error: " << err << endl;
00217       return false;
00218     }
00219     try {
00220       // load image 
00221       idx<ubyte> img = load_image<ubyte>(image_fullname);
00222       // extract patches given image and bounding boxes
00223       process_image(img, bboxes, image_filename);
00224     } catch(string &err) {
00225       cerr << "error: failed to add " << image_fullname;
00226       cerr << ": " << endl << err << endl;
00227       add_errors++;
00228     }
00229     return true;
00230   }
00231   
00233   // process 1 object of an xml file
00234 
00235   template <class Tdata>
00236   rect<int> pascalbg_dataset<Tdata>::get_object(Node* onode) {
00237     unsigned int xmin = 0, ymin = 0, xmax = 0, ymax = 0;
00238     
00239     // parse object node
00240     Node::NodeList list = onode->get_children();
00241     for(Node::NodeList::iterator iter = list.begin();
00242         iter != list.end(); ++iter) {
00243       // parse bounding box
00244       if (!strcmp((*iter)->get_name().c_str(), "bndbox")) {
00245         Node::NodeList blist = (*iter)->get_children();
00246         for(Node::NodeList::iterator biter = blist.begin();
00247             biter != blist.end(); ++biter) {
00248           // save xmin, ymin, xmax and ymax
00249           if (!strcmp((*biter)->get_name().c_str(), "xmin"))
00250             xmin = xml_get_uint(*biter);
00251           else if (!strcmp((*biter)->get_name().c_str(), "ymin"))
00252             ymin = xml_get_uint(*biter);
00253           else if (!strcmp((*biter)->get_name().c_str(), "xmax"))
00254             xmax = xml_get_uint(*biter);
00255           else if (!strcmp((*biter)->get_name().c_str(), "ymax"))
00256             ymax = xml_get_uint(*biter);
00257         }
00258       } // else get object class name
00259     }
00260     rect<int> r(ymin, xmin, ymax - ymin, xmax - xmin);
00261     return r;
00262   }
00263   
00265   // process object's image
00266 
00267   template <class Tdata>
00268   void pascalbg_dataset<Tdata>::
00269   process_image(idx<ubyte> &img, vector<rect<int> >& bboxes,
00270                 const string &image_filename) {
00271     vector<rect<int> > patch_bboxes;
00272     vector<rect<int> >::iterator ibb;
00273     idxdim d(img);
00274     ostringstream fname;
00275     bool overlap;
00276     
00277     // for each scale, find patches and save them
00278     for (vector<double>::iterator i = scales.begin(); i != scales.end(); ++i) {
00279       patch_bboxes.clear();
00280       // rescale original bboxes
00281 //       double ratio = std::max(outdims.dim(0) / (double) img.dim(0),
00282 //                            outdims.dim(1) / (double) img.dim(1)) * *i;
00283       double ratio = *i;
00284       // extract all non overlapping patches with dimensions outdims that
00285       // do not overlap with bounding boxes
00286       rect<int> patch(0, 0, outdims.dim(0), outdims.dim(1));
00287       patch = patch * ratio;
00288       for (patch.h0 = 0; patch.h0 + patch.height <= img.dim(0);
00289            patch.h0 += patch.height) {
00290         for (patch.w0 = 0; patch.w0 + patch.width <= img.dim(1);
00291              patch.w0 += patch.width) {
00292           // test if patch overlaps with any bounding box or is outside of image
00293           overlap = false;
00294           for (ibb = bboxes.begin(); ibb != bboxes.end(); ++ibb) {
00295             if (patch.overlap(*ibb)) {
00296               overlap = true;
00297               break ;
00298             }
00299           }
00300           if (!overlap) {
00301             // push patch to list of extracted patches
00302             patch_bboxes.push_back(patch);
00303           }
00304         }
00305       }
00306       fname.str("");
00307       fname << image_filename << "_scale" << *i;
00308       if (patch_bboxes.size() == 0)
00309         cout << "No background patches could be extracted at scale " 
00310              << *i << endl;
00311       else {
00312         save_patches(img, image_filename, patch_bboxes, bboxes,
00313                      outtmp, max_folders, fname.str());
00314       }
00315     }
00316   }
00317   
00319   // save patches
00320 
00321   template <class Tdata>
00322   void pascalbg_dataset<Tdata>::save_patches(idx<ubyte> &im, const string &image_filename,
00323                                              vector<rect<int> > &patch_bboxes,
00324                                              vector<rect<int> > &objs_bboxes,
00325                                              const string &outd,
00326                                              uint max_folders,
00327                                              const string &filename) {
00328     ostringstream folder, fname;
00329     string cname = "background";
00330     rect<int> inr;
00331     // change image type from ubyte to Tdata
00332     idx<Tdata> img(im.get_idxdim());
00333     idx_copy(im, img);
00334     try {
00335       mkdir_full(outd.c_str());
00336       uint i;
00337       // shuffle randomly vector of patches to avoid taking top left corner
00338       // as first patch every time
00339       random_shuffle(patch_bboxes.begin(), patch_bboxes.end());
00340       // loop on patches
00341       for (i = 0; (i < patch_bboxes.size()) && (i < max_folders); ++i) {
00342         // extract patch
00343         rect<int> p = patch_bboxes[i];
00344         midx<Tdata> patch(1);
00345         patch.set(img, 0);
00346         // TODO: fix nasty memory leak if assigning patch = pp(patch) by deleting
00347         // references in destructor of midx
00348         midx<Tdata> patch2 = this->preprocess_data(patch, &cname, image_filename.c_str(),
00349                                                    &p, 0, NULL, NULL, 
00350                                                    NULL, NULL, NULL, &inr);     
00351         patch.clear();
00352         // create folder if doesn't exist
00353         folder.str("");
00354         folder << outd << "/" << "bg" << i+1 << "/";
00355         mkdir_full(folder.str().c_str());
00356         folder << "/background/";
00357         mkdir_full(folder.str().c_str());
00358         // save patch in folder
00359         // switch saving behavior
00360         fname.str("");
00361         fname << folder.str() << filename << ".bg" << i+1;
00362         if (!strcmp(save_mode.c_str(), "mat")) { // lush matrix mode
00363           fname << MATRIX_EXTENSION;
00364           //      idx<Tdata> patch2 = patch.shift_dim(2, 0);
00365           patch2.shift_dim_internal(2, 0);
00366           if (!save_matrices(patch2, fname.str())) {
00367             patch2.clear();
00368             throw fname.str();
00369           }
00370         } else { // image file mode
00371           eblerror("fix implementation");
00372           // fname << "." << save_mode;
00373           // idx<Tdata> tmp = patch;
00374           // // scale image to 0 255 if preprocessed
00375           // if (strcmp(ppconv_type.c_str(), "RGB")) {
00376           //   idx_addc(tmp, (Tdata) 1.0, tmp);
00377           //   idx_dotc(tmp, (Tdata) 127.5, tmp);
00378           // }
00379           // save_image(fname.str(), tmp, save_mode.c_str());
00380         }
00381         images_list.push_back(fname.str()); // add image to files list
00382         cout << data_cnt++ << ": saved " << fname.str().c_str() 
00383              << " " << patch2 << ", eta: " << xtimer.eta(data_cnt, max_data)
00384              << ", elapsed: " << xtimer.elapsed() << endl;
00385         display_patch(patch2, img, image_filename, cname, p, inr,
00386                       objs_bboxes, patch_bboxes);
00387         // TEMPORARY MEMORY LEAK FIX (use smart srg pointer to clear
00388         // automatically on object deletion)
00389         patch2.clear();
00390       }
00391 //       if (i < patches.size()) // reached max_folders, fill-up last one
00392 //      for ( ; i < patches.size(); ++i) {
00393 //        // save patch in folder
00394 //        fname.str("");
00395 //        fname << folder.str() << filename << ".bg" << i+1 << ".mat";
00396 //        if (!save_matrix(patches[i], fname.str()))
00397 //          throw fname.str();
00398 //        cout << data_cnt++ << ": saved " << fname.str().c_str() << endl;
00399 //      }
00400     } catch (const string &err) {
00401       cerr << "error: failed to save patch in " << err << endl;
00402     }
00403   }
00404 
00405   template <class Tdata>
00406   void pascalbg_dataset<Tdata>::
00407   display_patch(midx<Tdata> &patch, idx<Tdata> &img,
00408                 const string &image_filename, const string &cname,
00409                 rect<int> &pbbox, rect<int> &r, vector<rect<int> > &objs_bboxes,
00410                 vector<rect<int> > &patch_bboxes) {
00411 #ifdef __GUI__
00412       if (display_extraction) {
00413         disable_window_updates();
00414         // display
00415         //idx<Tdata> im3 = patch.shift_dim(2, 0);
00416         uint h = 47, w = 0;
00417         display_added(patch, img, &cname, image_filename.c_str(), NULL,
00418                       &r, false, NULL, NULL, NULL, NULL, NULL, NULL, &w);
00419         // draw patch bboxes
00420         vector<rect<int> >::iterator ibb;
00421         for (ibb = patch_bboxes.begin(); ibb != patch_bboxes.end(); ++ibb)
00422           draw_box(h + ibb->h0, w + ibb->w0,
00423                    ibb->height, ibb->width, 0, 255, 0);
00424         // draw objects bboxes
00425         for (ibb = objs_bboxes.begin(); ibb != objs_bboxes.end(); ++ibb)
00426           draw_box(h + ibb->h0, w + ibb->w0,
00427                    ibb->height, ibb->width, 255, 255, 0);
00428         // draw requested patch bbox
00429         draw_box(h + pbbox.h0, w + pbbox.w0,
00430                  pbbox.height, pbbox.width, 0, 0, 255);
00431         enable_window_updates();
00432         if (sleep_display)
00433           millisleep((long) sleep_delay);
00434       }
00435 #endif
00436   }
00437 
00438 
00439 #endif /* __XML__ */
00440 #endif /* __BOOST__ */
00441 
00442 } // end namespace ebl
00443 
00444 #endif /* PASCALBG_DATASET_HPP_ */