Is there any documentation on how to build a transformer encoder from scratch in matlab?
166 views (last 30 days)
Show older comments
I am building a transformer encoder, and I came accross the following exchange: https://www.mathworks.com/matlabcentral/fileexchange/107375-transformer-models
However, in the exchange there are examples on how to use a pretrained transformer model. I just need an example on how to build a model. Something to give a general idea so I can build on it. I have studied the basics of transformers but I am having some difficulty building the model from scratch.
Thank you in advance.
1 Comment
Shubham
on 8 Sep 2023
Hi,
You can refer to this documentation:
This article is Tensorflow, but you can replicate this in MATLAB
Accepted Answer
Ben
on 18 Sep 2023
The general structure of the intermediate encoder blocks is like:
selfAttentionLayer(numHeads,numKeyChannels) % self attention
additionLayer(2,Name="attention_add") % residual connection around attention
layerNormalizationLayer(Name="attention_norm") % layer norm
fullyConnectedLayer(feedforwardHiddenSize) % feedforward part 1
reluLayer % nonlinear activation
fullyConnectedLayer(attentionHiddenSize) % feedforward part 2
additionLayer(2,Name="feedforward_add") % residual connection around feedforward
layerNormalizationLayer() % layer norm
You would need to hook up the connections to the addition layers appropriately.
Typically you would have multiple copies of this encoder block in a transformer encoder.
You also typically need an embedding at the start of the model. For text data it's common to use wordEmbeddingLayer whereas image data you would use patchEmbeddingLayer.
Also the above encoder block makes no use of positional information, so if your training task requires positional information to be used, you would typically inject the position information via a positionEmbeddingLayer or sinusoidalPositionEncodingLayer.
Finally the last encoder block will typically feed into a model "head" to map the encoder output back to the dimensions of the training targets. Typically this can just be some simple fullyConnectedLayer-s.
Note that for both image and sequence input data the output of the encoder is still an image or sequence, so for image classification and sequence-to-one tasks you need some way to map that sequence of encoder ouptuts to a fixed-size representation. For this you could use indexing1dLayer or pooling layers like globalMaxPooling1dLayer.
Here's a demonstration of the general architecture for a toy task. Given a sequence where each we can specify a task . For example would have and , then . This is a toy problem that requires positional information to solve and can be easily implemented in code. You can train a transformer encoder to predict y from x as follows:
% Create model
% We will use 2 encoder layers.
numHeads = 1;
numKeyChannels = 20;
feedforwardHiddenSize = 100;
modelHiddenSize = 20;
% Since the values in the sequence can be 1,2, ..., 10 the "vocabulary" size is 10.
vocabSize = 10;
inputSize = 1;
encoderLayers = [
sequenceInputLayer(1,Name="in") % input
wordEmbeddingLayer(modelHiddenSize,vocabSize,Name="embedding") % embedding
positionEmbeddingLayer(modelHiddenSize,vocabSize) % position embedding
additionLayer(2,Name="embed_add") % add the data and position embeddings
selfAttentionLayer(numHeads,numKeyChannels) % encoder block 1
additionLayer(2,Name="attention_add") %
layerNormalizationLayer(Name="attention_norm") %
fullyConnectedLayer(feedforwardHiddenSize) %
reluLayer %
fullyConnectedLayer(modelHiddenSize) %
additionLayer(2,Name="feedforward_add") %
layerNormalizationLayer(Name="encoder1_out") %
selfAttentionLayer(numHeads,numKeyChannels) % encoder block 2
additionLayer(2,Name="attention2_add") %
layerNormalizationLayer(Name="attention2_norm") %
fullyConnectedLayer(feedforwardHiddenSize) %
reluLayer %
fullyConnectedLayer(modelHiddenSize) %
additionLayer(2,Name="feedforward2_add") %
layerNormalizationLayer() %
indexing1dLayer %
fullyConnectedLayer(inputSize)]; % output head
net = dlnetwork(encoderLayers,Initialize=false);
net = connectLayers(net,"embed_add","attention_add/in2");
net = connectLayers(net,"embedding","embed_add/in2");
net = connectLayers(net,"attention_norm","feedforward_add/in2");
net = connectLayers(net,"encoder1_out","attention2_add/in2");
net = connectLayers(net,"attention2_norm","feedforward2_add/in2");
net = initialize(net);
% analyze the network to see how data flows through it
analyzeNetwork(net)
% create toy training data
% We will generate 10,000 sequences of length 10
% with values that are random integers 1-10
numObs = 10000;
seqLen = 10;
x = randi([1,10],[seqLen,numObs]);
% Loop over to create y(i) = x(x(1),i) + x(x(2),i)
y = zeros(numObs,1);
for i = 1:numObs
idx = x(1:2,i);
y(i) = sum(x(idx,i));
end
x = num2cell(x,1);
% specify training options and train
opts = trainingOptions("adam", ...
MaxEpochs = 200, ...
MiniBatchSize = numObs/10, ...
Plots="training-progress", ...
Shuffle="every-epoch", ...
InitialLearnRate=1e-2, ...
LearnRateDropFactor=0.9, ...
LearnRateDropPeriod=10, ...
LearnRateSchedule="piecewise");
net = trainnet(x,y,net,"mse",opts);
% test the network on a new input
x = randi([1,10],[seqLen,1]));
ypred = predict(net,x)
yact = x(x(1)) + x(x(2))
Obviously this is a toy task, but I think it demonstrates the parts of the standard transformer architecture. Two additional things you would likely need to deal with in real tasks is:
- For sequence data often the observations have different sequence lengths. For this you need to pad the data and pass padding masks to the selfAttentionLayer so that no attention is paid to padding elements.
- Often the encoder will be initially pre-trained on a self-supervised task, e.g. masked-language-modeling for natural language encoders.
Hope that helps.
8 Comments
haohaoxuexi1
on 27 Jul 2024
@Ben Hi Ben, Is it possible for you to provide me an example of applying Transformer network for classification task?
More Answers (1)
Mehernaz Savai
on 6 Dec 2024 at 19:50
In addition to Ben's suggestions, we have new articles that can be a good source for getting started with Transformers in MATLAB:
0 Comments
See Also
Categories
Find more on Deep Learning for Image Processing in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!