Documentation

This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English verison of the page.

Note: This page has been translated by MathWorks. Please click here
To view all translated materals including this page, select Japan from the country navigator on the bottom of this page.

trainNetwork

Train neural network for deep learning

Use trainNetwork to train a deep learning network. For image classification and regression problems, you can train a convolutional neural network (ConvNet, CNN), such as a directed acyclic graph (DAG) network. For sequence and time series classification problems, you can train a long short-term memory (LSTM) network.

You can train a network on either a CPU, a GPU, multiple GPUs, or in parallel. Using GPU, multi-GPU, and parallel options require Parallel Computing Toolbox™. To use a GPU, you must also have a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher. Specify training options including the execution environment using trainingOptions.

Syntax

trainedNet = trainNetwork(imds,layers,options)
trainedNet = trainNetwork(mbs,layers,options)
trainedNet = trainNetwork(X,Y,layers,options)
trainedNet = trainNetwork(C,Y,layers,options)
trainedNet = trainNetwork(tbl,layers,options)
trainedNet = trainNetwork(tbl,responseName,layers,options)
trainedNet = trainNetwork(tbl,responseNames,layers,options)
[trainedNet,traininfo] = trainNetwork(___)

Description

example

trainedNet = trainNetwork(imds,layers,options) trains a network for image classification problems. imds stores the input image data, layers defines the network architecture, and options defines the training options.

trainedNet = trainNetwork(mbs,layers,options) trains a network for image classification and regression problems. mbs is an augmented image source, denoising image source, or pixel label image source, that preprocesses images for deep learning.

example

trainedNet = trainNetwork(X,Y,layers,options) trains a network for image classification and regression problems. X contains the predictor variables and Y contains the categorical labels or numeric responses.

example

trainedNet = trainNetwork(C,Y,layers,options) trains an LSTM network for sequence-to-label and sequence-to-sequence classification problems. C is a cell array containing sequence or time series predictors and Y contains the categorical labels or categorical sequences.

trainedNet = trainNetwork(tbl,layers,options) trains a network for classification and regression problems. tbl contains the predictors and the targets or response variables. The predictors must be in the first column of tbl. For information on the targets or response variables, see the tbl argument description.

trainedNet = trainNetwork(tbl,responseName,layers,options) trains a network for classification and regression problems. The predictors must be in the first column of tbl. The responseName argument specifies the response variable in the table tbl.

trainedNet = trainNetwork(tbl,responseNames,layers,options) trains a network for regression problems. The predictors must be in the first column of tbl. The responseNames argument specifies the response variables in the table tbl.

[trainedNet,traininfo] = trainNetwork(___) also returns information on the training for any of the input arguments.

Examples

collapse all

Load the data as an ImageDatastore object.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos',...
    'nndatasets','DigitDataset');
digitData = imageDatastore(digitDatasetPath,...
        'IncludeSubfolders',true,'LabelSource','foldernames');

The data store contains 10000 synthetic images of digits 0-9. The images are generated by applying random transformations to digit images created using different fonts. Each digit image is 28-by-28 pixels.

Display some of the images in the datastore.

figure;
perm = randperm(10000,20);
for i = 1:20
    subplot(4,5,i);
    imshow(digitData.Files{perm(i)});
end

Check the number of images in each digit category.

digitData.countEachLabel
ans =

  10×2 table

    Label    Count
    _____    _____

    0        1000 
    1        1000 
    2        1000 
    3        1000 
    4        1000 
    5        1000 
    6        1000 
    7        1000 
    8        1000 
    9        1000 

The data contains an equal number of images per category.

Divide the data set so that each category in the training set has 750 images and the testing set has the remaining images from each label.

trainingNumFiles = 750;
rng(1) % For reproducibility
[trainDigitData,testDigitData] = splitEachLabel(digitData,...
				trainingNumFiles,'randomize');

splitEachLabel splits the image files in digitData into two new datastores, trainDigitData and testDigitData.

Define the convolutional neural network architecture.

layers = [imageInputLayer([28 28 1]);
          convolution2dLayer(5,20);
          reluLayer();
          maxPooling2dLayer(2,'Stride',2);
          fullyConnectedLayer(10);
          softmaxLayer();
          classificationLayer()];

Set the options to default settings for the stochastic gradient descent with momentum. Set the maximum number of epochs at 20, and start the training with an initial learning rate of 0.001.

options = trainingOptions('sgdm','MaxEpochs',20,...
	'InitialLearnRate',0.0001);

Train the network.

convnet = trainNetwork(trainDigitData,layers,options);
Training on single GPU.
Initializing image normalization.
|=========================================================================================|
|     Epoch    |   Iteration  | Time Elapsed |  Mini-batch  |  Mini-batch  | Base Learning|
|              |              |  (seconds)   |     Loss     |   Accuracy   |     Rate     |
|=========================================================================================|
|            1 |            1 |         0.06 |       3.0845 |       13.28% |       0.0001 |
|            1 |           50 |         0.72 |       1.0945 |       65.63% |       0.0001 |
|            2 |          100 |         1.43 |       0.7276 |       74.22% |       0.0001 |
|            3 |          150 |         2.14 |       0.4743 |       83.59% |       0.0001 |
|            4 |          200 |         2.85 |       0.3086 |       91.41% |       0.0001 |
|            5 |          250 |         3.56 |       0.2324 |       92.97% |       0.0001 |
|            6 |          300 |         4.25 |       0.1542 |       97.66% |       0.0001 |
|            7 |          350 |         4.95 |       0.1315 |       97.66% |       0.0001 |
|            7 |          400 |         5.63 |       0.0944 |       96.09% |       0.0001 |
|            8 |          450 |         6.33 |       0.0668 |       98.44% |       0.0001 |
|            9 |          500 |         7.02 |       0.0458 |       99.22% |       0.0001 |
|           10 |          550 |         7.73 |       0.0544 |      100.00% |       0.0001 |
|           11 |          600 |         8.43 |       0.0660 |       99.22% |       0.0001 |
|           12 |          650 |         9.12 |       0.0338 |      100.00% |       0.0001 |
|           13 |          700 |         9.82 |       0.0340 |      100.00% |       0.0001 |
|           13 |          750 |        10.51 |       0.0370 |       99.22% |       0.0001 |
|           14 |          800 |        11.21 |       0.0264 |      100.00% |       0.0001 |
|           15 |          850 |        11.91 |       0.0182 |      100.00% |       0.0001 |
|           16 |          900 |        12.61 |       0.0234 |      100.00% |       0.0001 |
|           17 |          950 |        13.32 |       0.0224 |      100.00% |       0.0001 |
|           18 |         1000 |        14.01 |       0.0160 |      100.00% |       0.0001 |
|           19 |         1050 |        14.70 |       0.0233 |      100.00% |       0.0001 |
|           19 |         1100 |        15.39 |       0.0245 |      100.00% |       0.0001 |
|           20 |         1150 |        16.09 |       0.0154 |      100.00% |       0.0001 |
|           20 |         1160 |        16.23 |       0.0146 |      100.00% |       0.0001 |
|=========================================================================================|

Run the trained network on the test set that was not used to train the network and predict the image labels (digits).

YTest = classify(convnet,testDigitData);
TTest = testDigitData.Labels;

Calculate the accuracy.

accuracy = sum(YTest == TTest)/numel(TTest)
accuracy =

    0.9852

Accuracy is the ratio of the number of true labels in the test data matching the classifications from classify, to the number of images in the test data. In this case about 98.5% of the digit estimations match the true digit values in the test set.

Load the training data.

load lettersTrainSet

XTrain contains 1500 28-by-28 grayscale images of the letters A, B, and C in a 4-D array. There are equal numbers of each letter in the data set. TTrain contains the categorical array of the letter labels.

Display some of the letter images.

figure;
perm = randperm(1500,20);
for i = 1:20
    subplot(4,5,i);
    imshow(XTrain(:,:,:,perm(i)));
end

Define the convolutional neural network architecture.

layers = [imageInputLayer([28 28 1]);
          convolution2dLayer(5,16);
          reluLayer();
          maxPooling2dLayer(2,'Stride',2);
          fullyConnectedLayer(3);
          softmaxLayer();
          classificationLayer()];

Set the options to default settings for the stochastic gradient descent with momentum.

options = trainingOptions('sgdm');

Train the network.

rng('default') % For reproducibility
net = trainNetwork(XTrain,TTrain,layers,options);
Training on single GPU.
Initializing image normalization.
|=========================================================================================|
|     Epoch    |   Iteration  | Time Elapsed |  Mini-batch  |  Mini-batch  | Base Learning|
|              |              |  (seconds)   |     Loss     |   Accuracy   |     Rate     |
|=========================================================================================|
|            1 |            1 |         3.55 |       1.0994 |       27.34% |       0.0100 |
|            5 |           50 |         4.66 |       0.2175 |       98.44% |       0.0100 |
|           10 |          100 |         5.42 |       0.0238 |      100.00% |       0.0100 |
|           14 |          150 |         6.18 |       0.0108 |      100.00% |       0.0100 |
|           19 |          200 |         6.93 |       0.0088 |      100.00% |       0.0100 |
|           23 |          250 |         7.68 |       0.0048 |      100.00% |       0.0100 |
|           28 |          300 |         8.44 |       0.0035 |      100.00% |       0.0100 |
|           30 |          330 |         8.88 |       0.0052 |      100.00% |       0.0100 |
|=========================================================================================|

Run the trained network on a test set that was not used to train the network and predict the image labels (letters).

load lettersTestSet;

XTest contains 1500 28-by-28 grayscale images of the letters A, B, and C in a 4-D array. There is again equal numbers of each letter in the data set. TTest contains the categorical array of the letter labels.

YTest = classify(net,XTest);

Calculate the accuracy.

accuracy = sum(YTest == TTest)/numel(TTest)
accuracy =

    0.9273

Load the Japanese Vowels dataset as described in [1] and [2]. X is a cell array containing 270 sequences of dimension 12 of varying length. Y is a categorical vector of labels "1","2",...,"9".

load JapaneseVowelsTrain

The entries in X are matrices with 12 rows (one row for each feature) and varying number of columns (one column for each time step). Visualize the first time series in a plot. Each subplot corresponds to a feature.

figure
for i = 1:12
    subplot(12,1,13-i)
    plot(X{1}(i,:));
    ylabel(i) 
    xticklabels('')
    yticklabels('')
    box off
end
title("Training Observation 1")
subplot(12,1,12)
xticklabels('auto')
xlabel("Time Step")

Define the LSTM network architecture. Specify the inputs size to be sequences of size 12 (the dimension of the input data). Specify an LSTM layer to have output size of 100 and output the last element of the sequence. Finally, specify that there are 9 classes by including a fully connected layer of size 9, followed by a softmax layer and a classification layer.

inputSize = 12;
outputSize = 100;
outputMode = 'last';
numClasses = 9;
layers = [ ...
    sequenceInputLayer(inputSize)
    lstmLayer(outputSize,'OutputMode',outputMode)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  5x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 12 dimensions
     2   ''   LSTM                    LSTM with 100 hidden units
     3   ''   Fully Connected         9 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

Specify the training options. Choose mini-batch size of 27 and set the maximum number of epochs to 150.

maxEpochs = 150;
miniBatchSize = 27;
options = trainingOptions('sgdm', ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize);

Train the LSTM network with the specified training options using trainNetwork.

net = trainNetwork(X,Y,layers,options);
Training on single GPU.
|=========================================================================================|
|     Epoch    |   Iteration  | Time Elapsed |  Mini-batch  |  Mini-batch  | Base Learning|
|              |              |  (seconds)   |     Loss     |   Accuracy   |     Rate     |
|=========================================================================================|
|            1 |            1 |         0.14 |       2.1972 |        7.41% |       0.0100 |
|            5 |           50 |         7.88 |       2.1972 |       11.11% |       0.0100 |
|           10 |          100 |        15.68 |       2.1967 |       14.81% |       0.0100 |
|           15 |          150 |        23.43 |       2.1961 |       14.81% |       0.0100 |
|           20 |          200 |        31.09 |       2.1951 |       22.22% |       0.0100 |
|           25 |          250 |        38.62 |       2.1936 |       22.22% |       0.0100 |
|           30 |          300 |        46.28 |       2.1909 |       22.22% |       0.0100 |
|           35 |          350 |        53.98 |       2.1862 |       29.63% |       0.0100 |
|           40 |          400 |        61.72 |       2.1767 |       29.63% |       0.0100 |
|           45 |          450 |        69.31 |       2.1541 |       25.93% |       0.0100 |
|           50 |          500 |        77.09 |       2.0688 |       33.33% |       0.0100 |
|           55 |          550 |        84.82 |       1.7619 |       25.93% |       0.0100 |
|           60 |          600 |        92.41 |       1.6602 |       33.33% |       0.0100 |
|           65 |          650 |       100.15 |       1.5891 |       33.33% |       0.0100 |
|           70 |          700 |       108.15 |       1.5246 |       29.63% |       0.0100 |
|           75 |          750 |       115.94 |       1.4408 |       33.33% |       0.0100 |
|           80 |          800 |       123.73 |       1.3130 |       44.44% |       0.0100 |
|           85 |          850 |       131.65 |       1.0004 |       70.37% |       0.0100 |
|           90 |          900 |       139.56 |       0.7564 |       81.48% |       0.0100 |
|           95 |          950 |       147.32 |       0.5176 |       88.89% |       0.0100 |
|          100 |         1000 |       155.05 |       0.4757 |       88.89% |       0.0100 |
|          105 |         1050 |       162.86 |       0.4270 |       88.89% |       0.0100 |
|          110 |         1100 |       170.70 |       0.2822 |       92.59% |       0.0100 |
|          115 |         1150 |       178.57 |       0.2446 |       92.59% |       0.0100 |
|          120 |         1200 |       186.20 |       0.3388 |       88.89% |       0.0100 |
|          125 |         1250 |       193.92 |       0.2865 |       88.89% |       0.0100 |
|          130 |         1300 |       201.75 |       0.4839 |       74.07% |       0.0100 |
|          135 |         1350 |       209.50 |       0.2142 |      100.00% |       0.0100 |
|          140 |         1400 |       217.22 |       0.1650 |       92.59% |       0.0100 |
|          145 |         1450 |       225.12 |       0.1278 |       96.30% |       0.0100 |
|          150 |         1500 |       232.84 |       0.0906 |      100.00% |       0.0100 |
|=========================================================================================|

Load the test set and classify the sequences into speakers.

load JapaneseVowelsTest

Classify the test data. Set the mini-batch size to 27.

miniBatchSize = 27;
YPred = classify(net,XTest, ...
    'MiniBatchSize',miniBatchSize);

Calculate the classification accuracy of the predictions.

acc = sum(YPred == YTest)./numel(YTest)
acc = 0.8892

Input Arguments

collapse all

Images with labels, specified as an ImageDatastore object with categorical labels. You can store data in ImageDatastore for only image classification networks.

ImageDatastore allows batch-reading of JPG or PNG image files using pre-fetching. If you use a custom function for reading the images, pre-fetching does not happen.

Image source, specified as one of the following:

  • An augmentedImageSource that preprocesses images for deep learning. For example, an augmented image source can resize, rotate, and reflect input images.

  • A denoisingImageSource that preprocesses images for use in training denoising networks. For example, a denoising image source can add gaussian noise to input images.

  • A pixelLabelImageSource that specifies inputs and responses when training semantic segmentation networks.

Images, specified as a 4-D numeric array. The first three dimensions must be the height, width, and channels, and the last dimension must index the individual images.

If the array contains NaNs, then they are propagated through the training, however, in most cases the training fails to converge.

Data Types: single | double

Sequences or time series data, specified as a cell array of matrices. C is a N-by-1 cell array where N is the number of observations. Each entry of C is a time series represented by a matrix with rows corresponding to data points, and columns corresponding to time steps.

For sequence-to-label and sequence-to-sequence classification problems, layers must define an LSTM network. It must begin with a sequence input layer.

Data Types: cell

Responses, specified as a categorical vector of labels, a matrix, a 4-D numeric array, or a cell array of categorical row vectors.

For image and sequence-to-label classification problems, Y is a categorical vector of labels.

For sequence-to-sequence classification problems, Y is a cell array of categorical row vectors. Each row vector is represents a sequence of labels corresponding to the input sequence. The row vectors must have the same number of time steps as the corresponding input sequences.

For sequence-to-label and sequence-to-sequence classification problems, layers must define an LSTM network. It must begin with a sequence input layer.

For regression problems, Y can be one of the following:

  • N-by-r matrix, where N is the number of observations and r is the number of responses

  • h-by-w-by-c-by-N numeric array, where N is the number of observations and h-by-w-by-c is the size of a single response.

Responses must not contain NaNs.

Data Types: cell | categorical | double

Input data, specified as a table containing predictors in the first column and responses in the remaining column or columns. Each row in the table corresponds to an observation.

The arrangement of the predictors and the responses in the columns of the table depend on the type of problem. The following table describes the predictors and the responses in the table.

Predictors

For image classification and regression problems:

  • Absolute or relative paths to an image, specified as a character vector

  • Image specified as a 3-D numeric array

For sequence-to-label and sequence-to-sequence classification problems:

  • Absolute or relative MAT file paths of sequence or time series data. The MAT files must contain a time series represented by a matrix with rows corresponding to data points, and columns corresponding to time steps.

Responses

For image and sequence-to-label classification problems:

  • A single column of scalar categorical labels

For sequence-to-sequence classification problems:

  • A single column of absolute or relative file paths to MAT files. The MAT files must contain a time series represented by a vector with entries corresponding to labels for each time step.

For regression problems:

  • One or more columns of scalar values

  • A single column of numeric vectors

  • A single column of 1-by-1 cell arrays containing a 3-D numeric array

For classification problems, if you do not specify the name of the response variable in the call to trainNetwork, then the function, by default, uses the responses in the second column. To specify the responses in a different column of tbl, then use the responseName positional argument. For sequence-to-label and sequence-to-sequence classification problems, layers must define an LSTM network. It must begin with a sequence input layer.

For regression problems, if you do not specify the names of the response variable or variables in the call to trainNetwork, then the function, by default, uses the remaining columns of tbl. To specify the responses in a different column of tbl, then use the responseNames positional argument.

Responses must not contain NaNs. If the predictor data contains NaNs, then they are propagated through the training, however, usually the training fails to converge.

Data Types: table

Name of response variable for classification and regression problems, specified as a character vector that shows the name of the variable containing the responses in tbl.

Data Types: char

Names of the response variables for regression problems, specified as a cell array of character vectors that show the names of the variables containing the responses in tbl.

Data Types: cell

Network layers, specified as a Layer array or a LayerGraph object.

To train a network with all layers connected sequentially, you can use a Layer array as input argument. In this case, the returned trained network is a SeriesNetwork object.

A directed acyclic graph (DAG) network is a network with a more complex structure where layers can have multiple inputs or outputs. To train a DAG network, specify the network architecture as a LayerGraph object and then use that layer graph as input argument to trainNetwork.

Training options, specified as a TrainingOptionsSGDM object returned by the trainingOptions function. SGDM stands for the stochastic gradient descent with momentum solver.

Output Arguments

collapse all

Trained network, returned as a SeriesNetwork object or a DAGNetwork object.

If you train the network using a Layer array as the layers input argument, then trainedNet is a SeriesNetwork object. If you train the network using a LayerGraph object as input argument, then trainedNet is a DAGNetwork object.

Training information for each iteration, returned as a structure with a combination of the following fields.

  • TrainingLoss — Loss function value at each iteration

  • TrainingAccuracy — Training accuracy at each iteration

  • TrainingRMSE — Training RMSE at each iteration

  • ValidationLoss — Loss function value for validation data

  • ValidationAccuracy — Validation accuracy

  • ValidationRMSE — Validation RMSE

  • BaseLearnRate — Learning rate at each iteration

trainNetwork returns accuracy values for classification networks, RMSE values for regression networks, and validation metrics when you are validating the network during training. Each field is a numeric vector with one element per training iteration. Values that have not been calculated at a specific iteration are represented by NaN.

More About

collapse all

Save Checkpoint Networks and Resume Training

trainNetwork enables you to save checkpoint networks as .mat files during training. You can then resume training from any checkpoint network. If you want trainNetwork to save checkpoint networks, then you must specify the name of the path by using the CheckpointPath name-value pair argument of trainingOptions. If the path you specify is incorrect, then trainingOptions returns an error.

trainNetwork automatically assigns a unique name to each checkpoint network file, for example, convnet_checkpoint__351__2016_11_09__12_04_23.mat. In this example, 351 is the iteration number, 2016_11_09 is the date, and 12_04_23 is the time at which trainNetwork saves the network. You can load a checkpoint network file by double-clicking it or entering the load command at the command line. For example:

load convnet_checkpoint__351__2016_11_09__12_04_23.mat
You can then resume training by using the layers of this network in the call to trainNetwork. For example:

trainNetwork(Xtrain,Ytrain,net.Layers,options)
You must manually specify the training options and the input data because the checkpoint network does not contain this information.

References

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Introduced in R2016a

Was this topic helpful?