MATLAB Examples

Train a Convolutional Neural Network for Regression

This example shows how to fit a regression model using convolutional neural networks to predict the angles of rotation of handwritten digits.

Convolutional neural networks (CNNs or ConvNets) are essential tools for deep learning, and are especially suited for analyzing image data. For example, you can use CNNs to classify images. To predict continuous data such as angles and distances, you can include a regression layer at the end of the network.

The example constructs a convolutional neural network architecture, trains a network, and uses the trained network to predict angles of rotated, handwritten digits. These predictions are useful for optical character recognition.

Optionally, you can use imrotate (Image Processing Toolbox™) to rotate the images, and boxplot (Statistics and Machine Learning Toolbox™) to create a residual box plot.

Contents

Load Training Data

The network is trained on a collection synthetic handwritten digits containing 5000 images of digits with corresponding angles of rotation.

Load the digit training set as 4-D array data using digitTrain4DArrayData.

[trainImages,~,trainAngles] = digitTrain4DArrayData;

Display 20 random sample training digits using imshow.

numTrainImages = size(trainImages,4);

figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)

    imshow(trainImages(:,:,:,idx(i)))
    drawnow
end

Create Network Layers

To solve the regression problem, create the layers of the network and include a regression layer at the end of the network.

The first layer defines the size and type of the input data. The input images are 28-by-28-by-1. Create an image input layer of the same size as the training images.

The middle layers of the network define the core architecture of the network. Create a 2-D convolutional layer with 25 filters of size 12 followed by a ReLU layer.

The final layers define the size and type of output data. For regression problems, a fully connected layer must precede the regression layer at the end of the network. Create a fully connected output layer of size 1 and a regression layer.

Combine all the layers together in a Layer array.

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(12,25)
    reluLayer
    fullyConnectedLayer(1)
    regressionLayer];

Train Network

Create the network training options. Set the initial learn rate to 0.001. To reduce training time, lower the value of 'MaxEpochs'.

options = trainingOptions('sgdm','InitialLearnRate',0.001, ...
    'MaxEpochs',15);

Create the network using trainNetwork. This command uses a compatible GPU if available. Otherwise, trainNetwork uses the CPU. A CUDA®-enabled NVIDIA® GPU with compute capability 3.0 or higher is required for training on a GPU. Training can take a few minutes, especially when training on a CPU.

net = trainNetwork(trainImages,trainAngles,layers,options)
Training on single CPU.
Initializing image normalization.
|=========================================================================================|
|     Epoch    |   Iteration  | Time Elapsed |  Mini-batch  |  Mini-batch  | Base Learning|
|              |              |  (seconds)   |     Loss     |     RMSE     |     Rate     |
|=========================================================================================|
|            1 |            1 |         2.17 |     392.5879 |        28.02 |       0.0010 |
|            2 |           50 |        88.96 |      96.0257 |        13.86 |       0.0010 |
|            3 |          100 |       146.67 |      55.5571 |        10.54 |       0.0010 |
|            4 |          150 |       212.20 |      64.4336 |        11.35 |       0.0010 |
|            6 |          200 |       268.02 |      36.3027 |         8.52 |       0.0010 |
|            7 |          250 |       329.30 |      53.0679 |        10.30 |       0.0010 |
|            8 |          300 |       387.20 |      35.7902 |         8.46 |       0.0010 |
|            9 |          350 |       443.90 |      34.8994 |         8.35 |       0.0010 |
|           11 |          400 |       507.63 |      24.8728 |         7.05 |       0.0010 |
|           12 |          450 |       569.95 |      33.0423 |         8.13 |       0.0010 |
|           13 |          500 |       633.77 |      27.0265 |         7.35 |       0.0010 |
|           15 |          550 |       732.29 |      18.1709 |         6.03 |       0.0010 |
|           15 |          585 |       787.83 |      19.0721 |         6.18 |       0.0010 |
|=========================================================================================|

net = 

  SeriesNetwork with properties:

    Layers: [5x1 nnet.cnn.layer.Layer]

Examine the details of the network architecture contained in the Layers property of net.

net.Layers
ans = 

  5x1 Layer array with layers:

     1   'imageinput'         Image Input         28x28x1 images with 'zerocenter' normalization
     2   'conv'               Convolution         25 12x12x1 convolutions with stride [1  1] and padding [0  0  0  0]
     3   'relu'               ReLU                ReLU
     4   'fc'                 Fully Connected     1 fully connected layer
     5   'regressionoutput'   Regression Output   mean-squared-error with response 'Response'

Test Network

Test the performance of the network by evaluating the prediction accuracy of held out test data.

Load the digit test set.

[testImages,~,testAngles] = digitTest4DArrayData;

Use predict to predict the angles of rotation of the test images.

predictedTestAngles = predict(net,testImages);

Evaluate Performance

Evaluate the performance of the model by calculating:

  1. The percentage of predictions within an acceptable error margin
  2. The root-mean-square error (RMSE) of the predicted and actual angles of rotation

Calculate the prediction error between the predicted and actual angles of rotation.

predictionError = testAngles - predictedTestAngles;

Calculate the number of predictions within an acceptable error margin from the true angles. Set the threshold to be 10 degrees. Calculate the percentage of predictions within this threshold.

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numTestImages = size(testImages,4);

accuracy = numCorrect/numTestImages
accuracy =

    0.8484

Use the root-mean-square error (RMSE) to measure the differences between the predicted and actual angles of rotation.

squares = predictionError.^2;
rmse = sqrt(mean(squares))
rmse =

  single

    7.3450

If the accuracy is too low, or the RMSE is too high, then try increasing the value of 'MaxEpochs' in the call to trainingOptions.

Display Box Plot of Residuals for Each Digit Class

Calculate the residuals.

residuals = testAngles - predictedTestAngles;

The boxplot function requires a matrix where each column corresponds to the residuals for each digit class.

The test data groups images by digit classes 0–9 with 500 examples of each. Use reshape to group the residuals by digit class.

residualMatrix = reshape(residuals,500,10);

Each column of residualMatrix corresponds to the residuals of each digit. Create a residual box plot for each digit using boxplot (Statistics and Machine Learning Toolbox).

figure
boxplot(residualMatrix, ...
    'Labels',{'0','1','2','3','4','5','6','7','8','9'})

xlabel('Digit Class')
ylabel('Degrees Error')
title('Residuals')

The digit classes with highest accuracy have a mean close to zero and little variance.

Correct Digit Rotations

You can use functions from Image Processing Toolbox to straighten the digits and display them together. Rotate 49 sample digits according to their predicted angles of rotation using imrotate (Image Processing Toolbox).

idx = randperm(numTestImages,49);
for i = 1:numel(idx)
    image = testImages(:,:,:,idx(i));
    predictedAngle = predictedTestAngles(idx(i));

    imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop');
end

Display the original digits with their corrected rotations. You can use montage (Image Processing Toolbox) to display the digits together in a single image.

figure
subplot(1,2,1)
montage(testImages(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(imagesRotated)
title('Corrected')