Main Content

Speed Up Deep Neural Network Training

Training a network is commonly the most time-consuming step in a deep learning workflow. This page describes methods you can use to speed up training. Some, like optimizing training hyperparameters, can be used with almost all workflows and often have a large impact on training time, while others, like accelerating custom layers, are only relevant to some workflows. For information about how to improve the accuracy of your network, see Deep Learning Tips and Tricks.

Before attempting to speed up your code, profile the code using the Profiler app to determine where MATLAB® spends the most time. You can then focus your attention on the parts of your code that are slower than expected.

Optimize Training Hyperparameters

While every training option can affect training speed, this section discusses several important options. To investigate the effects of varying your training options, use Experiment Manager. For an example showing how to use Experiment Manager to find optimal hyperparameters, see Tune Experiment Hyperparameters by Using Bayesian Optimization.

Choose a Solver

When you train a network using the trainnet function, you must specify a solver using the trainingOptions function. The solvers are algorithms that update the network parameters (weights and biases) to minimize the loss function. Each solver performs differently and each might be an appropriate choice, depending on your network and data. This table discusses how the solvers affect training speed.


trainingOptions solverName Argument

Impact on Training Speed
Adaptive moment estimation (Adam)"adam"
  • Usually converges faster than the other solvers [1], leading to decreased training time

  • A good default choice

For more information, see Adaptive Moment Estimation.

Root mean square propagation (RMSProp)"rmsprop"
  • Usually converges faster than SGDM but slower than Adam [1]

For more information, see Root Mean Square Propagation.

Stochastic gradient descent with momentum (SGDM)"sgdm"
  • The slowest to converge, leading to increased training time

  • Generalizes better than Adam in some cases [2]

For more information, see Stochastic Gradient Descent with Momentum.

Limited-memory Broyden-Fletcher-Goldfarb-Shanno (L-BFGS)"lbfgs"
  • Converges fastest for small networks and small data sets

  • Full-batch solver that processes the entire training data set in a single iteration

For more information, see Limited-Memory BFGS.

Choose Training Hyperparameters

This table shows how some training options affect training speed.

HyperparametertrainingOptions Argument NamesImpact on Training Speed
Learning rate – Scales the amount by which the learnable parameters of the network can be updated in a single training iteration.

If your learning rate is too low, training will take a long time. If your learning rate is too high, training might fail to converge.

You can use the LearnRateSchedule, LearnRateDropPeriod and LearnRateDropFactor options to start training with a high learning rate and then gradually reduce the learning rate during training.

Mini-batch size – The size of the subset of the training data that is used to evaluate the gradient of the loss function and update the weights.MiniBatchSize

Increasing the mini-batch size usually results in a decrease in training time. However, larger mini-batch sizes can negatively impact the final accuracy of the trained network [3].

A high mini-batch size might result in out of memory errors or slower training. For more information, see Resolve GPU Memory Issues.

Gradient clipping – Rescales the gradients during training to prevent the gradients from vanishing or exploding [4].

Gradient clipping reduces training time by helping the training to converge.

Use Transfer Learning

Transfer learning is the process of taking a pretrained deep learning network and fine-tuning it to learn a new task. If a suitable pretrained model exists for your task, using transfer learning is usually faster than training a network from scratch. You can quickly transfer learned features to a new task using a smaller amount of data. For more information about the available pretrained networks, see Pretrained Deep Neural Networks. For an example showing how to fine-tune a pretrained network and use it to classify a new collection of images, see Retrain Neural Network to Classify New Images. For an example showing how using transfer learning can be faster, see Compare Deep Learning Models Using ROC Curves.

Optimize Network Architecture

Networks with a lot of learnable parameters require more computation for each training iteration. Therefore, reducing the number of learnable parameters in your network can speed up training. Many networks contain significantly more learnable parameters than are required for the network to perform well and can therefore be scaled down without negatively impacting network performance.

For an example showing how to choose an LSTM network with an optimum number of hidden units, see Choose Training Configurations for LSTM Using Bayesian Optimization.

Normalize Data

Like gradient clipping, normalizing your data speeds up training by stabilizing the training process.

Normalize Input Data

Most built-in input layers include normalization properties that you can set when you create the layer. For example, by default, an imageInputLayer centers the data around zero by subtracting the mean, but you can choose another normalization method, turn off normalization, or write your own function to control the normalization. For a list of built-in input layers, see Input Layers.

Alternatively, you can manually normalize your input data, provided that your input layer is not also performing normalization. For examples showing how to manually normalize input data, see Time Series Forecasting Using Deep Learning and Sequence-to-Sequence Regression Using Deep Learning.

Normalize Data Within Networks

You can also normalize data within your network. You can include built-in normalization layers at one or more points in your network to perform different kinds of normalization, such as layer normalization and batch normalization. For a list of built-in normalization layers, see Normalization Layers.

To implement normalization methods not supported by the built-in normalization layers, use a functionLayer or define your own custom layer. For more information on defining custom layers, see Define Custom Deep Learning Layers.

Stop Training Early

By default, the trainnet function trains a network until the number of epochs specified by the MaxEpochs training option have elapsed.

Stop Training Manually

If you plot training progress during training by specifying the Plots training option as "training-progress", you can manually stop training by clicking the stop button next to the training progress bar in the top right-hand corner of the plot window.

Screenshot of the training progress plot, highlighting the stop button in the top right-hand corner

Stop Training Based on Validation Statistics

If you are performing validation during training, you can stop training early if the validation loss does not improve for a specified number of validation passes. To specify the patience of this validation stopping, set the value of the ValidationPatience option using the trainingOptions function. For example, these training options cause training to stop if the loss on the validation fails to improve for three consecutive validation passes, then return the network with the lowest validation loss.

options = trainingOptions("adam", ...
    ValidationData=myValData, ...

To stop training early based on a different metric, specify the ObjectiveMetricName option using the trainingOptions function.

Stop Training Based on Other Criteria

To stop training based on other criteria, define an output function using the OutputFcn option of the trainingOptions function. The software calls this function once before the start of training, after each iteration, and once when training is complete. The trainnet function passes the output function the structure info, which contains fields describing the current epoch, iteration, time elapsed, learn rate, training loss, validation loss, and state (specified as "start", "iteration", or "done"). If the output function returns false (0), then the software stops training. For example, the training options defined in this code cause training to stop if more than 10 minutes have elapsed and the training loss is less than 0.5.

minTime = minutes(10);
maxLoss = 0.5;

options = trainingOptions("adam", ...
    OutputFcn=@(info) stopEarly(info,minTime,maxLoss));

function stop = stopEarly(info,minTime,maxLoss)

if (info.State == "iteration") && (info.TimeElapsed > minTime) ...
        && (info.TrainingLoss < maxLoss)
    stop = true;
    stop = false;


Disable Optional Visualizations

Visualizations of the training process allow you to assess the training progress and take appropriate action. For example, if training is proceeding as expected, you might allow training to continue, or, if the loss is not decreasing, you might stop training and change your network or training options. However, these visualizations add processing overhead to the training and therefore increase training time.

If you have established that training a network is training as expected, for example by plotting the training progress for several epochs, consider disabling these visualizations.

Disable Training Progress Plotting

To plot training progress while training a network using the trainnet function, specify the Plots option as "training-progress" using the trainingOptions function. To turn off plotting, either do not specify the Plots option or specify the Plots option as "none".

Even if you do not plot training progress during training, you can still open the training progress plot after training by using the show function.

Disable or Reduce Verbose Output

By default, verbose training progress information is output to the command window when you train a network using the trainnet function. To suppress verbose output, specify the Verbose option as false (0) using the trainingOptions function. To increase the number of training iterations between printing training progress information, increase the VerboseFrequency training option from the default value of 50. Verbose output impacts training speed much less than plotting training progress.

Reduce Validation Time

Testing your network on a validation data set during training provides a useful measure of how your network is performing on previously unseen data. Validation can reveal problems in the training process, such as underfitting and overfitting. However, calculating validation statistics can slow down training, particularly if your validation data set is large.

Reduce Size of Validation Data Set

Your validation data set should be large enough to sufficiently represent your data, but using a too-large validation data set can slow down training. If your data set is large, you can allocate a smaller proportion for validation.

Reduce Validation Frequency

By default, validation statistics are calculated every 50 iterations. To reduce the frequency of validation, specify the number of iterations between validation passes by setting the ValidationFrequency option using the trainingOptions function. For example, the training options defined in the following code cause validation statistics to be evaluated every 400 iterations.

options = trainingOptions("adam", ...
    ValidationData=myValData, ...

Preprocess Data in Advance

If your data requires significant preprocessing, then you can avoid repeatedly processing the images at runtime by instead applying the preprocessing in advance. For example, this code applies the transformation function preprocess that resizes, converts to grayscale, and normalizes images as they are read from an imageDatastore.

imds = imageDatastore(dataFolder,IncludeSubfolders=true,LabelSource="foldernames");

imageResize = [32, 32];

tds = transform(imds,@(x)preprocess(x,imageResize);

function [data] = preprocess(img,finalSize)
    % Resize the image.
    img = imresize(img,finalSize);
    % Convert to grayscale if the image is RGB.
    if numel(size(img)) > 2
        img = rgb2gray(img);
    % Normalize the image.
    img = im2single(img);

    % Return the image and label together.
    data = {img};

If you use the TransformedDatastore tds for training, the software applies the preprocess function each time an image is used. As each image is used many times during training, the preprocessing is repeated many times. To avoid this repeated processing, save the processed images to a folder by using the writeall function. You can then use these preprocessed images for training.


Use Uniformly Sized Data

When you train a network, the software applies performance optimizations including generating optimized underlying code. To take full advantage of these optimizations, ensure that your training data has a consistent size.

The trainnet function pads all sequences in a mini-batch to have the same length. However, the sequence length can vary between different mini-batches. If your training data includes sequences of different lengths, consider padding or trimming all of the sequences to have the same length before training. For more information on the default sequence padding behavior of the trainnet function, see Sequence Padding and Truncation.

Use GPUs and Parallel Computing

Training deep networks is computationally intensive and can take many hours of computing time; however, neural networks are inherently parallel algorithms. You can take advantage of this parallelism by running in parallel using high-performance GPUs and computer clusters. By default, the trainnet function uses a GPU if one is available. For more information, see Scale Up Deep Learning in Parallel, on GPUs, and in the Cloud.

Accelerate Custom Layers

If your network uses custom layers, you can speed up training by indicating that your custom layers support acceleration. Not all custom layers support acceleration. For more information, see Custom Layer Function Acceleration.

Optimize Custom Training Code

This page discusses built-in training using the trainnet function. However, most of the methods described can also be applied to training using a custom training loop. For information about training options and visualizations with custom training loops, see Specify Training Options in Custom Training Loop and Monitor Custom Training Loop Progress.

The following methods are either unique to custom training or are much more important when you write a custom training loop.

Accelerate Functions

When you use the dlfeval function in a custom training loop, the software traces each input dlarray object of the model loss function to determine the computation graph used for automatic differentiation. This tracing process can take some time, partly because it can recompute the same traces. By optimizing, caching, and reusing the traces, you can speed up gradient computation in deep learning functions.

To speed up calls to deep learning functions, you can use the dlaccelerate function to create an AcceleratedFunction object that automatically optimizes, caches, and reuses the traces. For example, to accelerate the model loss function and evaluate the accelerated function, use the dlaccelerate function and evaluate the returned AcceleratedFunction object.

accfun = dlaccelerate(@modelLoss);
[loss,gradients,state] = dlfeval(accfun,parameters,X,T,state)

For more information about function acceleration, see Deep Learning Function Acceleration for Custom Training Loops.

Train Inside Function

When you run code in a script, MATLAB creates a copy before updating any variable that exists in the workspace so that no data is lost if there is an error. As custom training loops usually update the network parameters many times, for example using adamupdate or sgdmupdate, this can greatly impact training speed and memory usage.

To accelerate your custom training loop, run the training loop inside a function instead of a script.

Improve Code Performance

If you use a custom training loop, then the speed of training strongly depends on your own code. For more information about improving the performance of your MATLAB code, see Profile Your Code to Improve Performance, Techniques to Improve Performance, and Measure and Improve GPU Performance (Parallel Computing Toolbox).


[1] Kingma, Diederik P., and Jimmy Ba. “Adam: A Method for Stochastic Optimization” Preprint, submitted in 2014.

[2] Wilson, Ashia C., Rebecca Roelofs, Mitchell Stern, Nathan Srebro, and Benjamin Recht. “The Marginal Value of Adaptive Gradient Methods in Machine Learning” Preprint, submitted in 2017.

[3] Mishkin, Dmytro, Nikolay Sergievskiy, and Jiri Matas. “Systematic Evaluation of CNN Advances on the ImageNet” Preprint, submitted in 2016.

[4] Pascanu, Razvan, Tomas Mikolov, and Yoshua Bengio. “On the Difficulty of Training Recurrent Neural Networks” Preprint, submitted in 2012.

Related Topics