MNIST demo: handwritten digits recognition using the eblearn C++ library

By Pierre Sermanet and Yann LeCun (New York University)



The eblearn (energy-based learning) C++ library libeblearn contains machine learning algorithms which can be used for computer vision. The library has a generic and modular architecture, allowing easy prototyping and building of different algorithms (supervised or unsupervised learning) and configurations from basic modules. Those algorithms were used for a variety for applications, including robotics with the Learning Applied to Ground Robots DARPA project (LAGR).

In this demo, we show how to train a convolutional neural network to identify the digit value in images of handwritten digits. For that purpose, we use the MNIST dataset which contains 60,000 images of handwritten digits for training and 10,000 for testing.

Dataset and classification

This is the visualization of the dataset. The right part shows the hardest examples, the .9% samples that the network did not manage to classify correctly in this demo.

First 100 samples of the MNIST testing dataset with groundtruth at the left, correct and incorrect answers at the middle (incorrect are boxed but none are here in the first 100 samples) and incorrect only samples on the right.

Internal Neural Network states


The internal states of the first two samples of the MNIST full testing set (10,000 samples), before and after training.

a. before training
b. after training down to 0.9% error rate.

Quick training and testing demo

Here we quickly go through the main components of the demo so that you can experiment yourself with the library as fast as possible.

    1 #include "libeblearn.h"
    2 
    3 #ifdef __GUI__
    4 #include "libeblearngui.h"
    5 #endif
    6 
    7 using namespace std;
    8 using namespace ebl; // all eblearn objects are under the ebl namespace
    9 
   10 // argv[1] is expected to contain the directory of the mnist dataset
   11 #ifdef __GUI__
   12 MAIN_QTHREAD() { // this is the macro replacing main to enable multithreaded gui
   13 #else
   14 int main(int argc, char **argv) { // regular main without gui
   15 #endif
   16   cout << "* MNIST demo: learning handwritten digits using the eblearn";
   17   cout << " C++ library *" << endl;
   18   if (argc != 2) {
   19     cout << "Usage: ./mnist <my mnist directory>" << endl;
   20     eblerror("MNIST path not specified");
   21   }
   22   init_drand(time(NULL)); // initialize random seed
   23 
   24   intg trsize = 60000; // maximum training set size: 60000
   25   intg tesize = 10000; // maximum testing set size:  10000
   26 
   27   //! load MNIST datasets: trize for training set and tesize for testing set
   28   mnist_datasource<ubyte,ubyte> train_ds, test_ds;
   29   load_mnist_dataset(argv[1], train_ds, test_ds, trsize, tesize);
   30 
   31   //! create 1-of-n targets with target 1.0 for shown class, -1.0 for the rest
   32   idx<double> targets = create_target_matrix(1+idx_max(train_ds.labels), 1.0);
   33 
   34   //! create the network weights, network and trainer
   35   idxdim dims(train_ds.sample_dims()); // get order and dimensions of sample
   36   parameter theparam(60000); // create trainable parameter
   37   lenet5 l5(theparam, 32, 32, 5, 5, 2, 2, 5, 5, 2, 2, 120, targets.dim(0));
   38   supervised_euclidean_machine thenet(l5, targets, dims);
   39   supervised_trainer<ubyte,ubyte> thetrainer(thenet, theparam);
   40   supervised_trainer_gui stgui; // the gui to display supervised_trainer
   41 
   42   //! a classifier-meter measures classification errors
   43   classifier_meter trainmeter, testmeter;
   44 
   45   //! initialize the network weights
   46   forget_param_linear fgp(1, 0.5);
   47   thenet.forget(fgp);
   48 
   49   // learning parameters
   50   gd_param gdp(/* double leta*/ 0.0001,
   51 	       /* double ln */ 	0.0,
   52 	       /* double l1 */ 	0.0,
   53 	       /* double l2 */ 	0.0,
   54 	       /* int dtime */ 	0,
   55 	       /* double iner */0.0,
   56 	       /* double a_v */ 0.0,
   57 	       /* double a_t */ 0.0,
   58 	       /* double g_t*/ 	0.0);
   59   infer_param infp;
   60 
   61   // estimate second derivative on 100 iterations, using mu=0.02
   62   cout << "Computing second derivatives on MNIST dataset: ";
   63   thetrainer.compute_diaghessian(train_ds, 100, 0.02);
   64 
   65   // first show classification results without training
   66   thetrainer.test(train_ds, trainmeter, infp);
   67   thetrainer.test(test_ds, testmeter, infp);
   68   stgui.display_datasource(thetrainer, test_ds, infp, 10, 10);
   69   stgui.display_internals(thetrainer, test_ds, infp, 2);
   70 
   71   // now do training iterations
   72   cout << "Training network on MNIST with " << train_ds.size();
   73   cout << " training samples and " << test_ds.size() << " test samples:" << endl;
   74   for (int i = 0; i < 100; ++i) {
   75     thetrainer.train(train_ds, trainmeter, gdp, 1);	         // train
   76     thetrainer.test(train_ds, trainmeter, infp);	         // test
   77     thetrainer.test(test_ds, testmeter, infp);	                 // test
   78     stgui.display_datasource(thetrainer, test_ds, infp, 10, 10); // display
   79     stgui.display_internals(thetrainer, test_ds, infp, 2);       // display
   80     thetrainer.compute_diaghessian(train_ds, 100, 0.02); // recompute 2nd der
   81   }
   82   return 0;
   83 }

Output

The output of the demo. In addition to the graphical outputs, the results are displayed on the terminal output showing the errors rates on both the training and testing set. Here we reach 0.9% error rate at the 37th iteration.

/eblearn-trunk/bin$ ./mnist ../../datasets/mnist/
* MNIST demo: learning handwritten digits using the eblearn C++ library *
Computing second derivatives on MNIST dataset: diaghessian inf: 0.987101 sup: 49.7794
MNIST TRAINING set: iter# 0 [0]  sz=60000 energy=5.0819  correct=14.695% errors=85.305% rejects=0%
MNIST TESTING set: iter# 0 [0]  sz=10000 energy=5.08783  correct=15.09% errors=84.91% rejects=0%
Training network on MNIST with 60000 training samples and 10000 test samples:
MNIST TRAINING set: iter# 1 [60000]  sz=60000 energy=0.0701102  correct=98.0467% errors=1.95333% rejects=0%
MNIST TESTING set: iter# 1 [60000]  sz=10000 energy=0.0653707  correct=98.15% errors=1.85% rejects=0%
diaghessian inf: 26.3591 sup: 49.9996
MNIST TRAINING set: iter# 2 [120000]  sz=60000 energy=0.0546717  correct=98.47% errors=1.53% rejects=0%
MNIST TESTING set: iter# 2 [120000]  sz=10000 energy=0.0544861  correct=98.32% errors=1.68% rejects=0%
diaghessian inf: 33.4199 sup: 49.9999
MNIST TRAINING set: iter# 3 [180000]  sz=60000 energy=0.0429966  correct=98.8167% errors=1.18333% rejects=0%
MNIST TESTING set: iter# 3 [180000]  sz=10000 energy=0.0476517  correct=98.62% errors=1.38% rejects=0%
diaghessian inf: 26.9359 sup: 50
MNIST TRAINING set: iter# 4 [240000]  sz=60000 energy=0.0380343  correct=98.97% errors=1.03% rejects=0%
MNIST TESTING set: iter# 4 [240000]  sz=10000 energy=0.0462413  correct=98.7% errors=1.3% rejects=0%
diaghessian inf: 27.4113 sup: 50
MNIST TRAINING set: iter# 5 [300000]  sz=60000 energy=0.0331095  correct=99.13% errors=0.87% rejects=0%
MNIST TESTING set: iter# 5 [300000]  sz=10000 energy=0.0469152  correct=98.72% errors=1.28% rejects=0%
diaghessian inf: 25.9917 sup: 50
MNIST TRAINING set: iter# 6 [360000]  sz=60000 energy=0.0354568  correct=99.0633% errors=0.936667% rejects=0%
MNIST TESTING set: iter# 6 [360000]  sz=10000 energy=0.0507304  correct=98.6% errors=1.4% rejects=0%
diaghessian inf: 31.1135 sup: 49.9999
MNIST TRAINING set: iter# 7 [420000]  sz=60000 energy=0.0300788  correct=99.2267% errors=0.773333% rejects=0%
MNIST TESTING set: iter# 7 [420000]  sz=10000 energy=0.0458589  correct=98.73% errors=1.27% rejects=0%
diaghessian inf: 31.014 sup: 50
MNIST TRAINING set: iter# 8 [480000]  sz=60000 energy=0.0256149  correct=99.34% errors=0.66% rejects=0%
MNIST TESTING set: iter# 8 [480000]  sz=10000 energy=0.0434282  correct=98.72% errors=1.28% rejects=0%
diaghessian inf: 31.8147 sup: 50
MNIST TRAINING set: iter# 9 [540000]  sz=60000 energy=0.0205026  correct=99.485% errors=0.515% rejects=0%
MNIST TESTING set: iter# 9 [540000]  sz=10000 energy=0.0401836  correct=98.82% errors=1.18% rejects=0%
diaghessian inf: 33.4823 sup: 50
MNIST TRAINING set: iter# 10 [600000]  sz=60000 energy=0.0218916  correct=99.4783% errors=0.521667% rejects=0%
MNIST TESTING set: iter# 10 [600000]  sz=10000 energy=0.0430964  correct=98.79% errors=1.21% rejects=0%
diaghessian inf: 31.214 sup: 50
MNIST TRAINING set: iter# 11 [660000]  sz=60000 energy=0.0167645  correct=99.57% errors=0.43% rejects=0%
MNIST TESTING set: iter# 11 [660000]  sz=10000 energy=0.0386548  correct=98.89% errors=1.11% rejects=0%
diaghessian inf: 36.5078 sup: 50
MNIST TRAINING set: iter# 12 [720000]  sz=60000 energy=0.0184436  correct=99.5317% errors=0.468333% rejects=0%
MNIST TESTING set: iter# 12 [720000]  sz=10000 energy=0.0410647  correct=98.86% errors=1.14% rejects=0%
diaghessian inf: 42.0382 sup: 50
MNIST TRAINING set: iter# 13 [780000]  sz=60000 energy=0.0181611  correct=99.5883% errors=0.411667% rejects=0%
MNIST TESTING set: iter# 13 [780000]  sz=10000 energy=0.0421034  correct=98.88% errors=1.12% rejects=0%
diaghessian inf: 44.4364 sup: 50
MNIST TRAINING set: iter# 14 [840000]  sz=60000 energy=0.0196715  correct=99.5183% errors=0.481667% rejects=0%
MNIST TESTING set: iter# 14 [840000]  sz=10000 energy=0.0441488  correct=98.78% errors=1.22% rejects=0%
diaghessian inf: 37.5865 sup: 50
MNIST TRAINING set: iter# 15 [900000]  sz=60000 energy=0.0139162  correct=99.665% errors=0.335% rejects=0%
MNIST TESTING set: iter# 15 [900000]  sz=10000 energy=0.0395798  correct=98.98% errors=1.02% rejects=0%
diaghessian inf: 46.7421 sup: 50
MNIST TRAINING set: iter# 16 [960000]  sz=60000 energy=0.0125451  correct=99.6933% errors=0.306667% rejects=0%
MNIST TESTING set: iter# 16 [960000]  sz=10000 energy=0.0391716  correct=98.89% errors=1.11% rejects=0%
diaghessian inf: 41.7718 sup: 50
MNIST TRAINING set: iter# 17 [1020000]  sz=60000 energy=0.0134095  correct=99.66% errors=0.34% rejects=0%
MNIST TESTING set: iter# 17 [1020000]  sz=10000 energy=0.0409884  correct=98.78% errors=1.22% rejects=0%
diaghessian inf: 31.8297 sup: 50
MNIST TRAINING set: iter# 18 [1080000]  sz=60000 energy=0.00921904  correct=99.7717% errors=0.228333% rejects=0%
MNIST TESTING set: iter# 18 [1080000]  sz=10000 energy=0.0371163  correct=98.95% errors=1.05% rejects=0%
diaghessian inf: 34.3347 sup: 50
MNIST TRAINING set: iter# 19 [1140000]  sz=60000 energy=0.00918398  correct=99.7683% errors=0.231667% rejects=0%
MNIST TESTING set: iter# 19 [1140000]  sz=10000 energy=0.0373996  correct=98.92% errors=1.08% rejects=0%
diaghessian inf: 47.2771 sup: 50
MNIST TRAINING set: iter# 20 [1200000]  sz=60000 energy=0.00964223  correct=99.7617% errors=0.238333% rejects=0%
MNIST TESTING set: iter# 20 [1200000]  sz=10000 energy=0.0383149  correct=98.87% errors=1.13% rejects=0%
diaghessian inf: 45.1764 sup: 50
MNIST TRAINING set: iter# 21 [1260000]  sz=60000 energy=0.00844604  correct=99.7817% errors=0.218333% rejects=0%
MNIST TESTING set: iter# 21 [1260000]  sz=10000 energy=0.0377095  correct=98.86% errors=1.14% rejects=0%
diaghessian inf: 48.9093 sup: 50
MNIST TRAINING set: iter# 22 [1320000]  sz=60000 energy=0.00752063  correct=99.8217% errors=0.178333% rejects=0%
MNIST TESTING set: iter# 22 [1320000]  sz=10000 energy=0.0367743  correct=98.93% errors=1.07% rejects=0%
diaghessian inf: 41.4839 sup: 50
MNIST TRAINING set: iter# 23 [1380000]  sz=60000 energy=0.0063252  correct=99.835% errors=0.165% rejects=0%
MNIST TESTING set: iter# 23 [1380000]  sz=10000 energy=0.0353768  correct=98.99% errors=1.01% rejects=0%
diaghessian inf: 45.3013 sup: 50
MNIST TRAINING set: iter# 24 [1440000]  sz=60000 energy=0.00708443  correct=99.8183% errors=0.181667% rejects=0%
MNIST TESTING set: iter# 24 [1440000]  sz=10000 energy=0.0383037  correct=98.94% errors=1.06% rejects=0%
diaghessian inf: 42.3925 sup: 50
MNIST TRAINING set: iter# 25 [1500000]  sz=60000 energy=0.00498736  correct=99.8517% errors=0.148333% rejects=0%
MNIST TESTING set: iter# 25 [1500000]  sz=10000 energy=0.0368142  correct=99% errors=1% rejects=0%
diaghessian inf: 47.6495 sup: 50
MNIST TRAINING set: iter# 26 [1560000]  sz=60000 energy=0.00692641  correct=99.83% errors=0.17% rejects=0%
MNIST TESTING set: iter# 26 [1560000]  sz=10000 energy=0.0393558  correct=98.89% errors=1.11% rejects=0%
diaghessian inf: 28.1664 sup: 50
MNIST TRAINING set: iter# 27 [1620000]  sz=60000 energy=0.00449746  correct=99.8667% errors=0.133333% rejects=0%
MNIST TESTING set: iter# 27 [1620000]  sz=10000 energy=0.0357185  correct=99.01% errors=0.99% rejects=0%
diaghessian inf: 49.2415 sup: 50
MNIST TRAINING set: iter# 28 [1680000]  sz=60000 energy=0.00619437  correct=99.8367% errors=0.163333% rejects=0%
MNIST TESTING set: iter# 28 [1680000]  sz=10000 energy=0.0376879  correct=98.92% errors=1.08% rejects=0%
diaghessian inf: 49.585 sup: 50
MNIST TRAINING set: iter# 29 [1740000]  sz=60000 energy=0.00482012  correct=99.8583% errors=0.141667% rejects=0%
MNIST TESTING set: iter# 29 [1740000]  sz=10000 energy=0.0360403  correct=98.97% errors=1.03% rejects=0%
diaghessian inf: 48.43 sup: 50
MNIST TRAINING set: iter# 30 [1800000]  sz=60000 energy=0.00421703  correct=99.8683% errors=0.131667% rejects=0%
MNIST TESTING set: iter# 30 [1800000]  sz=10000 energy=0.035505  correct=98.98% errors=1.02% rejects=0%
diaghessian inf: 47.8533 sup: 50
MNIST TRAINING set: iter# 31 [1860000]  sz=60000 energy=0.00565429  correct=99.845% errors=0.155% rejects=0%
MNIST TESTING set: iter# 31 [1860000]  sz=10000 energy=0.0369994  correct=98.89% errors=1.11% rejects=0%
diaghessian inf: 49.3615 sup: 50
MNIST TRAINING set: iter# 32 [1920000]  sz=60000 energy=0.00464666  correct=99.8667% errors=0.133333% rejects=0%
MNIST TESTING set: iter# 32 [1920000]  sz=10000 energy=0.0368726  correct=98.98% errors=1.02% rejects=0%
diaghessian inf: 49.3121 sup: 50
MNIST TRAINING set: iter# 33 [1980000]  sz=60000 energy=0.00354118  correct=99.88% errors=0.12% rejects=0%
MNIST TESTING set: iter# 33 [1980000]  sz=10000 energy=0.0347413  correct=99.08% errors=0.92% rejects=0%
diaghessian inf: 49.6058 sup: 50
MNIST TRAINING set: iter# 34 [2040000]  sz=60000 energy=0.0034331  correct=99.885% errors=0.115% rejects=0%
MNIST TESTING set: iter# 34 [2040000]  sz=10000 energy=0.034873  correct=99.05% errors=0.95% rejects=0%
diaghessian inf: 49.4398 sup: 50
MNIST TRAINING set: iter# 35 [2100000]  sz=60000 energy=0.00347821  correct=99.8833% errors=0.116667% rejects=0%
MNIST TESTING set: iter# 35 [2100000]  sz=10000 energy=0.0352945  correct=99.06% errors=0.94% rejects=0%
diaghessian inf: 49.5109 sup: 50
MNIST TRAINING set: iter# 36 [2160000]  sz=60000 energy=0.00326626  correct=99.885% errors=0.115% rejects=0%
MNIST TESTING set: iter# 36 [2160000]  sz=10000 energy=0.0349153  correct=99.07% errors=0.93% rejects=0%
diaghessian inf: 49.5637 sup: 50
MNIST TRAINING set: iter# 37 [2220000]  sz=60000 energy=0.00321661  correct=99.885% errors=0.115% rejects=0%
MNIST TESTING set: iter# 37 [2220000]  sz=10000 energy=0.0347334  correct=99.1% errors=0.9% rejects=0%

Detailed explanation

Here we quickly go through the main components of the demo so that you can experiment yourself with the library as fast as possible.

    1 #include "libeblearn.h"
    2 
    3 using namespace std;
    4 using namespace ebl; // all eblearn objects are under the ebl namespace
    5 
    6 // argv[1] is expected to contain the directory of the mnist dataset
    7 int main(int argc, char **argv) {
    8   cout << endl << "* MNIST demo: learning handwritten digits using the eblearn";
    9   cout << " C++ library *" << endl;
   10   init_drand(time(NULL)); // initialize random seed
   11 
   12   intg trsize = 2000; // maximum training set size: 60000
   13   intg tesize = 1000; // maximum testing set size:  10000
   14 

Details.

   15   // load MNIST datasets: trize for training set and tesize for testing set
   16   MnistDataSource<ubyte,ubyte> train_ds, test_ds;
   17   load_mnist_dataset(argv[1], train_ds, test_ds, trsize, tesize);

Details.

   18 
   19   // create 1-of-n targets with target 1.0 for shown class, -1.0 for the rest
   20   Idx<double> targets = create_target_matrix(1+idx_max(train_ds.labels), 1.0);
   21 

Details.

   22   // create the network weights, network and trainer
   23   IdxDim dims(train_ds.data.spec); // get order and dimenions from data
   24   parameter theparam(60000); // create trainable parameter
   25   lenet5 l5(theparam, 32, 32, 5, 5, 2, 2, 5, 5, 2, 2, 120, targets.dim(0));
   26   supervised_euclidean_machine thenet(l5, targets, dims);
   27   supervised_trainer<ubyte,ubyte> thetrainer(thenet, theparam);
   28 
   29   // a classifier-meter measures classification errors
   30   classifier_meter trainmeter, testmeter;
   31 

Details.

   32   // initialize the network weights
   33   forget_param_linear fgp(1, 0.5);
   34   thenet.forget(fgp);
   35 
   36   // learning parameters
   37   gd_param gdp(/* double leta*/ 0.0001,
   38 	       /* double ln */ 	0.0,
   39 	       /* double l1 */ 	0.0,
   40 	       /* double l2 */ 	0.0,
   41 	       /* int dtime */ 	0,
   42 	       /* double iner */0.0,
   43 	       /* double a_v */ 0.0,
   44 	       /* double a_t */ 0.0,
   45 	       /* double g_t*/ 	0.0);
   46 

Details.

   47   // estimate second derivative on 100 iterations, using mu=0.02
   48   cout << "Computing second derivatives on MNIST dataset: ";
   49   thetrainer.compute_diaghessian(train_ds, 100, 0.02);
   50 

Details.

   51   // do training iterations
   52   cout << "Training network on MNIST with " << train_ds.size();
   53   cout << " training samples and " << test_ds.size() << " test samples" << endl;
   54   for (int i = 0; i < 5; ++i) {
   55     thetrainer.train(train_ds, trainmeter, gdp, 1);
   56     cout << "training: " << flush;
   57     thetrainer.test(train_ds, trainmeter);
   58     trainmeter.display();
   59     cout << " testing: " << flush;
   60     thetrainer.test(test_ds, testmeter);
   61     testmeter.display();
   62   }
   63   return 0;
   64 }