Determining Coefficient Weights in a Rational Quadratic GPR model

Hello there I have a trianed GPR model stored in the workspace as "GPRModel". I have 63 input featured for my trianed model and I am trying to detmine which features are weighted the highest and contribute the most to the model's predictions. My training data with the features is stored in a Table named: "AllFinalTrials". Let me know if there is any way to determine the coefficents or wieghts of the featured for this Rational Quadratic GPR model. Thank you so much!

Answers (1)

If you want to know "which features ... contribute the most to the model's predictions", one way to do that is to use model interpretability techniques. Here are a couple of interpretability techniques that can produce an importance value for each feature of your Rational Quadratic GPR model:
For a general introduction to model interpretability, see https://www.mathworks.com/discovery/interpretability.html
If this answer helps you, please remember to accept the answer.

5 Comments

Hi there I am trying to uses the shapley values to poduce an importance value for each feature of my Rational Quadratic GPR model. My training data Table ("AllFinalTrials) is 5411x83 and 63 features that were included in the model. Because my table has over 5000 oberservations I think the function is having trouble running and is giving me a warning that the computation can be slow cause the predictor data has over 1000 observations. How to I randomize and choose a smaller set of training data or fix this issue? Here is my code:
explainer = shapley(GPRModel.predictFcn,AllFinalTrials)
queryPoint = AllFinalTrials(1,:)
explainer = fit(explainer,queryPoint,'UseParallel',true);
explainer.ShapleyValues
plot(explainer)
(1) Assuming that your model is a native MATLAB RegressionGP model, then the first argument to shapley function should just be the model (GPR.Model), not the predict function.
(2) Generally the number of recommended background samples is 100 to 1000. Create a bg100 set:
bg100 = datasample(AllFinalTrials, 100);
(3) So, with those changes, and for now, switching to UseParallel=false, to do serial computation instead of parallel, your code would become this for calculating and viewing local shapley values for a single query point:
bg100 = datasample(AllFinalTrials, 100);
explainer = shapley(GPRModel,bg100)
queryPoint = AllFinalTrials(1,:)
explainer = fit(explainer,queryPoint,'UseParallel',false);
explainer.ShapleyValues
plot(explainer)
For R2023b and earlier, you could calculate shapley values for all query points by looping over the query points, then calculate mean(abs(shapley_value)) for each predictor, and create a bar plot from that info. In R2024a, this is all done easily for you, as described in the next point.
(4) What you really want is a shapley importance plot, showing mean(abs(shap)) across a large set of query points. R2024a makes this easy by allowing multi-query-point calls to shapley, and new plot methods for multi-query-point shapley objects. So, you can do the following. You could test a small number of query points (10) first, to check the speed, and then use that to project how long all the points will take. If your parallel pool consists of just the threads or processes on a single machine, I recommend to set the Parallel Environment "Default Profile" to "Threads" (In Preferences => Parallel Computing Toolbox), then set UseParallel=true.
% This multi-query-point call to shapley works in R2024a
bg100 = datasample(AllFinalTrials, 100);
explainer = shapley(GPRModel, bg100, QueryPoints=AllFinalTrials, UseParallel=true);
% mean(abs(shap)) shapley importance plot
plot(explainer);
Here is an example of this type of plot, showing the importance of each predictor, given by mean(abs(shapley_value) over many query points, using the common abalone data (see the example at this page https://www.mathworks.com/help/stats/fitrgp.html ), and a RegressionGP model:
Another note: Permutation Importance is computationally inexpensive, so I recommend to look at permutation importance in addition to Shapley importance.
Thank you so much for the help! Unfortunately my MATLAB version is 2023b so I don't have the permutationImportance function avalible to me. I will try the first method and let you know how it performs.
If you have a MATLAB license, you can get access to R2024a in two ways:
(1) Access R2024a online at matlab.mathworks.com
(2) R2024a is also available for download from matlab.mathworks.com
Since you wrote "Thank you", I recommend to accept the answer. I see you have asked 9 questions, but never accepted an answer. If you are wondering how to accept an answer, there should be an "Accept this answer" button that appears next to the answerer's name, near the top of the answer.

Sign in to comment.

Asked:

on 2 Apr 2024

Commented:

on 12 Apr 2024

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!