Documentation

This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English verison of the page.

Note: This page has been translated by MathWorks. Please click here
To view all translated materals including this page, select Japan from the country navigator on the bottom of this page.

Resume Training from a Checkpoint Network

This example shows how to save checkpoint networks while training a convolutional neural network and resume training from a previously saved network.

Load Sample Data

Load the sample data as a 4-D array.

[XTrain,TTrain] = digitTrain4DArrayData;

Display the size of XTrain.

size(XTrain)
ans =

          28          28           1        5000

digitTrain4DArrayData loads the digit training set as 4-D array data. XTrain is a 28-by-28-by-1-by-5000 array, where 28 is the height and 28 is the width of the images. 1 is the number of channels and 5000 is the number of synthetic images of handwritten digits. TTrain is a categorical vector containing the labels for each observation.

Display some of the images in XTrain .

figure;
perm = randperm(5000,20);
for i = 1:20
   subplot(4,5,i);
   imshow(XTrain(:,:,:,perm(i)));
end

Define Layers of Network

Define the layers of the convolutional neural network.

layers = [imageInputLayer([28 28 1])
          convolution2dLayer(5,20)
          reluLayer()
          maxPooling2dLayer(2,'Stride',2)
          fullyConnectedLayer(10)
          softmaxLayer()
          classificationLayer()];

Specify Training Options with Checkpoint Path and Train Network

Set the options to default settings for the stochastic gradient descent with momentum and specify the path for saving the checkpoint networks.

options = trainingOptions('sgdm','CheckpointPath','C:\Temp\cnncheckpoint');

Train the network.

convnet = trainNetwork(XTrain,TTrain,layers,options);
Training on single GPU.
Initializing image normalization.
|=========================================================================================|
|     Epoch    |   Iteration  | Time Elapsed |  Mini-batch  |  Mini-batch  | Base Learning|
|              |              |  (seconds)   |     Loss     |   Accuracy   |     Rate     |
|=========================================================================================|
|            1 |            1 |         0.02 |       2.3022 |        7.81% |       0.0100 |
|            2 |           50 |         0.63 |       2.2684 |       31.25% |       0.0100 |
|            3 |          100 |         1.22 |       1.5395 |       53.13% |       0.0100 |
|            4 |          150 |         1.81 |       1.4287 |       50.00% |       0.0100 |
|            6 |          200 |         2.43 |       1.1682 |       60.94% |       0.0100 |
|            7 |          250 |         3.04 |       0.7370 |       77.34% |       0.0100 |
|            8 |          300 |         3.65 |       0.8038 |       71.09% |       0.0100 |
|            9 |          350 |         4.26 |       0.6766 |       78.91% |       0.0100 |
|           11 |          400 |         4.88 |       0.4824 |       89.06% |       0.0100 |
|           12 |          450 |         5.50 |       0.3670 |       90.63% |       0.0100 |
|           13 |          500 |         6.11 |       0.3582 |       92.19% |       0.0100 |
|           15 |          550 |         6.75 |       0.2501 |       93.75% |       0.0100 |
|           16 |          600 |         7.37 |       0.2662 |       93.75% |       0.0100 |
|           17 |          650 |         7.97 |       0.1974 |       96.09% |       0.0100 |
|           18 |          700 |         8.59 |       0.2140 |       97.66% |       0.0100 |
|           20 |          750 |         9.21 |       0.1402 |       99.22% |       0.0100 |
|           21 |          800 |         9.82 |       0.1324 |       97.66% |       0.0100 |
|           22 |          850 |        10.43 |       0.1373 |       96.88% |       0.0100 |
|           24 |          900 |        11.07 |       0.0913 |      100.00% |       0.0100 |
|           25 |          950 |        11.70 |       0.0935 |       98.44% |       0.0100 |
|           26 |         1000 |        12.31 |       0.0647 |      100.00% |       0.0100 |
|           27 |         1050 |        12.92 |       0.0678 |       99.22% |       0.0100 |
|           29 |         1100 |        13.55 |       0.1053 |       98.44% |       0.0100 |
|           30 |         1150 |        14.17 |       0.0714 |       99.22% |       0.0100 |
|           30 |         1170 |        14.41 |       0.0497 |      100.00% |       0.0100 |
|=========================================================================================|

trainNetwork uses a GPU if there is one available. If there is no available GPU, then it uses CPU. Note: For the available hardware options, see the trainingOptions function page.

Suppose the training was interrupted after the 8th epoch and did not complete. Rather than restarting the training from the beginning, you can load the last checkpoint network and resume training from that point.

Load Checkpoint Network and Resume Training

Load the checkpoint network.

load convnet_checkpoint__351__2016_11_09__12_04_23.mat

trainNetwork automatically assigns unique names to the checkpoint network files. For example, in this case, where 351 is the iteration number, 2016_11_09 is the date and 12_04_21 is the time trainNetwork saves the network.

If you don't have the data and the training options in your working directory, you must manually load and/or specify them before you can resume training.

Specify the training options to reduce the maximum number of epochs.

options = trainingOptions('sgdm','MaxEpochs',20, ...
        'CheckpointPath','C:\Temp\cnncheckpoint');

You can also adjust other training options, such as initial learning rate.

Resume the training using the layers of the checkpoint network you loaded with the new training options. The name of the network, by default, is net .

convnet2 = trainNetwork(XTrain,TTrain,net.Layers,options)
Training on single GPU.
Initializing image normalization.
|=========================================================================================|
|     Epoch    |   Iteration  | Time Elapsed |  Mini-batch  |  Mini-batch  | Base Learning|
|              |              |  (seconds)   |     Loss     |   Accuracy   |     Rate     |
|=========================================================================================|
|            1 |            1 |         0.02 |       0.5210 |       84.38% |       0.0100 |
|            2 |           50 |         0.62 |       0.4168 |       87.50% |       0.0100 |
|            3 |          100 |         1.23 |       0.4532 |       87.50% |       0.0100 |
|            4 |          150 |         1.83 |       0.3424 |       92.97% |       0.0100 |
|            6 |          200 |         2.45 |       0.3177 |       95.31% |       0.0100 |
|            7 |          250 |         3.06 |       0.2091 |       94.53% |       0.0100 |
|            8 |          300 |         3.67 |       0.1829 |       96.88% |       0.0100 |
|            9 |          350 |         4.27 |       0.1531 |       97.66% |       0.0100 |
|           11 |          400 |         4.91 |       0.1482 |       96.88% |       0.0100 |
|           12 |          450 |         5.58 |       0.1293 |       97.66% |       0.0100 |
|           13 |          500 |         6.31 |       0.1134 |       98.44% |       0.0100 |
|           15 |          550 |         6.94 |       0.1006 |      100.00% |       0.0100 |
|           16 |          600 |         7.55 |       0.0909 |       98.44% |       0.0100 |
|           17 |          650 |         8.16 |       0.0567 |      100.00% |       0.0100 |
|           18 |          700 |         8.76 |       0.0654 |      100.00% |       0.0100 |
|           20 |          750 |         9.41 |       0.0654 |      100.00% |       0.0100 |
|           20 |          780 |         9.77 |       0.0606 |       99.22% |       0.0100 |
|=========================================================================================|

convnet2 = 

  SeriesNetwork with properties:

    Layers: [7×1 nnet.cnn.layer.Layer]

See Also

|

Related Examples

More About

Was this topic helpful?