This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English verison of the page.

Note: This page has been translated by MathWorks. Please click here
To view all translated materals including this page, select Japan from the country navigator on the bottom of this page.

Transfer Learning and Fine-Tuning of Convolutional Neural Networks

This example shows how to classify new image data by fine-tuning an existing, pretrained convolutional neural network.

If you have a small amount of training data, constructing and training a new network can be time consuming and ineffective. Instead, you can fine-tune an existing pretrained network to solve a new problem. This technique, called transfer learning, usually results in faster training. By taking layers from a pretrained network and retraining the layers only at the end of the network, you can finish training much faster.

Neural Network Toolbox provides access to several pretrained networks, including the popular AlexNet. This example uses a simple pretrained network to introduce you to transfer learning. It takes a convolutional neural network trained on letter images and fine-tunes it to classify images of digits.

Load Pretrained Network

Load a pretrained network. LettersClassificationNet.mat contains the pretrained network net, which is trained on a large collection of 28-by-28 grayscale letter images. The network classifies images into the three letter classes 'A', 'B', and 'C'.


Examine the details of the network architecture contained in the Layers property of net.

ans = 

  7x1 Layer array with layers:

     1   'imageinput'    Image Input             28x28x1 images with 'zerocenter' normalization
     2   'conv'          Convolution             20 5x5x1 convolutions with stride [1  1] and padding [0  0]
     3   'relu'          ReLU                    ReLU
     4   'maxpool'       Max Pooling             2x2 max pooling with stride [2  2] and padding [0  0]
     5   'fc'            Fully Connected         3 fully connected layer
     6   'softmax'       Softmax                 softmax
     7   'classoutput'   Classification Output   crossentropyex with 'A', 'B', and 1 other classes

The network has seven layers. Layers 2–4 detect features. Layers 5–7 classify the features into letters classes. The 20 outputs (channels) of the convolutional layer correspond to the learned features. You can use these features to predict the classes of the digits data using transfer learning.

The last three layers of the pretrained network net are tuned for the letters data. By replacing the last three layers of the pretrained network, you can fine-tune the network to classify digits instead.

Load Training Data

Load the digits sample data as an ImageDatastore object. The imageDatastore function labels the images automatically based on folder names and stores the data as an ImageDatastore object. An ImageDatastore object lets you store large image data, including data that does not fit in memory. The object also lets you efficiently read batches of images during training.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...

digitData = imageDatastore(digitDatasetPath, ...

The data store contains 10000 synthetic images of digits 0–9. The images are generated by applying random transformations to digit images created using different fonts. Each digit image is 28-by-28 pixels.

Split the data set into training and test data sets. Split the data so that each data store has 50% of the images from each category. splitEachLabel splits the image files in digitData into two new data stores: trainDigitData and testDigitData.

[trainDigitData,testDigitData] = splitEachLabel(digitData,0.5,'randomize');

Display 20 sample training digits using imshow.

numImages = numel(trainDigitData.Files);
idx = randperm(numImages,20);
for i = 1:20

    I = readimage(trainDigitData, idx(i));


Transfer Layers to Target Network

Transfer learning to the new task by replacing and fine-tuning these layers to the digits data. You can use transfer learning since the digits data is similar in content to the letters data. The last three layers of the pretrained network net are tuned for the letters data.

Extract all the layers except the last three from the pretrained network.

layersTransfer = net.Layers(1:end-3);

The digits data set has 10 classes. Add a new fully connected layer for 10 classes, a softmax layer, and a classification output layer. To speed up learning on the new layers only, increase the weight and bias learn rate factors of the fully connected layer.

numClasses = numel(categories(trainDigitData.Labels))
numClasses =


layers = [ ...

If the training images differ in size to the image input layer, then you must resize or crop the image data. In this example, the training images are the same size as the input size of the pretrained network (28-by-28). You do not need to resize or crop the new image data.

Create the training options. For transfer learning, you want to keep the features from the early layers of the pretrained network (the transferred layer weights). Set 'InitialLearnRate' to a low value. This low initial learn rate slows down learning on the transferred layers. In the previous step, you set the learn rate factors for the fully connected layer higher to speed up learning on the new final layers. This combination results in fast learning only on the new layers while keeping the other layers fixed. When performing transfer learning, you do not need to train for as many epochs. To speed up training, reduce the value of the 'MaxEpochs' name-value pair argument in the call to trainingOptions.

optionsTransfer = trainingOptions('sgdm', ...
    'MaxEpochs',5, ...

Fine-tune the network that consists of the transferred and new layers using trainNetwork.

netTransfer = trainNetwork(trainDigitData,layers,optionsTransfer);
Training on single CPU.
Initializing image normalization.
|     Epoch    |   Iteration  | Time Elapsed |  Mini-batch  |  Mini-batch  | Base Learning|
|              |              |  (seconds)   |     Loss     |   Accuracy   |     Rate     |
|            1 |            1 |         0.75 |      10.5459 |        6.25% |     1.00e-04 |
|            2 |           50 |        19.63 |       3.1897 |       70.31% |     1.00e-04 |
|            3 |          100 |        38.44 |       1.1082 |       80.47% |     1.00e-04 |
|            4 |          150 |        58.13 |       0.1465 |       95.31% |     1.00e-04 |
|            5 |          195 |        75.28 |       0.1504 |       96.88% |     1.00e-04 |

Calculate the classification accuracy: the proportion of correctly classified instances in the test data.

YPred = classify(netTransfer,testDigitData);
YTest = testDigitData.Labels;

accuracy = sum(YPred==YTest)/numel(YTest)
accuracy =


Display sample test images with their predicted labels.

idx = 501:500:5000;
for i = 1:numel(idx)

    I = readimage(testDigitData, idx(i));
    label = char(YTest(idx(i)));


See Also


Related Topics

Was this topic helpful?