Update BatchNorm Layer State in Siamese netwrok with custom loop for triplet and contrastive loss
Show older comments
HI everyone, I'm trying to implement a siamese network for face verification. I'm using as a subnetwork a Resnet18 pretrained on my dataset and I'm trying to implement the triplet loss and contrstive loss. The major problem is due to the batch normalization layer in my subnetwork that need to be updated durine the training fase using
dlnet.State=state;
But searching on mathworks tutorials, i found the update using only the Crossentropy with one dlarray as input in the forward function that return the state
function [loss,gradients,state] = modelLoss(net,X,T)
[Y,state] = forward(net,X);
....
end
At the moment this is my training loop for Contrastive loss, there is another one similar for the triplet loss that thake 3 images at time
for iteration = 1:numIterations
[X1,X2,pairLabels] = GetSiameseBatch(IMGS, miniBatchSize);
% Convert mini-batch of data to dlarray. Specify the dimension labels
% 'SSCB' (spatial, spatial, channel, batch) for image data
dlX1 = dlarray(single(X1),'SSCB');
dlX2 = dlarray(single(X2),'SSCB');
% clear X1 X2
% I load the pairs into the GPU memory
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
dlX1 = gpuArray(dlX1);
dlX2 = gpuArray(dlX2);
end
% Evaluate the model gradients and the generator state using
% dlfeval and the modelGradients functions
[loss,gradientsSubnet,state] = dlfeval(@modelLoss,dlnet,dlX1,dlX2,pairLabels);
dlnet.State = state;
% Update the Siamese subnetwork parameters. Scope: train the last fc
% for 128 dim features vector
[dlnet,trailingAvgSubnet,trailingAvgSqSubnet] = ...
adamupdate(dlnet,gradientsSubnet, ...
trailingAvgSubnet,trailingAvgSqSubnet,iteration,learningRate,gradDecay,gradDecaySq);
D = duration(0,0,toc(start),Format="hh:mm:ss");
lossValue = double(gather(extractdata(loss)));
% lossValue = double(loss);
addpoints(lineLoss,iteration,lossValue);
title("Elapsed: " + string(D))
drawnow
end
And the model loss is
function [loss,gradientsSubnet,state] = modelLoss(net,X1,X2,pairLabels)
% Pass the image pair through the network.
[F1,F2,state] = ForwardSiamese(net,X1,X2);
% Calculate binary cross-entropy loss.
margin = 1;
loss = ContrastiveLoss(F1,F2,pairLabels, margin);
% Calculate gradients of the loss with respect to the network learnable
% parameters.
gradientsSubnet = dlgradient(loss,net.Learnables);
end
But in the ForwardSiamese function I make the forward of the two dlarray X1 and X2 that contains the batch of pair images (i.e. in X1 there are 32 images, in X2 same, the first image in X1 is paired qith first image in X2 and so on) and compute the loss, but the state to update the batch norm layer where come from?
function [Y1,Y2,state] = ForwardSiamese(dlnet,dlX1,dlX2)
[Y1,state] = forward(dlnet,dlX1);
Y1 = sigmoid(Y1);
% Pass the second image through the twin subnetwork
Y2 = forward(dlnet,dlX2);
Y2 = sigmoid(Y2);
end
If i compute also [Y2,state] I have 2 states but which one should be used to update the batch norm TrainedMean and TrainedVariance?
Accepted Answer
More Answers (0)
Categories
Find more on Semantic Segmentation in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!