Retrieving Layer Activations from bertDocumentClassifier (Text Analytics Tooblx)

23 views (last 30 days)
Hi,
started using the text analytics toolbox, and successfully trained a bertDocmentClassifier network on my dataset.
In the past I've used the 'activations' function successfully to extract layer activations from dlNetworks.
However, for a bertDocmentClassifier, I cannot get the activations function to work, as it is not like e.g. image DL network objects - it has a tokenizer first.
So for example out=activations(bertTrained,textstring,layername) does not work
I tried to apply the tokenizer first, as in e.g.:
[a,b]=encode(mdl.Tokenizer,textDataTrain(1,:))
and that gives the token codes and segments fine in a,b.
But how do i "feed" those to the dlNetwork itself from the bertDocumentClassifer object?
This for example does NOT work:
net=mdl.Network;
activations(net,a,b,'out_fc2')
and variations fail as well.
So to sum up - I have a trained BERT classifier object, I can use it to classify just fine, but I can't get the network's layer activations.
Thanks!

Accepted Answer

Malay Agarwal
Malay Agarwal on 28 Mar 2024 at 12:37
Edited: Malay Agarwal on 29 Mar 2024 at 20:51
Hi Tsvi,
I understand that you want to retrieve layer activations from a trained “bertDocumentClassifier” model.
There are two reasons why the “activations” function is not working as expected.
First, the “InputNames” property of the underlying network for the model shows that the model accepts three inputs instead of two. Namely, it expects the input IDs, an attention mask, and the segment IDs.
In your code, you are calling the function with only two inputs, the input IDs and the segment IDs.
Second, the “activations” function only works with networks represented as “DAGNetwork” objects or “SeriesNetwork” objects, as specified in the documentation: https://www.mathworks.com/help/releases/R2023b/deeplearning/ref/seriesnetwork.activations.html#d126e5157.
The underlying network for “bertDocumentClassifier” is a “dlnetwork” object: https://www.mathworks.com/help/releases/R2023b/textanalytics/ref/bertdocumentclassifier.html#mw_2480ef12-2a75-480d-aec3-eefb236d8afe. For such objects, you need to use the “predict” or the “forward” function, based on whether you want the model to output for inference or for training.
Please try the following code. I am assuming you want the model to output for inference and hence, using the “predict” function. If you want the model to output for training, change the “predict” call to a “forward” call. No other changes will be required:
% Extract the network
net = mdl.Network;
% Extract an example and encode it
example = textDataTrain(1, :);
[tokens, segments] = encode(mdl.Tokenizer, example);
% Since tokens and segments is a cell arrays with single vectors
% Extract the vectors
tokens = tokens{1};
segments = segments{1};
% Extract number of tokens
dims = size(tokens, 2);
% Convert the tokens and segments to dlarray
% BERT expects input in CTB format
tokens = dlarray(tokens, "CTB");
segments = dlarray(segments, "CTB");
% Create an attention mask of all zeros in CTB format
attentionMask = dlarray(zeros(1, dims), "CTB");
% Use predict function to get the output of layer 'out_fc2'
output = predict(net, tokens, attentionMask, segments, 'Outputs', 'out_fc2');
The code:
This has the following output:
Please refer to the following resources for more information:
Hope this helps!

More Answers (1)

tsvi lev
tsvi lev on 30 Mar 2024 at 20:37
Excellent answer - logical, informed and with working code.
Thank you!

Categories

Find more on Statistics and Machine Learning Toolbox in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!