MATLAB Examples

Get Started with Transfer Learning

This example shows how to use transfer learning to retrain AlexNet, a pretrained convolutional neural network, to classify a new set of images. Try this example to see how simple it is to get started with deep learning in MATLAB®.

Transfer learning is commonly used in deep learning applications. You can take a pretrained network and use it as a starting point to learn a new task. Fine-tuning a network with transfer learning is usually much faster and easier than training a network with randomly initialized weights from scratch. You can quickly transfer learned features to a new task using a smaller number of training images.

Contents

Load Data

Unzip and load the new images as an image datastore. Divide the data into training and validation data sets. Use 70% of the images for training and 30% for validation.

unzip('MerchData.zip');
imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,idmsValidation] = splitEachLabel(imds,0.7,'randomized');

Load Pretrained Network

Load the pretrained AlexNet network. If Neural Network Toolbox™ Model for AlexNet Network is not installed, then the software provides a download link. AlexNet has been trained on over a million images and can classify images into 1000 object categories (such as keyboard, coffee mug, pencil, and many animals). The network has learned rich feature representations for a wide range of images. The network takes an image as input and outputs a label for the object in the image together with the probabilities for each of the object categories.

net = alexnet;

Replace Final Layers

To retrain AlexNet to classify new images, replace the last three layers of the network. Set the final fully connected layer to have the same size as the number of classes in the new data set (5, in this example). To learn faster in the new layers than in the transferred layers, increase the learning rate factors of the fully connected layer.

layersTransfer = net.Layers(1:end-3);
numClasses = numel(categories(imdsTrain.Labels));
layers = [
    layersTransfer
    fullyConnectedLayer(numClasses,'WeightLearnRateFactor',20,'BiasLearnRateFactor',20)
    softmaxLayer
    classificationLayer];

Train Network

Specify the training options, including mini-batch size and validation data. Set InitialLearnRate to a small value to slow down learning in the transferred layers. In the previous step, you increased the learning rate factors for the fully connected layer to speed up learning in the new final layers. This combination of learning rate settings results in fast learning only in the new layers and slower learning in the other layers.

options = trainingOptions('sgdm',...
    'MiniBatchSize',10, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',1e-4, ...
    'ValidationData',idmsValidation, ...
    'ValidationFrequency',3, ...
    'ValidationPatience',Inf, ...
    'Verbose',false, ...
    'Plots','training-progress');

Train the network using the training data. By default, trainNetwork uses a GPU if one is available (requires Parallel Computing Toolbox™ and a CUDA® enabled GPU with compute capability 3.0 or higher). Otherwise, it uses a CPU. You can also specify the execution environment by using the 'ExecutionEnvironment' name-value pair argument of trainingOptions.

netTransfer = trainNetwork(imdsTrain,layers,options);

Classify Validation Images

Classify the validation images using the fine-tuned network, and calculate the classification accuracy.

YPred = classify(netTransfer,idmsValidation);
accuracy = mean(YPred == idmsValidation.Labels)
accuracy =

     1

For a more detailed transfer learning example, see Transfer Learning Using AlexNet.