This example shows how to fine-tune a pretrained GoogLeNet network to classify a new collection of images. This process is called transfer learning and is usually much faster and easier than training a new network, because you can apply learned features to a new task using a smaller number of training images. To interactively prepare a network for transfer learning, use Deep Network Designer.
Load a pretrained GoogLeNet network. If you need to download the network, use the download link.
net = googlenet;
Open Deep Network Designer.
Click Import and select the network from the workspace. Deep Network Designer displays a zoomed out view of the whole network. Explore the network plot. To zoom in with the mouse, use Ctrl+scroll wheel.
To retrain a pretrained network to classify new images, replace the final layers with new layers adapted to the new data set. You must change the number of classes to match your data.
Drag a new fullyConnectedLayer from the Layer Library onto the canvas. Edit the
OutputSize to the number of classes in the new data, in this example, 5.
Edit learning rates to learn faster in the new layers than in the transferred layers. Set
BiasLearnRateFactor to 10. Delete the last fully connected and connect up your new layer instead.
Replace the output layer. Scroll to the end of the Layer Library and drag a new classificationLayer onto the canvas. Delete the original
output layer and connect up your new layer instead.
To make sure your edited network is ready for training, click Analyze, and ensure the Deep Learning Network Analyzer reports zero errors.
Return to the Deep Network Designer and click Export. Deep Network Designer exports the network to a new variable called
lgraph_1 containing the edited network layers. You can now supply the layer variable to the
trainNetwork function. You can also generate MATLAB® code that recreates the network architecture and returns it as a
layerGraph object or a
Layer array in the MATLAB workspace.
Unzip and load the new images as an image datastore. Divide the data into 70% training data and 30% validation data.
unzip('MerchData.zip'); imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames'); [imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
Resize images to match the pretrained network input size.
augimdsTrain = augmentedImageDatastore([224 224],imdsTrain); augimdsValidation = augmentedImageDatastore([224 224],imdsValidation);
Specify training options.
Specify the mini-batch size, that is, how many images to use in each iteration.
Specify a small number of epochs. An epoch is a full training cycle on the entire training data set. For transfer learning, you do not need to train for as many epochs. Shuffle the data every epoch.
InitialLearnRate to a small value to slow down learning in the transferred layers.
Specify validation data and a small validation frequency.
Turn on the training plot to monitor progress while you train.
options = trainingOptions('sgdm', ... 'MiniBatchSize',10, ... 'MaxEpochs',6, ... 'Shuffle','every-epoch', ... 'InitialLearnRate',1e-4, ... 'ValidationData',augimdsValidation, ... 'ValidationFrequency',6, ... 'Verbose',false, ... 'Plots','training-progress');
To train the network, supply the layers exported from the app,
lgraph_1, the training images, and options, to the
trainNetwork function. By default,
trainNetwork uses a GPU if available (requires Parallel Computing Toolbox™). Otherwise, it uses a CPU. Training is fast because the data set is so small.
netTransfer = trainNetwork(augimdsTrain,lgraph_1,options);
Classify the validation images using the fine-tuned network, and calculate the classification accuracy.
[YPred,probs] = classify(netTransfer,augimdsValidation); accuracy = mean(YPred == imdsValidation.Labels)
accuracy = 1
Display four sample validation images with predicted labels and predicted probabilities.
idx = randperm(numel(augimdsValidation.Files),4); figure for i = 1:4 subplot(2,2,i) I = readimage(imdsValidation,idx(i)); imshow(I) label = YPred(idx(i)); title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%"); end
To learn more and try other pretrained networks, see Deep Network Designer.