MATLAB Examples

Resume Training from a Checkpoint Network

Contents

This example shows how to save checkpoint networks while training a convolutional neural network and resume training from a previously saved network.

Load Sample Data

Load the sample data as a 4-D array.

[XTrain,TTrain] = digitTrain4DArrayData;

Display the size of XTrain.

size(XTrain)
ans =

          28          28           1        5000

digitTrain4DArrayData loads the digit training set as 4-D array data. XTrain is a 28-by-28-by-1-by-5000 array, where 28 is the height and 28 is the width of the images. 1 is the number of channels and 5000 is the number of synthetic images of handwritten digits. TTrain is a categorical vector containing the labels for each observation.

Display some of the images in XTrain .

figure;
perm = randperm(5000,20);
for i = 1:20
   subplot(4,5,i);
   imshow(XTrain(:,:,:,perm(i)));
end

Define Layers of Network

Define the layers of the convolutional neural network.

layers = [imageInputLayer([28 28 1])
          convolution2dLayer(5,20)
          reluLayer()
          maxPooling2dLayer(2,'Stride',2)
          fullyConnectedLayer(10)
          softmaxLayer()
          classificationLayer()];

Specify Training Options with Checkpoint Path and Train Network

Set the options to default settings for the stochastic gradient descent with momentum and specify the path for saving the checkpoint networks.

options = trainingOptions('sgdm','CheckpointPath','C:\Temp\cnncheckpoint');

Train the network.

convnet = trainNetwork(XTrain,TTrain,layers,options);
Training on single GPU.
Initializing image normalization.
|=========================================================================================|
|     Epoch    |   Iteration  | Time Elapsed |  Mini-batch  |  Mini-batch  | Base Learning|
|              |              |  (seconds)   |     Loss     |   Accuracy   |     Rate     |
|=========================================================================================|
|            1 |            1 |         0.02 |       2.3022 |        7.81% |       0.0100 |
|            2 |           50 |         0.63 |       2.2684 |       31.25% |       0.0100 |
|            3 |          100 |         1.22 |       1.5395 |       53.13% |       0.0100 |
|            4 |          150 |         1.81 |       1.4287 |       50.00% |       0.0100 |
|            6 |          200 |         2.43 |       1.1682 |       60.94% |       0.0100 |
|            7 |          250 |         3.04 |       0.7370 |       77.34% |       0.0100 |
|            8 |          300 |         3.65 |       0.8038 |       71.09% |       0.0100 |
|            9 |          350 |         4.26 |       0.6766 |       78.91% |       0.0100 |
|           11 |          400 |         4.88 |       0.4824 |       89.06% |       0.0100 |
|           12 |          450 |         5.50 |       0.3670 |       90.63% |       0.0100 |
|           13 |          500 |         6.11 |       0.3582 |       92.19% |       0.0100 |
|           15 |          550 |         6.75 |       0.2501 |       93.75% |       0.0100 |
|           16 |          600 |         7.37 |       0.2662 |       93.75% |       0.0100 |
|           17 |          650 |         7.97 |       0.1974 |       96.09% |       0.0100 |
|           18 |          700 |         8.59 |       0.2140 |       97.66% |       0.0100 |
|           20 |          750 |         9.21 |       0.1402 |       99.22% |       0.0100 |
|           21 |          800 |         9.82 |       0.1324 |       97.66% |       0.0100 |
|           22 |          850 |        10.43 |       0.1373 |       96.88% |       0.0100 |
|           24 |          900 |        11.07 |       0.0913 |      100.00% |       0.0100 |
|           25 |          950 |        11.70 |       0.0935 |       98.44% |       0.0100 |
|           26 |         1000 |        12.31 |       0.0647 |      100.00% |       0.0100 |
|           27 |         1050 |        12.92 |       0.0678 |       99.22% |       0.0100 |
|           29 |         1100 |        13.55 |       0.1053 |       98.44% |       0.0100 |
|           30 |         1150 |        14.17 |       0.0714 |       99.22% |       0.0100 |
|           30 |         1170 |        14.41 |       0.0497 |      100.00% |       0.0100 |
|=========================================================================================|

trainNetwork uses a GPU if there is one available. If there is no available GPU, then it uses CPU. Note: For the available hardware options, see the trainingOptions function page.

Suppose the training was interrupted after the 8th epoch and did not complete. Rather than restarting the training from the beginning, you can load the last checkpoint network and resume training from that point.

Load Checkpoint Network and Resume Training

Load the checkpoint network.

load convnet_checkpoint__351__2016_11_09__12_04_23.mat

trainNetwork automatically assigns unique names to the checkpoint network files. For example, in this case, where 351 is the iteration number, 2016_11_09 is the date and 12_04_21 is the time trainNetwork saves the network.

If you don't have the data and the training options in your working directory, you must manually load and/or specify them before you can resume training.

Specify the training options to reduce the maximum number of epochs.

options = trainingOptions('sgdm','MaxEpochs',20, ...
        'CheckpointPath','C:\Temp\cnncheckpoint');

You can also adjust other training options, such as initial learning rate.

Resume the training using the layers of the checkpoint network you loaded with the new training options. The name of the network, by default, is net .

convnet2 = trainNetwork(XTrain,TTrain,net.Layers,options)
Training on single GPU.
Initializing image normalization.
|=========================================================================================|
|     Epoch    |   Iteration  | Time Elapsed |  Mini-batch  |  Mini-batch  | Base Learning|
|              |              |  (seconds)   |     Loss     |   Accuracy   |     Rate     |
|=========================================================================================|
|            1 |            1 |         0.02 |       0.5210 |       84.38% |       0.0100 |
|            2 |           50 |         0.62 |       0.4168 |       87.50% |       0.0100 |
|            3 |          100 |         1.23 |       0.4532 |       87.50% |       0.0100 |
|            4 |          150 |         1.83 |       0.3424 |       92.97% |       0.0100 |
|            6 |          200 |         2.45 |       0.3177 |       95.31% |       0.0100 |
|            7 |          250 |         3.06 |       0.2091 |       94.53% |       0.0100 |
|            8 |          300 |         3.67 |       0.1829 |       96.88% |       0.0100 |
|            9 |          350 |         4.27 |       0.1531 |       97.66% |       0.0100 |
|           11 |          400 |         4.91 |       0.1482 |       96.88% |       0.0100 |
|           12 |          450 |         5.58 |       0.1293 |       97.66% |       0.0100 |
|           13 |          500 |         6.31 |       0.1134 |       98.44% |       0.0100 |
|           15 |          550 |         6.94 |       0.1006 |      100.00% |       0.0100 |
|           16 |          600 |         7.55 |       0.0909 |       98.44% |       0.0100 |
|           17 |          650 |         8.16 |       0.0567 |      100.00% |       0.0100 |
|           18 |          700 |         8.76 |       0.0654 |      100.00% |       0.0100 |
|           20 |          750 |         9.41 |       0.0654 |      100.00% |       0.0100 |
|           20 |          780 |         9.77 |       0.0606 |       99.22% |       0.0100 |
|=========================================================================================|

convnet2 = 

  SeriesNetwork with properties:

    Layers: [7×1 nnet.cnn.layer.Layer]