Update parameters using adaptive moment estimation (Adam)
Update the network learnable parameters in a custom training loop using the adaptive moment estimation (Adam) algorithm.
Note
This function applies the Adam optimization algorithm to update network parameters in
custom training loops that use networks defined as dlnetwork objects or model functions. If you want to train a network defined as
a Layer array or as a
LayerGraph, use the
following functions:
Create a TrainingOptionsADAM object using the trainingOptions function.
Use the TrainingOptionsADAM object with the trainNetwork function.
[
updates the learnable parameters of the network dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,grad,averageGrad,averageSqGrad,iteration)dlnet using the Adam
algorithm. Use this syntax in a training loop to iteratively update a network defined as a
dlnetwork object.
[
updates the learnable parameters in params,averageGrad,averageSqGrad] = adamupdate(params,grad,averageGrad,averageSqGrad,iteration)params using the Adam algorithm. Use
this syntax in a training loop to iteratively update the learnable parameters of a network
defined using functions.
[___] = adamupdate(___
also specifies values to use for the global learning rate, gradient decay, square gradient
decay, and small constant epsilon, in addition to the input arguments in previous syntaxes. learnRate,gradDecay,sqGradDecay,epsilon)
adamupdatePerform a single adaptive moment estimation update step with a global
learning rate of 0.05, gradient decay factor of
0.75, and squared gradient decay factor of
0.95.
Create the parameters and parameter gradients as numeric arrays.
params = rand(3,3,4); grad = ones(3,3,4);
Initialize the iteration counter, average gradient, and average squared gradient for the first iteration.
iteration = 1; averageGrad = []; averageSqGrad = [];
Specify custom values for the global learning rate, gradient decay factor, and squared gradient decay factor.
learnRate = 0.05; gradDecay = 0.75; sqGradDecay = 0.95;
Update the learnable parameters using adamupdate.
[params,averageGrad,averageSqGrad] = adamupdate(params,grad,averageGrad,averageSqGrad,iteration,learnRate,gradDecay,sqGradDecay);
Update the iteration counter.
iteration = iteration + 1;
adamupdateUse adamupdate to train a network using the Adam algorithm.
Load Training Data
Load the digits training data.
[XTrain,YTrain] = digitTrain4DArrayData; classes = categories(YTrain); numClasses = numel(classes);
Define Network
Define the network and specify the average image value using the 'Mean' option in the image input layer.
layers = [
imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4))
convolution2dLayer(5,20,'Name','conv1')
reluLayer('Name', 'relu1')
convolution2dLayer(3,20,'Padding',1,'Name','conv2')
reluLayer('Name','relu2')
convolution2dLayer(3,20,'Padding',1,'Name','conv3')
reluLayer('Name','relu3')
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);
Create a dlnetwork object from the layer graph.
dlnet = dlnetwork(lgraph);
Define Model Gradients Function
Create the helper function modelGradients, listed at the end of the example. The function takes a dlnetwork object dlnet and a mini-batch of input data dlX with corresponding labels Y, and returns the loss and the gradients of the loss with respect to the learnable parameters in dlnet.
Specify Training Options
Specify the options to use during training.
miniBatchSize = 128; numEpochs = 20; numObservations = numel(YTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize);
Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Support by Release (Parallel Computing Toolbox).
executionEnvironment = "auto";Visualize the training progress in a plot.
plots = "training-progress";Train Network
Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. Update the network parameters using the adamupdate function. At the end of each epoch, display the training progress.
Initialize the training progress plot.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
Initialize the average gradients and squared average gradients.
averageGrad = []; averageSqGrad = [];
Train the network.
iteration = 0; start = tic; for epoch = 1:numEpochs % Shuffle data. idx = randperm(numel(YTrain)); XTrain = XTrain(:,:,:,idx); YTrain = YTrain(idx); for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Read mini-batch of data and convert the labels to dummy % variables. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); Y = zeros(numClasses, miniBatchSize, 'single'); for c = 1:numClasses Y(c,YTrain(idx)==classes(c)) = 1; end % Convert mini-batch of data to a dlarray. dlX = dlarray(single(X),'SSCB'); % If training on a GPU, then convert data to a gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end % Evaluate the model gradients and loss using dlfeval and the % modelGradients helper function. [grad,loss] = dlfeval(@modelGradients,dlnet,dlX,Y); % Update the network parameters using the Adam optimizer. [dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,grad,averageGrad,averageSqGrad,iteration); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end

Test Network
Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.
[XTest, YTest] = digitTest4DArrayData;
Convert the data to a dlarray with the dimension format 'SSCB'. For GPU prediction, also convert the data to a gpuArray.
dlXTest = dlarray(XTest,'SSCB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlXTest = gpuArray(dlXTest); end
To classify images using a dlnetwork object, use the predict function and find the classes with the highest scores.
dlYPred = predict(dlnet,dlXTest); [~,idx] = max(extractdata(dlYPred),[],1); YPred = classes(idx);
Evaluate the classification accuracy.
accuracy = mean(YPred==YTest)
accuracy = 0.9896
Model Gradients Function
The modelGradients helper function takes a dlnetwork object dlnet and a mini-batch of input data dlX with corresponding labels Y, and returns the loss and the gradients of the loss with respect to the learnable parameters in dlnet. To compute the gradients automatically, use the dlgradient function.
function [gradients,loss] = modelGradients(dlnet,dlX,Y) dlYPred = forward(dlnet,dlX); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables); end
dlnet — Networkdlnetwork objectNetwork, specified as a dlnetwork object.
The function updates the dlnet.Learnables property of the
dlnetwork object. dlnet.Learnables is a table
with three variables:
Layer — Layer name, specified as a string scalar.
Parameter — Parameter name, specified as a string
scalar.
Value — Value of parameter, specified as a cell array
containing a dlarray.
The input argument grad must be a table of the same
form as dlnet.Learnables.
params — Network learnable parametersdlarray | numeric array | cell array | structure | tableNetwork learnable parameters, specified as a dlarray, a numeric
array, a cell array, a structure, or a table.
If you specify params as a table, it must contain the following
three variables:
Layer — Layer name, specified as a string scalar.
Parameter — Parameter name, specified as a string
scalar.
Value — Value of parameter, specified as a cell array
containing a dlarray.
You can specify params as a container of learnable parameters for
your network using a cell array, structure, or table, or nested cell arrays or
structures. The learnable parameters inside the cell array, structure, or table must be
dlarray or numeric values of data type double or
single.
The input argument grad must be provided with exactly the same
data type, ordering, and fields (for structures) or variables (for tables) as
params.
Data Types: single | double | struct | table | cell
grad — Gradients of the lossdlarray | numeric array | cell array | structure | tableGradients of the loss, specified as a dlarray, a numeric array, a
cell array, a structure, or a table.
The exact form of grad depends on the input network or learnable
parameters. The following table shows the required format for grad
for possible inputs to adamupdate.
| Input | Learnable Parameters | Gradients |
|---|---|---|
dlnet | Table dlnet.Learnables containing
Layer, Parameter, and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray. | Table with the same data type, variables, and ordering as
dlnet.Learnables. grad must have a
Value variable consisting of cell arrays that contain the
gradient of each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params
|
| Numeric array | Numeric array with the same data type and ordering as
params
| |
| Cell array | Cell array with the same data types, structure, and ordering as
params | |
| Structure | Structure with the same data types, fields, and ordering as
params | |
Table with Layer, Parameter, and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray. | Table with the same data types, variables, and ordering as
params. grad must have a
Value variable consisting of cell arrays that contain the
gradient of each learnable parameter. |
You can obtain grad from a call to dlfeval that
evaluates a function that contains a call to dlgradient.
For more information, see Use Automatic Differentiation In Deep Learning Toolbox.
averageGrad — Moving average of parameter gradients[] | dlarray | numeric array | cell array | structure | tableMoving average of parameter gradients, specified as an empty array, a
dlarray, a numeric array, a cell array, a structure, or a table.
The exact form of averageGrad depends on the input network or
learnable parameters. The following table shows the required format for
averageGrad for possible inputs to
adamupdate.
| Input | Learnable Parameters | Average Gradients |
|---|---|---|
dlnet | Table dlnet.Learnables containing
Layer, Parameter, and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray. | Table with the same data type, variables, and ordering as
dlnet.Learnables. averageGrad must have
a Value variable consisting of cell arrays that contain the
average gradient of each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params
|
| Numeric array | Numeric array with the same data type and ordering as
params
| |
| Cell array | Cell array with the same data types, structure, and ordering as
params | |
| Structure | Structure with the same data types, fields, and ordering as
params | |
Table with Layer, Parameter, and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray. | Table with the same data types, variables, and ordering as
params. averageGrad must have a
Value variable consisting of cell arrays that contain the
average gradient of each learnable parameter. |
If you specify averageGrad and averageSqGrad
as empty arrays, the function assumes no previous gradients and runs in the same way as
for the first update in a series of iterations. To update the learnable parameters
iteratively, use the averageGrad output of a previous call to
adamupdate as the averageGrad input.
averageSqGrad — Moving average of squared parameter gradients[] | dlarray | numeric array | cell array | structure | tableMoving average of squared parameter gradients, specified as an empty array, a
dlarray, a numeric array, a cell array, a structure, or a table.
The exact form of averageSqGrad depends on the input network or
learnable parameters. The following table shows the required format for
averageSqGrad for possible inputs to
adamupdate.
| Input | Learnable parameters | Average Squared Gradients |
|---|---|---|
dlnet | Table dlnet.Learnables containing
Layer, Parameter, and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray. | Table with the same data type, variables, and ordering as
dlnet.Learnables. averageSqGrad must
have a Value variable consisting of cell arrays that contain
the average squared gradient of each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params
|
| Numeric array | Numeric array with the same data type and ordering as
params
| |
| Cell array | Cell array with the same data types, structure, and ordering as
params | |
| Structure | Structure with the same data types, fields, and ordering as
params | |
Table with Layer, Parameter, and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray. | Table with the same data types, variables and ordering as
params. averageSqGrad must have a
Value variable consisting of cell arrays that contain the
average squared gradient of each learnable parameter. |
If you specify averageGrad and averageSqGrad
as empty arrays, the function assumes no previous gradients and runs in the same way as
for the first update in a series of iterations. To update the learnable parameters
iteratively, use the averageSqGrad output of a previous call to
adamupdate as the averageSqGrad input.
iteration — Iteration numberIteration number, specified as a positive integer. For the first call to
adamupdate, use a value of 1. You must increment
iteration by 1 for each subsequent call in a
series of calls to adamupdate. The Adam algorithm uses this value to
correct for bias in the moving averages at the beginning of a set of iterations.
learnRate — Global learning rate0.001 (default) | positive scalarGlobal learning rate, specified as a positive scalar. The default value of
learnRate is 0.001.
If you specify the network parameters as a dlnetwork, the
learning rate for each parameter is the global learning rate multiplied by the
corresponding learning rate factor property defined in the network layers.
gradDecay — Gradient decay factor0.9 (default) | positive scalar between 0 and 1
Gradient decay factor, specified as a positive scalar between 0
and 1. The default value of gradDecay is
0.9.
sqGradDecay — Squared gradient decay factor0.999 (default) | positive scalar between 0 and 1Squared gradient decay factor, specified as a positive scalar between
0 and 1. The default value of
sqGradDecay is 0.999.
epsilon — Small constant1e-8 (default) | positive scalarSmall constant for preventing divide-by-zero errors, specified as a positive scalar.
The default value of epsilon is 1e-8.
dlnet — Updated networkdlnetwork objectNetwork, returned as a dlnetwork object.
The function updates the dlnet.Learnables property of the
dlnetwork object.
params — Updated network learnable parametersdlarray | numeric array | cell array | structure | tableUpdated network learnable parameters, returned as a dlarray, a
numeric array, a cell array, a structure, or a table with a Value
variable containing the updated learnable parameters of the network.
averageGrad — Updated moving average of parameter gradientsdlarray | numeric array | cell array | structure | tableUpdated moving average of parameter gradients, returned as a
dlarray, a numeric array, a cell array, a structure, or a table.
averageSqGrad — Updated moving average of squared parameter gradientsdlarray | numeric array | cell array | structure | tableUpdated moving average of squared parameter gradients, returned as a
dlarray, a numeric array, a cell array, a structure, or a table.
The function uses the adaptive moment estimation (Adam) algorithm to
update the learnable parameters. For more information, see the definition of the Adam
algorithm under Stochastic Gradient Descent on the trainingOptions reference page.
Usage notes and limitations:
When at least one of the following input arguments is a gpuArray
or a dlarray with underlying data of type
gpuArray, this function runs on the GPU.
grad
averageGrad
averageSqGrad
params
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
dlarray | dlfeval | dlgradient | dlnetwork | dlupdate | forward | rmspropupdate | sgdmupdate
You have a modified version of this example. Do you want to open this example with your edits?