libeblearntools
/home/rex/ebltrunk/tools/libeblearntools/include/netconf.hpp
00001 /***************************************************************************
00002  *   Copyright (C) 2010 by Pierre Sermanet *
00003  *   pierre.sermanet@gmail.com *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef NETCONF_HPP_
00034 #define NETCONF_HPP_
00035 
00036 namespace ebl {
00037 
00041   template <typename T>
00042   bool get_param2(configuration &conf, const string &module_name,
00043                   const string &var_name, T &p, intg thickness,
00044                   intg noutputs) {
00045     string pn = module_name; pn << "_" << var_name;
00046     // check that variable is present
00047     if (!conf.exists(pn)) {
00048       // not found
00049       cerr << "error: required parameter " << pn << " not found" << endl;
00050       return false;
00051     }
00052     std::string val_in = conf.get_string(pn);
00053     if (!val_in.compare("thickness"))
00054       p = (T) thickness; // special value
00055     else if (!val_in.compare("noutputs"))
00056       p = (T) noutputs; // special value
00057     else // get int value
00058       conf.get(p, pn);
00059     return true;
00060   }
00061 
00065   template <typename T>
00066   bool get_param(configuration &conf, const string &module_name,
00067                  const string &var_name, T &p, bool optional = false) {
00068     string pn = module_name; pn << "_" << var_name;
00069     // check that variable is present
00070     if (!conf.exists(pn)) {
00071       // not found
00072       if (!optional)
00073         cerr << "error: required parameter " << pn << " not found" << endl;
00074       return false;
00075     }
00076     std::string val_in = conf.get_string(pn);
00077     conf.get(p, pn);
00078     return true;
00079   }
00080 
00081   // select network based on configuration
00082   template <typename T, class Tstate>
00083   module_1_1<T,Tstate>*
00084   create_network(parameter<T, Tstate> &theparam, configuration &conf,
00085                  intg &thick, int nout,
00086                  const char *varname, bool isbranch, bool narrow,
00087                  intg narrow_dim, intg narrow_size, intg narrow_offset,
00088                  vector<layers<T,Tstate>*> *branches,
00089                  vector<intg> *branches_thick,
00090                  map<string,module_1_1<T,Tstate>*> *shared_) {
00091     // if we don't find the generic architecture variable, try the old style
00092     // way with 'net-type'
00093     // if (!conf.exists(varname))
00094     //  return create_network_old(theparam, conf, nout);
00095     // else, use the arch list
00096     map<string,module_1_1<T,Tstate>*> *shared = shared_;
00097     if (!shared) shared = new map<string,module_1_1<T,Tstate>*>;
00098     list<string> arch = string_to_stringlist(conf.get_string(varname));
00099     uint arch_size = arch.size();
00100     layers<T,Tstate>* l =
00101       new layers<T,Tstate>(true, varname, isbranch, narrow, narrow_dim,
00102                                   narrow_size, narrow_offset);
00103     // remember thickness output of branches
00104     if (!branches_thick)
00105       branches_thick = new vector<intg>;
00106     // keep list of existing branches so far
00107     if (!branches)
00108       branches = new vector<layers<T,Tstate>*>;
00109     // info
00110     cout << "Creating a network with " << nout << " outputs and "
00111          << arch_size << " modules (input thickness is " << thick
00112          << "): " << conf.get_string(varname) << endl;
00113     try {
00114       // loop over each module
00115       for (uint i = 0; i < arch_size; ++i) {
00116         cout << varname << " " << i << ": ";
00117         // get first module name of the list and remove it from list
00118         string name = arch.front(); arch.pop_front();
00119         int errid = 0; // id of error thrown by create_module
00120         module_1_1<T,Tstate> *module = NULL;
00121         try {
00122           module = create_module<T,Tstate>
00123             (name, theparam, conf, nout, thick, *shared,
00124              branches, branches_thick);
00125         } catch (int err) { errid = err; }
00126         // add module to layers, if not null
00127         if (module) {
00128           // add the module
00129           l->add_module(module);
00130           cout << "Added " << module->describe() << " (#params "
00131                << theparam.x.nelements() << ", thickness " << thick
00132                << ")" << endl;
00133         } else {
00134           switch (errid) {
00135           case 1: eblwarn("ignoring module " << name);
00136             arch_size--; // decrease expected size of architecture
00137             break ;
00138           default: eblerror("failed to load module " << name); break ;
00139           }
00140         }
00141       }
00142       if (isbranch) // remember last thickness for this branch
00143         branches_thick->push_back(thick);
00144       if (arch_size != l->size())
00145         eblerror("Some error occurred when loading modules, expected to load "
00146                  << arch_size << " modules but only " << l->size()
00147                  << " were successfully loaded");
00148       cout << varname << ": loaded " << l->size() << " modules." << endl;
00149     } eblcatcherror();
00150     if (!shared_) // shared was allocated here, we can delete it
00151       delete shared;
00152     return l;
00153   }
00154 
00155   // select network based on configuration
00156   template <typename T, class Tstate>
00157   module_1_1<T,Tstate>*
00158   create_module(const string &name, parameter<T, Tstate> &theparam,
00159                 configuration &conf, int &nout, intg &thick,
00160                 map<string,module_1_1<T,Tstate>*> &shared,
00161                 vector<layers<T,Tstate>*> *branches,
00162                 vector<intg> *branches_thick) {
00163     string type = strip_last_num(name);
00164     module_1_1<T,Tstate> *module = NULL;
00165     // switch on each possible type of module
00166     // shared //////////////////////////////////////////////////////////////////
00167     // first check if the module we're loading is shared
00168     string sshared; bool bshared = false, bshared_exists = false;
00169     if (get_param(conf, name, "shared", sshared, true))
00170       bshared = (bool) string_to_int(sshared);
00171     // check if we already have it in stock
00172     typename map<string,module_1_1<T,Tstate>*>::iterator i =
00173       shared.find(name);
00174     if (i != shared.end()) bshared_exists = true; // module already allocated
00175     // merge ///////////////////////////////////////////////////////////////
00176     if (!type.compare("merge")) {
00177       string type, strbranches;
00178       if (!get_param(conf, name, "type", type)) return NULL;
00179       // switch on merging type
00180       if (type.compare("mflat")) { // all types but mflat
00181         get_param(conf, name, "branches", strbranches, true);
00182         // get vector of branch buffers to merge
00183         vector<mstate<Tstate>**> inputs;
00184         list<string> b = string_to_stringlist(strbranches);
00185         for (list<string>::iterator bi = b.begin(); bi != b.end(); bi++) {
00186           layers<T,Tstate> *branch = NULL;
00187           for (uint k = 0; k < branches->size(); ++k) {
00188             if (!strcmp((*branches)[k]->name(), (*bi).c_str())) {
00189               branch = (*branches)[k];
00190               thick += (*branches_thick)[k]; // update thickness
00191               break ;
00192             }
00193           }
00194           if (!branch)
00195             eblerror(name << " is trying to merge branch " << *bi
00196                      << " but branch not found");
00197           inputs.push_back(&(branch->intern_out));
00198         }
00199         // switch on merging type
00200         if (!type.compare("concat")) { // concatenate
00201           intg concat_dim;
00202           string sstates;
00203           vector<vector<uint> > states;
00204           if (!get_param(conf, name, "dim", concat_dim)) return NULL;
00205           if (get_param(conf, name, "states", sstates, true)) {
00206             vector<string> s = string_to_stringvector(sstates, ';');
00207             EDEBUG("s: " << s);
00208             for (uint i = 0; i < s.size(); ++i) {
00209               vector<uint> v = string_to_uintvector(s[i]);
00210               EDEBUG("v: " << v);
00211               states.push_back(v);
00212             }
00213           }
00214           // create module
00215           if (states.size() > 0)
00216           module = (module_1_1<T,Tstate>*)
00217             new merge_module<T,Tstate>(states, concat_dim, name.c_str());
00218           else // old-style with branches (TODO: remove)
00219             module = (module_1_1<T,Tstate>*)
00220               new merge_module<T,Tstate>(inputs, concat_dim, name.c_str(),
00221                                          strbranches.c_str());
00222         } else if (!type.compare("flat")) { // flatten
00223           string strides, ins, bstrides, bins;
00224           if (!get_param(conf, name, "in", ins)) return NULL;
00225           if (!get_param(conf, name, "stride", strides)) return NULL;
00226           if (!get_param(conf, name, "branches_in", bins)) return NULL;
00227           if (!get_param(conf, name, "branches_stride", bstrides))
00228             return NULL;
00229           idxdim in = string_to_idxdim(ins);
00230           fidxdim stride = string_to_fidxdim(strides.c_str());
00231           midxdim bin = string_to_idxdimvector(bins.c_str());
00232           mfidxdim bstride = string_to_fidxdimvector(bstrides.c_str());
00233           module = (module_1_1<T,Tstate>*)
00234             new flat_merge_module<T,Tstate>(inputs, in, bin, stride,
00235                                             bstride, name.c_str(),
00236                                             strbranches.c_str());
00237         } else eblerror("unknown merge_type " << type);
00238       } else if (!type.compare("mflat")) { // multi-state flatten
00239         string strides, ins, sscales;
00240         bool bpad = false;
00241         mfidxdim bstride, scales;
00242         get_param(conf, name, "pad", bpad, true);
00243         if (!get_param(conf, name, "ins", ins)) return NULL;
00244         if (!get_param(conf, name, "strides", strides)) return NULL;
00245         if (get_param(conf, name, "scales", sscales, true))
00246           scales = string_to_fidxdimvector(sscales.c_str());
00247         intg hextra = 0, wextra = 0;
00248         float ss = 1, edge = 0;
00249         get_param(conf, name, "hextra", hextra, true);
00250         get_param(conf, name, "wextra", wextra, true);
00251         get_param(conf, name, "subsampling", ss, true);
00252         get_param(conf, name, "edge", edge, true);
00253 
00254         midxdim bin = string_to_idxdimvector(ins.c_str());
00255         EDEBUG("flat_merge " << strides);
00256         bstride = string_to_fidxdimvector(strides.c_str());
00257         EDEBUG("bstride " << bstride);
00258         module = (module_1_1<T,Tstate>*)
00259           new flat_merge_module<T,Tstate>
00260           (bin, bstride, bpad, name.c_str(), &scales,
00261            hextra, wextra, ss, edge);
00262       } else eblerror("unknown merge_type " << type);
00263     }
00264     // branch //////////////////////////////////////////////////////////////
00265     else if (!type.compare("branch")) {
00266       layers<T,Tstate> *branch = NULL;
00267       string type;
00268       bool narrow = false;
00269       intg narrow_dim, narrow_size, narrow_offset;
00270       if (!get_param(conf, name, "type", type)) return NULL;
00271       // get narrow parameters
00272       if (!type.compare("narrow")) { // narrow input
00273         narrow = true;
00274         if (!get_param(conf, name, "narrow_dim", narrow_dim)) return NULL;
00275         if (!get_param(conf, name, "narrow_size", narrow_size)) return NULL;
00276         if (!get_param(conf, name, "narrow_offset", narrow_offset)) return NULL;
00277       }
00278       cout << "Creating branch " << name;
00279       if (narrow)
00280         cout << ", narrow dim: " << narrow_dim << ", size: "
00281              << narrow_size << ", offset: " << narrow_offset;
00282       cout << endl;
00283       // add branch
00284       branch = (layers<T,Tstate>*) create_network<T,Tstate>
00285         (theparam, conf, thick, nout, name.c_str(), true,
00286          narrow, narrow_dim, narrow_size, narrow_offset, branches,
00287          branches_thick, &shared);
00288       branches->push_back(branch);
00289       module = (module_1_1<T,Tstate>*) branch;
00290     }
00291     // narrow //////////////////////////////////////////////////////////////////
00292     else if (!type.compare("narrow")) {
00293       intg dim, size;
00294       vector<intg> offsets;
00295       string soff;
00296       bool narrow_states = false;
00297       if (!get_param(conf, name, "dim", dim)) return NULL;
00298       if (!get_param(conf, name, "size", size)) return NULL;
00299       if (!get_param(conf, name, "offset", soff)) return NULL;
00300       get_param(conf, name, "narrow_states", narrow_states, true);
00301       offsets = string_to_intgvector(soff.c_str());
00302       module = new narrow_module<T,Tstate>(dim, size, offsets, narrow_states,
00303                                            name.c_str());
00304     }
00305     // table //////////////////////////////////////////////////////////////////
00306     else if (!type.compare("table")) {
00307       vector<intg> tbl;
00308       string sin;
00309       intg total = -1;
00310       if (!get_param(conf, name, "in", sin)) return NULL;
00311       if (!get_param(conf, name, "total", total)) return NULL;
00312       tbl = string_to_intgvector(sin.c_str());
00313       module = new table_module<T,Tstate>(tbl, total, name.c_str());
00314     }
00315     // interlace ///////////////////////////////////////////////////////////////
00316     else if (!type.compare("interlace")) {
00317       uint stride = 0;
00318       if (!get_param(conf, name, "stride", stride)) return NULL;
00319       module = new interlace_module<T,Tstate>(stride, name.c_str());
00320     }
00321     // preprocessing //////////////////////////////////////////////////////
00322     else if (!type.compare("rgb_to_ypuv") || !type.compare("rgb_to_ynuv")
00323              || !type.compare("rgb_to_yp") || !type.compare("rgb_to_yn")
00324              || !type.compare("y_to_yp")) {
00325       // get parameters for normalization
00326       string skernel; idxdim kernel;
00327       bool mirror = true, globn = true;
00328       t_norm mode = WSTD_NORM;
00329       if (get_param(conf, name, "kernel", skernel))
00330         kernel = string_to_idxdim(skernel);
00331       get_param(conf, name, "mirror", mirror, true);
00332       get_param(conf, name, "global_norm", globn, true);
00333       // create modules
00334       if (!type.compare("rgb_to_ypuv") || !type.compare("rgb_to_ynuv")) {
00335         module = (module_1_1<T,Tstate>*)
00336           new rgb_to_ynuv_module<T,Tstate>(kernel, mirror, mode, globn);
00337       } else if (!type.compare("rgb_to_yp") || !type.compare("rgb_to_yn")) {
00338         module = (module_1_1<T,Tstate>*)
00339           new rgb_to_yn_module<T,Tstate>(kernel, mirror, mode, globn);
00340       } else if (!type.compare("y_to_yp")) {
00341         module = (module_1_1<T,Tstate>*)
00342           new y_to_yp_module<T,Tstate>(kernel, mirror);
00343       }
00344     } else if (!type.compare("rgb_to_yuv"))
00345       module = (module_1_1<T,Tstate>*) new rgb_to_yuv_module<T,Tstate>();
00346     else if (!type.compare("rgb_to_y"))
00347       module = (module_1_1<T,Tstate>*) new rgb_to_y_module<T,Tstate>();
00348     else if (!type.compare("mschan")) {
00349       string snstates;
00350       if (!get_param(conf, name, "nstates", snstates)) return NULL;
00351       uint nstates = string_to_uint(snstates);
00352       module = (module_1_1<T,Tstate>*)
00353         new mschan_module<T,Tstate>(nstates, name.c_str());
00354     }
00355     // ms ////////////////////////////////////////////////////////////////
00356     else if (!type.compare("ms") || !type.compare("msc")) {
00357       string spipe;
00358       spipe << name << "_pipe";
00359       std::vector<module_1_1<T,Tstate>*> pipes;
00360       // loop while pipes exist
00361       vector<string> matches = conf.get_all_strings(spipe);
00362       intg thick2 = thick;
00363       for (uint i = 0; i < matches.size(); ++i) {
00364         thick2 = thick;
00365         string sp = matches[i];
00366         if (conf.exists(sp)) {
00367           module_1_1<T,Tstate>* m =
00368             create_network<T,Tstate>(theparam, conf, thick2, nout, sp.c_str(),
00369                                      false, 0,0,0,0, branches, branches_thick,
00370                                      &shared);
00371           // check the module was created
00372           if (!m) {
00373             cerr << "expected a module in " << spipe << endl;
00374             return NULL;
00375           }
00376           // add it
00377           pipes.push_back(m);
00378         } else {
00379           cout << "adding empty pipe (just passing data along) from variable "
00380                << sp << endl;
00381           pipes.push_back(NULL);
00382         }
00383       }
00384       thick = thick2;
00385       if (pipes.size() == 0) {
00386         eblwarn("no pipes found in module " << name.c_str()
00387                 << ", ignoring it");
00388         throw 1; // ignore this module
00389       }
00390       // get switching parameter
00391       string sswitch;
00392       midxdim switches;
00393       if (get_param(conf, name, "switch", sswitch, true))
00394         switches = string_to_idxdimvector(sswitch.c_str());
00395       // ms
00396       if (!type.compare("ms")) {
00397         bool replicate_inputs = false;
00398         get_param(conf, name, "replicate_inputs", replicate_inputs, true);
00399         ms_module<T,Tstate> *ms =
00400           new ms_module<T,Tstate>(pipes, replicate_inputs, name.c_str());
00401         ms->set_switch(switches);
00402         module = (module_1_1<T,Tstate>*) ms;
00403       } else if (!type.compare("msc")) { // msc
00404         uint nsize = 0, nsize2 = 0, stride = 1;
00405         if (!get_param(conf, name, "nsize", nsize)) return NULL;
00406         get_param(conf, name, "nsize2", nsize2, true);
00407         get_param(conf, name, "stride", stride, true);
00408         msc_module<T,Tstate> *msc = new msc_module<T,Tstate>
00409           (pipes, nsize, stride, nsize2, name.c_str());
00410         msc->set_switch(switches);
00411         module = (module_1_1<T,Tstate>*) msc;
00412       }
00413       EDEBUG("type: " << type << " " << module->describe());
00414     }
00415     // zpad /////////////////////////////////////////////////////////
00416     else if (!type.compare("zpad")) {
00417       string szpad;
00418       midxdim dims;
00419       if (get_param(conf, name, "dims", szpad))
00420         dims = string_to_idxdimvector(szpad.c_str());
00421       module = (module_1_1<T,Tstate>*)
00422         new zpad_module<T,Tstate>(dims, name.c_str());
00423     }
00424     // jitter //////////////////////////////////////////////////////////////////
00425     else if (!type.compare("jitter")) {
00426       jitter_module<T,Tstate> *j = new jitter_module<T,Tstate>(name.c_str());
00427       module = (module_1_1<T,Tstate>*) j;
00428       string str, srot, ssc, ssh, sel, spad;
00429       if (get_param(conf, name, "translations", str, true)) {
00430         vector<int> tr = string_to_intvector(str.c_str());
00431         j->set_translations(tr);
00432       }
00433       if (get_param(conf, name, "rotations", srot, true)) {
00434         vector<float> rot = string_to_floatvector(srot.c_str());
00435         j->set_rotations(rot);
00436       }
00437       if (get_param(conf, name, "scalings", ssc, true)) {
00438         vector<float> sc = string_to_floatvector(ssc.c_str());
00439         j->set_scalings(sc);
00440       }
00441       if (get_param(conf, name, "shears", ssh, true)) {
00442         vector<float> sh = string_to_floatvector(ssh.c_str());
00443         j->set_shears(sh);
00444       }
00445       if (get_param(conf, name, "elastic", sel, true)) {
00446         vector<float> el = string_to_floatvector(sel.c_str());
00447         j->set_elastics(el);
00448       }
00449       if (get_param(conf, name, "padding", spad, true)) {
00450         vector<uint> sp = string_to_uintvector(spad.c_str());
00451         j->set_padding(sp);
00452       }
00453     }
00454     // resizepp /////////////////////////////////////////////////////////
00455     else if (!type.compare("resizepp")) {
00456       string pps;
00457       // first get the preprocessing module
00458       if (!get_param(conf, name, "pp", pps)) return NULL;
00459       string pps_type = strip_last_num(pps);
00460       module_1_1<T,Tstate> *pp =
00461         create_module<T,Tstate>(pps, theparam, conf, nout, thick, shared,
00462                                 branches, branches_thick);
00463       if (!pp) {
00464         cerr << "expected a preprocessing module in " << name << endl;
00465         return NULL;
00466       }
00467       string szpad, ssize, sfovea;
00468       idxdim zpad, size;
00469       if (get_param(conf, name, "zpad", szpad, true))
00470         zpad = string_to_idxdim(szpad);
00471       if (get_param(conf, name, "size", ssize, true)) {
00472         size = string_to_idxdim(ssize);
00473         module = (module_1_1<T,Tstate>*)
00474           new resizepp_module<T,Tstate>(size, MEAN_RESIZE, pp, true, &zpad);
00475       } else if (get_param(conf, name, "fovea", sfovea, true)) {
00476         //TODO: might have to add fovea_scale_size
00477         vector<double> fovea = string_to_doublevector(sfovea);
00478         module = (module_1_1<T,Tstate>*)
00479           new fovea_module<T,Tstate>(fovea, false, MEAN_RESIZE, pp, true,&zpad);
00480       } else
00481         module = (module_1_1<T,Tstate>*)
00482           new resizepp_module<T,Tstate>(MEAN_RESIZE, pp, true, &zpad);
00483     }
00484     // resize /////////////////////////////////////////////////////////
00485     else if (!type.compare("resize")) {
00486       double resizeh, resizew;
00487       uint hzpad = 0, wzpad = 0;
00488       string szpad;
00489       if (!get_param(conf, name, "hratio", resizeh)) return NULL;
00490       if (!get_param(conf, name, "wratio", resizew)) return NULL;
00491       if (get_param(conf, name, "zpad", szpad, true)) {
00492         vector<uint> zp = string_to_uintvector(szpad, 'x');
00493         hzpad = zp[0];
00494         wzpad= zp[1];
00495       }
00496       // create module
00497       module = (module_1_1<T,Tstate>*)
00498         new resize_module<T,Tstate>(resizeh, resizew, BILINEAR_RESIZE,
00499                                     hzpad, wzpad);
00500     }
00501     // resize /////////////////////////////////////////////////////////
00502     else if (!type.compare("lpyramid")) {
00503       uint nscales = 0;
00504       string pp, skernels, sscalings, szpad;
00505       bool globnorm = true, locnorm = true, locnorm2 = false,
00506         color_lnorm = false, cnorm_across = true;
00507       midxdim zpads;
00508       if (!get_param(conf, name, "nscales", nscales)) return NULL;
00509       get_param(conf, name, "pp", pp, true);
00510       if (!get_param(conf, name, "kernels", skernels)) return NULL;
00511       midxdim kernels = string_to_idxdimvector(skernels.c_str());
00512       get_param(conf, name, "globalnorm", globnorm, true);
00513       get_param(conf, name, "localnorm", locnorm, true);
00514       get_param(conf, name, "localnorm2", locnorm2, true);
00515       get_param(conf, name, "cnorm_across", cnorm_across, true);
00516       get_param(conf, name, "color_lnorm", color_lnorm, true);
00517       if (get_param(conf, name, "zpad", szpad, true))
00518         zpads = string_to_idxdimvector(szpad.c_str());
00519 
00520       vector<float> scalings;
00521       if (get_param(conf, name, "scalings", sscalings, true))
00522         scalings = string_to_floatvector(sscalings.c_str());
00523       // create module
00524       module = (module_1_1<T,Tstate>*)
00525         create_preprocessing<T,Tstate>(pp.c_str(), kernels, zpads, "bilinear",
00526                                        true,
00527                                        nscales, NULL, NULL, globnorm, locnorm,
00528                                        locnorm2, color_lnorm,
00529                                        cnorm_across, 1.0, 1.0,
00530                                        scalings.size() > 0 ? &scalings : NULL);
00531     }
00532     // convolution /////////////////////////////////////////////////////////
00533     else if (!type.compare("conv") || !type.compare("convl")) {
00534       idxdim kernel, stride;
00535       string skernel, sstride;
00536       idx<intg> table(1, 1);
00537       if (get_param(conf, name, "kernel", skernel, true))
00538         kernel = string_to_idxdim(skernel);
00539       if (get_param(conf, name, "stride", sstride, true))
00540         stride = string_to_idxdim(sstride);
00541       if (!load_table(conf, name, table, thick, nout)) return NULL;
00542       // update thickness
00543       idx<intg> tblmax = table.select(1, 1);
00544       thick = 1 + idx_max(tblmax);
00545       // create module
00546       if (!type.compare("conv")) // conv module
00547         module = (module_1_1<T,Tstate>*)
00548           //      new convolution_module_replicable<T,Tstate>
00549           new convolution_module<T,Tstate>
00550           (bshared_exists? NULL : &theparam, kernel, stride, table,
00551            name.c_str());
00552       else if (!type.compare("convl")) // conv layer
00553         module = (module_1_1<T,Tstate>*)
00554           new convolution_layer<T,Tstate>
00555           (bshared_exists? NULL : &theparam, kernel, stride, table,
00556            true /* tanh */, name.c_str());
00557     }
00558     // subsampling ///////////////////////////////////////////////////////
00559     else if (!type.compare("subs") || !type.compare("subsl")
00560              || !type.compare("maxss")) {
00561       string skernel, sstride;
00562       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00563       if (!get_param(conf, name, "stride", sstride)) return NULL;
00564       idxdim kernel = string_to_idxdim(skernel);
00565       idxdim stride = string_to_idxdim(sstride);
00566       // create module
00567       if (!type.compare("subs")) // subsampling module
00568         module = (module_1_1<T,Tstate>*)
00569           new subsampling_module_replicable<T,Tstate>
00570           (bshared_exists? NULL : &theparam, thick, kernel, stride,
00571            name.c_str());
00572       else if (!type.compare("subsl"))
00573         module = (module_1_1<T,Tstate>*)
00574           new subsampling_layer<T,Tstate>
00575           (bshared_exists? NULL : &theparam, thick, kernel, stride, true,
00576            name.c_str());
00577       else if (!type.compare("maxss"))
00578         module = (module_1_1<T,Tstate>*)
00579           new maxss_module<T,Tstate>(thick, kernel, stride, name.c_str());
00580     }
00581     // subsampling ///////////////////////////////////////////////////////
00582     else if (!type.compare("avg_pyramid")) {
00583       string sstride;
00584       if (!get_param(conf, name, "strides", sstride)) return NULL;
00585       midxdim strides = string_to_idxdimvector(sstride.c_str());
00586       module = (module_1_1<T,Tstate>*)
00587         new average_pyramid_module<T,Tstate>
00588         (bshared_exists? NULL : &theparam, thick, strides, name.c_str());
00589     }
00590     // wavg_pooling ///////////////////////////////////////////////////////////////
00591     else if (!type.compare("wavgpool")) {
00592       string skernel, sstride;
00593       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00594       if (!get_param(conf, name, "stride", sstride)) return NULL;
00595       idxdim kernel = string_to_idxdim(skernel);
00596       idxdim stride = string_to_idxdim(sstride);
00597       module = (module_1_1<T,Tstate>*)
00598         new wavg_pooling_module<T,Tstate>(thick, kernel, stride, name.c_str());
00599     }
00600     // l1pooling ///////////////////////////////////////////////////////////////
00601     else if (!type.compare("l1pool")) {
00602       string skernel, sstride;
00603       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00604       if (!get_param(conf, name, "stride", sstride)) return NULL;
00605       idxdim kernel = string_to_idxdim(skernel);
00606       idxdim stride = string_to_idxdim(sstride);
00607       intg th = thick;
00608       get_param(conf, name, "thickness", th, true);
00609       module = (module_1_1<T,Tstate>*)
00610         new lppooling_module<T,Tstate>(th, kernel, stride, 1, name.c_str());
00611     }
00612     // l2pooling ///////////////////////////////////////////////////////////////
00613     else if (!type.compare("l2pool")) {
00614       string skernel, sstride;
00615       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00616       if (!get_param(conf, name, "stride", sstride)) return NULL;
00617       idxdim kernel = string_to_idxdim(skernel);
00618       idxdim stride = string_to_idxdim(sstride);
00619       intg th = thick;
00620       get_param(conf, name, "thickness", th, true);
00621       module = (module_1_1<T,Tstate>*)
00622         new lppooling_module<T,Tstate>(th, kernel, stride, 2, name.c_str());
00623     }
00624     // l4pooling ///////////////////////////////////////////////////////////////
00625     else if (!type.compare("l4pool")) {
00626       string skernel, sstride;
00627       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00628       if (!get_param(conf, name, "stride", sstride)) return NULL;
00629       idxdim kernel = string_to_idxdim(skernel);
00630       idxdim stride = string_to_idxdim(sstride);
00631       intg th = thick;
00632       get_param(conf, name, "thickness", th, true);
00633       module = (module_1_1<T,Tstate>*)
00634         new lppooling_module<T,Tstate>(th, kernel, stride, 4, name.c_str());
00635     }
00636     // l6pooling ///////////////////////////////////////////////////////////////
00637     else if (!type.compare("l6pool")) {
00638       string skernel, sstride;
00639       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00640       if (!get_param(conf, name, "stride", sstride)) return NULL;
00641       idxdim kernel = string_to_idxdim(skernel);
00642       idxdim stride = string_to_idxdim(sstride);
00643       intg th = thick;
00644       get_param(conf, name, "thickness", th, true);
00645       module = (module_1_1<T,Tstate>*)
00646         new lppooling_module<T,Tstate>(th, kernel, stride, 6, name.c_str());
00647     }
00648     // l8pooling ///////////////////////////////////////////////////////////////
00649     else if (!type.compare("l8pool")) {
00650       string skernel, sstride;
00651       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00652       if (!get_param(conf, name, "stride", sstride)) return NULL;
00653       idxdim kernel = string_to_idxdim(skernel);
00654       idxdim stride = string_to_idxdim(sstride);
00655       intg th = thick;
00656       get_param(conf, name, "thickness", th, true);
00657       module = (module_1_1<T,Tstate>*)
00658         new lppooling_module<T,Tstate>(th, kernel, stride, 8, name.c_str());
00659     }
00660     // l10pooling ///////////////////////////////////////////////////////////////
00661     else if (!type.compare("l10pool")) {
00662       string skernel, sstride;
00663       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00664       if (!get_param(conf, name, "stride", sstride)) return NULL;
00665       idxdim kernel = string_to_idxdim(skernel);
00666       idxdim stride = string_to_idxdim(sstride);
00667       intg th = thick;
00668       get_param(conf, name, "thickness", th, true);
00669       module = (module_1_1<T,Tstate>*)
00670         new lppooling_module<T,Tstate>(th, kernel, stride, 10, name.c_str());
00671     }
00672     // l12pooling ///////////////////////////////////////////////////////////////
00673     else if (!type.compare("l12pool")) {
00674       string skernel, sstride;
00675       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00676       if (!get_param(conf, name, "stride", sstride)) return NULL;
00677       idxdim kernel = string_to_idxdim(skernel);
00678       idxdim stride = string_to_idxdim(sstride);
00679       intg th = thick;
00680       get_param(conf, name, "thickness", th, true);
00681       module = (module_1_1<T,Tstate>*)
00682         new lppooling_module<T,Tstate>(th, kernel, stride, 12, name.c_str());
00683     }
00684     // l14pooling ///////////////////////////////////////////////////////////////
00685     else if (!type.compare("l14pool")) {
00686       string skernel, sstride;
00687       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00688       if (!get_param(conf, name, "stride", sstride)) return NULL;
00689       idxdim kernel = string_to_idxdim(skernel);
00690       idxdim stride = string_to_idxdim(sstride);
00691       intg th = thick;
00692       get_param(conf, name, "thickness", th, true);
00693       module = (module_1_1<T,Tstate>*)
00694         new lppooling_module<T,Tstate>(th, kernel, stride, 14, name.c_str());
00695     }
00696     // l16pooling ///////////////////////////////////////////////////////////////
00697     else if (!type.compare("l16pool")) {
00698       string skernel, sstride;
00699       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00700       if (!get_param(conf, name, "stride", sstride)) return NULL;
00701       idxdim kernel = string_to_idxdim(skernel);
00702       idxdim stride = string_to_idxdim(sstride);
00703       intg th = thick;
00704       get_param(conf, name, "thickness", th, true);
00705       module = (module_1_1<T,Tstate>*)
00706         new lppooling_module<T,Tstate>(th, kernel, stride, 16, name.c_str());
00707     }
00708     // l32pooling ///////////////////////////////////////////////////////////////
00709     else if (!type.compare("l32pool")) {
00710       string skernel, sstride;
00711       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00712       if (!get_param(conf, name, "stride", sstride)) return NULL;
00713       idxdim kernel = string_to_idxdim(skernel);
00714       idxdim stride = string_to_idxdim(sstride);
00715       intg th = thick;
00716       get_param(conf, name, "thickness", th, true);
00717       module = (module_1_1<T,Tstate>*)
00718         new lppooling_module<T,Tstate>(th, kernel, stride, 32, name.c_str());
00719     }
00720     // l64pooling ///////////////////////////////////////////////////////////////
00721     else if (!type.compare("l64pool")) {
00722       string skernel, sstride;
00723       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00724       if (!get_param(conf, name, "stride", sstride)) return NULL;
00725       idxdim kernel = string_to_idxdim(skernel);
00726       idxdim stride = string_to_idxdim(sstride);
00727       intg th = thick;
00728       get_param(conf, name, "thickness", th, true);
00729       module = (module_1_1<T,Tstate>*)
00730         new lppooling_module<T,Tstate>(th, kernel, stride, 64, name.c_str());
00731     }
00732     // lppooling ///////////////////////////////////////////////////////////////
00733     else if (!type.compare("lppool")) {
00734       string skernel, sstride;
00735       uint pool_power;
00736       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00737       if (!get_param(conf, name, "stride", sstride)) return NULL;
00738       if (!get_param(conf, name, "power", pool_power)) return NULL;
00739       idxdim kernel = string_to_idxdim(skernel);
00740       idxdim stride = string_to_idxdim(sstride);
00741       intg th = thick;
00742       get_param(conf, name, "thickness", th, true);
00743       module = (module_1_1<T,Tstate>*)
00744         new lppooling_module<T,Tstate>(th, kernel, stride, pool_power, name.c_str());
00745     }
00746     // linear //////////////////////////////////////////////////////////////
00747     else if (!type.compare("linear") || !type.compare("linear_replicable")) {
00748       intg lin, lout;
00749       if (!get_param2(conf, name, "in", lin, thick, nout)) return NULL;
00750       if (!get_param2(conf, name, "out", lout, thick, nout)) return NULL;
00751       // create module
00752       if (!type.compare("linear"))
00753         module = (module_1_1<T,Tstate>*) new linear_module<T,Tstate>
00754           (bshared_exists? NULL : &theparam, lin, lout, name.c_str());
00755       else
00756         module = (module_1_1<T,Tstate>*) new linear_module_replicable<T,Tstate>
00757           (bshared_exists? NULL : &theparam, lin, lout, name.c_str());
00758       thick = lout; // update thickness
00759     }
00760     // addc //////////////////////////////////////////////////////////////
00761     else if (!type.compare("addc"))
00762       module = (module_1_1<T,Tstate>*) new addc_module<T,Tstate>
00763         (bshared_exists? NULL : &theparam, thick, name.c_str());
00764     // diag //////////////////////////////////////////////////////////////
00765     else if (!type.compare("diag"))
00766       module = (module_1_1<T,Tstate>*) new diag_module<T,Tstate>
00767         (bshared_exists? NULL : &theparam, thick, name.c_str());
00768     // copy //////////////////////////////////////////////////////////////
00769     else if (!type.compare("copy"))
00770       module = (module_1_1<T,Tstate>*) new copy_module<T,Tstate>
00771         (name.c_str());
00772     // printer //////////////////////////////////////////////////////////////
00773     else if (!type.compare("printer"))
00774       module = (module_1_1<T,Tstate>*) new printer_module<T,Tstate>
00775         (name.c_str());
00776     // normalization ///////////////////////////////////////////////////////////
00777     else if (!type.compare("wstd") || !type.compare("cnorm")
00778              || !type.compare("snorm") || !type.compare("dnorm")) {
00779       intg wthick = thick;
00780       string skernel;
00781       bool learn = false, learn_mean = false, fsum_div = false;
00782       double cgauss = 2.0, epsilon = 1e-6;
00783       float fsum_split = 1.0;
00784       if (!get_param(conf, name, "kernel", skernel)) return NULL;
00785       idxdim ker = string_to_idxdim(skernel);
00786       // set optional number of features (default is 'thick')
00787       get_param(conf, name, "features", wthick, true);
00788       get_param(conf, name, "learn", learn, true);
00789       get_param(conf, name, "learn_mean", learn_mean, true);
00790       get_param(conf, name, "gaussian_coeff", cgauss, true);
00791       get_param(conf, name, "fsum_div", fsum_div, true);
00792       get_param(conf, name, "fsum_split", fsum_split, true);
00793       get_param(conf, name, "epsilon", epsilon, true);
00794       // normalization modules
00795       if (!type.compare("wstd") || !type.compare("cnorm"))
00796         module = (module_1_1<T,Tstate>*) new contrast_norm_module<T,Tstate>
00797           (ker, wthick, conf.exists_true("mirror"), true, false,
00798            learn ? &theparam : NULL, name.c_str(), true, learn_mean, cgauss,
00799            fsum_div, fsum_split, epsilon);
00800       else if (!type.compare("snorm"))
00801         module = (module_1_1<T,Tstate>*) new subtractive_norm_module<T,Tstate>
00802           (ker, wthick, conf.exists_true("mirror"), false,
00803            learn ? &theparam : NULL, name.c_str(), true, cgauss,
00804            fsum_div, fsum_split);
00805       else if (!type.compare("dnorm"))
00806         module = (module_1_1<T,Tstate>*) new divisive_norm_module<T,Tstate>
00807           (ker, wthick, conf.exists_true("mirror"), true,
00808            learn ? &theparam : NULL, name.c_str(), true, cgauss, fsum_div,
00809            fsum_split, epsilon);
00810     }
00811     // smooth shrink ///////////////////////////////////////////////////////////
00812     else if (!type.compare("sshrink")) {
00813       string sbias, sbeta;
00814       T beta = (T) 10, bias = (T) .3;
00815       if (get_param(conf, name, "beta", sbeta, true))
00816         beta = (T) string_to_double(sbeta);
00817       if (get_param(conf, name, "bias", bias, true))
00818         bias = (T) string_to_double(sbias);
00819       module = (module_1_1<T,Tstate>*) new smooth_shrink_module<T,Tstate>
00820         (bshared_exists? NULL : &theparam, thick, beta, bias);
00821     }
00822     // linear shrink ///////////////////////////////////////////////////////////
00823     else if (!type.compare("lshrink")) {
00824       string sbias;
00825       T bias = 0;
00826       if (get_param(conf, name, "bias", sbias, true))
00827         bias = (T) string_to_double(sbias);
00828       module = (module_1_1<T,Tstate>*) new linear_shrink_module<T,Tstate>
00829         (bshared_exists? NULL : &theparam, thick, bias);
00830     }
00831     // linear shrink ///////////////////////////////////////////////////////////
00832     else if (!type.compare("tshrink")) {
00833       bool diags = false;
00834       get_param(conf, name, "coefficients", diags, true);
00835       module = (module_1_1<T,Tstate>*) new tanh_shrink_module<T,Tstate>
00836         (bshared_exists? NULL : &theparam, thick, diags);
00837       // tanh ///////////////////////////////////////////////////////////////
00838     } else if (!type.compare("tanh"))
00839       module = (module_1_1<T,Tstate>*) new tanh_module<T,Tstate>();
00840     // stdsig //////////////////////////////////////////////////////////////
00841     else if (!type.compare("stdsig"))
00842       module = (module_1_1<T,Tstate>*) new stdsigmoid_module<T,Tstate>();
00843     // abs //////////////////////////////////////////////////////////////
00844     else if (!type.compare("abs"))
00845       module = (module_1_1<T,Tstate>*) new abs_module<T,Tstate>();
00846     // abs //////////////////////////////////////////////////////////////
00847     else if (!type.compare("back"))
00848       module = (module_1_1<T,Tstate>*) new back_module<T,Tstate>();
00849     // abs //////////////////////////////////////////////////////////////
00850     else if (!type.compare("lua")) {
00851       string script;
00852       if (!get_param(conf, name, "script", script)) return NULL;
00853       module = (module_1_1<T,Tstate>*) new lua_module<T,Tstate>(script.c_str());
00854     } else
00855       cout << "unknown module type " << type << endl;
00856     // check if the module we're loading is shared
00857     if (module && bshared) { // this module is shared with others
00858       // check if we already have it in stock
00859       typename map<string,module_1_1<T,Tstate>*>::iterator i =
00860         shared.find(name);
00861       if (i != shared.end()) { // already exist
00862         delete module;
00863         module = i->second->copy(); // load a shared copy instead
00864         cout << "Loaded a shared copy of " << name << ". ";
00865       }
00866       else // we don't have it, add it
00867         shared[name] = module; // save this copy for future sharing
00868     }
00869     // add an ebm1 wrapper around this module if requested
00870     string sebm;
00871     if (get_param(conf, name, "energy", sebm, true)) {
00872       // create penalty module
00873       ebm_1<T,Tstate> *e = create_ebm1<T,Tstate>(sebm, conf);
00874       if (!e) eblerror("failed to create ebm1 from " << sebm);
00875       // create hybrid penalty / module_1_1
00876       module = new ebm_module_1_1<T,Tstate>(module, e, sebm.c_str());
00877     }
00878     return module;
00879   }
00880 
00881   // select network based on configuration
00882   template <typename T, class Tstate>
00883   ebm_1<T,Tstate>* create_ebm1(const string &name, configuration &conf) {
00884     string type = strip_last_num(name);
00885     ebm_1<T,Tstate> *ebm = NULL;
00886     // switch on each possible type of module
00887     if (!type.compare("l1penalty")) {
00888       T threshold = 0, coeff = 1;
00889       get_param(conf, name, "threshold", threshold, true);
00890       get_param(conf, name, "coeff", coeff, true);
00891       ebm = (ebm_1<T,Tstate>*) new l1_penalty<T,Tstate>(threshold, coeff);
00892     }
00893     else cout << "unknown ebm1 type " << type << endl;
00894     return ebm;
00895   }
00896 
00897   template <typename T, typename Tds1, typename Tds2, class Tstate>
00898   answer_module<T,Tds1,Tds2,Tstate>*
00899   create_answer(configuration &conf, uint noutputs,
00900                 const char *varname) {
00901     string name = conf.get_string(varname);
00902     string type = strip_last_num(name);
00903     answer_module<T,Tds1,Tds2,Tstate> *module = NULL;
00904     // loop on possible answer modules /////////////////////////////////////////
00905     if (!type.compare("class_answer")) {
00906       string factor_name, binary_name, tconf_name, tanh_name, force_name;
00907       t_confidence tconf = confidence_max;
00908       bool binary = false, btanh = false;
00909       float factor = 1.0;
00910       int force = -1;
00911       if (get_param(conf, name, "factor", factor_name, true))
00912         factor = string_to_float(factor_name);
00913       if (get_param(conf, name, "binary", binary_name, true))
00914         binary = (bool) string_to_uint(binary_name);
00915       if (get_param(conf, name, "confidence", tconf_name, true))
00916         tconf = (t_confidence) string_to_uint(tconf_name);
00917       if (get_param(conf, name, "tanh", tanh_name, true))
00918         btanh = (bool) string_to_uint(tanh_name);
00919       if (get_param(conf, name, "force_class", force_name, true))
00920         force = string_to_int(force_name);
00921       module = new class_answer<T,Tds1,Tds2,Tstate>
00922         (noutputs, factor, binary, tconf, btanh, name.c_str(), force);
00924     } else if (!type.compare("vote_answer")) {
00925       string factor_name, binary_name, tconf_name, tanh_name;
00926       t_confidence tconf = confidence_max;
00927       bool binary = false, btanh = false;
00928       float factor = 1.0;
00929       if (get_param(conf, name, "factor", factor_name, true))
00930         factor = string_to_float(factor_name);
00931       if (get_param(conf, name, "binary", binary_name, true))
00932         binary = (bool) string_to_uint(binary_name);
00933       if (get_param(conf, name, "confidence", tconf_name, true))
00934         tconf = (t_confidence) string_to_uint(tconf_name);
00935       if (get_param(conf, name, "tanh", tanh_name, true))
00936         btanh = (bool) string_to_uint(tanh_name);
00937       module = new vote_answer<T,Tds1,Tds2,Tstate>
00938         (noutputs, factor, binary, tconf, btanh, name.c_str());
00940     } else if (!type.compare("regression_answer")) {
00941       string threshold_name;
00942       float64 threshold = 0.0;
00943       if (get_param(conf, name, "threshold", threshold_name, true))
00944         threshold = (float64) string_to_double(threshold_name);
00945       module = new regression_answer<T,Tds1,Tds2,Tstate>
00946         (noutputs, threshold, name.c_str());
00948     } else if (!type.compare("scaler_answer")) {
00949       string negative_name, raw_name, threshold_name, spatial_name;
00950       bool raw_conf = false, spatial = false;
00951       float threshold = 0.0;
00952       if (!get_param(conf, name, "negative", negative_name)) return NULL;
00953       if (get_param(conf, name, "rawconf", raw_name, true))
00954         raw_conf = (bool) string_to_uint(raw_name);
00955       if (get_param(conf, name, "threshold", threshold_name, true))
00956         threshold = (float) string_to_float(threshold_name);
00957       if (get_param(conf, name, "spatial", spatial_name, true))
00958         spatial = (bool) string_to_uint(spatial_name);
00959       module = new scaler_answer<T,Tds1,Tds2,Tstate>
00960         (1, 0, raw_conf, threshold, spatial, name.c_str());
00962     } else if (!type.compare("scalerclass_answer")) {
00963       string factor_name, binary_name, tconf_name, tanh_name,
00964         jsize_name, joff_name, mgauss_name, pconf_name, pbconf_name,
00965         coeffs_name, biases_name;
00966       t_confidence tconf = confidence_max;
00967       bool binary = false, btanh = false,
00968         predict_conf = false, predict_bconf = false;
00969       float factor = 1.0, mgauss = 1.5;
00970       uint jsize = 1, joff = 0;
00971       idx<T> coeffs, biases;
00972       bool coeffs_set = false, biases_set = false;
00973       if (get_param(conf, name, "factor", factor_name, true))
00974         factor = string_to_float(factor_name);
00975       if (get_param(conf, name, "binary", binary_name, true))
00976         binary = (bool) string_to_uint(binary_name);
00977       if (get_param(conf, name, "confidence", tconf_name, true))
00978         tconf = (t_confidence) string_to_uint(tconf_name);
00979       if (get_param(conf, name, "tanh", tanh_name, true))
00980         btanh = (bool) string_to_uint(tanh_name);
00981       if (get_param(conf, name, "jsize", jsize_name, true))
00982         jsize = string_to_uint(jsize_name);
00983       if (get_param(conf, name, "joffset", joff_name, true))
00984         joff = string_to_uint(joff_name);
00985       if (get_param(conf, name, "mgauss", mgauss_name, true))
00986         mgauss = string_to_float(mgauss_name);
00987       if (get_param(conf, name, "predict_conf", pconf_name, true))
00988         predict_conf = (bool) string_to_uint(pconf_name);
00989       if (get_param(conf, name, "predict_bconf", pbconf_name, true))
00990         predict_bconf = (bool) string_to_uint(pbconf_name);
00991       if (get_param(conf, name, "coeffs", coeffs_name, true)) {
00992         coeffs = string_to_idx<T>(coeffs_name.c_str());
00993         coeffs_set = true;
00994       }
00995       if (get_param(conf, name, "biases", biases_name, true)) {
00996         biases = string_to_idx<T>(biases_name.c_str());
00997         biases_set = true;
00998       }
00999       module =
01000         new scalerclass_answer<T,Tds1,Tds2,Tstate>
01001         (noutputs, factor, binary, tconf, btanh, jsize, joff, mgauss,
01002          predict_conf, predict_bconf, biases_set ? &biases : NULL,
01003          coeffs_set ? &coeffs : NULL, name.c_str());
01005     } else
01006       cerr << "unknown answer type " << type << endl;
01007     return module;
01008   }
01009 
01010   template <typename T, typename Tds1, typename Tds2, class Tstate>
01011   trainable_module<T,Tds1,Tds2,Tstate>*
01012   create_trainer(configuration &conf, labeled_datasource<T,Tds1,Tds2> &ds,
01013                  module_1_1<T,Tstate> &net,
01014                  answer_module<T,Tds1,Tds2,Tstate> &answer,
01015                  const char *varname) {
01016     string name = conf.get_string(varname);
01017     string type = strip_last_num(name);
01018     trainable_module<T,Tds1,Tds2,Tstate> *module = NULL;
01019     // switch on each possible type of trainer module
01020     if (!type.compare("trainable_module")) {
01021       ebm_2<Tstate> *energy = NULL;
01022       string energy_name, switcher;
01023       if (!get_param(conf, name, "energy", energy_name)) return NULL;
01024       string energy_type = strip_last_num(energy_name);
01025       get_param(conf, name, "switcher", switcher, true);
01026 
01027       // loop on possible energy modules ///////////////////////////////////////
01028       if (!energy_type.compare("l2_energy")) {
01029         energy = new l2_energy<T,Tstate>(energy_name.c_str());
01030       } else if (!energy_type.compare("scalerclass_energy")) {
01031         string tanh_name, jsize_name, jselection_name, dist_name, scale_name,
01032           pconf_name, pbconf_name, coeffs_name, biases_name;
01033         bool apply_tanh = false, predict_conf = false, predict_bconf = false;
01034         uint jsize = 1, jselection = 0;
01035         float dist_coeff = 1.0, scale_coeff = 1.0;
01036         idx<T> coeffs, biases;
01037         bool coeffs_set = false, biases_set = false;
01038         if (get_param(conf, energy_name, "tanh", tanh_name, true))
01039           apply_tanh = (bool) string_to_uint(tanh_name);
01040         if (get_param(conf, energy_name, "jsize", jsize_name, true))
01041           jsize = string_to_uint(jsize_name);
01042         if (get_param(conf, energy_name, "jselection", jselection_name, true))
01043           jselection = string_to_uint(jselection_name);
01044         if (get_param(conf, energy_name, "distcoeff", dist_name, true))
01045           dist_coeff = string_to_float(dist_name);
01046         if (get_param(conf, energy_name, "scalecoeff", scale_name, true))
01047           scale_coeff = string_to_float(scale_name);
01048         if (get_param(conf, energy_name, "predict_conf", pconf_name, true))
01049           predict_conf = (bool) string_to_uint(pconf_name);
01050         if (get_param(conf, energy_name, "predict_bconf", pbconf_name, true))
01051           predict_bconf = (bool) string_to_uint(pbconf_name);
01052         if (get_param(conf, energy_name, "coeffs", coeffs_name, true)) {
01053           coeffs = string_to_idx<T>(coeffs_name.c_str());
01054           coeffs_set = true;
01055         }
01056         if (get_param(conf, energy_name, "biases", biases_name, true)) {
01057           biases = string_to_idx<T>(biases_name.c_str());
01058           biases_set = true;
01059         }
01060         energy =
01061           new scalerclass_energy<T,Tstate>(apply_tanh, jsize, jselection,
01062                                            dist_coeff, scale_coeff,
01063                                            predict_conf, predict_bconf,
01064                                            biases_set ? &biases : NULL,
01065                                            coeffs_set ? &coeffs : NULL,
01066                                            energy_name.c_str());
01067       } else if (!energy_type.compare("scaler_energy")) {
01068         energy = new scaler_energy<T,Tstate>(energy_name.c_str());
01069       } else
01070         eblerror("unknown energy type " << energy_type);
01071 
01072       // allocate trainer module
01073       module = new trainable_module<T,Tds1,Tds2,Tstate>
01074         (*energy, net, NULL, NULL, &answer, name.c_str(), switcher.c_str());
01075     }
01076     if (!module)
01077       eblerror("no trainer module found");
01078     return module;
01079   }
01080 
01081   template <typename T, class Tstate>
01082   resizepp_module<T,Tstate>*
01083   create_preprocessing(uint height, uint width, const char *ppchan,
01084                        idxdim &kersz, const char *resize_method,
01085                        bool keep_aspect_ratio, int lpyramid,
01086                        vector<double> *fovea, midxdim *fovea_scale_size,
01087                        bool globnorm, bool locnorm, bool locnorm2,
01088                        bool color_lnorm, bool cnorm_across,
01089                        double hscale, double wscale, vector<float> *scalings) {
01090     midxdim kers;
01091     kers.push_back(kersz);
01092     return create_preprocessing<T,Tstate>
01093       (height, width, ppchan, kers, resize_method, keep_aspect_ratio, lpyramid,
01094        fovea, fovea_scale_size, globnorm, locnorm, locnorm2, color_lnorm,
01095        cnorm_across, hscale, wscale, scalings);
01096   }
01097 
01098   template <typename T, class Tstate>
01099   resizepp_module<T,Tstate>*
01100   create_preprocessing(midxdim &dims, const char *ppchan,
01101                        midxdim &kers, midxdim &zpads, const char *resize_method,
01102                        bool keep_aspect_ratio, int lpyramid,
01103                        vector<double> *fovea, midxdim *fovea_scale_size,
01104                        bool globnorm, bool locnorm,
01105                        bool locnorm2, bool color_lnorm, bool cnorm_across,
01106                        double hscale, double wscale, vector<float> *scalings) {
01107     module_1_1<T,Tstate> *chanmodule = NULL;
01108     resizepp_module<T,Tstate> *ppmodule = NULL;
01109     if (kers.size() == 0) eblerror("expected at least 1 ker dims");
01110     idxdim kersz = kers[0];
01111     // set name of preprocessing
01112     string name;
01113     if (dims.size() == 0) eblerror("expected at least 1 idxdim in dims");
01114     idxdim d = dims[0];
01115     int height = d.dim(0), width = d.dim(1);
01116     name << ppchan << kersz << "_" << resize_method << height << "x" << width;
01117     if (!keep_aspect_ratio) name << "_noaspratio";
01118     if (!globnorm) name << "_nognorm";
01119     // set default min/max val for display
01120     T minval = (T) -2, maxval = (T) 2;
01121     t_norm tn = WSTD_NORM; bool mir = true;
01122     // create channel preprocessing module
01123     if (!strcmp(ppchan, "YpUV") || !strcmp(ppchan, "YnUV")) {
01124       chanmodule = new rgb_to_ynuv_module<T,Tstate>(kersz, mir, tn, globnorm);
01125     } else if (!strcmp(ppchan, "Yp") || !strcmp(ppchan, "Yn")) {
01126       chanmodule = new rgb_to_yn_module<T,Tstate>(kersz, mir, tn, globnorm);
01127     } else if (!strcmp(ppchan, "YnUVn")) {
01128       chanmodule = new rgb_to_ynuvn_module<T,Tstate>(kersz, mir, tn, globnorm);
01129     } else if (!strcmp(ppchan, "YnUnVn")) {
01130       chanmodule = new rgb_to_ynunvn_module<T,Tstate>(kersz, mir, tn, globnorm);
01131     } else if (!strcmp(ppchan, "YUVn")) {
01132       chanmodule = new rgb_to_yuvn_module<T,Tstate>(kersz, mir, tn, globnorm);
01133     } else if (!strcmp(ppchan, "RGBn")) {
01134       chanmodule = new rgb_to_rgbn_module<T,Tstate>(kersz, mir, tn, globnorm);
01135     } else if (!strcmp(ppchan, "YUV")) {
01136       chanmodule = new rgb_to_yuv_module<T,Tstate>(globnorm);
01137     } else if (!strcmp(ppchan, "HSV")) {
01138       eblerror("HSV pp module not implemented");
01139     } else if (!strcmp(ppchan, "RGB")) {
01140       // no preprocessing module, just set min/max val for display
01141       minval = (T) 0;
01142       maxval = (T) 255;
01143     } else eblerror("undefined channel preprocessing " << ppchan);
01144     // initialize resizing method
01145     uint resiz = 0;
01146     if (!strcmp(resize_method, "bilinear")) resiz = BILINEAR_RESIZE;
01147     else if (!strcmp(resize_method, "gaussian")) resiz = GAUSSIAN_RESIZE;
01148     else if (!strcmp(resize_method, "mean")) resiz = MEAN_RESIZE;
01149     else eblerror("undefined resizing method" << resize_method);
01150     // create resizing module
01151     // fovea resize
01152     if (fovea && fovea->size() > 0) {
01153       if (!fovea_scale_size || fovea_scale_size->size() != fovea->size())
01154         eblerror("expected same number of parameters in fovea and "
01155                  << "fovea_scale_size");
01156       ppmodule = new fovea_module<T,Tstate>(*fovea, *fovea_scale_size, d, true,
01157                                             resiz, chanmodule);
01158       name << "_fovea" << fovea->size();
01159     } else if (lpyramid > 0) { // laplacian pyramid resize
01160       laplacian_pyramid_module<T,Tstate> *pyr =
01161         new laplacian_pyramid_module<T,Tstate>
01162         (lpyramid, kers, dims, resiz, chanmodule, false, NULL, globnorm,
01163          locnorm, locnorm2, color_lnorm, cnorm_across, keep_aspect_ratio);
01164       if (scalings) pyr->set_scalings(*scalings);
01165       ppmodule = pyr;
01166       if (!locnorm) name << "_nolnorm";
01167       if (!locnorm2) name << "_nolnorm2";
01168       if (color_lnorm) {
01169         name << "_colorlnorm";
01170         if (cnorm_across) name << "across";
01171       }
01172       name << "_lpyramid" << lpyramid;
01173     } else // regular resize
01174       ppmodule = new resizepp_module<T,Tstate>(d, resiz, chanmodule, true,
01175                                                NULL, keep_aspect_ratio);
01176     ppmodule->set_scale_factor(hscale, wscale);
01177     ppmodule->set_display_range(minval, maxval);
01178     ppmodule->set_name(name.c_str());
01179     ppmodule->set_zpad(zpads);
01180     return ppmodule;
01181   }
01182 
01183   template <typename T, class Tstate>
01184   resizepp_module<T,Tstate>*
01185   create_preprocessing(const char *ppchan,
01186                        midxdim &kers, midxdim &zpads, const char *resize_method,
01187                        bool keep_aspect_ratio, int lpyramid,
01188                        vector<double> *fovea, midxdim *fovea_scale_size,
01189                        bool globnorm, bool locnorm,
01190                        bool locnorm2, bool color_lnorm, bool cnorm_across,
01191                        double hscale, double wscale, vector<float> *scalings) {
01192     midxdim d;
01193     d.push_back(idxdim(0, 0));
01194     return create_preprocessing<T,Tstate>
01195       (d, ppchan, kers, zpads, resize_method, keep_aspect_ratio, lpyramid,
01196        fovea, fovea_scale_size, globnorm, locnorm, locnorm2, color_lnorm,
01197        cnorm_across, hscale, wscale, scalings);
01198   }
01199 
01200   // select network based on configuration, using old-style variables
01201   template <typename T, class Tstate>
01202   module_1_1<T,Tstate>* create_network_old(parameter<T, Tstate> &theparam,
01203                                            configuration &conf, int noutputs) {
01204     string net_type = conf.get_string("net_type");
01205     // load custom tables if defined
01206     string mname;
01207     idx<intg> t0(1,1), t1(1,1), t2(1,1),
01208       *table0 = NULL, *table1 = NULL, *table2 = NULL;
01209     intg thick = -1;
01210     mname = "conv0";
01211     if (load_table(conf, mname, t0, thick, noutputs))
01212       table0 = &t0;
01213     mname = "conv1";
01214     if (load_table(conf, mname, t1, thick, noutputs))
01215       table1 = &t1;
01216     mname = "conv2";
01217     if (load_table(conf, mname, t2, thick, noutputs))
01218       table2 = &t2;
01219     // create networks
01220     // cscscf ////////////////////////////////////////////////////////////////
01221     if (!strcmp(net_type.c_str(), "cscscf")) {
01222       return (module_1_1<T,Tstate>*) new lenet<T,Tstate>
01223         (theparam, conf.get_uint("net_ih"), conf.get_uint("net_iw"),
01224          conf.get_uint("net_c1h"), conf.get_uint("net_c1w"),
01225          conf.get_uint("net_s1h"), conf.get_uint("net_s1w"),
01226          conf.get_uint("net_c2h"), conf.get_uint("net_c2w"),
01227          conf.get_uint("net_s2h"), conf.get_uint("net_s2w"),
01228          conf.get_uint("net_full"), noutputs,
01229          conf.get_bool("absnorm"), conf.get_bool("color"),
01230          conf.get_bool("mirror"), conf.get_bool("use_tanh"),
01231          conf.exists_true("use_shrink"), conf.exists_true("use_diag"),
01232          table0, table1, table2);
01233     // cscsc ////////////////////////////////////////////////////////////////
01234     } else if (!strcmp(net_type.c_str(), "cscsc")) {
01235       return (module_1_1<T,Tstate>*) new lenet_cscsc<T,Tstate>
01236         (theparam, conf.get_uint("net_ih"), conf.get_uint("net_iw"),
01237          conf.get_uint("net_c1h"), conf.get_uint("net_c1w"),
01238          conf.get_uint("net_s1h"), conf.get_uint("net_s1w"),
01239          conf.get_uint("net_c2h"), conf.get_uint("net_c2w"),
01240          conf.get_uint("net_s2h"), conf.get_uint("net_s2w"),
01241          noutputs, conf.get_bool("absnorm"), conf.get_bool("color"),
01242          conf.get_bool("mirror"), conf.get_bool("use_tanh"),
01243          conf.exists_true("use_shrink"), conf.exists_true("use_diag"),
01244          conf.exists_true("norm_pos"), table0, table1, table2);
01245     // cscf ////////////////////////////////////////////////////////////////
01246     } else if (!strcmp(net_type.c_str(), "cscf")) {
01247       return (module_1_1<T,Tstate>*) new lenet_cscf<T,Tstate>
01248         (theparam, conf.get_uint("net_ih"), conf.get_uint("net_iw"),
01249          conf.get_uint("net_c1h"), conf.get_uint("net_c1w"),
01250          conf.get_uint("net_s1h"), conf.get_uint("net_s1w"),
01251          conf.get_uint("net_c2h"), conf.get_uint("net_c2w"),
01252          noutputs, conf.get_bool("absnorm"), conf.get_bool("color"),
01253          conf.get_bool("mirror"), conf.get_bool("use_tanh"),
01254          conf.exists_true("use_shrink"), conf.exists_true("use_diag"),
01255          table0, table1);
01256     // cscc ////////////////////////////////////////////////////////////////
01257     } else if (!strcmp(net_type.c_str(), "cscc")) {
01258       if (!table0 || !table1 || !table2)
01259         eblerror("undefined connection tables");
01260       return (module_1_1<T,Tstate>*) new net_cscc<T,Tstate>
01261         (theparam, conf.get_uint("net_ih"), conf.get_uint("net_iw"),
01262          conf.get_uint("net_c1h"), conf.get_uint("net_c1w"), *table0,
01263          conf.get_uint("net_s1h"), conf.get_uint("net_s1w"),
01264          conf.get_uint("net_c2h"), conf.get_uint("net_c2w"), *table1,
01265          *table2, noutputs, conf.get_bool("absnorm"),
01266          conf.get_bool("mirror"), conf.get_bool("use_tanh"),
01267          conf.exists_true("use_shrink"), conf.exists_true("use_diag"));
01268     } else {
01269       cerr << "network type: " << net_type << endl;
01270       eblerror("unknown network type");
01271     }
01272     return NULL;
01273   }
01274 
01276   template <class Tmodule, typename T, class Tstate>
01277   bool load_module(configuration &conf, module_1_1<T,Tstate> &m,
01278                    const string &module_name, const string &type) {
01279     if (!dynamic_cast<Tmodule*>(&m))
01280       eblerror("cannot cast module " << module_name << " (\"" << m.name()
01281                << "\") into a " << type << " type");
01282     string name = module_name; name << "_weights";
01283     if (!conf.exists(name))
01284       return false; // do nothing if variable not found
01285     string filename = conf.get_string(name.c_str());
01286     idx<T> w = load_matrix<T>(filename);
01287     m.load_x(w);
01288     cout << "Loaded weights " << w << " into " << module_name << " from "
01289          << filename << " (dims " << w << " min " << idx_min(w) << " max "
01290          << idx_max(w) << " mean " << idx_mean(w) << ")" << endl;
01291     return true;
01292   }
01293 
01294   // select network based on configuration
01295   template <typename T, class Tstate>
01296   uint manually_load_network(layers<T,Tstate> &l, configuration &conf,
01297                              const char *varname) {
01298     list<string> arch = string_to_stringlist(conf.get_string(varname));
01299     uint arch_size = arch.size();
01300     cout << "Loading network manually using module list: "
01301          << conf.get_string(varname) << endl;
01302     uint n = 0;
01303     // loop over each module
01304     for (uint i = 0; i < arch_size; ++i) {
01305       // get first module name of the list and remove it from list
01306       string name = arch.front(); arch.pop_front();
01307       string type = strip_last_num(name);
01308       module_1_1<T,Tstate> *m = l.modules[i];
01309       // switch on each possible type of module
01310       if (!type.compare("conv"))
01311         n += load_module<convolution_module<T,Tstate>,T,Tstate>
01312           (conf, *m, name, type);
01313       else if (!type.compare("addc"))
01314         n += load_module<addc_module<T,Tstate>,T,Tstate>(conf, *m, name, type);
01315       else if (!type.compare("linear"))
01316         n += load_module<linear_module<T,Tstate>,T,Tstate>(conf, *m, name,type);
01317       else if (!type.compare("diag"))
01318         n += load_module<diag_module<T,Tstate>,T,Tstate>(conf, *m, name, type);
01319       else if (!type.compare("ms")) {
01320         ms_module<T,Tstate> *msm = dynamic_cast<ms_module<T,Tstate>*>(m);
01321         if (!msm)
01322           eblerror("expected a ms module while trying to load module "
01323                    << name << " but found: "<< typeid(m).name());
01324         for (uint pi = 0; pi < msm->npipes(); ++pi) {
01325           module_1_1<T,Tstate> *pipe = msm->get_pipe(pi);
01326           if (!pipe) continue ;
01327           if (!dynamic_cast<layers<T,Tstate>*>(pipe))
01328             eblerror("expected a layers module in pipes[" << pi << "] while "
01329                      << "trying to load module " << pipe->name()
01330                      << " but found: " << typeid(pipe).name());
01331           n += manually_load_network(*((layers<T,Tstate>*)pipe), conf,
01332                                      pipe->name());
01333         }
01334       }
01335       else if (!type.compare("branch")) {
01336         if (!dynamic_cast<layers<T,Tstate>*>(m))
01337           eblerror("expected a layers module with type branch while trying "
01338                    << "to load module " << name << " but found: "
01339                    << typeid(m).name());
01340         n += manually_load_network(*((layers<T,Tstate>*)m), conf,
01341                                    name.c_str());
01342       }
01343     }
01344     if (!l.is_branch())
01345       cout << "Loaded " << n << " weights." << endl;
01346     return n;
01347   }
01348 
01349 } // end namespace ebl
01350 
01351 #endif /* NETCONF_HPP_ */