Determining Coefficient Weights in a Rational Quadratic GPR model
Show older comments
Hello there I have a trianed GPR model stored in the workspace as "GPRModel". I have 63 input featured for my trianed model and I am trying to detmine which features are weighted the highest and contribute the most to the model's predictions. My training data with the features is stored in a Table named: "AllFinalTrials". Let me know if there is any way to determine the coefficents or wieghts of the featured for this Rational Quadratic GPR model. Thank you so much!
Answers (1)
Drew
on 3 Apr 2024
0 votes
If you want to know "which features ... contribute the most to the model's predictions", one way to do that is to use model interpretability techniques. Here are a couple of interpretability techniques that can produce an importance value for each feature of your Rational Quadratic GPR model:
- Permutation importance: https://www.mathworks.com/help/stats/permutationimportance.html
- Shapley values: https://www.mathworks.com/help/stats/shapley.html
For a general introduction to model interpretability, see https://www.mathworks.com/discovery/interpretability.html
If this answer helps you, please remember to accept the answer.
5 Comments
Isabelle Museck
on 3 Apr 2024
(1) Assuming that your model is a native MATLAB RegressionGP model, then the first argument to shapley function should just be the model (GPR.Model), not the predict function.
(2) Generally the number of recommended background samples is 100 to 1000. Create a bg100 set:
bg100 = datasample(AllFinalTrials, 100);
(3) So, with those changes, and for now, switching to UseParallel=false, to do serial computation instead of parallel, your code would become this for calculating and viewing local shapley values for a single query point:
bg100 = datasample(AllFinalTrials, 100);
explainer = shapley(GPRModel,bg100)
queryPoint = AllFinalTrials(1,:)
explainer = fit(explainer,queryPoint,'UseParallel',false);
explainer.ShapleyValues
plot(explainer)
For R2023b and earlier, you could calculate shapley values for all query points by looping over the query points, then calculate mean(abs(shapley_value)) for each predictor, and create a bar plot from that info. In R2024a, this is all done easily for you, as described in the next point.
(4) What you really want is a shapley importance plot, showing mean(abs(shap)) across a large set of query points. R2024a makes this easy by allowing multi-query-point calls to shapley, and new plot methods for multi-query-point shapley objects. So, you can do the following. You could test a small number of query points (10) first, to check the speed, and then use that to project how long all the points will take. If your parallel pool consists of just the threads or processes on a single machine, I recommend to set the Parallel Environment "Default Profile" to "Threads" (In Preferences => Parallel Computing Toolbox), then set UseParallel=true.
% This multi-query-point call to shapley works in R2024a
bg100 = datasample(AllFinalTrials, 100);
explainer = shapley(GPRModel, bg100, QueryPoints=AllFinalTrials, UseParallel=true);
% mean(abs(shap)) shapley importance plot
plot(explainer);
Here is an example of this type of plot, showing the importance of each predictor, given by mean(abs(shapley_value) over many query points, using the common abalone data (see the example at this page https://www.mathworks.com/help/stats/fitrgp.html ), and a RegressionGP model:

Another note: Permutation Importance is computationally inexpensive, so I recommend to look at permutation importance in addition to Shapley importance.
Isabelle Museck
on 3 Apr 2024
Drew
on 3 Apr 2024
If you have a MATLAB license, you can get access to R2024a in two ways:
(1) Access R2024a online at matlab.mathworks.com
(2) R2024a is also available for download from matlab.mathworks.com
Drew
on 12 Apr 2024
Since you wrote "Thank you", I recommend to accept the answer. I see you have asked 9 questions, but never accepted an answer. If you are wondering how to accept an answer, there should be an "Accept this answer" button that appears next to the answerer's name, near the top of the answer.
Categories
Find more on Gaussian Process Regression in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!