Deep Learning

Deep Learning Examples

Explore deep learning examples, and learn how you can get started in MATLAB. 

Training a Model from Scratch

In this example, we want to train a convolutional neural network (CNN) to identify handwritten digits. We will use data from the MNIST dataset, which contains 60,000 images of handwritten numbers 0-9. Here is a random sample of 25 handwritten numbers in the MNIST dataset:

By using a simple dataset, we'll be able to cover all the key steps in the deep learning workflow without dealing with challenges such as processing power or datasets that are too large to fit into memory. You can apply the workflow described here to more complex deep learning problems and larger data sets.

If you are just getting started with applying deep learning, another advantage to using this data set is that you can train it without investing in an expensive GPU.

Even though the dataset is simple, with the right deep learning model and training options, it is possible to achieve over 99% accuracy. So how do we create a model that will get us to that point?

This will be an iterative process in which we build on previous training results to figure out how to approach the training problem. The steps are as follows:

1. Accessing the Data

We begin by downloading the MNIST images into MATLAB. Datasets are stored in many different file types. This data is stored as binary files, which MATLAB can quickly use and reshape into images.
These lines of code will read an original binary file and create an array of all the training images.

rawImgDataTrain = uint8 (fread(fid, numImg * numRows * numCols, 'uint8'));

% Reshape the data part into a 4D array
rawImgDataTrain = reshape(rawImgDataTrain, [numRows, numCols, numImgs]);
imgDataTrain(:,:,1,ii) = uint8(rawImgDataTrain(:,:,ii));	

We can check the size and class of the data by typing whos in the command window.

>> whos imgDataTrain

Name               Size                   Bytes          Class

imgDataTrain       28x28x1x60000          47040000       uint8	

These images are quite small – only 28 x 28 pixels – and there are 60000 total training images.

The next task would be image labeling, but since the MNIST images come with labels, we can skip that tedious step and quickly move on to building our neural network.

2. Creating and Configuring Network Layers

We'll start by building a CNN, the most common kind of deep learning network.

About CNNS

A CNN takes an image, passes it through the network layers, and outputs a final class. The network can have tens or hundreds of layers, with each layer learning to detect different features of an image. Filters are applied to each training image at different resolutions, and the output of each convolved image is used as the input to the next layer. The filters can start as very simple features, such as brightness and edges, and increase in complexity to features that uniquely define the object as the layers progress.

To learn more about the structure of a CNN, watch:

Since we're training the CNN from scratch, we must first specify which layers it will contain and in what order.

layers = [
    imageInputLayer([28 28 1])
	
    convolution2dLayer(3,16,'Padding',1)
    batchNormalizationLayer
    reluLayer
	
    maxPooling2dLayer(2,'Stride',2)
	
    convolution2dLayer(3,32,'Padding',1)
    batchNormalizationLayer
    reluLayer
	
    maxPooling2dLayer(2,'Stride',2)
	
    convolution2dLayer(3,64,'Padding',1)
    batchNormalizationLayer
    reluLayer
	
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

You can learn more about all of these layers in the documentation.

3. Training the Network

First, we select training options. There are many options available. The table shows the most commonly used options.

Commonly Used Training Options

Training Option Definition Hint
Plot of training progress

The plot shows the mini-batch loss and accuracy. It includes a stop button that lets you halt network training at any point.

Use (‘Plots’,’training-progress’) to plot the progress of the network as it trains.    
Max epochs

An epoch is the full pass of the training algorithm over the entire training set.

(‘MaxEpoch’,20)

The more epochs specified, the longer the network will train, but the accuracy may improve with each epoch.

Mini-batch size

A mini-batch is a subset of the training data set that is processed at the same time.

(‘MiniBatchSize’,64)

 The larger the mini-batch, the faster the training, but the maximum size will be determined by the GPU memory. If you get a memory error when training, reduce the mini-batch size.

Learning rate This is a major parameter that controls the speed of training.  A lower learning rate can give a more accurate result, but the network may take longer to train.

We begin by specifying two options: plot progress, and minibatch size.

miniBatchSize = 8192;
options = trainingOptions( 'sgdm',...
    'MiniBatchSize', miniBatchSize,...
    'Plots', 'training-progress');

net = trainNetwork(imgDataTrain, labelsTrain, layers, options);	

We then run the network and monitor its progress.

4. Checking Network Accuracy

Our goal is to have the accuracy of the model increase over time. As the network trains, the progress plot appears.

We'll try altering the training options and the network configuration.

Changing Training Options

First, we'll adjust the learning rate. We set the initial learning rate to be much lower than the default rate of 0.01.

'InitialLearnRate', 0.0001

As a result of changing that one parameter, we get a much better result—nearly 99%!

For some applications, this result would be satisfactory, but you may recall that we're aiming for 99%.

Changing the Network Configuration

Getting to 99% from 90% requires a deeper network and many rounds of trial and error. We add more layers, including batch normalization layers, which will help speed up the network convergence (the point at which it responds correctly to new input).

The network is now “deeper”. This time, we'll change the network but leave the training options the same as they were before.

After the network has trained, we test it on 10,000 images.

predLabelsTest = net.classify(imgDataTest);
accuracy = sum(predLabelsTest == labelsTest) / numel(labelsTest)
	
testAccuracy = 0.9913

We can now use it to identify handwritten letters in online images, or even in a live video stream.

When creating a network from scratch, you are responsible for determining the network configuration. This approach gives you the most control over the network, and can produce impressive results, but it requires an understanding of the structure of a neural network and the many options for layer types and configuration.


Learn More