This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English version of the page.

Note: This page has been translated by MathWorks. Click here to see
To view all translated materials including this page, select Country from the country navigator on the bottom of this page.

trainNetwork

Train neural network for deep learning

Use trainNetwork to train a convolutional neural network (ConvNet, CNN), a long short-term memory (LSTM) network, or a bidirectional LSTM (BiLSTM) network for deep learning classification and regression problems. You can train a network on either a CPU or a GPU. For image classification and image regression, you can train using multiple GPUs or in parallel. Using GPU, multi-GPU, and parallel options requires Parallel Computing Toolbox™. To use a GPU for deep learning, you must also have a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher. Specify training options, including options for the execution environment, by using trainingOptions.

Syntax

trainedNet = trainNetwork(imds,layers,options)
trainedNet = trainNetwork(mbds,layers,options)
trainedNet = trainNetwork(X,Y,layers,options)
trainedNet = trainNetwork(sequences,Y,layers,options)
trainedNet = trainNetwork(tbl,layers,options)
trainedNet = trainNetwork(tbl,responseName,layers,options)
[trainedNet,traininfo] = trainNetwork(___)

Description

example

trainedNet = trainNetwork(imds,layers,options) trains a network for image classification problems. imds stores the input image data, layers defines the network architecture, and options defines the training options.

example

trainedNet = trainNetwork(mbds,layers,options) trains a network using the mini-batch datastore mdbs. Use a mini-batch datastore to read out-of-memory data or to perform specific operations when reading batches of data.

example

trainedNet = trainNetwork(X,Y,layers,options) trains a network for image classification and regression problems. X contains the predictor variables and Y contains the categorical labels or numeric responses.

example

trainedNet = trainNetwork(sequences,Y,layers,options) trains an LSTM or BiLSTM network for classification and regression problems. sequences is a cell array containing sequence or time series predictors and Y contains the responses. For classification problems, Y is a categorical vector or a cell array of categorical sequences. For regression problems, Y is a matrix of targets or a cell array of numeric sequences.

trainedNet = trainNetwork(tbl,layers,options) trains a network for classification and regression problems. tbl contains numeric data or file paths to the data. The predictors must be in the first column of tbl. For information on the targets or response variables, see tbl.

trainedNet = trainNetwork(tbl,responseName,layers,options) trains a network for classification and regression problems. The predictors must be in the first column of tbl. The responseName argument specifies the response variables in tbl.

[trainedNet,traininfo] = trainNetwork(___) also returns information on the training using any of the input arguments in the previous syntaxes.

Examples

collapse all

Load the data as an ImageDatastore object.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', ...
    'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

The datastore contains 10,000 synthetic images of digits from 0 to 9. The images are generated by applying random transformations to digit images created with different fonts. Each digit image is 28-by-28 pixels. The datastore contains an equal number of images per category.

Display some of the images in the datastore.

figure
numImages = 10000;
perm = randperm(numImages,20);
for i = 1:20
    subplot(4,5,i);
    imshow(imds.Files{perm(i)});
end

Divide the datastore so that each category in the training set has 750 images and the testing set has the remaining images from each label.

numTrainingFiles = 750;
[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles,'randomize');

splitEachLabel splits the image files in digitData into two new datastores, imdsTrain and imdsTest.

Define the convolutional neural network architecture.

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Set the options to the default settings for the stochastic gradient descent with momentum. Set the maximum number of epochs at 20, and start the training with an initial learning rate of 0.0001.

options = trainingOptions('sgdm', ...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4, ...
    'Verbose',false, ...
    'Plots','training-progress');

Train the network.

net = trainNetwork(imdsTrain,layers,options);

Run the trained network on the test set, which was not used to train the network, and predict the image labels (digits).

YPred = classify(net,imdsTest);
YTest = imdsTest.Labels;

Calculate the accuracy. The accuracy is the ratio of the number of true labels in the test data matching the classifications from classify to the number of images in the test data.

accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9896

Train a convolutional neural network using augmented image data. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.

Load the sample data, which consists of synthetic images of handwritten digits.

[XTrain,YTrain] = digitTrain4DArrayData;

digitTrain4DArrayData loads the digit training set as 4-D array data. XTrain is a 28-by-28-by-1-by-5000 array, where:

  • 28 is the height and width of the images.

  • 1 is the number of channels.

  • 5000 is the number of synthetic images of handwritten digits.

YTrain is a categorical vector containing the labels for each observation.

Set aside 1000 of the images for network validation.

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

Create an imageDataAugmenter object that specifies preprocessing options for image augmentation, such as resizing, rotation, translation, and reflection. Randomly translate the images up to three pixels horizontally and vertically, and rotate the images with an angle up to 20 degrees.

imageAugmenter = imageDataAugmenter( ...
    'RandRotation',[-20,20], ...
    'RandXTranslation',[-3 3], ...
    'RandYTranslation',[-3 3])
imageAugmenter = 
  imageDataAugmenter with properties:

           FillValue: 0
     RandXReflection: 0
     RandYReflection: 0
        RandRotation: [-20 20]
           RandScale: [1 1]
          RandXScale: [1 1]
          RandYScale: [1 1]
          RandXShear: [0 0]
          RandYShear: [0 0]
    RandXTranslation: [-3 3]
    RandYTranslation: [-3 3]

Create an augmentedImageDatastore object to use for network training and specify the image output size. During training, the datastore performs image augmentation and resizes the images. The datastore augments the images without saving any images to memory. trainNetwork updates the network parameters and then discards the augmented images.

imageSize = [28 28 1];
augimds = augmentedImageDatastore(imageSize,XTrain,YTrain,'DataAugmentation',imageAugmenter);

Specify the convolutional neural network architecture.

layers = [
    imageInputLayer(imageSize)
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Specify training options for stochastic gradient descent with momentum.

opts = trainingOptions('sgdm', ...
    'MaxEpochs',15, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',{XValidation,YValidation});

Train the network. Because the validation images are not augmented, the validation accuracy is higher than the training accuracy.

net = trainNetwork(augimds,layers,opts);

Load the sample data, which consists of synthetic images of handwritten digits. The third output contains the corresponding angles in degrees by which each image has been rotated.

Load the training images as 4-D arrays using digitTrain4DArrayData. The output XTrain is a 28-by-28-by-1-by-5000 array, where:

  • 28 is the height and width of the images.

  • 1 is the number of channels.

  • 5000 is the number of synthetic images of handwritten digits.

YTrain contains the rotation angles in degrees.

[XTrain,~,YTrain] = digitTrain4DArrayData;

Display 20 random training images using imshow.

figure
numTrainImages = numel(YTrain);
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Specify the convolutional neural network architecture. For regression problems, include a regression layer at the end of the network.

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(12,25)
    reluLayer
    fullyConnectedLayer(1)
    regressionLayer];

Specify the network training options. Set the initial learn rate to 0.001.

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.001, ...
    'Verbose',false, ...
    'Plots','training-progress');

Train the network.

net = trainNetwork(XTrain,YTrain,layers,options);

Test the performance of the network by evaluating the prediction accuracy of the test data. Use predict to predict the angles of rotation of the validation images.

[XTest,~,YTest] = digitTest4DArrayData;
YPred = predict(net,XTest);

Evaluate the performance of the model by calculating the root-mean-square error (RMSE) of the predicted and actual angles of rotation.

rmse = sqrt(mean((YTest - YPred).^2))
rmse = single
    6.4026

Train a deep learning LSTM network for sequence-to-label classification.

Load the Japanese Vowels data set as described in [1] and [2]. XTrain is a cell array containing 270 sequences of varying length with a feature dimension of 12. Y is a categorical vector of labels 1,2,...,9. The entries in XTrain are matrices with 12 rows (one row for each feature) and a varying number of columns (one column for each time step).

[XTrain,YTrain] = japaneseVowelsTrainData;

Visualize the first time series in a plot. Each line corresponds to a feature.

figure
plot(XTrain{1}')
title("Training Observation 1")
numFeatures = size(XTrain{1},1);
legend("Feature " + string(1:numFeatures),'Location','northeastoutside')

Define the LSTM network architecture. Specify the input size as 12 (the number of features of the input data). Specify an LSTM layer to have 100 hidden units and to output the last element of the sequence. Finally, specify nine classes by including a fully connected layer of size 9, followed by a softmax layer and a classification layer.

inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  5x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 12 dimensions
     2   ''   LSTM                    LSTM with 100 hidden units
     3   ''   Fully Connected         9 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

Specify the training options. Specify the solver as 'adam' and 'GradientThreshold' as 1. Set the mini-batch size to 27 and set the maximum number of epochs to 100.

Because the mini-batches are small with short sequences, the CPU is better suited for training. Set 'ExecutionEnvironment' to 'cpu'. To train on a GPU, if available, set 'ExecutionEnvironment' to 'auto' (the default value).

maxEpochs = 100;
miniBatchSize = 27;

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'GradientThreshold',1, ...
    'Verbose',false, ...
    'Plots','training-progress');

Train the LSTM network with the specified training options.

net = trainNetwork(XTrain,YTrain,layers,options);

Load the test set and classify the sequences into speakers.

[XTest,YTest] = japaneseVowelsTestData;

Classify the test data. Specify the same mini-batch size used for training.

YPred = classify(net,XTest,'MiniBatchSize',miniBatchSize);

Calculate the classification accuracy of the predictions.

acc = sum(YPred == YTest)./numel(YTest)
acc = 0.9270

Input Arguments

collapse all

Images with labels, specified as an ImageDatastore object with categorical labels. You can store data in ImageDatastore for image classification networks only.

ImageDatastore allows batch reading of JPG or PNG image files using prefetching. If you use a custom function for reading the images, then ImageDatastore does not prefetch.

Tip

Use augmentedImageDatastore for efficient preprocessing of images for deep learning including image resizing.

Do not use the readFcn option of imageDatastore as this option is usually significantly slower.

Mini-batch datastore for out-of-memory data and preprocessing, specified as one of the following:

Type of Mini-Batch DatastoreDescription
augmentedImageDatastoreApply random affine geometric transformations, including resizing, rotation, reflection, shear, and translation, for training deep neural networks.
pixelLabelImageDatastoreApply identical affine geometric transformations to images and corresponding ground truth labels for training semantic segmentation networks (requires Computer Vision System Toolbox™).
randomPatchExtractionDatastoreExtract pairs of random patches from images or pixel label images (requires Image Processing Toolbox™). You optionally can apply identical random affine geometric transformations to the pairs of patches.
denoisingImageDatastoreApply randomly generated Gaussian noise for training denoising networks (requires Image Processing Toolbox).
Custom mini-batch datastoreSpecify your own preprocessing options or create mini-batches of sequence data. For details, see Develop Custom Mini-Batch Datastore.

Use a mini-batch datastore to read out-of-memory data or to perform specific operations when reading batches of data. For more information about using mini-batch datastores for image processing, see Advanced Image Preprocessing.

Images, specified as a 4-D numeric array. The first three dimensions are the height, width, and channels, and the last dimension indexes the individual images.

If the array contains NaNs, then they are propagated through the training. However, in most cases, the training fails to converge.

Data Types: single | double | uint8 | int8 | uint16 | int16 | uint32 | int32

Sequences or time series data, specified as a cell array of matrices or a matrix. For cell array input, sequences is an N-by-1 cell array, where N is the number of observations. Each entry of sequences is a time series represented by a matrix with rows corresponding to data points and columns corresponding to time steps.

For sequence-to-sequence problems with one observation, sequences can be a D-by-S matrix, where D is the number of features and S is the number of time steps. If sequences is a matrix, then Y must be a categorical sequence of labels or a matrix of responses.

For sequence classification and regression problems, layers must begin with a sequence input layer.

Data Types: cell | single | double

Responses, specified as a categorical vector of labels, matrix, 4-D numeric array, cell array of categorical row vectors, or cell array of numeric sequences. The format of Y depends on the type of problem.

For classification problems, the format depends on the task.

TaskFormat
Image classificationN-by-1 categorical vector of labels, where N is the number of observations.
Sequence-to-label classification
Sequence-to-sequence classification

N-by-1 cell array of categorical sequences of labels, where N is the number of observations. Each sequence has the same number of time steps as the corresponding input sequence.

For sequence-to-sequence classification problems with one observation, sequences can be a matrix. In this case, Y must be a categorical sequence of labels.

For regression problems, the format depends on the task.

TaskFormat
Image Regression
  • N-by-R matrix, where N is the number of observations and R is the number of responses.

  • h-by-w-by-c-by-N numeric array, where N is the number of observations and h-by-w-by-c is the image size of a single response.

Sequence-to-one regressionN-by-R matrix, where N is the number of observations and R is the number of responses.
Sequence-to-sequence regression

N-by-1 cell array of numeric sequences, where N is the number of observations. The sequences are matrices with R rows, where R is the number of responses. Each sequence has the same number of time steps as the corresponding input sequence.

For sequence-to-sequence regression problems with one observation, sequences can be a matrix. In this case, Y must be a matrix of responses.

Normalizing the responses often helps to stabilize and speed up training of neural networks for regression. For more information, see Train Convolutional Neural Network for Regression.

Responses cannot contain NaNs.

Data Types: cell | categorical | double

Input data, specified as a table containing predictors in the first column and responses in the remaining column or columns. Each row in the table corresponds to an observation.

The arrangement of predictors and responses in the table columns depends on the type of problem.

Classification

TaskPredictorsResponses
Image classification
  • Absolute or relative file path to an image, specified as a character vector

  • Image specified as a 3-D numeric array

Categorical label

Sequence-to-label classification

Absolute or relative file path to a MAT file containing sequence or time series data.

The MAT file must contain a time series represented by a matrix with rows corresponding to data points and columns corresponding to time steps.

Categorical label

Sequence-to-sequence classification

Absolute or relative file path to a MAT file. The MAT file must contain a time series represented by a categorical vector, with entries corresponding to labels for each time step.

For classification problems, if you do not specify responseName, then the function, by default, uses the responses in the second column of tbl.

Regression

TaskPredictorsResponses
Image regression
  • Absolute or relative file path to an image, specified as a character vector

  • Image specified as a 3-D numeric array

  • One or more columns of scalar values

  • Numeric vector

  • 1-by-1 cell array containing a 3-D numeric array

Sequence-to-one regression

Absolute or relative file path to a MAT file containing sequence or time series data.

The MAT file must contain a time series represented by a matrix with rows corresponding to data points and columns corresponding to time steps.

  • One or more columns of scalar values

  • Numeric vector

Sequence-to-sequence regression

Absolute or relative file path to a MAT file. The MAT file must contain a time series represented by a matrix, where rows correspond to responses and columns correspond to time steps.

For regression problems, if you do not specify responseName, then the function, by default, uses the remaining columns of tbl. Normalizing the responses often helps to stabilize and speed up training of neural networks for regression. For more information, see Train Convolutional Neural Network for Regression.

For sequence classification and regression problems, layers must begin with a sequence input layer.

Responses cannot contain NaNs. If the predictor data contains NaNs, then they are propagated through the training. However, in most cases, the training fails to converge.

Data Types: table

Names of the response variables in the input table, specified as a character vector or cell array of character vectors. For problems with one response, responseName is the corresponding variable name in tbl. For regression problems with multiple response variables, responseName is a cell array of the corresponding variable names in tbl.

Data Types: char | cell

Network layers, specified as a Layer array or a LayerGraph object.

To create a network with all layers connected sequentially, you can use a Layer array as the input argument. In this case, the returned network is a SeriesNetwork object.

A directed acyclic graph (DAG) network has a complex structure in which layers can have multiple inputs and outputs. To create a DAG network, specify the network architecture as a LayerGraph object and then use that layer graph as the input argument to trainNetwork.

For a list of built-in layers, see List of Deep Learning Layers.

Training options, specified as a TrainingOptionsSGDM, TrainingOptionsRMSProp, or TrainingOptionsADAM object returned by the trainingOptions function. To specify solver and other options for network training, use trainingOptions.

Output Arguments

collapse all

Trained network, returned as a SeriesNetwork object or a DAGNetwork object.

If you train the network using a Layer array as the layers input argument, then trainedNet is a SeriesNetwork object. If you train the network using a LayerGraph object as the input argument, then trainedNet is a DAGNetwork object.

Training information for each iteration, returned as a structure with a combination of the following fields:

  • TrainingLoss — Loss function value at each iteration

  • TrainingAccuracy — Training accuracy at each iteration

  • TrainingRMSE — Training RMSE at each iteration

  • ValidationLoss — Loss function value for validation data

  • ValidationAccuracy — Validation accuracy

  • ValidationRMSE — Validation RMSE

  • BaseLearnRate — Learning rate at each iteration

trainNetwork returns accuracy values for classification networks, RMSE values for regression networks, and validation metrics when you validate the network during training. Each field is a numeric vector with one element per training iteration. Values that have not been calculated at a specific iteration are represented by NaN.

More About

collapse all

Save Checkpoint Networks and Resume Training

Deep Learning Toolbox™ enables you to save networks as .mat files after each epoch during training. This periodic saving is especially useful when you have a large network or a large data set, and training takes a long time. If the training is interrupted for some reason, you can resume training from the last saved checkpoint network. If you want trainNetwork to save checkpoint networks, then you must specify the name of the path by using the 'CheckpointPath' name-value pair argument of trainingOptions. If the path that you specify does not exist, then trainingOptions returns an error.

trainNetwork automatically assigns unique names to checkpoint network files. In the example name, net_checkpoint__351__2018_04_12__18_09_52.mat, 351 is the iteration number, 2018_04_12 is the date, and 18_09_52 is the time at which trainNetwork saves the network. You can load a checkpoint network file by double-clicking it or using the load command at the command line. For example:

load net_checkpoint__351__2018_04_12__18_09_52.mat
You can then resume training by using the layers of the network as an input argument to trainNetwork. For example:

trainNetwork(XTrain,YTrain,net.Layers,options)
You must manually specify the training options and the input data, because the checkpoint network does not contain this information. For an example, see Resume Training from Checkpoint Network.

Floating-Point Arithmetic

All functions for deep learning training, prediction, and validation in Deep Learning Toolbox perform computations using single-precision, floating-point arithmetic. Functions for deep learning include trainNetwork, predict, classify, and activations. The software uses single-precision arithmetic when you train networks using both CPUs and GPUs.

References

[1] Kudo, M., J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pp. 1103–1111.

[2] Kudo, M., J. Toyama, and M. Shimbo. Japanese Vowels Data Set. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Extended Capabilities

Introduced in R2016a