Why are the gradients not backpropagating into the encoder in this custom loop?

I am building a convolutional autoencoder using a custom training loop. When I attempt to reconstruct the images, the network's output degenerates to guessing the same incorrect value for all inputs. However, training the autoencoder in a single stack with the trainnet function works fine, indicating that the gradient updates are unable to bridge the bottleneck layer in the custom training loop. Unfortunately, I need to use the custom training loop for a different task and am prohibited from using TensorFlow or PyTorch.
What is the syntax to ensure that the encoder is able to update based on the decoder's reconstruction performance?
%% Functional 'trainnet' loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
aeLayers = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name='Output')
];
autoencoder = dlnetwork(aeLayers);
%% Training Parameters
numEpochs = 150;
miniBatchSize = 25;
learnRate = 1e-3;
options = trainingOptions("adam", ...
InitialLearnRate=learnRate,...
MaxEpochs=30, ...
Plots="training-progress", ...
TargetDataFormats="SSCB", ...
InputDataFormats="SSCB", ...
MiniBatchSize=miniBatchSize, ...
OutputNetwork="last-iteration", ...
Shuffle="every-epoch");
autoencoder = trainnet(dlarray(xTrain, 'SSCB'),dlarray(xTrain, 'SSCB'), ...
autoencoder, 'mse', options);
%% Testing
YTest = predict(autoencoder, dlarray(xTest, 'SSCB'));
indices = randi(size(xTest, 4), [1, size(xTest, 4)]); % Shuffle YTest & xTest
xTest = xTest(:,:,:,indices); YTest = YTest(:,:,:,indices);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = extractdata(YTest(:,:,:,1:numImages));
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = xTest(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Nonfunctional Custom Training Loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Encoder
layersE = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer];
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
layersD = [
imageInputLayer(projectionSize)
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name='Output')
];
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
%% Training Parameters
numEpochs = 150;
miniBatchSize = 25;
learnRate = 1e-3;
% Create training minibatchqueue
dsTrain = arrayDatastore(xTrain,IterationDimension=4);
numOutputs = 1;
mbq = minibatchqueue(dsTrain,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFormat="SSCB", ...
MiniBatchFcn=@preprocessMiniBatch,...
PartialMiniBatch="return");
%Initialize the parameters for the Adam solver.
trailingAvgE = [];
trailingAvgSqE = [];
trailingAvgD = [];
trailingAvgSqD = [];
%Calculate the total number of iterations for the training progress monitor
numIterationsPerEpoch = ceil(size(xTrain, 4) / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
epoch = 0;
iteration = 0;
%Initialize the training progress monitor.
monitor = trainingProgressMonitor( ...
Metrics="TrainingLoss", ...
Info=["Epoch", "LearningRate"], ...
XLabel="Iteration");
%% Training
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
% Assess validation criterion
iteration = iteration + 1;
% Read mini-batch of data.
X = next(mbq);
% Evaluate loss and gradients.
[loss,gradientsE,gradientsD] = dlfeval(@modelLoss,netE,netD,X);
% Update learnable parameters.
[netE,trailingAvgE,trailingAvgSqE] = adamupdate(netE, ...
gradientsE,trailingAvgE,trailingAvgSqE,iteration,learnRate);
[netD, trailingAvgD, trailingAvgSqD] = adamupdate(netD, ...
gradientsD,trailingAvgD,trailingAvgSqD,iteration,learnRate);
updateInfo(monitor, ...
LearningRate=learnRate, ...
Epoch=string(epoch) + " of " + string(numEpochs));
recordMetrics(monitor,iteration, ...
TrainingLoss=loss);
monitor.Progress = 100*iteration/numIterations;
end
end
%% Testing
dsTest = arrayDatastore(xTest,IterationDimension=4);
numOutputs = 1;
ntest = size(xTest, 4);
indices = randi(ntest,[1,ntest]);
xTest = xTest(:,:,:,indices);% Shuffle test data
mbqTest = minibatchqueue(dsTest,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB");
YTest = modelPredictions(netE,netD,mbqTest);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = YTest(:,:,:,1:numImages);
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = xTest(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Functions
function [loss,gradientsE,gradientsD] = modelLoss(netE,netD,X)
% Forward through encoder.
Z = forward(netE,X);
% Forward through decoder.
Xrecon = forward(netD,Z);
% Calculate loss and gradients.
loss = regularizedLoss(Xrecon,X);
[gradientsE,gradientsD] = dlgradient(loss,netE.Learnables,netD.Learnables);
end
function loss = regularizedLoss(Xrecon,X)
% Image Reconstruction loss.
reconstructionLoss = l2loss(Xrecon, X, 'NormalizationFactor','all-elements');
% Combined loss.
loss = reconstructionLoss;
end
function Xrecon = modelPredictions(netE,netD,mbq)
Xrecon = [];
shuffle(mbq)
% Loop over mini-batches.
while hasdata(mbq)
X = next(mbq);
% Pass through encoder
Z = predict(netE,X);
% Pass through decoder to get reconstructed images
XGenerated = predict(netD,Z);
% Extract and concatenate predictions.
Xrecon = cat(4,Xrecon,extractdata(XGenerated));
end
end
function X = preprocessMiniBatch(Xcell)
% Concatenate.
X = cat(4,Xcell{:});
end

2 Comments

training the autoencoder in a single stack with the trainnet function works fine, indicating that the gradient updates are unable to bridge the bottleneck layer in the custom training loop.
I don't see why gradient back propagation failure is the only possible culprit. However, it should be something that is easy to test. You can implement a second version of your modelLoss that takes a single stack as input. Then you can run dlfeval on both versions and see if they return the same gradients (within floating point differences).
Hm. The stacked autoencoder trains exactly the same as the separate encoder/decoders when run through the custom training loop. It must be something I have done in the training loop then.

Sign in to comment.

 Accepted Answer

In terms of what may be different from trainnet, I don't see any regularization in your customized loop. You have a function called regularizedLoss(), but it doesn't seem to evaluate any regularization terms or apply any regularization hyperparameters.
Aside from that, I wonder where the parameter initialization is happening. Presumably it is in adamupdate(), but since you call adamupdate separately on netE and netD, I am not sure how that might be affecting the initialization as compared to when trainnet is used on the entire end-to-end network.

19 Comments

There is no regularization in either loop. The trainnet function simply uses 'mse', while the custom loop uses l2loss normalized over all elements, meaning their losses should only differ by a constant factor on the order of ~1. They are on the same scale when they train. The name 'regularizedLoss' is due to the fact that this started out as a regularized autoencoder from which I have been gradually stripping parts while debugging.
Parameter initialization for the weights in the convolutional layers should occur when the net = dlnetwork(layers) functions are called.
trainnet uses mse or whatever lossFcn you specify for the data loss, but there is always an L2 parameter regularization term added in the background, unless you set the trainingOption parameter L2Regularization to 0 (its default is 1e-4). It doesn't appear you have deactivated L2 regularization in your call to trainnet. It is in any case, a bad idea to omit regularization in either loop.
Alright, I'll try out the loss function with an L2 penalty on the weights. However, wouldn't a lack of a penalty on the weights in the custom model allow an increase in overfitting behavior, not severe underfitting?
Yes, that's would I expect. Your posted question doesn't show your training curves or mention any underfitting explicitly. From what I could interpret, the degenerate prediction you mention only happens at inference time, not during training.
It looks like the L2 penalty on the weights is the default structure for the convolutional layers, so it is implicit in both models, not just the trainnet one.
I'm not sure what you mean by "is the default structure for the convolutionlayers". Convolutional layers do have a WeightL2Factor and BiasL2Factor property, but I don't think that is enough to conclude that there is implicit L2regularization in the custom training loop, if that's what you meant.
Note that even in a non-custom loop,e .g., trainnet, these parameters do not determine the regularization loss on their own. They are used in conjunction with the L2Regularization trainingOption parameter, which globally scales the regularization weights in all layers.
However, wouldn't a lack of a penalty on the weights in the custom model allow an increase in overfitting behavior, not severe underfitting?
Perhaps, but I also wonder if, because your custom loop doesn't have the same overall lossFcn as trainnet, the learning rate you are applying with trainnet is not appropriate for the custom loop. Perhaps it needs to be retuned?
Trainnet training
The custom training loop does not perform well. It does not overfit, but once it has converged to the degenerate solution, its validation loss does not change.
I would not expect any difference between inference and training performance for this particular model since I am not learning normalization factors or using dropout, but I would like to hear your thought process.
I have tried manipulating the learning rate and batch size for the defective autoencoder to no avail. I suspect that MatLab's ability to trace the computational graph from the decoder output to the encoder input is broken at the encoding layer, but I do not know what syntax to use to maintain the graph.
I suspect that MatLab's ability to trace the computational graph from the decoder output to the encoder input is broken at the encoding layer, but I do not know what syntax to use to maintain the graph.
I don't see why that would be happening, but it should be something that is easy to test. You can implement a second version of your modelLoss that takes a single stack as input. Then you can run dlfeval on both versions and see if they return the same gradients (within floating point differences).
You are correct about the gradients being able to flow to the encoder. I was so convinced that they were not because I was certain I had missed some crucial bit of syntax, explaining the disparity between two otherwise congruent networks. The gradients are present in the encoder, however, by iteration 13, the gradients are down to a scale of 1e-40. This strikes me as odd because I have a total of four layers, all but the last of which use ReLU activation functions. I am going to try a stochastic gradient descent algorithm, in incase the two separate initializations of Adam create a subtle difference between calling 'adam' in optimizer options and using adamupdate.
Another observation. In the custom loop, I believe your imageInputLayers apply no normalization, but with trainnet, they will apply zero-center normalization by default. This would be another inconsistency between the scenarios.
How have I deactivated the imageInputLayer's zerocenter normalization in the custom training loop?
I noted that the custom loop zerocenters the images at the input of the decoder. Changing this to no normalization has no impact.
How have I deactivated the imageInputLayer's zerocenter
You haven't expressly deactivated it, but from the documentation for the imageInputLayer's 'Mean' property, we see that,
  • The trainnet function calculates the mean using the training data and uses the resulting value.
  • The initialize function and the dlnetwork function when the Initialize option is 1 (true) sets the property to 0.
In other words, since the custom loop does no pre-analysis of the data, the Mean is set to zero, which is effectively the same as deactivating the zero-center normalization.
Ah, I see. I assumed it calculated the mean value with the layer initialization and updated it in any training, not just with the trainnet function. Thank you!
In my own tests of the code, I see that trainnet() fails with a similar training curve as your custom loop when L2 regularization is deactivated. That's a pretty decent indication that omitting regularization is what needs to be fixed in the custom loop.
options = trainingOptions("adam", ...
L2Regularization=0,... %<----Matt J deactivated
InitialLearnRate=learnRate,...
MaxEpochs=30, ...
Plots="training-progress", ...
TargetDataFormats="SSCB", ...
InputDataFormats="SSCB", ...
MiniBatchSize=miniBatchSize, ...
OutputNetwork="last-iteration", ...
Shuffle="every-epoch");
This regularized version of the custom loop seems to work. You can of course modify this so that the regularization weight is not hardcoded.
%% Nonfunctional Custom Training Loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Encoder
layersE = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer];
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
layersD = [
imageInputLayer(projectionSize)
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name='Output')
];
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
%% Training Parameters
numEpochs = 10;
miniBatchSize = 25;
learnRate = 1e-3;
% Create training minibatchqueue
dsTrain = arrayDatastore(xTrain,IterationDimension=4);
numOutputs = 1;
mbq = minibatchqueue(dsTrain,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFormat="SSCB", ...
MiniBatchFcn=@preprocessMiniBatch,...
PartialMiniBatch="return");
%Initialize the parameters for the Adam solver.
trailingAvgE = [];
trailingAvgSqE = [];
trailingAvgD = [];
trailingAvgSqD = [];
%Calculate the total number of iterations for the training progress monitor
numIterationsPerEpoch = ceil(size(xTrain, 4) / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
epoch = 0;
iteration = 0;
%Initialize the training progress monitor.
monitor = trainingProgressMonitor( ...
Metrics="TrainingLoss", ...
Info=["Epoch", "LearningRate"], ...
XLabel="Iteration");
%% Training
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
% Assess validation criterion
iteration = iteration + 1;
% Read mini-batch of data.
X = next(mbq);
% Evaluate loss and gradients.
[loss,gradientsE,gradientsD] = dlfeval(@modelLoss,netE,netD,X);
% Update learnable parameters.
[netE,trailingAvgE,trailingAvgSqE] = adamupdate(netE, ...
gradientsE,trailingAvgE,trailingAvgSqE,iteration,learnRate);
[netD, trailingAvgD, trailingAvgSqD] = adamupdate(netD, ...
gradientsD,trailingAvgD,trailingAvgSqD,iteration,learnRate);
updateInfo(monitor, ...
LearningRate=learnRate, ...
Epoch=string(epoch) + " of " + string(numEpochs));
recordMetrics(monitor,iteration, ...
TrainingLoss=loss);
monitor.Progress = 100*iteration/numIterations;
end
end
%% Testing
dsTest = arrayDatastore(xTest,IterationDimension=4);
numOutputs = 1;
ntest = size(xTest, 4);
indices = randi(ntest,[1,ntest]);
xTest = xTest(:,:,:,indices);% Shuffle test data
mbqTest = minibatchqueue(dsTest,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB");
[Xmbq,Ymbq] = modelPredictions(netE,netD,mbqTest);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = Ymbq(:,:,:,1:numImages);
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = Xmbq(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Functions
function [loss,gradientsE,gradientsD] = modelLoss(netE,netD,X)
% Forward through encoder.
Z = forward(netE,X);
% Forward through decoder.
Xrecon = forward(netD,Z);
% Calculate loss and gradients.
loss = regularizedLoss(Xrecon,X);
[gradientsE,gradientsD] = dlgradient(loss,netE.Learnables,netD.Learnables);
%Regularize gradients
regFcn=@(g,p) g+1e-4*p;
gradientsE=dlupdate(regFcn,gradientsE,netE.Learnables);
gradientsD=dlupdate(regFcn,gradientsD,netD.Learnables);
end
function loss = regularizedLoss(Xrecon,X)
% Image Reconstruction loss.
reconstructionLoss = l2loss(Xrecon, X, 'NormalizationFactor','all-elements');
% Combined loss.
loss = reconstructionLoss;
end
function [Xs,Ys] = modelPredictions(netE,netD,mbq)
[Xs,Ys] = deal(cell(1,1e5));
i=0;
shuffle(mbq)
% Loop over mini-batches.
while hasdata(mbq)
X = next(mbq); i=i+1;
% Pass through encoder
Z = predict(netE,X);
% Pass through decoder to get reconstructed images
Xs{i} = extractdata(X);
Ys{i} = extractdata( predict(netD,Z));
end
Xs=cat(4,Xs{:}); Ys=cat(4,Ys{:});
end
function X = preprocessMiniBatch(Xcell)
% Concatenate.
X = cat(4,Xcell{:});
end
Marvelous! Thank you.
I know you've given me a great deal of your time, but if you could manage, I would like to hear your thoughts on why adding L2 regularization to the weights induces a solution that is a better fit to the training data, not just a more generalizable solution. My understanding of L2 regularization is that it adds a Gaussian prior about 0 for the weights on a given layer, pulling the solution closer towards the W = 0 origin. I can easily see how this favors more generalizable solutions, but I am not certain why adding this allows the solver to escape the local minimum of guessing the average.
My first guess would be that guessing the average reduces the error signal from the predictions, allowing the -const*w term to dominate in the gradient updates, which subsequently pulls the solution out of its degenerate behavior. This vague notion suggests this behavior is not general, but rather is unique to a dataset wherein the optimal solution lies closer to the weight-space origin than the average guess.
This vague notion suggests this behavior is not general, but rather is unique to a dataset wherein the optimal solution lies closer to the weight-space origin than the average guess.
I don't know how general it is, but deep learning data loss functions do tend to have plateaus and local minima at large values of the weights, because with large weights it is easy for the ReLUs to saturate.

Sign in to comment.

More Answers (0)

Products

Release

R2023b

Asked:

on 16 Jul 2024

Edited:

on 18 Jul 2024

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!