Why are the gradients not backpropagating into the encoder in this custom loop?
4 views (last 30 days)
Show older comments
I am building a convolutional autoencoder using a custom training loop. When I attempt to reconstruct the images, the network's output degenerates to guessing the same incorrect value for all inputs. However, training the autoencoder in a single stack with the trainnet function works fine, indicating that the gradient updates are unable to bridge the bottleneck layer in the custom training loop. Unfortunately, I need to use the custom training loop for a different task and am prohibited from using TensorFlow or PyTorch.
What is the syntax to ensure that the encoder is able to update based on the decoder's reconstruction performance?
%% Functional 'trainnet' loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
aeLayers = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name='Output')
];
autoencoder = dlnetwork(aeLayers);
%% Training Parameters
numEpochs = 150;
miniBatchSize = 25;
learnRate = 1e-3;
options = trainingOptions("adam", ...
InitialLearnRate=learnRate,...
MaxEpochs=30, ...
Plots="training-progress", ...
TargetDataFormats="SSCB", ...
InputDataFormats="SSCB", ...
MiniBatchSize=miniBatchSize, ...
OutputNetwork="last-iteration", ...
Shuffle="every-epoch");
autoencoder = trainnet(dlarray(xTrain, 'SSCB'),dlarray(xTrain, 'SSCB'), ...
autoencoder, 'mse', options);
%% Testing
YTest = predict(autoencoder, dlarray(xTest, 'SSCB'));
indices = randi(size(xTest, 4), [1, size(xTest, 4)]); % Shuffle YTest & xTest
xTest = xTest(:,:,:,indices); YTest = YTest(:,:,:,indices);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = extractdata(YTest(:,:,:,1:numImages));
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = xTest(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Nonfunctional Custom Training Loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Encoder
layersE = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer];
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
layersD = [
imageInputLayer(projectionSize)
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name='Output')
];
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
%% Training Parameters
numEpochs = 150;
miniBatchSize = 25;
learnRate = 1e-3;
% Create training minibatchqueue
dsTrain = arrayDatastore(xTrain,IterationDimension=4);
numOutputs = 1;
mbq = minibatchqueue(dsTrain,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFormat="SSCB", ...
MiniBatchFcn=@preprocessMiniBatch,...
PartialMiniBatch="return");
%Initialize the parameters for the Adam solver.
trailingAvgE = [];
trailingAvgSqE = [];
trailingAvgD = [];
trailingAvgSqD = [];
%Calculate the total number of iterations for the training progress monitor
numIterationsPerEpoch = ceil(size(xTrain, 4) / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
epoch = 0;
iteration = 0;
%Initialize the training progress monitor.
monitor = trainingProgressMonitor( ...
Metrics="TrainingLoss", ...
Info=["Epoch", "LearningRate"], ...
XLabel="Iteration");
%% Training
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
% Assess validation criterion
iteration = iteration + 1;
% Read mini-batch of data.
X = next(mbq);
% Evaluate loss and gradients.
[loss,gradientsE,gradientsD] = dlfeval(@modelLoss,netE,netD,X);
% Update learnable parameters.
[netE,trailingAvgE,trailingAvgSqE] = adamupdate(netE, ...
gradientsE,trailingAvgE,trailingAvgSqE,iteration,learnRate);
[netD, trailingAvgD, trailingAvgSqD] = adamupdate(netD, ...
gradientsD,trailingAvgD,trailingAvgSqD,iteration,learnRate);
updateInfo(monitor, ...
LearningRate=learnRate, ...
Epoch=string(epoch) + " of " + string(numEpochs));
recordMetrics(monitor,iteration, ...
TrainingLoss=loss);
monitor.Progress = 100*iteration/numIterations;
end
end
%% Testing
dsTest = arrayDatastore(xTest,IterationDimension=4);
numOutputs = 1;
ntest = size(xTest, 4);
indices = randi(ntest,[1,ntest]);
xTest = xTest(:,:,:,indices);% Shuffle test data
mbqTest = minibatchqueue(dsTest,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB");
YTest = modelPredictions(netE,netD,mbqTest);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = YTest(:,:,:,1:numImages);
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = xTest(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Functions
function [loss,gradientsE,gradientsD] = modelLoss(netE,netD,X)
% Forward through encoder.
Z = forward(netE,X);
% Forward through decoder.
Xrecon = forward(netD,Z);
% Calculate loss and gradients.
loss = regularizedLoss(Xrecon,X);
[gradientsE,gradientsD] = dlgradient(loss,netE.Learnables,netD.Learnables);
end
function loss = regularizedLoss(Xrecon,X)
% Image Reconstruction loss.
reconstructionLoss = l2loss(Xrecon, X, 'NormalizationFactor','all-elements');
% Combined loss.
loss = reconstructionLoss;
end
function Xrecon = modelPredictions(netE,netD,mbq)
Xrecon = [];
shuffle(mbq)
% Loop over mini-batches.
while hasdata(mbq)
X = next(mbq);
% Pass through encoder
Z = predict(netE,X);
% Pass through decoder to get reconstructed images
XGenerated = predict(netD,Z);
% Extract and concatenate predictions.
Xrecon = cat(4,Xrecon,extractdata(XGenerated));
end
end
function X = preprocessMiniBatch(Xcell)
% Concatenate.
X = cat(4,Xcell{:});
end
2 Comments
Matt J
on 16 Jul 2024
Edited: Matt J
on 16 Jul 2024
training the autoencoder in a single stack with the trainnet function works fine, indicating that the gradient updates are unable to bridge the bottleneck layer in the custom training loop.
I don't see why gradient back propagation failure is the only possible culprit. However, it should be something that is easy to test. You can implement a second version of your modelLoss that takes a single stack as input. Then you can run dlfeval on both versions and see if they return the same gradients (within floating point differences).
Accepted Answer
Matt J
on 16 Jul 2024
Edited: Matt J
on 16 Jul 2024
In terms of what may be different from trainnet, I don't see any regularization in your customized loop. You have a function called regularizedLoss(), but it doesn't seem to evaluate any regularization terms or apply any regularization hyperparameters.
Aside from that, I wonder where the parameter initialization is happening. Presumably it is in adamupdate(), but since you call adamupdate separately on netE and netD, I am not sure how that might be affecting the initialization as compared to when trainnet is used on the entire end-to-end network.
19 Comments
Matt J
on 18 Jul 2024
Edited: Matt J
on 18 Jul 2024
This vague notion suggests this behavior is not general, but rather is unique to a dataset wherein the optimal solution lies closer to the weight-space origin than the average guess.
I don't know how general it is, but deep learning data loss functions do tend to have plateaus and local minima at large values of the weights, because with large weights it is easy for the ReLUs to saturate.
More Answers (0)
See Also
Categories
Find more on Custom Training Loops in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!