Dynamic sequence length for transforme-based model - error when exporting from Python to MATLAB

2 views (last 30 days)
We developed a simple transformer architecture (see the Python code below). This model, which we created using Python, can handle sequences of different lengths. I want to use my model in MATLAB. I tried to export the model to ONNX or to PT format. In both cases, I had to fix the input shape to export my model. I used the torch.jit.script() function in Python to trace and export my model in the .pt format. However, I think pytorchmex from the Deep Learning Toolbox Converter for PyTorch Models only works with torch.jit.trace.
I want to find a way to use a model in MATLAB that can accept inputs of any length.
Any help would be much appreciated.
# Python Code
# Model class to export
class TransformerModel(nn.Module):
def __init__(
self,
input_dim,
model_dim,
n_classes,
num_heads,
num_layers,
):
super(TransformerModel, self).__init__()
self.model_dim = model_dim
# Embedding Layer
self.embedding = nn.Linear(input_dim, model_dim)
# Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=model_dim,
nhead=num_heads,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
# Output Layer
self.output_layer = nn.Linear(model_dim, n_classes)
def forward(self, x, padding_mask):
padding_mask = ~padding_mask
x = self.embedding(x)
# Transformer Encoder
x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
# Model prediction
output = self.output_layer(x)
return output

Answers (0)

Categories

Find more on Deep Learning with GPU Coder in Help Center and File Exchange

Products


Release

R2024b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!