how can I obtain shapley values from convolutional neural network
Show older comments
clear all;
close all;
[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData;
[XTest,YTest,anglesTest] = digitTest4DArrayData;
% Define the layers of the CNN
layers = [
imageInputLayer([28 28 1]) % Assuming grayscale 64x64 images
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(10) % Assuming 10 classes
softmaxLayer
classificationLayer];
% Define the training options
options = trainingOptions('adam', ...
'MaxEpochs',10, ...
'InitialLearnRate',1e-4, ...
'Verbose',false, ...
'Plots','training-progress');
% Train the CNN on the training data
net = trainNetwork(XTrain, categorical(YTrain), layers, options);
% Predict on the test data
YPred = classify(net, XTest);
accuracy = mean(YPred == categorical(YTest));
% Display the confusion matrix
confusionchart(categorical(YTest), YPred);
explainer = shapley( ...
@(XTest)PredictCNN(net,XTest,YTest(1)), ...
reshape(XTest,[5000,28*28]), "QueryPoint", reshape(XTest(:,:,1,1),[1,28*28]) );
function score = PredictCNN(net,XTest,YTest)
YPred = predict(net,XTest);
score = YPred(:,double(YTest));
end
Accepted Answer
More Answers (1)
Taylor
on 24 Oct 2024
The shapley function expects the input data to be in a format suitable for the model. You are reshaping XTest to a 2D matrix with dimensions [5000, 28*28], assuming 5000 samples of 28x28 images. Ensure that XTest indeed has 5000 samples. If not, adjust the reshape dimensions accordingly.
The PredictCNN function is used as a handle in the shapley function. It takes XTest and YTest as inputs. However, the shapley function only passes the reshaped XTest. You will need to modify PredictCNN to handle this correctly, possibly by removing YTest from its input arguments.
Modify PredictCNN to accommodate the input format expected by shapley:
numSamples = size(XTest, 4); % Adjust based on your dataset
explainer = shapley( ...
@(XTest)PredictCNN(net, XTest), ...
reshape(XTest, [numSamples, 28*28]), ...
"QueryPoint", reshape(XTest(:,:,1,1), [1, 28*28]) );
Ensure that the XTest reshaping aligns with the actual number of samples:
function score = PredictCNN(net, XTest)
YPred = predict(net, XTest);
score = YPred;
end
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!