libeblearn
/home/rex/ebltrunk/core/libeblearn/include/datasource.hpp
00001 /***************************************************************************
00002  *   Copyright (C) 2011 by Pierre Sermanet   *
00003  *   pierre.sermanet@gmail.com   *
00004  *   All rights reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions are met:
00008  *     * Redistributions of source code must retain the above copyright
00009  *       notice, this list of conditions and the following disclaimer.
00010  *     * Redistributions in binary form must reproduce the above copyright
00011  *       notice, this list of conditions and the following disclaimer in the
00012  *       documentation and/or other materials provided with the distribution.
00013  *     * Redistribution under a license not approved by the Open Source
00014  *       Initiative (http://www.opensource.org) must display the
00015  *       following acknowledgement in all advertising material:
00016  *        This product includes software developed at the Courant
00017  *        Institute of Mathematical Sciences (http://cims.nyu.edu).
00018  *     * The names of the authors may not be used to endorse or promote products
00019  *       derived from this software without specific prior written permission.
00020  *
00021  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
00022  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024  * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY
00025  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00026  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00027  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00028  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00029  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00030  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031  ***************************************************************************/
00032 
00033 #ifndef DATASOURCE_HPP_
00034 #define DATASOURCE_HPP_
00035 
00036 #include <ostream>
00037 #include <algorithm>
00038 
00039 using namespace std;
00040 
00041 namespace ebl {
00042 
00044   // datasource
00045 
00046   template <typename Tnet, typename Tdata>
00047   datasource<Tnet,Tdata>::datasource() {
00048   }
00049 
00050   template <typename Tnet, typename Tdata>
00051   datasource<Tnet,Tdata>::
00052   datasource(midx<Tdata> &data_, const char *name_) {
00053     multimat = true; // data matrix is composed of multiple matrices
00054     init(data_, name_);
00055     init_epoch();
00056     pretty(); // print information about the dataset
00057   }
00058 
00059   template <typename Tnet, typename Tdata>
00060   datasource<Tnet,Tdata>::
00061   datasource(idx<Tdata> &data_, const char *name_) {
00062     multimat = false; // data matrix is composed of multiple matrices
00063     init(data_, name_);
00064     init_epoch();
00065     pretty(); // print information about the dataset
00066   }
00067 
00068   template <typename Tnet, typename Tdata>
00069   datasource<Tnet,Tdata>::
00070   datasource(const char *data_fname, const char *name_) {
00071     try {
00072       if (has_multiple_matrices(data_fname)) {
00073         multimat = true;
00074         midx<Tdata> datas_ = load_matrices<Tdata>(data_fname);
00075         init(datas_, name_);
00076       } else {
00077         multimat = false;
00078         idx<Tdata> data_ = load_matrix<Tdata>(data_fname);
00079         init(data_, name_);
00080       }
00081       init_epoch();
00082       pretty(); // print information about the dataset
00083     } eblcatcherror();
00084   }
00085 
00086   template <typename Tnet, typename Tdata>
00087   datasource<Tnet,Tdata>::~datasource() {
00088   }
00089 
00091   // init methods
00092 
00093   template<typename Tnet, typename Tdata>
00094   void datasource<Tnet,Tdata>::
00095   init(midx<Tdata> &datas_, const char *name_) {
00096     datas = datas_;
00097     data = (idx<Tdata>&) datas_;
00098     multimat = true;
00099     init2(name_);
00100   }
00101 
00102   template<typename Tnet, typename Tdata>
00103   void datasource<Tnet,Tdata>::
00104   init(idx<Tdata> &data_, const char *name_) {
00105     multimat = false;
00106     data = data_;
00107     init2(name_);
00108   }
00109 
00110   template<typename Tnet, typename Tdata>
00111   void datasource<Tnet,Tdata>::
00112   init2(const char *name_) {
00113     // init randomization
00114     if (!drand_ini) // only re-init if not initialized
00115       dynamic_init_drand(); // initialize random seed
00116     // no bias and coeff by default (0 and 1)
00117     bias = (Tnet) 0;
00118     coeff = (Tnet) 1.0;
00119     // iterating
00120     it = 0;
00121     it_test = 0;
00122     it_train = 0;
00123     shuffle_passes = false;
00124     test_set = false;
00125     epoch_sz = 0;
00126     epoch_cnt = 0;
00127     epoch_pick_cnt = 0;
00128     epoch_mode = 1; // default (1): all samples are seen at least once.
00129     hardest_focus = false;
00130     _ignore_correct = false;
00131     // state saving
00132     state_saved = false;
00133     // buffers assigments/allocations
00134     indices = idx<intg>(data.dim(0));
00135     indices_saved = idx<intg>(data.dim(0));
00136     probas = idx<double>(data.dim(0));
00137     energies = idx<double>(data.dim(0));
00138     raw_outputs = idx<Tnet>(1, 1);
00139     pick_count = idx<uint>(data.dim(0));
00140     correct = idx<ubyte>(data.dim(0));
00141     answers = idx<Tnet>(1, 1);
00142     targets = idx<Tnet>(1, 1);
00143     // pickings
00144     idx_clear(pick_count);
00145     count_pickings = true;
00146     sample_min_proba = 0.0;
00147     // intialize buffers
00148     idx_fill(correct, 0);
00149     idx_fill(answers, 0);
00150     idx_fill(targets, 0);
00151     idx_fill(probas, 1.0); // default picking probability for a sample is 1
00152     idx_fill(energies, -1.0);
00153     idx_fill(raw_outputs, 0);
00154     _name = (name_ ? name_ : "Unknown Dataset");
00155     // iterating
00156     set_shuffle_passes(true); // for next_train only
00157     set_weigh_samples(true, true, true, 0.0); // for next_train only
00158     seek_begin();
00159     seek_begin_train();
00160     epoch_sz = size(); //get_lowest_common_size();
00161     epoch_mode = 1;
00162     cout << _name << ": Each training epoch sees " << epoch_sz
00163          << " samples." << endl;
00164     not_picked = 0;
00165     epoch_show = 50; // print epoch count message every epoch_show
00166     epoch_show_printed = -1; // last epoch count we have printed
00167     // fill indices with original data order
00168     for (it = 0; it < data.dim(0); ++it)
00169       indices.set(it, it);
00170     // set sample dimensions
00171     if (multimat) {
00172       bool found = false;
00173       uint i = 0;
00174       idx<Tdata> e;      
00175       if (datas.order() == 2) {
00176         for (intg i = 0; i < datas.dim(0) && !found; ++i) {
00177           for (intg j = 0; j < datas.dim(1); ++j) {
00178             if (datas.exists(i, j)) {
00179               found = true;
00180               e = datas.get(i, j);
00181               samplemfdims.push_back_new(e.get_idxdim());
00182             }
00183           }
00184         }
00185       } else {
00186         while (!found && i < datas.dim(0)) {
00187           if (datas.exists(i)) {
00188             found = true;
00189             e = datas.get(i);
00190             samplemfdims.push_back_new(e.get_idxdim());
00191           }
00192           cout << endl;
00193           i++;
00194         }
00195       }
00196       if (!found)
00197         eblerror("no sample found in multi-matrix data " << datas);
00198       sampledims = e.get_idxdim();
00199     } else
00200       sampledims = idxdim(data.select(0, 0));
00201     if (sampledims.order() == 2)
00202       sampledims.insert_dim(0, 1);
00203     if (sampledims.order() > 2) {
00204       height = sampledims.dim(1);
00205       width = sampledims.dim(2);
00206     }
00207     // initialize index to 0
00208     it = 0;
00209     // shuffle data indices
00210     shuffle();
00211     bkeep_outputs = false;
00212   }
00213 
00215   // accessors
00216 
00217   template <typename Tnet, typename Tdata>
00218   unsigned int datasource<Tnet,Tdata>::size() {
00219     if (multimat)
00220       return datas.dim(0);
00221     return data.dim(0);
00222   }
00223 
00224   template <typename Tnet, typename Tdata>
00225   idxdim datasource<Tnet,Tdata>::sample_dims() {
00226     return sampledims;
00227   }
00228 
00229   template <typename Tnet, typename Tdata>
00230   mfidxdim datasource<Tnet,Tdata>::sample_mfdims() {
00231     return samplemfdims;
00232   }
00233 
00234   template <typename Tnet, typename Tdata>
00235   string& datasource<Tnet,Tdata>::name() {
00236     return _name;
00237   }
00238 
00239   template <typename Tnet, typename Tdata>
00240   void datasource<Tnet,Tdata>::set_test() {
00241     test_set = true;
00242     cout << _name << ": This is a testing set only." << endl;
00243   }
00244 
00245   template <typename Tnet, typename Tdata>
00246   bool datasource<Tnet,Tdata>::is_test() {
00247     return test_set;
00248   }
00249 
00250   template <typename Tnet, typename Tdata>
00251   intg datasource<Tnet,Tdata>::get_epoch_size() {
00252     return epoch_sz;
00253   }
00254 
00255   template <typename Tnet, typename Tdata>
00256   intg datasource<Tnet,Tdata>::get_epoch_count() {
00257     return epoch_cnt;
00258   }
00259 
00260   template <typename Tnet, typename Tdata>
00261   void datasource<Tnet,Tdata>::set_epoch_size(intg sz) {
00262     cout << _name << ": Setting epoch size to " << sz << endl;
00263     epoch_sz = sz;
00264   }
00265 
00266   template <typename Tnet, typename Tdata>
00267   void datasource<Tnet,Tdata>::set_epoch_mode(uint mode) {
00268     epoch_mode = mode;
00269     cout << _name << ": Setting epoch mode to " << epoch_mode;
00270     switch (epoch_mode) {
00271     case 0: cout << " (fixed number of samples)" << endl; break ;
00272     case 1: cout << " (see all samples at least once)" << endl; break ;
00273     default: eblerror("unknown mode");
00274     }
00275   }
00276 
00278   // data access methods
00279 
00280   template <typename Tnet, typename Tdata> template <class Tstate>
00281   void datasource<Tnet,Tdata>::fprop_data(mstate<Tstate> &out) {
00282     // number of states to put in out
00283     uint nstates = 1;
00284     if (multimat && datas.order() == 2) {
00285       // count actual number of submatrices for this sample
00286       for (nstates = 0; nstates < datas.dim(1); ++nstates)
00287         if (!datas.exists(it, nstates)) break ;
00288       EDEBUG("Number of submatrices for sample " << it << ": " << nstates);
00289     }
00290     // reallocate if necessary
00291     if (out.size() != nstates) {
00292       out.clear();
00293       idxdim d = sampledims;
00294       d.setdims(1);
00295       for (uint i = 0; i < nstates; ++i) {
00296         Tstate ts(d);
00297         out.push_back(new Tstate(ts));
00298       }
00299     }
00300     // copy data
00301     if (multimat) { // multiple matrices per sample
00302       for (uint i = 0; i < nstates; ++i) {
00303         idx<Tdata> dat;
00304         if (datas.order() == 2)
00305           dat = datas.get(it, i);
00306         else if (datas.order() == 1)
00307           dat = datas.get(it);
00308         else eblerror("not implemented");
00309         // resize if necessary
00310         idxdim d(dat);
00311         Tstate &ts = out[i];
00312         if (ts.x.get_idxdim() != d)
00313           ts.resize(d);
00314         // copy (and cast) data
00315         idx_copy(dat, ts.x);
00316         if (bias != 0.0)
00317           idx_addc(ts.x, bias, ts.x);
00318         if (coeff != 1.0)
00319           idx_dotc(ts.x, coeff, ts.x);
00320       }
00321     } else { // single matrix per sample
00322       // resize if necessary
00323       Tstate &ts = out[0];
00324       if (ts.x.get_idxdim() != sampledims)
00325         ts.resize(sampledims);
00326       // copy data
00327       idx<Tdata> dat = data[it];
00328       idx_copy(dat, ts.x);
00329       if (bias != 0.0)
00330         idx_addc(ts.x, bias, ts.x);
00331       if (coeff != 1.0)
00332         idx_dotc(ts.x, coeff, ts.x);
00333     }
00334   }
00335 
00336   template <typename Tnet, typename Tdata>
00337   void datasource<Tnet,Tdata>::fprop_data(fstate_idx<Tnet> &out) {
00338     if (out.x.order() != sampledims.order())
00339       out = fstate_idx<Tnet>(sampledims);
00340     else
00341       out.resize(sampledims);
00342     idx<Tdata> dat;
00343     if (multimat)
00344       dat = datas.get(it);
00345     else
00346       dat = data[it];
00347     idx_copy(dat, out.x);
00348     if (bias != 0.0)
00349       idx_addc(out.x, bias, out.x);
00350     if (coeff != 1.0)
00351       idx_dotc(out.x, coeff, out.x);
00352   }
00353 
00354   template <typename Tnet, typename Tdata>
00355   void datasource<Tnet,Tdata>::fprop_data(bbstate_idx<Tnet> &out) {
00356     if (out.x.order() != sampledims.order())
00357       out = bbstate_idx<Tnet>(sampledims);
00358     else
00359       out.resize(sampledims);
00360     fprop_data((fstate_idx<Tnet>&) out);
00361   }
00362 
00363   template <typename Tnet, typename Tdata>
00364   void datasource<Tnet,Tdata>::fprop(bbstate_idx<Tnet> &out) {
00365     if (out.x.order() != sampledims.order())
00366       out = bbstate_idx<Tnet>(sampledims);
00367     else
00368       out.resize(sampledims);
00369     fprop_data((fstate_idx<Tnet>&) out);
00370   }
00371 
00372   template <typename Tnet, typename Tdata>
00373   idx<Tdata> datasource<Tnet,Tdata>::get_sample(intg index) {
00374     if (multimat)
00375       return datas.get(index);
00376     else
00377       return data[index];
00378   }
00379 
00380   template <typename Tnet, typename Tdata>
00381   idx<Tnet> datasource<Tnet,Tdata>::get_raw_output(intg index) {
00382     if (index >= 0)
00383       return raw_outputs[index];
00384     return raw_outputs[it];
00385   }
00386 
00388   // iterating methods
00389 
00390   template <typename Tnet, typename Tdata>
00391   void datasource<Tnet,Tdata>::select_sample(intg index) {
00392     if (index < 0 || index >= data.dim(0))
00393       eblthrow("cannot select index " << index
00394                << " in datasource of dimensions " << data);
00395     it = index;
00396   }
00397 
00398   template <typename Tnet, typename Tdata>
00399   void datasource<Tnet,Tdata>::shuffle() {
00400     // shuffle indices to the data
00401     idx_shuffle(indices);
00402   }
00403 
00404   template <typename Tnet, typename Tdata>
00405   bool datasource<Tnet,Tdata>::next() {
00406     // increment test iterator
00407     it_test++;
00408     // reset if reached end
00409     if (it_test >= data.dim(0)) {
00410       seek_begin();
00411       return false;
00412     }
00413     // set main iterator used by fprop
00414     it = it_test;
00415     return true;
00416   }
00417 
00418   template <typename Tnet, typename Tdata>
00419   bool datasource<Tnet,Tdata>::next_train() {
00420     // check that this datasource is allowed to call this method
00421     if (test_set)
00422       eblerror("forbidden call of next_train() on testing sets");
00423     bool pick = false;
00424     not_picked++;
00425     // increment iterator
00426     it_train++;
00427     // reset if reached end
00428     if (it_train >= indices.dim(0)) {
00429       if (shuffle_passes)
00430         shuffle(); // shuffle indices to the data
00431       // reset iterator
00432       seek_begin_train();
00433       // normalize probabilities, mapping [0..max] to [0..1]
00434       if (weigh_samples)
00435         normalize_probas();
00436     }
00437     it = indices.get(it_train); // set main iterator to the train iterator
00438     // recursively loop until we find a sample that is picked for this class
00439     pick = this->pick_current();
00440     epoch_cnt++;
00441     if (pick) {
00442 #ifdef __DEBUG__
00443       cout << "Picking sample " << it << ", pickings: " << pick_count.get(it)
00444            << ", energy: " << energies.get(it) << ", correct: "
00445            << (int) correct.get(it);
00446       if (weigh_samples) cout << ", proba: " << probas.get(it) << ")";
00447       cout << endl;
00448 #endif
00449       // increment pick counter for this sample
00450       if (count_pickings) pick_count.set(pick_count.get(it) + 1, it);
00451       // increment sample counter
00452       epoch_pick_cnt++;
00453       not_picked = 0;
00454       return true;
00455     } else {
00456       EDEBUG("Not picking sample " << it << ", pickings: " << pick_count.get(it)
00457             << ", energy: " << energies.get(it) << ", correct: "
00458             << (int) correct.get(it) << ", proba: " << probas.get(it) << ")");
00459       return false;
00460     }
00461   }
00462 
00463   // accessors ///////////////////////////////////////////////////////////////
00464 
00465   template <typename Tnet, typename Tdata>
00466   void datasource<Tnet,Tdata>::set_data_bias(Tnet b) {
00467     bias = b;
00468     cout << _name << ": Setting data bias to " << bias << endl;
00469   }
00470 
00471   template <typename Tnet, typename Tdata>
00472   void datasource<Tnet,Tdata>::set_data_coeff(Tnet c) {
00473     coeff = c;
00474     cout << _name << ": Setting data coefficient to " << coeff << endl;
00475   }
00476 
00477   template <typename Tnet, typename Tdata>
00478   bool datasource<Tnet,Tdata>::epoch_done() {
00479     switch (epoch_mode) {
00480     case 0: // fixed number of samples
00481       if (epoch_cnt >= epoch_sz)
00482         return true;
00483       break ;
00484     case 1: // see all samples at least once
00485       // TODO: same as case 0?
00486       if (epoch_cnt >= epoch_sz)
00487         return true;
00488       break ;
00489     default: eblerror("unknown epoch_mode");
00490     }
00491     return false;
00492   }
00493 
00494   template <typename Tnet, typename Tdata>
00495   void datasource<Tnet,Tdata>::init_epoch() {
00496     epoch_cnt = 0;
00497     epoch_pick_cnt = 0;
00498     epoch_timer.restart();
00499     epoch_show_printed = -1; // last epoch count we have printed
00500     // if we have prior information about each sample energy and classification
00501     // let's use it to initialize the picking probabilities.
00502     if (weigh_samples)
00503       this->normalize_all_probas();
00504   }
00505 
00506   template <typename Tnet, typename Tdata>
00507   void datasource<Tnet,Tdata>::seek_begin() {
00508     it_test = 0; // reset test iterator
00509     it = it_test; // set main iterator to test iterator
00510     test_timer.restart();
00511   }
00512 
00513   template <typename Tnet, typename Tdata>
00514   void datasource<Tnet,Tdata>::seek_begin_train() {
00515     // reset train iterator
00516     it_train = 0;
00517     // set main iterator to train iterator
00518     it = indices.get(it_train);
00519   }
00520 
00521   template <typename Tnet, typename Tdata>
00522   void datasource<Tnet,Tdata>::set_shuffle_passes(bool activate) {
00523     shuffle_passes = activate;
00524     cout << _name
00525          << ": Shuffling of samples (training only) after each pass is "
00526          << (shuffle_passes ? "activated" : "deactivated") << "." << endl;
00527   }
00528 
00530   // picking probability methods
00531 
00532   template <typename Tnet, typename Tdata>
00533   void datasource<Tnet,Tdata>::normalize_all_probas() {
00534     if (weigh_samples)
00535       normalize_probas();
00536   }
00537 
00538   template <typename Tnet, typename Tdata>
00539   void datasource<Tnet,Tdata>::normalize_probas(vector<intg> *cindices) {
00540     double maxproba = 0, minproba = (numeric_limits<double>::max)();
00541     double maxenergy = 0, sum = 0; //, energy_ratio, maxenergy2;
00542     vector<intg> allindices;
00543     if (weigh_samples && !is_test()) {
00544       if (!cindices) { // use all samples
00545         cout << _name << ": Normalizing all probabilities";
00546         allindices.resize(energies.dim(0)); // allocate
00547         for (intg i = 0; i < energies.dim(0); ++i)
00548           allindices[i] = i;
00549         cindices = &allindices;
00550       }
00551       idx<double> sorted_energies(cindices->size());
00552       // normalize probas for this class, mapping [0..max] to [0..1]
00553       maxenergy = 0; sum = 0;
00554       intg nincorrect = 0, ncorrect = 0, i = 0;
00555       // get max and sum
00556       for (vector<intg>::iterator j = cindices->begin();
00557            j != cindices->end(); ++j) {
00558         // don't take correct ones into account
00559         if (energies.get(*j) < 0) // energy not set yet
00560           continue ;
00561         if (correct.get(*j) == 1) { // correct
00562           ncorrect++;
00563           if (_ignore_correct)
00564             continue ; // skip this one
00565         } else
00566           nincorrect++;
00567         // max and sum
00568         maxenergy = (std::max)(energies.get(*j), maxenergy);
00569         sum += energies.get(*j);
00570         sorted_energies.set(energies.get(*j), i++);
00571       }
00572       cout << ", nincorrect: " << nincorrect << ", ncorrect: " << ncorrect;
00573       // no incorrect set all to 1
00574       if (!nincorrect) {
00575         idx_fill(probas, 1.0);
00576         cout << endl;
00577         return ;
00578       }
00579       // We choose 2 pivot points in the sorted energies curve,
00580       // one will be used as maximum energy and the other as the minimum
00581       // energy. This helps to have a meaningful range of energies not
00582       // biased by single extrema.
00583       double e1, e2;
00584       sorted_energies.resize(nincorrect);
00585       idx_sortup(sorted_energies);
00586       intg pivot1 = (intg) (sorted_energies.dim(0) * (float) .25);
00587       intg pivot2 = std::min(sorted_energies.dim(0) - 1,
00588                              (intg) (sorted_energies.dim(0)*(float)1.0));//.75);
00589       if (sorted_energies.dim(0) == 0) {
00590         e1 = 0; e2 = 1;
00591       } else {
00592         e1 = sorted_energies.get(pivot1);
00593         e2 = sorted_energies.get(pivot2);
00594       }
00595       // the ratio of total energies over n times the max energy
00596       //energy_ratio = sum / (maxenergy * cindices->size());
00597       // the max probability will be proportional to the energy ratio
00598       // this balances the probabilities so that outliers don't take
00599       // all the probabilites
00600       //maxenergy2 = maxenergy * energy_ratio;
00601       cout << ", max energy: " << maxenergy;
00602            // << ", energy ratio " << energy_ratio
00603            // << " and normalized max energy " << maxenergy2;
00604       // normalize
00605       for (vector<intg>::iterator j = cindices->begin();
00606            j != cindices->end(); ++j) {
00607         double e = energies.get(*j);
00608         // set proba 0 for correct samples if we ignore correct ones
00609         if (e >= 0 && _ignore_correct && correct.get(*j) == 1)
00610           probas.set(0.0, *j);
00611         else {
00612           // compute probas
00613           double den = e2 - e1;
00614           if (e < 0 || maxenergy == 0 || den == 0) // energy not set yet
00615             probas.set(1.0, *j);
00616           else {
00617             probas.set((std::max)((double) 0, (std::min)((e - e1) / den,
00618                                                           (double) 1)), *j);
00619             if (!hardest_focus) // concentrate on easiest misclassified
00620               probas.set(1 - probas.get(*j), *j); // reverse proba
00621             // iprobas.at(*j)->set((std::max)(sample_min_proba,
00622             //                             e / maxenergy2));
00623             // remember min and max proba
00624             maxproba = (std::max)(probas.get(*j), maxproba);
00625             minproba = (std::min)(probas.get(*j), minproba);
00626           }
00627         }
00628       }
00629       cout << ", Min/Max probas are: " << minproba << ", "
00630            << maxproba << endl;
00631     }
00632   }
00633 
00634   template <typename Tnet, typename Tdata>
00635   void datasource<Tnet,Tdata>::set_sample_energy(double e, bool correct_,
00636                                                  idx<Tnet> &raw_,
00637                                                  idx<Tnet> &answers_,
00638                                                  idx<Tnet> &target) {
00639     energies.set(e, it);
00640     correct.set(correct_ ? 1 : 0, it);
00641 
00642     // store model outputs for current sample
00643     if (bkeep_outputs) {
00644       // resize buffers if necessary
00645       idx<Tnet> ans = answers_.view_as_order(1);
00646       idx<Tnet> ra = raw_.view_as_order(1);
00647       idxdim d(ans), draw(ra);
00648       d.insert_dim(0, data.dim(0));
00649       draw.insert_dim(0, data.dim(0));
00650       if (raw_outputs.get_idxdim() != draw) {
00651         raw_outputs.resize(draw);
00652         idx_clear(raw_outputs);
00653       }
00654       if (answers.get_idxdim() != d) {
00655         answers.resize(d);
00656         idx_clear(answers);
00657       }
00658       if (targets.get_idxdim() != draw) {
00659         targets.resize(draw);
00660         idx_clear(targets);
00661       }
00662       // copy raw
00663       idx<Tnet> raw = raw_outputs.select(0, it);
00664       idx_copy(ra, raw);
00665       // copy answers
00666       idx<Tnet> answer = answers.select(0, it);
00667       idx_copy(ans, answer);
00668       // copy target
00669       idx<Tnet> tgt = targets.select(0, it);
00670       idx_copy(target, tgt);
00671     }
00672   }
00673 
00674   template <typename Tnet, typename Tdata>
00675   void datasource<Tnet,Tdata>::keep_outputs(bool keep) {
00676     bkeep_outputs = keep;
00677     cout << (bkeep_outputs ? "Keeping" : "Not keeping")
00678          << " model outputs for each sample." << endl;
00679   }
00680 
00681   template <typename Tnet, typename Tdata>
00682   void datasource<Tnet,Tdata>::save_pickings(const char *name_) {
00683     // plot file
00684     string name = "pickings";
00685     if (name_)
00686       name = name_;
00687     string fname = name;
00688     fname += ".plot";
00689     ofstream fp(fname.c_str());
00690     if (!fp) {
00691       cerr << "failed to open " << fname << endl;
00692       eblerror("failed to open file for writing");
00693     }
00694     eblerror("not implemented");
00695     // typename idx<uint>::dimension_iterator i = pick_count.dim_begin(0);
00696     // uint j = 0;
00697     // for ( ; i.notdone(); i++, j++)
00698     //   fp << j << " " << i->get() << endl;
00699     // fp.close();
00700     cout << _name << ": Wrote picking statistics in " << fname << endl;
00701     // p file
00702     string fname2 = name;
00703     fname2 += ".p";
00704     ofstream fp2(fname2.c_str());
00705     if (!fp2) {
00706       cerr << "failed to open " << fname2 << endl;
00707       eblerror("failed to open file for writing");
00708     }
00709     fp2 << "plot \"" << fname << "\" with impulse" << endl;
00710     fp2.close();
00711     cout << _name << ": Wrote gnuplot file in " << fname2 << endl;
00712   }
00713 
00714   template <typename Tnet, typename Tdata>
00715   void datasource<Tnet,Tdata>::ignore_correct(bool ignore) {
00716     _ignore_correct = ignore;
00717     if (ignore)
00718       cout << (ignore ? "Ignoring" : "Using") <<
00719         " correctly classified samples for training." << endl;
00720   }
00721 
00722   template <typename Tnet, typename Tdata>
00723   bool datasource<Tnet,Tdata>::mstate_samples() {
00724     return multimat;
00725   }
00726 
00727   template <typename Tnet, typename Tdata>
00728   bool datasource<Tnet,Tdata>::get_count_pickings() {
00729     return count_pickings;
00730   }
00731 
00732   template <typename Tnet, typename Tdata>
00733   void datasource<Tnet,Tdata>::set_count_pickings(bool count) {
00734     count_pickings = count;
00735   }
00736 
00737   template <typename Tnet, typename Tdata>
00738   void datasource<Tnet,Tdata>::
00739   set_weigh_samples(bool activate, bool hardest_focus_, bool perclass_norm_,
00740                     double min_proba) {
00741     weigh_samples = activate;
00742     hardest_focus = hardest_focus_;
00743     perclass_norm = perclass_norm_;
00744     sample_min_proba = MIN(1.0, min_proba);
00745     cout << _name
00746          << ": Weighing of samples (training only) based on classification is "
00747          << (weigh_samples ? "activated" : "deactivated") << "." << endl;
00748     if (activate) {
00749       cout << _name << ": learning is focused on "
00750            << (hardest_focus ? "hardest" : "easiest")
00751            << " misclassified samples" << endl;
00752       if (!_ignore_correct && !hardest_focus)
00753         cerr << "Warning: correct samples are not ignored and focus is on "
00754              << "easiest samples, this may not be optimal" << endl;
00755       cout << "Sample picking probabilities are normalized "
00756            << (perclass_norm ? "per class" : "globally")
00757            << " with minimum probability " << sample_min_proba << endl;
00758     }
00759   }
00760 
00762   // pretty methods
00763 
00764   template <typename Tnet, typename Tdata>
00765   void datasource<Tnet,Tdata>::pretty_progress(bool newline) {
00766     // train pretty
00767     intg i = epoch_cnt, sz = epoch_sz;
00768     string pre = "training: ";
00769     // test pretty
00770     if (is_test()) {
00771       i = it_test;
00772       sz = this->size();
00773       pre = "testing: ";
00774     }
00775     // common code
00776     if (epoch_show > 0 && i % epoch_show == 0 && epoch_show_printed != i) {
00777       epoch_show_printed = i; // remember last time printed
00778       cout << pre << i << " / " << sz
00779            << ", elapsed: " << test_timer.elapsed() << ", ETA: "
00780            << test_timer.
00781         elapsed((long) ((sz - i) *
00782                 (test_timer.elapsed_seconds()
00783                  /(double)std::max((intg)1,i))));
00784       if (newline)
00785         cout << endl;
00786     }
00787   }
00788 
00789   template <typename Tnet, typename Tdata>
00790   void datasource<Tnet,Tdata>::pretty() {
00791     cout << _name << ": dataset \"" << _name << "\" contains " << data.dim(0)
00792          << " samples of dimension " << sampledims
00793          << " and defines an epoch as " << epoch_sz << " samples." << endl;
00794   }
00795 
00797   // state saving
00798 
00799   template <typename Tnet, typename Tdata>
00800   void datasource<Tnet,Tdata>::save_state() {
00801     state_saved = true;
00802     count_pickings_save = count_pickings;
00803     it_saved = it; // save main iterator
00804     it_test_saved = it_test;
00805     it_train_saved = it_train;
00806     for (intg k = 0; k < indices.dim(0); ++k)
00807       indices_saved[k] = indices[k];
00808   }
00809 
00810   template <typename Tnet, typename Tdata>
00811   void datasource<Tnet,Tdata>::restore_state() {
00812     if (!state_saved)
00813       eblerror("state not saved, call save_state() before restore_state()");
00814     count_pickings = count_pickings_save;
00815     it = it_saved; // restore main iterator
00816     it_test = it_test_saved;
00817     it_train = it_train_saved;
00818     for (intg k = 0; k < indices.dim(0); ++k)
00819       indices[k] = indices_saved[k];
00820   }
00821 
00822   template <typename Tnet, typename Tdata>
00823   void datasource<Tnet,Tdata>::set_epoch_show(uint modulo) {
00824     cout << _name << ": Print training count every " << modulo
00825          << " samples." << endl;
00826     epoch_show = modulo;
00827   }
00828 
00830   // protected pickings methods
00831 
00832   template <typename Tnet, typename Tdata>
00833   bool datasource<Tnet,Tdata>::pick_current() {
00834     if (test_set) // check that this datasource is allowed to call this method
00835       eblerror("forbidden call of pick_current() on testing sets");
00836     if (!weigh_samples) // always pick sample when not using probabilities
00837       return true;
00838     // draw random number between 0 and 1 and return true if lower
00839     // than sample's probability.
00840     double r = drand(); // [0..1]
00841     if (r <= probas.get(it))
00842       return true;
00843     return false;
00844   }
00845 
00846   template <typename Tnet, typename Tdata>
00847     map<uint,intg>& datasource<Tnet,Tdata>::get_pickings() {
00848     picksmap.clear();
00849     // typename idx<uint>::dimension_iterator i = pick_count.dim_begin(0);
00850     // uint j = 0;
00851     // for ( ; i.notdone(); i++, j++)
00852     //   picksmap[i->get()] = j;
00853     eblerror("not implemented");
00854     return picksmap;
00855   }
00856 
00858   // labeled_datasource
00859 
00860   template <typename Tnet, typename Tdata, typename Tlabel>
00861   labeled_datasource<Tnet, Tdata, Tlabel>::labeled_datasource() {
00862   }
00863 
00864   template <typename Tnet, typename Tdata, typename Tlabel>
00865   labeled_datasource<Tnet, Tdata, Tlabel>::
00866   labeled_datasource(midx<Tdata> &data_, idx<Tlabel> &labels_,
00867                      const char *name_) {
00868     init(data_, labels_, name_);
00869     this->init_epoch();
00870     this->pretty(); // print information about this dataset
00871   }
00872 
00873   template <typename Tnet, typename Tdata, typename Tlabel>
00874   labeled_datasource<Tnet, Tdata, Tlabel>::
00875   labeled_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_,
00876                      const char *name_) {
00877     init(data_, labels_, name_);
00878     this->init_epoch();
00879     this->pretty(); // print information about this dataset
00880   }
00881 
00882   template <typename Tnet, typename Tdata, typename Tlabel>
00883   labeled_datasource<Tnet, Tdata, Tlabel>::
00884   labeled_datasource(const char *root_dsname, const char *name_) {
00885     init_root(root_dsname, name_);
00886     this->init_epoch();
00887     this->pretty(); // print information about this dataset
00888   }
00889 
00890   template <typename Tnet, typename Tdata, typename Tlabel>
00891   labeled_datasource<Tnet, Tdata, Tlabel>::
00892   labeled_datasource(const char *root, const char *data_name,
00893                      const char *labels_name, const char *jitters_name,
00894                      const char *scales_name, const char *name_) {
00895     init_root(root, data_name, labels_name, jitters_name, scales_name, name_);
00896     this->init_epoch();
00897     this->pretty(); // print information about this dataset
00898   }
00899 
00900   template <typename Tnet, typename Tdata, typename Tlabel>
00901   labeled_datasource<Tnet, Tdata, Tlabel>::~labeled_datasource() {
00902   }
00903 
00905   // init methods
00906 
00907   template <typename Tnet, typename Tdata, typename Tlabel>
00908   void labeled_datasource<Tnet, Tdata, Tlabel>::
00909   init(midx<Tdata> &data_, idx<Tlabel> &labels_, const char *name) {
00910     scales_loaded = false;
00911     datasource<Tnet,Tdata>::init(data_, name);
00912     init_labels(labels_, name);
00913   }
00914 
00915   template <typename Tnet, typename Tdata, typename Tlabel>
00916   void labeled_datasource<Tnet, Tdata, Tlabel>::
00917   init(idx<Tdata> &data_, idx<Tlabel> &labels_, const char *name) {
00918     scales_loaded = false;
00919     datasource<Tnet,Tdata>::init(data_, name);
00920     init_labels(labels_, name);
00921   }
00922 
00923   template <typename Tnet, typename Tdata, typename Tlabel>
00924   void labeled_datasource<Tnet, Tdata, Tlabel>::
00925   init_labels(idx<Tlabel> &labels_, const char *name) {
00926     labels = labels_;
00927     label_bias = 0;
00928     label_coeff = 1;
00929     // set label dimensions
00930     labeldims = labels.get_idxdim();
00931     if (labeldims.order() > 1)
00932       labeldims.remove_dim(0);
00933     else
00934       labeldims.setdim(0, 1);
00935   }
00936 
00937   template <typename Tnet, typename Tdata, typename Tlabel>
00938   void labeled_datasource<Tnet, Tdata, Tlabel>::
00939   init(const char *data_fname, const char *labels_fname,
00940        const char *jitters_fname, const char *scales_fname, const char *name_,
00941        uint max_size) {
00942     // load jitters
00943     if (jitters_fname && strlen(jitters_fname) != 0) {
00944       try {
00945         jitters = load_matrices<float>(jitters_fname);
00946         jitters_maxdim = jitters.get_maxdim();
00947       }
00948       catch (string &err) { cerr << "warning: " << err << endl; }
00949     } else cout << "No jitter information loaded." << endl;
00950     // load labels
00951     idx<Tlabel> lab;
00952     try {
00953       lab = load_matrix<Tlabel>(labels_fname);
00954     } catch (string &err) {
00955       cerr << err << endl;
00956       eblerror("Failed to load dataset file");
00957     }
00958     // limit number of samples
00959     if (max_size > 0) {
00960       cout << "Limiting " << name_<< " to " << max_size << " samples." <<endl;
00961       //      lab = lab.narrow(0, std::min((intg) max_size, lab.dim(0)), 0);
00962     }
00963     // load data
00964     multimat = has_multiple_matrices(data_fname);
00965     try {
00966       if (multimat) {
00967         midx<Tdata> dat = load_matrices<Tdata>(data_fname);
00968         if (max_size > 0)
00969           dat = dat.narrow(0, std::min((intg) max_size, dat.dim(0)), 0);
00970         // init
00971         labeled_datasource<Tnet, Tdata, Tlabel>::init(dat, lab, name_);
00972       } else {
00973         idx<Tdata> dat = load_matrix<Tdata>(data_fname);
00974         if (max_size > 0)
00975           dat = dat.narrow(0, std::min((intg) max_size, dat.dim(0)), 0);
00976         // init
00977         labeled_datasource<Tnet, Tdata, Tlabel>::init(dat, lab, name_);
00978       }
00979     } catch (string &err) {
00980       cerr << err << endl;
00981       eblerror("Failed to load dataset file");
00982     }
00983     // load scales
00984     scales_loaded = false;
00985     if (scales_fname && strlen(scales_fname) != 0) {
00986       try {
00987         scales = load_matrix<intg>(scales_fname);
00988         scales_loaded = true;
00989       }
00990       catch (string &err) { cerr << "warning: " << err << endl; }
00991     } else cout << "No scale information loaded." << endl;
00992   }
00993 
00994   template <typename Tnet, typename Tdata, typename Tlabel>
00995   void labeled_datasource<Tnet, Tdata, Tlabel>::
00996   init_root(const char *root, const char *data_name, const char *labels_name,
00997             const char *jitters_name, const char *scales_name, 
00998             const char *name_) {
00999     string data_fname, labels_fname, classes_fname, jitters_fname, scales_fname;
01000     data_fname << root << "/" << data_name << "_" << DATA_NAME
01001                << MATRIX_EXTENSION;
01002     labels_fname << root << "/" << labels_name << "_" << LABELS_NAME
01003                  << MATRIX_EXTENSION;
01004     if (jitters_name)
01005       jitters_fname << root << "/" << jitters_name
01006                     << "_" << JITTERS_NAME << MATRIX_EXTENSION;
01007     if (scales_name)
01008       scales_fname << root << "/" << scales_name
01009                     << "_" << SCALES_NAME << MATRIX_EXTENSION;
01010     init(data_fname.c_str(), labels_fname.c_str(),
01011          jitters_name ? jitters_fname.c_str() : NULL, 
01012          scales_name ? scales_fname.c_str() : NULL, name_);
01013   }
01014 
01015   template <typename Tnet, typename Tdata, typename Tlabel>
01016   void labeled_datasource<Tnet, Tdata, Tlabel>::
01017   init_root(const char *root_dsname, const char *name_) {
01018     string data_fname, labels_fname, classes_fname, jitters_fname,
01019       scales_fname;
01020     data_fname << root_dsname << "_" << DATA_NAME << MATRIX_EXTENSION;
01021     labels_fname << root_dsname << "_" << LABELS_NAME << MATRIX_EXTENSION;
01022     classes_fname << root_dsname << "_" << CLASSES_NAME << MATRIX_EXTENSION;
01023     jitters_fname << root_dsname << "_" << JITTERS_NAME << MATRIX_EXTENSION;
01024     scales_fname << root_dsname << "_" << SCALES_NAME << MATRIX_EXTENSION;
01025     init(data_fname.c_str(), labels_fname.c_str(), jitters_fname.c_str(),
01026          scales_fname.c_str(), name_);
01027   }
01028 
01029   // data access ///////////////////////////////////////////////////////////////
01030 
01031   template <typename Tnet, typename Tdata, typename Tlabel>
01032   void labeled_datasource<Tnet, Tdata, Tlabel>::
01033   fprop(bbstate_idx<Tnet> &out, bbstate_idx<Tlabel> &label) {
01034     this->fprop_data(out);
01035     this->fprop_label(label);
01036   }
01037 
01038   template <typename Tnet, typename Tdata, typename Tlabel>
01039   void labeled_datasource<Tnet,Tdata,Tlabel>::
01040   fprop_label(fstate_idx<Tlabel> &label) {
01041     idx<Tlabel> lab = labels[it];
01042     idx_copy(lab, label.x);
01043     if (label_bias != 0)
01044       idx_addc(label.x, label_bias, label.x);
01045     if (label_coeff != 1)
01046       idx_dotc(label.x, label_coeff, label.x);
01047   }
01048 
01049   template <typename Tnet, typename Tdata, typename Tlabel>
01050   void labeled_datasource<Tnet,Tdata,Tlabel>::
01051   fprop_label_net(fstate_idx<Tnet> &label) {
01052     idx<Tlabel> lab = labels[it];
01053     idx_copy(lab, label.x);
01054     if (label_bias != 0)
01055       idx_addc(label.x, label_bias, label.x);
01056     if (label_coeff != 1)
01057       idx_dotc(label.x, label_coeff, label.x);
01058   }
01059 
01060   template <typename Tnet, typename Tdata, typename Tlabel>
01061   void labeled_datasource<Tnet,Tdata,Tlabel>::
01062   fprop_label_net(bbstate_idx<Tnet> &label) {
01063     fprop_label_net((fstate_idx<Tnet>&) label);
01064   }
01065 
01066   template <typename Tnet, typename Tdata, typename Tlabel>
01067   void labeled_datasource<Tnet,Tdata,Tlabel>::
01068   fprop_jitter(bbstate_idx<Tnet> &jitt) {
01069     if (jitters.order() < 1) eblerror("jitter information was not loaded");
01070     if (jitters.exists(it)) {
01071       idx<float> j = jitters.get(it);
01072       idxdim d(j.get_idxdim());
01073       if (jitt.x.get_idxdim() != d)
01074         jitt.resize(d);
01075       idx_copy(j, jitt.x);
01076     } else { // fprop an empty jitter
01077       idxdim d(jitters.get_maxdim());
01078       d.setdim(0, 1);
01079       if (jitt.x.get_idxdim() != d)
01080         jitt.resize(d);
01081       idx_clear(jitt.x);
01082     }
01083   }
01084 
01085   template <typename Tnet, typename Tdata, typename Tlabel>
01086   intg labeled_datasource<Tnet,Tdata,Tlabel>::fprop_scale() {
01087     if (!scales_loaded) eblthrow("scales information not present");
01088     return scales.get(it);
01089   }
01090 
01091   // accessors /////////////////////////////////////////////////////////////////
01092 
01093   template <typename Tnet, typename Tdata, typename Tlabel>
01094   bool labeled_datasource<Tnet,Tdata,Tlabel>::included_sample(intg index) {
01095     if (index >= data.dim(0))
01096       eblerror("cannot check inclusion of sample " << index
01097                << ", only " << data.dim(0) << " samples");
01098     return true;
01099   }
01100 
01101   template <typename Tnet, typename Tdata, typename Tlabel>
01102   intg labeled_datasource<Tnet,Tdata,Tlabel>::count_included_samples() {
01103     return this->size();
01104   }
01105 
01106   template <typename Tnet, typename Tdata, typename Tlabel>
01107   void labeled_datasource<Tnet, Tdata, Tlabel>::pretty() {
01108     cout << _name << ": (regression) labeled dataset \"" << _name
01109          << "\" contains "
01110          << data.dim(0) << " samples of dimension " << sampledims
01111          << " and defines an epoch as " << epoch_sz << " samples.";
01112     pretty_scales();
01113   }
01114 
01115   template <typename Tnet, typename Tdata, typename Tlabel>
01116   void labeled_datasource<Tnet, Tdata, Tlabel>::pretty_scales() {
01117     if (scales_loaded) {
01118       intg maxscale = idx_max(scales);
01119       vector<intg> tally(maxscale + 1, 0);
01120       idx_bloop1(scale, scales, intg) {
01121         intg s = scale.get();
01122         if (s < 0) eblerror("unexpected negative value");
01123         tally[s] = tally[s] + 1;
01124       }
01125       intg nscales = 0;
01126       for (intg i = 0; i < (intg) tally.size(); ++i)
01127         if (tally[i] > 0) nscales++;
01128       // print scales distribution for each class
01129       cout << _name << ": has " << nscales << " scales";
01130       for (intg i = 0; i < (intg) tally.size(); ++i)
01131         cout << ", " << i << ": " << tally[i];
01132       cout << endl;      
01133     } else cout << _name << ": no scales information." << endl;
01134   }
01135 
01136   template <typename Tnet, typename Tdata, typename Tlabel>
01137   idxdim labeled_datasource<Tnet, Tdata, Tlabel>::label_dims() {
01138     return labeldims;
01139   }
01140 
01141   template <typename Tnet, typename Tdata, typename Tlabel>
01142   void labeled_datasource<Tnet,Tdata,Tlabel>::set_label_bias(Tnet b) {
01143     label_bias = b;
01144     cout << _name << ": Setting labels bias to " << label_bias << endl;
01145   }
01146 
01147   template <typename Tnet, typename Tdata, typename Tlabel>
01148   void labeled_datasource<Tnet,Tdata,Tlabel>::set_label_coeff(Tnet c) {
01149     label_coeff = c;
01150     cout << _name << ": Setting labels coefficient to " << label_coeff << endl;
01151   }
01152 
01153   template <typename Tnet, typename Tdata, typename Tlabel>
01154   bool labeled_datasource<Tnet,Tdata,Tlabel>::has_scales() {
01155     return scales_loaded;
01156   }
01157 
01159   // class_datasource
01160 
01161   template <typename Tnet, typename Tdata, typename Tlabel>
01162   class_datasource<Tnet, Tdata, Tlabel>::class_datasource()
01163     : lblstr(NULL) {
01164     defaults();
01165   }
01166 
01167   template <typename Tnet, typename Tdata, typename Tlabel>
01168   class_datasource<Tnet, Tdata, Tlabel>::
01169   class_datasource(midx<Tdata> &data_, idx<Tlabel> &labels_,
01170                    vector<string*> *lblstr_, const char *name_) {
01171     defaults();
01172     init(data_, labels_, name_, lblstr_);
01173     this->init_epoch();
01174     this->pretty(); // print info about dataset
01175   }
01176 
01177   template <typename Tnet, typename Tdata, typename Tlabel>
01178   class_datasource<Tnet, Tdata, Tlabel>::
01179   class_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_,
01180                    vector<string*> *lblstr_, const char *name_) {
01181     defaults();
01182     init(data_, labels_, name_, lblstr_);
01183     this->init_epoch();
01184     this->pretty(); // print info about dataset
01185   }
01186 
01187   template <typename Tnet, typename Tdata, typename Tlabel>
01188   class_datasource<Tnet, Tdata, Tlabel>::
01189   class_datasource(midx<Tdata> &data_, idx<Tlabel> &labels_,
01190                    idx<ubyte> &classes, const char *name_) {
01191     defaults();
01192     init_strings(classes);
01193     init(data_, labels_, this->lblstr, name_);
01194     this->init_epoch();
01195     this->pretty(); // print info about dataset
01196   }
01197 
01198   template <typename Tnet, typename Tdata, typename Tlabel>
01199   class_datasource<Tnet, Tdata, Tlabel>::
01200   class_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_,
01201                    idx<ubyte> &classes, const char *name_) {
01202     defaults();
01203     init_strings(classes);
01204     init(data_, labels_, this->lblstr, name_);
01205     this->init_epoch();
01206     this->pretty(); // print info about dataset
01207   }
01208 
01209   template <typename Tnet, typename Tdata, typename Tlabel>
01210   class_datasource<Tnet, Tdata, Tlabel>::
01211   class_datasource(const char *data_name, const char *labels_name, 
01212                    const char *jitters_name, const char *scales_name,
01213                    const char *classes_name, const char *name_) {
01214     defaults();
01215     init(data_name, labels_name, jitters_name, scales_name, classes_name,name_);
01216     this->init_epoch();
01217     this->pretty(); // print info about dataset
01218   }
01219 
01220   template <typename Tnet, typename Tdata, typename Tlabel>
01221   class_datasource<Tnet, Tdata, Tlabel>::
01222   class_datasource(const class_datasource<Tnet, Tdata, Tlabel> &ds)
01223     : datasource<Tnet,Tdata>((const datasource<Tnet,Tdata>&) ds),
01224       lblstr(NULL) {
01225     defaults();
01226     if (ds.lblstr) {
01227       this->lblstr = new vector<string*>;
01228       for (unsigned int i = 0; i < ds.lblstr->size(); ++i) {
01229         this->lblstr->push_back(new string(*ds.lblstr->at(i)));
01230       }
01231     }
01232   }
01233 
01234   template <typename Tnet, typename Tdata, typename Tlabel>
01235   class_datasource<Tnet, Tdata, Tlabel>::~class_datasource() {
01236     if (lblstr) { // this class owns lblstr and its content
01237       vector<string*>::iterator i = lblstr->begin();
01238       for ( ; i != lblstr->end(); ++i)
01239         if (*i)
01240           delete *i;
01241       delete lblstr;
01242     }
01243   }
01244 
01246   // init methods
01247 
01248   template <typename Tnet, typename Tdata, typename Tlabel>
01249   void class_datasource<Tnet, Tdata, Tlabel>::defaults() {
01250     balance = true;
01251     bexclusion = false;
01252     random_class_order = true;
01253   }
01254 
01255   template <typename Tnet, typename Tdata, typename Tlabel>
01256   void class_datasource<Tnet, Tdata, Tlabel>::
01257   init_strings(idx<ubyte> &classes) {
01258     this->lblstr = NULL;
01259     // load classes strings
01260     if (classes.order() == 2) {
01261       this->lblstr = new vector<string*>;
01262       idx_bloop1(classe, classes, ubyte) {
01263         this->lblstr->push_back(new string((const char*) classe.idx_ptr()));
01264       }
01265     }
01266   }
01267 
01268   template <typename Tnet, typename Tdata, typename Tlabel>
01269   void class_datasource<Tnet, Tdata, Tlabel>::
01270   init_local(vector<string*> *lblstr_) {
01271     nclasses = (intg) idx_max(labels) + 1;
01272     if (lblstr_)
01273       nclasses = std::max(nclasses, (intg) lblstr_->size());
01274     // assign classes strings
01275     this->lblstr = lblstr_;
01276     // if no names are given and discrete, use indices as names
01277     if (!this->lblstr) {
01278       this->lblstr = new vector<string*>;
01279       ostringstream o;
01280       int imax = (int) idx_max(this->labels);
01281       for (int i = 0; i <= imax; ++i) {
01282         o << i;
01283         this->lblstr->push_back(new string(o.str()));
01284         o.str("");
01285       }
01286     }
01287     init_class_labels();
01288     nclasses = std::max((intg) idx_max(labels) + 1, (intg) clblstr.size());
01289     // count number of samples per class
01290     counts.resize(nclasses);
01291     fill(counts.begin(), counts.end(), 0);
01292     idx_bloop1(lab, labels, Tlabel) {
01293       counts[(size_t)lab.gget()]++;
01294     }
01295     // balance
01296     set_balanced(true); // balance dataset for each class in next_train
01297     perclass_norm = false;
01298     included = nclasses;
01299     seek_begin();
01300     seek_begin_train();
01301   }
01302 
01303   template <typename Tnet, typename Tdata, typename Tlabel>
01304   void class_datasource<Tnet, Tdata, Tlabel>::
01305   init(midx<Tdata> &data_, idx<Tlabel> &labels_, vector<string*> *lblstr_,
01306        const char *name_) {
01307     labeled_datasource<Tnet, Tdata, Tlabel>::init(data_, labels_, name_);
01308     class_datasource<Tnet, Tdata, Tlabel>::init_local(lblstr_);
01309   }
01310 
01311   template <typename Tnet, typename Tdata, typename Tlabel>
01312   void class_datasource<Tnet, Tdata, Tlabel>::
01313   init(idx<Tdata> &data_, idx<Tlabel> &labels_, vector<string*> *lblstr_,
01314        const char *name_) {
01315     labeled_datasource<Tnet, Tdata, Tlabel>::init(data_, labels_, name_);
01316     class_datasource<Tnet, Tdata, Tlabel>::init_local(lblstr_);
01317   }
01318 
01319   template <typename Tnet, typename Tdata, typename Tlabel>
01320   void class_datasource<Tnet, Tdata, Tlabel>::
01321   init(const char *data_fname, const char *labels_fname,
01322        const char *jitters_fname, const char *scales_fname, 
01323        const char *classes_fname, const char *name_,
01324        uint max_size) {
01325     // load classes
01326     idx<ubyte> classes;
01327     bool classes_found = false;
01328     if (classes_fname && strlen(classes_fname) != 0) {
01329       try {
01330         classes = load_matrix<ubyte>(classes_fname);
01331         classes_found = true;
01332       } catch (string &err) { cerr << "warning: " << err << endl; }
01333     } else
01334       cout << "No category names found, using numbers." << endl;
01335     // classes names are optional, use numbers by default if not specified
01336     if (classes_found) {
01337       this->lblstr = new vector<string*>;
01338       idx_bloop1(classe, classes, ubyte) {
01339         this->lblstr->push_back(new string((const char*) classe.idx_ptr()));
01340       }
01341     }
01342     // init
01343     labeled_datasource<Tnet, Tdata, Tlabel>::
01344       init(data_fname, labels_fname, jitters_fname, scales_fname, name_, 
01345            max_size);
01346     class_datasource<Tnet, Tdata, Tlabel>::init_local(this->lblstr);
01347   }
01348 
01349   template <typename Tnet, typename Tdata, typename Tlabel>
01350   void class_datasource<Tnet, Tdata, Tlabel>::
01351   init_root(const char *root, const char *data_name, const char *labels_name,
01352             const char *jitters_name, const char *scales_name, 
01353             const char *classes_name, const char *name_) {
01354     string data_fname, labels_fname, classes_fname, jitters_fname, scales_fname;
01355     data_fname << root << "/" << data_name << "_" << DATA_NAME
01356                << MATRIX_EXTENSION;
01357     labels_fname << root << "/" << labels_name << "_" << LABELS_NAME
01358                  << MATRIX_EXTENSION;
01359     classes_fname << root << "/" << classes_name << "_" << CLASSES_NAME
01360                   << MATRIX_EXTENSION;
01361     jitters_fname << root << "/" << (jitters_name ? jitters_name :classes_name)
01362                   << "_" << JITTERS_NAME << MATRIX_EXTENSION;
01363     scales_fname << root << "/" << (scales_name ? scales_name :classes_name)
01364                   << "_" << SCALES_NAME << MATRIX_EXTENSION;
01365     init(data_fname.c_str(), labels_fname.c_str(), jitters_fname.c_str(),
01366          scales_fname.c_str(), classes_fname.c_str(), name_);
01367   }
01368 
01369   template <typename Tnet, typename Tdata, typename Tlabel>
01370   void class_datasource<Tnet, Tdata, Tlabel>::
01371   init_root(const char *root_dsname, const char *name_) {
01372     string data_fname, labels_fname, classes_fname, jitters_fname, scales_fname;
01373     data_fname << root_dsname << "_" << DATA_NAME << MATRIX_EXTENSION;
01374     labels_fname << root_dsname << "_" << LABELS_NAME << MATRIX_EXTENSION;
01375     classes_fname << root_dsname << "_" << CLASSES_NAME << MATRIX_EXTENSION;
01376     jitters_fname << root_dsname << "_" << JITTERS_NAME << MATRIX_EXTENSION;
01377     scales_fname << root_dsname << "_" << SCALES_NAME << MATRIX_EXTENSION;
01378     init(data_fname.c_str(), labels_fname.c_str(), jitters_fname.c_str(),
01379          scales_fname.c_str(), classes_fname.c_str(), name_);
01380   }
01381 
01382   template <typename Tnet, typename Tdata, typename Tlabel>
01383   void class_datasource<Tnet, Tdata, Tlabel>::init_class_labels() {
01384     std::map<Tlabel,Tlabel> mlabs;
01385     if (olabels.get_idxdim() != labels.get_idxdim())
01386       olabels = idx<Tlabel>(labels.get_idxdim());
01387     idx_copy(labels, olabels); // keep original labels
01388     idx_sortup(labels);
01389     // add all classes into table
01390     intg i = 0;
01391     clblstr.clear();
01392     { idx_bloop1(lab, labels, Tlabel) {
01393         Tlabel l = lab.gget();
01394         if (mlabs.find(l) == mlabs.end()
01395             && (!bexclusion || !excluded[(uint)l])) {
01396           mlabs[l] = i++;
01397           clblstr.push_back((*lblstr)[(uint)l]);
01398         }
01399       }}
01400     // now add excluded classes into table
01401     { idx_bloop1(lab, labels, Tlabel) {
01402         Tlabel l = lab.gget();
01403         if (mlabs.find(l) == mlabs.end()
01404             && (bexclusion && excluded[(uint)l])) {
01405           mlabs[l] = i++;
01406         }
01407       }}
01408     // now replace all original labels by their new value
01409     idx_copy(olabels, labels);
01410     { idx_bloop1(lab, labels, Tlabel) {
01411         Tlabel l = lab.gget();
01412         if (mlabs.find(l) != mlabs.end())
01413           lab.sset(mlabs[l]);
01414         else
01415           eblerror("label " << (uint) l << " not found");
01416       }}
01417     // replace exclusion table
01418     if (bexclusion) {
01419       for (uint i = 0; i < excluded.size(); ++i)
01420           excluded[i] = i < clblstr.size() ? false : true;
01421     }
01422     if (balance)
01423       set_balanced(true); // balance dataset for each class in next_train
01424   }
01425 
01427   // data access
01428 
01429   template <typename Tnet, typename Tdata, typename Tlabel>
01430   Tlabel class_datasource<Tnet,Tdata,Tlabel>::get_label() {
01431     idx<Tlabel> lab = labels[it];
01432     if (lab.order() == 0)
01433       return lab.get();
01434     else if (lab.order() == 1 && lab.dim(0) == 1)
01435       return lab.get(0);
01436     else
01437       eblerror("expected single-element labels");
01438     return 0;
01439   }
01440 
01442   // iterating
01443 
01444   template <typename Tnet, typename Tdata, typename Tlabel>
01445   bool class_datasource<Tnet,Tdata,Tlabel>::included_sample(intg index) {
01446     if (!bexclusion)
01447       return true;
01448     if (index >= data.dim(0))
01449       eblerror("cannot check inclusion of sample " << index
01450                << ", only " << data.dim(0) << " samples");
01451     return !excluded[(int) labels.gget(index)];
01452   }
01453 
01454   template <typename Tnet, typename Tdata, typename Tlabel>
01455   intg class_datasource<Tnet,Tdata,Tlabel>::count_included_samples() {
01456     if (!bexclusion) return this->size();
01457     intg n = 0;
01458     idx_bloop1(lab, labels, Tlabel) {
01459       if (!excluded[(int) lab.gget()])
01460         n++;
01461     }
01462     return n;
01463   }
01464 
01465   template <typename Tnet, typename Tdata, typename Tlabel>
01466   void class_datasource<Tnet,Tdata,Tlabel>::seek_begin() {
01467     datasource<Tnet,Tdata>::seek_begin();
01468     if (bexclusion) {
01469       while (excluded[(int)get_label()])
01470         this->next();
01471       EDEBUG("seek_begin to label: " << (int) get_label() << " it: " << it);
01472     }
01473   }
01474 
01475   template <typename Tnet, typename Tdata, typename Tlabel>
01476   void class_datasource<Tnet,Tdata,Tlabel>::seek_begin_train() {
01477     datasource<Tnet,Tdata>::seek_begin_train();
01478     if (bexclusion) {
01479       while (excluded[(int)get_label()])
01480         this->next();
01481       EDEBUG("seek_begin to label: " << (int) get_label() << " it: " << it);
01482     }
01483   }
01484 
01485 
01486   template <typename Tnet, typename Tdata, typename Tlabel>
01487   bool class_datasource<Tnet,Tdata,Tlabel>::next() {
01488     if (!bexclusion)
01489       return datasource<Tnet,Tdata>::next();
01490     // handle excluded classes
01491     bool b = datasource<Tnet,Tdata>::next();
01492     while (b && excluded[(int)get_label()])
01493       b = datasource<Tnet,Tdata>::next();
01494     return b;
01495   }
01496 
01497   template <typename Tnet, typename Tdata, typename Tlabel>
01498   bool class_datasource<Tnet,Tdata,Tlabel>::next_train() {
01499     // check that this datasource is allowed to call this method
01500     if (test_set)
01501       eblerror("forbidden call of next_train() on testing sets");
01502     if (!balance) // do not balance by class
01503       return datasource<Tnet,Tdata>::next_train();
01504     bool pick = false;
01505     not_picked++;
01506     // balanced: return samples in class-balanced order
01507     // get pointer to first non empty class
01508     while (!bal_indices[class_it].size())
01509       next_balanced_class();
01510     it = bal_indices[class_it][bal_it[class_it]];
01511     bal_it[class_it] += 1;
01512     // decide if we want to select this sample for training
01513     pick = this->pick_current();
01514     // decrement epoch counter
01515     //      if (epoch_done_counters[class_it] > 0)
01516     epoch_done_counters[class_it] = epoch_done_counters[class_it] - 1;
01517     if (bal_it[class_it] >= bal_indices[class_it].size()) {
01518       // returning to begining of list for this class
01519       bal_it[class_it] = 0;
01520       // shuffling list for this class
01521       if (shuffle_passes) {
01522         vector<intg> &clist = bal_indices[class_it];
01523         random_shuffle(clist.begin(), clist.end());
01524       }
01525       if (weigh_samples)
01526         normalize_probas(class_it);
01527     }
01528     // recursion failsafe, allow 1000 max recursions
01529     if (!bexclusion &&
01530         not_picked > MIN(1000, (intg) bal_indices[class_it].size())) {
01531       // we called recursion on this method more than number of class samples
01532       // give up and show current sample
01533       pick = true;
01534     }
01535     if (pick) {
01536       // increment pick counter for this sample
01537       if (count_pickings) pick_count.set(pick_count.get(it) + 1, it);
01538 #ifdef __DEBUG__
01539       cout << "Picking sample " << it << " (label: " << (int)get_label()
01540            << ", name: " << *(clblstr[(int)get_label()]) << ", pickings: "
01541            << pick_count.get(it) << ", energy: " << energies.get(it)
01542            << ", correct: " << (int) correct.get(it);
01543       if (weigh_samples) cout << ", proba: " << probas.get(it);
01544       cout << ")" << endl;
01545 #endif
01546       epoch_cnt++;
01547       // increment sample counter
01548       epoch_pick_cnt++;
01549       // if we picked a sample, jump to next class
01550       next_balanced_class();
01551       not_picked = 0;
01552       this->pretty_progress();
01553       return true;
01554     } else {
01555 #ifdef __DEBUG__
01556       if (!bexclusion)
01557         cout << "Not picking sample " << it << " (label: "
01558              << (int) labels.gget(it) << ", pickings: " << pick_count.get(it)
01559              << ", energy: " << energies.get(it)
01560              << ", correct: " << (int) correct.get(it)
01561              << ", proba: " << probas.get(it) << ")" << endl;
01562 #endif
01563       this->pretty_progress();
01564       return false;
01565     }
01566   }
01567 
01568   template <typename Tnet, typename Tdata, typename Tlabel>
01569   void class_datasource<Tnet,Tdata,Tlabel>::next_balanced_class() {
01570     class_it_it++;
01571     if (class_it_it >= class_order.size()) {
01572       class_it_it = 0;
01573       reset_class_order();
01574     }
01575     class_it = class_order[class_it_it];
01576     if (bexclusion && excluded[class_it])
01577       return next_balanced_class();
01578   }
01579 
01580   template <typename Tnet, typename Tdata, typename Tlabel>
01581   void class_datasource<Tnet,Tdata,Tlabel>::reset_class_order() {
01582     if (random_class_order) // randomize classes order
01583       random_shuffle(class_order.begin(), class_order.end());
01584   }
01585 
01586   template <typename Tnet, typename Tdata, typename Tlabel>
01587   void class_datasource<Tnet,Tdata,Tlabel>::set_balanced(bool bal) {
01588     balance = bal;
01589     if (!balance) // unbalanced
01590       cout << _name << ": Setting training as unbalanced (not taking class "
01591            << "distributions into account)." << endl;
01592     else { // balanced
01593       cout << _name << ": Setting training as balanced (taking class "
01594            << "distributions into account)." << endl;
01595       // compute vector of sample indices for each class
01596       bal_indices.clear();
01597       bal_it.clear();
01598       epoch_done_counters.clear();
01599       class_it = 0;
01600       class_it_it = 0;
01601       for (intg i = 0; i < nclasses; ++i) {
01602         vector<intg> indices;
01603         bal_indices.push_back(indices);
01604         bal_it.push_back(0); // init iterators
01605         class_order.push_back(i); // init iterators
01606       }
01607       reset_class_order();
01608       // distribute sample indices into each vector based on label
01609       for (uint i = 0; i < this->size(); ++i)
01610         bal_indices[(intg) (labels.gget(i))].push_back(i);
01611       for (uint i = 0; i < bal_indices.size(); ++i) {
01612         // shuffle
01613         random_shuffle(bal_indices[i].begin(), bal_indices[i].end());
01614         // init epoch counters
01615         epoch_done_counters.push_back(bal_indices[i].size());
01616       }
01617     }
01618   }
01619 
01620   template <typename Tnet, typename Tdata, typename Tlabel>
01621   void class_datasource<Tnet,Tdata,Tlabel>::set_random_class_order(bool ran) {
01622     random_class_order = ran;
01623     cout << "Classes order is " << (random_class_order ? "" : "not")
01624          << " random." << endl;
01625   }
01626 
01627   template <typename Tnet, typename Tdata, typename Tlabel>
01628   void class_datasource<Tnet,Tdata,Tlabel>::
01629   limit_classes(intg n, intg offset, bool random) {
01630     // enable exclusion except for the n classes after offset
01631     if ((offset == 0 && n >= nclasses) || offset >= nclasses) {
01632       eblwarn("ignoring attempt to limit classes to " << n
01633               << " starting at offset " << offset << " because there are only "
01634               << nclasses << " classes");
01635     } else {
01636       bexclusion = true;
01637       excluded.clear();
01638       for (intg i = 0; i < nclasses; ++i) {
01639         if (i < offset || i >= offset + n)
01640           excluded.push_back(true);
01641         else
01642           excluded.push_back(false);
01643       }
01644       if (random)
01645         random_shuffle(excluded.begin(), excluded.end());
01646       included = std::min(nclasses, n + offset) - offset;
01647       cout << "Excluded all but " << included
01648            << " classes (offset: " << offset << ", n: " << n << ")" << endl;
01649       init_class_labels();
01650     }
01651   }
01652 
01653   template <typename Tnet, typename Tdata, typename Tlabel>
01654   bool class_datasource<Tnet,Tdata,Tlabel>::epoch_done() {
01655     switch (epoch_mode) {
01656     case 0: // fixed number of samples
01657       if (epoch_cnt >= epoch_sz)
01658         return true;
01659       break ;
01660     case 1: // see all samples at least once
01661       if (balance) {
01662         // check that all classes are done
01663         for (uint i = 0; i < epoch_done_counters.size(); ++i) {
01664           if (epoch_done_counters[i] > 0)
01665             return false;
01666         }
01667         return true; // all classes are done
01668       } else { // do not balance, use epoch_sz
01669         if (epoch_cnt >= epoch_sz)
01670           return true;
01671       }
01672       break ;
01673     default: eblerror("unknown epoch_mode");
01674     }
01675     return false;
01676   }
01677 
01678   template <typename Tnet, typename Tdata, typename Tlabel>
01679   void class_datasource<Tnet,Tdata,Tlabel>::init_epoch() {
01680     epoch_cnt = 0;
01681     epoch_pick_cnt = 0;
01682     epoch_timer.restart();
01683     epoch_show_printed = -1; // last epoch count we have printed
01684     if (balance) {
01685       uint maxsize = 0;
01686       // for balanced training, set each class to not done.
01687       for (uint k = 0; k < bal_indices.size(); ++k) {
01688         epoch_done_counters[k] = bal_indices[k].size();
01689         if (bal_indices[k].size() > maxsize)
01690           maxsize = bal_indices[k].size();
01691       }
01692       if (epoch_mode == 1) // for ETA estimation only, estimate epoch size
01693         epoch_sz = maxsize * bal_indices.size();
01694     }
01695     // if we have prior information about each sample energy and classification
01696     // let's use it to initialize the picking probabilities.
01697     if (weigh_samples)
01698       this->normalize_all_probas();
01699   }
01700 
01701   template <typename Tnet, typename Tdata, typename Tlabel>
01702   void class_datasource<Tnet,Tdata,Tlabel>::normalize_all_probas() {
01703     if (weigh_samples) {
01704       if (perclass_norm && balance) {
01705         for (uint i = 0; i < bal_indices.size(); ++i)
01706           normalize_probas(i);
01707       } else
01708         normalize_probas();
01709     }
01710   }
01711 
01712   template <typename Tnet, typename Tdata, typename Tlabel>
01713   void class_datasource<Tnet,Tdata,Tlabel>::normalize_probas(int classid) {
01714     vector<intg> *cindices = NULL;
01715     if (perclass_norm && balance) { // use only class_it class samples
01716       if (classid < 0)
01717         eblerror("class id cannot be negative");
01718       uint class_it = (uint) classid;
01719       cindices = &(bal_indices[class_it]);
01720       cout << _name << ": Normalizing probabilities for class" << class_it;
01721       datasource<Tnet,Tdata>::normalize_probas(cindices);
01722     } else // use all samples
01723       datasource<Tnet,Tdata>::normalize_probas();
01724   }
01725 
01727   // accessors
01728 
01729   template <typename Tnet, typename Tdata, typename Tlabel>
01730   intg class_datasource<Tnet, Tdata, Tlabel>::get_nclasses() {
01731     if (bexclusion)
01732       return included;
01733     return nclasses;
01734   }
01735 
01736   template <typename Tnet, typename Tdata, typename Tlabel>
01737   int class_datasource<Tnet, Tdata, Tlabel>::get_class_id(const char *name) {
01738     int id_ = -1;
01739     vector<string*>::iterator i = clblstr.begin();
01740     for (int j = 0; i != clblstr.end(); ++i, ++j) {
01741       if (!strcmp(name, (*i)->c_str()))
01742         id_ = j;
01743     }
01744     return id_;
01745   }
01746 
01747   template <typename Tnet, typename Tdata, typename Tlabel>
01748   std::string &class_datasource<Tnet, Tdata, Tlabel>::get_class_name(int id) {
01749     if (id >= (int) clblstr.size())
01750       eblerror("requesting label string at index " << id
01751                << " but string vector has only " << clblstr.size()
01752                << " elements.");
01753     string *s = clblstr[id];
01754     if (!s)
01755       eblerror("empty label string");
01756     return *s;
01757   }
01758 
01759   template <typename Tnet, typename Tdata, typename Tlabel>
01760   std::vector<std::string*>& class_datasource<Tnet, Tdata, Tlabel>::
01761   get_label_strings() {
01762     return clblstr;
01763   }
01764 
01765   template <typename Tnet, typename Tdata, typename Tlabel>
01766   intg class_datasource<Tnet,Tdata,Tlabel>::get_lowest_common_size() {
01767     intg min_nonzero = (std::numeric_limits<intg>::max)();
01768     for (vector<intg>::iterator i = counts.begin(); i != counts.end(); ++i) {
01769       if ((*i < min_nonzero) && (*i != 0))
01770         min_nonzero = *i;
01771     }
01772     if (min_nonzero == (std::numeric_limits<intg>::max)())
01773       eblerror("empty dataset");
01774     return min_nonzero * nclasses;
01775   }
01776 
01777   template <typename Tnet, typename Tdata, typename Tlabel>
01778   void class_datasource<Tnet,Tdata,Tlabel>::save_pickings(const char *name_) {
01779     // non-class plotting
01780     datasource<Tnet,Tdata>::save_pickings(name_);
01781     string name = "pickings";
01782     if (name_)
01783       name = name_;
01784     // plot by class
01785     write_classed_pickings(pick_count, correct, name);
01786     write_classed_pickings(energies, correct, name, "_energies");
01787     idx<double> e = idx_copy(energies);
01788     idx<ubyte> c = idx_copy(correct);
01789     idx_sortup(e, c);
01790     write_classed_pickings(e, c, name, "_sorted_energies");
01791     idx<double> p = idx_copy(probas);
01792     c = idx_copy(correct);
01793     idx_sortup(p, c);
01794     write_classed_pickings(p, c, name, "_sorted_probas");
01795     p = idx_copy(probas);
01796     e = idx_copy(energies);
01797     c = idx_copy(correct);
01798     idx_sortup(e, c, p);
01799     write_classed_pickings(p, c, name, "_probas_sorted_by_energy", true,
01800                            "Picking probability");
01801     write_classed_pickings(p, c, name, "_probas_sorted_by_energy_wrong_only",
01802                            false, "Picking probability");
01803     write_classed_pickings(e, c, name, "_energies_sorted_by_energy_wrong_only",
01804                            false, "Energy");
01805   }
01806 
01807   template <typename Tnet, typename Tdata, typename Tlabel>
01808   template <typename T>
01809   void class_datasource<Tnet,Tdata,Tlabel>::
01810   write_classed_pickings(idx<T> &m, idx<ubyte> &c, string &name_,
01811                          const char *name2_, bool plot_correct,
01812                          const char *ylabel) {
01813     string name = name_;
01814     if (name2_)
01815       name += name2_;
01816     name += "_classed";
01817     // sorted classed plot file
01818     if (labels.order() == 1) { // single label value
01819       string fname = name;
01820       fname += ".plot";
01821       ofstream fp(fname.c_str());
01822       if (!fp) {
01823         cerr << "failed to open " << fname << endl;
01824         eblerror("failed to open file for writing");
01825       }
01826       eblerror("not implemented");
01827       // typename idx<T>::dimension_iterator i = m.dim_begin(0);
01828       // typename idx<Tlabel>::dimension_iterator l = labels.dim_begin(0);
01829       // typename idx<ubyte>::dimension_iterator ic = c.dim_begin(0);
01830       // uint j = 0;
01831       // for ( ; i.notdone(); i++, l++, ic++) {
01832       //        if (!plot_correct && ic->get() == 1)
01833       //          continue ; // don't plot correct ones
01834       //        fp << j++;
01835       //        for (Tlabel k = 0; k < (Tlabel) nclasses; ++k) {
01836       //          if (k == l.get()) {
01837       //            if (ic->get() == 0) { // wrong
01838       //              fp << "\t" << i->get();
01839       //              if (plot_correct)
01840       //                fp << "\t?";
01841       //            } else if (plot_correct) // correct
01842       //              fp << "\t?\t" << i->get();
01843       //          } else {
01844       //            fp << "\t?";
01845       //            if (plot_correct)
01846       //              fp << "\t?";
01847       //          }
01848       //        }
01849       //        fp << endl;
01850       // }
01851       fp.close();
01852       cout << _name << ": Wrote picking statistics in " << fname << endl;
01853       // p file
01854       string fname2 = name;
01855       fname2 += ".p";
01856       ofstream fp2(fname2.c_str());
01857       if (!fp2) {
01858         cerr << "failed to open " << fname2 << endl;
01859         eblerror("failed to open file for writing");
01860       }
01861       fp2 << "set title \"" << name << "\"; set ylabel \"" << ylabel
01862           << "\"; plot \""
01863           << fname << "\" using 1:2 title \"class 0 wrong\" with impulse";
01864       if (plot_correct)
01865         fp2 << ", \""
01866             << fname << "\" using 1:3 title \"class 0 correct\" with impulse";
01867       for (uint k = 1; k < nclasses; ++k) {
01868         fp2 << ", \"" << fname << "\" using 1:" << k * (plot_correct?2:1) + 2
01869             << " title \"class " << k << " wrong\" with impulse";
01870         if (plot_correct)
01871           fp2 << ", \"" << fname << "\" using 1:" << k * 2 + 3
01872               << " title \"class " << k << " correct\" with impulse";
01873       }
01874       fp << endl;
01875       fp2.close();
01876       cout << _name << ": Wrote gnuplot file in " << fname2 << endl;
01877     }
01878   }
01879 
01881   // state saving
01882 
01883   template <typename Tnet, typename Tdata, typename Tlabel>
01884   void class_datasource<Tnet,Tdata,Tlabel>::save_state() {
01885     state_saved = true;
01886     count_pickings_save = count_pickings;
01887     it_saved = it; // save main iterator
01888     it_test_saved = it_test;
01889     it_train_saved = it_train;
01890     this->epoch_cnt_saved = epoch_cnt;
01891     this->epoch_pick_cnt_saved = epoch_pick_cnt;
01892     this->epoch_done_counters_saved = epoch_done_counters;
01893     if (!balance) // save (unbalanced) iterators
01894       datasource<Tnet,Tdata>::save_state();
01895     else { // save balanced iterators
01896       bal_it_saved.clear();
01897       bal_indices_saved.clear();
01898       for (uint k = 0; k < bal_it.size(); ++k) {
01899         bal_it_saved.push_back(bal_it[k]);
01900         vector<intg> indices;
01901         for (uint l = 0; l < bal_indices[k].size(); ++l)
01902           indices.push_back(bal_indices[k][l]);
01903         bal_indices_saved.push_back(indices);
01904       }
01905       class_it_saved = class_it;
01906       class_it_it_saved = class_it_it;
01907     }
01908   }
01909 
01910   template <typename Tnet, typename Tdata, typename Tlabel>
01911   void class_datasource<Tnet,Tdata,Tlabel>::restore_state() {
01912     if (!state_saved)
01913       eblerror("state not saved, call save_state() before restore_state()");
01914     count_pickings = count_pickings_save;
01915     it = it_saved; // restore main iterator
01916     it_test = it_test_saved;
01917     it_train = it_train_saved;
01918     epoch_cnt = this->epoch_cnt_saved;
01919     epoch_pick_cnt = this->epoch_pick_cnt_saved;
01920     epoch_done_counters = this->epoch_done_counters_saved;
01921     if (!balance) // restore unbalanced
01922       datasource<Tnet,Tdata>::restore_state();
01923     else { // restore balanced iterators
01924       for (uint k = 0; k < bal_it.size(); ++k) {
01925         bal_it[k] = bal_it_saved[k];
01926         for (uint l = 0; l < bal_indices[k].size(); ++l)
01927           bal_indices[k][l] = bal_indices_saved[k][l];
01928       }
01929       class_it = class_it_saved;
01930       class_it_it = class_it_it_saved;
01931     }
01932   }
01933 
01935   // pretty methods
01936 
01937   template <typename Tnet, typename Tdata, typename Tlabel>
01938   void class_datasource<Tnet,Tdata,Tlabel>::pretty_progress(bool newline) {
01939     if (this->is_test())
01940       datasource<Tnet,Tdata>::pretty_progress(newline);
01941     else {
01942       if (epoch_show > 0 && epoch_pick_cnt % epoch_show == 0 &&
01943           epoch_show_printed != (intg) epoch_pick_cnt) {
01944         datasource<Tnet,Tdata>::pretty_progress(false);
01945         if (balance && epoch_done_counters.size() < 50) {
01946           cout << ", remaining:";
01947           for (uint i = 0; i < epoch_done_counters.size(); ++i) {
01948             cout << " " << i << ": " << epoch_done_counters[i];
01949           }
01950         }
01951         if (newline)
01952           cout << endl;
01953       }
01954     }
01955   }
01956 
01957   template <typename Tnet, typename Tdata, typename Tlabel>
01958   void class_datasource<Tnet, Tdata, Tlabel>::pretty() {
01959     cout << _name << ": classification dataset \"" << _name
01960          << "\" contains "
01961          << data.dim(0) << " samples of dimension " << sampledims
01962          << " and defines an epoch as " << epoch_sz << " samples." << endl;
01963     cout << this->_name << ": It has " << nclasses << " classes: ";
01964     uint i;
01965     for (i = 0; i < this->counts.size(); ++i)
01966       if (!bexclusion || !excluded[i])
01967         cout << this->counts[i] << " \"" << *(clblstr[i]) << "\" ";
01968     cout << endl;
01969     pretty_scales();
01970   }
01971 
01972   template <typename Tnet, typename Tdata, typename Tlabel>
01973   void class_datasource<Tnet, Tdata, Tlabel>::pretty_scales() {
01974     labeled_datasource<Tnet,Tdata,Tlabel>::pretty_scales();
01975     if (this->scales_loaded) {
01976       intg maxscale = idx_max(this->scales);
01977       for (uint c = 0; c < this->counts.size(); ++c)
01978         if (!bexclusion || !excluded[c]) {
01979           vector<intg> tally(maxscale + 1, 0);
01980           idx_bloop2(scale, this->scales, intg, label, labels, Tlabel) {
01981             if (label.get() != (Tlabel) c) continue ;
01982             intg s = scale.get();
01983             if (s < 0) eblerror("unexpected negative value");
01984             tally[s] = tally[s] + 1;
01985           }
01986           intg nscales = 0;
01987           for (intg i = 0; i < (intg) tally.size(); ++i)
01988             if (tally[i] > 0) nscales++;
01989           cout << _name << ": " << *(clblstr[c]) << " has " 
01990                << nscales << " scales";
01991           for (intg i = 0; i < (intg) tally.size(); ++i)
01992             cout << ", " << i << ": " << tally[i];
01993           cout << endl;
01994         }
01995     }
01996   }
01997 
01998   // protected methods /////////////////////////////////////////////////////////
01999 
02000   template <typename Tnet, typename Tdata, typename Tlabel>
02001   bool class_datasource<Tnet,Tdata,Tlabel>::pick_current() {
02002     if (bexclusion && excluded[(int) get_label()])
02003       return false;
02004     return datasource<Tnet,Tdata>::pick_current();
02005   }
02006 
02009 
02010   template <typename Tlabel>
02011   class_node<Tlabel>::class_node(Tlabel id, string &name_)
02012     : _label(id), _name(name_), parent(NULL), children(),
02013       it_children(children.begin()),
02014       bempty(true), iempty(true), _depth(0), it_samples(samples.begin()) {
02015     //EDEBUG("creating node " << this << " label: " << id << " name: "
02016     //<< _name);
02017   }
02018 
02019   template <typename Tlabel>
02020   class_node<Tlabel>::~class_node() {
02021   }
02022 
02023   template <typename Tlabel>
02024   bool class_node<Tlabel>::empty() {
02025     return bempty;
02026   }
02027 
02028   template <typename Tlabel>
02029   bool class_node<Tlabel>::internally_empty() {
02030     return iempty;
02031   }
02032 
02033   template <typename Tlabel>
02034   intg class_node<Tlabel>::next() {
02035     if (bempty) eblerror("cannot call next() on empty node");
02036     if (it_children < children.begin()) it_children = children.begin();
02037     if (it_samples < samples.begin()) it_samples = samples.begin();
02038     intg id = -1;
02039     if (it_children == children.end()) { // we reached the end of children
02040       // roll back children iterator
02041       it_children = children.begin();
02042       if (samples.size() > 0) { // return internal samples if present
02043         id = *it_samples;
02044         it_samples++;
02045         if (it_samples == samples.end()) // roll back samples iterator
02046           it_samples = samples.begin();
02047         return id;
02048       }
02049     }
02050     // return a sample from children
02051     if ((*it_children)->empty()) {
02052       it_children++;
02053       return next(); // skip current empty child
02054     } else // this child is not empty
02055       return (*it_children)->next();
02056   }
02057 
02058   template <typename Tlabel>
02059   void class_node<Tlabel>::add_child(class_node *child) {
02060     // make sure child is not added twice
02061     for (typename vector<class_node<Tlabel>*>::iterator i = children.begin();
02062          i != children.end(); ++i)
02063       if (*i == child)
02064         eblerror("trying to push same child node twice");
02065     // push
02066     children.push_back(child);
02067     // set child's parent to this node
02068     child->set_parent(this);
02069     // propagate up if child is non-empty
02070     if (!child->empty())
02071       set_non_empty();
02072   }
02073 
02074   template <typename Tlabel>
02075   void class_node<Tlabel>::add_sample(intg index) {
02076     samples.push_back(index);
02077     iempty = false;
02078     // propagate information back to parents
02079     set_non_empty();
02080   }
02081 
02082   template <typename Tlabel>
02083   intg class_node<Tlabel>::nsamples() {
02084     // TODO: we might want to use idx here to handle large sets (intg)
02085     return samples.size();
02086   }
02087 
02088   template <typename Tlabel>
02089   Tlabel class_node<Tlabel>::label() {
02090     return _label;
02091   }
02092 
02093   template <typename Tlabel>
02094   Tlabel class_node<Tlabel>::label(uint depth) {
02095     // current depth is lower or equal to target depth, return current label
02096     if (_depth <= depth || !parent)
02097       return _label;
02098     // current depth is higher than target depth, call parent's label
02099     return parent->label(depth);
02100   }
02101 
02102   template <typename Tlabel>
02103   uint class_node<Tlabel>::depth() {
02104     return _depth;
02105   }
02106 
02107   template <typename Tlabel>
02108   uint class_node<Tlabel>::set_depth(uint d) {
02109     _depth = d;
02110     uint dmax = d;
02111     for (typename vector<class_node<Tlabel>*>::iterator i = children.begin();
02112          i != children.end(); ++i)
02113       dmax = std::max(dmax, (*i)->set_depth(d + 1));
02114     return dmax;
02115   }
02116 
02117   template <typename Tlabel>
02118   string &class_node<Tlabel>::name() {
02119     return _name;
02120   }
02121 
02122   template <typename Tlabel>
02123   class_node<Tlabel>* class_node<Tlabel>::get_parent() {
02124     return parent;
02125   }
02126 
02127   template <typename Tlabel>
02128   bool class_node<Tlabel>::is_parent(Tlabel lab) {
02129     if (lab == _label)
02130       return true;
02131     if (parent)
02132       return parent->is_parent(lab);
02133     return false;
02134   }
02135 
02136   // protected methods /////////////////////////////////////////////////////////
02137 
02138   template <typename Tlabel>
02139   void class_node<Tlabel>::set_non_empty() {
02140     bempty = false;
02141     // propagate information back to parents
02142     if (parent)
02143       parent->set_non_empty();
02144   }
02145 
02146   template <typename Tlabel>
02147   void class_node<Tlabel>::set_parent(class_node *p) {
02148     parent = p;
02149   }
02150 
02152   // hierarchy_datasource
02153 
02154   template <typename Tnet, typename Tdata, typename Tlabel>
02155   hierarchy_datasource<Tnet, Tdata, Tlabel>::hierarchy_datasource()
02156     : class_datasource<Tnet,Tdata,Tlabel>() {
02157   }
02158 
02159   template <typename Tnet, typename Tdata, typename Tlabel>
02160   hierarchy_datasource<Tnet, Tdata, Tlabel>::
02161   hierarchy_datasource(midx<Tdata> &data_, idx<Tlabel> &labels_,
02162                        idx<Tlabel> *parents_,
02163                        vector<string*> *lblstr_, const char *name_) {
02164     class_datasource<Tnet,Tdata,Tlabel>::init(data_, labels_, name_, lblstr_);
02165     this->init_parents(parents_);
02166     this->init_epoch();
02167     this->pretty(); // print info about dataset
02168   }
02169 
02170   template <typename Tnet, typename Tdata, typename Tlabel>
02171   hierarchy_datasource<Tnet, Tdata, Tlabel>::
02172   hierarchy_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_,
02173                        idx<Tlabel> *parents_,
02174                        vector<string*> *lblstr_, const char *name_) {
02175     class_datasource<Tnet,Tdata,Tlabel>::init(data_, labels_, name_, lblstr_);
02176     this->init_parents(parents_);
02177     this->init_epoch();
02178     this->pretty(); // print info about dataset
02179   }
02180 
02181   template <typename Tnet, typename Tdata, typename Tlabel>
02182   hierarchy_datasource<Tnet, Tdata, Tlabel>::
02183   hierarchy_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_,
02184                        idx<Tlabel> *parents_, idx<ubyte> *classes,
02185                        const char *name_) {
02186     if (classes)
02187       this->init_strings(*classes);
02188     class_datasource<Tnet,Tdata,Tlabel>::init(data_, labels_, this->lblstr,
02189                                               name_);
02190     this->init_parents(parents_);
02191     this->init_epoch();
02192     this->pretty(); // print info about dataset
02193   }
02194 
02195   template <typename Tnet, typename Tdata, typename Tlabel>
02196   hierarchy_datasource<Tnet, Tdata, Tlabel>::
02197   hierarchy_datasource(const char *data_name, const char *labels_name,
02198                        const char *parents_name, const char *jitters_name,
02199                        const char *scales_name, const char *classes_name, 
02200                        const char *name_, uint max_size) {
02201     class_datasource<Tnet,Tdata,Tlabel>::
02202       init(data_name, labels_name, jitters_name, scales_name, classes_name,
02203            name_, max_size);
02204     // load parent
02205     idx<Tlabel> parents_;
02206     if (parents_name) {
02207       try {
02208         parents_ = load_matrix<Tlabel>(parents_name);
02209       } eblcatcherror();
02210     }
02211     // inits
02212     this->init_parents(parents_name ? &parents_ : NULL);
02213     this->init_epoch();
02214     this->pretty(); // print info about dataset
02215   }
02216 
02217   // template <typename Tnet, typename Tdata, typename Tlabel>
02218   // hierarchy_datasource<Tnet, Tdata, Tlabel>::
02219   // hierarchy_datasource(const hierarchy_datasource<Tnet, Tdata, Tlabel> &ds)
02220   //   : class_datasource<Tnet,Tdata>((const class_datasource<Tnet,Tdata>&) ds) {
02221   // }
02222 
02223   template <typename Tnet, typename Tdata, typename Tlabel>
02224   hierarchy_datasource<Tnet, Tdata, Tlabel>::~hierarchy_datasource() {
02225     // delete all nodes
02226     for (typename vector<class_node<Tlabel>*>::iterator i = all_nodes.begin();
02227          i != all_nodes.end(); ++i) {
02228       class_node<Tlabel> *n = *i;
02229       if (n) delete n;
02230     }
02231     // delete depth vectors
02232     for (typename vector<vector<class_node<Tlabel>*>*>::iterator i =
02233            all_depths.begin(); i != all_depths.end(); ++i) {
02234       vector<class_node<Tlabel>*> *n = *i;
02235       if (n) delete n;
02236     }
02237     if (parents)
02238       delete parents;
02239   }
02240 
02242   // init methods
02243 
02244   template <typename Tnet, typename Tdata, typename Tlabel>
02245   void hierarchy_datasource<Tnet, Tdata, Tlabel>::
02246   init_parents(idx<Tlabel> *parents_) {
02247     if (parents_)
02248       parents = new idx<Tlabel>(*parents_);
02249     else {
02250       eblwarn("no parents hierarchy specified, initializing with a flat "
02251               << "hierarchy");
02252       parents = new idx<Tlabel>(nclasses, 2);
02253       for (Tlabel i = 0; i < (Tlabel) nclasses; ++i) {
02254         parents->set(i, i, 0);
02255         parents->set(-1, i, 1); // node i has no parent
02256       }
02257     }
02258     // check that parent id do not exceed number of classes
02259     if (parents) {
02260       if (idx_max(*parents) > nclasses)
02261         eblerror("maximum parent id (" << idx_max(*parents)
02262                  << ") cannot exceed number of classes (" << nclasses << ")");
02263       if (parents->dim(1) != 2)
02264         eblerror("expected dim 1 to be size 2 (child/parent)");
02265     }
02266     if (!lblstr)
02267       eblerror("expected class strings to be defined");
02268     // create hierarchy tree for efficient balanced samples ordering
02269     all_nodes.resize(nclasses, NULL);
02270     // add a new node for each class
02271     Tlabel l = 0;
02272     for (vector<string*>::iterator i = lblstr->begin(); i != lblstr->end();++i){
02273       class_node<Tlabel> *node = all_nodes[l];
02274       if (!node) {
02275         node = new class_node<Tlabel>(l, *((*lblstr)[l]));
02276         all_nodes[l] = node;
02277       }
02278       l++;
02279     }
02280     // add samples to each node
02281     intg i = 0;
02282     idx_bloop1(lab, labels, Tlabel) {
02283       Tlabel l = lab.get();
02284       class_node<Tlabel> *node = all_nodes[l];
02285       if (!node) eblerror("node " << l << " is not defined");
02286       // add sample to node
02287       node->add_sample(i);
02288       i++;
02289     }
02290     // associate each class to its parent
02291     if (parents) {
02292       idx_bloop1(par, *parents, Tlabel) {
02293         class_node<Tlabel> *child = all_nodes[par.get(0)];
02294         if (!child) eblerror("no node with id " << par.get(0));
02295         if (par.get(1) >= 0) {
02296           class_node<Tlabel> *parent = all_nodes[par.get(1)];
02297           parent->add_child(child);
02298         //EDEBUG("adding "<< child->name() << " as child of " << parent->name());
02299         }
02300       }
02301     }
02302     // assign depth to all nodes, starting from all orphan nodes
02303     uint maxdepth = 0;
02304     for (typename vector<class_node<Tlabel>*>::iterator i = all_nodes.begin();
02305          i != all_nodes.end(); ++i) {
02306       class_node<Tlabel> *n = *i;
02307       if (n && !n->get_parent())
02308         maxdepth = std::max(maxdepth, n->set_depth(0));
02309     }
02310     // remember nodes arranged by depth
02311     all_depths.resize(maxdepth + 1, NULL);
02312     for (typename vector<class_node<Tlabel>*>::iterator i = all_nodes.begin();
02313          i != all_nodes.end(); ++i) {
02314       class_node<Tlabel> *n = *i;
02315       if (n) {
02316         vector<class_node<Tlabel>*> *d = all_depths[n->depth()];
02317         if (!d) {
02318           d = new vector<class_node<Tlabel>*>;
02319           all_depths[n->depth()] = d;
02320         }
02321         d->push_back(n);
02322         //EDEBUG("assigning " << n->name() << " to depth " << n->depth());
02323       }
02324     }
02325     // remember nodes arranged by depth and keep nodes with lower depth that
02326     // have samples internally. this garantees that all samples are used even
02327     // at lower depths than current one.
02328     complete_depths.resize(maxdepth + 1, NULL);
02329     for (typename vector<class_node<Tlabel>*>::iterator i = all_nodes.begin();
02330          i != all_nodes.end(); ++i) {
02331       class_node<Tlabel> *n = *i;
02332       if (n) {
02333         vector<class_node<Tlabel>*> *d = complete_depths[n->depth()];
02334         if (!d) {
02335           d = new vector<class_node<Tlabel>*>;
02336           complete_depths[n->depth()] = d;
02337         }
02338         d->push_back(n);
02339         // add node to all higher depths if it has internal samples
02340         if (!n->internally_empty()) {
02341           for (uint i = n->depth() + 1; i < complete_depths.size(); ++i) {
02342             vector<class_node<Tlabel>*> *dd = complete_depths[i];
02343             if (!dd) {
02344               dd = new vector<class_node<Tlabel>*>;
02345               complete_depths[i] = dd;
02346             }
02347             dd->push_back(n);
02348           }
02349         }
02350         //EDEBUG("assigning " << n->name() << " to depth " << n->depth());
02351       }
02352     }
02353     // order all by depth
02354     for (typename vector<vector<class_node<Tlabel>*>*>::iterator i =
02355            all_depths.begin(); i != all_depths.end(); ++i) {
02356       for (typename vector<class_node<Tlabel>*>::iterator j = (*i)->begin();
02357            j != (*i)->end(); ++j) {
02358         all_nodes_by_depth.push_back(*j);
02359       }
02360     }
02361     // initialize depths iterators
02362     it_depths.resize(complete_depths.size(), 0);
02363     // allocate depth labels
02364     depth_labels = idx<Tlabel>(labels.get_idxdim());
02365     set_current_depth(0);
02366     it = 0;
02367     depth_balance = false;
02368   }
02369 
02370   template <typename Tnet, typename Tdata, typename Tlabel>
02371   void hierarchy_datasource<Tnet, Tdata, Tlabel>::init_class_labels() {
02372     if (olabels.get_idxdim() != labels.get_idxdim())
02373       olabels = idx<Tlabel>(labels.get_idxdim());
02374     idx_copy(labels, olabels); // keep original labels
02375     clblstr.clear();
02376     for (uint i = 0; i < lblstr->size(); ++i)
02377       clblstr.push_back((*lblstr)[i]);
02378   }
02379 
02381   // data access
02382 
02383   template <typename Tnet, typename Tdata, typename Tlabel>
02384   Tlabel hierarchy_datasource<Tnet,Tdata,Tlabel>::get_parent() {
02385     // idx<Tlabel> lab = labels[it];
02386     // if (lab.order() != 1 && lab.dim(0) != 1)
02387     //   eblerror("expected single-element labels");
02388     // return parents.get(lab.get(0));
02389     return -1;
02390   }
02391 
02392   template <typename Tnet, typename Tdata, typename Tlabel>
02393   bool hierarchy_datasource<Tnet,Tdata,Tlabel>::
02394   is_parent_of(Tlabel l1, Tlabel l2) {
02395     class_node<Tlabel> *n = all_nodes[l2];
02396     if (!n) eblerror("node " << l2  << " not found");
02397     return n->is_parent(l1);
02398   }
02399 
02400   template <typename Tnet, typename Tdata, typename Tlabel>
02401   vector<class_node<Tlabel>*>& hierarchy_datasource<Tnet,Tdata,Tlabel>::
02402   get_nodes() {
02403     return all_nodes;
02404   }
02405 
02406   template <typename Tnet, typename Tdata, typename Tlabel>
02407   vector<class_node<Tlabel>*>&
02408   hierarchy_datasource<Tnet,Tdata,Tlabel>::get_nodes_by_depth() {
02409     return all_nodes_by_depth;
02410   }
02411 
02412   template <typename Tnet, typename Tdata, typename Tlabel>
02413   void hierarchy_datasource<Tnet,Tdata,Tlabel>::
02414   fprop_label(fstate_idx<Tlabel> &label) {
02415     label.x.sset(get_label());
02416   }
02417 
02418   template <typename Tnet, typename Tdata, typename Tlabel>
02419   void hierarchy_datasource<Tnet,Tdata,Tlabel>::
02420   fprop_label_net(fstate_idx<Tnet> &label) {
02421     label.x.sset((Tnet) get_label());
02422   }
02423 
02424   template <typename Tnet, typename Tdata, typename Tlabel>
02425   void hierarchy_datasource<Tnet,Tdata,Tlabel>::
02426   fprop_label_net(bbstate_idx<Tnet> &label) {
02427     label.x.sset((Tnet) get_label());
02428   }
02429 
02430   template <typename Tnet, typename Tdata, typename Tlabel>
02431   Tlabel hierarchy_datasource<Tnet,Tdata,Tlabel>::get_label() {
02432     // select label based on current depth and current iterator
02433     Tlabel l = labels[it].gget();
02434     return all_nodes[l]->label(current_depth);
02435   }
02436 
02437   template <typename Tnet, typename Tdata, typename Tlabel>
02438   Tlabel hierarchy_datasource<Tnet,Tdata,Tlabel>::get_label(uint d, intg index){
02439     // select label based on current depth and current iterator
02440     Tlabel l;
02441     if (index < 0)
02442       l = labels[it].gget();
02443     else
02444       l = labels[index].gget();
02445     return all_nodes[l]->label(d);
02446   }
02447 
02448   template <typename Tnet, typename Tdata, typename Tlabel>
02449   idx<Tlabel>& hierarchy_datasource<Tnet,Tdata,Tlabel>::get_depth_labels() {
02450     return depth_labels;
02451   }
02452 
02453   template <typename Tnet, typename Tdata, typename Tlabel>
02454   uint hierarchy_datasource<Tnet,Tdata,Tlabel>::
02455   get_nbrothers(class_node<Tlabel> &n) {
02456     if (n.parent)
02457       return n.parent->children.size();
02458     vector<class_node<Tlabel>*>* d0 = all_depths[0];
02459     if (d0)
02460       return d0->size();
02461     return 0;
02462   }
02463 
02465   // iterating
02466 
02467   template <typename Tnet, typename Tdata, typename Tlabel>
02468   void hierarchy_datasource<Tnet,Tdata,Tlabel>::set_depth_balanced(bool bal) {
02469     depth_balance = bal;
02470     if (!depth_balance) // unbalanced
02471       cout << _name << ": Setting training as depth-unbalanced." << endl;
02472     else // balanced
02473       cout << _name << ": Setting training as depth-balanced." << endl;
02474   }
02475 
02476   template <typename Tnet, typename Tdata, typename Tlabel>
02477   void hierarchy_datasource<Tnet,Tdata,Tlabel>::set_current_depth(uint depth) {
02478     if (depth >= all_depths.size())
02479       eblerror("cannot set current depth to " << depth << " because it is "
02480                << "more than maximum depth " << all_depths.size());
02481     current_depth = depth;
02482     // fill depth_labels matrix with labels of all samples given current depth
02483     idx_bloop2(l, labels, Tlabel, dl, depth_labels, Tlabel) {
02484       dl.set(all_nodes[l.get()]->label(current_depth));
02485     }
02486   }
02487 
02488   template <typename Tnet, typename Tdata, typename Tlabel>
02489   uint hierarchy_datasource<Tnet,Tdata,Tlabel>::get_current_depth() {
02490     return current_depth;
02491   }
02492 
02493   template <typename Tnet, typename Tdata, typename Tlabel>
02494   void hierarchy_datasource<Tnet,Tdata,Tlabel>::incr_current_depth() {
02495     if (current_depth + 1 >= all_depths.size())
02496       cout << "warning: cannot increment current depth beyond maximum ("
02497            << all_depths.size() << ")" << endl;
02498     else
02499       current_depth++;
02500     set_current_depth(current_depth);
02501   }
02502 
02503   template <typename Tnet, typename Tdata, typename Tlabel>
02504   bool hierarchy_datasource<Tnet,Tdata,Tlabel>::next_train() {
02505     // check that this datasource is allowed to call this method
02506     if (test_set)
02507       eblerror("forbidden call of next_train() on testing sets");
02508     if (!depth_balance) // do not balance by depth
02509       return class_datasource<Tnet,Tdata,Tlabel>::next_train();
02510     // balanced training
02511     vector<class_node<Tlabel>*> *nodes = complete_depths[current_depth];
02512     uint itd = it_depths[current_depth];
02513     // set it for further fprop calls
02514     class_node<Tlabel> *node = (*nodes)[itd];
02515     it = node->next();
02516     // increment depth iterator
02517     itd++;
02518     it_depths[current_depth] = itd;
02519     if (itd >= nodes->size())
02520       it_depths[current_depth] = 0;
02521     // increment epoch counters
02522     epoch_cnt++;
02523     epoch_pick_cnt++;
02524 #ifdef __DEBUG__
02525     cout << "Picking sample " << it << " (label: " << (int)get_label();
02526     if (lblstr)
02527       cout << ", name: " << *((*lblstr)[(int)get_label()]);
02528     cout << ", pickings: " << pick_count.get(it) << ", energy: "
02529          << energies.get(it) << ", correct: " << (int) correct.get(it);
02530     if (weigh_samples) cout << ", proba: " << probas.get(it);
02531     cout << ")" << endl;
02532 #endif
02533     this->pretty_progress();
02534     return true;
02535   }
02536 
02537 //   template <typename Tnet, typename Tdata, typename Tlabel>
02538 //   void hierarchy_datasource<Tnet,Tdata,Tlabel>::set_balanced(bool bal) {
02539 //     balance = bal;
02540 //     if (!balance) // unbalanced
02541 //       cout << _name << ": Setting training as unbalanced (not taking class "
02542 //         << "distributions into account)." << endl;
02543 //     else { // balanced
02544 //       cout << _name << ": Setting training as balanced (taking class "
02545 //         << "distributions into account)." << endl;
02546 //       // compute vector of sample indices for each class
02547 //       bal_indices.clear();
02548 //       bal_it.clear();
02549 //       epoch_done_counters.clear();
02550 //       class_it = 0;
02551 //       for (intg i = 0; i < nclasses; ++i) {
02552 //      vector<intg> indices;
02553 //      bal_indices.push_back(indices);
02554 //      bal_it.push_back(0); // init iterators
02555 //       }
02556 //       // distribute sample indices into each vector based on label
02557 //       for (uint i = 0; i < this->size(); ++i)
02558 //      bal_indices[(intg) (labels.gget(i))].push_back(i);
02559 //       for (uint i = 0; i < bal_indices.size(); ++i) {
02560 //      // shuffle
02561 //      random_shuffle(bal_indices[i].begin(), bal_indices[i].end());
02562 //      // init epoch counters
02563 //      epoch_done_counters.push_back(bal_indices[i].size());
02564 //       }
02565 //     }
02566 //   }
02567 
02568 //   template <typename Tnet, typename Tdata, typename Tlabel>
02569 //   bool hierarchy_datasource<Tnet,Tdata,Tlabel>::epoch_done() {
02570 //     switch (epoch_mode) {
02571 //     case 0: // fixed number of samples
02572 //       if (epoch_cnt >= epoch_sz)
02573 //      return true;
02574 //       break ;
02575 //     case 1: // see all samples at least once
02576 //       if (balance) {
02577 //      // check that all classes are done
02578 //      for (uint i = 0; i < epoch_done_counters.size(); ++i) {
02579 //        if (epoch_done_counters[i] > 0)
02580 //          return false;
02581 //      }
02582 //      return true; // all classes are done
02583 //       } else { // do not balance, use epoch_sz
02584 //      if (epoch_cnt >= epoch_sz)
02585 //        return true;
02586 //       }
02587 //       break ;
02588 //     default: eblerror("unknown epoch_mode");
02589 //     }
02590 //     return false;
02591 //   }
02592 
02593 //   template <typename Tnet, typename Tdata, typename Tlabel>
02594 //   void hierarchy_datasource<Tnet,Tdata,Tlabel>::init_epoch() {
02595 //     epoch_cnt = 0;
02596 //     epoch_pick_cnt = 0;
02597 //     epoch_timer.restart();
02598 //     epoch_show_printed = -1; // last epoch count we have printed
02599 //     if (balance) {
02600 //       uint maxsize = 0;
02601 //       // for balanced training, set each class to not done.
02602 //       for (uint k = 0; k < bal_indices.size(); ++k) {
02603 //      epoch_done_counters[k] = bal_indices[k].size();
02604 //      if (bal_indices[k].size() > maxsize)
02605 //        maxsize = bal_indices[k].size();
02606 //       }
02607 //       if (epoch_mode == 1) // for ETA estimation only, estimate epoch size
02608 //      epoch_sz = maxsize * bal_indices.size();
02609 //     }
02610 //     // if we have prior information about each sample energy and classification
02611 //     // let's use it to initialize the picking probabilities.
02612 //     this->normalize_all_probas();
02613 //   }
02614 
02615 //   template <typename Tnet, typename Tdata, typename Tlabel>
02616 //   void hierarchy_datasource<Tnet,Tdata,Tlabel>::normalize_all_probas() {
02617 //     if (weigh_samples) {
02618 //       if (perclass_norm && balance) {
02619 //      for (uint i = 0; i < bal_indices.size(); ++i)
02620 //        normalize_probas(i);
02621 //       } else
02622 //      normalize_probas();
02623 //     }
02624 //   }
02625 
02626 //   template <typename Tnet, typename Tdata, typename Tlabel>
02627 //   void hierarchy_datasource<Tnet,Tdata,Tlabel>::normalize_probas(int classid) {
02628 //     vector<intg> *cindices = NULL;
02629 //     if (perclass_norm && balance) { // use only class_it class samples
02630 //       if (classid < 0)
02631 //      eblerror("class id cannot be negative");
02632 //       uint class_it = (uint) classid;
02633 //       cindices = &(bal_indices[class_it]);
02634 //       cout << _name << ": Normalizing probabilities for class" << class_it;
02635 //       datasource<Tnet,Tdata>::normalize_probas(cindices);
02636 //     } else // use all samples
02637 //       datasource<Tnet,Tdata>::normalize_probas();
02638 //   }
02639 
02640 //   //////////////////////////////////////////////////////////////////////////////
02641 //   // accessors
02642 
02643 //   template <typename Tnet, typename Tdata, typename Tlabel>
02644 //   intg hierarchy_datasource<Tnet, Tdata, Tlabel>::get_nclasses() {
02645 //     return nclasses;
02646 //   }
02647 
02648 //   template <typename Tnet, typename Tdata, typename Tlabel>
02649 //   int hierarchy_datasource<Tnet, Tdata, Tlabel>::get_class_id(const char *name) {
02650 //     int id_ = -1;
02651 //     vector<string*>::iterator i = lblstr->begin();
02652 //     for (int j = 0; i != lblstr->end(); ++i, ++j) {
02653 //       if (!strcmp(name, (*i)->c_str()))
02654 //      id_ = j;
02655 //     }
02656 //     return id_;
02657 //   }
02658 
02659 //   template <typename Tnet, typename Tdata, typename Tlabel>
02660 //   std::string &hierarchy_datasource<Tnet, Tdata, Tlabel>::get_class_name(int id) {
02661 //     if (!lblstr)
02662 //       eblerror("no label strings");
02663 //     if (id >= (int) lblstr->size())
02664 //       eblerror("requesting label string at index " << id
02665 //             << " but string vector has only " << lblstr->size()
02666 //             << " elements.");
02667 //     string *s = (*lblstr)[id];
02668 //     if (!s)
02669 //       eblerror("empty label string");
02670 //     return *s;
02671 //   }
02672 
02673 //   template <typename Tnet, typename Tdata, typename Tlabel>
02674 //   std::vector<std::string*>& hierarchy_datasource<Tnet, Tdata, Tlabel>::
02675 //   get_label_strings() {
02676 //     if (!lblstr)
02677 //       eblerror("expected label strings to be set");
02678 //     return *lblstr;
02679 //   }
02680 
02681 //   template <typename Tnet, typename Tdata, typename Tlabel>
02682 //   intg hierarchy_datasource<Tnet,Tdata,Tlabel>::get_lowest_common_size() {
02683 //     intg min_nonzero = (std::numeric_limits<intg>::max)();
02684 //     for (vector<intg>::iterator i = counts.begin(); i != counts.end(); ++i) {
02685 //       if ((*i < min_nonzero) && (*i != 0))
02686 //      min_nonzero = *i;
02687 //     }
02688 //     if (min_nonzero == (std::numeric_limits<intg>::max)())
02689 //       eblerror("empty dataset");
02690 //     return min_nonzero * nclasses;
02691 //   }
02692 
02693 //   template <typename Tnet, typename Tdata, typename Tlabel>
02694 //   void hierarchy_datasource<Tnet,Tdata,Tlabel>::save_pickings(const char *name_) {
02695 //     // non-class plotting
02696 //     datasource<Tnet,Tdata>::save_pickings(name_);
02697 //     string name = "pickings";
02698 //     if (name_)
02699 //       name = name_;
02700 //     // plot by class
02701 //     write_classed_pickings(pick_count, correct, name);
02702 //     write_classed_pickings(energies, correct, name, "_energies");
02703 //     idx<double> e = idx_copy(energies);
02704 //     idx<ubyte> c = idx_copy(correct);
02705 //     idx_sortup(e, c);
02706 //     write_classed_pickings(e, c, name, "_sorted_energies");
02707 //     idx<double> p = idx_copy(probas);
02708 //     c = idx_copy(correct);
02709 //     idx_sortup(p, c);
02710 //     write_classed_pickings(p, c, name, "_sorted_probas");
02711 //     p = idx_copy(probas);
02712 //     e = idx_copy(energies);
02713 //     c = idx_copy(correct);
02714 //     idx_sortup(e, c, p);
02715 //     write_classed_pickings(p, c, name, "_probas_sorted_by_energy", true,
02716 //                         "Picking probability");
02717 //     write_classed_pickings(p, c, name, "_probas_sorted_by_energy_wrong_only",
02718 //                         false, "Picking probability");
02719 //     write_classed_pickings(e, c, name, "_energies_sorted_by_energy_wrong_only",
02720 //                         false, "Energy");
02721 //   }
02722 
02723 //   template <typename Tnet, typename Tdata, typename Tlabel>
02724 //   template <typename T>
02725 //   void hierarchy_datasource<Tnet,Tdata,Tlabel>::
02726 //   write_classed_pickings(idx<T> &m, idx<ubyte> &c, string &name_,
02727 //                       const char *name2_, bool plot_correct,
02728 //                       const char *ylabel) {
02729 //     string name = name_;
02730 //     if (name2_)
02731 //       name += name2_;
02732 //     name += "_classed";
02733 //     // sorted classed plot file
02734 //     if (labels.order() == 1) { // single label value
02735 //       string fname = name;
02736 //       fname += ".plot";
02737 //       ofstream fp(fname.c_str());
02738 //       if (!fp) {
02739 //      cerr << "failed to open " << fname << endl;
02740 //      eblerror("failed to open file for writing");
02741 //       }
02742 //       typename idx<T>::dimension_iterator i = m.dim_begin(0);
02743 //       typename idx<Tlabel>::dimension_iterator l = labels.dim_begin(0);
02744 //       typename idx<ubyte>::dimension_iterator ic = c.dim_begin(0);
02745 //       uint j = 0;
02746 //       for ( ; i.notdone(); i++, l++, ic++) {
02747 //      if (!plot_correct && ic->get() == 1)
02748 //        continue ; // don't plot correct ones
02749 //      fp << j++;
02750 //      for (Tlabel k = 0; k < (Tlabel) nclasses; ++k) {
02751 //        if (k == l.get()) {
02752 //          if (ic->get() == 0) { // wrong
02753 //            fp << "\t" << i->get();
02754 //            if (plot_correct)
02755 //              fp << "\t?";
02756 //          } else if (plot_correct) // correct
02757 //            fp << "\t?\t" << i->get();
02758 //        } else {
02759 //          fp << "\t?";
02760 //          if (plot_correct)
02761 //            fp << "\t?";
02762 //        }
02763 //      }
02764 //      fp << endl;
02765 //       }
02766 //       fp.close();
02767 //       cout << _name << ": Wrote picking statistics in " << fname << endl;
02768 //       // p file
02769 //       string fname2 = name;
02770 //       fname2 += ".p";
02771 //       ofstream fp2(fname2.c_str());
02772 //       if (!fp2) {
02773 //      cerr << "failed to open " << fname2 << endl;
02774 //      eblerror("failed to open file for writing");
02775 //       }
02776 //       fp2 << "set title \"" << name << "\"; set ylabel \"" << ylabel
02777 //        << "\"; plot \""
02778 //        << fname << "\" using 1:2 title \"class 0 wrong\" with impulse";
02779 //       if (plot_correct)
02780 //      fp2 << ", \""
02781 //          << fname << "\" using 1:3 title \"class 0 correct\" with impulse";
02782 //       for (uint k = 1; k < nclasses; ++k) {
02783 //      fp2 << ", \"" << fname << "\" using 1:" << k * (plot_correct?2:1) + 2
02784 //          << " title \"class " << k << " wrong\" with impulse";
02785 //      if (plot_correct)
02786 //        fp2 << ", \"" << fname << "\" using 1:" << k * 2 + 3
02787 //            << " title \"class " << k << " correct\" with impulse";
02788 //       }
02789 //       fp << endl;
02790 //       fp2.close();
02791 //       cout << _name << ": Wrote gnuplot file in " << fname2 << endl;
02792 //     }
02793 //   }
02794 
02795 //   //////////////////////////////////////////////////////////////////////////////
02796 //   // state saving
02797 
02798 //   template <typename Tnet, typename Tdata, typename Tlabel>
02799 //   void hierarchy_datasource<Tnet,Tdata,Tlabel>::save_state() {
02800 //     state_saved = true;
02801 //     count_pickings_save = count_pickings;
02802 //     it_saved = it; // save main iterator
02803 //     it_test_saved = it_test;
02804 //     it_train_saved = it_train;
02805 //     if (!balance) // save (unbalanced) iterators
02806 //       datasource<Tnet,Tdata>::save_state();
02807 //     else { // save balanced iterators
02808 //       bal_it_saved.clear();
02809 //       bal_indices_saved.clear();
02810 //       for (uint k = 0; k < bal_it.size(); ++k) {
02811 //      bal_it_saved.push_back(bal_it[k]);
02812 //      vector<intg> indices;
02813 //      for (uint l = 0; l < bal_indices[k].size(); ++l)
02814 //        indices.push_back(bal_indices[k][l]);
02815 //      bal_indices_saved.push_back(indices);
02816 //       }
02817 //       class_it_saved = class_it;
02818 //     }
02819 //   }
02820 
02821 //   template <typename Tnet, typename Tdata, typename Tlabel>
02822 //   void hierarchy_datasource<Tnet,Tdata,Tlabel>::restore_state() {
02823 //     if (!state_saved)
02824 //       eblerror("state not saved, call save_state() before restore_state()");
02825 //     count_pickings = count_pickings_save;
02826 //     it = it_saved; // restore main iterator
02827 //     it_test = it_test_saved;
02828 //     it_train = it_train_saved;
02829 //     if (!balance) // restore unbalanced
02830 //       datasource<Tnet,Tdata>::restore_state();
02831 //     else { // restore balanced iterators
02832 //       for (uint k = 0; k < bal_it.size(); ++k) {
02833 //      bal_it[k] = bal_it_saved[k];
02834 //      for (uint l = 0; l < bal_indices[k].size(); ++l)
02835 //        bal_indices[k][l] = bal_indices_saved[k][l];
02836 //       }
02837 //       class_it = class_it_saved;
02838 //     }
02839 //   }
02840 
02842   // pretty methods
02843 
02844   // template <typename Tnet, typename Tdata, typename Tlabel>
02845   // void hierarchy_datasource<Tnet,Tdata,Tlabel>::pretty_progress(bool newline) {
02846   //   if (epoch_show > 0 && epoch_pick_cnt % epoch_show == 0 &&
02847   //    epoch_show_printed != (intg) epoch_pick_cnt) {
02848   //     datasource<Tnet,Tdata>::pretty_progress(false);
02849   //     if (balance) {
02850   //    cout << ", remaining:";
02851   //    for (uint i = 0; i < epoch_done_counters.size(); ++i) {
02852   //      cout << " " << i << ": " << epoch_done_counters[i];
02853   //    }
02854   //     }
02855   //     if (newline)
02856   //    cout << endl;
02857   //   }
02858   // }
02859 
02860   template <typename Tnet, typename Tdata, typename Tlabel>
02861   void hierarchy_datasource<Tnet, Tdata, Tlabel>::pretty() {
02862     cout << _name << ": hierarchy class dataset \"" << _name
02863          << "\" contains "
02864          << data.dim(0) << " samples of dimension " << sampledims
02865          << " and defines an epoch as " << epoch_sz << " samples." << endl;
02866     if (lblstr) {
02867       cout << this->_name << ": It has " << nclasses << " classes: ";
02868       uint i;
02869       for (i = 0; i < this->counts.size() - 1; ++i)
02870         cout << this->counts[i] << " \"" << *(*lblstr)[i] << "\", ";
02871       cout << "and " << this->counts[i] << " \"" << *(*lblstr)[i] << "\".";
02872       cout << endl;
02873       // pretty hierarchy
02874       cout << "Hierarchy by depth:" << endl;
02875       for (typename vector<vector<class_node<Tlabel>*>*>::iterator
02876              j = all_depths.begin(); j != all_depths.end(); ++j) {
02877         vector<class_node<Tlabel>*> &d = **j;
02878         cout << "depth " << j - all_depths.begin() << " ("
02879              << d.size() << " nodes): ";
02880         for (typename vector<class_node<Tlabel>*>::iterator k = d.begin();
02881              k != d.end(); ++k)
02882           cout << (*k)->name() << " ";
02883         cout << endl;
02884       }
02885       // pretty complete-depth hierarchy
02886       cout << "Hierarchy by depth, keeping lower depth nodes with samples:"
02887            << endl;
02888       for (typename vector<vector<class_node<Tlabel>*>*>::iterator
02889              j = complete_depths.begin(); j != complete_depths.end(); ++j) {
02890         vector<class_node<Tlabel>*> &d = **j;
02891         cout << "depth " << j - complete_depths.begin() << " ("
02892              << d.size() << " nodes): ";
02893         for (typename vector<class_node<Tlabel>*>::iterator k = d.begin();
02894              k != d.end(); ++k)
02895           cout << (*k)->name() << " ";
02896         cout << endl;
02897       }
02898     }
02899   }
02900 
02901   template <typename Tnet, typename Tdata, typename Tlabel>
02902   void hierarchy_datasource<Tnet, Tdata, Tlabel>::print_path(Tlabel l) {
02903     bool first = true;
02904     for (class_node<Tlabel> *n = all_nodes[l]; n != NULL; n = n->get_parent()) {
02905       if (first)
02906         first = false;
02907       else
02908         cout << " <- ";
02909       cout << n->name();
02910     }
02911   }
02912 
02914   // labeled_pair_datasource
02915 
02916   // constructor
02917   template <typename Tnet, typename Tdata, typename Tlabel>
02918   labeled_pair_datasource<Tnet, Tdata, Tlabel>::
02919   labeled_pair_datasource(const char *data_fname, const char *labels_fname,
02920                           const char *classes_fname, const char *pairs_fname,
02921                           const char *name_, Tdata b, float c)
02922     : labeled_datasource<Tnet, Tdata, Tlabel>(data_fname, labels_fname,
02923                                               classes_fname, name_, b, c),
02924       pairs(1, 1) { //, pairsIter(pairs, 0) {
02925     // init current class
02926     try {
02927       pairs = load_matrix<intg>(pairs_fname);
02928     } catch(string &err) {
02929       cerr << "error: " << err << endl;
02930       cerr << "failed to load dataset file " << pairs_fname << endl;
02931       eblerror("Failed to load dataset file");
02932     }
02933     eblerror("not implemented");
02934     //    typename idx<intg>::dimension_iterator         diter(pairs, 0);
02935     // pairsIter = diter;
02936   }
02937 
02938   // constructor
02939   template <typename Tnet, typename Tdata, typename Tlabel>
02940   labeled_pair_datasource<Tnet, Tdata, Tlabel>::
02941   labeled_pair_datasource(idx<Tdata> &data_, idx<Tlabel> &labels_,
02942                           idx<ubyte> &classes_, idx<intg> &pairs_,
02943                           const char *name_, Tdata b, float c)
02944     : labeled_datasource<Tnet, Tdata, Tlabel>(data_, labels_, classes_, name_,
02945                                               b, c),
02946       pairs(pairs_) { //, pairsIter(pairs, 0) {
02947   }
02948 
02949   // destructor
02950   template <typename Tnet, typename Tdata, typename Tlabel>
02951   labeled_pair_datasource<Tnet, Tdata, Tlabel>::~labeled_pair_datasource() {
02952   }
02953 
02954   // fprop pair
02955   template <typename Tnet, typename Tdata, typename Tlabel>
02956   void labeled_pair_datasource<Tnet, Tdata, Tlabel>::
02957   fprop(bbstate_idx<Tnet> &in1, bbstate_idx<Tnet> &in2,
02958         bbstate_idx<Tlabel> &label) {
02959     eblerror("fixme");
02960     // in1.resize(this->sample_dims());
02961     // in2.resize(this->sample_dims());
02962     // intg id1 = pairsIter.get(0), id2 = pairsIter.get(1);
02963     // Tlabel lab = this->labels.get(id1);
02964     // label.x.set(lab);
02965     // idx<Tdata> im1 = this->data[id1], im2 = this->data[id2];
02966     // idx_copy(im1, in1.x);
02967     // idx_copy(im2, in2.x);
02968     // idx_addc(in1.x, this->bias, in1.x);
02969     // idx_dotc(in1.x, this->coeff, in1.x);
02970     // idx_addc(in2.x, this->bias, in2.x);
02971     // idx_dotc(in2.x, this->coeff, in2.x);
02972   }
02973 
02974   // next pair
02975   template <typename Tnet, typename Tdata, typename Tlabel>
02976   bool labeled_pair_datasource<Tnet, Tdata, Tlabel>::next() {
02977     eblerror("not implemented");
02978     // ++pairsIter;
02979     // if(!pairsIter.notdone()) {
02980     //   pairsIter = pairs.dim_begin(0);
02981     //   return false;
02982     // }
02983     return true;
02984   }
02985 
02986   // begin pair
02987   template <typename Tnet, typename Tdata, typename Tlabel>
02988   void labeled_pair_datasource<Tnet, Tdata, Tlabel>::seek_begin() {
02989     eblerror("not implemented");
02990     // pairsIter = pairs.dim_begin(0);
02991   }
02992 
02993   template <typename Tnet, typename Tdata, typename Tlabel>
02994   unsigned int labeled_pair_datasource<Tnet, Tdata, Tlabel>::size() {
02995     return pairs.dim(0);
02996   }
02997 
02999   // mnist_datasource
03000 
03001   template <typename Tnet, typename Tdata, typename Tlabel>
03002   mnist_datasource<Tnet, Tdata, Tlabel>::
03003   mnist_datasource(const char *root, bool train_data, uint size) {
03004     try {
03005       // load dataset
03006       string datafile, labelfile, name, setname = "MNIST";
03007       if (train_data) // training set
03008         setname = "train";
03009       else // testing set
03010         setname = "t10k";
03011       datafile << root << "/" << setname << "-images-idx3-ubyte";
03012       labelfile << root << "/" << setname << "-labels-idx1-ubyte";
03013       name << "MNIST " << setname;
03014       idx<Tdata> dat = load_matrix<Tdata>(datafile);
03015       idx<Tlabel> labs = load_matrix<Tlabel>(labelfile);
03016       dat = dat.narrow(0, MIN(dat.dim(0), (intg) size), 0);
03017       labs = labs.narrow(0, MIN(labs.dim(0), (intg) size), 0);
03018       mnist_datasource<Tnet,Tdata,Tlabel>::init(dat, labs, name.c_str());
03019       if (!train_data)
03020         this->set_test(); // remember that this is the testing set
03021       this->init_epoch();
03022       this->pretty(); // print info about dataset
03023     } catch(string &err) {
03024       eblerror("failed to load mnist dataset: " << err);
03025     } catch(bad_alloc& ba) {
03026       eblerror("bad_alloc: " << ba.what());
03027     }
03028   }
03029 
03030   template <typename Tnet, typename Tdata, typename Tlabel>
03031   mnist_datasource<Tnet, Tdata, Tlabel>::
03032   mnist_datasource(const char *root, const char *name, uint size) {
03033     try {
03034       // load dataset
03035       string datafile, labelfile;
03036       datafile << root << "/" << name << "_" << DATA_NAME << MATRIX_EXTENSION;
03037       labelfile << root << "/" << name
03038                 << "_" << LABELS_NAME << MATRIX_EXTENSION;
03039       idx<Tdata> dat = load_matrix<Tdata>(datafile);
03040       idx<Tlabel> labs = load_matrix<Tlabel>(labelfile);
03041       dat = dat.narrow(0, MIN((uint) dat.dim(0), size), 0);
03042       labs = labs.narrow(0, MIN((uint) labs.dim(0), size), 0);
03043       mnist_datasource<Tnet,Tdata,Tlabel>::init(dat, labs, name);
03044       this->init_epoch();
03045       this->pretty(); // print info about dataset
03046     } catch(string &err) {
03047       eblerror("failed to load mnist dataset: " << err);
03048     } catch(bad_alloc& ba) {
03049       eblerror("bad_alloc: " << ba.what());
03050     }
03051   }
03052 
03053   template <typename Tnet, typename Tdata, typename Tlabel>
03054   mnist_datasource<Tnet, Tdata, Tlabel>::~mnist_datasource() {
03055   }
03056 
03058 
03059   template <typename Tnet, typename Tdata, typename Tlabel>
03060   void mnist_datasource<Tnet, Tdata, Tlabel>::
03061   fprop_data(bbstate_idx<Tnet> &out) {
03062     if (out.x.order() != sampledims.order())
03063       out = bbstate_idx<Tnet>(sampledims);
03064     else
03065       out.resize(this->sample_dims());
03066     idx<Tdata> dat;
03067     if (multimat)
03068       dat = datas.get(it);
03069     else
03070       dat = data[it];
03071     uint ni = data.dim(1);
03072     uint nj = data.dim(2);
03073     uint di = (uint) (0.5 * (height - ni));
03074     uint dj = (uint) (0.5 * (width - nj));
03075     out.clear_x();
03076     idx<Tnet> tgt = out.x.select(0, 0);
03077     tgt = tgt.narrow(0, ni, di);
03078     tgt = tgt.narrow(1, nj, dj);
03079     idx_copy(dat, tgt);
03080     // bias and coeff
03081     if (bias != 0)
03082       idx_addc(out.x, bias, out.x);
03083     if (coeff != 1)
03084       idx_dotc(out.x, coeff, out.x);
03085   }
03086 
03087   template <typename Tnet, typename Tdata, typename Tlabel>
03088   void mnist_datasource<Tnet, Tdata, Tlabel>::
03089   init(idx<Tdata> &data_, idx<Tlabel> &labels_, const char *name_) {
03090     class_datasource<Tnet, Tdata, Tlabel>::init(data_, labels_, NULL, name_);
03091     this->set_data_coeff(.01); // scale input data from [0,255] to [0,2.55]
03092     // mnist is actually 28x28, but we add some padding
03093     sampledims = idxdim(1, 32, 32);
03094     height = sampledims.dim(1);
03095     width = sampledims.dim(2);
03096   }
03097 
03099 
03100   template <typename Tdata>
03101   idx<Tdata> create_target_matrix(intg nclasses, Tdata target) {
03102     // fill matrix with 1-of-n code
03103     idx<Tdata> targets(nclasses, nclasses);
03104     idx_fill(targets, -target);
03105     for (int i = 0; i < nclasses; ++i) {
03106       targets.set(target, i, i);
03107     }
03108     return targets; // return by copy
03109   }
03110 
03111 } // end namespace ebl
03112 
03113 #endif /*DATASOURCE_HPP_*/