Transfer Learning

What Is Transfer Learning?

Transfer learning is a deep learning approach in which a model that has been trained for one task is used as a starting point for a model that performs a similar task. Updating and retraining a network with transfer learning is usually much faster and easier than training a network from scratch. The approach is commonly used for object detection, image recognition, and speech recognition applications, among others.

Transfer learning is a popular technique because:

  • It enables you to train models with less labeled data by reusing popular models that have already been trained on large datasets.
  • It can reduce training time and computing resources. With transfer learning, the weights are not learned from scratch because the pretrained model has already learned the weights based on previous learnings.
  • You can take advantage of model architectures developed by the deep learning research community, including popular architectures such as GoogLeNet and ResNet.

Pretrained Models for Transfer Learning

At the center of transfer learning is the pretrained deep learning model, built by deep learning researchers, that has been trained using thousands or millions of sample training images.

Many pretrained models are available, and each has advantages and drawbacks to consider:

  • Size: What is the desired memory footprint for the model? The importance of your model’s size will vary depending on where and how you intend to deploy it. Will it run on embedded hardware or a desktop? The size of the network is particularly important when deploying to a low memory system.
  • Accuracy: How well does the model perform prior to retraining? Typically, a model that performs well for the ImageNet, a commonly used dataset containing a million images and a thousand classes of images, will likely perform well on new, similar tasks as well. However, a low accuracy score on ImageNet does not necessarily mean the model will perform poorly on all tasks.
  • Prediction speed: How fast can the model predict on new inputs? While prediction speed can vary based on other deep learning such as hardware and batch size, speed will also vary based on architecture of the chosen model, and the size of model.
Comparing model size, prediction speed, and accuracy of pretrained CNN models, which can be used for transfer learning.

Comparing model size, speed, and accuracy for popular pretrained networks.

You can use MATLAB and Deep Learning Toolbox to access pretrained networks from the latest research with a single line of code.  The toolbox also provides guidance on selecting the right network for your transfer learning project.

Which Model Is Best for Your Transfer Learning Application?

With many transfer learning models to choose from, it’s important to keep in mind the tradeoffs involved and the overall goals of your specific project. A network with relatively low accuracy, for example, may be perfectly suitable for a new deep learning task. A good approach is to try a variety of models to find the one that fits your application best.

Simple models for getting started. With simple models, such as AlexNet, GoogLeNet, VGG-16, and VGG-19, you can iterate quickly and experiment with different data preprocessing steps and training options. Once you see what settings work well, you can try a more accurate network to see if that improves your results.

Lightweight and computationally efficient models. SqueezeNet, MobileNet-v2, and ShuffleNet are good options when the deployment environment places limitations on model size.

You can use Deep Network Designer to quickly evaluate various pretrained models for your project and better understand tradeoffs between different model architectures.

The Transfer Learning Workflow

While there is great variety in transfer learning architectures and applications, most transfer learning workflows follow a common series of steps.

  1. Select a pretrained model. When getting started, it can help to select a relatively simple model. This example uses GoogLeNet, a popular network with 22 layers deep that has been trained to classify 1000 object categories.
Pretrained CNN model that can be modified for transfer learning in image classification tasks.
  1. Replace the final layers. To retrain the network to classify a new set of images and classes, you replace the last layers of the GoogLeNet model. The final fully connected layer is modified to contain the same number of nodes as the number of new classes, and a new classification layer which will produce an output based on the probabilities calculated by the softmax layer.
Replacing the final layers of a pretrained CNN model before retraining the model is essential to transfer learning.
  • After modifying the layers, the final fully connected layer will specify the new number of classes the network will learn, and the classification layer will determine outputs from the new output categories available.  For example, GoogLeNet was originally trained on 1000 categories, but by replacing the final layers you can retrain it to classify only the five (or any other number) categories of objects you are interested in.
  1. Optionally freeze the weights. You can freeze the weights of earlier layers in the network by setting the learning rates in those layers to zero. During training, the parameters of frozen layers are not updated, which can significantly speed up network training. If the new data set is small, then freezing weights can also prevent overfitting of the network to the new data set.
  2. Retrain the model. Retraining will update the network to learn and identify features associated with the new images and categories. In most cases, retraining requires less data than training a model from scratch.
  3. Predict and assess network accuracy. After the model is retrained, you can classify new images and evaluate how well the network performs.

    Training from Scratch or Transfer Learning?

    The two commonly used approaches for deep learning are training a model from scratch and transfer learning.

    Developing and training a model from scratch works better for highly specific tasks for which preexisting models cannot be used. The downside of this approach is that it typically requires a large amount of data to produce accurate results. If you’re performing text analysis, for example, and you don’t have access to a pretrained model for a text analysis but you do have access to a large number of data samples, then developing a model from scratch is likely the best approach.

    Transfer learning is useful for tasks such as object recognition, for which a variety of popular pretrained models exist. For example, if you need to classify images of flowers and you have a limited number of flower images, you can transfer weights and layers from an AlexNet network, replace the final classification layer, and retrain your model with the images you have.

Transfer learning steps: loading pretrained network, replacing final layers, retraining network, and assessing network accuracy.

Transfer learning workflow: Load network, replace layers, train network, and assess accuracy.

In such cases, it is possible to achieve higher model accuracy in a shorter time with transfer learning.

Comparing network performance against training between networks with and without transfer learning. The performance curve for transfer learning shows a higher start, slope, and asymptote.

Comparing the network performance (accuracy) of training from scratch and transfer learning.

An Interactive Approach to Transfer Learning

Using Deep Network Designer you can interactively complete the entire transfer learning workflow –  including importing a pretrained model, modifying the final layers, and retraining the network using new data – with little or no coding.

Deep Network Designer is a point-and-click tool for creating or modifying deep neural networks in MATLAB. This video shows how to use the app in a transfer learning workflow.

For more information, see Deep Learning Toolbox and Computer Vision Toolbox™.

See also: deep learning, convolutional neural networks, GPU Coder, artificial intelligence, biomedical signal processing