Update BatchNorm Layer State in Siamese netwrok with custom loop for triplet and contrastive loss

HI everyone, I'm trying to implement a siamese network for face verification. I'm using as a subnetwork a Resnet18 pretrained on my dataset and I'm trying to implement the triplet loss and contrstive loss. The major problem is due to the batch normalization layer in my subnetwork that need to be updated durine the training fase using
dlnet.State=state;
But searching on mathworks tutorials, i found the update using only the Crossentropy with one dlarray as input in the forward function that return the state
function [loss,gradients,state] = modelLoss(net,X,T)
[Y,state] = forward(net,X);
....
end
At the moment this is my training loop for Contrastive loss, there is another one similar for the triplet loss that thake 3 images at time
for iteration = 1:numIterations
[X1,X2,pairLabels] = GetSiameseBatch(IMGS, miniBatchSize);
% Convert mini-batch of data to dlarray. Specify the dimension labels
% 'SSCB' (spatial, spatial, channel, batch) for image data
dlX1 = dlarray(single(X1),'SSCB');
dlX2 = dlarray(single(X2),'SSCB');
% clear X1 X2
% I load the pairs into the GPU memory
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
dlX1 = gpuArray(dlX1);
dlX2 = gpuArray(dlX2);
end
% Evaluate the model gradients and the generator state using
% dlfeval and the modelGradients functions
[loss,gradientsSubnet,state] = dlfeval(@modelLoss,dlnet,dlX1,dlX2,pairLabels);
dlnet.State = state;
% Update the Siamese subnetwork parameters. Scope: train the last fc
% for 128 dim features vector
[dlnet,trailingAvgSubnet,trailingAvgSqSubnet] = ...
adamupdate(dlnet,gradientsSubnet, ...
trailingAvgSubnet,trailingAvgSqSubnet,iteration,learningRate,gradDecay,gradDecaySq);
D = duration(0,0,toc(start),Format="hh:mm:ss");
lossValue = double(gather(extractdata(loss)));
% lossValue = double(loss);
addpoints(lineLoss,iteration,lossValue);
title("Elapsed: " + string(D))
drawnow
end
And the model loss is
function [loss,gradientsSubnet,state] = modelLoss(net,X1,X2,pairLabels)
% Pass the image pair through the network.
[F1,F2,state] = ForwardSiamese(net,X1,X2);
% Calculate binary cross-entropy loss.
margin = 1;
loss = ContrastiveLoss(F1,F2,pairLabels, margin);
% Calculate gradients of the loss with respect to the network learnable
% parameters.
gradientsSubnet = dlgradient(loss,net.Learnables);
end
But in the ForwardSiamese function I make the forward of the two dlarray X1 and X2 that contains the batch of pair images (i.e. in X1 there are 32 images, in X2 same, the first image in X1 is paired qith first image in X2 and so on) and compute the loss, but the state to update the batch norm layer where come from?
function [Y1,Y2,state] = ForwardSiamese(dlnet,dlX1,dlX2)
[Y1,state] = forward(dlnet,dlX1);
Y1 = sigmoid(Y1);
% Pass the second image through the twin subnetwork
Y2 = forward(dlnet,dlX2);
Y2 = sigmoid(Y2);
end
If i compute also [Y2,state] I have 2 states but which one should be used to update the batch norm TrainedMean and TrainedVariance?

 Accepted Answer

Interesting question! The purpose of batch norm state is to collect statistics about typical inputs. In a normal Siamese workflow, both X1 and X2 are valid inputs, so you ought to be able to update the state with either result.
You could aggregate the state from both or even do an additional pass with both to compute the aggregated state, although this would come with extra performance cost. So
[~,dlnet.State] = forward(dlnet, cat(4,X1,X2));
You can do this after the call to dlfeval.

4 Comments

Thanks for the ansewr but the problem is a little different. If you look at the image, the network is branched to obtain both classification task and triplet/contrastive loss. This is just a test i'm making at the moment to compare this difficult network's results with the one that implent only contrastive/triplet loss.
The network implementing triplet/contrastive doesn't need the update of the batchNorm layers due to the fact the backbone network is a resnet18 with fine tuning on my dataset and i freeze all the layers until the average pooling layer in renset18.
While the network with the classification loss beahve in this way (i make an example for the triplet loss that is the most complicated).Try to image 6 parallel network that compute at the same time: 3 compute the embeddings for anchor, positive and negative and compute, at the end, the triplet loss; other 3 compute the classification loss for anchor, positive and negative and make an average loss. Summing the triplet loss and the average classification loss i compute the final loss, with this loss i update the network.
What i need was the state of batchNorm layer for the prediction that behaves differently from the forward one, but it can't depends only on one of the three images, but need to depends on the state obtain by the single forward of the 3 classification (not the forward of the triplet because is most important the classification for the backbone net), with your colleague we choose that the best way is to average the states obtained in the forward of classification to obtain a general state that uopdate the batchNorm layers in the backbone network, this works pretty well for the Contrastive Loss, but I'm training the triplet loss with this method and it tooks me a lot of time because if i branch the network, one of the two branch can't use the acceleration in the forward due to this error:
Previously accessible file "inmem:///deep_learning/tp513c20df_227f_4816_b6ef_795e1c33371d.m" is now inaccessible.
Error in deep.internal.recording.convert.tapeToFunction>@(varargin)fcnWithConstantsInput(varargin{:},constants) (line 37)
fcn = @(varargin)fcnWithConstantsInput(varargin{:},constants);
Error in deep.internal.AcceleratedOp/backward (line 69)
[varargout{1:op.NumGradients}] = backwardFun(varargin{:});
Error in deep.internal.recording.RecordingArray/backwardPass (line 90)
grad = backwardTape(tm,{y},{initialAdjoint},x,retainData,false,0);
Error in dlarray/dlgradient (line 132)
[grad,isTracedGrad] = backwardPass(y,xc,pvpairs{:});
Error in TrainingPhase>modelGradientsTripletLoss (line 91)
[gradientsSubnet] = dlgradient(finalLoss,dlnet.Learnables);
The leak of acceleration is a huge problem, at the moment i'm training the network with my university server and take a lot of time ( i think it will last for some days also with 6 TitanX and 100GB of ram, but i'm training a batch of size=64).
Thanks for the answer, i will continue in this way using the average state, i will test it in development phase and in case i will came back and try to use your code. I hope this answer can help more people
That bug is now fixed in the latest version of MATLAB R2022b.
If you are only fine-tuning part of the network then you only need to update the state for the part you are modifying. It looks like in your case that part doesn't contain any batch normalization layers. In which case, don't update the State at all!
Great news for the bug, maybe i could speed up the process.
For the state i need the update because when i use the classification loss i have to train all the network, not only the layers after the pooling in the resnet18 as backbone

Sign in to comment.

More Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!