Train a Siamese Network to Compare Images
This example shows how to train a Siamese network to identify similar images of handwritten characters.
A Siamese network is a type of deep learning network that uses two or more identical subnetworks that have the same architecture and share the same parameters and weights. Siamese networks are typically used in tasks that involve finding the relationship between two comparable things. Some common applications for Siamese networks include facial recognition, signature verification [1], or paraphrase identification [2]. Siamese networks perform well in these tasks because their shared weights mean there are fewer parameters to learn during training and they can produce good results with a relatively small amount of training data.
Siamese networks are particularly useful in cases where there are large numbers of classes with small numbers of observations of each. In such cases, there is not enough data to train a deep convolutional neural network to classify images into these classes. Instead, the Siamese network can determine if two images are in the same class.
This example use the Omniglot dataset [3] to train a Siamese network to compare images of handwritten characters [4]. The Omniglot dataset contains character sets for 50 alphabets, divided into 30 used for training and 20 for testing. Each alphabet contains a number of characters from 14 for Ojibwe (Canadia Aboriginal Sullabics) to 55 for Tifinagh. Finally, each character has 20 handwritten observations. This example trains a network to identify whether two handwritten observations are different instances of the same character.
You can also use Siamese networks to identify similar images using dimensionality reduction. For an example, see Train a Siamese Network for Dimensionality Reduction.
Load and Preprocess Training Data
Download and extract the Omniglot training dataset.
url = "https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip"; downloadFolder = tempdir; filename = fullfile(downloadFolder,"images_background.zip"); dataFolderTrain = fullfile(downloadFolder,"images_background"); if ~exist(dataFolderTrain,"dir") disp("Downloading Omniglot training data (4.5 MB)...") websave(filename,url); unzip(filename,downloadFolder); end disp("Training data downloaded.")
Training data downloaded.
Load the training data as a image datastore using the imageDatastore
function. Specify the labels manually by extracting the labels from the file names and setting the Labels
property.
imdsTrain = imageDatastore(dataFolderTrain, ... IncludeSubfolders=true, ... LabelSource="none"); files = imdsTrain.Files; parts = split(files,filesep); labels = join(parts(:,(end-2):(end-1)),"-"); imdsTrain.Labels = categorical(labels);
The Omniglot training dataset consists of black and white handwritten characters from 30 alphabets, with 20 observations of each character. The images are of size 105-by-105-by-1, and the values of each pixel are between 0
and 1
.
Display a random selection of the images.
idx = randperm(numel(imdsTrain.Files),8); for i = 1:numel(idx) subplot(4,2,i) imshow(readimage(imdsTrain,idx(i))) title(imdsTrain.Labels(idx(i)),Interpreter="none"); end
Create Pairs of Similar and Dissimilar Images
To train the network, the data must be grouped into pairs of images that are either similar or dissimilar. Here, similar images are different handwritten instances of the same character, which have the same label, while dissimilar images of different characters have different labels. The function getSiameseBatch
(defined in the Supporting Functions section of this example) creates randomized pairs of similar or dissimilar images, pairImage1
and pairImage2
. The function also returns the label pairLabel
, which identifies if the pair of images is similar or dissimilar to each other. Similar pairs of images have pairLabel = 1
, while dissimilar pairs have pairLabel = 0
.
As an example, create a small representative set of five pairs of images
batchSize = 10; [pairImage1,pairImage2,pairLabel] = getSiameseBatch(imdsTrain,batchSize);
Display the generated pairs of images.
for i = 1:batchSize if pairLabel(i) == 1 s = "similar"; else s = "dissimilar"; end subplot(2,5,i) imshow([pairImage1(:,:,:,i) pairImage2(:,:,:,i)]); title(s) end
In this example, a new batch of 180 paired images is created for each iteration of the training loop. This ensures that the network is trained on a large number of random pairs of images with approximately equal proportions of similar and dissimilar pairs.
Define Network Architecture
The Siamese network architecture is illustrated in the following diagram.
To compare two images, each image is passed through one of two identical subnetworks that share weights. The subnetworks convert each 105-by-105-by-1 image to a 4096-dimensional feature vector. Images of the same class have similar 4096-dimensional representations. The output feature vectors from each subnetwork are combined through subtraction and the result is passed through a fullyconnect
operation with a single output. A sigmoid
operation converts this value to a probability between 0
and 1
, indicating the network prediction of whether the images are similar or dissimilar. The binary cross-entropy loss between the network prediction and the true label is used to update the network during training.
In this example, the two identical subnetworks are defined as a dlnetwork
object. The final fullyconnect
and sigmoid
operations are performed as functional operations on the subnetwork outputs.
Create the subnetwork as a series of layers that accepts 105-by-105-by-1 images and outputs a feature vector of size 4096.
For the convolution2dLayer
objects, use the narrow normal distribution to initialize the weights and bias.
For the maxPooling2dLayer
objects, set the stride to 2
.
For the final fullyConnectedLayer
object, specify an output size of 4096 and use the narrow normal distribution to initialize the weights and bias.
layers = [ imageInputLayer([105 105 1],Normalization="none") convolution2dLayer(10,64,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal") reluLayer maxPooling2dLayer(2,Stride=2) convolution2dLayer(7,128,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal") reluLayer maxPooling2dLayer(2,Stride=2) convolution2dLayer(4,128,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal") reluLayer maxPooling2dLayer(2,Stride=2) convolution2dLayer(5,256,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal") reluLayer fullyConnectedLayer(4096,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")]; lgraph = layerGraph(layers);
To train the network with a custom training loop and enable automatic differentiation, convert the layer graph to a dlnetwork
object.
net = dlnetwork(lgraph);
Create the weights for the final fullyconnect
operation. Initialize the weights by sampling a random selection from a narrow normal distribution with standard deviation of 0.01.
fcWeights = dlarray(0.01*randn(1,4096)); fcBias = dlarray(0.01*randn(1,1)); fcParams = struct(... "FcWeights",fcWeights,... "FcBias",fcBias);
To use the network, create the function forwardSiamese
(defined in the Supporting Functions section of this example) that defines how the two subnetworks and the subtraction, fullyconnect,
and sigmoid
operations are combined. The function forwardSiamese
accepts the network, the structure containing the parameters for the fullyconnect
operation, and two training images. The forwardSiamese
function outputs a prediction about the similarity of the two images.
Define Model Loss Function
Create the function modelLoss
(defined in the Supporting Functions section of this example). The modelLoss
function takes the Siamese subnetwork net
, the parameter structure for the fullyconnect
operation, and a mini-batch of input data X1
and X2
with their labels pairLabels
. The function returns the loss values and the gradients of the loss with respect to the learnable parameters of the network.
The objective of the Siamese network is to discriminate between the two inputs X1
and X2
. The output of the network is a probability between 0
and 1
, where a value closer to 0
indicates a prediction that the images are dissimilar, and a value closer to 1
that the images are similar. The loss is given by the binary cross-entropy between the predicted score and the true label value:
where the true label can be 0 or 1 and is the predicted label.
Specify Training Options
Specify the options to use during training. Train for 10000 iterations.
numIterations = 10000; miniBatchSize = 180;
Specify the options for ADAM optimization:
Set the learning rate to
0.00006
.Set the gradient decay factor to
0.9
and the squared gradient decay factor to0.99
.
learningRate = 6e-5; gradDecay = 0.9; gradDecaySq = 0.99;
Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). To automatically detect if you have a GPU available and place the relevant data on the GPU, set the value of executionEnvironment
to "auto"
. If you do not have a GPU, or do not want to use one for training, set the value of executionEnvironment
to "cpu"
. To ensure you use a GPU for training, set the value of executionEnvironment
to "gpu"
.
executionEnvironment = "auto";
Train Model
Initialize the training progress plot.
figure C = colororder; lineLossTrain = animatedline(Color=C(2,:)); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on
Initialize the parameters for the ADAM solver.
trailingAvgSubnet = []; trailingAvgSqSubnet = []; trailingAvgParams = []; trailingAvgSqParams = [];
Train the model using a custom training loop. Loop over the training data and update the network parameters at each iteration.
For each iteration:
Extract a batch of image pairs and labels using the
getSiameseBatch
function defined in the section Create Batches of Image Pairs.Convert the data to
dlarray
objects specify the dimension labels"SSCB"
(spatial, spatial, channel, batch) for the image data and"CB"
(channel, batch) for the labels.For GPU training, convert the data to
gpuArray
objects.Evaluate the model loss and gradients using
dlfeval
and themodelLoss
function.Update the network parameters using the
adamupdate
function.
start = tic; % Loop over mini-batches. for iteration = 1:numIterations % Extract mini-batch of image pairs and pair labels [X1,X2,pairLabels] = getSiameseBatch(imdsTrain,miniBatchSize); % Convert mini-batch of data to dlarray. Specify the dimension labels % "SSCB" (spatial, spatial, channel, batch) for image data X1 = dlarray(X1,"SSCB"); X2 = dlarray(X2,"SSCB"); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" X1 = gpuArray(X1); X2 = gpuArray(X2); end % Evaluate the model loss and gradients using dlfeval and the modelLoss % function listed at the end of the example. [loss,gradientsSubnet,gradientsParams] = dlfeval(@modelLoss,net,fcParams,X1,X2,pairLabels); % Update the Siamese subnetwork parameters. [net,trailingAvgSubnet,trailingAvgSqSubnet] = adamupdate(net,gradientsSubnet, ... trailingAvgSubnet,trailingAvgSqSubnet,iteration,learningRate,gradDecay,gradDecaySq); % Update the fullyconnect parameters. [fcParams,trailingAvgParams,trailingAvgSqParams] = adamupdate(fcParams,gradientsParams, ... trailingAvgParams,trailingAvgSqParams,iteration,learningRate,gradDecay,gradDecaySq); % Update the training loss progress plot. D = duration(0,0,toc(start),Format="hh:mm:ss"); lossValue = double(loss); addpoints(lineLossTrain,iteration,lossValue); title("Elapsed: " + string(D)) drawnow end
Evaluate the Accuracy of the Network
Download and extract the Omniglot test dataset.
url = "https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip"; downloadFolder = tempdir; filename = fullfile(downloadFolder,"images_evaluation.zip"); dataFolderTest = fullfile(downloadFolder,"images_evaluation"); if ~exist(dataFolderTest,"dir") disp("Downloading Omniglot test data (3.2 MB)...") websave(filename,url); unzip(filename,downloadFolder); end disp("Test data downloaded.")
Test data downloaded.
Load the test data as a image datastore using the imageDatastore
function. Specify the labels manually by extracting the labels from the file names and setting the Labels
property.
imdsTest = imageDatastore(dataFolderTest, ... IncludeSubfolders=true, ... LabelSource="none"); files = imdsTest.Files; parts = split(files,filesep); labels = join(parts(:,(end-2):(end-1)),"_"); imdsTest.Labels = categorical(labels);
The test dataset contains 20 alphabets that are different to those that the network was trained on. In total, there 659 different classes in the test dataset.
numClasses = numel(unique(imdsTest.Labels))
numClasses = 659
To calculate the accuracy of the network, create a set of five random mini-batches of test pairs. Use the predictSiamese
function (defined in the Supporting Functions section of this example) to evaluate the network predictions and calculate the average accuracy over the mini-batches.
accuracy = zeros(1,5); accuracyBatchSize = 150; for i = 1:5 % Extract mini-batch of image pairs and pair labels [X1,X2,pairLabelsAcc] = getSiameseBatch(imdsTest,accuracyBatchSize); % Convert mini-batch of data to dlarray. Specify the dimension labels % "SSCB" (spatial, spatial, channel, batch) for image data. X1 = dlarray(X1,"SSCB"); X2 = dlarray(X2,"SSCB"); % If using a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" X1 = gpuArray(X1); X2 = gpuArray(X2); end % Evaluate predictions using trained network Y = predictSiamese(net,fcParams,X1,X2); % Convert predictions to binary 0 or 1 Y = gather(extractdata(Y)); Y = round(Y); % Compute average accuracy for the minibatch accuracy(i) = sum(Y == pairLabelsAcc)/accuracyBatchSize; end
Compute accuracy over all minibatches
averageAccuracy = mean(accuracy)*100
averageAccuracy = 89.0667
Display a Test Set of Images with Predictions
To visually check if the network correctly identifies similar and dissimilar pairs, create a small batch of image pairs to test. Use the predictSiamese function to get the prediction for each test pair. Display the pair of images with the prediction, the probability score, and a label indicating whether the prediction was correct or incorrect.
testBatchSize = 10; [XTest1,XTest2,pairLabelsTest] = getSiameseBatch(imdsTest,testBatchSize);
Convert the test batch of data to dlarray
. Specify the dimension labels "SSCB"
(spatial, spatial, channel, batch) for image data.
XTest1 = dlarray(XTest1,"SSCB"); XTest2 = dlarray(XTest2,"SSCB");
If using a GPU, then convert the data to gpuArray
.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" XTest1 = gpuArray(XTest1); XTest2 = gpuArray(XTest2); end
Calculate the predicted probability.
YScore = predictSiamese(net,fcParams,XTest1,XTest2); YScore = gather(extractdata(YScore));
Convert the predictions to binary 0 or 1.
YPred = round(YScore);
Extract the data for plotting.
XTest1 = extractdata(XTest1); XTest2 = extractdata(XTest2);
Plot images with predicted label and predicted score.
f = figure; tiledlayout(2,5); f.Position(3) = 2*f.Position(3); predLabels = categorical(YPred,[0 1],["dissimilar" "similar"]); targetLabels = categorical(pairLabelsTest,[0 1],["dissimilar","similar"]); for i = 1:numel(pairLabelsTest) nexttile imshow([XTest1(:,:,:,i) XTest2(:,:,:,i)]); title( ... "Target: " + string(targetLabels(i)) + newline + ... "Predicted: " + string(predLabels(i)) + newline + ... "Score: " + YScore(i)) end
The network is able to compare the test images to determine their similarity, even though none of these images were in the training dataset.
Supporting Functions
Model Functions for Training and Prediction
The function forwardSiamese
is used during network training. The function defines how the subnetworks and the fullyconnect
and sigmoid
operations combine to form the complete Siamese network. forwardSiamese accepts the network structure and two training images and outputs a prediction about the similarity of the two images. Within this example, the function forwardSiamese
is introduced in the section Define Network Architecture.
function Y = forwardSiamese(net,fcParams,X1,X2) % forwardSiamese accepts the network and pair of training images, and % returns a prediction of the probability of the pair being similar (closer % to 1) or dissimilar (closer to 0). Use forwardSiamese during training. % Pass the first image through the twin subnetwork Y1 = forward(net,X1); Y1 = sigmoid(Y1); % Pass the second image through the twin subnetwork Y2 = forward(net,X2); Y2 = sigmoid(Y2); % Subtract the feature vectors Y = abs(Y1 - Y2); % Pass the result through a fullyconnect operation Y = fullyconnect(Y,fcParams.FcWeights,fcParams.FcBias); % Convert to probability between 0 and 1. Y = sigmoid(Y); end
The function predictSiamese
uses the trained network to make predictions about the similarity of two images. The function is similar to the function forwardSiamese
, defined previously. However, predictSiamese
uses the predict
function with the network instead of the forward
function, because some deep learning layers behave differently during training and prediction. Within this example, the function predictSiamese
is introduced in the section Evaluate the Accuracy of the Network.
function Y = predictSiamese(net,fcParams,X1,X2) % predictSiamese accepts the network and pair of images, and returns a % prediction of the probability of the pair being similar (closer to 1) or % dissimilar (closer to 0). Use predictSiamese during prediction. % Pass the first image through the twin subnetwork. Y1 = predict(net,X1); Y1 = sigmoid(Y1); % Pass the second image through the twin subnetwork. Y2 = predict(net,X2); Y2 = sigmoid(Y2); % Subtract the feature vectors. Y = abs(Y1 - Y2); % Pass result through a fullyconnect operation. Y = fullyconnect(Y,fcParams.FcWeights,fcParams.FcBias); % Convert to probability between 0 and 1. Y = sigmoid(Y); end
Model Loss Function
The function modelLoss
takes the Siamese dlnetwork
object net
, a pair of mini-batch input data X1
and X2
, and the label indicating whether they are similar or dissimilar. The function returns the binary cross-entropy loss between the prediction and the ground truth and the gradients of the loss with respect to the learnable parameters in the network. Within this example, the function modelLoss
is introduced in the section Define Model Loss Function.
function [loss,gradientsSubnet,gradientsParams] = modelLoss(net,fcParams,X1,X2,pairLabels) % Pass the image pair through the network. Y = forwardSiamese(net,fcParams,X1,X2); % Calculate binary cross-entropy loss. loss = binarycrossentropy(Y,pairLabels); % Calculate gradients of the loss with respect to the network learnable % parameters. [gradientsSubnet,gradientsParams] = dlgradient(loss,net.Learnables,fcParams); end
Binary Cross-Entropy Loss Function
The binarycrossentropy
function accepts the network prediction and the pair labels, and returns the binary cross-entropy loss value.
function loss = binarycrossentropy(Y,pairLabels) % Get precision of prediction to prevent errors due to floating point % precision. precision = underlyingType(Y); % Convert values less than floating point precision to eps. Y(Y < eps(precision)) = eps(precision); % Convert values between 1-eps and 1 to 1-eps. Y(Y > 1 - eps(precision)) = 1 - eps(precision); % Calculate binary cross-entropy loss for each pair loss = -pairLabels.*log(Y) - (1 - pairLabels).*log(1 - Y); % Sum over all pairs in minibatch and normalize. loss = sum(loss)/numel(pairLabels); end
Create Batches of Image Pairs
The following functions create randomized pairs of images that are similar or dissimilar, based on their labels. Within this example, the function getSiameseBatch
is introduced in the section Create Pairs of Similar and Dissimilar Images.
Get Siamese Batch Function
The getSiameseBatch
returns a randomly selected batch or paired images. On average, this function produces a balanced set of similar and dissimilar pairs.
function [X1,X2,pairLabels] = getSiameseBatch(imds,miniBatchSize) pairLabels = zeros(1,miniBatchSize); imgSize = size(readimage(imds,1)); X1 = zeros([imgSize 1 miniBatchSize],"single"); X2 = zeros([imgSize 1 miniBatchSize],"single"); for i = 1:miniBatchSize choice = rand(1); if choice < 0.5 [pairIdx1,pairIdx2,pairLabels(i)] = getSimilarPair(imds.Labels); else [pairIdx1,pairIdx2,pairLabels(i)] = getDissimilarPair(imds.Labels); end X1(:,:,:,i) = imds.readimage(pairIdx1); X2(:,:,:,i) = imds.readimage(pairIdx2); end end
Get Similar Pair Function
The getSimilarPair
function returns a random pair of indices for images that are in the same class and the similar pair label equals 1.
function [pairIdx1,pairIdx2,pairLabel] = getSimilarPair(classLabel) % Find all unique classes. classes = unique(classLabel); % Choose a class randomly which will be used to get a similar pair. classChoice = randi(numel(classes)); % Find the indices of all the observations from the chosen class. idxs = find(classLabel==classes(classChoice)); % Randomly choose two different images from the chosen class. pairIdxChoice = randperm(numel(idxs),2); pairIdx1 = idxs(pairIdxChoice(1)); pairIdx2 = idxs(pairIdxChoice(2)); pairLabel = 1; end
Get Disimilar Pair Function
The getDissimilarPair
function returns a random pair of indices for images that are in different classes and the dissimilar pair label equals 0.
function [pairIdx1,pairIdx2,label] = getDissimilarPair(classLabel) % Find all unique classes. classes = unique(classLabel); % Choose two different classes randomly which will be used to get a % dissimilar pair. classesChoice = randperm(numel(classes),2); % Find the indices of all the observations from the first and second % classes. idxs1 = find(classLabel==classes(classesChoice(1))); idxs2 = find(classLabel==classes(classesChoice(2))); % Randomly choose one image from each class. pairIdx1Choice = randi(numel(idxs1)); pairIdx2Choice = randi(numel(idxs2)); pairIdx1 = idxs1(pairIdx1Choice); pairIdx2 = idxs2(pairIdx2Choice); label = 0; end
References
[1] Bromley, J., I. Guyon, Y. LeCun, E. Säckinger, and R. Shah. "Signature Verification using a "Siamese" Time Delay Neural Network." In Proceedings of the 6th International Conference on Neural Information Processing Systems (NIPS 1993), 1994, pp737-744. Available at Signature Verification using a "Siamese" Time Delay Neural Network on the NIPS Proceedings website.
[2] Wenpeg, Y., and H Schütze. "Convolutional Neural Network for Paraphrase Identification." In Proceedings of 2015 Conference of the North American Chapter of the ACL, 2015, pp901-911. Available at Convolutional Neural Network for Paraphrase Identification on the ACL Anthology website
[3] Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. "Human-level concept learning through probabilistic program induction." Science, 350(6266), (2015) pp1332-1338.
[4] Koch, G., Zemel, R., and Salakhutdinov, R. (2015). "Siamese neural networks for one-shot image recognition". In Proceedings of the 32nd International Conference on Machine Learning, 37 (2015). Available at Siamese Neural Networks for One-shot Image Recognition on the ICML'15 website.
See Also
dlarray
| dlgradient
| dlfeval
| dlnetwork
| adamupdate