In this project I trained 2 and 3-layer FFN classifiers on the MNIST dataset and deployed a final model in an iOS app as a practical example. The models were implemented and trained in Python with NNKit (a framework I developed to study neural networks) and the iOS example was done in Objective-C and C++.
The final model achieves 99% test accuracy in the MNIST test set and performs fairly well in practice with both handwritten and computer-generated digits.
I trained different combinations of network topologies and batch sizes (Tables 1 and 2) on the whole MNIST train set and used the first half of the test set (5K examples) for validation, setting aside the other half for testing. At the end of each training epoch I measured validation accuracy and kept the five best checkpoints across all epochs based on this performance measure (Algorithm 1).
|Model Output (10-unit layer)||Epochs||Loss||Optimizer||Learning Rate|
|softmax||100||cross entropy||gradient descent w / momentum ( )||0.99|
|170, 300, 900, (300, 300), (900, 100), (170, 100, 70), (300, 200, 100)||16, 32, 182|
The best models stemmed from two combinations of topologies and batch sizes. In fact, four out of the five models were iterative improvements over a common combination (Table 3).
|Model ID||Topology (excluding output)||Batch Size||Best Epoch||Training Loss||Validation Loss||Validation Accuracy (%)|
Figures 3 and 4 show training stats for the best models. From these it can be seen that best accuracy, which I used in deciding whether to keep a checkpoint, did not coincide with lowest validation loss, which is the usual measure to decide on early-stopping. In a sense then, I used a different loss function for validation than for training but this turned out beneficial.
A possible interpretation for why stopping on accuracy might have been better is that the loss function measures distance from the target and the targets in this problem are one-hot vectors equivalent to 100% probabilities for the right digit and 0% probabilities for all other digits. Then matching a 100% probability would probably mean overfitting to noise specific to the dataset. Accuracy instead measures the relative probabilities of digits within the output distribution (i.e.: the prediction is the digit with highest probability, without it necessarily being 100%).
Testing was done by running the remaining 5K examples in the test set as a whole batch thru each model and measuring accuracy in the same way as for validation. All models scored similarly around 99% (Table 4).
|Model ID||Topology||Batch Size||Best Epoch||Test Accuracy (%)|
For a baseline, I relied on error rates for similar network architectures as reported in the MNIST website, which I partially reproduced in Table 5. This table does not include entries where the input to the network was augmented in some way, since I used raw features. Given this table, my test results were very satisfactory.
|Topology (as described in original table)||Test Error (%)||Inferred Test Accuracy from Error (%)||Reference|
|3-layer NN, 500+300 HU, softmax, cross entropy, weight decay||1.53||98.47||Hinton, unpublished, 2005|
|2-layer NN, 800 HU, Cross-Entropy Loss||1.60||98.40||Simard et al., ICDAR 2003|
|3-layer NN, 500+150 hidden units||2.95||97.05||LeCun et al. 1998|
|3-layer NN, 300+100 hidden units||3.05||96.95||LeCun et al. 1998|
|2-layer NN, 1000 hidden units||4.50||95.50||LeCun et al. 1998|
|2-layer NN, 300 hidden units, mean square error||4.70||95.30||LeCun et al. 1998|
I also tested the models on photographs of handwritten digits taken with an iPhone to mimic the conditions under which the final model would perform. Before testing, the iPhone photographs were preprocessed following a similar approach to the one described in the MNIST website:
The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
By processing raw pixels in this way, model 5 correctly predicted all photographs and models 1 thru 4 correctly predicted all but one. Model 5 was also the smallest model, which made it more attractive for mobile deployment (2.2MB vs 7.3MB). So at this point I decided on model 5 for deployment.
I implemented an iOS app capable of taking photographs and running prediction on them. To import and run the trained model in the app, I re-implemented a forward-only subset of NNKit (originally written in Python) in C++.
Photographs taken in the iOS app however could contain other elements in the scene besides digits, so running prediction on a whole image did not always make sense. I partially addressed this by implementing a ROI (region of interest) extraction step before preprocessing and running prediction on the these ROIs instead (Figure 6).
As an example, Figure 7 shows the result of preprocessing an image as described above and running prediction on ROIs. The 7 is correctly classified but since ROI extraction is merely based on edge detection, several non-digit ROIs were also sent to the model.
Finally, I implemented to additional measures to further mitigate the issue of multiple ROIs per image. First, the app crops photographs to within a smaller focus area, reducing the number of ROIs. Second, the ROI algorithm returns contours in decreasing order of bounding box area and the app only runs prediction on the largest one. This stems from the assumption that the digit being recognized is likely to have the largest bounding box within the focus area. Figure 8 shows a diagram of the complete processing pipeline in the app.