This example shows how to save checkpoint networks while training a deep learning network and resume training from a previously saved network.
Load the sample data as a 4-D array.
digitTrain4DArrayData loads the digit training set as 4-D array data.
XTrain is a 28-by-28-by-1-by-5000 array, where 28 is the height and 28 is the width of the images. 1 is the number of channels and 5000 is the number of synthetic images of handwritten digits.
YTrain is a categorical vector containing the labels for each observation.
[XTrain,YTrain] = digitTrain4DArrayData; size(XTrain)
ans = 1×4 28 28 1 5000
Display some of the images in
figure; perm = randperm(size(XTrain,4),20); for i = 1:20 subplot(4,5,i); imshow(XTrain(:,:,:,perm(i))); end
Define the neural network architecture.
layers = [ imageInputLayer([28 28 1]) convolution2dLayer(3,8,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer averagePooling2dLayer(7) fullyConnectedLayer(10) softmaxLayer classificationLayer];
Specify training options for stochastic gradient descent with momentum (SGDM) and specify the path for saving the checkpoint networks.
checkpointPath = pwd; options = trainingOptions('sgdm', ... 'InitialLearnRate',0.1, ... 'MaxEpochs',20, ... 'Verbose',false, ... 'Plots','training-progress', ... 'Shuffle','every-epoch', ... 'CheckpointPath',checkpointPath);
Train the network.
trainNetwork uses a GPU if there is one available. If there is no available GPU, then it uses CPU.
trainNetwork saves one checkpoint network each epoch and automatically assigns unique names to the checkpoint files.
net1 = trainNetwork(XTrain,YTrain,layers,options);
Suppose that training was interrupted and did not complete. Rather than restarting the training from the beginning, you can load the last checkpoint network and resume training from that point.
trainNetwork saves the checkpoint files with file names on the form
net_checkpoint__195__2018_07_13__11_59_10.mat, where 195 is the iteration number,
2018_07_13 is the date, and
11_59_10 is the time
trainNetwork saved the network. The checkpoint network has the variable name
Load the checkpoint network into the workspace.
Specify the training options and reduce the maximum number of epochs. You can also adjust other training options, such as the initial learning rate.
options = trainingOptions('sgdm', ... 'InitialLearnRate',0.1, ... 'MaxEpochs',15, ... 'Verbose',false, ... 'Plots','training-progress', ... 'Shuffle','every-epoch', ... 'CheckpointPath',checkpointPath);
Resume training using the layers of the checkpoint network you loaded with the new training options. If the checkpoint network is a DAG network, then use
layerGraph(net) as the argument instead of
net2 = trainNetwork(XTrain,YTrain,net.Layers,options);