00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #ifndef DATASOURCE_H_
00034 #define DATASOURCE_H_
00035
00036 #include "libidx.h"
00037 #include "Ebl.h"
00038
00039 #ifdef __GUI__
00040 #include "libidxgui.h"
00041 #endif
00042
00043 using namespace std;
00044
00045 namespace ebl {
00046
00047 template<typename Tdata, typename Tlabel> class LabeledDataSource {
00048 public:
00049 double bias;
00050 double coeff;
00051 Idx<Tdata> data;
00052 Idx<Tlabel> labels;
00053 typename Idx<Tdata>::dimension_iterator dataIter;
00054 typename Idx<Tlabel>::dimension_iterator labelsIter;
00055 unsigned int height;
00056 unsigned int width;
00057 vector<string*> *lblstr;
00058 const char *name;
00059 unsigned int display_wid;
00060
00062 LabeledDataSource();
00063
00064 void init(Idx<Tdata> &inp, Idx<Tlabel> &lbl, double b, double c,
00065 const char *name, vector<string*> *lblstr);
00066
00073 LabeledDataSource(Idx<Tdata> &inputs, Idx<Tlabel> &labels,
00074 double b = 0.0, double c = 0.01,
00075 const char *name = NULL,
00076 vector<string*> *lblstr = NULL);
00077
00078 virtual ~LabeledDataSource();
00079
00081 void virtual fprop(state_idx &datum, Idx<Tlabel> &label);
00082
00084 virtual int size();
00085
00087
00088 virtual int tell() { return -1; };
00089
00091 virtual void next();
00092
00094 virtual void seek_begin();
00095
00096 virtual void display(unsigned int nh, unsigned int nw,
00097 unsigned int h0 = 0, unsigned int w0 = 0,
00098 double zoom = 1.0, int wid = -1,
00099 const char *wname = NULL);
00100
00101 virtual void draw(unsigned int nh, unsigned int nw, unsigned int h0 = 0,
00102 unsigned int w0 = 0, double zoom = 1.0);
00103 };
00104
00106
00112 template<class Tdata, class Tlabel>
00113 class MnistDataSource : public LabeledDataSource<Tdata, Tlabel> {
00114 public:
00115 double bias;
00116 double coeff;
00117
00119 MnistDataSource() {};
00120
00128 MnistDataSource(Idx<Tdata> &inp, Idx<Tlabel> &lbl,
00129 intg w, intg h, double b, double c,
00130 const char *name = NULL);
00131 virtual ~MnistDataSource () {}
00132
00133 virtual void init(Idx<Tdata> &inp, Idx<Tlabel> &lbl, intg w, intg h,
00134 double b, double c, const char *name);
00135
00139 virtual void fprop(state_idx &out, Idx<Tlabel> &label);
00140
00141 virtual void display(unsigned int nh, unsigned int nw,
00142 unsigned int h0 = 0, unsigned int w0 = 0,
00143 double zoom = 1.0, int wid = -1,
00144 const char *wname = NULL);
00145 };
00146
00148
00149
00153 template<class Tdata, class Tlabel>
00154 bool load_mnist_dataset(const char *directory,
00155 MnistDataSource<Tdata,Tlabel> &train_ds,
00156 MnistDataSource<Tdata,Tlabel> &test_ds,
00157 int train_size, int test_size);
00158
00162 Idx<double> create_target_matrix(intg nclasses, double target);
00163
00165
00169
00170
00171
00172
00173
00174
00175
00179
00180
00181
00182
00184
00185
00187
00188
00189
00190
00191 }
00192
00193 #include "DataSource.hpp"
00194
00195 #endif