adamupdate

Update parameters using adaptive moment estimation (Adam)

Description

Update the network learnable parameters in a custom training loop using the adaptive moment estimation (Adam) algorithm.

Note

This function applies the Adam optimization algorithm to update network parameters in custom training loops that use networks defined as dlnetwork objects or model functions. If you want to train a network defined as a Layer array or as a LayerGraph, use the following functions:

example

[dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,grad,averageGrad,averageSqGrad,iteration) updates the learnable parameters of the network dlnet using the Adam algorithm. Use this syntax in a training loop to iteratively update a network defined as a dlnetwork object.

example

[params,averageGrad,averageSqGrad] = adamupdate(params,grad,averageGrad,averageSqGrad,iteration) updates the learnable parameters in params using the Adam algorithm. Use this syntax in a training loop to iteratively update the learnable parameters of a network defined using functions.

example

[___] = adamupdate(___learnRate,gradDecay,sqGradDecay,epsilon) also specifies values to use for the global learning rate, gradient decay, square gradient decay, and small constant epsilon, in addition to the input arguments in previous syntaxes.

Examples

collapse all

Perform a single adaptive moment estimation update step with a global learning rate of 0.05, gradient decay factor of 0.75, and squared gradient decay factor of 0.95.

Create the parameters and parameter gradients as numeric arrays.

params = rand(3,3,4);
grad = ones(3,3,4);

Initialize the iteration counter, average gradient, and average squared gradient for the first iteration.

iteration = 1;
averageGrad = [];
averageSqGrad = [];

Specify custom values for the global learning rate, gradient decay factor, and squared gradient decay factor.

learnRate = 0.05;
gradDecay = 0.75;
sqGradDecay = 0.95;

Update the learnable parameters using adamupdate.

[params,averageGrad,averageSqGrad] = adamupdate(params,grad,averageGrad,averageSqGrad,iteration,learnRate,gradDecay,sqGradDecay);

Update the iteration counter.

iteration = iteration + 1;

Use adamupdate to train a network using the adaptive moment estimation (Adam) algorithm.

Load Training Data

Load the digits training data.

[XTrain,YTrain] = digitTrain4DArrayData;
classes = categories(YTrain);
numClasses = numel(classes);

Define the Network

Define the network and specify the average image using the 'Mean' option in the image input layer.

layers = [
    imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4))
    convolution2dLayer(5,20,'Name','conv1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')];
lgraph = layerGraph(layers);

Create a dlnetwork object from the layer graph.

dlnet = dlnetwork(lgraph);

Define the Model Gradients Function

Create the function modelGradients, listed at the end of the example, that takes a dlnetwork object dlnet, a mini-batch of input data dlX with corresponding labels Y and returns the loss and the gradients of the loss with respect to the learnable parameters in dlnet.

Specify Training Options

Specify the options to use during training.

miniBatchSize = 128;
numEpochs = 20;
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);

Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher.

executionEnvironment = "auto";

Initialize average gradients and squared average gradients.

averageGrad = [];
averageSqGrad = [];

Initialize iteration counter.

iteration = 1;

Initialize the training progress plot.

plots = "training-progress";
if plots == "training-progress"
    figure
    lineLossTrain = animatedline;
    xlabel("Total Iterations")
    ylabel("Loss")
end

Train the Network

Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. Update the network parameters using the adamupdate function.At the end of each epoch, display the training progress.

for epoch = 1:numEpochs
    % Shuffle data.
    idx = randperm(numel(YTrain));
    XTrain = XTrain(:,:,:,idx);
    YTrain = YTrain(idx);
    
    for i = 1:numIterationsPerEpoch
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        X = XTrain(:,:,:,idx);
        
        Y = zeros(numClasses, miniBatchSize, 'single');
        for c = 1:numClasses
            Y(c,YTrain(idx)==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients function.
        [grad,loss] = dlfeval(@modelGradients,dlnet,dlX,Y);
        
        % Update the network parameters using the Adam optimizer.
        [dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,grad,averageGrad,averageSqGrad,iteration);
        
        % Display the training progress.
        if plots == "training-progress"
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Loss During Training: Epoch - " + epoch + "; Iteration - " + i)
            drawnow
        end
        
        % Increment iteration counter
        iteration = iteration + 1;
    end
end

Test the Network

Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.

[XTest, YTest] = digitTest4DArrayData;

Convert the data to a dlarray object with dimension format 'SSCB'. For GPU prediction, also convert the data to gpuArray.

dlXTest = dlarray(XTest,'SSCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);
end

To classify images using a dlnetwork object, use the predict function and find the classes with the highest scores.

dlYPred = predict(dlnet,dlXTest);
[~,idx] = max(extractdata(dlYPred),[],1);
YPred = classes(idx);

Evaluate the classification accuracy.

accuracy = mean(YPred==YTest)
accuracy = 0.9908

Model Gradients Function

The modelGradients function takes a dlnetwork object dlnet, a mini-batch of input data dlX with corresponding labels Y and returns the loss and the gradients of the loss with respect to the learnable parameters in dlnet. To compute the gradients automatically, use the dlgradient function.

function [gradients,loss] = modelGradients(dlnet,dlX,Y)

    dlYPred = forward(dlnet,dlX);
    dlYPred = softmax(dlYPred);
    
    loss = crossentropy(dlYPred,Y);
    gradients = dlgradient(loss,dlnet.Learnables);

end

Input Arguments

collapse all

Network, specified as a dlnetwork object.

The function updates the dlnet.Learnables property of the dlnetwork object. dlnet.Learnables is a table with three variables:

  • Layer — Layer name, specified as a string scalar.

  • Parameter — Parameter name, specified as a string scalar.

  • Value — Value of parameter, specified as a cell array containing a dlarray.

The input argument grad must be a table of the same form as dlnet.Learnables.

Network learnable parameters, specified as a dlarray, a numeric array, a cell array, a structure, or a table.

If you specify params as a table, it must contain the following three variables:

  • Layer — Layer name, specified as a string scalar.

  • Parameter — Parameter name, specified as a string scalar.

  • Value — Value of parameter, specified as a cell array containing a dlarray.

You can specify params as a container of learnable parameters for your network using a cell array, structure, or table, or nested cell arrays or structures. The learnable parameters inside the cell array, structure, or table must be dlarray or numeric values of data type double or single.

The input argument grad must be provided with exactly the same data type, ordering, and fields (for structures) or variables (for tables) as params.

Data Types: single | double | struct | table | cell

Gradients of the loss, specified as a dlarray, a numeric array, a cell array, a structure, or a table.

The exact form of grad depends on the input network or learnable parameters. The following table shows the required format for grad for possible inputs to adamupdate.

InputLearnable ParametersGradients
dlnetTable dlnet.Learnables containing Layer, Parameter, and Value variables. The Value variable consists of cell arrays that contain each learnable parameter as a dlarray. Table with the same data type, variables, and ordering as dlnet.Learnables. grad must have a Value variable consisting of cell arrays that contain the gradient of each learnable parameter.
paramsdlarraydlarray with the same data type and ordering as params
Numeric arrayNumeric array with the same data type and ordering as params
Cell arrayCell array with the same data types, structure, and ordering as params
StructureStructure with the same data types, fields, and ordering as params
Table with Layer, Parameter, and Value variables. The Value variable must consist of cell arrays that contain each learnable parameter as a dlarray.Table with the same data types, variables, and ordering as params. grad must have a Value variable consisting of cell arrays that contain the gradient of each learnable parameter.

You can obtain grad from a call to dlfeval that evaluates a function that contains a call to dlgradient. For more information, see Use Automatic Differentiation In Deep Learning Toolbox.

Moving average of parameter gradients, specified as an empty array, a dlarray, a numeric array, a cell array, a structure, or a table.

The exact form of averageGrad depends on the input network or learnable parameters. The following table shows the required format for averageGrad for possible inputs to adamupdate.

InputLearnable ParametersAverage Gradients
dlnetTable dlnet.Learnables containing Layer, Parameter, and Value variables. The Value variable consists of cell arrays that contain each learnable parameter as a dlarray. Table with the same data type, variables, and ordering as dlnet.Learnables. averageGrad must have a Value variable consisting of cell arrays that contain the average gradient of each learnable parameter.
paramsdlarraydlarray with the same data type and ordering as params
Numeric arrayNumeric array with the same data type and ordering as params
Cell arrayCell array with the same data types, structure, and ordering as params
StructureStructure with the same data types, fields, and ordering as params
Table with Layer, Parameter, and Value variables. The Value variable must consist of cell arrays that contain each learnable parameter as a dlarray.Table with the same data types, variables, and ordering as params. averageGrad must have a Value variable consisting of cell arrays that contain the average gradient of each learnable parameter.

If you specify averageGrad and averageSqGrad as empty arrays, the function assumes no previous gradients and runs in the same way as for the first update in a series of iterations. To update the learnable parameters iteratively, use the averageGrad output of a previous call to adamupdate as the averageGrad input.

Moving average of squared parameter gradients, specified as an empty array, a dlarray, a numeric array, a cell array, a structure, or a table.

The exact form of averageSqGrad depends on the input network or learnable parameters. The following table shows the required format for averageSqGrad for possible inputs to adamupdate.

InputLearnable parametersAverage Squared Gradients
dlnetTable dlnet.Learnables containing Layer, Parameter, and Value variables. The Value variable consists of cell arrays that contain each learnable parameter as a dlarray. Table with the same data type, variables, and ordering as dlnet.Learnables. averageSqGrad must have a Value variable consisting of cell arrays that contain the average squared gradient of each learnable parameter.
paramsdlarraydlarray with the same data type and ordering as params
Numeric arrayNumeric array with the same data type and ordering as params
Cell arrayCell array with the same data types, structure, and ordering as params
StructureStructure with the same data types, fields, and ordering as params
Table with Layer, Parameter, and Value variables. The Value variable must consist of cell arrays that contain each learnable parameter as a dlarray.Table with the same data types, variables and ordering as params. averageSqGrad must have a Value variable consisting of cell arrays that contain the average squared gradient of each learnable parameter.

If you specify averageGrad and averageSqGrad as empty arrays, the function assumes no previous gradients and runs in the same way as for the first update in a series of iterations. To update the learnable parameters iteratively, use the averageSqGrad output of a previous call to adamupdate as the averageSqGrad input.

Iteration number, specified as a positive integer. For the first call to adamupdate, use a value of 1. You must increment iteration by 1 for each subsequent call in a series of calls to adamupdate. The Adam algorithm uses this value to correct for bias in the moving averages at the beginning of a set of iterations.

Global learning rate, specified as a positive scalar. The default value of learnRate is 0.001.

If you specify the network parameters as a dlnetwork, the learning rate for each parameter is the global learning rate multiplied by the corresponding learning rate factor property defined in the network layers.

Gradient decay factor, specified as a positive scalar between 0 and 1. The default value of gradDecay is 0.9.

Squared gradient decay factor, specified as a positive scalar between 0 and 1. The default value of sqGradDecay is 0.999.

Small constant for preventing divide-by-zero errors, specified as a positive scalar. The default value of epsilon is 1e-8.

Output Arguments

collapse all

Network, returned as a dlnetwork object.

The function updates the dlnet.Learnables property of the dlnetwork object.

Updated network learnable parameters, returned as a dlarray, a numeric array, a cell array, a structure, or a table with a Value variable containing the updated learnable parameters of the network.

Updated moving average of parameter gradients, returned as a dlarray, a numeric array, a cell array, a structure, or a table.

Updated moving average of squared parameter gradients, returned as a dlarray, a numeric array, a cell array, a structure, or a table.

More About

collapse all

Adam

The function uses the adaptive moment estimation (Adam) algorithm to update the learnable parameters. For more information, see the definition of the Adam algorithm under Stochastic Gradient Descent on the trainingOptions reference page.

Introduced in R2019b