Main Content

Prepare Network for Transfer Learning Using Deep Network Designer

This example shows how to prepare a network for transfer learning interactively using the Deep Network Designer app.

Transfer learning is the process of taking a pretrained deep learning network and fine-tuning it to learn a new task. Using transfer learning is usually faster and easier than training a network from scratch. You can quickly transfer learned features to a new task using a smaller amount of data.

Extract Data

Extract the MathWorks Merch dataset. This is a small dataset containing 75 images of MathWorks merchandise, belonging to five different classes (cap, cube, playing cards, screwdriver, and torch). The data is arranged such that the images are in subfolders that correspond to these five classes.

folderName = "MerchData";

Create an image datastore. An image datastore enables you to store large collections of image data, including data that does not fit in memory, and efficiently read batches of images during training of a neural network. Specify the folder with the extracted images and indicate that the subfolder names correspond to the image labels.

imds = imageDatastore(folderName, ...
    IncludeSubfolders=true, ...

Display some sample images.

numImages = numel(imds.Labels);
idx = randperm(numImages,16);
I = imtile(imds,Frames=idx);

Extract the class names and the number of classes.

classNames = categories(imds.Labels);
numClasses = numel(classNames)
numClasses = 5

Partition the data into training, validation, and testing datasets. Use 70% of the images for training, 15% for validation, and 15% for testing. The splitEachLabel function splits the image datastore into two new datastores.

[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.7,0.15,0.15,"randomized");

Select a Pretrained Network

To open Deep Network Designer, on the Apps tab, under Machine Learning and Deep Learning, click the app icon. Alternatively, you can open the app from the command line:


Deep Network Designer provides a selection of pretrained image classification networks that have learned rich feature representations suitable for a wide range of images. Transfer learning works best if your images are similar to the images originally used to train the network. If your training images are natural images like those in the ImageNet database, then any of the pretrained networks is suitable. For a list of available networks and how to compare them, see Pretrained Deep Neural Networks.

If your data is very different from the ImageNet data—for example, if you have tiny images, spectrograms, or nonimage data—training a new network might be better. For an example showing how to train a network from scratch, see Get Started with Time Series Forecasting.

SqueezeNet does not require an additional support package. For other pretrained networks, if you do not have the required support package installed, then the app provides the Install option.

Select SqueezeNet from the list of pretrained networks and click Open.

Explore Network

Deep Network Designer displays a zoomed-out view of the whole network in the Designer pane.

Explore the network plot. To zoom in with the mouse, use Ctrl+scroll wheel. To pan, use the arrow keys, or hold down the scroll wheel and drag the mouse. Select a layer to view its properties. Deselect all layers to view the network summary in the Properties pane.

Select the image input layer, 'input'. You can see that the input size for this network is 227-by-227-by-3 pixels.

Save the input size in the variable inputSize.

inputSize = [227 227 3];

Prepare Network for Training

To use a pretrained network for transfer learning, you must change the number of classes to match your new data set. First, find the last learnable layer in the network. For SqueezeNet, the last learnable layer is the last convolutional layer, 'conv10'. Select the 'conv10' layer. At the bottom of the Properties pane, click Unlock Layer. In the warning dialog that appears, click Unlock Anyway. This unlocks the layer properties so that you can adapt them to your new task.

Before R2023b: To adapt the network to the new data, you must replace the layers instead of unlocking them. In the new convolutional 2-D layer, set the FilterSize to [1 1].

The NumFilters property defines the number of classes for classification problems. Change NumFilters to the number of classes in the new data, in this example, 5.

Change the learning rates so that learning is faster in the new layer than in the transferred layers by setting WeightLearnRateFactor and BiasLearnRateFactor to 10.

Check Network

To check that the network is ready for training, click Analyze. The Deep Learning Network Analyzer reports zero errors or warnings, therefore, the network is ready for training. To export the network, click Export. The app saves the network in the variable net_1.

Prepare Data for Training

The images in the datastore can have different sizes. To automatically resize the training images, use an augmented image datastore. Data augmentation also helps prevent the network from overfitting and memorizing the exact details of the training images. Specify these additional augmentation operations to perform on the training images: randomly flip the training images along the vertical axis, and randomly translate them up to 30 pixels horizontally and vertically.

pixelRange = [-30 30];

imageAugmenter = imageDataAugmenter( ...
    RandXReflection=true, ...
    RandXTranslation=pixelRange, ...

augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...

To automatically resize the validation and testing images without performing further data augmentation, use an augmented image datastore without specifying any additional preprocessing operations.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);

Specify Training Options

Specify the training options. Choosing among the options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.

  • Train using the Adam optimizer.

  • Set the initial learn rate to a small value to slow down learning in the transferred layers.

  • Specify a small number of epochs. An epoch is a full training cycle on the entire training data set. For transfer learning, you do not need to train for as many epochs.

  • Specify validation data and a validation frequency so that the accuracy on the validation data is calculated once every epoch.

  • Specify the mini-batch size, that is, how many images to use in each iteration. To ensure the whole data set is used during each epoch, set the mini-batch size to evenly divide the number of training samples.

  • Display the training progress in a plot and monitor the accuracy metric.

  • Disable the verbose output.

options = trainingOptions("adam", ...
    InitialLearnRate=0.0001, ...
    MaxEpochs=8, ...
    ValidationData=imdsValidation, ...
    ValidationFrequency=5, ...
    MiniBatchSize=11, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...

Train Neural Network

Train the neural network using the trainnet function. For classification tasks, use cross-entropy loss. By default, the trainnet function uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainnet function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

net = trainnet(imdsTrain,net_1,"crossentropy",options);

Test Neural Network

Classify the test images. To make predictions with multiple observations, use the minibatchpredict function. To covert the prediction scores to labels, use the scores2label function. The minibatchpredict function automatically uses a GPU if one is available.

YTest = minibatchpredict(net,augimdsTest);
YTest = scores2label(YTest,classNames);

Visualize the classification accuracy in a confusion chart.

TTest = imdsTest.Labels;

Make Predictions with New Data

Classify an image. Read an image from a JPEG file, resize it, and convert to the single data type.

im = imread("MerchDataTest.jpg");
im = imresize(im,inputSize(1:2));
X = single(im);

Classify the image. To make prediction with a single observation, use the predict function. To use a GPU, first convert the data to gpuArray.

if canUseGPU
    X = gpuArray(X);
scores = predict(net,X);
[label,score] = scores2label(scores,classNames);

Display the image with the predicted label and the corresponding score.

title(string(label) + " (Score: " + gather(score) + ")")

See Also

Related Topics