Train a Siamese Network for Dimensionality Reduction

This example shows how to train a Siamese network to compare handwritten digits using dimensionality reduction.

A Siamese network is a type of deep learning network that uses two or more identical subnetworks that have the same architecture and share the same parameters and weights. Siamese networks are typically used in tasks that involve finding the relationship between two comparable things. Some common applications for Siamese networks include facial recognition, signature verification [1], or paraphrase identification [2]. Siamese networks perform well in these tasks because their shared weights mean there are fewer parameters to learn during training and they can produce good results with a relatively small amount of training data.

Siamese networks are particularly useful in cases where there are large numbers of classes with small numbers of observations of each. In such cases, there is not enough data to train a deep convolutional neural network to classify images into these classes. Instead, the Siamese network can determine if two images are in the same class. The network does this by reducing the dimensionality of the training data and using a distance-based cost function to differentiate between the classes.

This example uses a Siamese network for dimensionality reduction of a collection of images of handwritten digits. The Siamese architecture reduces the dimensionality by mapping images with the same class to nearby points in a low-dimensional space. The reduced-feature representation is then used to extract images from the dataset that are most similar to a test image. The training data in this example are images of size 28-by-28-by-1, giving an initial feature dimensionality of 784. The Siamese network reduces the dimensionality of the input images to two features and is trained to output similar reduced features for images with the same label.

You can also use Siamese networks to identify similar images by directly comparing them. For an example, see Train a Siamese Network to Compare Images.

Load and Preprocess Training Data

Load the training data, which consists of images of handwritten digits.The function digitTrain4DArrayData loads the digit images and their labels.

[XTrain,YTrain] = digitTrain4DArrayData;

XTrain is a 28-by-28-by-1-by-5000 array containing 5000 single-channel images, each of size 28-by-28. The values of each pixel are between 0 and 1. YTrain is a categorical vector containing the labels for each observation, which are the numbers from 0 to 9 corresponding to the value of the written digit.

Display a random selection of the images.

perm = randperm(numel(YTrain), 9);
imshow(imtile(XTrain(:,:,:,perm),"ThumbnailSize",[100 100]));

Create Pairs of Similar and Dissimilar Images

To train the network, the data must be grouped into pairs of images that are either similar or dissimilar. Here, similar images are defined as having the same label, while dissimilar images have different labels. The function getSiameseBatch (defined in the Supporting Functions section of this example) creates randomized pairs of similar or dissimilar images, pairImage1 and pairImage2. The function also returns the label pairLabel, which identifies if the pair of images is similar or dissimilar to each other. Similar pairs of images have pairLabel = 1, while dissimilar pairs have pairLabel = 0.

As an example, create a small representative set of five pairs of images

batchSize = 10;
[pairImage1,pairImage2,pairLabel] = getSiameseBatch(XTrain,YTrain,batchSize);

Display the generated pairs of images.

for i = 1:batchSize
subplot(2,5,i)
imshow([pairImage1(:,:,:,i) pairImage2(:,:,:,i)]);
if pairLabel(i) == 1
    s = "similar";
else
    s = "dissimilar";
end
title(s)
end

In this example, a new batch of 180 paired images is created for each iteration of the training loop. This ensures that the network is trained on a large number of random pairs of images with approximately equal proportions of similar and dissimilar pairs.

Define Network Architecture

The Siamese network achitecture is illustrated in the following diagram.

In this example, the two identical subnetworks are defined as a series of fully connected layers with ReLU layers. Create a network that accepts 28-by-28-by-1 images and outputs the two feature vectors used for the reduced feature representation. The network reduces the dimensionality of the input images to two, a value that is easier to plot and visualize than the initial dimensionality of 784.

For the first two fully connected layers, specify an output size of 1024 and use the He weight initializer.

For the final fully connected layer, specify an output size of two and use the He weights initializer.

layers = [
    imageInputLayer([28 28],'Name','input1','Normalization','none')
    fullyConnectedLayer(1024,'Name','fc1','WeightsInitializer','he')
    reluLayer('Name','relu1')
    fullyConnectedLayer(1024,'Name','fc2','WeightsInitializer','he')
    reluLayer('Name','relu2')
    fullyConnectedLayer(2,'Name','fc3','WeightsInitializer','he')];

lgraph = layerGraph(layers);

To train the network with a custom training loop and enable automatic differentiation, convert the layer graph to a dlnetwork object.

dlnet = dlnetwork(lgraph);

Define Model Gradients Function

Create the function modelGradients (defined in the Supporting Functions section of this example). The modelGradients function takes the Siamese dlnetwork object dlnet and a mini-batch of input data dlX1 and dlX2 with ther labels pairLabels. The function returns the loss values and the gradients of the loss with respect to the learnable parameters of the network.

The objective of the Siamese network is to output a feature vector for each image such that the feature vectors are similar for similar images, and notably different for dissimilar images. In this way, the network can discriminate between the two inputs.

Find the contrastive loss between the outputs from the last fully connected layer, the feature vectors features1 and features1 from pairImage1 and pairImage2, respectively. The contrastive loss for a pair is given by [3]

loss=12yd2+12(1-y)max(margin-d,0)2,

where y is the value of the pair label (y=1 for similar images;y=0 for dissimilar images), and d is the Euclidean distance between two features vectors f1 and f2: d=f1-f22.

The margin parameter is used for constraint: if two images in a pair are dissimilar, then their distance should be at least margin, or a loss will be incurred.

The contrastive loss has two terms, but only one is ever non-zero for a given image pair. In the case of similar images, the first term can be non-zero and is minimized by reducing the distance between the image features f1 and f2. In the case of dissimilar images, the second term can be non-zero, and is minimized by increasing the distance between the image features, to at least a distance of margin. The smaller the value of margin, the less constraining it is over how close a dissimilar pair can be before a loss is incurred.

Specify Training Options

Specify the value of margin to use during training.

margin = 0.3;

Specify the options to use during training. Train for 3000 iterations.

numIterations = 3000;
miniBatchSize = 180;

Specify the options for ADAM optimization:

  • Set the learning rate to 0.0001.

  • Initialize the trailing average gradient and trailing average gradient-square decay rates with [].

  • Set the gradient decay factor to 0.9 and the squared gradient decay factor to 0.99.

learningRate = 1e-4;
trailingAvg = [];
trailingAvgSq = [];
gradDecay = 0.9;
gradDecaySq = 0.99;

Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher. To automatically detect if you have a GPU available and place the relevant data on the GPU, set the value of executionEnvironment to "auto". If you don't have a GPU, or don't want to use one for training, set the value of executionEnvironment to "cpu". To ensure you use a GPU for training, set the value of executionEnvironment to "gpu".

executionEnvironment = "auto";

To monitor the training progress, you can plot the training loss after each iteration. Create the variable plots that contains "training-progress". If you don't want to plot the training progress, set this value to "none".

plots = "training-progress";

Initialize the plot parameters for the training loss progress plot.

plotRatio = 16/9;

if plots == "training-progress"
    trainingPlot = figure;
    trainingPlot.Position(3) = plotRatio*trainingPlot.Position(4);
    trainingPlot.Visible = 'on';
    
    trainingPlotAxes = gca;
    
    lineLossTrain = animatedline(trainingPlotAxes);
    xlabel(trainingPlotAxes,"Iteration")
    ylabel(trainingPlotAxes,"Loss")
    title(trainingPlotAxes,"Loss During Training")
end

To evaluate how well the network is doing at dimensionality reduction, compute and plot the reduced features of a set of test data after each iteration. Load the test data, which consists of images of handwritten digits similar to the training data. Convert the test data to dlarray and specify the dimension labels 'SSCB' (spatial, spatial, channel, batch). If you are using a GPU, convert the test data to gpuArray.

[XTest,YTest] = digitTest4DArrayData;
dlXTest = dlarray(single(XTest),'SSCB');

% If training on a GPU, then convert data to gpuArray.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);           
end 

Initialize the plot parameters for the reduced-feature plot of the test data.

dimensionPlot = figure;
dimensionPlot.Position(3) = plotRatio*dimensionPlot.Position(4);
dimensionPlot.Visible = 'on';

dimensionPlotAxes = gca;

uniqueGroups = unique(YTest);
colors = hsv(length(uniqueGroups));

Initialize a counter to keep track of the total number of iterations.

iteration = 1;

Train Model

Train the model using a custom training loop. Loop over the training data and update the network parameters at each iteration.

For each iteration:

  • Extract a batch of image pairs and labels using the getSiameseBatch function defined in the section Create Batches of Image Pairs.

  • Convert the image data to dlarray objects with underlying type single and specify the dimension labels 'SSCB' (spatial, spatial, channel, batch).

  • For GPU training, convert the image data to gpuArray objects.

  • Evaluate the model gradients using dlfeval and the modelGradients function.

  • Update the network parameters using the adamupdate function.

% Loop over mini-batches.
for iteration = 1:numIterations
    
    % Extract mini-batch of image pairs and pair labels
    [X1,X2,pairLabels] = getSiameseBatch(XTrain,YTrain,miniBatchSize);
    
    % Convert mini-batch of data to dlarray. Specify the dimension labels
    % 'SSCB' (spatial, spatial, channel, batch) for image data
    dlX1 = dlarray(single(X1),'SSCB');
    dlX2 = dlarray(single(X2),'SSCB');
    
    % If training on a GPU, then convert data to gpuArray.
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        dlX1 = gpuArray(dlX1);
        dlX2 = gpuArray(dlX2);
    end       
    
    % Evaluate the model gradients and the generator state using
    % dlfeval and the modelGradients function listed at the end of the
    % example.
    [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX1,dlX2,pairLabels,margin);
    lossValue = double(gather(extractdata(loss)));
    
    % Update the Siamese network parameters.
    [dlnet.Learnables,trailingAvg,trailingAvgSq] = ...
        adamupdate(dlnet.Learnables,gradients, ...
        trailingAvg,trailingAvgSq,iteration,learningRate,gradDecay,gradDecaySq);
    
    % Update the training loss progress plot.
    if plots == "training-progress"
        addpoints(lineLossTrain,iteration,lossValue);
    end
            
    % Update the reduced-feature plot of the test data.        
    % Compute reduced features of the test data:
    dlFTest = predict(dlnet,dlXTest);
    FTest = extractdata(dlFTest);
       
    figure(dimensionPlot);
    for k = 1:length(uniqueGroups)
        % Get indices of each image in test data with the same numeric 
        % label (defined by the unique group):
        ind = YTest==uniqueGroups(k);
        % Plot this group:
        plot(dimensionPlotAxes,gather(FTest(1,ind)'),gather(FTest(2,ind)'),'.','color',...
            colors(k,:));
        hold on
    end
    
    legend(uniqueGroups)
    
    % Update title of reduced-feature plot with training progress information.
    title(dimensionPlotAxes,"2-D Feature Representation of Digits Images. Iteration = " +...
        iteration);
    legend(dimensionPlotAxes,'Location','eastoutside');
    xlabel(dimensionPlotAxes,"Feature 1")
    ylabel(dimensionPlotAxes,"Feature 2")
    
    hold off    
    drawnow    
end

The network has now learned to represent each image as a 2-D vector. As you can see from the reduced-feature plot of the test data, images of similar digits are clustered close to each other in this 2-D representation.

Use the Trained Network to Find Similar Images

You can use the trained network to find a selection of images that are similar to each other out of a group. In this case, use the test data as the group of images. Convert the group of images to dlarray objects and gpuArray objects, if you are using a GPU.

groupX = XTest;

dlGroupX = dlarray(single(groupX),'SSCB');

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlGroupX = gpuArray(dlGroupX);           
end 

Extract a single test image from the group and display it. Remove the test image from the group so that it does not appear in the set of similar images.

testIdx = randi(5000);
testImg = dlGroupX(:,:,:,testIdx);

trialImgDisp = extractdata(testImg);

figure
imshow(trialImgDisp, 'InitialMagnification', 500);

dlGroupX(:,:,:,testIdx) = [];

Find the reduced features of the test image using predict.

trialF = predict(dlnet,testImg);

Find the 2-D reduced feature representation of each of the images in the group using the trained network.

FGroupX = predict(dlnet,dlGroupX);

Use the the reduced feature representation to find the nine images in the group that are closest to the test image, using the Euclidean distance metric. Display the images.

distances = vecnorm(extractdata(trialF - FGroupX));
[~,idx] = sort(distances);
sortedImages = groupX(:,:,:,idx);

figure
imshow(imtile(sortedImages(:,:,:,1:9)), 'InitialMagnification', 500);

By reducing the images to a lower dimensionality, the network is able to identify images that are similar to the trial image. The reduced feature representation allows the network to discriminate between images that are similar and dissimilar. Siamese networks are often used in the context of facial or signature recognition. For example, you can train a Siamese network to accept an image of a face as an imput, and return a set of the most similar faces from a database.

Supporting Functions

Model Gradients Function

The function modelGradients takes the Siamese dlnetwork object dlnet, a pair of mini-batch input data X1 and X2, and the label pairLabels. The function returns the gradients of the loss with respect to the learnable parameters in the network as well as the contrastive loss between the reduced dimensionality features of the paired images. Within this example, the function modelGradients is introduced in the section Define Model Gradients Function.

function [gradients, loss] = modelGradients(net,X1,X2,pairLabel,margin)
% The modelGradients function calculates the contrastive loss between the
% paired images and returns the loss and the gradients of the loss with 
% respect to the network learnable parameters

    % Pass first half of image pairs forward through the network
    F1 = forward(net,X1);
    % Pass second set of image pairs forward through the network
    F2 = forward(net,X2);
    
    % Calculate contrastive loss
    loss = contrastiveLoss(F1,F2,pairLabel,margin);
    
    % Calculate gradients of the loss with respect to the network learnable
    % parameters
    gradients = dlgradient(loss, net.Learnables);

end

function loss = contrastiveLoss(F1,F2,pairLabel,margin)
% The contrastiveLoss function calculates the contrastive loss between
% the reduced features of the paired images 
    
    % Define small value to prevent taking square root of 0
    delta = 1e-6;
    
    % Find Euclidean distance metric
    distances = sqrt(sum((F1 - F2).^2,1) + delta);
    
    % label(i) = 1 if features1(:,i) and features2(:,i) are features
    % for similar images, and 0 otherwise
    lossSimilar = pairLabel.*(distances.^2);
 
    lossDissimilar = (1 - pairLabel).*(max(margin - distances, 0).^2);
    
    loss = 0.5*sum(lossSimilar + lossDissimilar,'all');
end

Create Batches of Image Pairs

The following functions create randomized pairs of images that are similar or dissimilar, based on their labels. Within this example, the function getSiameseBatch is introduced in the section Create Pairs of Similar and Dissimilar Images.

function [X1,X2,pairLabels] = getSiameseBatch(X,Y,miniBatchSize)
% getSiameseBatch returns a randomly selected batch of paired images. 
% On average, this function produces a balanced set of similar and 
% dissimilar pairs.
    pairLabels = zeros(1, miniBatchSize);
    imgSize = size(X(:,:,:,1));
    X1 = zeros([imgSize 1 miniBatchSize]);
    X2 = zeros([imgSize 1 miniBatchSize]);
    
    for i = 1:miniBatchSize
        choice = rand(1);
        if choice < 0.5
            [pairIdx1, pairIdx2, pairLabels(i)] = getSimilarPair(Y);
        else
            [pairIdx1, pairIdx2, pairLabels(i)] = getDissimilarPair(Y);
        end
        X1(:,:,:,i) = X(:,:,:,pairIdx1);
        X2(:,:,:,i) = X(:,:,:,pairIdx2);
    end
    
end

function [pairIdx1,pairIdx2,pairLabel] = getSimilarPair(classLabel)
% getSimilarPair returns a random pair of indices for images
% that are in the same class and the similar pair label = 1.

    % Find all unique classes.
    classes = unique(classLabel);
    
    % Choose a class randomly which will be used to get a similar pair.
    classChoice = randi(numel(classes));
    
    % Find the indices of all the observations from the chosen class.
    idxs = find(classLabel==classes(classChoice));
    
    % Randomly choose two different images from the chosen class.
    pairIdxChoice = randperm(numel(idxs),2);
    pairIdx1 = idxs(pairIdxChoice(1));
    pairIdx2 = idxs(pairIdxChoice(2));
    pairLabel = 1;
end

function  [pairIdx1,pairIdx2,pairLabel] = getDissimilarPair(classLabel)
% getDissimilarPair returns a random pair of indices for images
% that are in different classes and the dissimilar pair label = 0.

    % Find all unique classes.
    classes = unique(classLabel);
    
    % Choose two different classes randomly which will be used to get a dissimilar pair.
    classesChoice = randperm(numel(classes), 2);
    
    % Find the indices of all the observations from the first and second classes.
    idxs1 = find(classLabel==classes(classesChoice(1)));
    idxs2 = find(classLabel==classes(classesChoice(2)));
    
    % Randomly choose one image from each class.
    pairIdx1Choice = randi(numel(idxs1));
    pairIdx2Choice = randi(numel(idxs2));
    pairIdx1 = idxs1(pairIdx1Choice);
    pairIdx2 = idxs2(pairIdx2Choice);
    pairLabel = 0;
end

References

[1] Bromley, J., I. Guyon, Y. LeCunn, E. Säckinger, and R. Shah. "Signature Verification using a "Siamese" Time Delay Neural Network." In Proceedings of the 6th International Conference on Neural Information Processing Systems (NIPS 1993), 1994, pp737-744. Available at Signature Verification using a "Siamese" Time Delay Neural Network on the NIPS Proceedings website.

[2] Wenpeg, Y., and H Schütze. "Convolutional Neural Network for Paraphrase Identification." In Proceedings of 2015 Conference of the North American Cahapter of the ACL, 2015, pp901-911. Available at Convolutional Neural Network for Paraphrase Identification on the ACL Anthology website.

[3] Hadsell, R., S. Chopra, and Y. LeCunn. "Dimensionality Reduction by Learning an Invariant Mapping." In Proceedings of the 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR 2006), 2006, pp1735-1742.

See Also

| | | |

Related Topics