This example shows how to classify new image data by fine-tuning an existing, pretrained convolutional neural network.
If you have a small amount of training data, constructing and training a new network can be time consuming and ineffective. Instead, you can fine-tune an existing pretrained network to solve a new problem. This technique, called transfer learning, usually results in faster training. By taking layers from a pretrained network and retraining the layers only at the end of the network, you can finish training much faster.
Neural Network Toolbox provides access to several pretrained networks, including the popular AlexNet. This example uses a simple pretrained network to introduce you to transfer learning. It takes a convolutional neural network trained on letter images and fine-tunes it to classify images of digits.
Load a pretrained network.
LettersClassificationNet.mat contains the pretrained network
net, which is trained on a large collection of 28-by-28 grayscale letter images. The network classifies images into the three letter classes
Examine the details of the network architecture contained in the
Layers property of
ans = 7x1 Layer array with layers: 1 'imageinput' Image Input 28x28x1 images with 'zerocenter' normalization 2 'conv' Convolution 20 5x5x1 convolutions with stride [1 1] and padding [0 0 0 0] 3 'relu' ReLU ReLU 4 'maxpool' Max Pooling 2x2 max pooling with stride [2 2] and padding [0 0 0 0] 5 'fc' Fully Connected 3 fully connected layer 6 'softmax' Softmax softmax 7 'classoutput' Classification Output crossentropyex with 'A', 'B', and 1 other classes
The network has seven layers. Layers 2-4 detect features. Layers 5-7 classify the features into letters classes. The 20 outputs (channels) of the convolutional layer correspond to the learned features. You can use these features to predict the classes of the digits data using transfer learning.
The last three layers of the pretrained network
net are tuned for the letters data. By replacing the last three layers of the pretrained network, you can fine-tune the network to classify digits instead.
Load the digits sample data as an
ImageDatastore object. The
imageDatastore function labels the images automatically based on folder names and stores the data as an
ImageDatastore object. An
ImageDatastore object lets you store large image data, including data that does not fit in memory. The object also lets you efficiently read batches of images during training.
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); digitData = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true,'LabelSource','foldernames');
The data store contains 10000 synthetic images of digits 0-9. The images are generated by applying random transformations to digit images created using different fonts. Each digit image is 28-by-28 pixels.
Split the data set into training and test data sets. Split the data so that each data store has 50% of the images from each category.
splitEachLabel splits the image files in
digitData into two new data stores:
[trainDigitData,testDigitData] = splitEachLabel(digitData,0.5,'randomize');
Display 20 sample training digits using
numImages = numel(trainDigitData.Files); idx = randperm(numImages,20); for i = 1:20 subplot(4,5,i) I = readimage(trainDigitData, idx(i)); imshow(I) end
Transfer learning to the new task by replacing and fine-tuning these layers to the digits data. You can use transfer learning since the digits data is similar in content to the letters data. The last three layers of the pretrained network
net are tuned for the letters data.
Extract all the layers except the last three from the pretrained network.
layersTransfer = net.Layers(1:end-3);
The digits data set has 10 classes. Add a new fully connected layer for 10 classes, a softmax layer, and a classification output layer. To speed up learning on the new layers only, increase the weight and bias learn rate factors of the fully connected layer.
numClasses = numel(categories(trainDigitData.Labels))
numClasses = 10
layers = [ ... layersTransfer fullyConnectedLayer(numClasses,'WeightLearnRateFactor',20,'BiasLearnRateFactor',20) softmaxLayer classificationLayer];
If the training images differ in size to the image input layer, then you must resize or crop the image data. In this example, the training images are the same size as the input size of the pretrained network (28-by-28). You do not need to resize or crop the new image data.
Create the training options. For transfer learning, you want to keep the features from the early layers of the pretrained network (the transferred layer weights). Set
'InitialLearnRate' to a low value. This low initial learn rate slows down learning on the transferred layers. In the previous step, you set the learn rate factors for the fully connected layer higher to speed up learning on the new final layers. This combination results in fast learning only on the new layers while keeping the other layers fixed. When performing transfer learning, you do not need to train for as many epochs. To speed up training, reduce the value of the
'MaxEpochs' name-value pair argument in the call to
optionsTransfer = trainingOptions('sgdm', ... 'MaxEpochs',5, ... 'InitialLearnRate',0.0001);
Fine-tune the network that consists of the transferred and new layers using
netTransfer = trainNetwork(trainDigitData,layers,optionsTransfer);
Training on single CPU. Initializing image normalization. |=========================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Base Learning| | | | (seconds) | Loss | Accuracy | Rate | |=========================================================================================| | 1 | 1 | 0.70 | 10.5459 | 6.25% | 1.00e-04 | | 2 | 50 | 9.22 | 2.8589 | 71.09% | 1.00e-04 | | 3 | 100 | 17.35 | 0.7484 | 89.84% | 1.00e-04 | | 4 | 150 | 25.38 | 0.1280 | 95.31% | 1.00e-04 | | 5 | 195 | 33.04 | 0.2434 | 93.75% | 1.00e-04 | |=========================================================================================|
Calculate the classification accuracy: the proportion of correctly classified instances in the test data.
YPred = classify(netTransfer,testDigitData); YTest = testDigitData.Labels; accuracy = sum(YPred==YTest)/numel(YTest)
accuracy = 0.9274
Display sample test images with their predicted labels.
idx = 501:500:5000; figure for i = 1:numel(idx) subplot(3,3,i) I = readimage(testDigitData, idx(i)); label = char(YPred(idx(i))); imshow(I) title(label) end