|
MNIST demo:
handwritten digits recognition using the eblearn C++
library
The eblearn (energy-based learning) C++
library libeblearn contains machine learning algorithms
which can be used for computer vision.
The library has a generic and modular architecture, allowing
easy prototyping and building of different algorithms
(supervised or unsupervised learning) and configurations
from basic modules. Those algorithms were used for a variety for
applications, including robotics with the
Learning
Applied to Ground Robots DARPA project (LAGR).
In this demo, we show how to train a convolutional neural
network to identify the digit value in images of handwritten
digits.
For that purpose, we use
the MNIST
dataset which contains 60,000 images of handwritten digits
for training and
10,000 for testing.
Dataset and classification
This is the visualization of the dataset. The right part
shows the hardest examples, the .9% samples that the network did
not manage to classify correctly in this demo.
First 100 samples of the MNIST
testing dataset with groundtruth at the left,
correct and incorrect answers at the middle (incorrect are
boxed but none are here in the first 100 samples) and
incorrect only samples on the right.
Internal Neural Network states
The internal states of the first two samples of the MNIST
full testing set (10,000 samples), before and after training.
a. before training
|
b. after training down to 0.9%
error rate.
|
Quick training and testing demo
Here we quickly go through the main components of the demo so that you can experiment yourself with the library as fast as possible.
1 #include "libeblearn.h"
2
3 #ifdef __GUI__
4 #include "libeblearngui.h"
5 #endif
6
7 using namespace std;
8 using namespace ebl; // all eblearn objects are under the ebl namespace
9
10 // argv[1] is expected to contain the directory of the mnist dataset
11 #ifdef __GUI__
12 MAIN_QTHREAD() { // this is the macro replacing main to enable multithreaded gui
13 #else
14 int main(int argc, char **argv) { // regular main without gui
15 #endif
16 cout << "* MNIST demo: learning handwritten digits using the eblearn";
17 cout << " C++ library *" << endl;
18 if (argc != 2) {
19 cout << "Usage: ./mnist <my mnist directory>" << endl;
20 eblerror("MNIST path not specified");
21 }
22 init_drand(time(NULL)); // initialize random seed
23
24 intg trsize = 60000; // maximum training set size: 60000
25 intg tesize = 10000; // maximum testing set size: 10000
26
27 //! load MNIST datasets: trize for training set and tesize for testing set
28 mnist_datasource<ubyte,ubyte> train_ds, test_ds;
29 load_mnist_dataset(argv[1], train_ds, test_ds, trsize, tesize);
30
31 //! create 1-of-n targets with target 1.0 for shown class, -1.0 for the rest
32 idx<double> targets = create_target_matrix(1+idx_max(train_ds.labels), 1.0);
33
34 //! create the network weights, network and trainer
35 idxdim dims(train_ds.sample_dims()); // get order and dimensions of sample
36 parameter theparam(60000); // create trainable parameter
37 lenet5 l5(theparam, 32, 32, 5, 5, 2, 2, 5, 5, 2, 2, 120, targets.dim(0));
38 supervised_euclidean_machine thenet(l5, targets, dims);
39 supervised_trainer<ubyte,ubyte> thetrainer(thenet, theparam);
40 supervised_trainer_gui stgui; // the gui to display supervised_trainer
41
42 //! a classifier-meter measures classification errors
43 classifier_meter trainmeter, testmeter;
44
45 //! initialize the network weights
46 forget_param_linear fgp(1, 0.5);
47 thenet.forget(fgp);
48
49 // learning parameters
50 gd_param gdp(/* double leta*/ 0.0001,
51 /* double ln */ 0.0,
52 /* double l1 */ 0.0,
53 /* double l2 */ 0.0,
54 /* int dtime */ 0,
55 /* double iner */0.0,
56 /* double a_v */ 0.0,
57 /* double a_t */ 0.0,
58 /* double g_t*/ 0.0);
59 infer_param infp;
60
61 // estimate second derivative on 100 iterations, using mu=0.02
62 cout << "Computing second derivatives on MNIST dataset: ";
63 thetrainer.compute_diaghessian(train_ds, 100, 0.02);
64
65 // first show classification results without training
66 thetrainer.test(train_ds, trainmeter, infp);
67 thetrainer.test(test_ds, testmeter, infp);
68 stgui.display_datasource(thetrainer, test_ds, infp, 10, 10);
69 stgui.display_internals(thetrainer, test_ds, infp, 2);
70
71 // now do training iterations
72 cout << "Training network on MNIST with " << train_ds.size();
73 cout << " training samples and " << test_ds.size() << " test samples:" << endl;
74 for (int i = 0; i < 100; ++i) {
75 thetrainer.train(train_ds, trainmeter, gdp, 1); // train
76 thetrainer.test(train_ds, trainmeter, infp); // test
77 thetrainer.test(test_ds, testmeter, infp); // test
78 stgui.display_datasource(thetrainer, test_ds, infp, 10, 10); // display
79 stgui.display_internals(thetrainer, test_ds, infp, 2); // display
80 thetrainer.compute_diaghessian(train_ds, 100, 0.02); // recompute 2nd der
81 }
82 return 0;
83 }
Output
The output of the demo. In addition to the graphical
outputs, the results are displayed on the terminal output
showing the errors rates on both the training and testing
set. Here we reach 0.9% error rate at the 37th
iteration.
/eblearn-trunk/bin$ ./mnist ../../datasets/mnist/
* MNIST demo: learning handwritten digits using the eblearn C++ library *
Computing second derivatives on MNIST dataset: diaghessian inf: 0.987101 sup: 49.7794
MNIST TRAINING set: iter# 0 [0] sz=60000 energy=5.0819 correct=14.695% errors=85.305% rejects=0%
MNIST TESTING set: iter# 0 [0] sz=10000 energy=5.08783 correct=15.09% errors=84.91% rejects=0%
Training network on MNIST with 60000 training samples and 10000 test samples:
MNIST TRAINING set: iter# 1 [60000] sz=60000 energy=0.0701102 correct=98.0467% errors=1.95333% rejects=0%
MNIST TESTING set: iter# 1 [60000] sz=10000 energy=0.0653707 correct=98.15% errors=1.85% rejects=0%
diaghessian inf: 26.3591 sup: 49.9996
MNIST TRAINING set: iter# 2 [120000] sz=60000 energy=0.0546717 correct=98.47% errors=1.53% rejects=0%
MNIST TESTING set: iter# 2 [120000] sz=10000 energy=0.0544861 correct=98.32% errors=1.68% rejects=0%
diaghessian inf: 33.4199 sup: 49.9999
MNIST TRAINING set: iter# 3 [180000] sz=60000 energy=0.0429966 correct=98.8167% errors=1.18333% rejects=0%
MNIST TESTING set: iter# 3 [180000] sz=10000 energy=0.0476517 correct=98.62% errors=1.38% rejects=0%
diaghessian inf: 26.9359 sup: 50
MNIST TRAINING set: iter# 4 [240000] sz=60000 energy=0.0380343 correct=98.97% errors=1.03% rejects=0%
MNIST TESTING set: iter# 4 [240000] sz=10000 energy=0.0462413 correct=98.7% errors=1.3% rejects=0%
diaghessian inf: 27.4113 sup: 50
MNIST TRAINING set: iter# 5 [300000] sz=60000 energy=0.0331095 correct=99.13% errors=0.87% rejects=0%
MNIST TESTING set: iter# 5 [300000] sz=10000 energy=0.0469152 correct=98.72% errors=1.28% rejects=0%
diaghessian inf: 25.9917 sup: 50
MNIST TRAINING set: iter# 6 [360000] sz=60000 energy=0.0354568 correct=99.0633% errors=0.936667% rejects=0%
MNIST TESTING set: iter# 6 [360000] sz=10000 energy=0.0507304 correct=98.6% errors=1.4% rejects=0%
diaghessian inf: 31.1135 sup: 49.9999
MNIST TRAINING set: iter# 7 [420000] sz=60000 energy=0.0300788 correct=99.2267% errors=0.773333% rejects=0%
MNIST TESTING set: iter# 7 [420000] sz=10000 energy=0.0458589 correct=98.73% errors=1.27% rejects=0%
diaghessian inf: 31.014 sup: 50
MNIST TRAINING set: iter# 8 [480000] sz=60000 energy=0.0256149 correct=99.34% errors=0.66% rejects=0%
MNIST TESTING set: iter# 8 [480000] sz=10000 energy=0.0434282 correct=98.72% errors=1.28% rejects=0%
diaghessian inf: 31.8147 sup: 50
MNIST TRAINING set: iter# 9 [540000] sz=60000 energy=0.0205026 correct=99.485% errors=0.515% rejects=0%
MNIST TESTING set: iter# 9 [540000] sz=10000 energy=0.0401836 correct=98.82% errors=1.18% rejects=0%
diaghessian inf: 33.4823 sup: 50
MNIST TRAINING set: iter# 10 [600000] sz=60000 energy=0.0218916 correct=99.4783% errors=0.521667% rejects=0%
MNIST TESTING set: iter# 10 [600000] sz=10000 energy=0.0430964 correct=98.79% errors=1.21% rejects=0%
diaghessian inf: 31.214 sup: 50
MNIST TRAINING set: iter# 11 [660000] sz=60000 energy=0.0167645 correct=99.57% errors=0.43% rejects=0%
MNIST TESTING set: iter# 11 [660000] sz=10000 energy=0.0386548 correct=98.89% errors=1.11% rejects=0%
diaghessian inf: 36.5078 sup: 50
MNIST TRAINING set: iter# 12 [720000] sz=60000 energy=0.0184436 correct=99.5317% errors=0.468333% rejects=0%
MNIST TESTING set: iter# 12 [720000] sz=10000 energy=0.0410647 correct=98.86% errors=1.14% rejects=0%
diaghessian inf: 42.0382 sup: 50
MNIST TRAINING set: iter# 13 [780000] sz=60000 energy=0.0181611 correct=99.5883% errors=0.411667% rejects=0%
MNIST TESTING set: iter# 13 [780000] sz=10000 energy=0.0421034 correct=98.88% errors=1.12% rejects=0%
diaghessian inf: 44.4364 sup: 50
MNIST TRAINING set: iter# 14 [840000] sz=60000 energy=0.0196715 correct=99.5183% errors=0.481667% rejects=0%
MNIST TESTING set: iter# 14 [840000] sz=10000 energy=0.0441488 correct=98.78% errors=1.22% rejects=0%
diaghessian inf: 37.5865 sup: 50
MNIST TRAINING set: iter# 15 [900000] sz=60000 energy=0.0139162 correct=99.665% errors=0.335% rejects=0%
MNIST TESTING set: iter# 15 [900000] sz=10000 energy=0.0395798 correct=98.98% errors=1.02% rejects=0%
diaghessian inf: 46.7421 sup: 50
MNIST TRAINING set: iter# 16 [960000] sz=60000 energy=0.0125451 correct=99.6933% errors=0.306667% rejects=0%
MNIST TESTING set: iter# 16 [960000] sz=10000 energy=0.0391716 correct=98.89% errors=1.11% rejects=0%
diaghessian inf: 41.7718 sup: 50
MNIST TRAINING set: iter# 17 [1020000] sz=60000 energy=0.0134095 correct=99.66% errors=0.34% rejects=0%
MNIST TESTING set: iter# 17 [1020000] sz=10000 energy=0.0409884 correct=98.78% errors=1.22% rejects=0%
diaghessian inf: 31.8297 sup: 50
MNIST TRAINING set: iter# 18 [1080000] sz=60000 energy=0.00921904 correct=99.7717% errors=0.228333% rejects=0%
MNIST TESTING set: iter# 18 [1080000] sz=10000 energy=0.0371163 correct=98.95% errors=1.05% rejects=0%
diaghessian inf: 34.3347 sup: 50
MNIST TRAINING set: iter# 19 [1140000] sz=60000 energy=0.00918398 correct=99.7683% errors=0.231667% rejects=0%
MNIST TESTING set: iter# 19 [1140000] sz=10000 energy=0.0373996 correct=98.92% errors=1.08% rejects=0%
diaghessian inf: 47.2771 sup: 50
MNIST TRAINING set: iter# 20 [1200000] sz=60000 energy=0.00964223 correct=99.7617% errors=0.238333% rejects=0%
MNIST TESTING set: iter# 20 [1200000] sz=10000 energy=0.0383149 correct=98.87% errors=1.13% rejects=0%
diaghessian inf: 45.1764 sup: 50
MNIST TRAINING set: iter# 21 [1260000] sz=60000 energy=0.00844604 correct=99.7817% errors=0.218333% rejects=0%
MNIST TESTING set: iter# 21 [1260000] sz=10000 energy=0.0377095 correct=98.86% errors=1.14% rejects=0%
diaghessian inf: 48.9093 sup: 50
MNIST TRAINING set: iter# 22 [1320000] sz=60000 energy=0.00752063 correct=99.8217% errors=0.178333% rejects=0%
MNIST TESTING set: iter# 22 [1320000] sz=10000 energy=0.0367743 correct=98.93% errors=1.07% rejects=0%
diaghessian inf: 41.4839 sup: 50
MNIST TRAINING set: iter# 23 [1380000] sz=60000 energy=0.0063252 correct=99.835% errors=0.165% rejects=0%
MNIST TESTING set: iter# 23 [1380000] sz=10000 energy=0.0353768 correct=98.99% errors=1.01% rejects=0%
diaghessian inf: 45.3013 sup: 50
MNIST TRAINING set: iter# 24 [1440000] sz=60000 energy=0.00708443 correct=99.8183% errors=0.181667% rejects=0%
MNIST TESTING set: iter# 24 [1440000] sz=10000 energy=0.0383037 correct=98.94% errors=1.06% rejects=0%
diaghessian inf: 42.3925 sup: 50
MNIST TRAINING set: iter# 25 [1500000] sz=60000 energy=0.00498736 correct=99.8517% errors=0.148333% rejects=0%
MNIST TESTING set: iter# 25 [1500000] sz=10000 energy=0.0368142 correct=99% errors=1% rejects=0%
diaghessian inf: 47.6495 sup: 50
MNIST TRAINING set: iter# 26 [1560000] sz=60000 energy=0.00692641 correct=99.83% errors=0.17% rejects=0%
MNIST TESTING set: iter# 26 [1560000] sz=10000 energy=0.0393558 correct=98.89% errors=1.11% rejects=0%
diaghessian inf: 28.1664 sup: 50
MNIST TRAINING set: iter# 27 [1620000] sz=60000 energy=0.00449746 correct=99.8667% errors=0.133333% rejects=0%
MNIST TESTING set: iter# 27 [1620000] sz=10000 energy=0.0357185 correct=99.01% errors=0.99% rejects=0%
diaghessian inf: 49.2415 sup: 50
MNIST TRAINING set: iter# 28 [1680000] sz=60000 energy=0.00619437 correct=99.8367% errors=0.163333% rejects=0%
MNIST TESTING set: iter# 28 [1680000] sz=10000 energy=0.0376879 correct=98.92% errors=1.08% rejects=0%
diaghessian inf: 49.585 sup: 50
MNIST TRAINING set: iter# 29 [1740000] sz=60000 energy=0.00482012 correct=99.8583% errors=0.141667% rejects=0%
MNIST TESTING set: iter# 29 [1740000] sz=10000 energy=0.0360403 correct=98.97% errors=1.03% rejects=0%
diaghessian inf: 48.43 sup: 50
MNIST TRAINING set: iter# 30 [1800000] sz=60000 energy=0.00421703 correct=99.8683% errors=0.131667% rejects=0%
MNIST TESTING set: iter# 30 [1800000] sz=10000 energy=0.035505 correct=98.98% errors=1.02% rejects=0%
diaghessian inf: 47.8533 sup: 50
MNIST TRAINING set: iter# 31 [1860000] sz=60000 energy=0.00565429 correct=99.845% errors=0.155% rejects=0%
MNIST TESTING set: iter# 31 [1860000] sz=10000 energy=0.0369994 correct=98.89% errors=1.11% rejects=0%
diaghessian inf: 49.3615 sup: 50
MNIST TRAINING set: iter# 32 [1920000] sz=60000 energy=0.00464666 correct=99.8667% errors=0.133333% rejects=0%
MNIST TESTING set: iter# 32 [1920000] sz=10000 energy=0.0368726 correct=98.98% errors=1.02% rejects=0%
diaghessian inf: 49.3121 sup: 50
MNIST TRAINING set: iter# 33 [1980000] sz=60000 energy=0.00354118 correct=99.88% errors=0.12% rejects=0%
MNIST TESTING set: iter# 33 [1980000] sz=10000 energy=0.0347413 correct=99.08% errors=0.92% rejects=0%
diaghessian inf: 49.6058 sup: 50
MNIST TRAINING set: iter# 34 [2040000] sz=60000 energy=0.0034331 correct=99.885% errors=0.115% rejects=0%
MNIST TESTING set: iter# 34 [2040000] sz=10000 energy=0.034873 correct=99.05% errors=0.95% rejects=0%
diaghessian inf: 49.4398 sup: 50
MNIST TRAINING set: iter# 35 [2100000] sz=60000 energy=0.00347821 correct=99.8833% errors=0.116667% rejects=0%
MNIST TESTING set: iter# 35 [2100000] sz=10000 energy=0.0352945 correct=99.06% errors=0.94% rejects=0%
diaghessian inf: 49.5109 sup: 50
MNIST TRAINING set: iter# 36 [2160000] sz=60000 energy=0.00326626 correct=99.885% errors=0.115% rejects=0%
MNIST TESTING set: iter# 36 [2160000] sz=10000 energy=0.0349153 correct=99.07% errors=0.93% rejects=0%
diaghessian inf: 49.5637 sup: 50
MNIST TRAINING set: iter# 37 [2220000] sz=60000 energy=0.00321661 correct=99.885% errors=0.115% rejects=0%
MNIST TESTING set: iter# 37 [2220000] sz=10000 energy=0.0347334 correct=99.1% errors=0.9% rejects=0%
Detailed explanation
Here we quickly go through the main components of the demo so that you can experiment yourself with the library as fast as possible.
1 #include "libeblearn.h"
2
3 using namespace std;
4 using namespace ebl; // all eblearn objects are under the ebl namespace
5
6 // argv[1] is expected to contain the directory of the mnist dataset
7 int main(int argc, char **argv) {
8 cout << endl << "* MNIST demo: learning handwritten digits using the eblearn";
9 cout << " C++ library *" << endl;
10 init_drand(time(NULL)); // initialize random seed
11
12 intg trsize = 2000; // maximum training set size: 60000
13 intg tesize = 1000; // maximum testing set size: 10000
14
Details.
15 // load MNIST datasets: trize for training set and tesize for testing set
16 MnistDataSource<ubyte,ubyte> train_ds, test_ds;
17 load_mnist_dataset(argv[1], train_ds, test_ds, trsize, tesize);
Details.
18
19 // create 1-of-n targets with target 1.0 for shown class, -1.0 for the rest
20 Idx<double> targets = create_target_matrix(1+idx_max(train_ds.labels), 1.0);
21
Details.
22 // create the network weights, network and trainer
23 IdxDim dims(train_ds.data.spec); // get order and dimenions from data
24 parameter theparam(60000); // create trainable parameter
25 lenet5 l5(theparam, 32, 32, 5, 5, 2, 2, 5, 5, 2, 2, 120, targets.dim(0));
26 supervised_euclidean_machine thenet(l5, targets, dims);
27 supervised_trainer<ubyte,ubyte> thetrainer(thenet, theparam);
28
29 // a classifier-meter measures classification errors
30 classifier_meter trainmeter, testmeter;
31
Details.
32 // initialize the network weights
33 forget_param_linear fgp(1, 0.5);
34 thenet.forget(fgp);
35
36 // learning parameters
37 gd_param gdp(/* double leta*/ 0.0001,
38 /* double ln */ 0.0,
39 /* double l1 */ 0.0,
40 /* double l2 */ 0.0,
41 /* int dtime */ 0,
42 /* double iner */0.0,
43 /* double a_v */ 0.0,
44 /* double a_t */ 0.0,
45 /* double g_t*/ 0.0);
46
Details.
47 // estimate second derivative on 100 iterations, using mu=0.02
48 cout << "Computing second derivatives on MNIST dataset: ";
49 thetrainer.compute_diaghessian(train_ds, 100, 0.02);
50
Details.
51 // do training iterations
52 cout << "Training network on MNIST with " << train_ds.size();
53 cout << " training samples and " << test_ds.size() << " test samples" << endl;
54 for (int i = 0; i < 5; ++i) {
55 thetrainer.train(train_ds, trainmeter, gdp, 1);
56 cout << "training: " << flush;
57 thetrainer.test(train_ds, trainmeter);
58 trainmeter.display();
59 cout << " testing: " << flush;
60 thetrainer.test(test_ds, testmeter);
61 testmeter.display();
62 }
63 return 0;
64 }
|