MNIST with MXNet

Tutorial for MNIST with MXNet

NOTE: All tutorials are available in Jupyter Notebook format. To download the tutorials run curl -L | tar xz from a Jupyter Notebook Terminal running in your KUDO Kubeflow installation.

NOTE: Please note that these notebook tutorials have been built for and tested on D2iQ's KUDO for Kubeflow. Without the requisite Kubernetes operators and custom Docker images, these notebook will likely not work.

Training MNIST with MXNet


Recognizing handwritten digits based on the MNIST (Modified National Institute of Standards and Technology) data set is the “Hello, World” example of machine learning. Each (anti-aliased) black-and-white image represents a digit from 0 to 9 and has been fit into a 28x28 pixel bounding box. The problem of recognizing digits from handwriting is, for instance, important to the postal service when automatically reading zip codes from envelopes.

What You’ll Learn

We’ll show you how to use Apache MXNet to build a model with two convolutional layers and two fully connected layers to perform the multi-class classification of images provided. An approach using the Sequential layer type is available here.

What You’ll Need

All you need is this notebook.

How to Load and Inspect the Data

Before we proceed, let’s check that we’re using the right image, that is, MXNet is available:

! pip list | grep mxnet
mxnet-cu102mkl           1.6.0


Let’s import the necessary Python modules to load the data:

import mxnet as mx
import logging


mnist = mx.test_utils.get_mnist()

The documentation tells us we receive a dict, so for future reference we can look at the available keys:

dict_keys(['train_data', 'train_label', 'test_data', 'test_label'])

How is the data structured? We can see that by grabbing an example and asking for the shape of the array:

example = mnist["train_data"][42]
(1, 28, 28)

So, we have (batch, height, width) because batch = 1 for a single example. For RGB images, we’d have (batch, height, width, channels) with channels = 3. What does the image itself look like?

import numpy as np

from matplotlib import pyplot as plt
%matplotlib inline


That could be both a 1 or a 7, so let’s check what the label says:


Just to be on the safe side, we need to check our pixel values have already been scaled into the [0, 1] range:

flattened = example.flatten()
min(flattened), max(flattened)
(0.0, 1.0)

How to Train the Model

We shall make use of the following convenience function to create a single convolutional layer with a certain activation function followed by a pre-defined max pooling layer. Since we intend to have two such layers in our model, it makes sense to package a single layer as a re-usable function.

A Note on Activation Functions
A common choice for activation functions is a ReLU (Rectified Linear Unit). It is linear for non-negative values and zero for negative ones. The main benefits of ReLU as opposed to sigmoidal functions (e.g. logistic or `tanh`) are:
  • ReLU and its gradient are very cheap to compute;
  • Gradients are less likely to vanish, because for (non-)negative values its gradient is constant and therefore does not saturate, which for deep neural networks can accelerate convergence
  • ReLU has a regularizing effect, because it promotes sparse representations (i.e. some nodes' weights are zero);
  • Empirically it has been found to work well.
ReLU activation functions can cause neurons to 'die' because a large, negative (learned) bias value causes all inputs to be negative, which in turn leads to a zero output. The neuron has thus become incapable of discriminating different input values. So-called leaky ReLU activations functions address that issue; these functions are linear but non-zero for negative values, so that their gradients are small but non-zero. ELUs, or exponential linear units, are another solution to the problem of dying neurons.
def conv_layer(input_layer, kernel, num_filters, activation):
    Defines a CNN layer with `activation` function and 2D max pooling with a kernel and stride of (2, 2)
    :param layer: input layer (an MXNet symbol)
    :param kernel: 2D convolutional kernel
    :param filters: number of filters to use in convolution
    :param activation: activation function (e.g. "tanh" or "relu")
    :rtype: mxnet.symbol.symbol.Symbol
    conv = mx.sym.Convolution(data=input_layer, kernel=kernel, num_filter=num_filters)
    act = mx.sym.Activation(data=conv, act_type=activation)
    pool = mx.sym.Pooling(data=act, pool_type="max", kernel=(2, 2), stride=(2, 2))
    return pool
A Note on CNNs
While it is not our intention to cover the basics of convolutional neural networks (CNNs), there are a few matters worth mentioning. Convolutional layers are spatial feature extractors for images. A series of convolutional kernels (of the same dimensions) is applied to the image to obtain different versions of the same base image (i.e. filters). These filters extract patterns hierarchically. In the first layer, filters typically capture dots, edges, corners, and so on. With each additional layer, these patterns become more complex and turn from basic geometric shapes into constituents of objects and entire objects. That is why often the number of filters increases with each additional convolutional layer: to extract more complex patterns.

Convolutional layers are often followed by a pooling layer to down-sample the input. This aids in lowering the computational burden as we increase the number of filters. A max pooling layer simply picks the largest value of pixels in a small (rectangular) neighbourhood of a single channel (e.g. RGB). This has the effect of making features locally translation-invariant, which is often desired: whether a feature of interest is on the left or right edge of a pooling window, which is also referred to as a kernel, is largely irrelevant to the problem of image classification. Note that this may not always be a desired characteristic and depends on the size of the pooling kernel. For instance, the precice location of tissue damage in living organisms or defects on manufactured products may be very significant indeed. Pooling kernels are generally chosen to be relatively small compared to the dimensions of the input, which means that local translation invariance is often desired.

Another common component of CNNs is a dropout layer. Dropout provides a mechanism for regularization that has proven successful in many applications. It is suprisingly simple: some nodes' weights (and biases) in a specific layer are set to zero at random, that is, arbitrary nodes are removed from the network during the training step. This causes the network to not rely on any single node (a.k.a. neuron) for a feature, as each node can be dropped at random. The network therefore has to learn redundant representations of features. This is important because of what is referred to as internal covariate shift (often mentioned in connection with batch normalization): the change of distributions of internal nodes' weights due to all other layers, which can cause nodes to stop learning (i.e. updating their weights). Thanks to dropout, layers become more robust to changes, although it also means it limits what can be learnt (as always with regularization). Still, dropout is the neural network's equivalent of the saying you should never put all your eggs in one basket. Layers with a high risk of overfitting (e.g. layers with many units and lots of inputs) typically have a higher dropout rate.

A nice visual explanation of convolutional layers is available here. If you're curious what a CNN "sees" while training, you can have a look here.

With the following function we create an ANN with two convolutional layers (as defined above), two fully connected layers with a different number of , and an output layer with a softmax function. In each of the layers, we choose the same activation function, although that is not needed and can easily be changed.

def ann(input_layer, kernels, filters, activation, hidden_units):
    Defines a neural network with two convolutional layers and two dense layers.
    To train the model it needs to be wrapped in a module.
    :param input_layer: input layer (an MXNet symbol)
    :param kernels: a list of 2D convolutional kernels
    :param filters: a list of convolutional filters
    :param activation: the activation function for all layers (e.g. "tanh" or "relu")
    :param hidden_units: a list of hidden units for the dense layers
    :rtype: mxnet.symbol.symbol.Symbol
    conv1 = conv_layer(
    conv2 = conv_layer(

    flattened = mx.sym.flatten(data=conv2)

    fc1 = mx.sym.FullyConnected(data=flattened, num_hidden=hidden_units[0])
    fc1_out = mx.sym.Activation(data=fc1, act_type=activation)

    fc2 = mx.sym.FullyConnected(data=fc1_out, num_hidden=hidden_units[1])

    out = mx.sym.SoftmaxOutput(data=fc2, name="softmax")
    return out

Now it’s time to define an execution context for the model training. If GPUs are available the model is trained on there. If not, it defaults to using CPUs.

Since we can use the context across training instances (e.g. if we want to see the effect of a different activation function), we can define it outside the main training loop:

context = mx.cpu()
if mx.context.num_gpus() > 0:
    context = mx.gpu()
epochs = 10
batch_size = 100
def train(
    kernels=[(5, 5), (5, 5)],
    filters=[20, 50],
    hidden_units=[500, 10],
    # Create an iterator for the training data with a fixed `batch_size`
    # We shuffle the data to ensure each batch is representative of the entire data set
    # Note that we can also use `mxnet.test_utils.get_mnist_iterator`
    train_iter =
        data["train_data"], data["train_label"], batch_size, shuffle=True

    # Create an iterator for the test data with a fixed `batch_size`
    # No need to shuffle the test data!
    eval_iter =["test_data"], data["test_label"], batch_size)

    # Define a symbolic (placeholder) variable
    images = mx.sym.Variable("data")

    # Create the artificial neural network based
    net = ann(images, kernels, filters, activation, hidden_units)

    # To be able to train (and evaluate) a model, we need the execution `context`,
    # and we must wrap the (output) symbol `net` in a module.
    model = mx.module.Module(symbol=net, context=context)

    # Train using a stochastic gradient descent algorithm with respect to the test accuracy
        optimizer_params={"learning_rate": learning_rate},
        batch_end_callback=mx.callback.Speedometer(batch_size, 100),

    # Define another iterator and use the `score` method our `model` module
    # to compute the accuracy.
    # Note: We can see the accuracy on the training and test data set from the logs,
    # but it's convenient to return it from the function, together with the model, so
    # we can call `model.predict(...)` on it.
    test_iter =["test_data"], mnist["test_label"], batch_size)
    acc = mx.metric.Accuracy()
    model.score(test_iter, acc)
    return model, acc

Let’s invoke it with the defaults (and our pre-defined parameters):

model, acc = train(mnist, context, epochs, batch_size)
INFO:root:Epoch[0] Batch [0-100]	Speed: 1353.96 samples/sec	accuracy=0.114356
INFO:root:Epoch[0] Batch [100-200]	Speed: 1438.84 samples/sec	accuracy=0.113700
INFO:root:Epoch[0] Batch [200-300]	Speed: 1464.64 samples/sec	accuracy=0.112600
INFO:root:Epoch[0] Batch [300-400]	Speed: 1523.29 samples/sec	accuracy=0.109100
INFO:root:Epoch[0] Batch [400-500]	Speed: 1467.58 samples/sec	accuracy=0.113200
INFO:root:Epoch[0] Train-accuracy=0.112033
INFO:root:Epoch[0] Time cost=41.259
INFO:root:Epoch[0] Validation-accuracy=0.113500
INFO:root:Epoch[1] Batch [0-100]	Speed: 1490.50 samples/sec	accuracy=0.113069
INFO:root:Epoch[1] Batch [100-200]	Speed: 1425.70 samples/sec	accuracy=0.202600
INFO:root:Epoch[1] Batch [200-300]	Speed: 1297.02 samples/sec	accuracy=0.698700
INFO:root:Epoch[1] Batch [300-400]	Speed: 1369.69 samples/sec	accuracy=0.866400
INFO:root:Epoch[1] Batch [400-500]	Speed: 1276.43 samples/sec	accuracy=0.904900
INFO:root:Epoch[1] Train-accuracy=0.616617
INFO:root:Epoch[1] Time cost=44.176
INFO:root:Epoch[1] Validation-accuracy=0.931500
INFO:root:Epoch[2] Batch [0-100]	Speed: 1347.37 samples/sec	accuracy=0.935446
INFO:root:Epoch[2] Batch [100-200]	Speed: 1332.14 samples/sec	accuracy=0.949500
INFO:root:Epoch[2] Batch [200-300]	Speed: 1287.23 samples/sec	accuracy=0.955100
INFO:root:Epoch[2] Batch [300-400]	Speed: 1245.89 samples/sec	accuracy=0.956200
INFO:root:Epoch[2] Batch [400-500]	Speed: 1289.72 samples/sec	accuracy=0.961800
INFO:root:Epoch[2] Train-accuracy=0.953717
INFO:root:Epoch[2] Time cost=46.817
INFO:root:Epoch[2] Validation-accuracy=0.967500
INFO:root:Epoch[3] Batch [0-100]	Speed: 1187.91 samples/sec	accuracy=0.969604
INFO:root:Epoch[3] Batch [100-200]	Speed: 1267.84 samples/sec	accuracy=0.968600
INFO:root:Epoch[3] Batch [200-300]	Speed: 1298.08 samples/sec	accuracy=0.971600
INFO:root:Epoch[3] Batch [300-400]	Speed: 1300.17 samples/sec	accuracy=0.976300
INFO:root:Epoch[3] Batch [400-500]	Speed: 1224.72 samples/sec	accuracy=0.974500
INFO:root:Epoch[3] Train-accuracy=0.973083
INFO:root:Epoch[3] Time cost=47.699
INFO:root:Epoch[3] Validation-accuracy=0.977000
INFO:root:Epoch[4] Batch [0-100]	Speed: 1233.89 samples/sec	accuracy=0.975842
INFO:root:Epoch[4] Batch [100-200]	Speed: 1286.16 samples/sec	accuracy=0.981800
INFO:root:Epoch[4] Batch [200-300]	Speed: 1167.10 samples/sec	accuracy=0.979000
INFO:root:Epoch[4] Batch [300-400]	Speed: 1246.25 samples/sec	accuracy=0.979700
INFO:root:Epoch[4] Batch [400-500]	Speed: 1202.52 samples/sec	accuracy=0.982900
INFO:root:Epoch[4] Train-accuracy=0.980350
INFO:root:Epoch[4] Time cost=49.019
INFO:root:Epoch[4] Validation-accuracy=0.984200
INFO:root:Epoch[5] Batch [0-100]	Speed: 1159.72 samples/sec	accuracy=0.983366
INFO:root:Epoch[5] Batch [100-200]	Speed: 1217.18 samples/sec	accuracy=0.980900
INFO:root:Epoch[5] Batch [200-300]	Speed: 1223.35 samples/sec	accuracy=0.984400
INFO:root:Epoch[5] Batch [300-400]	Speed: 1251.70 samples/sec	accuracy=0.985000
INFO:root:Epoch[5] Batch [400-500]	Speed: 1270.82 samples/sec	accuracy=0.985400
INFO:root:Epoch[5] Train-accuracy=0.984067
INFO:root:Epoch[5] Time cost=48.736
INFO:root:Epoch[5] Validation-accuracy=0.984500
INFO:root:Epoch[6] Batch [0-100]	Speed: 1225.29 samples/sec	accuracy=0.986139
INFO:root:Epoch[6] Batch [100-200]	Speed: 1180.46 samples/sec	accuracy=0.986300
INFO:root:Epoch[6] Batch [200-300]	Speed: 1272.58 samples/sec	accuracy=0.986600
INFO:root:Epoch[6] Batch [300-400]	Speed: 1276.39 samples/sec	accuracy=0.985300
INFO:root:Epoch[6] Batch [400-500]	Speed: 1295.95 samples/sec	accuracy=0.989000
INFO:root:Epoch[6] Train-accuracy=0.986550
INFO:root:Epoch[6] Time cost=48.073
INFO:root:Epoch[6] Validation-accuracy=0.987700
INFO:root:Epoch[7] Batch [0-100]	Speed: 1213.19 samples/sec	accuracy=0.986931
INFO:root:Epoch[7] Batch [100-200]	Speed: 1157.00 samples/sec	accuracy=0.987300
INFO:root:Epoch[7] Batch [200-300]	Speed: 1182.32 samples/sec	accuracy=0.988800
INFO:root:Epoch[7] Batch [300-400]	Speed: 1250.13 samples/sec	accuracy=0.989800
INFO:root:Epoch[7] Batch [400-500]	Speed: 1226.76 samples/sec	accuracy=0.989100
INFO:root:Epoch[7] Train-accuracy=0.988350
INFO:root:Epoch[7] Time cost=49.748
INFO:root:Epoch[7] Validation-accuracy=0.987400
INFO:root:Epoch[8] Batch [0-100]	Speed: 1262.47 samples/sec	accuracy=0.990990
INFO:root:Epoch[8] Batch [100-200]	Speed: 1199.14 samples/sec	accuracy=0.989800
INFO:root:Epoch[8] Batch [200-300]	Speed: 1193.96 samples/sec	accuracy=0.990500
INFO:root:Epoch[8] Batch [300-400]	Speed: 1087.18 samples/sec	accuracy=0.988200
INFO:root:Epoch[8] Batch [400-500]	Speed: 1154.58 samples/sec	accuracy=0.989800
INFO:root:Epoch[8] Train-accuracy=0.989850
INFO:root:Epoch[8] Time cost=50.834
INFO:root:Epoch[8] Validation-accuracy=0.988700
INFO:root:Epoch[9] Batch [0-100]	Speed: 1179.14 samples/sec	accuracy=0.990297
INFO:root:Epoch[9] Batch [100-200]	Speed: 1226.71 samples/sec	accuracy=0.990800
INFO:root:Epoch[9] Batch [200-300]	Speed: 1208.98 samples/sec	accuracy=0.989800
INFO:root:Epoch[9] Batch [300-400]	Speed: 1272.81 samples/sec	accuracy=0.991500
INFO:root:Epoch[9] Batch [400-500]	Speed: 1292.01 samples/sec	accuracy=0.989900
INFO:root:Epoch[9] Train-accuracy=0.990800
INFO:root:Epoch[9] Time cost=48.257
INFO:root:Epoch[9] Validation-accuracy=0.987700
A Note on Accuracy
We can see from the logs that the accuracy on both training and test data are relatively close. A training accuracy that is significantly higher than the test accuracy is an indication of a model that overfits: it picks up on noise rather than the signal that's present in the data. This model, therefore, does a decent job of classifying digits in images.

Because we wrapped out training process in a function, we can easily see the impact of a different activation function:

model_relu, acc_relu = train(mnist, context, epochs, batch_size, activation="relu")
INFO:root:Epoch[0] Batch [0-100]	Speed: 1693.27 samples/sec	accuracy=0.111089
INFO:root:Epoch[0] Batch [100-200]	Speed: 1739.83 samples/sec	accuracy=0.118500
INFO:root:Epoch[0] Batch [200-300]	Speed: 1768.59 samples/sec	accuracy=0.105400
INFO:root:Epoch[0] Batch [300-400]	Speed: 1870.50 samples/sec	accuracy=0.112700
INFO:root:Epoch[0] Batch [400-500]	Speed: 1777.87 samples/sec	accuracy=0.111900
INFO:root:Epoch[0] Train-accuracy=0.112017
INFO:root:Epoch[0] Time cost=33.978
INFO:root:Epoch[0] Validation-accuracy=0.113500
INFO:root:Epoch[1] Batch [0-100]	Speed: 1813.15 samples/sec	accuracy=0.117228
INFO:root:Epoch[1] Batch [100-200]	Speed: 1764.89 samples/sec	accuracy=0.107800
INFO:root:Epoch[1] Batch [200-300]	Speed: 1798.58 samples/sec	accuracy=0.113700
INFO:root:Epoch[1] Batch [300-400]	Speed: 1880.55 samples/sec	accuracy=0.235100
INFO:root:Epoch[1] Batch [400-500]	Speed: 1829.45 samples/sec	accuracy=0.713800
INFO:root:Epoch[1] Train-accuracy=0.360967
INFO:root:Epoch[1] Time cost=32.823
INFO:root:Epoch[1] Validation-accuracy=0.935800
INFO:root:Epoch[2] Batch [0-100]	Speed: 1876.07 samples/sec	accuracy=0.928416
INFO:root:Epoch[2] Batch [100-200]	Speed: 1837.19 samples/sec	accuracy=0.942200
INFO:root:Epoch[2] Batch [200-300]	Speed: 1892.24 samples/sec	accuracy=0.952900
INFO:root:Epoch[2] Batch [300-400]	Speed: 1895.88 samples/sec	accuracy=0.961800
INFO:root:Epoch[2] Batch [400-500]	Speed: 1943.46 samples/sec	accuracy=0.962000
INFO:root:Epoch[2] Train-accuracy=0.952467
INFO:root:Epoch[2] Time cost=31.772
INFO:root:Epoch[2] Validation-accuracy=0.970900
INFO:root:Epoch[3] Batch [0-100]	Speed: 1884.76 samples/sec	accuracy=0.965248
INFO:root:Epoch[3] Batch [100-200]	Speed: 1930.55 samples/sec	accuracy=0.974600
INFO:root:Epoch[3] Batch [200-300]	Speed: 1932.68 samples/sec	accuracy=0.973000
INFO:root:Epoch[3] Batch [300-400]	Speed: 1919.08 samples/sec	accuracy=0.975500
INFO:root:Epoch[3] Batch [400-500]	Speed: 1875.23 samples/sec	accuracy=0.977300
INFO:root:Epoch[3] Train-accuracy=0.973900
INFO:root:Epoch[3] Time cost=31.617
INFO:root:Epoch[3] Validation-accuracy=0.981600
INFO:root:Epoch[4] Batch [0-100]	Speed: 1882.73 samples/sec	accuracy=0.976832
INFO:root:Epoch[4] Batch [100-200]	Speed: 1879.33 samples/sec	accuracy=0.978900
INFO:root:Epoch[4] Batch [200-300]	Speed: 1696.34 samples/sec	accuracy=0.981700
INFO:root:Epoch[4] Batch [300-400]	Speed: 1771.81 samples/sec	accuracy=0.982200
INFO:root:Epoch[4] Batch [400-500]	Speed: 1779.80 samples/sec	accuracy=0.982500
INFO:root:Epoch[4] Train-accuracy=0.980533
INFO:root:Epoch[4] Time cost=33.584
INFO:root:Epoch[4] Validation-accuracy=0.983500
INFO:root:Epoch[5] Batch [0-100]	Speed: 1888.09 samples/sec	accuracy=0.984653
INFO:root:Epoch[5] Batch [100-200]	Speed: 1830.69 samples/sec	accuracy=0.987400
INFO:root:Epoch[5] Batch [200-300]	Speed: 1752.42 samples/sec	accuracy=0.985500
INFO:root:Epoch[5] Batch [300-400]	Speed: 1716.19 samples/sec	accuracy=0.983000
INFO:root:Epoch[5] Batch [400-500]	Speed: 1728.71 samples/sec	accuracy=0.985200
INFO:root:Epoch[5] Train-accuracy=0.985150
INFO:root:Epoch[5] Time cost=33.577
INFO:root:Epoch[5] Validation-accuracy=0.983600
INFO:root:Epoch[6] Batch [0-100]	Speed: 1829.37 samples/sec	accuracy=0.988119
INFO:root:Epoch[6] Batch [100-200]	Speed: 1879.84 samples/sec	accuracy=0.988000
INFO:root:Epoch[6] Batch [200-300]	Speed: 1866.12 samples/sec	accuracy=0.988600
INFO:root:Epoch[6] Batch [300-400]	Speed: 1901.34 samples/sec	accuracy=0.988500
INFO:root:Epoch[6] Batch [400-500]	Speed: 1850.14 samples/sec	accuracy=0.986400
INFO:root:Epoch[6] Train-accuracy=0.988083
INFO:root:Epoch[6] Time cost=32.148
INFO:root:Epoch[6] Validation-accuracy=0.985400
INFO:root:Epoch[7] Batch [0-100]	Speed: 1834.51 samples/sec	accuracy=0.990891
INFO:root:Epoch[7] Batch [100-200]	Speed: 1873.71 samples/sec	accuracy=0.989500
INFO:root:Epoch[7] Batch [200-300]	Speed: 1878.99 samples/sec	accuracy=0.989000
INFO:root:Epoch[7] Batch [300-400]	Speed: 1879.32 samples/sec	accuracy=0.989200
INFO:root:Epoch[7] Batch [400-500]	Speed: 1815.73 samples/sec	accuracy=0.989600
INFO:root:Epoch[7] Train-accuracy=0.989600
INFO:root:Epoch[7] Time cost=32.270
INFO:root:Epoch[7] Validation-accuracy=0.987400
INFO:root:Epoch[8] Batch [0-100]	Speed: 1732.95 samples/sec	accuracy=0.991188
INFO:root:Epoch[8] Batch [100-200]	Speed: 1635.98 samples/sec	accuracy=0.988900
INFO:root:Epoch[8] Batch [200-300]	Speed: 1799.76 samples/sec	accuracy=0.992300
INFO:root:Epoch[8] Batch [300-400]	Speed: 1736.55 samples/sec	accuracy=0.992600
INFO:root:Epoch[8] Batch [400-500]	Speed: 1825.21 samples/sec	accuracy=0.990400
INFO:root:Epoch[8] Train-accuracy=0.991017
INFO:root:Epoch[8] Time cost=34.618
INFO:root:Epoch[8] Validation-accuracy=0.987800
INFO:root:Epoch[9] Batch [0-100]	Speed: 1749.84 samples/sec	accuracy=0.991881
INFO:root:Epoch[9] Batch [100-200]	Speed: 1813.36 samples/sec	accuracy=0.991900
INFO:root:Epoch[9] Batch [200-300]	Speed: 1832.93 samples/sec	accuracy=0.992500
INFO:root:Epoch[9] Batch [300-400]	Speed: 1561.59 samples/sec	accuracy=0.994700
INFO:root:Epoch[9] Batch [400-500]	Speed: 1848.26 samples/sec	accuracy=0.992400
INFO:root:Epoch[9] Train-accuracy=0.992583
INFO:root:Epoch[9] Time cost=34.517
INFO:root:Epoch[9] Validation-accuracy=0.987800
print(f"Accuracy: {acc} (tanh) vs {acc_relu} (ReLU)")
Accuracy: EvalMetric: {'accuracy': 0.9877} (tanh) vs EvalMetric: {'accuracy': 0.9878} (ReLU)

This is just a simple example of fiddling with hyperparameters. If we wanted to tune hyperparameters (with Katib) automatically, we could simply pass these hyperparameters as arguments to container that contains a script with all necessary imports and functions to run the train-and-evaluate process.

How to Predict with a Trained Model

Batch predictions based on a trained model are easy:

def test_iter(data=mnist, batch_size=100):
    return["test_data"], None, batch_size=batch_size)

prob = model.predict(test_iter())
prob_relu = model_relu.predict(test_iter())

If we pick a random example, we can see the probabilities per category (i.e. digit):

prob[24], prob_relu[24]
 [3.88936355e-10 5.50869439e-09 1.38218335e-08 1.33717464e-11
  9.99999046e-01 1.08974885e-09 3.25030669e-07 8.28973228e-08
  1.59384072e-07 3.14834097e-07]
 <NDArray 10 @cpu(0)>,
 [6.4347176e-08 7.3004806e-08 3.4015596e-08 5.8196594e-09 9.9998724e-01
  1.2476704e-07 1.9580840e-07 3.0599865e-06 1.2806310e-08 9.1640350e-06]
 <NDArray 10 @cpu(0)>)

The highest probabily is observed for the fourth index (i.e. the digit ‘4’):

np.argmax(prob[24]), np.argmax(prob_relu[24])
 <NDArray 1 @cpu(0)>,
 <NDArray 1 @cpu(0)>)

Since we did not shuffle data for the iterator test_iter that was used to generate probabilities we can use the same index to obtain the label to verify that the model predicts the digit correctly: