How to obtain Shapley Values with a pattern classification neural network?
3 views (last 30 days)
Show older comments
Margarita Cabrera
on 23 Jun 2021
Answered: Amanjit Dulai
on 18 Jul 2022
We wanted to obtain the Shapley Values for a feature vector (query point) and a database trained with a neuralnet initiated with patternet.
As our neuralnet is a “pattern recognition neural network” We obtain this error message:
Error using shapley (line XX)
Blackbox model must be a classification model, regression model, or function handle
So, my question is: Is there any option to compute the Shapley values for a “pattern recognition neural network” model
Or instead, can We convert a “pattern recognition neural network” into a “classification neural network” in order to compute their Shappey values?
Thanks in any case.
0 Comments
Accepted Answer
Amanjit Dulai
on 18 Jul 2022
It is possible to do this by passing a function handle to shapley. This function handle needs to output the score for the class of interest. Also, shapley expects inputs and outputs for the function handle to be row vectors rather than column vectors, so some transposes are needed. Below is an example using the Fisher Iris data:
% Train a neural network on the iris data
[x,t] = iris_dataset;
net = patternnet(10);
net = train(net,x,t);
% Choose an observation to explain. We need its class as an index.
x1 = x(:,1);
t1 = find(t(:,1));
% Plot shapley values. For Setosa (the first class) the petal length (x3)
% is usually the most informative feature.
explainer = shapley( ...
@(x)predictScoreForSpecifiedClass(net,x,t1), ...
x', "QueryPoint", x1' );
plot(explainer)
% Helpers
function score = predictScoreForSpecifiedClass(net, x, classIndex)
Y = net(x');
score = Y(classIndex,:)';
end
0 Comments
More Answers (0)
See Also
Categories
Find more on Parallel and Cloud in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!