libeblearntools
|
00001 /*************************************************************************** 00002 * Copyright (C) 2010 by Pierre Sermanet * 00003 * pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef NETCONF_HPP_ 00034 #define NETCONF_HPP_ 00035 00036 namespace ebl { 00037 00041 template <typename T> 00042 bool get_param2(configuration &conf, const string &module_name, 00043 const string &var_name, T &p, intg thickness, 00044 intg noutputs) { 00045 string pn = module_name; pn << "_" << var_name; 00046 // check that variable is present 00047 if (!conf.exists(pn)) { 00048 // not found 00049 cerr << "error: required parameter " << pn << " not found" << endl; 00050 return false; 00051 } 00052 std::string val_in = conf.get_string(pn); 00053 if (!val_in.compare("thickness")) 00054 p = (T) thickness; // special value 00055 else if (!val_in.compare("noutputs")) 00056 p = (T) noutputs; // special value 00057 else // get int value 00058 conf.get(p, pn); 00059 return true; 00060 } 00061 00065 template <typename T> 00066 bool get_param(configuration &conf, const string &module_name, 00067 const string &var_name, T &p, bool optional = false) { 00068 string pn = module_name; pn << "_" << var_name; 00069 // check that variable is present 00070 if (!conf.exists(pn)) { 00071 // not found 00072 if (!optional) 00073 cerr << "error: required parameter " << pn << " not found" << endl; 00074 return false; 00075 } 00076 std::string val_in = conf.get_string(pn); 00077 conf.get(p, pn); 00078 return true; 00079 } 00080 00081 // select network based on configuration 00082 template <typename T, class Tstate> 00083 module_1_1<T,Tstate>* 00084 create_network(parameter<T, Tstate> &theparam, configuration &conf, 00085 intg &thick, int nout, 00086 const char *varname, bool isbranch, bool narrow, 00087 intg narrow_dim, intg narrow_size, intg narrow_offset, 00088 vector<layers<T,Tstate>*> *branches, 00089 vector<intg> *branches_thick, 00090 map<string,module_1_1<T,Tstate>*> *shared_) { 00091 // if we don't find the generic architecture variable, try the old style 00092 // way with 'net-type' 00093 // if (!conf.exists(varname)) 00094 // return create_network_old(theparam, conf, nout); 00095 // else, use the arch list 00096 map<string,module_1_1<T,Tstate>*> *shared = shared_; 00097 if (!shared) shared = new map<string,module_1_1<T,Tstate>*>; 00098 list<string> arch = string_to_stringlist(conf.get_string(varname)); 00099 uint arch_size = arch.size(); 00100 layers<T,Tstate>* l = 00101 new layers<T,Tstate>(true, varname, isbranch, narrow, narrow_dim, 00102 narrow_size, narrow_offset); 00103 // remember thickness output of branches 00104 if (!branches_thick) 00105 branches_thick = new vector<intg>; 00106 // keep list of existing branches so far 00107 if (!branches) 00108 branches = new vector<layers<T,Tstate>*>; 00109 // info 00110 cout << "Creating a network with " << nout << " outputs and " 00111 << arch_size << " modules (input thickness is " << thick 00112 << "): " << conf.get_string(varname) << endl; 00113 try { 00114 // loop over each module 00115 for (uint i = 0; i < arch_size; ++i) { 00116 cout << varname << " " << i << ": "; 00117 // get first module name of the list and remove it from list 00118 string name = arch.front(); arch.pop_front(); 00119 int errid = 0; // id of error thrown by create_module 00120 module_1_1<T,Tstate> *module = NULL; 00121 try { 00122 module = create_module<T,Tstate> 00123 (name, theparam, conf, nout, thick, *shared, 00124 branches, branches_thick); 00125 } catch (int err) { errid = err; } 00126 // add module to layers, if not null 00127 if (module) { 00128 // add the module 00129 l->add_module(module); 00130 cout << "Added " << module->describe() << " (#params " 00131 << theparam.x.nelements() << ", thickness " << thick 00132 << ")" << endl; 00133 } else { 00134 switch (errid) { 00135 case 1: eblwarn("ignoring module " << name); 00136 arch_size--; // decrease expected size of architecture 00137 break ; 00138 default: eblerror("failed to load module " << name); break ; 00139 } 00140 } 00141 } 00142 if (isbranch) // remember last thickness for this branch 00143 branches_thick->push_back(thick); 00144 if (arch_size != l->size()) 00145 eblerror("Some error occurred when loading modules, expected to load " 00146 << arch_size << " modules but only " << l->size() 00147 << " were successfully loaded"); 00148 cout << varname << ": loaded " << l->size() << " modules." << endl; 00149 } eblcatcherror(); 00150 if (!shared_) // shared was allocated here, we can delete it 00151 delete shared; 00152 return l; 00153 } 00154 00155 // select network based on configuration 00156 template <typename T, class Tstate> 00157 module_1_1<T,Tstate>* 00158 create_module(const string &name, parameter<T, Tstate> &theparam, 00159 configuration &conf, int &nout, intg &thick, 00160 map<string,module_1_1<T,Tstate>*> &shared, 00161 vector<layers<T,Tstate>*> *branches, 00162 vector<intg> *branches_thick) { 00163 string type = strip_last_num(name); 00164 module_1_1<T,Tstate> *module = NULL; 00165 // switch on each possible type of module 00166 // shared ////////////////////////////////////////////////////////////////// 00167 // first check if the module we're loading is shared 00168 string sshared; bool bshared = false, bshared_exists = false; 00169 if (get_param(conf, name, "shared", sshared, true)) 00170 bshared = (bool) string_to_int(sshared); 00171 // check if we already have it in stock 00172 typename map<string,module_1_1<T,Tstate>*>::iterator i = 00173 shared.find(name); 00174 if (i != shared.end()) bshared_exists = true; // module already allocated 00175 // merge /////////////////////////////////////////////////////////////// 00176 if (!type.compare("merge")) { 00177 string type, strbranches; 00178 if (!get_param(conf, name, "type", type)) return NULL; 00179 // switch on merging type 00180 if (type.compare("mflat")) { // all types but mflat 00181 get_param(conf, name, "branches", strbranches, true); 00182 // get vector of branch buffers to merge 00183 vector<mstate<Tstate>**> inputs; 00184 list<string> b = string_to_stringlist(strbranches); 00185 for (list<string>::iterator bi = b.begin(); bi != b.end(); bi++) { 00186 layers<T,Tstate> *branch = NULL; 00187 for (uint k = 0; k < branches->size(); ++k) { 00188 if (!strcmp((*branches)[k]->name(), (*bi).c_str())) { 00189 branch = (*branches)[k]; 00190 thick += (*branches_thick)[k]; // update thickness 00191 break ; 00192 } 00193 } 00194 if (!branch) 00195 eblerror(name << " is trying to merge branch " << *bi 00196 << " but branch not found"); 00197 inputs.push_back(&(branch->intern_out)); 00198 } 00199 // switch on merging type 00200 if (!type.compare("concat")) { // concatenate 00201 intg concat_dim; 00202 string sstates; 00203 vector<vector<uint> > states; 00204 if (!get_param(conf, name, "dim", concat_dim)) return NULL; 00205 if (get_param(conf, name, "states", sstates, true)) { 00206 vector<string> s = string_to_stringvector(sstates, ';'); 00207 EDEBUG("s: " << s); 00208 for (uint i = 0; i < s.size(); ++i) { 00209 vector<uint> v = string_to_uintvector(s[i]); 00210 EDEBUG("v: " << v); 00211 states.push_back(v); 00212 } 00213 } 00214 // create module 00215 if (states.size() > 0) 00216 module = (module_1_1<T,Tstate>*) 00217 new merge_module<T,Tstate>(states, concat_dim, name.c_str()); 00218 else // old-style with branches (TODO: remove) 00219 module = (module_1_1<T,Tstate>*) 00220 new merge_module<T,Tstate>(inputs, concat_dim, name.c_str(), 00221 strbranches.c_str()); 00222 } else if (!type.compare("flat")) { // flatten 00223 string strides, ins, bstrides, bins; 00224 if (!get_param(conf, name, "in", ins)) return NULL; 00225 if (!get_param(conf, name, "stride", strides)) return NULL; 00226 if (!get_param(conf, name, "branches_in", bins)) return NULL; 00227 if (!get_param(conf, name, "branches_stride", bstrides)) 00228 return NULL; 00229 idxdim in = string_to_idxdim(ins); 00230 fidxdim stride = string_to_fidxdim(strides.c_str()); 00231 midxdim bin = string_to_idxdimvector(bins.c_str()); 00232 mfidxdim bstride = string_to_fidxdimvector(bstrides.c_str()); 00233 module = (module_1_1<T,Tstate>*) 00234 new flat_merge_module<T,Tstate>(inputs, in, bin, stride, 00235 bstride, name.c_str(), 00236 strbranches.c_str()); 00237 } else eblerror("unknown merge_type " << type); 00238 } else if (!type.compare("mflat")) { // multi-state flatten 00239 string strides, ins, sscales; 00240 bool bpad = false; 00241 mfidxdim bstride, scales; 00242 get_param(conf, name, "pad", bpad, true); 00243 if (!get_param(conf, name, "ins", ins)) return NULL; 00244 if (!get_param(conf, name, "strides", strides)) return NULL; 00245 if (get_param(conf, name, "scales", sscales, true)) 00246 scales = string_to_fidxdimvector(sscales.c_str()); 00247 intg hextra = 0, wextra = 0; 00248 float ss = 1, edge = 0; 00249 get_param(conf, name, "hextra", hextra, true); 00250 get_param(conf, name, "wextra", wextra, true); 00251 get_param(conf, name, "subsampling", ss, true); 00252 get_param(conf, name, "edge", edge, true); 00253 00254 midxdim bin = string_to_idxdimvector(ins.c_str()); 00255 EDEBUG("flat_merge " << strides); 00256 bstride = string_to_fidxdimvector(strides.c_str()); 00257 EDEBUG("bstride " << bstride); 00258 module = (module_1_1<T,Tstate>*) 00259 new flat_merge_module<T,Tstate> 00260 (bin, bstride, bpad, name.c_str(), &scales, 00261 hextra, wextra, ss, edge); 00262 } else eblerror("unknown merge_type " << type); 00263 } 00264 // branch ////////////////////////////////////////////////////////////// 00265 else if (!type.compare("branch")) { 00266 layers<T,Tstate> *branch = NULL; 00267 string type; 00268 bool narrow = false; 00269 intg narrow_dim, narrow_size, narrow_offset; 00270 if (!get_param(conf, name, "type", type)) return NULL; 00271 // get narrow parameters 00272 if (!type.compare("narrow")) { // narrow input 00273 narrow = true; 00274 if (!get_param(conf, name, "narrow_dim", narrow_dim)) return NULL; 00275 if (!get_param(conf, name, "narrow_size", narrow_size)) return NULL; 00276 if (!get_param(conf, name, "narrow_offset", narrow_offset)) return NULL; 00277 } 00278 cout << "Creating branch " << name; 00279 if (narrow) 00280 cout << ", narrow dim: " << narrow_dim << ", size: " 00281 << narrow_size << ", offset: " << narrow_offset; 00282 cout << endl; 00283 // add branch 00284 branch = (layers<T,Tstate>*) create_network<T,Tstate> 00285 (theparam, conf, thick, nout, name.c_str(), true, 00286 narrow, narrow_dim, narrow_size, narrow_offset, branches, 00287 branches_thick, &shared); 00288 branches->push_back(branch); 00289 module = (module_1_1<T,Tstate>*) branch; 00290 } 00291 // narrow ////////////////////////////////////////////////////////////////// 00292 else if (!type.compare("narrow")) { 00293 intg dim, size; 00294 vector<intg> offsets; 00295 string soff; 00296 bool narrow_states = false; 00297 if (!get_param(conf, name, "dim", dim)) return NULL; 00298 if (!get_param(conf, name, "size", size)) return NULL; 00299 if (!get_param(conf, name, "offset", soff)) return NULL; 00300 get_param(conf, name, "narrow_states", narrow_states, true); 00301 offsets = string_to_intgvector(soff.c_str()); 00302 module = new narrow_module<T,Tstate>(dim, size, offsets, narrow_states, 00303 name.c_str()); 00304 } 00305 // table ////////////////////////////////////////////////////////////////// 00306 else if (!type.compare("table")) { 00307 vector<intg> tbl; 00308 string sin; 00309 intg total = -1; 00310 if (!get_param(conf, name, "in", sin)) return NULL; 00311 if (!get_param(conf, name, "total", total)) return NULL; 00312 tbl = string_to_intgvector(sin.c_str()); 00313 module = new table_module<T,Tstate>(tbl, total, name.c_str()); 00314 } 00315 // interlace /////////////////////////////////////////////////////////////// 00316 else if (!type.compare("interlace")) { 00317 uint stride = 0; 00318 if (!get_param(conf, name, "stride", stride)) return NULL; 00319 module = new interlace_module<T,Tstate>(stride, name.c_str()); 00320 } 00321 // preprocessing ////////////////////////////////////////////////////// 00322 else if (!type.compare("rgb_to_ypuv") || !type.compare("rgb_to_ynuv") 00323 || !type.compare("rgb_to_yp") || !type.compare("rgb_to_yn") 00324 || !type.compare("y_to_yp")) { 00325 // get parameters for normalization 00326 string skernel; idxdim kernel; 00327 bool mirror = true, globn = true; 00328 t_norm mode = WSTD_NORM; 00329 if (get_param(conf, name, "kernel", skernel)) 00330 kernel = string_to_idxdim(skernel); 00331 get_param(conf, name, "mirror", mirror, true); 00332 get_param(conf, name, "global_norm", globn, true); 00333 // create modules 00334 if (!type.compare("rgb_to_ypuv") || !type.compare("rgb_to_ynuv")) { 00335 module = (module_1_1<T,Tstate>*) 00336 new rgb_to_ynuv_module<T,Tstate>(kernel, mirror, mode, globn); 00337 } else if (!type.compare("rgb_to_yp") || !type.compare("rgb_to_yn")) { 00338 module = (module_1_1<T,Tstate>*) 00339 new rgb_to_yn_module<T,Tstate>(kernel, mirror, mode, globn); 00340 } else if (!type.compare("y_to_yp")) { 00341 module = (module_1_1<T,Tstate>*) 00342 new y_to_yp_module<T,Tstate>(kernel, mirror); 00343 } 00344 } else if (!type.compare("rgb_to_yuv")) 00345 module = (module_1_1<T,Tstate>*) new rgb_to_yuv_module<T,Tstate>(); 00346 else if (!type.compare("rgb_to_y")) 00347 module = (module_1_1<T,Tstate>*) new rgb_to_y_module<T,Tstate>(); 00348 else if (!type.compare("mschan")) { 00349 string snstates; 00350 if (!get_param(conf, name, "nstates", snstates)) return NULL; 00351 uint nstates = string_to_uint(snstates); 00352 module = (module_1_1<T,Tstate>*) 00353 new mschan_module<T,Tstate>(nstates, name.c_str()); 00354 } 00355 // ms //////////////////////////////////////////////////////////////// 00356 else if (!type.compare("ms") || !type.compare("msc")) { 00357 string spipe; 00358 spipe << name << "_pipe"; 00359 std::vector<module_1_1<T,Tstate>*> pipes; 00360 // loop while pipes exist 00361 vector<string> matches = conf.get_all_strings(spipe); 00362 intg thick2 = thick; 00363 for (uint i = 0; i < matches.size(); ++i) { 00364 thick2 = thick; 00365 string sp = matches[i]; 00366 if (conf.exists(sp)) { 00367 module_1_1<T,Tstate>* m = 00368 create_network<T,Tstate>(theparam, conf, thick2, nout, sp.c_str(), 00369 false, 0,0,0,0, branches, branches_thick, 00370 &shared); 00371 // check the module was created 00372 if (!m) { 00373 cerr << "expected a module in " << spipe << endl; 00374 return NULL; 00375 } 00376 // add it 00377 pipes.push_back(m); 00378 } else { 00379 cout << "adding empty pipe (just passing data along) from variable " 00380 << sp << endl; 00381 pipes.push_back(NULL); 00382 } 00383 } 00384 thick = thick2; 00385 if (pipes.size() == 0) { 00386 eblwarn("no pipes found in module " << name.c_str() 00387 << ", ignoring it"); 00388 throw 1; // ignore this module 00389 } 00390 // get switching parameter 00391 string sswitch; 00392 midxdim switches; 00393 if (get_param(conf, name, "switch", sswitch, true)) 00394 switches = string_to_idxdimvector(sswitch.c_str()); 00395 // ms 00396 if (!type.compare("ms")) { 00397 bool replicate_inputs = false; 00398 get_param(conf, name, "replicate_inputs", replicate_inputs, true); 00399 ms_module<T,Tstate> *ms = 00400 new ms_module<T,Tstate>(pipes, replicate_inputs, name.c_str()); 00401 ms->set_switch(switches); 00402 module = (module_1_1<T,Tstate>*) ms; 00403 } else if (!type.compare("msc")) { // msc 00404 uint nsize = 0, nsize2 = 0, stride = 1; 00405 if (!get_param(conf, name, "nsize", nsize)) return NULL; 00406 get_param(conf, name, "nsize2", nsize2, true); 00407 get_param(conf, name, "stride", stride, true); 00408 msc_module<T,Tstate> *msc = new msc_module<T,Tstate> 00409 (pipes, nsize, stride, nsize2, name.c_str()); 00410 msc->set_switch(switches); 00411 module = (module_1_1<T,Tstate>*) msc; 00412 } 00413 EDEBUG("type: " << type << " " << module->describe()); 00414 } 00415 // zpad ///////////////////////////////////////////////////////// 00416 else if (!type.compare("zpad")) { 00417 string szpad; 00418 midxdim dims; 00419 if (get_param(conf, name, "dims", szpad)) 00420 dims = string_to_idxdimvector(szpad.c_str()); 00421 module = (module_1_1<T,Tstate>*) 00422 new zpad_module<T,Tstate>(dims, name.c_str()); 00423 } 00424 // jitter ////////////////////////////////////////////////////////////////// 00425 else if (!type.compare("jitter")) { 00426 jitter_module<T,Tstate> *j = new jitter_module<T,Tstate>(name.c_str()); 00427 module = (module_1_1<T,Tstate>*) j; 00428 string str, srot, ssc, ssh, sel, spad; 00429 if (get_param(conf, name, "translations", str, true)) { 00430 vector<int> tr = string_to_intvector(str.c_str()); 00431 j->set_translations(tr); 00432 } 00433 if (get_param(conf, name, "rotations", srot, true)) { 00434 vector<float> rot = string_to_floatvector(srot.c_str()); 00435 j->set_rotations(rot); 00436 } 00437 if (get_param(conf, name, "scalings", ssc, true)) { 00438 vector<float> sc = string_to_floatvector(ssc.c_str()); 00439 j->set_scalings(sc); 00440 } 00441 if (get_param(conf, name, "shears", ssh, true)) { 00442 vector<float> sh = string_to_floatvector(ssh.c_str()); 00443 j->set_shears(sh); 00444 } 00445 if (get_param(conf, name, "elastic", sel, true)) { 00446 vector<float> el = string_to_floatvector(sel.c_str()); 00447 j->set_elastics(el); 00448 } 00449 if (get_param(conf, name, "padding", spad, true)) { 00450 vector<uint> sp = string_to_uintvector(spad.c_str()); 00451 j->set_padding(sp); 00452 } 00453 } 00454 // resizepp ///////////////////////////////////////////////////////// 00455 else if (!type.compare("resizepp")) { 00456 string pps; 00457 // first get the preprocessing module 00458 if (!get_param(conf, name, "pp", pps)) return NULL; 00459 string pps_type = strip_last_num(pps); 00460 module_1_1<T,Tstate> *pp = 00461 create_module<T,Tstate>(pps, theparam, conf, nout, thick, shared, 00462 branches, branches_thick); 00463 if (!pp) { 00464 cerr << "expected a preprocessing module in " << name << endl; 00465 return NULL; 00466 } 00467 string szpad, ssize, sfovea; 00468 idxdim zpad, size; 00469 if (get_param(conf, name, "zpad", szpad, true)) 00470 zpad = string_to_idxdim(szpad); 00471 if (get_param(conf, name, "size", ssize, true)) { 00472 size = string_to_idxdim(ssize); 00473 module = (module_1_1<T,Tstate>*) 00474 new resizepp_module<T,Tstate>(size, MEAN_RESIZE, pp, true, &zpad); 00475 } else if (get_param(conf, name, "fovea", sfovea, true)) { 00476 //TODO: might have to add fovea_scale_size 00477 vector<double> fovea = string_to_doublevector(sfovea); 00478 module = (module_1_1<T,Tstate>*) 00479 new fovea_module<T,Tstate>(fovea, false, MEAN_RESIZE, pp, true,&zpad); 00480 } else 00481 module = (module_1_1<T,Tstate>*) 00482 new resizepp_module<T,Tstate>(MEAN_RESIZE, pp, true, &zpad); 00483 } 00484 // resize ///////////////////////////////////////////////////////// 00485 else if (!type.compare("resize")) { 00486 double resizeh, resizew; 00487 uint hzpad = 0, wzpad = 0; 00488 string szpad; 00489 if (!get_param(conf, name, "hratio", resizeh)) return NULL; 00490 if (!get_param(conf, name, "wratio", resizew)) return NULL; 00491 if (get_param(conf, name, "zpad", szpad, true)) { 00492 vector<uint> zp = string_to_uintvector(szpad, 'x'); 00493 hzpad = zp[0]; 00494 wzpad= zp[1]; 00495 } 00496 // create module 00497 module = (module_1_1<T,Tstate>*) 00498 new resize_module<T,Tstate>(resizeh, resizew, BILINEAR_RESIZE, 00499 hzpad, wzpad); 00500 } 00501 // resize ///////////////////////////////////////////////////////// 00502 else if (!type.compare("lpyramid")) { 00503 uint nscales = 0; 00504 string pp, skernels, sscalings, szpad; 00505 bool globnorm = true, locnorm = true, locnorm2 = false, 00506 color_lnorm = false, cnorm_across = true; 00507 midxdim zpads; 00508 if (!get_param(conf, name, "nscales", nscales)) return NULL; 00509 get_param(conf, name, "pp", pp, true); 00510 if (!get_param(conf, name, "kernels", skernels)) return NULL; 00511 midxdim kernels = string_to_idxdimvector(skernels.c_str()); 00512 get_param(conf, name, "globalnorm", globnorm, true); 00513 get_param(conf, name, "localnorm", locnorm, true); 00514 get_param(conf, name, "localnorm2", locnorm2, true); 00515 get_param(conf, name, "cnorm_across", cnorm_across, true); 00516 get_param(conf, name, "color_lnorm", color_lnorm, true); 00517 if (get_param(conf, name, "zpad", szpad, true)) 00518 zpads = string_to_idxdimvector(szpad.c_str()); 00519 00520 vector<float> scalings; 00521 if (get_param(conf, name, "scalings", sscalings, true)) 00522 scalings = string_to_floatvector(sscalings.c_str()); 00523 // create module 00524 module = (module_1_1<T,Tstate>*) 00525 create_preprocessing<T,Tstate>(pp.c_str(), kernels, zpads, "bilinear", 00526 true, 00527 nscales, NULL, NULL, globnorm, locnorm, 00528 locnorm2, color_lnorm, 00529 cnorm_across, 1.0, 1.0, 00530 scalings.size() > 0 ? &scalings : NULL); 00531 } 00532 // convolution ///////////////////////////////////////////////////////// 00533 else if (!type.compare("conv") || !type.compare("convl")) { 00534 idxdim kernel, stride; 00535 string skernel, sstride; 00536 idx<intg> table(1, 1); 00537 if (get_param(conf, name, "kernel", skernel, true)) 00538 kernel = string_to_idxdim(skernel); 00539 if (get_param(conf, name, "stride", sstride, true)) 00540 stride = string_to_idxdim(sstride); 00541 if (!load_table(conf, name, table, thick, nout)) return NULL; 00542 // update thickness 00543 idx<intg> tblmax = table.select(1, 1); 00544 thick = 1 + idx_max(tblmax); 00545 // create module 00546 if (!type.compare("conv")) // conv module 00547 module = (module_1_1<T,Tstate>*) 00548 // new convolution_module_replicable<T,Tstate> 00549 new convolution_module<T,Tstate> 00550 (bshared_exists? NULL : &theparam, kernel, stride, table, 00551 name.c_str()); 00552 else if (!type.compare("convl")) // conv layer 00553 module = (module_1_1<T,Tstate>*) 00554 new convolution_layer<T,Tstate> 00555 (bshared_exists? NULL : &theparam, kernel, stride, table, 00556 true /* tanh */, name.c_str()); 00557 } 00558 // subsampling /////////////////////////////////////////////////////// 00559 else if (!type.compare("subs") || !type.compare("subsl") 00560 || !type.compare("maxss")) { 00561 string skernel, sstride; 00562 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00563 if (!get_param(conf, name, "stride", sstride)) return NULL; 00564 idxdim kernel = string_to_idxdim(skernel); 00565 idxdim stride = string_to_idxdim(sstride); 00566 // create module 00567 if (!type.compare("subs")) // subsampling module 00568 module = (module_1_1<T,Tstate>*) 00569 new subsampling_module_replicable<T,Tstate> 00570 (bshared_exists? NULL : &theparam, thick, kernel, stride, 00571 name.c_str()); 00572 else if (!type.compare("subsl")) 00573 module = (module_1_1<T,Tstate>*) 00574 new subsampling_layer<T,Tstate> 00575 (bshared_exists? NULL : &theparam, thick, kernel, stride, true, 00576 name.c_str()); 00577 else if (!type.compare("maxss")) 00578 module = (module_1_1<T,Tstate>*) 00579 new maxss_module<T,Tstate>(thick, kernel, stride, name.c_str()); 00580 } 00581 // subsampling /////////////////////////////////////////////////////// 00582 else if (!type.compare("avg_pyramid")) { 00583 string sstride; 00584 if (!get_param(conf, name, "strides", sstride)) return NULL; 00585 midxdim strides = string_to_idxdimvector(sstride.c_str()); 00586 module = (module_1_1<T,Tstate>*) 00587 new average_pyramid_module<T,Tstate> 00588 (bshared_exists? NULL : &theparam, thick, strides, name.c_str()); 00589 } 00590 // wavg_pooling /////////////////////////////////////////////////////////////// 00591 else if (!type.compare("wavgpool")) { 00592 string skernel, sstride; 00593 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00594 if (!get_param(conf, name, "stride", sstride)) return NULL; 00595 idxdim kernel = string_to_idxdim(skernel); 00596 idxdim stride = string_to_idxdim(sstride); 00597 module = (module_1_1<T,Tstate>*) 00598 new wavg_pooling_module<T,Tstate>(thick, kernel, stride, name.c_str()); 00599 } 00600 // l1pooling /////////////////////////////////////////////////////////////// 00601 else if (!type.compare("l1pool")) { 00602 string skernel, sstride; 00603 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00604 if (!get_param(conf, name, "stride", sstride)) return NULL; 00605 idxdim kernel = string_to_idxdim(skernel); 00606 idxdim stride = string_to_idxdim(sstride); 00607 intg th = thick; 00608 get_param(conf, name, "thickness", th, true); 00609 module = (module_1_1<T,Tstate>*) 00610 new lppooling_module<T,Tstate>(th, kernel, stride, 1, name.c_str()); 00611 } 00612 // l2pooling /////////////////////////////////////////////////////////////// 00613 else if (!type.compare("l2pool")) { 00614 string skernel, sstride; 00615 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00616 if (!get_param(conf, name, "stride", sstride)) return NULL; 00617 idxdim kernel = string_to_idxdim(skernel); 00618 idxdim stride = string_to_idxdim(sstride); 00619 intg th = thick; 00620 get_param(conf, name, "thickness", th, true); 00621 module = (module_1_1<T,Tstate>*) 00622 new lppooling_module<T,Tstate>(th, kernel, stride, 2, name.c_str()); 00623 } 00624 // l4pooling /////////////////////////////////////////////////////////////// 00625 else if (!type.compare("l4pool")) { 00626 string skernel, sstride; 00627 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00628 if (!get_param(conf, name, "stride", sstride)) return NULL; 00629 idxdim kernel = string_to_idxdim(skernel); 00630 idxdim stride = string_to_idxdim(sstride); 00631 intg th = thick; 00632 get_param(conf, name, "thickness", th, true); 00633 module = (module_1_1<T,Tstate>*) 00634 new lppooling_module<T,Tstate>(th, kernel, stride, 4, name.c_str()); 00635 } 00636 // l6pooling /////////////////////////////////////////////////////////////// 00637 else if (!type.compare("l6pool")) { 00638 string skernel, sstride; 00639 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00640 if (!get_param(conf, name, "stride", sstride)) return NULL; 00641 idxdim kernel = string_to_idxdim(skernel); 00642 idxdim stride = string_to_idxdim(sstride); 00643 intg th = thick; 00644 get_param(conf, name, "thickness", th, true); 00645 module = (module_1_1<T,Tstate>*) 00646 new lppooling_module<T,Tstate>(th, kernel, stride, 6, name.c_str()); 00647 } 00648 // l8pooling /////////////////////////////////////////////////////////////// 00649 else if (!type.compare("l8pool")) { 00650 string skernel, sstride; 00651 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00652 if (!get_param(conf, name, "stride", sstride)) return NULL; 00653 idxdim kernel = string_to_idxdim(skernel); 00654 idxdim stride = string_to_idxdim(sstride); 00655 intg th = thick; 00656 get_param(conf, name, "thickness", th, true); 00657 module = (module_1_1<T,Tstate>*) 00658 new lppooling_module<T,Tstate>(th, kernel, stride, 8, name.c_str()); 00659 } 00660 // l10pooling /////////////////////////////////////////////////////////////// 00661 else if (!type.compare("l10pool")) { 00662 string skernel, sstride; 00663 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00664 if (!get_param(conf, name, "stride", sstride)) return NULL; 00665 idxdim kernel = string_to_idxdim(skernel); 00666 idxdim stride = string_to_idxdim(sstride); 00667 intg th = thick; 00668 get_param(conf, name, "thickness", th, true); 00669 module = (module_1_1<T,Tstate>*) 00670 new lppooling_module<T,Tstate>(th, kernel, stride, 10, name.c_str()); 00671 } 00672 // l12pooling /////////////////////////////////////////////////////////////// 00673 else if (!type.compare("l12pool")) { 00674 string skernel, sstride; 00675 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00676 if (!get_param(conf, name, "stride", sstride)) return NULL; 00677 idxdim kernel = string_to_idxdim(skernel); 00678 idxdim stride = string_to_idxdim(sstride); 00679 intg th = thick; 00680 get_param(conf, name, "thickness", th, true); 00681 module = (module_1_1<T,Tstate>*) 00682 new lppooling_module<T,Tstate>(th, kernel, stride, 12, name.c_str()); 00683 } 00684 // l14pooling /////////////////////////////////////////////////////////////// 00685 else if (!type.compare("l14pool")) { 00686 string skernel, sstride; 00687 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00688 if (!get_param(conf, name, "stride", sstride)) return NULL; 00689 idxdim kernel = string_to_idxdim(skernel); 00690 idxdim stride = string_to_idxdim(sstride); 00691 intg th = thick; 00692 get_param(conf, name, "thickness", th, true); 00693 module = (module_1_1<T,Tstate>*) 00694 new lppooling_module<T,Tstate>(th, kernel, stride, 14, name.c_str()); 00695 } 00696 // l16pooling /////////////////////////////////////////////////////////////// 00697 else if (!type.compare("l16pool")) { 00698 string skernel, sstride; 00699 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00700 if (!get_param(conf, name, "stride", sstride)) return NULL; 00701 idxdim kernel = string_to_idxdim(skernel); 00702 idxdim stride = string_to_idxdim(sstride); 00703 intg th = thick; 00704 get_param(conf, name, "thickness", th, true); 00705 module = (module_1_1<T,Tstate>*) 00706 new lppooling_module<T,Tstate>(th, kernel, stride, 16, name.c_str()); 00707 } 00708 // l32pooling /////////////////////////////////////////////////////////////// 00709 else if (!type.compare("l32pool")) { 00710 string skernel, sstride; 00711 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00712 if (!get_param(conf, name, "stride", sstride)) return NULL; 00713 idxdim kernel = string_to_idxdim(skernel); 00714 idxdim stride = string_to_idxdim(sstride); 00715 intg th = thick; 00716 get_param(conf, name, "thickness", th, true); 00717 module = (module_1_1<T,Tstate>*) 00718 new lppooling_module<T,Tstate>(th, kernel, stride, 32, name.c_str()); 00719 } 00720 // l64pooling /////////////////////////////////////////////////////////////// 00721 else if (!type.compare("l64pool")) { 00722 string skernel, sstride; 00723 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00724 if (!get_param(conf, name, "stride", sstride)) return NULL; 00725 idxdim kernel = string_to_idxdim(skernel); 00726 idxdim stride = string_to_idxdim(sstride); 00727 intg th = thick; 00728 get_param(conf, name, "thickness", th, true); 00729 module = (module_1_1<T,Tstate>*) 00730 new lppooling_module<T,Tstate>(th, kernel, stride, 64, name.c_str()); 00731 } 00732 // lppooling /////////////////////////////////////////////////////////////// 00733 else if (!type.compare("lppool")) { 00734 string skernel, sstride; 00735 uint pool_power; 00736 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00737 if (!get_param(conf, name, "stride", sstride)) return NULL; 00738 if (!get_param(conf, name, "power", pool_power)) return NULL; 00739 idxdim kernel = string_to_idxdim(skernel); 00740 idxdim stride = string_to_idxdim(sstride); 00741 intg th = thick; 00742 get_param(conf, name, "thickness", th, true); 00743 module = (module_1_1<T,Tstate>*) 00744 new lppooling_module<T,Tstate>(th, kernel, stride, pool_power, name.c_str()); 00745 } 00746 // linear ////////////////////////////////////////////////////////////// 00747 else if (!type.compare("linear") || !type.compare("linear_replicable")) { 00748 intg lin, lout; 00749 if (!get_param2(conf, name, "in", lin, thick, nout)) return NULL; 00750 if (!get_param2(conf, name, "out", lout, thick, nout)) return NULL; 00751 // create module 00752 if (!type.compare("linear")) 00753 module = (module_1_1<T,Tstate>*) new linear_module<T,Tstate> 00754 (bshared_exists? NULL : &theparam, lin, lout, name.c_str()); 00755 else 00756 module = (module_1_1<T,Tstate>*) new linear_module_replicable<T,Tstate> 00757 (bshared_exists? NULL : &theparam, lin, lout, name.c_str()); 00758 thick = lout; // update thickness 00759 } 00760 // addc ////////////////////////////////////////////////////////////// 00761 else if (!type.compare("addc")) 00762 module = (module_1_1<T,Tstate>*) new addc_module<T,Tstate> 00763 (bshared_exists? NULL : &theparam, thick, name.c_str()); 00764 // diag ////////////////////////////////////////////////////////////// 00765 else if (!type.compare("diag")) 00766 module = (module_1_1<T,Tstate>*) new diag_module<T,Tstate> 00767 (bshared_exists? NULL : &theparam, thick, name.c_str()); 00768 // copy ////////////////////////////////////////////////////////////// 00769 else if (!type.compare("copy")) 00770 module = (module_1_1<T,Tstate>*) new copy_module<T,Tstate> 00771 (name.c_str()); 00772 // printer ////////////////////////////////////////////////////////////// 00773 else if (!type.compare("printer")) 00774 module = (module_1_1<T,Tstate>*) new printer_module<T,Tstate> 00775 (name.c_str()); 00776 // normalization /////////////////////////////////////////////////////////// 00777 else if (!type.compare("wstd") || !type.compare("cnorm") 00778 || !type.compare("snorm") || !type.compare("dnorm")) { 00779 intg wthick = thick; 00780 string skernel; 00781 bool learn = false, learn_mean = false, fsum_div = false; 00782 double cgauss = 2.0, epsilon = 1e-6; 00783 float fsum_split = 1.0; 00784 if (!get_param(conf, name, "kernel", skernel)) return NULL; 00785 idxdim ker = string_to_idxdim(skernel); 00786 // set optional number of features (default is 'thick') 00787 get_param(conf, name, "features", wthick, true); 00788 get_param(conf, name, "learn", learn, true); 00789 get_param(conf, name, "learn_mean", learn_mean, true); 00790 get_param(conf, name, "gaussian_coeff", cgauss, true); 00791 get_param(conf, name, "fsum_div", fsum_div, true); 00792 get_param(conf, name, "fsum_split", fsum_split, true); 00793 get_param(conf, name, "epsilon", epsilon, true); 00794 // normalization modules 00795 if (!type.compare("wstd") || !type.compare("cnorm")) 00796 module = (module_1_1<T,Tstate>*) new contrast_norm_module<T,Tstate> 00797 (ker, wthick, conf.exists_true("mirror"), true, false, 00798 learn ? &theparam : NULL, name.c_str(), true, learn_mean, cgauss, 00799 fsum_div, fsum_split, epsilon); 00800 else if (!type.compare("snorm")) 00801 module = (module_1_1<T,Tstate>*) new subtractive_norm_module<T,Tstate> 00802 (ker, wthick, conf.exists_true("mirror"), false, 00803 learn ? &theparam : NULL, name.c_str(), true, cgauss, 00804 fsum_div, fsum_split); 00805 else if (!type.compare("dnorm")) 00806 module = (module_1_1<T,Tstate>*) new divisive_norm_module<T,Tstate> 00807 (ker, wthick, conf.exists_true("mirror"), true, 00808 learn ? &theparam : NULL, name.c_str(), true, cgauss, fsum_div, 00809 fsum_split, epsilon); 00810 } 00811 // smooth shrink /////////////////////////////////////////////////////////// 00812 else if (!type.compare("sshrink")) { 00813 string sbias, sbeta; 00814 T beta = (T) 10, bias = (T) .3; 00815 if (get_param(conf, name, "beta", sbeta, true)) 00816 beta = (T) string_to_double(sbeta); 00817 if (get_param(conf, name, "bias", bias, true)) 00818 bias = (T) string_to_double(sbias); 00819 module = (module_1_1<T,Tstate>*) new smooth_shrink_module<T,Tstate> 00820 (bshared_exists? NULL : &theparam, thick, beta, bias); 00821 } 00822 // linear shrink /////////////////////////////////////////////////////////// 00823 else if (!type.compare("lshrink")) { 00824 string sbias; 00825 T bias = 0; 00826 if (get_param(conf, name, "bias", sbias, true)) 00827 bias = (T) string_to_double(sbias); 00828 module = (module_1_1<T,Tstate>*) new linear_shrink_module<T,Tstate> 00829 (bshared_exists? NULL : &theparam, thick, bias); 00830 } 00831 // linear shrink /////////////////////////////////////////////////////////// 00832 else if (!type.compare("tshrink")) { 00833 bool diags = false; 00834 get_param(conf, name, "coefficients", diags, true); 00835 module = (module_1_1<T,Tstate>*) new tanh_shrink_module<T,Tstate> 00836 (bshared_exists? NULL : &theparam, thick, diags); 00837 // tanh /////////////////////////////////////////////////////////////// 00838 } else if (!type.compare("tanh")) 00839 module = (module_1_1<T,Tstate>*) new tanh_module<T,Tstate>(); 00840 // stdsig ////////////////////////////////////////////////////////////// 00841 else if (!type.compare("stdsig")) 00842 module = (module_1_1<T,Tstate>*) new stdsigmoid_module<T,Tstate>(); 00843 // abs ////////////////////////////////////////////////////////////// 00844 else if (!type.compare("abs")) 00845 module = (module_1_1<T,Tstate>*) new abs_module<T,Tstate>(); 00846 // abs ////////////////////////////////////////////////////////////// 00847 else if (!type.compare("back")) 00848 module = (module_1_1<T,Tstate>*) new back_module<T,Tstate>(); 00849 // abs ////////////////////////////////////////////////////////////// 00850 else if (!type.compare("lua")) { 00851 string script; 00852 if (!get_param(conf, name, "script", script)) return NULL; 00853 module = (module_1_1<T,Tstate>*) new lua_module<T,Tstate>(script.c_str()); 00854 } else 00855 cout << "unknown module type " << type << endl; 00856 // check if the module we're loading is shared 00857 if (module && bshared) { // this module is shared with others 00858 // check if we already have it in stock 00859 typename map<string,module_1_1<T,Tstate>*>::iterator i = 00860 shared.find(name); 00861 if (i != shared.end()) { // already exist 00862 delete module; 00863 module = i->second->copy(); // load a shared copy instead 00864 cout << "Loaded a shared copy of " << name << ". "; 00865 } 00866 else // we don't have it, add it 00867 shared[name] = module; // save this copy for future sharing 00868 } 00869 // add an ebm1 wrapper around this module if requested 00870 string sebm; 00871 if (get_param(conf, name, "energy", sebm, true)) { 00872 // create penalty module 00873 ebm_1<T,Tstate> *e = create_ebm1<T,Tstate>(sebm, conf); 00874 if (!e) eblerror("failed to create ebm1 from " << sebm); 00875 // create hybrid penalty / module_1_1 00876 module = new ebm_module_1_1<T,Tstate>(module, e, sebm.c_str()); 00877 } 00878 return module; 00879 } 00880 00881 // select network based on configuration 00882 template <typename T, class Tstate> 00883 ebm_1<T,Tstate>* create_ebm1(const string &name, configuration &conf) { 00884 string type = strip_last_num(name); 00885 ebm_1<T,Tstate> *ebm = NULL; 00886 // switch on each possible type of module 00887 if (!type.compare("l1penalty")) { 00888 T threshold = 0, coeff = 1; 00889 get_param(conf, name, "threshold", threshold, true); 00890 get_param(conf, name, "coeff", coeff, true); 00891 ebm = (ebm_1<T,Tstate>*) new l1_penalty<T,Tstate>(threshold, coeff); 00892 } 00893 else cout << "unknown ebm1 type " << type << endl; 00894 return ebm; 00895 } 00896 00897 template <typename T, typename Tds1, typename Tds2, class Tstate> 00898 answer_module<T,Tds1,Tds2,Tstate>* 00899 create_answer(configuration &conf, uint noutputs, 00900 const char *varname) { 00901 string name = conf.get_string(varname); 00902 string type = strip_last_num(name); 00903 answer_module<T,Tds1,Tds2,Tstate> *module = NULL; 00904 // loop on possible answer modules ///////////////////////////////////////// 00905 if (!type.compare("class_answer")) { 00906 string factor_name, binary_name, tconf_name, tanh_name, force_name; 00907 t_confidence tconf = confidence_max; 00908 bool binary = false, btanh = false; 00909 float factor = 1.0; 00910 int force = -1; 00911 if (get_param(conf, name, "factor", factor_name, true)) 00912 factor = string_to_float(factor_name); 00913 if (get_param(conf, name, "binary", binary_name, true)) 00914 binary = (bool) string_to_uint(binary_name); 00915 if (get_param(conf, name, "confidence", tconf_name, true)) 00916 tconf = (t_confidence) string_to_uint(tconf_name); 00917 if (get_param(conf, name, "tanh", tanh_name, true)) 00918 btanh = (bool) string_to_uint(tanh_name); 00919 if (get_param(conf, name, "force_class", force_name, true)) 00920 force = string_to_int(force_name); 00921 module = new class_answer<T,Tds1,Tds2,Tstate> 00922 (noutputs, factor, binary, tconf, btanh, name.c_str(), force); 00924 } else if (!type.compare("vote_answer")) { 00925 string factor_name, binary_name, tconf_name, tanh_name; 00926 t_confidence tconf = confidence_max; 00927 bool binary = false, btanh = false; 00928 float factor = 1.0; 00929 if (get_param(conf, name, "factor", factor_name, true)) 00930 factor = string_to_float(factor_name); 00931 if (get_param(conf, name, "binary", binary_name, true)) 00932 binary = (bool) string_to_uint(binary_name); 00933 if (get_param(conf, name, "confidence", tconf_name, true)) 00934 tconf = (t_confidence) string_to_uint(tconf_name); 00935 if (get_param(conf, name, "tanh", tanh_name, true)) 00936 btanh = (bool) string_to_uint(tanh_name); 00937 module = new vote_answer<T,Tds1,Tds2,Tstate> 00938 (noutputs, factor, binary, tconf, btanh, name.c_str()); 00940 } else if (!type.compare("regression_answer")) { 00941 string threshold_name; 00942 float64 threshold = 0.0; 00943 if (get_param(conf, name, "threshold", threshold_name, true)) 00944 threshold = (float64) string_to_double(threshold_name); 00945 module = new regression_answer<T,Tds1,Tds2,Tstate> 00946 (noutputs, threshold, name.c_str()); 00948 } else if (!type.compare("scaler_answer")) { 00949 string negative_name, raw_name, threshold_name, spatial_name; 00950 bool raw_conf = false, spatial = false; 00951 float threshold = 0.0; 00952 if (!get_param(conf, name, "negative", negative_name)) return NULL; 00953 if (get_param(conf, name, "rawconf", raw_name, true)) 00954 raw_conf = (bool) string_to_uint(raw_name); 00955 if (get_param(conf, name, "threshold", threshold_name, true)) 00956 threshold = (float) string_to_float(threshold_name); 00957 if (get_param(conf, name, "spatial", spatial_name, true)) 00958 spatial = (bool) string_to_uint(spatial_name); 00959 module = new scaler_answer<T,Tds1,Tds2,Tstate> 00960 (1, 0, raw_conf, threshold, spatial, name.c_str()); 00962 } else if (!type.compare("scalerclass_answer")) { 00963 string factor_name, binary_name, tconf_name, tanh_name, 00964 jsize_name, joff_name, mgauss_name, pconf_name, pbconf_name, 00965 coeffs_name, biases_name; 00966 t_confidence tconf = confidence_max; 00967 bool binary = false, btanh = false, 00968 predict_conf = false, predict_bconf = false; 00969 float factor = 1.0, mgauss = 1.5; 00970 uint jsize = 1, joff = 0; 00971 idx<T> coeffs, biases; 00972 bool coeffs_set = false, biases_set = false; 00973 if (get_param(conf, name, "factor", factor_name, true)) 00974 factor = string_to_float(factor_name); 00975 if (get_param(conf, name, "binary", binary_name, true)) 00976 binary = (bool) string_to_uint(binary_name); 00977 if (get_param(conf, name, "confidence", tconf_name, true)) 00978 tconf = (t_confidence) string_to_uint(tconf_name); 00979 if (get_param(conf, name, "tanh", tanh_name, true)) 00980 btanh = (bool) string_to_uint(tanh_name); 00981 if (get_param(conf, name, "jsize", jsize_name, true)) 00982 jsize = string_to_uint(jsize_name); 00983 if (get_param(conf, name, "joffset", joff_name, true)) 00984 joff = string_to_uint(joff_name); 00985 if (get_param(conf, name, "mgauss", mgauss_name, true)) 00986 mgauss = string_to_float(mgauss_name); 00987 if (get_param(conf, name, "predict_conf", pconf_name, true)) 00988 predict_conf = (bool) string_to_uint(pconf_name); 00989 if (get_param(conf, name, "predict_bconf", pbconf_name, true)) 00990 predict_bconf = (bool) string_to_uint(pbconf_name); 00991 if (get_param(conf, name, "coeffs", coeffs_name, true)) { 00992 coeffs = string_to_idx<T>(coeffs_name.c_str()); 00993 coeffs_set = true; 00994 } 00995 if (get_param(conf, name, "biases", biases_name, true)) { 00996 biases = string_to_idx<T>(biases_name.c_str()); 00997 biases_set = true; 00998 } 00999 module = 01000 new scalerclass_answer<T,Tds1,Tds2,Tstate> 01001 (noutputs, factor, binary, tconf, btanh, jsize, joff, mgauss, 01002 predict_conf, predict_bconf, biases_set ? &biases : NULL, 01003 coeffs_set ? &coeffs : NULL, name.c_str()); 01005 } else 01006 cerr << "unknown answer type " << type << endl; 01007 return module; 01008 } 01009 01010 template <typename T, typename Tds1, typename Tds2, class Tstate> 01011 trainable_module<T,Tds1,Tds2,Tstate>* 01012 create_trainer(configuration &conf, labeled_datasource<T,Tds1,Tds2> &ds, 01013 module_1_1<T,Tstate> &net, 01014 answer_module<T,Tds1,Tds2,Tstate> &answer, 01015 const char *varname) { 01016 string name = conf.get_string(varname); 01017 string type = strip_last_num(name); 01018 trainable_module<T,Tds1,Tds2,Tstate> *module = NULL; 01019 // switch on each possible type of trainer module 01020 if (!type.compare("trainable_module")) { 01021 ebm_2<Tstate> *energy = NULL; 01022 string energy_name, switcher; 01023 if (!get_param(conf, name, "energy", energy_name)) return NULL; 01024 string energy_type = strip_last_num(energy_name); 01025 get_param(conf, name, "switcher", switcher, true); 01026 01027 // loop on possible energy modules /////////////////////////////////////// 01028 if (!energy_type.compare("l2_energy")) { 01029 energy = new l2_energy<T,Tstate>(energy_name.c_str()); 01030 } else if (!energy_type.compare("scalerclass_energy")) { 01031 string tanh_name, jsize_name, jselection_name, dist_name, scale_name, 01032 pconf_name, pbconf_name, coeffs_name, biases_name; 01033 bool apply_tanh = false, predict_conf = false, predict_bconf = false; 01034 uint jsize = 1, jselection = 0; 01035 float dist_coeff = 1.0, scale_coeff = 1.0; 01036 idx<T> coeffs, biases; 01037 bool coeffs_set = false, biases_set = false; 01038 if (get_param(conf, energy_name, "tanh", tanh_name, true)) 01039 apply_tanh = (bool) string_to_uint(tanh_name); 01040 if (get_param(conf, energy_name, "jsize", jsize_name, true)) 01041 jsize = string_to_uint(jsize_name); 01042 if (get_param(conf, energy_name, "jselection", jselection_name, true)) 01043 jselection = string_to_uint(jselection_name); 01044 if (get_param(conf, energy_name, "distcoeff", dist_name, true)) 01045 dist_coeff = string_to_float(dist_name); 01046 if (get_param(conf, energy_name, "scalecoeff", scale_name, true)) 01047 scale_coeff = string_to_float(scale_name); 01048 if (get_param(conf, energy_name, "predict_conf", pconf_name, true)) 01049 predict_conf = (bool) string_to_uint(pconf_name); 01050 if (get_param(conf, energy_name, "predict_bconf", pbconf_name, true)) 01051 predict_bconf = (bool) string_to_uint(pbconf_name); 01052 if (get_param(conf, energy_name, "coeffs", coeffs_name, true)) { 01053 coeffs = string_to_idx<T>(coeffs_name.c_str()); 01054 coeffs_set = true; 01055 } 01056 if (get_param(conf, energy_name, "biases", biases_name, true)) { 01057 biases = string_to_idx<T>(biases_name.c_str()); 01058 biases_set = true; 01059 } 01060 energy = 01061 new scalerclass_energy<T,Tstate>(apply_tanh, jsize, jselection, 01062 dist_coeff, scale_coeff, 01063 predict_conf, predict_bconf, 01064 biases_set ? &biases : NULL, 01065 coeffs_set ? &coeffs : NULL, 01066 energy_name.c_str()); 01067 } else if (!energy_type.compare("scaler_energy")) { 01068 energy = new scaler_energy<T,Tstate>(energy_name.c_str()); 01069 } else 01070 eblerror("unknown energy type " << energy_type); 01071 01072 // allocate trainer module 01073 module = new trainable_module<T,Tds1,Tds2,Tstate> 01074 (*energy, net, NULL, NULL, &answer, name.c_str(), switcher.c_str()); 01075 } 01076 if (!module) 01077 eblerror("no trainer module found"); 01078 return module; 01079 } 01080 01081 template <typename T, class Tstate> 01082 resizepp_module<T,Tstate>* 01083 create_preprocessing(uint height, uint width, const char *ppchan, 01084 idxdim &kersz, const char *resize_method, 01085 bool keep_aspect_ratio, int lpyramid, 01086 vector<double> *fovea, midxdim *fovea_scale_size, 01087 bool globnorm, bool locnorm, bool locnorm2, 01088 bool color_lnorm, bool cnorm_across, 01089 double hscale, double wscale, vector<float> *scalings) { 01090 midxdim kers; 01091 kers.push_back(kersz); 01092 return create_preprocessing<T,Tstate> 01093 (height, width, ppchan, kers, resize_method, keep_aspect_ratio, lpyramid, 01094 fovea, fovea_scale_size, globnorm, locnorm, locnorm2, color_lnorm, 01095 cnorm_across, hscale, wscale, scalings); 01096 } 01097 01098 template <typename T, class Tstate> 01099 resizepp_module<T,Tstate>* 01100 create_preprocessing(midxdim &dims, const char *ppchan, 01101 midxdim &kers, midxdim &zpads, const char *resize_method, 01102 bool keep_aspect_ratio, int lpyramid, 01103 vector<double> *fovea, midxdim *fovea_scale_size, 01104 bool globnorm, bool locnorm, 01105 bool locnorm2, bool color_lnorm, bool cnorm_across, 01106 double hscale, double wscale, vector<float> *scalings) { 01107 module_1_1<T,Tstate> *chanmodule = NULL; 01108 resizepp_module<T,Tstate> *ppmodule = NULL; 01109 if (kers.size() == 0) eblerror("expected at least 1 ker dims"); 01110 idxdim kersz = kers[0]; 01111 // set name of preprocessing 01112 string name; 01113 if (dims.size() == 0) eblerror("expected at least 1 idxdim in dims"); 01114 idxdim d = dims[0]; 01115 int height = d.dim(0), width = d.dim(1); 01116 name << ppchan << kersz << "_" << resize_method << height << "x" << width; 01117 if (!keep_aspect_ratio) name << "_noaspratio"; 01118 if (!globnorm) name << "_nognorm"; 01119 // set default min/max val for display 01120 T minval = (T) -2, maxval = (T) 2; 01121 t_norm tn = WSTD_NORM; bool mir = true; 01122 // create channel preprocessing module 01123 if (!strcmp(ppchan, "YpUV") || !strcmp(ppchan, "YnUV")) { 01124 chanmodule = new rgb_to_ynuv_module<T,Tstate>(kersz, mir, tn, globnorm); 01125 } else if (!strcmp(ppchan, "Yp") || !strcmp(ppchan, "Yn")) { 01126 chanmodule = new rgb_to_yn_module<T,Tstate>(kersz, mir, tn, globnorm); 01127 } else if (!strcmp(ppchan, "YnUVn")) { 01128 chanmodule = new rgb_to_ynuvn_module<T,Tstate>(kersz, mir, tn, globnorm); 01129 } else if (!strcmp(ppchan, "YnUnVn")) { 01130 chanmodule = new rgb_to_ynunvn_module<T,Tstate>(kersz, mir, tn, globnorm); 01131 } else if (!strcmp(ppchan, "YUVn")) { 01132 chanmodule = new rgb_to_yuvn_module<T,Tstate>(kersz, mir, tn, globnorm); 01133 } else if (!strcmp(ppchan, "RGBn")) { 01134 chanmodule = new rgb_to_rgbn_module<T,Tstate>(kersz, mir, tn, globnorm); 01135 } else if (!strcmp(ppchan, "YUV")) { 01136 chanmodule = new rgb_to_yuv_module<T,Tstate>(globnorm); 01137 } else if (!strcmp(ppchan, "HSV")) { 01138 eblerror("HSV pp module not implemented"); 01139 } else if (!strcmp(ppchan, "RGB")) { 01140 // no preprocessing module, just set min/max val for display 01141 minval = (T) 0; 01142 maxval = (T) 255; 01143 } else eblerror("undefined channel preprocessing " << ppchan); 01144 // initialize resizing method 01145 uint resiz = 0; 01146 if (!strcmp(resize_method, "bilinear")) resiz = BILINEAR_RESIZE; 01147 else if (!strcmp(resize_method, "gaussian")) resiz = GAUSSIAN_RESIZE; 01148 else if (!strcmp(resize_method, "mean")) resiz = MEAN_RESIZE; 01149 else eblerror("undefined resizing method" << resize_method); 01150 // create resizing module 01151 // fovea resize 01152 if (fovea && fovea->size() > 0) { 01153 if (!fovea_scale_size || fovea_scale_size->size() != fovea->size()) 01154 eblerror("expected same number of parameters in fovea and " 01155 << "fovea_scale_size"); 01156 ppmodule = new fovea_module<T,Tstate>(*fovea, *fovea_scale_size, d, true, 01157 resiz, chanmodule); 01158 name << "_fovea" << fovea->size(); 01159 } else if (lpyramid > 0) { // laplacian pyramid resize 01160 laplacian_pyramid_module<T,Tstate> *pyr = 01161 new laplacian_pyramid_module<T,Tstate> 01162 (lpyramid, kers, dims, resiz, chanmodule, false, NULL, globnorm, 01163 locnorm, locnorm2, color_lnorm, cnorm_across, keep_aspect_ratio); 01164 if (scalings) pyr->set_scalings(*scalings); 01165 ppmodule = pyr; 01166 if (!locnorm) name << "_nolnorm"; 01167 if (!locnorm2) name << "_nolnorm2"; 01168 if (color_lnorm) { 01169 name << "_colorlnorm"; 01170 if (cnorm_across) name << "across"; 01171 } 01172 name << "_lpyramid" << lpyramid; 01173 } else // regular resize 01174 ppmodule = new resizepp_module<T,Tstate>(d, resiz, chanmodule, true, 01175 NULL, keep_aspect_ratio); 01176 ppmodule->set_scale_factor(hscale, wscale); 01177 ppmodule->set_display_range(minval, maxval); 01178 ppmodule->set_name(name.c_str()); 01179 ppmodule->set_zpad(zpads); 01180 return ppmodule; 01181 } 01182 01183 template <typename T, class Tstate> 01184 resizepp_module<T,Tstate>* 01185 create_preprocessing(const char *ppchan, 01186 midxdim &kers, midxdim &zpads, const char *resize_method, 01187 bool keep_aspect_ratio, int lpyramid, 01188 vector<double> *fovea, midxdim *fovea_scale_size, 01189 bool globnorm, bool locnorm, 01190 bool locnorm2, bool color_lnorm, bool cnorm_across, 01191 double hscale, double wscale, vector<float> *scalings) { 01192 midxdim d; 01193 d.push_back(idxdim(0, 0)); 01194 return create_preprocessing<T,Tstate> 01195 (d, ppchan, kers, zpads, resize_method, keep_aspect_ratio, lpyramid, 01196 fovea, fovea_scale_size, globnorm, locnorm, locnorm2, color_lnorm, 01197 cnorm_across, hscale, wscale, scalings); 01198 } 01199 01200 // select network based on configuration, using old-style variables 01201 template <typename T, class Tstate> 01202 module_1_1<T,Tstate>* create_network_old(parameter<T, Tstate> &theparam, 01203 configuration &conf, int noutputs) { 01204 string net_type = conf.get_string("net_type"); 01205 // load custom tables if defined 01206 string mname; 01207 idx<intg> t0(1,1), t1(1,1), t2(1,1), 01208 *table0 = NULL, *table1 = NULL, *table2 = NULL; 01209 intg thick = -1; 01210 mname = "conv0"; 01211 if (load_table(conf, mname, t0, thick, noutputs)) 01212 table0 = &t0; 01213 mname = "conv1"; 01214 if (load_table(conf, mname, t1, thick, noutputs)) 01215 table1 = &t1; 01216 mname = "conv2"; 01217 if (load_table(conf, mname, t2, thick, noutputs)) 01218 table2 = &t2; 01219 // create networks 01220 // cscscf //////////////////////////////////////////////////////////////// 01221 if (!strcmp(net_type.c_str(), "cscscf")) { 01222 return (module_1_1<T,Tstate>*) new lenet<T,Tstate> 01223 (theparam, conf.get_uint("net_ih"), conf.get_uint("net_iw"), 01224 conf.get_uint("net_c1h"), conf.get_uint("net_c1w"), 01225 conf.get_uint("net_s1h"), conf.get_uint("net_s1w"), 01226 conf.get_uint("net_c2h"), conf.get_uint("net_c2w"), 01227 conf.get_uint("net_s2h"), conf.get_uint("net_s2w"), 01228 conf.get_uint("net_full"), noutputs, 01229 conf.get_bool("absnorm"), conf.get_bool("color"), 01230 conf.get_bool("mirror"), conf.get_bool("use_tanh"), 01231 conf.exists_true("use_shrink"), conf.exists_true("use_diag"), 01232 table0, table1, table2); 01233 // cscsc //////////////////////////////////////////////////////////////// 01234 } else if (!strcmp(net_type.c_str(), "cscsc")) { 01235 return (module_1_1<T,Tstate>*) new lenet_cscsc<T,Tstate> 01236 (theparam, conf.get_uint("net_ih"), conf.get_uint("net_iw"), 01237 conf.get_uint("net_c1h"), conf.get_uint("net_c1w"), 01238 conf.get_uint("net_s1h"), conf.get_uint("net_s1w"), 01239 conf.get_uint("net_c2h"), conf.get_uint("net_c2w"), 01240 conf.get_uint("net_s2h"), conf.get_uint("net_s2w"), 01241 noutputs, conf.get_bool("absnorm"), conf.get_bool("color"), 01242 conf.get_bool("mirror"), conf.get_bool("use_tanh"), 01243 conf.exists_true("use_shrink"), conf.exists_true("use_diag"), 01244 conf.exists_true("norm_pos"), table0, table1, table2); 01245 // cscf //////////////////////////////////////////////////////////////// 01246 } else if (!strcmp(net_type.c_str(), "cscf")) { 01247 return (module_1_1<T,Tstate>*) new lenet_cscf<T,Tstate> 01248 (theparam, conf.get_uint("net_ih"), conf.get_uint("net_iw"), 01249 conf.get_uint("net_c1h"), conf.get_uint("net_c1w"), 01250 conf.get_uint("net_s1h"), conf.get_uint("net_s1w"), 01251 conf.get_uint("net_c2h"), conf.get_uint("net_c2w"), 01252 noutputs, conf.get_bool("absnorm"), conf.get_bool("color"), 01253 conf.get_bool("mirror"), conf.get_bool("use_tanh"), 01254 conf.exists_true("use_shrink"), conf.exists_true("use_diag"), 01255 table0, table1); 01256 // cscc //////////////////////////////////////////////////////////////// 01257 } else if (!strcmp(net_type.c_str(), "cscc")) { 01258 if (!table0 || !table1 || !table2) 01259 eblerror("undefined connection tables"); 01260 return (module_1_1<T,Tstate>*) new net_cscc<T,Tstate> 01261 (theparam, conf.get_uint("net_ih"), conf.get_uint("net_iw"), 01262 conf.get_uint("net_c1h"), conf.get_uint("net_c1w"), *table0, 01263 conf.get_uint("net_s1h"), conf.get_uint("net_s1w"), 01264 conf.get_uint("net_c2h"), conf.get_uint("net_c2w"), *table1, 01265 *table2, noutputs, conf.get_bool("absnorm"), 01266 conf.get_bool("mirror"), conf.get_bool("use_tanh"), 01267 conf.exists_true("use_shrink"), conf.exists_true("use_diag")); 01268 } else { 01269 cerr << "network type: " << net_type << endl; 01270 eblerror("unknown network type"); 01271 } 01272 return NULL; 01273 } 01274 01276 template <class Tmodule, typename T, class Tstate> 01277 bool load_module(configuration &conf, module_1_1<T,Tstate> &m, 01278 const string &module_name, const string &type) { 01279 if (!dynamic_cast<Tmodule*>(&m)) 01280 eblerror("cannot cast module " << module_name << " (\"" << m.name() 01281 << "\") into a " << type << " type"); 01282 string name = module_name; name << "_weights"; 01283 if (!conf.exists(name)) 01284 return false; // do nothing if variable not found 01285 string filename = conf.get_string(name.c_str()); 01286 idx<T> w = load_matrix<T>(filename); 01287 m.load_x(w); 01288 cout << "Loaded weights " << w << " into " << module_name << " from " 01289 << filename << " (dims " << w << " min " << idx_min(w) << " max " 01290 << idx_max(w) << " mean " << idx_mean(w) << ")" << endl; 01291 return true; 01292 } 01293 01294 // select network based on configuration 01295 template <typename T, class Tstate> 01296 uint manually_load_network(layers<T,Tstate> &l, configuration &conf, 01297 const char *varname) { 01298 list<string> arch = string_to_stringlist(conf.get_string(varname)); 01299 uint arch_size = arch.size(); 01300 cout << "Loading network manually using module list: " 01301 << conf.get_string(varname) << endl; 01302 uint n = 0; 01303 // loop over each module 01304 for (uint i = 0; i < arch_size; ++i) { 01305 // get first module name of the list and remove it from list 01306 string name = arch.front(); arch.pop_front(); 01307 string type = strip_last_num(name); 01308 module_1_1<T,Tstate> *m = l.modules[i]; 01309 // switch on each possible type of module 01310 if (!type.compare("conv")) 01311 n += load_module<convolution_module<T,Tstate>,T,Tstate> 01312 (conf, *m, name, type); 01313 else if (!type.compare("addc")) 01314 n += load_module<addc_module<T,Tstate>,T,Tstate>(conf, *m, name, type); 01315 else if (!type.compare("linear")) 01316 n += load_module<linear_module<T,Tstate>,T,Tstate>(conf, *m, name,type); 01317 else if (!type.compare("diag")) 01318 n += load_module<diag_module<T,Tstate>,T,Tstate>(conf, *m, name, type); 01319 else if (!type.compare("ms")) { 01320 ms_module<T,Tstate> *msm = dynamic_cast<ms_module<T,Tstate>*>(m); 01321 if (!msm) 01322 eblerror("expected a ms module while trying to load module " 01323 << name << " but found: "<< typeid(m).name()); 01324 for (uint pi = 0; pi < msm->npipes(); ++pi) { 01325 module_1_1<T,Tstate> *pipe = msm->get_pipe(pi); 01326 if (!pipe) continue ; 01327 if (!dynamic_cast<layers<T,Tstate>*>(pipe)) 01328 eblerror("expected a layers module in pipes[" << pi << "] while " 01329 << "trying to load module " << pipe->name() 01330 << " but found: " << typeid(pipe).name()); 01331 n += manually_load_network(*((layers<T,Tstate>*)pipe), conf, 01332 pipe->name()); 01333 } 01334 } 01335 else if (!type.compare("branch")) { 01336 if (!dynamic_cast<layers<T,Tstate>*>(m)) 01337 eblerror("expected a layers module with type branch while trying " 01338 << "to load module " << name << " but found: " 01339 << typeid(m).name()); 01340 n += manually_load_network(*((layers<T,Tstate>*)m), conf, 01341 name.c_str()); 01342 } 01343 } 01344 if (!l.is_branch()) 01345 cout << "Loaded " << n << " weights." << endl; 01346 return n; 01347 } 01348 01349 } // end namespace ebl 01350 01351 #endif /* NETCONF_HPP_ */