Dynamic sequence length for transforme-based model - error when exporting from Python to MATLAB
2 views (last 30 days)
Show older comments
We developed a simple transformer architecture (see the Python code below). This model, which we created using Python, can handle sequences of different lengths. I want to use my model in MATLAB. I tried to export the model to ONNX or to PT format. In both cases, I had to fix the input shape to export my model. I used the torch.jit.script() function in Python to trace and export my model in the .pt format. However, I think pytorchmex from the Deep Learning Toolbox Converter for PyTorch Models only works with torch.jit.trace.
I want to find a way to use a model in MATLAB that can accept inputs of any length.
Any help would be much appreciated.
# Python Code
# Model class to export
class TransformerModel(nn.Module):
def __init__(
self,
input_dim,
model_dim,
n_classes,
num_heads,
num_layers,
):
super(TransformerModel, self).__init__()
self.model_dim = model_dim
# Embedding Layer
self.embedding = nn.Linear(input_dim, model_dim)
# Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=model_dim,
nhead=num_heads,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
# Output Layer
self.output_layer = nn.Linear(model_dim, n_classes)
def forward(self, x, padding_mask):
padding_mask = ~padding_mask
x = self.embedding(x)
# Transformer Encoder
x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
# Model prediction
output = self.output_layer(x)
return output
0 Comments
Answers (0)
See Also
Categories
Find more on Deep Learning with GPU Coder in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!