libeblearn
|
00001 /*************************************************************************** 00002 * Copyright (C) 2011 by Pierre Sermanet * 00003 * pierre.sermanet@gmail.com * 00004 * All rights reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions are met: 00008 * * Redistributions of source code must retain the above copyright 00009 * notice, this list of conditions and the following disclaimer. 00010 * * Redistributions in binary form must reproduce the above copyright 00011 * notice, this list of conditions and the following disclaimer in the 00012 * documentation and/or other materials provided with the distribution. 00013 * * Redistribution under a license not approved by the Open Source 00014 * Initiative (http://www.opensource.org) must display the 00015 * following acknowledgement in all advertising material: 00016 * This product includes software developed at the Courant 00017 * Institute of Mathematical Sciences (http://cims.nyu.edu). 00018 * * The names of the authors may not be used to endorse or promote products 00019 * derived from this software without specific prior written permission. 00020 * 00021 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED 00022 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 00023 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 00024 * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY 00025 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 00026 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 00027 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 00028 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00029 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 ***************************************************************************/ 00032 00033 #ifndef DATASOURCE_HPP_ 00034 #define DATASOURCE_HPP_ 00035 00036 #include <ostream> 00037 #include <algorithm> 00038 00039 using namespace std; 00040 00041 namespace ebl { 00042 00044 // datasource 00045 00046 template <typename Tnet, typename Tdata> 00047 datasource<Tnet,Tdata>::datasource() { 00048 } 00049 00050 template <typename Tnet, typename Tdata> 00051 datasource<Tnet,Tdata>:: 00052 datasource(midx<Tdata> &data_, const char *name_) { 00053 multimat = true; // data matrix is composed of multiple matrices 00054 init(data_, name_); 00055 init_epoch(); 00056 pretty(); // print information about the dataset 00057 } 00058 00059 template <typename Tnet, typename Tdata> 00060 datasource<Tnet,Tdata>:: 00061 datasource(idx<Tdata> &data_, const char *name_) { 00062 multimat = false; // data matrix is composed of multiple matrices 00063 init(data_, name_); 00064 init_epoch(); 00065 pretty(); // print information about the dataset 00066 } 00067 00068 template <typename Tnet, typename Tdata> 00069 datasource<Tnet,Tdata>:: 00070 datasource(const char *data_fname, const char *name_) { 00071 try { 00072 if (has_multiple_matrices(data_fname)) { 00073 multimat = true; 00074 midx<Tdata> datas_ = load_matrices<Tdata>(data_fname); 00075 init(datas_, name_); 00076 } else { 00077 multimat = false; 00078 idx<Tdata> data_ = load_matrix<Tdata>(data_fname); 00079 init(data_, name_); 00080 } 00081 init_epoch(); 00082 pretty(); // print information about the dataset 00083 } eblcatcherror(); 00084 } 00085 00086 template <typename Tnet, typename Tdata> 00087 datasource<Tnet,Tdata>::~datasource() { 00088 } 00089 00091 // init methods 00092 00093 template<typename Tnet, typename Tdata> 00094 void datasource<Tnet,Tdata>:: 00095 init(midx<Tdata> &datas_, const char *name_) { 00096 datas = datas_; 00097 data = (idx<Tdata>&) datas_; 00098 multimat = true; 00099 init2(name_); 00100 } 00101 00102 template<typename Tnet, typename Tdata> 00103 void datasource<Tnet,Tdata>:: 00104 init(idx<Tdata> &data_, const char *name_) { 00105 multimat = false; 00106 data = data_; 00107 init2(name_); 00108 } 00109 00110 template<typename Tnet, typename Tdata> 00111 void datasource<Tnet,Tdata>:: 00112 init2(const char *name_) { 00113 // init randomization 00114 if (!drand_ini) // only re-init if not initialized 00115 dynamic_init_drand(); // initialize random seed 00116 // no bias and coeff by default (0 and 1) 00117 bias = (Tnet) 0; 00118 coeff = (Tnet) 1.0; 00119 // iterating 00120 it = 0; 00121 it_test = 0; 00122 it_train = 0; 00123 shuffle_passes = false; 00124 test_set = false; 00125 epoch_sz = 0; 00126 epoch_cnt = 0; 00127 epoch_pick_cnt = 0; 00128 epoch_mode = 1; // default (1): all samples are seen at least once. 00129 hardest_focus = false; 00130 _ignore_correct = false; 00131 // state saving 00132 state_saved = false; 00133 // buffers assigments/allocations 00134 indices = idx<intg>(data.dim(0)); 00135 indices_saved = idx<intg>(data.dim(0)); 00136 probas = idx<double>(data.dim(0)); 00137 energies = idx<double>(data.dim(0)); 00138 raw_outputs = idx<Tnet>(1, 1); 00139 pick_count = idx<uint>(data.dim(0)); 00140 correct = idx<ubyte>(data.dim(0)); 00141 answers = idx<Tnet>(1, 1); 00142 targets = idx<Tnet>(1, 1); 00143 // pickings 00144 idx_clear(pick_count); 00145 count_pickings = true; 00146 sample_min_proba = 0.0; 00147 // intialize buffers 00148 idx_fill(correct, 0); 00149 idx_fill(answers, 0); 00150 idx_fill(targets, 0); 00151 idx_fill(probas, 1.0); // default picking probability for a sample is 1 00152 idx_fill(energies, -1.0); 00153 idx_fill(raw_outputs, 0); 00154 _name = (name_ ? name_ : "Unknown Dataset"); 00155 // iterating 00156 set_shuffle_passes(true); // for next_train only 00157 set_weigh_samples(true, true, true, 0.0); // for next_train only 00158 seek_begin(); 00159 seek_begin_train(); 00160 epoch_sz = size(); //get_lowest_common_size(); 00161 epoch_mode = 1; 00162 cout << _name << ": Each training epoch sees " << epoch_sz 00163 << " samples." << endl; 00164 not_picked = 0; 00165 epoch_show = 50; // print epoch count message every epoch_show 00166 epoch_show_printed = -1; // last epoch count we have printed 00167 // fill indices with original data order 00168 for (it = 0; it < data.dim(0); ++it) 00169 indices.set(it, it); 00170 // set sample dimensions 00171 if (multimat) { 00172 bool found = false; 00173 uint i = 0; 00174 idx<Tdata> e; 00175 if (datas.order() == 2) { 00176 for (intg i = 0; i < datas.dim(0) && !found; ++i) { 00177 for (intg j = 0; j < datas.dim(1); ++j) { 00178 if (datas.exists(i, j)) { 00179 found = true; 00180 e = datas.get(i, j); 00181 samplemfdims.push_back_new(e.get_idxdim()); 00182 } 00183 } 00184 } 00185 } else { 00186 while (!found && i < datas.dim(0)) { 00187 if (datas.exists(i)) { 00188 found = true; 00189 e = datas.get(i); 00190 samplemfdims.push_back_new(e.get_idxdim()); 00191 } 00192 cout << endl; 00193 i++; 00194 } 00195 } 00196 if (!found) 00197 eblerror("no sample found in multi-matrix data " << datas); 00198 sampledims = e.get_idxdim(); 00199 } else 00200 sampledims = idxdim(data.select(0, 0)); 00201 if (sampledims.order() == 2) 00202 sampledims.insert_dim(0, 1); 00203 if (sampledims.order() > 2) { 00204 height = sampledims.dim(1); 00205 width = sampledims.dim(2); 00206 } 00207 // initialize index to 0 00208 it = 0; 00209 // shuffle data indices 00210 shuffle(); 00211 bkeep_outputs = false; 00212 } 00213 00215 // accessors 00216 00217 template <typename Tnet, typename Tdata> 00218 unsigned int datasource<Tnet,Tdata>::size() { 00219 if (multimat) 00220 return datas.dim(0); 00221 return data.dim(0); 00222 } 00223 00224 template <typename Tnet, typename Tdata> 00225 idxdim datasource<Tnet,Tdata>::sample_dims() { 00226 return sampledims; 00227 } 00228 00229 template <typename Tnet, typename Tdata> 00230 mfidxdim datasource<Tnet,Tdata>::sample_mfdims() { 00231 return samplemfdims; 00232 } 00233 00234 template <typename Tnet, typename Tdata> 00235 string& datasource<Tnet,Tdata>::name() { 00236 return _name; 00237 } 00238 00239 template <typename Tnet, typename Tdata> 00240 void datasource<Tnet,Tdata>::set_test() { 00241 test_set = true; 00242 cout << _name << ": This is a testing set only." << endl; 00243 } 00244 00245 template <typename Tnet, typename Tdata> 00246 bool datasource<Tnet,Tdata>::is_test() { 00247 return test_set; 00248 } 00249 00250 template <typename Tnet, typename Tdata> 00251 intg datasource<Tnet,Tdata>::get_epoch_size() { 00252 return epoch_sz; 00253 } 00254 00255 template <typename Tnet, typename Tdata> 00256 intg datasource<Tnet,Tdata>::get_epoch_count() { 00257 return epoch_cnt; 00258 } 00259 00260 template <typename Tnet, typename Tdata> 00261 void datasource<Tnet,Tdata>::set_epoch_size(intg sz) { 00262 cout << _name << ": Setting epoch size to " << sz << endl; 00263 epoch_sz = sz; 00264 } 00265 00266 template <typename Tnet, typename Tdata> 00267 void datasource<Tnet,Tdata>::set_epoch_mode(uint mode) { 00268 epoch_mode = mode; 00269 cout << _name << ": Setting epoch mode to " << epoch_mode; 00270 switch (epoch_mode) { 00271 case 0: cout << " (fixed number of samples)" << endl; break ; 00272 case 1: cout << " (see all samples at least once)" << endl; break ; 00273 default: eblerror("unknown mode"); 00274 } 00275 } 00276 00278 // data access methods 00279 00280 template <typename Tnet, typename Tdata> template <class Tstate> 00281 void datasource<Tnet,Tdata>::fprop_data(mstate<Tstate> &out) { 00282 // number of states to put in out 00283 uint nstates = 1; 00284 if (multimat && datas.order() == 2) { 00285 // count actual number of submatrices for this sample 00286 for (nstates = 0; nstates < datas.dim(1); ++nstates) 00287 if (!datas.exists(it, nstates)) break ; 00288 EDEBUG("Number of submatrices for sample " << it << ": " << nstates); 00289 } 00290 // reallocate if necessary 00291 if (out.size() != nstates) { 00292 out.clear(); 00293 idxdim d = sampledims; 00294 d.setdims(1); 00295 for (uint i = 0; i < nstates; ++i) { 00296 Tstate ts(d); 00297 out.push_back(new Tstate(ts)); 00298 } 00299 } 00300 // copy data 00301 if (multimat) { // multiple matrices per sample 00302 for (uint i = 0; i < nstates; ++i) { 00303 idx<Tdata> dat; 00304 if (datas.order() == 2) 00305 dat = datas.get(it, i); 00306 else if (datas.order() == 1) 00307 dat = datas.get(it); 00308 else eblerror("not implemented"); 00309 // resize if necessary 00310 idxdim d(dat); 00311 Tstate &ts = out[i]; 00312 if (ts.x.get_idxdim() != d) 00313 ts.resize(d); 00314 // copy (and cast) data 00315 idx_copy(dat, ts.x); 00316 if (bias != 0.0) 00317 idx_addc(ts.x, bias, ts.x); 00318 if (coeff != 1.0) 00319 idx_dotc(ts.x, coeff, ts.x); 00320 } 00321 } else { // single matrix per sample 00322 // resize if necessary 00323 Tstate &ts = out[0]; 00324 if (ts.x.get_idxdim() != sampledims) 00325 ts.resize(sampledims); 00326 // copy data 00327 idx<Tdata> dat = data[it]; 00328 idx_copy(dat, ts.x); 00329 if (bias != 0.0) 00330 idx_addc(ts.x, bias, ts.x); 00331 if (coeff != 1.0) 00332 idx_dotc(ts.x, coeff, ts.x); 00333 } 00334 } 00335 00336 template <typename Tnet, typename Tdata> 00337 void datasource<Tnet,Tdata>::fprop_data(fstate_idx<Tnet> &out) { 00338 if (out.x.order() != sampledims.order()) 00339 out = fstate_idx<Tnet>(sampledims); 00340 else 00341 out.resize(sampledims); 00342 idx<Tdata> dat; 00343 if (multimat) 00344 dat = datas.get(it); 00345 else 00346 dat = data[it]; 00347 idx_copy(dat, out.x); 00348 if (bias != 0.0) 00349 idx_addc(out.x, bias, out.x); 00350 if (coeff != 1.0) 00351 idx_dotc(out.x, coeff, out.x); 00352 } 00353 00354 template <typename Tnet, typename Tdata> 00355 void datasource<Tnet,Tdata>::fprop_data(bbstate_idx<Tnet> &out) { 00356 if (out.x.order() != sampledims.order()) 00357 out = bbstate_idx<Tnet>(sampledims); 00358 else 00359 out.resize(sampledims); 00360 fprop_data((fstate_idx<Tnet>&) out); 00361 } 00362 00363 template <typename Tnet, typename Tdata> 00364 void datasource<Tnet,Tdata>::fprop(bbstate_idx<Tnet> &out) { 00365 if (out.x.order() != sampledims.order()) 00366 out = bbstate_idx<Tnet>(sampledims); 00367 else 00368 out.resize(sampledims); 00369 fprop_data((fstate_idx<Tnet>&) out); 00370 } 00371 00372 template <typename Tnet, typename Tdata> 00373 idx<Tdata> datasource<Tnet,Tdata>::get_sample(intg index) { 00374 if (multimat) 00375 return datas.get(index); 00376 else 00377 return data[index]; 00378 } 00379 00380 template <typename Tnet, typename Tdata> 00381 idx<Tnet> datasource<Tnet,Tdata>::get_raw_output(intg index) { 00382 if (index >= 0) 00383 return raw_outputs[index]; 00384 return raw_outputs[it]; 00385 } 00386 00388 // iterating methods 00389 00390 template <typename Tnet, typename Tdata> 00391 void datasource<Tnet,Tdata>::select_sample(intg index) { 00392 if (index < 0 || index >= data.dim(0)) 00393 eblthrow("cannot select index " << index 00394 << " in datasource of dimensions " << data); 00395 it = index; 00396 } 00397 00398 template <typename Tnet, typename Tdata> 00399 void datasource<Tnet,Tdata>::shuffle() { 00400 // shuffle indices to the data 00401 idx_shuffle(indices); 00402 } 00403 00404 template <typename Tnet, typename Tdata> 00405 bool datasource<Tnet,Tdata>::next() { 00406 // increment test iterator 00407 it_test++; 00408 // reset if reached end 00409 if (it_test >= data.dim(0)) { 00410 seek_begin(); 00411 return false; 00412 } 00413 // set main iterator used by fprop 00414 it = it_test; 00415 return true; 00416 } 00417 00418 template <typename Tnet, typename Tdata> 00419 bool datasource<Tnet,Tdata>::next_train() { 00420 // check that this datasource is allowed to call this method 00421 if (test_set) 00422 eblerror("forbidden call of next_train() on testing sets"); 00423 bool pick = false; 00424 not_picked++; 00425 // increment iterator 00426 it_train++; 00427 // reset if reached end 00428 if (it_train >= indices.dim(0)) { 00429 if (shuffle_passes) 00430 shuffle(); // shuffle indices to the data 00431 // reset iterator 00432 seek_begin_train(); 00433 // normalize probabilities, mapping [0..max] to [0..1] 00434 if (weigh_samples) 00435 normalize_probas(); 00436 } 00437 it = indices.get(it_train); // set main iterator to the train iterator 00438 // recursively loop until we find a sample that is picked for this class 00439 pick = this->pick_current(); 00440 epoch_cnt++; 00441 if (pick) { 00442 #ifdef __DEBUG__ 00443 cout << "Picking sample " << it << ", pickings: " << pick_count.get(it) 00444 << ", energy: " << energies.get(it) << ", correct: " 00445 << (int) correct.get(it); 00446 if (weigh_samples) cout << ", proba: " << probas.get(it) << ")"; 00447 cout << endl; 00448 #endif 00449 // increment pick counter for this sample 00450 if (count_pickings) pick_count.set(pick_count.get(it) + 1, it); 00451 // increment sample counter 00452 epoch_pick_cnt++; 00453 not_picked = 0; 00454 return true; 00455 } else { 00456 EDEBUG("Not picking sample " << it << ", pickings: " << pick_count.get(it) 00457 << ", energy: " << energies.get(it) << ", correct: " 00458 << (int) correct.get(it) << ", proba: " << probas.get(it) << ")"); 00459 return false; 00460 } 00461 } 00462 00463 // accessors /////////////////////////////////////////////////////////////// 00464 00465 template <typename Tnet, typename Tdata> 00466 void datasource<Tnet,Tdata>::set_data_bias(Tnet b) { 00467 bias = b; 00468 cout << _name << ": Setting data bias to " << bias << endl; 00469 } 00470 00471 template <typename Tnet, typename Tdata> 00472 void datasource<Tnet,Tdata>::set_data_coeff(Tnet c) { 00473 coeff = c; 00474 cout << _name << ": Setting data coefficient to " << coeff << endl; 00475 } 00476 00477 template <typename Tnet, typename Tdata> 00478 bool datasource<Tnet,Tdata>::epoch_done() { 00479 switch (epoch_mode) { 00480 case 0: // fixed number of samples 00481 if (epoch_cnt >= epoch_sz) 00482 return true; 00483 break ; 00484 case 1: // see all samples at least once 00485 // TODO: same as case 0? 00486 if (epoch_cnt >= epoch_sz) 00487 return true; 00488 break ; 00489 default: eblerror("unknown epoch_mode"); 00490 } 00491 return false; 00492 } 00493 00494 template <typename Tnet, typename Tdata> 00495 void datasource<Tnet,Tdata>::init_epoch() { 00496 epoch_cnt = 0; 00497 epoch_pick_cnt = 0; 00498 epoch_timer.restart(); 00499 epoch_show_printed = -1; // last epoch count we have printed 00500 // if we have prior information about each sample energy and classification 00501 // let's use it to initialize the picking probabilities. 00502 if (weigh_samples) 00503 this->normalize_all_probas(); 00504 } 00505 00506 template <typename Tnet, typename Tdata> 00507 void datasource<Tnet,Tdata>::seek_begin() { 00508 it_test = 0; // reset test iterator 00509 it = it_test; // set main iterator to test iterator 00510 test_timer.restart(); 00511 } 00512 00513 template <typename Tnet, typename Tdata> 00514 void datasource<Tnet,Tdata>::seek_begin_train() { 00515 // reset train iterator 00516 it_train = 0; 00517 // set main iterator to train iterator 00518 it = indices.get(it_train); 00519 } 00520 00521 template <typename Tnet, typename Tdata> 00522 void datasource<Tnet,Tdata>::set_shuffle_passes(bool activate) { 00523 shuffle_passes = activate; 00524 cout << _name 00525 << ": Shuffling of samples (training only) after each pass is " 00526 << (shuffle_passes ? "activated" : "deactivated") << "." << endl; 00527 } 00528 00530 // picking probability methods 00531 00532 template <typename Tnet, typename Tdata> 00533 void datasource<Tnet,Tdata>::normalize_all_probas() { 00534 if (weigh_samples) 00535 normalize_probas(); 00536 } 00537 00538 template <typename Tnet, typename Tdata> 00539 void datasource<Tnet,Tdata>::normalize_probas(vector<intg> *cindices) { 00540 double maxproba = 0, minproba = (numeric_limits<double>::max)(); 00541 double maxenergy = 0, sum = 0; //, energy_ratio, maxenergy2; 00542 vector<intg> allindices; 00543 if (weigh_samples && !is_test()) { 00544 if (!cindices) { // use all samples 00545 cout << _name << ": Normalizing all probabilities"; 00546 allindices.resize(energies.dim(0)); // allocate 00547 for (intg i = 0; i < energies.dim(0); ++i) 00548 allindices[i] = i; 00549 cindices = &allindices; 00550 } 00551 idx<double> sorted_energies(cindices->size()); 00552 // normalize probas for this class, mapping [0..max] to [0..1] 00553 maxenergy = 0; sum = 0; 00554 intg nincorrect = 0, ncorrect = 0, i = 0; 00555 // get max and sum 00556 for (vector<intg>::iterator j = cindices->begin(); 00557 j != cindices->end(); ++j) { 00558 // don't take correct ones into account 00559 if (energies.get(*j) < 0) // energy not set yet 00560 continue ; 00561 if (correct.get(*j) == 1) { // correct 00562 ncorrect++; 00563 if (_ignore_correct) 00564 continue ; // skip this one 00565 } else 00566 nincorrect++; 00567 // max and sum 00568 maxenergy = (std::max)(energies.get(*j), maxenergy); 00569 sum += energies.get(*j); 00570 sorted_energies.set(energies.get(*j), i++); 00571 } 00572 cout << ", nincorrect: " << nincorrect << ", ncorrect: " << ncorrect; 00573 // no incorrect set all to 1 00574 if (!nincorrect) { 00575 idx_fill(probas, 1.0); 00576 cout << endl; 00577 return ; 00578 } 00579 // We choose 2 pivot points in the sorted energies curve, 00580 // one will be used as maximum energy and the other as the minimum 00581 // energy. This helps to have a meaningful range of energies not 00582 // biased by single extrema. 00583 double e1, e2; 00584 sorted_energies.resize(nincorrect); 00585 idx_sortup(sorted_energies); 00586 intg pivot1 = (intg) (sorted_energies.dim(0) * (float) .25); 00587 intg pivot2 = std::min(sorted_energies.dim(0) - 1, 00588 (intg) (sorted_energies.dim(0)*(float)1.0));//.75); 00589 if (sorted_energies.dim(0) == 0) { 00590 e1 = 0; e2 = 1; 00591 } else { 00592 e1 = sorted_energies.get(pivot1); 00593 e2 = sorted_energies.get(pivot2); 00594 } 00595 // the ratio of total energies over n times the max energy 00596 //energy_ratio = sum / (maxenergy * cindices->size()); 00597 // the max probability will be proportional to the energy ratio 00598 // this balances the probabilities so that outliers don't take 00599 // all the probabilites 00600 //maxenergy2 = maxenergy * energy_ratio; 00601 cout << ", max energy: " << maxenergy; 00602 // << ", energy ratio " << energy_ratio 00603 // << " and normalized max energy " << maxenergy2; 00604 // normalize 00605 for (vector<intg>::iterator j = cindices->begin(); 00606 j != cindices->end(); ++j) { 00607 double e = energies.get(*j); 00608 // set proba 0 for correct samples if we ignore correct ones 00609 if (e >= 0 && _ignore_correct && correct.get(*j) == 1) 00610 probas.set(0.0, *j); 00611 else { 00612 // compute probas 00613 double den = e2 - e1; 00614 if (e < 0 || maxenergy == 0 || den == 0) // energy not set yet 00615 probas.set(1.0, *j); 00616 else { 00617 probas.set((std::max)((double) 0, (std::min)((e - e1) / den, 00618 (double) 1)), *j); 00619 if (!hardest_focus) // concentrate on easiest misclassified 00620 probas.set(1 - probas.get(*j), *j); // reverse proba 00621 // iprobas.at(*j)->set((std::max)(sample_min_proba, 00622 // e / maxenergy2)); 00623 // remember min and max proba 00624 maxproba = (std::max)(probas.get(*j), maxproba); 00625 minproba = (std::min)(probas.get(*j), minproba); 00626 } 00627 } 00628 } 00629 cout << ", Min/Max probas are: " << minproba << ", " 00630 << maxproba << endl; 00631 } 00632 } 00633 00634 template <typename Tnet, typename Tdata> 00635 void datasource<Tnet,Tdata>::set_sample_energy(double e, bool correct_, 00636 idx<Tnet> &raw_, 00637 idx<Tnet> &answers_, 00638 idx<Tnet> &target) { 00639 energies.set(e, it); 00640 correct.set(correct_ ? 1 : 0, it); 00641 00642 // store model outputs for current sample 00643 if (bkeep_outputs) { 00644 // resize buffers if necessary 00645 idx<Tnet> ans = answers_.view_as_order(1); 00646 idx<Tnet> ra = raw_.view_as_order(1); 00647 idxdim d(ans), draw(ra); 00648 d.insert_dim(0, data.dim(0)); 00649 draw.insert_dim(0, data.dim(0)); 00650 if (raw_outputs.get_idxdim() != draw) { 00651 raw_outputs.resize(draw); 00652 idx_clear(raw_outputs); 00653 } 00654 if (answers.get_idxdim() != d) { 00655 answers.resize(d); 00656 idx_clear(answers); 00657 } 00658 if (targets.get_idxdim() != draw) { 00659 targets.resize(draw); 00660 idx_clear(targets); 00661 } 00662 // copy raw 00663 idx<Tnet> raw = raw_outputs.select(0, it); 00664 idx_copy(ra, raw); 00665 // copy answers 00666 idx<Tnet> answer = answers.select(0, it); 00667 idx_copy(ans, answer); 00668 // copy target 00669 idx<Tnet> tgt = targets.select(0, it); 00670 idx_copy(target, tgt); 00671 } 00672 } 00673 00674 template <typename Tnet, typename Tdata> 00675 void datasource<Tnet,Tdata>::keep_outputs(bool keep) { 00676 bkeep_outputs = keep; 00677 cout << (bkeep_outputs ? "Keeping" : "Not keeping") 00678 << " model outputs for each sample." << endl; 00679 } 00680 00681 template <typename Tnet, typename Tdata> 00682 void datasource<Tnet,Tdata>::save_pickings(const char *name_) { 00683 // plot file 00684 string name = "pickings"; 00685 if (name_) 00686 name = name_; 00687 string fname = name; 00688 fname += ".plot"; 00689 ofstream fp(fname.c_str()); 00690 if (!fp) { 00691 cerr << "failed to open " << fname << endl; 00692 eblerror("failed to open file for writing"); 00693 } 00694 eblerror("not implemented"); 00695 // typename idx<uint>::dimension_iterator i = pick_count.dim_begin(0); 00696 // uint j = 0; 00697 // for ( ; i.notdone(); i++, j++) 00698 // fp << j << " " << i->get() << endl; 00699 // fp.close(); 00700 cout << _name << ": Wrote picking statistics in " << fname << endl; 00701 // p file 00702 string fname2 = name; 00703 fname2 += ".p"; 00704 ofstream fp2(fname2.c_str()); 00705 if (!fp2) { 00706 cerr << "failed to open " << fname2 << endl; 00707 eblerror("failed to open file for writing"); 00708 } 00709 fp2 << "plot \"" << fname << "\" with impulse" << endl; 00710 fp2.close(); 00711 cout << _name << ": Wrote gnuplot file in " << fname2 << endl; 00712 } 00713 00714 template <typename Tnet, typename Tdata> 00715 void datasource<Tnet,Tdata>::ignore_correct(bool ignore) { 00716 _ignore_correct = ignore; 00717 if (ignore) 00718 cout << (ignore ? "Ignoring" : "Using") << 00719 " correctly classified samples for training." << endl; 00720 } 00721 00722 template <typename Tnet, typename Tdata> 00723 bool datasource<Tnet,Tdata>::mstate_samples() { 00724 return multimat; 00725 } 00726 00727 template <typename Tnet, typename Tdata> 00728 bool datasource<Tnet,Tdata>::get_count_pickings() { 00729 return count_pickings; 00730 } 00731 00732 template <typename Tnet, typename Tdata> 00733 void datasource<Tnet,Tdata>::set_count_pickings(bool count) { 00734 count_pickings = count; 00735 } 00736 00737 template <typename Tnet, typename Tdata> 00738 void datasource<Tnet,Tdata>:: 00739 set_weigh_samples(bool activate, bool hardest_focus_, bool perclass_norm_, 00740 double min_proba) { 00741 weigh_samples = activate; 00742 hardest_focus = hardest_focus_; 00743 perclass_norm = perclass_norm_; 00744 sample_min_proba = MIN(1.0, min_proba); 00745 cout << _name 00746 << ": Weighing of samples (training only) based on classification is " 00747 << (weigh_samples ? "activated" : "deactivated") << "." << endl; 00748 if (activate) { 00749 cout << _name << ": learning is focused on " 00750 << (hardest_focus ? "hardest" : "easiest") 00751 << " misclassified samples" << endl; 00752 if (!_ignore_correct && !hardest_focus) 00753 cerr << "Warning: correct samples are not ignored and focus is on " 00754 << "easiest samples, this may not be optimal" << endl; 00755 cout << "Sample picking probabilities are normalized " 00756 << (perclass_norm ? "per class" : "globally") 00757 << " with minimum probability " << sample_min_proba << endl; 00758 } 00759 } 00760 00762 // pretty methods 00763 00764 template <typename Tnet, typename Tdata> 00765 void datasource<Tnet,Tdata>::pretty_progress(bool newline) { 00766 // train pretty 00767 intg i = epoch_cnt, sz = epoch_sz; 00768 string pre = "training: "; 00769 // test pretty 00770 if (is_test()) { 00771 i = it_test; 00772 sz = this->size(); 00773 pre = "testing: "; 00774 } 00775 // common code 00776 if (epoch_show > 0 && i % epoch_show == 0 && epoch_show_printed != i) { 00777 epoch_show_printed = i; // remember last time printed 00778 cout << pre << i << " / " << sz 00779 << ", elapsed: " << test_timer.elapsed() << ", ETA: " 00780 << test_timer. 00781 elapsed((long) ((sz - i) * 00782 (test_timer.elapsed_seconds() 00783 /(double)std::max((intg)1,i)))); 00784 if (newline) 00785 cout << endl; 00786 } 00787 } 00788 00789 template <typename Tnet, typename Tdata> 00790 void datasource<Tnet,Tdata>::pretty() { 00791 cout << _name << ": dataset \"" << _name << "\" contains " << data.dim(0) 00792 << " samples of dimension " << sampledims 00793 << " and defines an epoch as " << epoch_sz << " samples." << endl; 00794 } 00795 00797 // state saving 00798 00799 template <typename Tnet, typename Tdata> 00800 void datasource<Tnet,Tdata>::save_state() { 00801 state_saved = true; 00802 count_pickings_save = count_pickings; 00803 it_saved = it; // save main iterator 00804 it_test_saved = it_test; 00805 it_train_saved = it_train; 00806 for (intg k = 0; k < indices.dim(0); ++k) 00807 indices_saved[k] = indices[k]; 00808 } 00809 00810 template <typename Tnet, typename Tdata> 00811 void datasource<Tnet,Tdata>::restore_state() { 00812 if (!state_saved) 00813 eblerror("state not saved, call save_state() before restore_state()"); 00814 count_pickings = count_pickings_save; 00815 it = it_saved; // restore main iterator 00816 it_test = it_test_saved; 00817 it_train = it_train_saved; 00818 for (intg k = 0; k < indices.dim(0); ++k) 00819 indices[k] = indices_saved[k]; 00820 } 00821 00822 template <typename Tnet, typename Tdata> 00823 void datasource<Tnet,Tdata>::set_epoch_show(uint modulo) { 00824 cout << _name << ": Print training count every " << modulo 00825 << " samples." << endl; 00826 epoch_show = modulo; 00827 } 00828 00830 // protected pickings methods 00831 00832 template <typename Tnet, typename Tdata> 00833 bool datasource<Tnet,Tdata>::pick_current() { 00834 if (test_set) // check that this datasource is allowed to call this method 00835 eblerror("forbidden call of pick_current() on testing sets"); 00836 if (!weigh_samples) // always pick sample when not using probabilities 00837 return true; 00838 // draw random number between 0 and 1 and return true if lower 00839 // than sample's probability. 00840 double r = drand(); // [0..1] 00841 if (r <= probas.get(it)) 00842 return true; 00843 return false; 00844 } 00845 00846 template <typename Tnet, typename Tdata> 00847 map<uint,intg>& datasource<Tnet,Tdata>::get_pickings() { 00848 picksmap.clear(); 00849 // typename idx<uint>::dimension_iterator i = pick_count.dim_begin(0); 00850 // uint j = 0; 00851 // for ( ; i.notdone(); i++, j++) 00852 // picksmap[i->get()] = j; 00853 eblerror("not implemented"); 00854 return picksmap; 00855 } 00856 00858 // labeled_datasource 00859 00860 template <typename Tnet, typename Tdata, typename Tlabel> 00861 labeled_datasource<Tnet, Tdata, Tlabel>::labeled_datasource() { 00862 } 00863 00864 template <typename Tnet, typename Tdata, typename Tlabel> 00865 labeled_datasource<Tnet, Tdata, Tlabel>:: 00866 labeled_datasource(midx<Tdata> &data_, idx<Tlabel> &labels_, 00867 const char *name_) { 00868 init(data_, labels_, name_); 00869 this->init_epoch(); 00870 this->pretty(); // print information about this dataset 00871 } 00872 00873 template <typename Tnet, typename Tdata, typename Tlabel> 00874 labeled_datasource<Tnet, Tdata, Tlabel>:: 00875 labeled_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_, 00876 const char *name_) { 00877 init(data_, labels_, name_); 00878 this->init_epoch(); 00879 this->pretty(); // print information about this dataset 00880 } 00881 00882 template <typename Tnet, typename Tdata, typename Tlabel> 00883 labeled_datasource<Tnet, Tdata, Tlabel>:: 00884 labeled_datasource(const char *root_dsname, const char *name_) { 00885 init_root(root_dsname, name_); 00886 this->init_epoch(); 00887 this->pretty(); // print information about this dataset 00888 } 00889 00890 template <typename Tnet, typename Tdata, typename Tlabel> 00891 labeled_datasource<Tnet, Tdata, Tlabel>:: 00892 labeled_datasource(const char *root, const char *data_name, 00893 const char *labels_name, const char *jitters_name, 00894 const char *scales_name, const char *name_) { 00895 init_root(root, data_name, labels_name, jitters_name, scales_name, name_); 00896 this->init_epoch(); 00897 this->pretty(); // print information about this dataset 00898 } 00899 00900 template <typename Tnet, typename Tdata, typename Tlabel> 00901 labeled_datasource<Tnet, Tdata, Tlabel>::~labeled_datasource() { 00902 } 00903 00905 // init methods 00906 00907 template <typename Tnet, typename Tdata, typename Tlabel> 00908 void labeled_datasource<Tnet, Tdata, Tlabel>:: 00909 init(midx<Tdata> &data_, idx<Tlabel> &labels_, const char *name) { 00910 scales_loaded = false; 00911 datasource<Tnet,Tdata>::init(data_, name); 00912 init_labels(labels_, name); 00913 } 00914 00915 template <typename Tnet, typename Tdata, typename Tlabel> 00916 void labeled_datasource<Tnet, Tdata, Tlabel>:: 00917 init(idx<Tdata> &data_, idx<Tlabel> &labels_, const char *name) { 00918 scales_loaded = false; 00919 datasource<Tnet,Tdata>::init(data_, name); 00920 init_labels(labels_, name); 00921 } 00922 00923 template <typename Tnet, typename Tdata, typename Tlabel> 00924 void labeled_datasource<Tnet, Tdata, Tlabel>:: 00925 init_labels(idx<Tlabel> &labels_, const char *name) { 00926 labels = labels_; 00927 label_bias = 0; 00928 label_coeff = 1; 00929 // set label dimensions 00930 labeldims = labels.get_idxdim(); 00931 if (labeldims.order() > 1) 00932 labeldims.remove_dim(0); 00933 else 00934 labeldims.setdim(0, 1); 00935 } 00936 00937 template <typename Tnet, typename Tdata, typename Tlabel> 00938 void labeled_datasource<Tnet, Tdata, Tlabel>:: 00939 init(const char *data_fname, const char *labels_fname, 00940 const char *jitters_fname, const char *scales_fname, const char *name_, 00941 uint max_size) { 00942 // load jitters 00943 if (jitters_fname && strlen(jitters_fname) != 0) { 00944 try { 00945 jitters = load_matrices<float>(jitters_fname); 00946 jitters_maxdim = jitters.get_maxdim(); 00947 } 00948 catch (string &err) { cerr << "warning: " << err << endl; } 00949 } else cout << "No jitter information loaded." << endl; 00950 // load labels 00951 idx<Tlabel> lab; 00952 try { 00953 lab = load_matrix<Tlabel>(labels_fname); 00954 } catch (string &err) { 00955 cerr << err << endl; 00956 eblerror("Failed to load dataset file"); 00957 } 00958 // limit number of samples 00959 if (max_size > 0) { 00960 cout << "Limiting " << name_<< " to " << max_size << " samples." <<endl; 00961 // lab = lab.narrow(0, std::min((intg) max_size, lab.dim(0)), 0); 00962 } 00963 // load data 00964 multimat = has_multiple_matrices(data_fname); 00965 try { 00966 if (multimat) { 00967 midx<Tdata> dat = load_matrices<Tdata>(data_fname); 00968 if (max_size > 0) 00969 dat = dat.narrow(0, std::min((intg) max_size, dat.dim(0)), 0); 00970 // init 00971 labeled_datasource<Tnet, Tdata, Tlabel>::init(dat, lab, name_); 00972 } else { 00973 idx<Tdata> dat = load_matrix<Tdata>(data_fname); 00974 if (max_size > 0) 00975 dat = dat.narrow(0, std::min((intg) max_size, dat.dim(0)), 0); 00976 // init 00977 labeled_datasource<Tnet, Tdata, Tlabel>::init(dat, lab, name_); 00978 } 00979 } catch (string &err) { 00980 cerr << err << endl; 00981 eblerror("Failed to load dataset file"); 00982 } 00983 // load scales 00984 scales_loaded = false; 00985 if (scales_fname && strlen(scales_fname) != 0) { 00986 try { 00987 scales = load_matrix<intg>(scales_fname); 00988 scales_loaded = true; 00989 } 00990 catch (string &err) { cerr << "warning: " << err << endl; } 00991 } else cout << "No scale information loaded." << endl; 00992 } 00993 00994 template <typename Tnet, typename Tdata, typename Tlabel> 00995 void labeled_datasource<Tnet, Tdata, Tlabel>:: 00996 init_root(const char *root, const char *data_name, const char *labels_name, 00997 const char *jitters_name, const char *scales_name, 00998 const char *name_) { 00999 string data_fname, labels_fname, classes_fname, jitters_fname, scales_fname; 01000 data_fname << root << "/" << data_name << "_" << DATA_NAME 01001 << MATRIX_EXTENSION; 01002 labels_fname << root << "/" << labels_name << "_" << LABELS_NAME 01003 << MATRIX_EXTENSION; 01004 if (jitters_name) 01005 jitters_fname << root << "/" << jitters_name 01006 << "_" << JITTERS_NAME << MATRIX_EXTENSION; 01007 if (scales_name) 01008 scales_fname << root << "/" << scales_name 01009 << "_" << SCALES_NAME << MATRIX_EXTENSION; 01010 init(data_fname.c_str(), labels_fname.c_str(), 01011 jitters_name ? jitters_fname.c_str() : NULL, 01012 scales_name ? scales_fname.c_str() : NULL, name_); 01013 } 01014 01015 template <typename Tnet, typename Tdata, typename Tlabel> 01016 void labeled_datasource<Tnet, Tdata, Tlabel>:: 01017 init_root(const char *root_dsname, const char *name_) { 01018 string data_fname, labels_fname, classes_fname, jitters_fname, 01019 scales_fname; 01020 data_fname << root_dsname << "_" << DATA_NAME << MATRIX_EXTENSION; 01021 labels_fname << root_dsname << "_" << LABELS_NAME << MATRIX_EXTENSION; 01022 classes_fname << root_dsname << "_" << CLASSES_NAME << MATRIX_EXTENSION; 01023 jitters_fname << root_dsname << "_" << JITTERS_NAME << MATRIX_EXTENSION; 01024 scales_fname << root_dsname << "_" << SCALES_NAME << MATRIX_EXTENSION; 01025 init(data_fname.c_str(), labels_fname.c_str(), jitters_fname.c_str(), 01026 scales_fname.c_str(), name_); 01027 } 01028 01029 // data access /////////////////////////////////////////////////////////////// 01030 01031 template <typename Tnet, typename Tdata, typename Tlabel> 01032 void labeled_datasource<Tnet, Tdata, Tlabel>:: 01033 fprop(bbstate_idx<Tnet> &out, bbstate_idx<Tlabel> &label) { 01034 this->fprop_data(out); 01035 this->fprop_label(label); 01036 } 01037 01038 template <typename Tnet, typename Tdata, typename Tlabel> 01039 void labeled_datasource<Tnet,Tdata,Tlabel>:: 01040 fprop_label(fstate_idx<Tlabel> &label) { 01041 idx<Tlabel> lab = labels[it]; 01042 idx_copy(lab, label.x); 01043 if (label_bias != 0) 01044 idx_addc(label.x, label_bias, label.x); 01045 if (label_coeff != 1) 01046 idx_dotc(label.x, label_coeff, label.x); 01047 } 01048 01049 template <typename Tnet, typename Tdata, typename Tlabel> 01050 void labeled_datasource<Tnet,Tdata,Tlabel>:: 01051 fprop_label_net(fstate_idx<Tnet> &label) { 01052 idx<Tlabel> lab = labels[it]; 01053 idx_copy(lab, label.x); 01054 if (label_bias != 0) 01055 idx_addc(label.x, label_bias, label.x); 01056 if (label_coeff != 1) 01057 idx_dotc(label.x, label_coeff, label.x); 01058 } 01059 01060 template <typename Tnet, typename Tdata, typename Tlabel> 01061 void labeled_datasource<Tnet,Tdata,Tlabel>:: 01062 fprop_label_net(bbstate_idx<Tnet> &label) { 01063 fprop_label_net((fstate_idx<Tnet>&) label); 01064 } 01065 01066 template <typename Tnet, typename Tdata, typename Tlabel> 01067 void labeled_datasource<Tnet,Tdata,Tlabel>:: 01068 fprop_jitter(bbstate_idx<Tnet> &jitt) { 01069 if (jitters.order() < 1) eblerror("jitter information was not loaded"); 01070 if (jitters.exists(it)) { 01071 idx<float> j = jitters.get(it); 01072 idxdim d(j.get_idxdim()); 01073 if (jitt.x.get_idxdim() != d) 01074 jitt.resize(d); 01075 idx_copy(j, jitt.x); 01076 } else { // fprop an empty jitter 01077 idxdim d(jitters.get_maxdim()); 01078 d.setdim(0, 1); 01079 if (jitt.x.get_idxdim() != d) 01080 jitt.resize(d); 01081 idx_clear(jitt.x); 01082 } 01083 } 01084 01085 template <typename Tnet, typename Tdata, typename Tlabel> 01086 intg labeled_datasource<Tnet,Tdata,Tlabel>::fprop_scale() { 01087 if (!scales_loaded) eblthrow("scales information not present"); 01088 return scales.get(it); 01089 } 01090 01091 // accessors ///////////////////////////////////////////////////////////////// 01092 01093 template <typename Tnet, typename Tdata, typename Tlabel> 01094 bool labeled_datasource<Tnet,Tdata,Tlabel>::included_sample(intg index) { 01095 if (index >= data.dim(0)) 01096 eblerror("cannot check inclusion of sample " << index 01097 << ", only " << data.dim(0) << " samples"); 01098 return true; 01099 } 01100 01101 template <typename Tnet, typename Tdata, typename Tlabel> 01102 intg labeled_datasource<Tnet,Tdata,Tlabel>::count_included_samples() { 01103 return this->size(); 01104 } 01105 01106 template <typename Tnet, typename Tdata, typename Tlabel> 01107 void labeled_datasource<Tnet, Tdata, Tlabel>::pretty() { 01108 cout << _name << ": (regression) labeled dataset \"" << _name 01109 << "\" contains " 01110 << data.dim(0) << " samples of dimension " << sampledims 01111 << " and defines an epoch as " << epoch_sz << " samples."; 01112 pretty_scales(); 01113 } 01114 01115 template <typename Tnet, typename Tdata, typename Tlabel> 01116 void labeled_datasource<Tnet, Tdata, Tlabel>::pretty_scales() { 01117 if (scales_loaded) { 01118 intg maxscale = idx_max(scales); 01119 vector<intg> tally(maxscale + 1, 0); 01120 idx_bloop1(scale, scales, intg) { 01121 intg s = scale.get(); 01122 if (s < 0) eblerror("unexpected negative value"); 01123 tally[s] = tally[s] + 1; 01124 } 01125 intg nscales = 0; 01126 for (intg i = 0; i < (intg) tally.size(); ++i) 01127 if (tally[i] > 0) nscales++; 01128 // print scales distribution for each class 01129 cout << _name << ": has " << nscales << " scales"; 01130 for (intg i = 0; i < (intg) tally.size(); ++i) 01131 cout << ", " << i << ": " << tally[i]; 01132 cout << endl; 01133 } else cout << _name << ": no scales information." << endl; 01134 } 01135 01136 template <typename Tnet, typename Tdata, typename Tlabel> 01137 idxdim labeled_datasource<Tnet, Tdata, Tlabel>::label_dims() { 01138 return labeldims; 01139 } 01140 01141 template <typename Tnet, typename Tdata, typename Tlabel> 01142 void labeled_datasource<Tnet,Tdata,Tlabel>::set_label_bias(Tnet b) { 01143 label_bias = b; 01144 cout << _name << ": Setting labels bias to " << label_bias << endl; 01145 } 01146 01147 template <typename Tnet, typename Tdata, typename Tlabel> 01148 void labeled_datasource<Tnet,Tdata,Tlabel>::set_label_coeff(Tnet c) { 01149 label_coeff = c; 01150 cout << _name << ": Setting labels coefficient to " << label_coeff << endl; 01151 } 01152 01153 template <typename Tnet, typename Tdata, typename Tlabel> 01154 bool labeled_datasource<Tnet,Tdata,Tlabel>::has_scales() { 01155 return scales_loaded; 01156 } 01157 01159 // class_datasource 01160 01161 template <typename Tnet, typename Tdata, typename Tlabel> 01162 class_datasource<Tnet, Tdata, Tlabel>::class_datasource() 01163 : lblstr(NULL) { 01164 defaults(); 01165 } 01166 01167 template <typename Tnet, typename Tdata, typename Tlabel> 01168 class_datasource<Tnet, Tdata, Tlabel>:: 01169 class_datasource(midx<Tdata> &data_, idx<Tlabel> &labels_, 01170 vector<string*> *lblstr_, const char *name_) { 01171 defaults(); 01172 init(data_, labels_, name_, lblstr_); 01173 this->init_epoch(); 01174 this->pretty(); // print info about dataset 01175 } 01176 01177 template <typename Tnet, typename Tdata, typename Tlabel> 01178 class_datasource<Tnet, Tdata, Tlabel>:: 01179 class_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_, 01180 vector<string*> *lblstr_, const char *name_) { 01181 defaults(); 01182 init(data_, labels_, name_, lblstr_); 01183 this->init_epoch(); 01184 this->pretty(); // print info about dataset 01185 } 01186 01187 template <typename Tnet, typename Tdata, typename Tlabel> 01188 class_datasource<Tnet, Tdata, Tlabel>:: 01189 class_datasource(midx<Tdata> &data_, idx<Tlabel> &labels_, 01190 idx<ubyte> &classes, const char *name_) { 01191 defaults(); 01192 init_strings(classes); 01193 init(data_, labels_, this->lblstr, name_); 01194 this->init_epoch(); 01195 this->pretty(); // print info about dataset 01196 } 01197 01198 template <typename Tnet, typename Tdata, typename Tlabel> 01199 class_datasource<Tnet, Tdata, Tlabel>:: 01200 class_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_, 01201 idx<ubyte> &classes, const char *name_) { 01202 defaults(); 01203 init_strings(classes); 01204 init(data_, labels_, this->lblstr, name_); 01205 this->init_epoch(); 01206 this->pretty(); // print info about dataset 01207 } 01208 01209 template <typename Tnet, typename Tdata, typename Tlabel> 01210 class_datasource<Tnet, Tdata, Tlabel>:: 01211 class_datasource(const char *data_name, const char *labels_name, 01212 const char *jitters_name, const char *scales_name, 01213 const char *classes_name, const char *name_) { 01214 defaults(); 01215 init(data_name, labels_name, jitters_name, scales_name, classes_name,name_); 01216 this->init_epoch(); 01217 this->pretty(); // print info about dataset 01218 } 01219 01220 template <typename Tnet, typename Tdata, typename Tlabel> 01221 class_datasource<Tnet, Tdata, Tlabel>:: 01222 class_datasource(const class_datasource<Tnet, Tdata, Tlabel> &ds) 01223 : datasource<Tnet,Tdata>((const datasource<Tnet,Tdata>&) ds), 01224 lblstr(NULL) { 01225 defaults(); 01226 if (ds.lblstr) { 01227 this->lblstr = new vector<string*>; 01228 for (unsigned int i = 0; i < ds.lblstr->size(); ++i) { 01229 this->lblstr->push_back(new string(*ds.lblstr->at(i))); 01230 } 01231 } 01232 } 01233 01234 template <typename Tnet, typename Tdata, typename Tlabel> 01235 class_datasource<Tnet, Tdata, Tlabel>::~class_datasource() { 01236 if (lblstr) { // this class owns lblstr and its content 01237 vector<string*>::iterator i = lblstr->begin(); 01238 for ( ; i != lblstr->end(); ++i) 01239 if (*i) 01240 delete *i; 01241 delete lblstr; 01242 } 01243 } 01244 01246 // init methods 01247 01248 template <typename Tnet, typename Tdata, typename Tlabel> 01249 void class_datasource<Tnet, Tdata, Tlabel>::defaults() { 01250 balance = true; 01251 bexclusion = false; 01252 random_class_order = true; 01253 } 01254 01255 template <typename Tnet, typename Tdata, typename Tlabel> 01256 void class_datasource<Tnet, Tdata, Tlabel>:: 01257 init_strings(idx<ubyte> &classes) { 01258 this->lblstr = NULL; 01259 // load classes strings 01260 if (classes.order() == 2) { 01261 this->lblstr = new vector<string*>; 01262 idx_bloop1(classe, classes, ubyte) { 01263 this->lblstr->push_back(new string((const char*) classe.idx_ptr())); 01264 } 01265 } 01266 } 01267 01268 template <typename Tnet, typename Tdata, typename Tlabel> 01269 void class_datasource<Tnet, Tdata, Tlabel>:: 01270 init_local(vector<string*> *lblstr_) { 01271 nclasses = (intg) idx_max(labels) + 1; 01272 if (lblstr_) 01273 nclasses = std::max(nclasses, (intg) lblstr_->size()); 01274 // assign classes strings 01275 this->lblstr = lblstr_; 01276 // if no names are given and discrete, use indices as names 01277 if (!this->lblstr) { 01278 this->lblstr = new vector<string*>; 01279 ostringstream o; 01280 int imax = (int) idx_max(this->labels); 01281 for (int i = 0; i <= imax; ++i) { 01282 o << i; 01283 this->lblstr->push_back(new string(o.str())); 01284 o.str(""); 01285 } 01286 } 01287 init_class_labels(); 01288 nclasses = std::max((intg) idx_max(labels) + 1, (intg) clblstr.size()); 01289 // count number of samples per class 01290 counts.resize(nclasses); 01291 fill(counts.begin(), counts.end(), 0); 01292 idx_bloop1(lab, labels, Tlabel) { 01293 counts[(size_t)lab.gget()]++; 01294 } 01295 // balance 01296 set_balanced(true); // balance dataset for each class in next_train 01297 perclass_norm = false; 01298 included = nclasses; 01299 seek_begin(); 01300 seek_begin_train(); 01301 } 01302 01303 template <typename Tnet, typename Tdata, typename Tlabel> 01304 void class_datasource<Tnet, Tdata, Tlabel>:: 01305 init(midx<Tdata> &data_, idx<Tlabel> &labels_, vector<string*> *lblstr_, 01306 const char *name_) { 01307 labeled_datasource<Tnet, Tdata, Tlabel>::init(data_, labels_, name_); 01308 class_datasource<Tnet, Tdata, Tlabel>::init_local(lblstr_); 01309 } 01310 01311 template <typename Tnet, typename Tdata, typename Tlabel> 01312 void class_datasource<Tnet, Tdata, Tlabel>:: 01313 init(idx<Tdata> &data_, idx<Tlabel> &labels_, vector<string*> *lblstr_, 01314 const char *name_) { 01315 labeled_datasource<Tnet, Tdata, Tlabel>::init(data_, labels_, name_); 01316 class_datasource<Tnet, Tdata, Tlabel>::init_local(lblstr_); 01317 } 01318 01319 template <typename Tnet, typename Tdata, typename Tlabel> 01320 void class_datasource<Tnet, Tdata, Tlabel>:: 01321 init(const char *data_fname, const char *labels_fname, 01322 const char *jitters_fname, const char *scales_fname, 01323 const char *classes_fname, const char *name_, 01324 uint max_size) { 01325 // load classes 01326 idx<ubyte> classes; 01327 bool classes_found = false; 01328 if (classes_fname && strlen(classes_fname) != 0) { 01329 try { 01330 classes = load_matrix<ubyte>(classes_fname); 01331 classes_found = true; 01332 } catch (string &err) { cerr << "warning: " << err << endl; } 01333 } else 01334 cout << "No category names found, using numbers." << endl; 01335 // classes names are optional, use numbers by default if not specified 01336 if (classes_found) { 01337 this->lblstr = new vector<string*>; 01338 idx_bloop1(classe, classes, ubyte) { 01339 this->lblstr->push_back(new string((const char*) classe.idx_ptr())); 01340 } 01341 } 01342 // init 01343 labeled_datasource<Tnet, Tdata, Tlabel>:: 01344 init(data_fname, labels_fname, jitters_fname, scales_fname, name_, 01345 max_size); 01346 class_datasource<Tnet, Tdata, Tlabel>::init_local(this->lblstr); 01347 } 01348 01349 template <typename Tnet, typename Tdata, typename Tlabel> 01350 void class_datasource<Tnet, Tdata, Tlabel>:: 01351 init_root(const char *root, const char *data_name, const char *labels_name, 01352 const char *jitters_name, const char *scales_name, 01353 const char *classes_name, const char *name_) { 01354 string data_fname, labels_fname, classes_fname, jitters_fname, scales_fname; 01355 data_fname << root << "/" << data_name << "_" << DATA_NAME 01356 << MATRIX_EXTENSION; 01357 labels_fname << root << "/" << labels_name << "_" << LABELS_NAME 01358 << MATRIX_EXTENSION; 01359 classes_fname << root << "/" << classes_name << "_" << CLASSES_NAME 01360 << MATRIX_EXTENSION; 01361 jitters_fname << root << "/" << (jitters_name ? jitters_name :classes_name) 01362 << "_" << JITTERS_NAME << MATRIX_EXTENSION; 01363 scales_fname << root << "/" << (scales_name ? scales_name :classes_name) 01364 << "_" << SCALES_NAME << MATRIX_EXTENSION; 01365 init(data_fname.c_str(), labels_fname.c_str(), jitters_fname.c_str(), 01366 scales_fname.c_str(), classes_fname.c_str(), name_); 01367 } 01368 01369 template <typename Tnet, typename Tdata, typename Tlabel> 01370 void class_datasource<Tnet, Tdata, Tlabel>:: 01371 init_root(const char *root_dsname, const char *name_) { 01372 string data_fname, labels_fname, classes_fname, jitters_fname, scales_fname; 01373 data_fname << root_dsname << "_" << DATA_NAME << MATRIX_EXTENSION; 01374 labels_fname << root_dsname << "_" << LABELS_NAME << MATRIX_EXTENSION; 01375 classes_fname << root_dsname << "_" << CLASSES_NAME << MATRIX_EXTENSION; 01376 jitters_fname << root_dsname << "_" << JITTERS_NAME << MATRIX_EXTENSION; 01377 scales_fname << root_dsname << "_" << SCALES_NAME << MATRIX_EXTENSION; 01378 init(data_fname.c_str(), labels_fname.c_str(), jitters_fname.c_str(), 01379 scales_fname.c_str(), classes_fname.c_str(), name_); 01380 } 01381 01382 template <typename Tnet, typename Tdata, typename Tlabel> 01383 void class_datasource<Tnet, Tdata, Tlabel>::init_class_labels() { 01384 std::map<Tlabel,Tlabel> mlabs; 01385 if (olabels.get_idxdim() != labels.get_idxdim()) 01386 olabels = idx<Tlabel>(labels.get_idxdim()); 01387 idx_copy(labels, olabels); // keep original labels 01388 idx_sortup(labels); 01389 // add all classes into table 01390 intg i = 0; 01391 clblstr.clear(); 01392 { idx_bloop1(lab, labels, Tlabel) { 01393 Tlabel l = lab.gget(); 01394 if (mlabs.find(l) == mlabs.end() 01395 && (!bexclusion || !excluded[(uint)l])) { 01396 mlabs[l] = i++; 01397 clblstr.push_back((*lblstr)[(uint)l]); 01398 } 01399 }} 01400 // now add excluded classes into table 01401 { idx_bloop1(lab, labels, Tlabel) { 01402 Tlabel l = lab.gget(); 01403 if (mlabs.find(l) == mlabs.end() 01404 && (bexclusion && excluded[(uint)l])) { 01405 mlabs[l] = i++; 01406 } 01407 }} 01408 // now replace all original labels by their new value 01409 idx_copy(olabels, labels); 01410 { idx_bloop1(lab, labels, Tlabel) { 01411 Tlabel l = lab.gget(); 01412 if (mlabs.find(l) != mlabs.end()) 01413 lab.sset(mlabs[l]); 01414 else 01415 eblerror("label " << (uint) l << " not found"); 01416 }} 01417 // replace exclusion table 01418 if (bexclusion) { 01419 for (uint i = 0; i < excluded.size(); ++i) 01420 excluded[i] = i < clblstr.size() ? false : true; 01421 } 01422 if (balance) 01423 set_balanced(true); // balance dataset for each class in next_train 01424 } 01425 01427 // data access 01428 01429 template <typename Tnet, typename Tdata, typename Tlabel> 01430 Tlabel class_datasource<Tnet,Tdata,Tlabel>::get_label() { 01431 idx<Tlabel> lab = labels[it]; 01432 if (lab.order() == 0) 01433 return lab.get(); 01434 else if (lab.order() == 1 && lab.dim(0) == 1) 01435 return lab.get(0); 01436 else 01437 eblerror("expected single-element labels"); 01438 return 0; 01439 } 01440 01442 // iterating 01443 01444 template <typename Tnet, typename Tdata, typename Tlabel> 01445 bool class_datasource<Tnet,Tdata,Tlabel>::included_sample(intg index) { 01446 if (!bexclusion) 01447 return true; 01448 if (index >= data.dim(0)) 01449 eblerror("cannot check inclusion of sample " << index 01450 << ", only " << data.dim(0) << " samples"); 01451 return !excluded[(int) labels.gget(index)]; 01452 } 01453 01454 template <typename Tnet, typename Tdata, typename Tlabel> 01455 intg class_datasource<Tnet,Tdata,Tlabel>::count_included_samples() { 01456 if (!bexclusion) return this->size(); 01457 intg n = 0; 01458 idx_bloop1(lab, labels, Tlabel) { 01459 if (!excluded[(int) lab.gget()]) 01460 n++; 01461 } 01462 return n; 01463 } 01464 01465 template <typename Tnet, typename Tdata, typename Tlabel> 01466 void class_datasource<Tnet,Tdata,Tlabel>::seek_begin() { 01467 datasource<Tnet,Tdata>::seek_begin(); 01468 if (bexclusion) { 01469 while (excluded[(int)get_label()]) 01470 this->next(); 01471 EDEBUG("seek_begin to label: " << (int) get_label() << " it: " << it); 01472 } 01473 } 01474 01475 template <typename Tnet, typename Tdata, typename Tlabel> 01476 void class_datasource<Tnet,Tdata,Tlabel>::seek_begin_train() { 01477 datasource<Tnet,Tdata>::seek_begin_train(); 01478 if (bexclusion) { 01479 while (excluded[(int)get_label()]) 01480 this->next(); 01481 EDEBUG("seek_begin to label: " << (int) get_label() << " it: " << it); 01482 } 01483 } 01484 01485 01486 template <typename Tnet, typename Tdata, typename Tlabel> 01487 bool class_datasource<Tnet,Tdata,Tlabel>::next() { 01488 if (!bexclusion) 01489 return datasource<Tnet,Tdata>::next(); 01490 // handle excluded classes 01491 bool b = datasource<Tnet,Tdata>::next(); 01492 while (b && excluded[(int)get_label()]) 01493 b = datasource<Tnet,Tdata>::next(); 01494 return b; 01495 } 01496 01497 template <typename Tnet, typename Tdata, typename Tlabel> 01498 bool class_datasource<Tnet,Tdata,Tlabel>::next_train() { 01499 // check that this datasource is allowed to call this method 01500 if (test_set) 01501 eblerror("forbidden call of next_train() on testing sets"); 01502 if (!balance) // do not balance by class 01503 return datasource<Tnet,Tdata>::next_train(); 01504 bool pick = false; 01505 not_picked++; 01506 // balanced: return samples in class-balanced order 01507 // get pointer to first non empty class 01508 while (!bal_indices[class_it].size()) 01509 next_balanced_class(); 01510 it = bal_indices[class_it][bal_it[class_it]]; 01511 bal_it[class_it] += 1; 01512 // decide if we want to select this sample for training 01513 pick = this->pick_current(); 01514 // decrement epoch counter 01515 // if (epoch_done_counters[class_it] > 0) 01516 epoch_done_counters[class_it] = epoch_done_counters[class_it] - 1; 01517 if (bal_it[class_it] >= bal_indices[class_it].size()) { 01518 // returning to begining of list for this class 01519 bal_it[class_it] = 0; 01520 // shuffling list for this class 01521 if (shuffle_passes) { 01522 vector<intg> &clist = bal_indices[class_it]; 01523 random_shuffle(clist.begin(), clist.end()); 01524 } 01525 if (weigh_samples) 01526 normalize_probas(class_it); 01527 } 01528 // recursion failsafe, allow 1000 max recursions 01529 if (!bexclusion && 01530 not_picked > MIN(1000, (intg) bal_indices[class_it].size())) { 01531 // we called recursion on this method more than number of class samples 01532 // give up and show current sample 01533 pick = true; 01534 } 01535 if (pick) { 01536 // increment pick counter for this sample 01537 if (count_pickings) pick_count.set(pick_count.get(it) + 1, it); 01538 #ifdef __DEBUG__ 01539 cout << "Picking sample " << it << " (label: " << (int)get_label() 01540 << ", name: " << *(clblstr[(int)get_label()]) << ", pickings: " 01541 << pick_count.get(it) << ", energy: " << energies.get(it) 01542 << ", correct: " << (int) correct.get(it); 01543 if (weigh_samples) cout << ", proba: " << probas.get(it); 01544 cout << ")" << endl; 01545 #endif 01546 epoch_cnt++; 01547 // increment sample counter 01548 epoch_pick_cnt++; 01549 // if we picked a sample, jump to next class 01550 next_balanced_class(); 01551 not_picked = 0; 01552 this->pretty_progress(); 01553 return true; 01554 } else { 01555 #ifdef __DEBUG__ 01556 if (!bexclusion) 01557 cout << "Not picking sample " << it << " (label: " 01558 << (int) labels.gget(it) << ", pickings: " << pick_count.get(it) 01559 << ", energy: " << energies.get(it) 01560 << ", correct: " << (int) correct.get(it) 01561 << ", proba: " << probas.get(it) << ")" << endl; 01562 #endif 01563 this->pretty_progress(); 01564 return false; 01565 } 01566 } 01567 01568 template <typename Tnet, typename Tdata, typename Tlabel> 01569 void class_datasource<Tnet,Tdata,Tlabel>::next_balanced_class() { 01570 class_it_it++; 01571 if (class_it_it >= class_order.size()) { 01572 class_it_it = 0; 01573 reset_class_order(); 01574 } 01575 class_it = class_order[class_it_it]; 01576 if (bexclusion && excluded[class_it]) 01577 return next_balanced_class(); 01578 } 01579 01580 template <typename Tnet, typename Tdata, typename Tlabel> 01581 void class_datasource<Tnet,Tdata,Tlabel>::reset_class_order() { 01582 if (random_class_order) // randomize classes order 01583 random_shuffle(class_order.begin(), class_order.end()); 01584 } 01585 01586 template <typename Tnet, typename Tdata, typename Tlabel> 01587 void class_datasource<Tnet,Tdata,Tlabel>::set_balanced(bool bal) { 01588 balance = bal; 01589 if (!balance) // unbalanced 01590 cout << _name << ": Setting training as unbalanced (not taking class " 01591 << "distributions into account)." << endl; 01592 else { // balanced 01593 cout << _name << ": Setting training as balanced (taking class " 01594 << "distributions into account)." << endl; 01595 // compute vector of sample indices for each class 01596 bal_indices.clear(); 01597 bal_it.clear(); 01598 epoch_done_counters.clear(); 01599 class_it = 0; 01600 class_it_it = 0; 01601 for (intg i = 0; i < nclasses; ++i) { 01602 vector<intg> indices; 01603 bal_indices.push_back(indices); 01604 bal_it.push_back(0); // init iterators 01605 class_order.push_back(i); // init iterators 01606 } 01607 reset_class_order(); 01608 // distribute sample indices into each vector based on label 01609 for (uint i = 0; i < this->size(); ++i) 01610 bal_indices[(intg) (labels.gget(i))].push_back(i); 01611 for (uint i = 0; i < bal_indices.size(); ++i) { 01612 // shuffle 01613 random_shuffle(bal_indices[i].begin(), bal_indices[i].end()); 01614 // init epoch counters 01615 epoch_done_counters.push_back(bal_indices[i].size()); 01616 } 01617 } 01618 } 01619 01620 template <typename Tnet, typename Tdata, typename Tlabel> 01621 void class_datasource<Tnet,Tdata,Tlabel>::set_random_class_order(bool ran) { 01622 random_class_order = ran; 01623 cout << "Classes order is " << (random_class_order ? "" : "not") 01624 << " random." << endl; 01625 } 01626 01627 template <typename Tnet, typename Tdata, typename Tlabel> 01628 void class_datasource<Tnet,Tdata,Tlabel>:: 01629 limit_classes(intg n, intg offset, bool random) { 01630 // enable exclusion except for the n classes after offset 01631 if ((offset == 0 && n >= nclasses) || offset >= nclasses) { 01632 eblwarn("ignoring attempt to limit classes to " << n 01633 << " starting at offset " << offset << " because there are only " 01634 << nclasses << " classes"); 01635 } else { 01636 bexclusion = true; 01637 excluded.clear(); 01638 for (intg i = 0; i < nclasses; ++i) { 01639 if (i < offset || i >= offset + n) 01640 excluded.push_back(true); 01641 else 01642 excluded.push_back(false); 01643 } 01644 if (random) 01645 random_shuffle(excluded.begin(), excluded.end()); 01646 included = std::min(nclasses, n + offset) - offset; 01647 cout << "Excluded all but " << included 01648 << " classes (offset: " << offset << ", n: " << n << ")" << endl; 01649 init_class_labels(); 01650 } 01651 } 01652 01653 template <typename Tnet, typename Tdata, typename Tlabel> 01654 bool class_datasource<Tnet,Tdata,Tlabel>::epoch_done() { 01655 switch (epoch_mode) { 01656 case 0: // fixed number of samples 01657 if (epoch_cnt >= epoch_sz) 01658 return true; 01659 break ; 01660 case 1: // see all samples at least once 01661 if (balance) { 01662 // check that all classes are done 01663 for (uint i = 0; i < epoch_done_counters.size(); ++i) { 01664 if (epoch_done_counters[i] > 0) 01665 return false; 01666 } 01667 return true; // all classes are done 01668 } else { // do not balance, use epoch_sz 01669 if (epoch_cnt >= epoch_sz) 01670 return true; 01671 } 01672 break ; 01673 default: eblerror("unknown epoch_mode"); 01674 } 01675 return false; 01676 } 01677 01678 template <typename Tnet, typename Tdata, typename Tlabel> 01679 void class_datasource<Tnet,Tdata,Tlabel>::init_epoch() { 01680 epoch_cnt = 0; 01681 epoch_pick_cnt = 0; 01682 epoch_timer.restart(); 01683 epoch_show_printed = -1; // last epoch count we have printed 01684 if (balance) { 01685 uint maxsize = 0; 01686 // for balanced training, set each class to not done. 01687 for (uint k = 0; k < bal_indices.size(); ++k) { 01688 epoch_done_counters[k] = bal_indices[k].size(); 01689 if (bal_indices[k].size() > maxsize) 01690 maxsize = bal_indices[k].size(); 01691 } 01692 if (epoch_mode == 1) // for ETA estimation only, estimate epoch size 01693 epoch_sz = maxsize * bal_indices.size(); 01694 } 01695 // if we have prior information about each sample energy and classification 01696 // let's use it to initialize the picking probabilities. 01697 if (weigh_samples) 01698 this->normalize_all_probas(); 01699 } 01700 01701 template <typename Tnet, typename Tdata, typename Tlabel> 01702 void class_datasource<Tnet,Tdata,Tlabel>::normalize_all_probas() { 01703 if (weigh_samples) { 01704 if (perclass_norm && balance) { 01705 for (uint i = 0; i < bal_indices.size(); ++i) 01706 normalize_probas(i); 01707 } else 01708 normalize_probas(); 01709 } 01710 } 01711 01712 template <typename Tnet, typename Tdata, typename Tlabel> 01713 void class_datasource<Tnet,Tdata,Tlabel>::normalize_probas(int classid) { 01714 vector<intg> *cindices = NULL; 01715 if (perclass_norm && balance) { // use only class_it class samples 01716 if (classid < 0) 01717 eblerror("class id cannot be negative"); 01718 uint class_it = (uint) classid; 01719 cindices = &(bal_indices[class_it]); 01720 cout << _name << ": Normalizing probabilities for class" << class_it; 01721 datasource<Tnet,Tdata>::normalize_probas(cindices); 01722 } else // use all samples 01723 datasource<Tnet,Tdata>::normalize_probas(); 01724 } 01725 01727 // accessors 01728 01729 template <typename Tnet, typename Tdata, typename Tlabel> 01730 intg class_datasource<Tnet, Tdata, Tlabel>::get_nclasses() { 01731 if (bexclusion) 01732 return included; 01733 return nclasses; 01734 } 01735 01736 template <typename Tnet, typename Tdata, typename Tlabel> 01737 int class_datasource<Tnet, Tdata, Tlabel>::get_class_id(const char *name) { 01738 int id_ = -1; 01739 vector<string*>::iterator i = clblstr.begin(); 01740 for (int j = 0; i != clblstr.end(); ++i, ++j) { 01741 if (!strcmp(name, (*i)->c_str())) 01742 id_ = j; 01743 } 01744 return id_; 01745 } 01746 01747 template <typename Tnet, typename Tdata, typename Tlabel> 01748 std::string &class_datasource<Tnet, Tdata, Tlabel>::get_class_name(int id) { 01749 if (id >= (int) clblstr.size()) 01750 eblerror("requesting label string at index " << id 01751 << " but string vector has only " << clblstr.size() 01752 << " elements."); 01753 string *s = clblstr[id]; 01754 if (!s) 01755 eblerror("empty label string"); 01756 return *s; 01757 } 01758 01759 template <typename Tnet, typename Tdata, typename Tlabel> 01760 std::vector<std::string*>& class_datasource<Tnet, Tdata, Tlabel>:: 01761 get_label_strings() { 01762 return clblstr; 01763 } 01764 01765 template <typename Tnet, typename Tdata, typename Tlabel> 01766 intg class_datasource<Tnet,Tdata,Tlabel>::get_lowest_common_size() { 01767 intg min_nonzero = (std::numeric_limits<intg>::max)(); 01768 for (vector<intg>::iterator i = counts.begin(); i != counts.end(); ++i) { 01769 if ((*i < min_nonzero) && (*i != 0)) 01770 min_nonzero = *i; 01771 } 01772 if (min_nonzero == (std::numeric_limits<intg>::max)()) 01773 eblerror("empty dataset"); 01774 return min_nonzero * nclasses; 01775 } 01776 01777 template <typename Tnet, typename Tdata, typename Tlabel> 01778 void class_datasource<Tnet,Tdata,Tlabel>::save_pickings(const char *name_) { 01779 // non-class plotting 01780 datasource<Tnet,Tdata>::save_pickings(name_); 01781 string name = "pickings"; 01782 if (name_) 01783 name = name_; 01784 // plot by class 01785 write_classed_pickings(pick_count, correct, name); 01786 write_classed_pickings(energies, correct, name, "_energies"); 01787 idx<double> e = idx_copy(energies); 01788 idx<ubyte> c = idx_copy(correct); 01789 idx_sortup(e, c); 01790 write_classed_pickings(e, c, name, "_sorted_energies"); 01791 idx<double> p = idx_copy(probas); 01792 c = idx_copy(correct); 01793 idx_sortup(p, c); 01794 write_classed_pickings(p, c, name, "_sorted_probas"); 01795 p = idx_copy(probas); 01796 e = idx_copy(energies); 01797 c = idx_copy(correct); 01798 idx_sortup(e, c, p); 01799 write_classed_pickings(p, c, name, "_probas_sorted_by_energy", true, 01800 "Picking probability"); 01801 write_classed_pickings(p, c, name, "_probas_sorted_by_energy_wrong_only", 01802 false, "Picking probability"); 01803 write_classed_pickings(e, c, name, "_energies_sorted_by_energy_wrong_only", 01804 false, "Energy"); 01805 } 01806 01807 template <typename Tnet, typename Tdata, typename Tlabel> 01808 template <typename T> 01809 void class_datasource<Tnet,Tdata,Tlabel>:: 01810 write_classed_pickings(idx<T> &m, idx<ubyte> &c, string &name_, 01811 const char *name2_, bool plot_correct, 01812 const char *ylabel) { 01813 string name = name_; 01814 if (name2_) 01815 name += name2_; 01816 name += "_classed"; 01817 // sorted classed plot file 01818 if (labels.order() == 1) { // single label value 01819 string fname = name; 01820 fname += ".plot"; 01821 ofstream fp(fname.c_str()); 01822 if (!fp) { 01823 cerr << "failed to open " << fname << endl; 01824 eblerror("failed to open file for writing"); 01825 } 01826 eblerror("not implemented"); 01827 // typename idx<T>::dimension_iterator i = m.dim_begin(0); 01828 // typename idx<Tlabel>::dimension_iterator l = labels.dim_begin(0); 01829 // typename idx<ubyte>::dimension_iterator ic = c.dim_begin(0); 01830 // uint j = 0; 01831 // for ( ; i.notdone(); i++, l++, ic++) { 01832 // if (!plot_correct && ic->get() == 1) 01833 // continue ; // don't plot correct ones 01834 // fp << j++; 01835 // for (Tlabel k = 0; k < (Tlabel) nclasses; ++k) { 01836 // if (k == l.get()) { 01837 // if (ic->get() == 0) { // wrong 01838 // fp << "\t" << i->get(); 01839 // if (plot_correct) 01840 // fp << "\t?"; 01841 // } else if (plot_correct) // correct 01842 // fp << "\t?\t" << i->get(); 01843 // } else { 01844 // fp << "\t?"; 01845 // if (plot_correct) 01846 // fp << "\t?"; 01847 // } 01848 // } 01849 // fp << endl; 01850 // } 01851 fp.close(); 01852 cout << _name << ": Wrote picking statistics in " << fname << endl; 01853 // p file 01854 string fname2 = name; 01855 fname2 += ".p"; 01856 ofstream fp2(fname2.c_str()); 01857 if (!fp2) { 01858 cerr << "failed to open " << fname2 << endl; 01859 eblerror("failed to open file for writing"); 01860 } 01861 fp2 << "set title \"" << name << "\"; set ylabel \"" << ylabel 01862 << "\"; plot \"" 01863 << fname << "\" using 1:2 title \"class 0 wrong\" with impulse"; 01864 if (plot_correct) 01865 fp2 << ", \"" 01866 << fname << "\" using 1:3 title \"class 0 correct\" with impulse"; 01867 for (uint k = 1; k < nclasses; ++k) { 01868 fp2 << ", \"" << fname << "\" using 1:" << k * (plot_correct?2:1) + 2 01869 << " title \"class " << k << " wrong\" with impulse"; 01870 if (plot_correct) 01871 fp2 << ", \"" << fname << "\" using 1:" << k * 2 + 3 01872 << " title \"class " << k << " correct\" with impulse"; 01873 } 01874 fp << endl; 01875 fp2.close(); 01876 cout << _name << ": Wrote gnuplot file in " << fname2 << endl; 01877 } 01878 } 01879 01881 // state saving 01882 01883 template <typename Tnet, typename Tdata, typename Tlabel> 01884 void class_datasource<Tnet,Tdata,Tlabel>::save_state() { 01885 state_saved = true; 01886 count_pickings_save = count_pickings; 01887 it_saved = it; // save main iterator 01888 it_test_saved = it_test; 01889 it_train_saved = it_train; 01890 this->epoch_cnt_saved = epoch_cnt; 01891 this->epoch_pick_cnt_saved = epoch_pick_cnt; 01892 this->epoch_done_counters_saved = epoch_done_counters; 01893 if (!balance) // save (unbalanced) iterators 01894 datasource<Tnet,Tdata>::save_state(); 01895 else { // save balanced iterators 01896 bal_it_saved.clear(); 01897 bal_indices_saved.clear(); 01898 for (uint k = 0; k < bal_it.size(); ++k) { 01899 bal_it_saved.push_back(bal_it[k]); 01900 vector<intg> indices; 01901 for (uint l = 0; l < bal_indices[k].size(); ++l) 01902 indices.push_back(bal_indices[k][l]); 01903 bal_indices_saved.push_back(indices); 01904 } 01905 class_it_saved = class_it; 01906 class_it_it_saved = class_it_it; 01907 } 01908 } 01909 01910 template <typename Tnet, typename Tdata, typename Tlabel> 01911 void class_datasource<Tnet,Tdata,Tlabel>::restore_state() { 01912 if (!state_saved) 01913 eblerror("state not saved, call save_state() before restore_state()"); 01914 count_pickings = count_pickings_save; 01915 it = it_saved; // restore main iterator 01916 it_test = it_test_saved; 01917 it_train = it_train_saved; 01918 epoch_cnt = this->epoch_cnt_saved; 01919 epoch_pick_cnt = this->epoch_pick_cnt_saved; 01920 epoch_done_counters = this->epoch_done_counters_saved; 01921 if (!balance) // restore unbalanced 01922 datasource<Tnet,Tdata>::restore_state(); 01923 else { // restore balanced iterators 01924 for (uint k = 0; k < bal_it.size(); ++k) { 01925 bal_it[k] = bal_it_saved[k]; 01926 for (uint l = 0; l < bal_indices[k].size(); ++l) 01927 bal_indices[k][l] = bal_indices_saved[k][l]; 01928 } 01929 class_it = class_it_saved; 01930 class_it_it = class_it_it_saved; 01931 } 01932 } 01933 01935 // pretty methods 01936 01937 template <typename Tnet, typename Tdata, typename Tlabel> 01938 void class_datasource<Tnet,Tdata,Tlabel>::pretty_progress(bool newline) { 01939 if (this->is_test()) 01940 datasource<Tnet,Tdata>::pretty_progress(newline); 01941 else { 01942 if (epoch_show > 0 && epoch_pick_cnt % epoch_show == 0 && 01943 epoch_show_printed != (intg) epoch_pick_cnt) { 01944 datasource<Tnet,Tdata>::pretty_progress(false); 01945 if (balance && epoch_done_counters.size() < 50) { 01946 cout << ", remaining:"; 01947 for (uint i = 0; i < epoch_done_counters.size(); ++i) { 01948 cout << " " << i << ": " << epoch_done_counters[i]; 01949 } 01950 } 01951 if (newline) 01952 cout << endl; 01953 } 01954 } 01955 } 01956 01957 template <typename Tnet, typename Tdata, typename Tlabel> 01958 void class_datasource<Tnet, Tdata, Tlabel>::pretty() { 01959 cout << _name << ": classification dataset \"" << _name 01960 << "\" contains " 01961 << data.dim(0) << " samples of dimension " << sampledims 01962 << " and defines an epoch as " << epoch_sz << " samples." << endl; 01963 cout << this->_name << ": It has " << nclasses << " classes: "; 01964 uint i; 01965 for (i = 0; i < this->counts.size(); ++i) 01966 if (!bexclusion || !excluded[i]) 01967 cout << this->counts[i] << " \"" << *(clblstr[i]) << "\" "; 01968 cout << endl; 01969 pretty_scales(); 01970 } 01971 01972 template <typename Tnet, typename Tdata, typename Tlabel> 01973 void class_datasource<Tnet, Tdata, Tlabel>::pretty_scales() { 01974 labeled_datasource<Tnet,Tdata,Tlabel>::pretty_scales(); 01975 if (this->scales_loaded) { 01976 intg maxscale = idx_max(this->scales); 01977 for (uint c = 0; c < this->counts.size(); ++c) 01978 if (!bexclusion || !excluded[c]) { 01979 vector<intg> tally(maxscale + 1, 0); 01980 idx_bloop2(scale, this->scales, intg, label, labels, Tlabel) { 01981 if (label.get() != (Tlabel) c) continue ; 01982 intg s = scale.get(); 01983 if (s < 0) eblerror("unexpected negative value"); 01984 tally[s] = tally[s] + 1; 01985 } 01986 intg nscales = 0; 01987 for (intg i = 0; i < (intg) tally.size(); ++i) 01988 if (tally[i] > 0) nscales++; 01989 cout << _name << ": " << *(clblstr[c]) << " has " 01990 << nscales << " scales"; 01991 for (intg i = 0; i < (intg) tally.size(); ++i) 01992 cout << ", " << i << ": " << tally[i]; 01993 cout << endl; 01994 } 01995 } 01996 } 01997 01998 // protected methods ///////////////////////////////////////////////////////// 01999 02000 template <typename Tnet, typename Tdata, typename Tlabel> 02001 bool class_datasource<Tnet,Tdata,Tlabel>::pick_current() { 02002 if (bexclusion && excluded[(int) get_label()]) 02003 return false; 02004 return datasource<Tnet,Tdata>::pick_current(); 02005 } 02006 02009 02010 template <typename Tlabel> 02011 class_node<Tlabel>::class_node(Tlabel id, string &name_) 02012 : _label(id), _name(name_), parent(NULL), children(), 02013 it_children(children.begin()), 02014 bempty(true), iempty(true), _depth(0), it_samples(samples.begin()) { 02015 //EDEBUG("creating node " << this << " label: " << id << " name: " 02016 //<< _name); 02017 } 02018 02019 template <typename Tlabel> 02020 class_node<Tlabel>::~class_node() { 02021 } 02022 02023 template <typename Tlabel> 02024 bool class_node<Tlabel>::empty() { 02025 return bempty; 02026 } 02027 02028 template <typename Tlabel> 02029 bool class_node<Tlabel>::internally_empty() { 02030 return iempty; 02031 } 02032 02033 template <typename Tlabel> 02034 intg class_node<Tlabel>::next() { 02035 if (bempty) eblerror("cannot call next() on empty node"); 02036 if (it_children < children.begin()) it_children = children.begin(); 02037 if (it_samples < samples.begin()) it_samples = samples.begin(); 02038 intg id = -1; 02039 if (it_children == children.end()) { // we reached the end of children 02040 // roll back children iterator 02041 it_children = children.begin(); 02042 if (samples.size() > 0) { // return internal samples if present 02043 id = *it_samples; 02044 it_samples++; 02045 if (it_samples == samples.end()) // roll back samples iterator 02046 it_samples = samples.begin(); 02047 return id; 02048 } 02049 } 02050 // return a sample from children 02051 if ((*it_children)->empty()) { 02052 it_children++; 02053 return next(); // skip current empty child 02054 } else // this child is not empty 02055 return (*it_children)->next(); 02056 } 02057 02058 template <typename Tlabel> 02059 void class_node<Tlabel>::add_child(class_node *child) { 02060 // make sure child is not added twice 02061 for (typename vector<class_node<Tlabel>*>::iterator i = children.begin(); 02062 i != children.end(); ++i) 02063 if (*i == child) 02064 eblerror("trying to push same child node twice"); 02065 // push 02066 children.push_back(child); 02067 // set child's parent to this node 02068 child->set_parent(this); 02069 // propagate up if child is non-empty 02070 if (!child->empty()) 02071 set_non_empty(); 02072 } 02073 02074 template <typename Tlabel> 02075 void class_node<Tlabel>::add_sample(intg index) { 02076 samples.push_back(index); 02077 iempty = false; 02078 // propagate information back to parents 02079 set_non_empty(); 02080 } 02081 02082 template <typename Tlabel> 02083 intg class_node<Tlabel>::nsamples() { 02084 // TODO: we might want to use idx here to handle large sets (intg) 02085 return samples.size(); 02086 } 02087 02088 template <typename Tlabel> 02089 Tlabel class_node<Tlabel>::label() { 02090 return _label; 02091 } 02092 02093 template <typename Tlabel> 02094 Tlabel class_node<Tlabel>::label(uint depth) { 02095 // current depth is lower or equal to target depth, return current label 02096 if (_depth <= depth || !parent) 02097 return _label; 02098 // current depth is higher than target depth, call parent's label 02099 return parent->label(depth); 02100 } 02101 02102 template <typename Tlabel> 02103 uint class_node<Tlabel>::depth() { 02104 return _depth; 02105 } 02106 02107 template <typename Tlabel> 02108 uint class_node<Tlabel>::set_depth(uint d) { 02109 _depth = d; 02110 uint dmax = d; 02111 for (typename vector<class_node<Tlabel>*>::iterator i = children.begin(); 02112 i != children.end(); ++i) 02113 dmax = std::max(dmax, (*i)->set_depth(d + 1)); 02114 return dmax; 02115 } 02116 02117 template <typename Tlabel> 02118 string &class_node<Tlabel>::name() { 02119 return _name; 02120 } 02121 02122 template <typename Tlabel> 02123 class_node<Tlabel>* class_node<Tlabel>::get_parent() { 02124 return parent; 02125 } 02126 02127 template <typename Tlabel> 02128 bool class_node<Tlabel>::is_parent(Tlabel lab) { 02129 if (lab == _label) 02130 return true; 02131 if (parent) 02132 return parent->is_parent(lab); 02133 return false; 02134 } 02135 02136 // protected methods ///////////////////////////////////////////////////////// 02137 02138 template <typename Tlabel> 02139 void class_node<Tlabel>::set_non_empty() { 02140 bempty = false; 02141 // propagate information back to parents 02142 if (parent) 02143 parent->set_non_empty(); 02144 } 02145 02146 template <typename Tlabel> 02147 void class_node<Tlabel>::set_parent(class_node *p) { 02148 parent = p; 02149 } 02150 02152 // hierarchy_datasource 02153 02154 template <typename Tnet, typename Tdata, typename Tlabel> 02155 hierarchy_datasource<Tnet, Tdata, Tlabel>::hierarchy_datasource() 02156 : class_datasource<Tnet,Tdata,Tlabel>() { 02157 } 02158 02159 template <typename Tnet, typename Tdata, typename Tlabel> 02160 hierarchy_datasource<Tnet, Tdata, Tlabel>:: 02161 hierarchy_datasource(midx<Tdata> &data_, idx<Tlabel> &labels_, 02162 idx<Tlabel> *parents_, 02163 vector<string*> *lblstr_, const char *name_) { 02164 class_datasource<Tnet,Tdata,Tlabel>::init(data_, labels_, name_, lblstr_); 02165 this->init_parents(parents_); 02166 this->init_epoch(); 02167 this->pretty(); // print info about dataset 02168 } 02169 02170 template <typename Tnet, typename Tdata, typename Tlabel> 02171 hierarchy_datasource<Tnet, Tdata, Tlabel>:: 02172 hierarchy_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_, 02173 idx<Tlabel> *parents_, 02174 vector<string*> *lblstr_, const char *name_) { 02175 class_datasource<Tnet,Tdata,Tlabel>::init(data_, labels_, name_, lblstr_); 02176 this->init_parents(parents_); 02177 this->init_epoch(); 02178 this->pretty(); // print info about dataset 02179 } 02180 02181 template <typename Tnet, typename Tdata, typename Tlabel> 02182 hierarchy_datasource<Tnet, Tdata, Tlabel>:: 02183 hierarchy_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_, 02184 idx<Tlabel> *parents_, idx<ubyte> *classes, 02185 const char *name_) { 02186 if (classes) 02187 this->init_strings(*classes); 02188 class_datasource<Tnet,Tdata,Tlabel>::init(data_, labels_, this->lblstr, 02189 name_); 02190 this->init_parents(parents_); 02191 this->init_epoch(); 02192 this->pretty(); // print info about dataset 02193 } 02194 02195 template <typename Tnet, typename Tdata, typename Tlabel> 02196 hierarchy_datasource<Tnet, Tdata, Tlabel>:: 02197 hierarchy_datasource(const char *data_name, const char *labels_name, 02198 const char *parents_name, const char *jitters_name, 02199 const char *scales_name, const char *classes_name, 02200 const char *name_, uint max_size) { 02201 class_datasource<Tnet,Tdata,Tlabel>:: 02202 init(data_name, labels_name, jitters_name, scales_name, classes_name, 02203 name_, max_size); 02204 // load parent 02205 idx<Tlabel> parents_; 02206 if (parents_name) { 02207 try { 02208 parents_ = load_matrix<Tlabel>(parents_name); 02209 } eblcatcherror(); 02210 } 02211 // inits 02212 this->init_parents(parents_name ? &parents_ : NULL); 02213 this->init_epoch(); 02214 this->pretty(); // print info about dataset 02215 } 02216 02217 // template <typename Tnet, typename Tdata, typename Tlabel> 02218 // hierarchy_datasource<Tnet, Tdata, Tlabel>:: 02219 // hierarchy_datasource(const hierarchy_datasource<Tnet, Tdata, Tlabel> &ds) 02220 // : class_datasource<Tnet,Tdata>((const class_datasource<Tnet,Tdata>&) ds) { 02221 // } 02222 02223 template <typename Tnet, typename Tdata, typename Tlabel> 02224 hierarchy_datasource<Tnet, Tdata, Tlabel>::~hierarchy_datasource() { 02225 // delete all nodes 02226 for (typename vector<class_node<Tlabel>*>::iterator i = all_nodes.begin(); 02227 i != all_nodes.end(); ++i) { 02228 class_node<Tlabel> *n = *i; 02229 if (n) delete n; 02230 } 02231 // delete depth vectors 02232 for (typename vector<vector<class_node<Tlabel>*>*>::iterator i = 02233 all_depths.begin(); i != all_depths.end(); ++i) { 02234 vector<class_node<Tlabel>*> *n = *i; 02235 if (n) delete n; 02236 } 02237 if (parents) 02238 delete parents; 02239 } 02240 02242 // init methods 02243 02244 template <typename Tnet, typename Tdata, typename Tlabel> 02245 void hierarchy_datasource<Tnet, Tdata, Tlabel>:: 02246 init_parents(idx<Tlabel> *parents_) { 02247 if (parents_) 02248 parents = new idx<Tlabel>(*parents_); 02249 else { 02250 eblwarn("no parents hierarchy specified, initializing with a flat " 02251 << "hierarchy"); 02252 parents = new idx<Tlabel>(nclasses, 2); 02253 for (Tlabel i = 0; i < (Tlabel) nclasses; ++i) { 02254 parents->set(i, i, 0); 02255 parents->set(-1, i, 1); // node i has no parent 02256 } 02257 } 02258 // check that parent id do not exceed number of classes 02259 if (parents) { 02260 if (idx_max(*parents) > nclasses) 02261 eblerror("maximum parent id (" << idx_max(*parents) 02262 << ") cannot exceed number of classes (" << nclasses << ")"); 02263 if (parents->dim(1) != 2) 02264 eblerror("expected dim 1 to be size 2 (child/parent)"); 02265 } 02266 if (!lblstr) 02267 eblerror("expected class strings to be defined"); 02268 // create hierarchy tree for efficient balanced samples ordering 02269 all_nodes.resize(nclasses, NULL); 02270 // add a new node for each class 02271 Tlabel l = 0; 02272 for (vector<string*>::iterator i = lblstr->begin(); i != lblstr->end();++i){ 02273 class_node<Tlabel> *node = all_nodes[l]; 02274 if (!node) { 02275 node = new class_node<Tlabel>(l, *((*lblstr)[l])); 02276 all_nodes[l] = node; 02277 } 02278 l++; 02279 } 02280 // add samples to each node 02281 intg i = 0; 02282 idx_bloop1(lab, labels, Tlabel) { 02283 Tlabel l = lab.get(); 02284 class_node<Tlabel> *node = all_nodes[l]; 02285 if (!node) eblerror("node " << l << " is not defined"); 02286 // add sample to node 02287 node->add_sample(i); 02288 i++; 02289 } 02290 // associate each class to its parent 02291 if (parents) { 02292 idx_bloop1(par, *parents, Tlabel) { 02293 class_node<Tlabel> *child = all_nodes[par.get(0)]; 02294 if (!child) eblerror("no node with id " << par.get(0)); 02295 if (par.get(1) >= 0) { 02296 class_node<Tlabel> *parent = all_nodes[par.get(1)]; 02297 parent->add_child(child); 02298 //EDEBUG("adding "<< child->name() << " as child of " << parent->name()); 02299 } 02300 } 02301 } 02302 // assign depth to all nodes, starting from all orphan nodes 02303 uint maxdepth = 0; 02304 for (typename vector<class_node<Tlabel>*>::iterator i = all_nodes.begin(); 02305 i != all_nodes.end(); ++i) { 02306 class_node<Tlabel> *n = *i; 02307 if (n && !n->get_parent()) 02308 maxdepth = std::max(maxdepth, n->set_depth(0)); 02309 } 02310 // remember nodes arranged by depth 02311 all_depths.resize(maxdepth + 1, NULL); 02312 for (typename vector<class_node<Tlabel>*>::iterator i = all_nodes.begin(); 02313 i != all_nodes.end(); ++i) { 02314 class_node<Tlabel> *n = *i; 02315 if (n) { 02316 vector<class_node<Tlabel>*> *d = all_depths[n->depth()]; 02317 if (!d) { 02318 d = new vector<class_node<Tlabel>*>; 02319 all_depths[n->depth()] = d; 02320 } 02321 d->push_back(n); 02322 //EDEBUG("assigning " << n->name() << " to depth " << n->depth()); 02323 } 02324 } 02325 // remember nodes arranged by depth and keep nodes with lower depth that 02326 // have samples internally. this garantees that all samples are used even 02327 // at lower depths than current one. 02328 complete_depths.resize(maxdepth + 1, NULL); 02329 for (typename vector<class_node<Tlabel>*>::iterator i = all_nodes.begin(); 02330 i != all_nodes.end(); ++i) { 02331 class_node<Tlabel> *n = *i; 02332 if (n) { 02333 vector<class_node<Tlabel>*> *d = complete_depths[n->depth()]; 02334 if (!d) { 02335 d = new vector<class_node<Tlabel>*>; 02336 complete_depths[n->depth()] = d; 02337 } 02338 d->push_back(n); 02339 // add node to all higher depths if it has internal samples 02340 if (!n->internally_empty()) { 02341 for (uint i = n->depth() + 1; i < complete_depths.size(); ++i) { 02342 vector<class_node<Tlabel>*> *dd = complete_depths[i]; 02343 if (!dd) { 02344 dd = new vector<class_node<Tlabel>*>; 02345 complete_depths[i] = dd; 02346 } 02347 dd->push_back(n); 02348 } 02349 } 02350 //EDEBUG("assigning " << n->name() << " to depth " << n->depth()); 02351 } 02352 } 02353 // order all by depth 02354 for (typename vector<vector<class_node<Tlabel>*>*>::iterator i = 02355 all_depths.begin(); i != all_depths.end(); ++i) { 02356 for (typename vector<class_node<Tlabel>*>::iterator j = (*i)->begin(); 02357 j != (*i)->end(); ++j) { 02358 all_nodes_by_depth.push_back(*j); 02359 } 02360 } 02361 // initialize depths iterators 02362 it_depths.resize(complete_depths.size(), 0); 02363 // allocate depth labels 02364 depth_labels = idx<Tlabel>(labels.get_idxdim()); 02365 set_current_depth(0); 02366 it = 0; 02367 depth_balance = false; 02368 } 02369 02370 template <typename Tnet, typename Tdata, typename Tlabel> 02371 void hierarchy_datasource<Tnet, Tdata, Tlabel>::init_class_labels() { 02372 if (olabels.get_idxdim() != labels.get_idxdim()) 02373 olabels = idx<Tlabel>(labels.get_idxdim()); 02374 idx_copy(labels, olabels); // keep original labels 02375 clblstr.clear(); 02376 for (uint i = 0; i < lblstr->size(); ++i) 02377 clblstr.push_back((*lblstr)[i]); 02378 } 02379 02381 // data access 02382 02383 template <typename Tnet, typename Tdata, typename Tlabel> 02384 Tlabel hierarchy_datasource<Tnet,Tdata,Tlabel>::get_parent() { 02385 // idx<Tlabel> lab = labels[it]; 02386 // if (lab.order() != 1 && lab.dim(0) != 1) 02387 // eblerror("expected single-element labels"); 02388 // return parents.get(lab.get(0)); 02389 return -1; 02390 } 02391 02392 template <typename Tnet, typename Tdata, typename Tlabel> 02393 bool hierarchy_datasource<Tnet,Tdata,Tlabel>:: 02394 is_parent_of(Tlabel l1, Tlabel l2) { 02395 class_node<Tlabel> *n = all_nodes[l2]; 02396 if (!n) eblerror("node " << l2 << " not found"); 02397 return n->is_parent(l1); 02398 } 02399 02400 template <typename Tnet, typename Tdata, typename Tlabel> 02401 vector<class_node<Tlabel>*>& hierarchy_datasource<Tnet,Tdata,Tlabel>:: 02402 get_nodes() { 02403 return all_nodes; 02404 } 02405 02406 template <typename Tnet, typename Tdata, typename Tlabel> 02407 vector<class_node<Tlabel>*>& 02408 hierarchy_datasource<Tnet,Tdata,Tlabel>::get_nodes_by_depth() { 02409 return all_nodes_by_depth; 02410 } 02411 02412 template <typename Tnet, typename Tdata, typename Tlabel> 02413 void hierarchy_datasource<Tnet,Tdata,Tlabel>:: 02414 fprop_label(fstate_idx<Tlabel> &label) { 02415 label.x.sset(get_label()); 02416 } 02417 02418 template <typename Tnet, typename Tdata, typename Tlabel> 02419 void hierarchy_datasource<Tnet,Tdata,Tlabel>:: 02420 fprop_label_net(fstate_idx<Tnet> &label) { 02421 label.x.sset((Tnet) get_label()); 02422 } 02423 02424 template <typename Tnet, typename Tdata, typename Tlabel> 02425 void hierarchy_datasource<Tnet,Tdata,Tlabel>:: 02426 fprop_label_net(bbstate_idx<Tnet> &label) { 02427 label.x.sset((Tnet) get_label()); 02428 } 02429 02430 template <typename Tnet, typename Tdata, typename Tlabel> 02431 Tlabel hierarchy_datasource<Tnet,Tdata,Tlabel>::get_label() { 02432 // select label based on current depth and current iterator 02433 Tlabel l = labels[it].gget(); 02434 return all_nodes[l]->label(current_depth); 02435 } 02436 02437 template <typename Tnet, typename Tdata, typename Tlabel> 02438 Tlabel hierarchy_datasource<Tnet,Tdata,Tlabel>::get_label(uint d, intg index){ 02439 // select label based on current depth and current iterator 02440 Tlabel l; 02441 if (index < 0) 02442 l = labels[it].gget(); 02443 else 02444 l = labels[index].gget(); 02445 return all_nodes[l]->label(d); 02446 } 02447 02448 template <typename Tnet, typename Tdata, typename Tlabel> 02449 idx<Tlabel>& hierarchy_datasource<Tnet,Tdata,Tlabel>::get_depth_labels() { 02450 return depth_labels; 02451 } 02452 02453 template <typename Tnet, typename Tdata, typename Tlabel> 02454 uint hierarchy_datasource<Tnet,Tdata,Tlabel>:: 02455 get_nbrothers(class_node<Tlabel> &n) { 02456 if (n.parent) 02457 return n.parent->children.size(); 02458 vector<class_node<Tlabel>*>* d0 = all_depths[0]; 02459 if (d0) 02460 return d0->size(); 02461 return 0; 02462 } 02463 02465 // iterating 02466 02467 template <typename Tnet, typename Tdata, typename Tlabel> 02468 void hierarchy_datasource<Tnet,Tdata,Tlabel>::set_depth_balanced(bool bal) { 02469 depth_balance = bal; 02470 if (!depth_balance) // unbalanced 02471 cout << _name << ": Setting training as depth-unbalanced." << endl; 02472 else // balanced 02473 cout << _name << ": Setting training as depth-balanced." << endl; 02474 } 02475 02476 template <typename Tnet, typename Tdata, typename Tlabel> 02477 void hierarchy_datasource<Tnet,Tdata,Tlabel>::set_current_depth(uint depth) { 02478 if (depth >= all_depths.size()) 02479 eblerror("cannot set current depth to " << depth << " because it is " 02480 << "more than maximum depth " << all_depths.size()); 02481 current_depth = depth; 02482 // fill depth_labels matrix with labels of all samples given current depth 02483 idx_bloop2(l, labels, Tlabel, dl, depth_labels, Tlabel) { 02484 dl.set(all_nodes[l.get()]->label(current_depth)); 02485 } 02486 } 02487 02488 template <typename Tnet, typename Tdata, typename Tlabel> 02489 uint hierarchy_datasource<Tnet,Tdata,Tlabel>::get_current_depth() { 02490 return current_depth; 02491 } 02492 02493 template <typename Tnet, typename Tdata, typename Tlabel> 02494 void hierarchy_datasource<Tnet,Tdata,Tlabel>::incr_current_depth() { 02495 if (current_depth + 1 >= all_depths.size()) 02496 cout << "warning: cannot increment current depth beyond maximum (" 02497 << all_depths.size() << ")" << endl; 02498 else 02499 current_depth++; 02500 set_current_depth(current_depth); 02501 } 02502 02503 template <typename Tnet, typename Tdata, typename Tlabel> 02504 bool hierarchy_datasource<Tnet,Tdata,Tlabel>::next_train() { 02505 // check that this datasource is allowed to call this method 02506 if (test_set) 02507 eblerror("forbidden call of next_train() on testing sets"); 02508 if (!depth_balance) // do not balance by depth 02509 return class_datasource<Tnet,Tdata,Tlabel>::next_train(); 02510 // balanced training 02511 vector<class_node<Tlabel>*> *nodes = complete_depths[current_depth]; 02512 uint itd = it_depths[current_depth]; 02513 // set it for further fprop calls 02514 class_node<Tlabel> *node = (*nodes)[itd]; 02515 it = node->next(); 02516 // increment depth iterator 02517 itd++; 02518 it_depths[current_depth] = itd; 02519 if (itd >= nodes->size()) 02520 it_depths[current_depth] = 0; 02521 // increment epoch counters 02522 epoch_cnt++; 02523 epoch_pick_cnt++; 02524 #ifdef __DEBUG__ 02525 cout << "Picking sample " << it << " (label: " << (int)get_label(); 02526 if (lblstr) 02527 cout << ", name: " << *((*lblstr)[(int)get_label()]); 02528 cout << ", pickings: " << pick_count.get(it) << ", energy: " 02529 << energies.get(it) << ", correct: " << (int) correct.get(it); 02530 if (weigh_samples) cout << ", proba: " << probas.get(it); 02531 cout << ")" << endl; 02532 #endif 02533 this->pretty_progress(); 02534 return true; 02535 } 02536 02537 // template <typename Tnet, typename Tdata, typename Tlabel> 02538 // void hierarchy_datasource<Tnet,Tdata,Tlabel>::set_balanced(bool bal) { 02539 // balance = bal; 02540 // if (!balance) // unbalanced 02541 // cout << _name << ": Setting training as unbalanced (not taking class " 02542 // << "distributions into account)." << endl; 02543 // else { // balanced 02544 // cout << _name << ": Setting training as balanced (taking class " 02545 // << "distributions into account)." << endl; 02546 // // compute vector of sample indices for each class 02547 // bal_indices.clear(); 02548 // bal_it.clear(); 02549 // epoch_done_counters.clear(); 02550 // class_it = 0; 02551 // for (intg i = 0; i < nclasses; ++i) { 02552 // vector<intg> indices; 02553 // bal_indices.push_back(indices); 02554 // bal_it.push_back(0); // init iterators 02555 // } 02556 // // distribute sample indices into each vector based on label 02557 // for (uint i = 0; i < this->size(); ++i) 02558 // bal_indices[(intg) (labels.gget(i))].push_back(i); 02559 // for (uint i = 0; i < bal_indices.size(); ++i) { 02560 // // shuffle 02561 // random_shuffle(bal_indices[i].begin(), bal_indices[i].end()); 02562 // // init epoch counters 02563 // epoch_done_counters.push_back(bal_indices[i].size()); 02564 // } 02565 // } 02566 // } 02567 02568 // template <typename Tnet, typename Tdata, typename Tlabel> 02569 // bool hierarchy_datasource<Tnet,Tdata,Tlabel>::epoch_done() { 02570 // switch (epoch_mode) { 02571 // case 0: // fixed number of samples 02572 // if (epoch_cnt >= epoch_sz) 02573 // return true; 02574 // break ; 02575 // case 1: // see all samples at least once 02576 // if (balance) { 02577 // // check that all classes are done 02578 // for (uint i = 0; i < epoch_done_counters.size(); ++i) { 02579 // if (epoch_done_counters[i] > 0) 02580 // return false; 02581 // } 02582 // return true; // all classes are done 02583 // } else { // do not balance, use epoch_sz 02584 // if (epoch_cnt >= epoch_sz) 02585 // return true; 02586 // } 02587 // break ; 02588 // default: eblerror("unknown epoch_mode"); 02589 // } 02590 // return false; 02591 // } 02592 02593 // template <typename Tnet, typename Tdata, typename Tlabel> 02594 // void hierarchy_datasource<Tnet,Tdata,Tlabel>::init_epoch() { 02595 // epoch_cnt = 0; 02596 // epoch_pick_cnt = 0; 02597 // epoch_timer.restart(); 02598 // epoch_show_printed = -1; // last epoch count we have printed 02599 // if (balance) { 02600 // uint maxsize = 0; 02601 // // for balanced training, set each class to not done. 02602 // for (uint k = 0; k < bal_indices.size(); ++k) { 02603 // epoch_done_counters[k] = bal_indices[k].size(); 02604 // if (bal_indices[k].size() > maxsize) 02605 // maxsize = bal_indices[k].size(); 02606 // } 02607 // if (epoch_mode == 1) // for ETA estimation only, estimate epoch size 02608 // epoch_sz = maxsize * bal_indices.size(); 02609 // } 02610 // // if we have prior information about each sample energy and classification 02611 // // let's use it to initialize the picking probabilities. 02612 // this->normalize_all_probas(); 02613 // } 02614 02615 // template <typename Tnet, typename Tdata, typename Tlabel> 02616 // void hierarchy_datasource<Tnet,Tdata,Tlabel>::normalize_all_probas() { 02617 // if (weigh_samples) { 02618 // if (perclass_norm && balance) { 02619 // for (uint i = 0; i < bal_indices.size(); ++i) 02620 // normalize_probas(i); 02621 // } else 02622 // normalize_probas(); 02623 // } 02624 // } 02625 02626 // template <typename Tnet, typename Tdata, typename Tlabel> 02627 // void hierarchy_datasource<Tnet,Tdata,Tlabel>::normalize_probas(int classid) { 02628 // vector<intg> *cindices = NULL; 02629 // if (perclass_norm && balance) { // use only class_it class samples 02630 // if (classid < 0) 02631 // eblerror("class id cannot be negative"); 02632 // uint class_it = (uint) classid; 02633 // cindices = &(bal_indices[class_it]); 02634 // cout << _name << ": Normalizing probabilities for class" << class_it; 02635 // datasource<Tnet,Tdata>::normalize_probas(cindices); 02636 // } else // use all samples 02637 // datasource<Tnet,Tdata>::normalize_probas(); 02638 // } 02639 02640 // ////////////////////////////////////////////////////////////////////////////// 02641 // // accessors 02642 02643 // template <typename Tnet, typename Tdata, typename Tlabel> 02644 // intg hierarchy_datasource<Tnet, Tdata, Tlabel>::get_nclasses() { 02645 // return nclasses; 02646 // } 02647 02648 // template <typename Tnet, typename Tdata, typename Tlabel> 02649 // int hierarchy_datasource<Tnet, Tdata, Tlabel>::get_class_id(const char *name) { 02650 // int id_ = -1; 02651 // vector<string*>::iterator i = lblstr->begin(); 02652 // for (int j = 0; i != lblstr->end(); ++i, ++j) { 02653 // if (!strcmp(name, (*i)->c_str())) 02654 // id_ = j; 02655 // } 02656 // return id_; 02657 // } 02658 02659 // template <typename Tnet, typename Tdata, typename Tlabel> 02660 // std::string &hierarchy_datasource<Tnet, Tdata, Tlabel>::get_class_name(int id) { 02661 // if (!lblstr) 02662 // eblerror("no label strings"); 02663 // if (id >= (int) lblstr->size()) 02664 // eblerror("requesting label string at index " << id 02665 // << " but string vector has only " << lblstr->size() 02666 // << " elements."); 02667 // string *s = (*lblstr)[id]; 02668 // if (!s) 02669 // eblerror("empty label string"); 02670 // return *s; 02671 // } 02672 02673 // template <typename Tnet, typename Tdata, typename Tlabel> 02674 // std::vector<std::string*>& hierarchy_datasource<Tnet, Tdata, Tlabel>:: 02675 // get_label_strings() { 02676 // if (!lblstr) 02677 // eblerror("expected label strings to be set"); 02678 // return *lblstr; 02679 // } 02680 02681 // template <typename Tnet, typename Tdata, typename Tlabel> 02682 // intg hierarchy_datasource<Tnet,Tdata,Tlabel>::get_lowest_common_size() { 02683 // intg min_nonzero = (std::numeric_limits<intg>::max)(); 02684 // for (vector<intg>::iterator i = counts.begin(); i != counts.end(); ++i) { 02685 // if ((*i < min_nonzero) && (*i != 0)) 02686 // min_nonzero = *i; 02687 // } 02688 // if (min_nonzero == (std::numeric_limits<intg>::max)()) 02689 // eblerror("empty dataset"); 02690 // return min_nonzero * nclasses; 02691 // } 02692 02693 // template <typename Tnet, typename Tdata, typename Tlabel> 02694 // void hierarchy_datasource<Tnet,Tdata,Tlabel>::save_pickings(const char *name_) { 02695 // // non-class plotting 02696 // datasource<Tnet,Tdata>::save_pickings(name_); 02697 // string name = "pickings"; 02698 // if (name_) 02699 // name = name_; 02700 // // plot by class 02701 // write_classed_pickings(pick_count, correct, name); 02702 // write_classed_pickings(energies, correct, name, "_energies"); 02703 // idx<double> e = idx_copy(energies); 02704 // idx<ubyte> c = idx_copy(correct); 02705 // idx_sortup(e, c); 02706 // write_classed_pickings(e, c, name, "_sorted_energies"); 02707 // idx<double> p = idx_copy(probas); 02708 // c = idx_copy(correct); 02709 // idx_sortup(p, c); 02710 // write_classed_pickings(p, c, name, "_sorted_probas"); 02711 // p = idx_copy(probas); 02712 // e = idx_copy(energies); 02713 // c = idx_copy(correct); 02714 // idx_sortup(e, c, p); 02715 // write_classed_pickings(p, c, name, "_probas_sorted_by_energy", true, 02716 // "Picking probability"); 02717 // write_classed_pickings(p, c, name, "_probas_sorted_by_energy_wrong_only", 02718 // false, "Picking probability"); 02719 // write_classed_pickings(e, c, name, "_energies_sorted_by_energy_wrong_only", 02720 // false, "Energy"); 02721 // } 02722 02723 // template <typename Tnet, typename Tdata, typename Tlabel> 02724 // template <typename T> 02725 // void hierarchy_datasource<Tnet,Tdata,Tlabel>:: 02726 // write_classed_pickings(idx<T> &m, idx<ubyte> &c, string &name_, 02727 // const char *name2_, bool plot_correct, 02728 // const char *ylabel) { 02729 // string name = name_; 02730 // if (name2_) 02731 // name += name2_; 02732 // name += "_classed"; 02733 // // sorted classed plot file 02734 // if (labels.order() == 1) { // single label value 02735 // string fname = name; 02736 // fname += ".plot"; 02737 // ofstream fp(fname.c_str()); 02738 // if (!fp) { 02739 // cerr << "failed to open " << fname << endl; 02740 // eblerror("failed to open file for writing"); 02741 // } 02742 // typename idx<T>::dimension_iterator i = m.dim_begin(0); 02743 // typename idx<Tlabel>::dimension_iterator l = labels.dim_begin(0); 02744 // typename idx<ubyte>::dimension_iterator ic = c.dim_begin(0); 02745 // uint j = 0; 02746 // for ( ; i.notdone(); i++, l++, ic++) { 02747 // if (!plot_correct && ic->get() == 1) 02748 // continue ; // don't plot correct ones 02749 // fp << j++; 02750 // for (Tlabel k = 0; k < (Tlabel) nclasses; ++k) { 02751 // if (k == l.get()) { 02752 // if (ic->get() == 0) { // wrong 02753 // fp << "\t" << i->get(); 02754 // if (plot_correct) 02755 // fp << "\t?"; 02756 // } else if (plot_correct) // correct 02757 // fp << "\t?\t" << i->get(); 02758 // } else { 02759 // fp << "\t?"; 02760 // if (plot_correct) 02761 // fp << "\t?"; 02762 // } 02763 // } 02764 // fp << endl; 02765 // } 02766 // fp.close(); 02767 // cout << _name << ": Wrote picking statistics in " << fname << endl; 02768 // // p file 02769 // string fname2 = name; 02770 // fname2 += ".p"; 02771 // ofstream fp2(fname2.c_str()); 02772 // if (!fp2) { 02773 // cerr << "failed to open " << fname2 << endl; 02774 // eblerror("failed to open file for writing"); 02775 // } 02776 // fp2 << "set title \"" << name << "\"; set ylabel \"" << ylabel 02777 // << "\"; plot \"" 02778 // << fname << "\" using 1:2 title \"class 0 wrong\" with impulse"; 02779 // if (plot_correct) 02780 // fp2 << ", \"" 02781 // << fname << "\" using 1:3 title \"class 0 correct\" with impulse"; 02782 // for (uint k = 1; k < nclasses; ++k) { 02783 // fp2 << ", \"" << fname << "\" using 1:" << k * (plot_correct?2:1) + 2 02784 // << " title \"class " << k << " wrong\" with impulse"; 02785 // if (plot_correct) 02786 // fp2 << ", \"" << fname << "\" using 1:" << k * 2 + 3 02787 // << " title \"class " << k << " correct\" with impulse"; 02788 // } 02789 // fp << endl; 02790 // fp2.close(); 02791 // cout << _name << ": Wrote gnuplot file in " << fname2 << endl; 02792 // } 02793 // } 02794 02795 // ////////////////////////////////////////////////////////////////////////////// 02796 // // state saving 02797 02798 // template <typename Tnet, typename Tdata, typename Tlabel> 02799 // void hierarchy_datasource<Tnet,Tdata,Tlabel>::save_state() { 02800 // state_saved = true; 02801 // count_pickings_save = count_pickings; 02802 // it_saved = it; // save main iterator 02803 // it_test_saved = it_test; 02804 // it_train_saved = it_train; 02805 // if (!balance) // save (unbalanced) iterators 02806 // datasource<Tnet,Tdata>::save_state(); 02807 // else { // save balanced iterators 02808 // bal_it_saved.clear(); 02809 // bal_indices_saved.clear(); 02810 // for (uint k = 0; k < bal_it.size(); ++k) { 02811 // bal_it_saved.push_back(bal_it[k]); 02812 // vector<intg> indices; 02813 // for (uint l = 0; l < bal_indices[k].size(); ++l) 02814 // indices.push_back(bal_indices[k][l]); 02815 // bal_indices_saved.push_back(indices); 02816 // } 02817 // class_it_saved = class_it; 02818 // } 02819 // } 02820 02821 // template <typename Tnet, typename Tdata, typename Tlabel> 02822 // void hierarchy_datasource<Tnet,Tdata,Tlabel>::restore_state() { 02823 // if (!state_saved) 02824 // eblerror("state not saved, call save_state() before restore_state()"); 02825 // count_pickings = count_pickings_save; 02826 // it = it_saved; // restore main iterator 02827 // it_test = it_test_saved; 02828 // it_train = it_train_saved; 02829 // if (!balance) // restore unbalanced 02830 // datasource<Tnet,Tdata>::restore_state(); 02831 // else { // restore balanced iterators 02832 // for (uint k = 0; k < bal_it.size(); ++k) { 02833 // bal_it[k] = bal_it_saved[k]; 02834 // for (uint l = 0; l < bal_indices[k].size(); ++l) 02835 // bal_indices[k][l] = bal_indices_saved[k][l]; 02836 // } 02837 // class_it = class_it_saved; 02838 // } 02839 // } 02840 02842 // pretty methods 02843 02844 // template <typename Tnet, typename Tdata, typename Tlabel> 02845 // void hierarchy_datasource<Tnet,Tdata,Tlabel>::pretty_progress(bool newline) { 02846 // if (epoch_show > 0 && epoch_pick_cnt % epoch_show == 0 && 02847 // epoch_show_printed != (intg) epoch_pick_cnt) { 02848 // datasource<Tnet,Tdata>::pretty_progress(false); 02849 // if (balance) { 02850 // cout << ", remaining:"; 02851 // for (uint i = 0; i < epoch_done_counters.size(); ++i) { 02852 // cout << " " << i << ": " << epoch_done_counters[i]; 02853 // } 02854 // } 02855 // if (newline) 02856 // cout << endl; 02857 // } 02858 // } 02859 02860 template <typename Tnet, typename Tdata, typename Tlabel> 02861 void hierarchy_datasource<Tnet, Tdata, Tlabel>::pretty() { 02862 cout << _name << ": hierarchy class dataset \"" << _name 02863 << "\" contains " 02864 << data.dim(0) << " samples of dimension " << sampledims 02865 << " and defines an epoch as " << epoch_sz << " samples." << endl; 02866 if (lblstr) { 02867 cout << this->_name << ": It has " << nclasses << " classes: "; 02868 uint i; 02869 for (i = 0; i < this->counts.size() - 1; ++i) 02870 cout << this->counts[i] << " \"" << *(*lblstr)[i] << "\", "; 02871 cout << "and " << this->counts[i] << " \"" << *(*lblstr)[i] << "\"."; 02872 cout << endl; 02873 // pretty hierarchy 02874 cout << "Hierarchy by depth:" << endl; 02875 for (typename vector<vector<class_node<Tlabel>*>*>::iterator 02876 j = all_depths.begin(); j != all_depths.end(); ++j) { 02877 vector<class_node<Tlabel>*> &d = **j; 02878 cout << "depth " << j - all_depths.begin() << " (" 02879 << d.size() << " nodes): "; 02880 for (typename vector<class_node<Tlabel>*>::iterator k = d.begin(); 02881 k != d.end(); ++k) 02882 cout << (*k)->name() << " "; 02883 cout << endl; 02884 } 02885 // pretty complete-depth hierarchy 02886 cout << "Hierarchy by depth, keeping lower depth nodes with samples:" 02887 << endl; 02888 for (typename vector<vector<class_node<Tlabel>*>*>::iterator 02889 j = complete_depths.begin(); j != complete_depths.end(); ++j) { 02890 vector<class_node<Tlabel>*> &d = **j; 02891 cout << "depth " << j - complete_depths.begin() << " (" 02892 << d.size() << " nodes): "; 02893 for (typename vector<class_node<Tlabel>*>::iterator k = d.begin(); 02894 k != d.end(); ++k) 02895 cout << (*k)->name() << " "; 02896 cout << endl; 02897 } 02898 } 02899 } 02900 02901 template <typename Tnet, typename Tdata, typename Tlabel> 02902 void hierarchy_datasource<Tnet, Tdata, Tlabel>::print_path(Tlabel l) { 02903 bool first = true; 02904 for (class_node<Tlabel> *n = all_nodes[l]; n != NULL; n = n->get_parent()) { 02905 if (first) 02906 first = false; 02907 else 02908 cout << " <- "; 02909 cout << n->name(); 02910 } 02911 } 02912 02914 // labeled_pair_datasource 02915 02916 // constructor 02917 template <typename Tnet, typename Tdata, typename Tlabel> 02918 labeled_pair_datasource<Tnet, Tdata, Tlabel>:: 02919 labeled_pair_datasource(const char *data_fname, const char *labels_fname, 02920 const char *classes_fname, const char *pairs_fname, 02921 const char *name_, Tdata b, float c) 02922 : labeled_datasource<Tnet, Tdata, Tlabel>(data_fname, labels_fname, 02923 classes_fname, name_, b, c), 02924 pairs(1, 1) { //, pairsIter(pairs, 0) { 02925 // init current class 02926 try { 02927 pairs = load_matrix<intg>(pairs_fname); 02928 } catch(string &err) { 02929 cerr << "error: " << err << endl; 02930 cerr << "failed to load dataset file " << pairs_fname << endl; 02931 eblerror("Failed to load dataset file"); 02932 } 02933 eblerror("not implemented"); 02934 // typename idx<intg>::dimension_iterator diter(pairs, 0); 02935 // pairsIter = diter; 02936 } 02937 02938 // constructor 02939 template <typename Tnet, typename Tdata, typename Tlabel> 02940 labeled_pair_datasource<Tnet, Tdata, Tlabel>:: 02941 labeled_pair_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_, 02942 idx<ubyte> &classes_, idx<intg> &pairs_, 02943 const char *name_, Tdata b, float c) 02944 : labeled_datasource<Tnet, Tdata, Tlabel>(data_, labels_, classes_, name_, 02945 b, c), 02946 pairs(pairs_) { //, pairsIter(pairs, 0) { 02947 } 02948 02949 // destructor 02950 template <typename Tnet, typename Tdata, typename Tlabel> 02951 labeled_pair_datasource<Tnet, Tdata, Tlabel>::~labeled_pair_datasource() { 02952 } 02953 02954 // fprop pair 02955 template <typename Tnet, typename Tdata, typename Tlabel> 02956 void labeled_pair_datasource<Tnet, Tdata, Tlabel>:: 02957 fprop(bbstate_idx<Tnet> &in1, bbstate_idx<Tnet> &in2, 02958 bbstate_idx<Tlabel> &label) { 02959 eblerror("fixme"); 02960 // in1.resize(this->sample_dims()); 02961 // in2.resize(this->sample_dims()); 02962 // intg id1 = pairsIter.get(0), id2 = pairsIter.get(1); 02963 // Tlabel lab = this->labels.get(id1); 02964 // label.x.set(lab); 02965 // idx<Tdata> im1 = this->data[id1], im2 = this->data[id2]; 02966 // idx_copy(im1, in1.x); 02967 // idx_copy(im2, in2.x); 02968 // idx_addc(in1.x, this->bias, in1.x); 02969 // idx_dotc(in1.x, this->coeff, in1.x); 02970 // idx_addc(in2.x, this->bias, in2.x); 02971 // idx_dotc(in2.x, this->coeff, in2.x); 02972 } 02973 02974 // next pair 02975 template <typename Tnet, typename Tdata, typename Tlabel> 02976 bool labeled_pair_datasource<Tnet, Tdata, Tlabel>::next() { 02977 eblerror("not implemented"); 02978 // ++pairsIter; 02979 // if(!pairsIter.notdone()) { 02980 // pairsIter = pairs.dim_begin(0); 02981 // return false; 02982 // } 02983 return true; 02984 } 02985 02986 // begin pair 02987 template <typename Tnet, typename Tdata, typename Tlabel> 02988 void labeled_pair_datasource<Tnet, Tdata, Tlabel>::seek_begin() { 02989 eblerror("not implemented"); 02990 // pairsIter = pairs.dim_begin(0); 02991 } 02992 02993 template <typename Tnet, typename Tdata, typename Tlabel> 02994 unsigned int labeled_pair_datasource<Tnet, Tdata, Tlabel>::size() { 02995 return pairs.dim(0); 02996 } 02997 02999 // mnist_datasource 03000 03001 template <typename Tnet, typename Tdata, typename Tlabel> 03002 mnist_datasource<Tnet, Tdata, Tlabel>:: 03003 mnist_datasource(const char *root, bool train_data, uint size) { 03004 try { 03005 // load dataset 03006 string datafile, labelfile, name, setname = "MNIST"; 03007 if (train_data) // training set 03008 setname = "train"; 03009 else // testing set 03010 setname = "t10k"; 03011 datafile << root << "/" << setname << "-images-idx3-ubyte"; 03012 labelfile << root << "/" << setname << "-labels-idx1-ubyte"; 03013 name << "MNIST " << setname; 03014 idx<Tdata> dat = load_matrix<Tdata>(datafile); 03015 idx<Tlabel> labs = load_matrix<Tlabel>(labelfile); 03016 dat = dat.narrow(0, MIN(dat.dim(0), (intg) size), 0); 03017 labs = labs.narrow(0, MIN(labs.dim(0), (intg) size), 0); 03018 mnist_datasource<Tnet,Tdata,Tlabel>::init(dat, labs, name.c_str()); 03019 if (!train_data) 03020 this->set_test(); // remember that this is the testing set 03021 this->init_epoch(); 03022 this->pretty(); // print info about dataset 03023 } catch(string &err) { 03024 eblerror("failed to load mnist dataset: " << err); 03025 } catch(bad_alloc& ba) { 03026 eblerror("bad_alloc: " << ba.what()); 03027 } 03028 } 03029 03030 template <typename Tnet, typename Tdata, typename Tlabel> 03031 mnist_datasource<Tnet, Tdata, Tlabel>:: 03032 mnist_datasource(const char *root, const char *name, uint size) { 03033 try { 03034 // load dataset 03035 string datafile, labelfile; 03036 datafile << root << "/" << name << "_" << DATA_NAME << MATRIX_EXTENSION; 03037 labelfile << root << "/" << name 03038 << "_" << LABELS_NAME << MATRIX_EXTENSION; 03039 idx<Tdata> dat = load_matrix<Tdata>(datafile); 03040 idx<Tlabel> labs = load_matrix<Tlabel>(labelfile); 03041 dat = dat.narrow(0, MIN((uint) dat.dim(0), size), 0); 03042 labs = labs.narrow(0, MIN((uint) labs.dim(0), size), 0); 03043 mnist_datasource<Tnet,Tdata,Tlabel>::init(dat, labs, name); 03044 this->init_epoch(); 03045 this->pretty(); // print info about dataset 03046 } catch(string &err) { 03047 eblerror("failed to load mnist dataset: " << err); 03048 } catch(bad_alloc& ba) { 03049 eblerror("bad_alloc: " << ba.what()); 03050 } 03051 } 03052 03053 template <typename Tnet, typename Tdata, typename Tlabel> 03054 mnist_datasource<Tnet, Tdata, Tlabel>::~mnist_datasource() { 03055 } 03056 03058 03059 template <typename Tnet, typename Tdata, typename Tlabel> 03060 void mnist_datasource<Tnet, Tdata, Tlabel>:: 03061 fprop_data(bbstate_idx<Tnet> &out) { 03062 if (out.x.order() != sampledims.order()) 03063 out = bbstate_idx<Tnet>(sampledims); 03064 else 03065 out.resize(this->sample_dims()); 03066 idx<Tdata> dat; 03067 if (multimat) 03068 dat = datas.get(it); 03069 else 03070 dat = data[it]; 03071 uint ni = data.dim(1); 03072 uint nj = data.dim(2); 03073 uint di = (uint) (0.5 * (height - ni)); 03074 uint dj = (uint) (0.5 * (width - nj)); 03075 out.clear_x(); 03076 idx<Tnet> tgt = out.x.select(0, 0); 03077 tgt = tgt.narrow(0, ni, di); 03078 tgt = tgt.narrow(1, nj, dj); 03079 idx_copy(dat, tgt); 03080 // bias and coeff 03081 if (bias != 0) 03082 idx_addc(out.x, bias, out.x); 03083 if (coeff != 1) 03084 idx_dotc(out.x, coeff, out.x); 03085 } 03086 03087 template <typename Tnet, typename Tdata, typename Tlabel> 03088 void mnist_datasource<Tnet, Tdata, Tlabel>:: 03089 init(idx<Tdata> &data_, idx<Tlabel> &labels_, const char *name_) { 03090 class_datasource<Tnet, Tdata, Tlabel>::init(data_, labels_, NULL, name_); 03091 this->set_data_coeff(.01); // scale input data from [0,255] to [0,2.55] 03092 // mnist is actually 28x28, but we add some padding 03093 sampledims = idxdim(1, 32, 32); 03094 height = sampledims.dim(1); 03095 width = sampledims.dim(2); 03096 } 03097 03099 03100 template <typename Tdata> 03101 idx<Tdata> create_target_matrix(intg nclasses, Tdata target) { 03102 // fill matrix with 1-of-n code 03103 idx<Tdata> targets(nclasses, nclasses); 03104 idx_fill(targets, -target); 03105 for (int i = 0; i < nclasses; ++i) { 03106 targets.set(target, i, i); 03107 } 03108 return targets; // return by copy 03109 } 03110 03111 } // end namespace ebl 03112 03113 #endif /*DATASOURCE_HPP_*/