Specify Training Options in Custom Training Loop

For most tasks, you can control the training algorithm details using the trainingOptions and trainNetwork functions. If the trainingOptions function does not provide the options you need for your task (for example, a custom learn rate schedule), then you can define your own custom training loop using automatic differentiation.

To specify the same options as the trainingOptions, use these examples as a guide:

Solver Options

To specify the solver, use the adamupdate, rmspropupdate, and sgdmupdate functions for the update step in your training loop. To implement your own custom solver, update the learnable parameters using the dlupdate function.

Adaptive Moment Estimation (ADAM)

To update your network parameters using Adam, use the adamupdate function. Specify the gradient decay and the squared gradient decay factors using the corresponding input arguments.

Root Mean Square Propagation (RMSProp)

To update your network parameters using RMSProp, use the rmspropupdate function. Specify the denominator offset (epsilon) value using the corresponding input argument.

Stochastic Gradient Descent with Momentum (SGDM)

To update your network parameters using SGDM, use the sgdmupdate function. Specify the momentum using the corresponding input argument.

Learn Rate

To specify the learn rate, use the learn rate input arguments of the adamupdate, rmspropupdate, and sgdmupdate functions.

To easily adjust the learn rate or use it for custom learn rate schedules, set the initial learn rate before the custom training loop.

learnRate = 0.01;

Piecewise Learn Rate Schedule

To automatically drop the learn rate during training using a piecewise learn rate schedule, multiply the learn rate by a given drop factor after a specified interval.

To easily specify a piecewise learn rate schedule, create the variables learnRate, learnRateSchedule, learnRateDropFactor, and learnRateDropPeriod, where learnRate is the initial learn rate, learnRateScedule contains either "piecewise" or "none", learnRateDropFactor is a scalar in the range [0, 1] that specifies the factor for dropping the learning rate, and learnRateDropPeriod is a positive integer that specifies how many epochs between dropping the learn rate.

learnRate = 0.01;
learnRateSchedule = "piecewise"
learnRateDropPeriod = 10;
learnRateDropFactor = 0.1;

Inside the training loop, at the end of each epoch, drop the learn rate when the learnRateSchedule option is "piecewise" and the current epoch number is a multiple of learnRateDropPeriod. Set the new learn rate to the product of the learn rate and the learn rate drop factor.

if learnRateSchedule == "piecewise" && mod(epoch,learnRateDropPeriod) == 0
    learnRate = learnRate * learnRateDropFactor;
end

Plots

To plot the training loss and accuracy during training, calculate the mini-batch loss and either the accuracy or the root-mean-squared-error (RMSE) in the model gradients function and plot them using an animated line.

To easily specify that the plot should be on or off, create the variable plots that contains either "training-progress" or "none". To also plot validation metrics, use the same options validationData and validationFrequency described in Validation.

plots = "training-progress";

validationData = {XValidation, YValidation};
validationFrequency = 50;

Before training, initialize the animated lines using the animatedline function. For classification tasks create a plot for the training accuracy and the training loss. Also initialize animated lines for validation metrics when validation data is specified.

if plots == "training-progress"
    figure
    subplot(2,1,1)
    lineAccuracyTrain = animatedline;
    ylabel("Accuracy")
	
    subplot(2,1,2)
    lineLossTrain = animatedline;
    xlabel("Iteration")
    ylabel("Loss")

    if ~isempty(validationData)
        subplot(1,2,1)
        lineAccuracyValidation = animatedline;

        subplot(1,2,2)
        lineLossValidation = animatedline;
    end
end

For regression tasks, adjust the code by changing the variable names and labels so that it initializes plots for the training and validation RMSE instead of the training and validation accuracy.

Inside the training loop, at the end of an iteration, update the plot so that it includes the appropriate metrics for the network. For classification tasks, add points corresponding to the mini-batch accuracy and the mini-batch loss. If the validation data is nonempty, and the current iteration is either 1 or a multiple of the validation frequency option, then also add points for the validation data.

if plots == "training-progress"
    addpoints(lineAccuracyTrain,iteration,accuracyTrain)
    addpoints(lineLossTrain,iteration,lossTrain)

    if ~isempty(validationData) && (iteration == 1 || mod(iteration,validationFrequency) == 0)
        addpoints(lineAccuracyValidation,iteration,accuracyValidation)
        addpoints(lineLossValidation,iteration,lossValidation)
    end
end
where accuracyTrain and lossTrain correspond to the mini-batch accuracy and loss calculated in the model gradients function. For regression tasks, use the mini-batch RMSE losses instead of the mini-batch accuracies.

Tip

The addpoints function requires the data points to have type double. To extract numeric data from dlarray objects, use the extractdata function. To collect data from a GPU, use the gather function.

To learn how to compute validation metrics, see Validation.

Verbose Output

To display the training loss and accuracy during training in a verbose table, calculate the mini-batch loss and either the accuracy (for classification tasks) or the RMSE (for regression tasks) in the model gradients function and display them using the disp function.

To easily specify that the verbose table should be on or off, create the variables verbose and verboseFrequency, where verbose is true or false and verbosefrequency specifies how many iterations between printing verbose output. To display validation metrics, use the same options validationData and validationFrequency described in Validation.

verbose = true
verboseFrequency = 50;

validationData = {XValidation, YValidation};
validationFrequency = 50;

Before training, display the verbose output table headings and initialize a timer using the tic function.

disp("|======================================================================================================================|")
disp("|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |")
disp("|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |")
disp("|======================================================================================================================|")

start = tic;

For regression tasks, adjust the code so that it displays the training and validation RMSE instead of the training and validation accuracy.

Inside the training loop, at the end of an iteration, print the verbose output when the verbose option is true and it is either the first iteration or the iteration number is a multiple of verboseFrequency.

if verbose && (iteration == 1 || mod(iteration,verboseFrequency) == 0
    D = duration(0,0,toc(start),'Format','hh:mm:ss');

    if isempty(validationData) || mod(iteration,validationFrequency) ~= 0 
        accuracyValidation = "";
        lossValidation = "";
    end

    disp("| " + ...
        pad(epoch,7,'left') + " | " + ...
        pad(iteration,11,'left') + " | " + ...
        pad(D,14,'left') + " | " + ...
        pad(accuracyTrain,12,'left') + " | " + ...
        pad(accuracyValidation,12,'left') + " | " + ...
        pad(lossTrain,12,'left') + " | " + ...
        pad(lossValidation,12,'left') + " | " + ...
        pad(learnRate,15,'left') + " |")
end

For regression tasks, adjust the code so that it displays the training and validation RMSE instead of the training and validation accuracy.

When training is finished, print the last border of the verbose table.

disp("|======================================================================================================================|")

To learn how to compute validation metrics, see Validation.

Mini-Batch Size

Setting the mini-batch size depends on the format of data or type of datastore used.

To easily specify the mini-batch size, create a variable miniBatchSize.

miniBatchSize = 128;

For data in an image datastore, before training, set the ReadSize property of the datastore to the mini-batch size.

imds.ReadSize = miniBatchSize;

For data in an augmented image datastore, before training, set the MiniBatchSize property of the datastore to the mini-batch size.

augimds.MiniBatchSize = miniBatchSize;

For in-memory data, during training at the start of each iteration, read the observations directly from the array.

idx = ((iteration - 1)*miniBatchSize + 1):(iteration*miniBatchSize);
X = XTrain(:,:,:,idx);

Number of Epochs

Specify the maximum number of epochs for training in the outer for loop of the training loop.

To easily specify the maximum number of epochs, create the variable maxEpochs that contains the maximum number of epochs.

maxEpochs = 30;

In the outer for loop of the training loop, specify to loop over the range 1, 2, …, maxEpochs.

for epoch = 1:maxEpochs
    ...
end

Validation

To validate your network during training, set aside a held-out validation set and evaluate how well the network performs on that data.

To easily specify validation options, create the variables validationData and validationFrequency, where validationData contains the validation data or is empty and validationFrequency specifies how many iterations between validating the network.

validationData = {XValidation,YValidation};
validationFrequency = 50;

During the training loop, after updating the network parameters, test how well the network performs on the held-out validation set using the predict function. Validate the network only when validation data is specified and it is either the first iteration or the current iteration is a multiple of the verboseFrequency option.

if iteration == 1 || mod(iteration,verboseFrequency) == 0
    dlYPredValidation = predict(dlnet,dlXValidation);
    lossValidation = crossentropy(softmax(dlYPredValidation), YValidation);

    [~,idx] = max(dlYPredValidation);
    labelsPredValidation = classNames(idx);

    accuracyValidation = mean(labelsPredValidation == labelsValidation);
end
Here, YValidation is a dummy variable corresponding to the labels in classNames. To calculate the accuracy, convert YValidation to an array of labels.

For regression tasks, adjust the code so that it calculates the validation RMSE instead of the validation accuracy.

Early Stopping

To stop training early when the loss on the held-out validation stops decreasing, use set a flag to break out of the training loops.

To easily specify the validation patience (the number of times that the validation loss can be larger than or equal to the previously smallest loss before network training stops), create the variable validationPatience.

validationPatience = 5;

Before training, initialize a variables earlyStop and validationLosses, where earlyStop is a flag to stop training early and validationLosses contains the losses to compare. Initialize the early stopping flag with false and array of validation losses with inf.

earlyStop = false;
if isfinite(validationPatience)
    validationLosses = inf(1,validationPatience);
end

Inside the training loop, in the loop over mini-batches, add the earlyStop flag to the loop condition.

while hasdata(ds) && ~earlyStop
    ...
end

During the validation step, append the new validation loss to the array validationLosses. If the first element of the array is the smallest, then set the earlyStop flag to true. Otherwise, remove the first element.

if isfinite(validationPatience)
    validationLosses = [validationLosses validationLoss];
    if min(validationLosses) == validationLosses(1)
        earlyStop = true;
    else
        validationLosses(1) = [];
    end
end

L2 Regularization

To apply L2 regularization to the weights, use the dlupdate function.

To easily specify the L2 regularization factor, create the variable l2Regularization that contains the L2 regularization factor.

l2Regularization = 0.0001;

During training, after computing the model gradients, for each of the weight parameters, add the product of the L2 regularization factor and the weights to the computed gradients using the dlupdate function. To update only the weight parameters, extract the parameters with name "Weights".

idx = dlnet.Learnables.Parameter == "Weights";
gradients(idx,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:));

After adding the L2 regularization parameter to the gradients, update the network parameters.

Gradient Clipping

To clip the gradients, use the dlupdate function.

To easily specify gradient clipping options, create the variables gradientThresholdMethod and gradientThreshold, where gradientThresholdMethod contains "global-l2norm", "l2norm", or "absolute-value", and gradientThreshold is a positive scalar containing the threshold or inf.

gradientThresholdMethod = "global-l2norm";
gradientThreshold = 2;

Create functions named thresholdGlobalL2Norm, thresholdL2Norm, and thresholdAbsoluteValue that apply the "global-l2norm", "l2norm", and "absolute-value" threshold methods, respectively.

For the "global-l2norm" option, the function operates on all gradients of the model.

function gradients = thresholdGlobalL2Norm(gradients,gradientThreshold)

globalL2Norm = 0;
for i = 1:numel(gradients)
    globalL2Norm = globalL2Norm + sum(gradients{i}(:).^2);
end
globalL2Norm = sqrt(globalL2Norm);

if globalL2Norm > gradientThreshold
    normScale = gradientThreshold / globalL2Norm;
    for i = 1:numel(gradients)
        gradients{i} = gradients{i} * normScale;
    end
end

end

For the "l2norm" and "absolute-value" options, the functions operate on each gradient independently.

function gradients = thresholdL2Norm(gradients,gradientThreshold)

gradientNorm = sqrt(sum(gradients(:).^2));
if gradientNorm > gradientThreshold
    gradients = gradients * (gradientThreshold / gradientNorm);
end

end
function gradients = thresholdAbsoluteValue(gradients,gradientThreshold)

gradients(gradients > gradientThreshold) = gradientThreshold;
gradients(gradients < -gradientThreshold) = -gradientThreshold;

end

During training, after computing the model gradients, apply the appropriate gradient clipping method to the gradients using the dlupdate function. Because the "global-l2norm" option requires all the model gradients, apply the thresholdGlobalL2Norm function directly to the gradients. For the "l2norm" and "absolute-value" options, update the gradients independently using the dlupdate function.

switch gradientThresholdMethod
    case "global-l2norm"
        gradients = thresholdGlobalL2Norm(gradients, gradientThreshold);
    case "l2norm"
        gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients);
    case "absolute-value"
        gradients = dlupdate(@(g) thresholdAbsoluteValue(g, gradientThreshold),gradients);
end

After applying the gradient threshold operation, update the network parameters.

Single CPU or GPU Training

The software, by default, performs calculations using one the CPU. Train on a single GPU, convert to data to gpuArray objects. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher.

To easily specify the execution environment, create the variable executionEnvironment that contains either "cpu", "gpu", or "auto".

executionEnvironment = "auto"

During training, after reading a mini-batch, check the execution environment option and convert the data to a gpuArray if necessary. The canUseGPU function checks for useable GPUs.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

Checkpoints

To save checkpoint networks during training save the network using the save function.

To easily specify whether checkpoints should be switched on, create the variable checkpointPath contains the folder for the checkpoint networks or is empty.

checkpointPath = fullfile(tempdir,"checkpoints");

If the checkpoint folder does not exist, then before training, create the checkpoint folder.

if ~exist(checkpointPath,"dir")
    mkdir(checkpointPath)
end

During training, at the end of an epoch, save the network in a MAT file. Specify a file name containing the current iteration number, date, and time.

if ~isempty(checkpointPath)
    D = datestr(now,'yyyy_mm_dd__HH_MM_SS');
    filename = "dlnet_checkpoint__" + iteration + "__" + D + ".mat";
    save(filename,"dlnet")
end
where dlnet is the dlnetwork object to be saved.

See Also

| | | | | | |

Related Topics