Main Content

Train Generalized Additive Model for Binary Classification

This example shows how to train a generalized additive model (GAM) with optimal parameters and how to assess the predictive performance of the trained model. The example first finds the optimal parameter values for a univariate GAM (parameters for linear terms) and then finds the values for a bivariate GAM (parameters for interaction terms). Also, the example explains how to interpret the trained model by examining local effects of terms on a specific prediction and by computing the partial dependence of the predictions on predictors.

Load Sample Data

Load the 1994 census data stored in census1994.mat. The data set consists of demographic data from the US Census Bureau to predict whether an individual makes over $50,000 per year. The classification task is to fit a model that predicts the salary category of people given their age, working class, education level, marital status, race, and so on.

load census1994

census1994 contains the training data set adultdata and the test data set adulttest. To reduce the running time for this example, subsample 500 training observations and 500 test observations by using the datasample function.

rng('default')
NumSamples = 5e2;
adultdata = datasample(adultdata,NumSamples,'Replace',false);
adulttest = datasample(adulttest,NumSamples,'Replace',false);

Find Optimal Parameters for Univariate GAM

Optimize the parameters for a univariate GAM with respect to cross-validation by using the bayesopt function.

Prepare optimizableVariable objects for the name-value arguments of a univariate GAM: MaxNumSplitsPerPredictor, NumTreesPerPredictor, and InitialLearnRateForPredictors.

maxNumSplitsPerPredictor = optimizableVariable('maxNumSplitsPerPredictor',[1,10],'Type','integer');
numTreesPerPredictor = optimizableVariable('numTreesPerPredictor',[1,500],'Type','integer');
initialLearnRateForPredictors = optimizableVariable('initialLearnRateForPredictors',[1e-3,1],'Type','real');

Create an objective function that takes an input z = [maxNumSplitsPerPredictor,numTreesPerPredictor,initialLearnRateForPredictors] and returns the cross-validated loss value at the parameters in z.

minfun1 = @(z)kfoldLoss(fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',z.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',z.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',z.numTreesPerPredictor));

If you specify the cross-validation option 'CrossVal','on', then the fitcgam function returns a cross-validated model object ClassificationPartitionedGAM. The kfoldLoss function returns the classification loss obtained by the cross-validated model. Therefore, the function handle minfun computes the cross-validation loss at the parameters in z.

Search for the best parameters using bayesopt. For reproducibility, choose the 'expected-improvement-plus' acquisition function. The default acquisition function depends on run time and, therefore, can give varying results.

results1 = bayesopt(minfun1, ...
    [initialLearnRateForPredictors,maxNumSplitsPerPredictor,numTreesPerPredictor], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|    1 | Best   |     0.18549 |      5.6957 |     0.18549 |     0.18549 |      0.73503 |            7 |           99 |
|    2 | Accept |     0.19145 |      20.383 |     0.18549 |     0.18549 |      0.72917 |           10 |          399 |
|    3 | Best   |     0.17703 |      13.412 |     0.17703 |     0.17703 |     0.079299 |            8 |          267 |
|    4 | Best   |     0.14955 |       0.402 |     0.14955 |     0.14955 |      0.24236 |            4 |            3 |
|    5 | Accept |     0.15999 |      12.363 |     0.14955 |     0.14955 |      0.25509 |            1 |          377 |
|    6 | Accept |     0.15158 |      1.5035 |     0.14955 |     0.14955 |      0.23051 |            7 |           29 |
|    7 | Accept |     0.16181 |     0.18204 |     0.14955 |     0.14955 |      0.34396 |            4 |            1 |
|    8 | Accept |     0.15079 |     0.38418 |     0.14955 |     0.14955 |      0.26669 |           10 |            5 |
|    9 | Accept |     0.16102 |     0.55525 |     0.14955 |     0.14955 |      0.26065 |            2 |           10 |
|   10 | Accept |     0.19259 |      8.6487 |     0.14955 |     0.14955 |      0.24894 |           10 |          182 |
|   11 | Accept |     0.18628 |     0.20681 |     0.14955 |     0.14955 |      0.13389 |            6 |            2 |
|   12 | Accept |     0.15653 |     0.24643 |     0.14955 |     0.14955 |      0.24172 |           10 |            2 |
|   13 | Best   |     0.14699 |     0.82743 |     0.14699 |     0.14699 |      0.26745 |            7 |           12 |
|   14 | Best   |     0.14634 |     0.47528 |     0.14634 |     0.14634 |      0.25025 |            6 |            6 |
|   15 | Best   |     0.14312 |     0.34493 |     0.14312 |     0.14312 |      0.30452 |            9 |            3 |
|   16 | Accept |     0.14334 |     0.51583 |     0.14312 |     0.14312 |      0.33507 |           10 |            7 |
|   17 | Best   |     0.13791 |     0.32248 |     0.13791 |     0.13791 |      0.33179 |            9 |            4 |
|   18 | Accept |     0.14875 |      0.3551 |     0.13791 |     0.13791 |      0.36806 |            8 |            5 |
|   19 | Accept |      0.1651 |      1.3731 |     0.13791 |     0.13791 |      0.32691 |            8 |           27 |
|   20 | Accept |     0.15895 |     0.37324 |     0.13791 |     0.13791 |      0.32985 |            7 |            5 |
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|   21 | Accept |     0.13946 |     0.26793 |     0.13791 |     0.13791 |      0.36721 |            9 |            3 |
|   22 | Accept |     0.16719 |      1.1276 |     0.13791 |     0.13791 |      0.25385 |            5 |           23 |
|   23 | Accept |     0.17017 |        1.35 |     0.13791 |     0.13791 |      0.23809 |            9 |           26 |
|   24 | Accept |     0.15519 |     0.46246 |     0.13791 |     0.13791 |      0.34831 |            9 |            7 |
|   25 | Accept |     0.15312 |     0.26445 |     0.13791 |     0.13791 |      0.33416 |           10 |            3 |
|   26 | Accept |     0.15852 |     0.31045 |     0.13791 |     0.13791 |       0.6142 |            9 |            4 |
|   27 | Accept |     0.16691 |     0.50559 |     0.13791 |     0.13791 |      0.31446 |            5 |            7 |
|   28 | Accept |     0.14384 |     0.35136 |     0.13791 |     0.13791 |      0.40215 |            9 |            4 |
|   29 | Accept |     0.14773 |     0.33296 |     0.13791 |     0.13791 |      0.34255 |            9 |            4 |
|   30 | Accept |     0.17604 |     0.85847 |     0.13791 |     0.13791 |      0.36565 |            6 |           15 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 97.6656 seconds
Total objective function evaluation time: 74.4022

Best observed feasible point:
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Observed objective function value = 0.13791
Estimated objective function value = 0.13791
Function evaluation time = 0.32248

Best estimated feasible point (according to models):
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Estimated objective function value = 0.13791
Estimated function evaluation time = 0.33084

Obtain the best point from results1.

zbest1 = bestPoint(results1)
zbest1=1×3 table
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Train Univariate GAM with Optimal Parameters

Train an optimized GAM using the zbest1 values. A recommended practice is to specify the class names.

Mdl1 = fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'ClassNames',categorical({'<=50K','>50K'}), ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor) 
Mdl1 = 
  ClassificationGAM
           PredictorNames: {'age'  'workClass'  'education'  'education_num'  'marital_status'  'occupation'  'relationship'  'race'  'sex'  'capital_gain'  'capital_loss'  'hours_per_week'  'native_country'}
             ResponseName: 'salary'
    CategoricalPredictors: [2 3 5 6 7 8 9 13]
               ClassNames: [<=50K    >50K]
           ScoreTransform: 'logit'
                Intercept: -1.7383
          NumObservations: 500


  Properties, Methods

Mdl1 is a ClassificationGAM model object. The model display shows a partial list of the model properties. To view the full list of the model properties, double-click the variable name Mdl1 in the Workspace. The Variables editor opens for Mdl1. Alternatively, you can display the properties in the Command Window by using dot notation. For example, display the ReasonForTermination property.

Mdl1.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: ''

The PredictorTrees field of the property value indicates that Mdl1 includes the specified number of trees. NumTreesPerPredictor of fitcgam specifies the maximum number of trees per predictor, and the function can stop before training the requested number of trees. You can use the ReasonForTermination property to determine whether the trained model contains the specified number of trees.

If you specify to include interaction terms so that fitcgam trains trees for them, then the InteractionTrees field contains a nonempty value.

Find Optimal Parameters for Bivariate GAM

Find the parameters for interaction terms of a bivariate GAM by using the bayesopt function.

Prepare optimizableVariable objects for the name-value arguments for the interaction terms: InitialLearnRateForInteractions, MaxNumSplitsPerInteraction, NumTreesPerInteraction, and InitialLearnRateForInteractions.

initialLearnRateForInteractions = optimizableVariable('initialLearnRateForInteractions',[1e-3,1],'Type','real');
maxNumSplitsPerInteraction = optimizableVariable('maxNumSplitsPerInteraction',[1,10],'Type','integer');
numTreesPerInteraction = optimizableVariable('numTreesPerInteraction',[1,500],'Type','integer');
numInteractions = optimizableVariable('numInteractions',[1,28],'Type','integer');

Create an objective function for the optimization. Use the optimal parameter values in zbest1 so that the software finds optimal parameter values for interaction terms based on the zbest1 values.

minfun2 = @(z)kfoldLoss(fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',z.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',z.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',z.numTreesPerInteraction, ...
    'Interactions',z.numInteractions));

Search for the best parameters using bayesopt. The optimization process trains multiple models and displays warning messages if the models include no interaction terms. Disable all warnings before calling bayesopt and restore the warning state after running bayesopt. You can leave the warning state unchanged to view the warning messages.

orig_state = warning('query'); 
warning('off')
results2 = bayesopt(minfun2, ...
    [initialLearnRateForInteractions,maxNumSplitsPerInteraction,numTreesPerInteraction,numInteractions], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|    1 | Best   |     0.19671 |      10.999 |     0.19671 |     0.19671 |      0.96444 |            8 |          109 |           22 |
|    2 | Best   |       0.189 |       30.57 |       0.189 |       0.189 |      0.98548 |            6 |          457 |           17 |
|    3 | Best   |     0.16538 |      18.643 |     0.16538 |     0.16538 |      0.28678 |            4 |          383 |           13 |
|    4 | Best   |     0.15243 |      0.4285 |     0.15243 |     0.15243 |      0.28044 |            1 |           45 |            3 |
|    5 | Accept |     0.16065 |     0.69005 |     0.15243 |     0.15243 |      0.20151 |            7 |           60 |            1 |
|    6 | Best   |     0.14831 |     0.36629 |     0.14831 |     0.14831 |     0.032423 |            1 |          151 |            1 |
|    7 | Accept |     0.14887 |     0.36443 |     0.14831 |     0.14831 |     0.021093 |            1 |           15 |            1 |
|    8 | Accept |     0.15039 |     0.42139 |     0.14831 |     0.14831 |     0.012128 |            2 |          482 |            1 |
|    9 | Best   |     0.14787 |     0.42482 |     0.14787 |     0.14787 |      0.10119 |            1 |          121 |            6 |
|   10 | Best   |     0.13902 |     0.38822 |     0.13902 |     0.13902 |       0.1233 |            1 |          281 |            3 |
|   11 | Accept |     0.14721 |     0.39532 |     0.13902 |     0.13902 |     0.065618 |            1 |          291 |            3 |
|   12 | Accept |     0.14586 |     0.39205 |     0.13902 |     0.13902 |      0.18711 |            1 |          117 |            1 |
|   13 | Accept |     0.15073 |       0.383 |     0.13902 |     0.13902 |      0.15072 |            1 |           15 |            3 |
|   14 | Accept |     0.14966 |     0.42744 |     0.13902 |     0.13902 |      0.17155 |            1 |          497 |            4 |
|   15 | Best   |     0.13716 |     0.37599 |     0.13716 |     0.13716 |      0.12601 |            1 |          281 |            1 |
|   16 | Accept |     0.15094 |     0.38197 |     0.13716 |     0.13716 |      0.13962 |            2 |          284 |            1 |
|   17 | Accept |     0.13972 |      4.5994 |     0.13716 |     0.13716 |    0.0028545 |            5 |          481 |            2 |
|   18 | Accept |     0.14788 |      31.639 |     0.13716 |     0.13716 |    0.0024433 |            6 |          489 |           15 |
|   19 | Accept |     0.14565 |       1.276 |     0.13716 |     0.13716 |     0.013118 |            5 |          257 |            1 |
|   20 | Accept |     0.16502 |      28.315 |     0.13716 |     0.13716 |    0.0063353 |            4 |          457 |           16 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|   21 | Accept |     0.15693 |      4.9653 |     0.13716 |     0.13716 |     0.016486 |            6 |          466 |            2 |
|   22 | Accept |     0.16312 |      29.942 |     0.13716 |     0.13716 |     0.019904 |            5 |          488 |           15 |
|   23 | Accept |     0.15719 |      4.7423 |     0.13716 |     0.13716 |     0.020155 |            4 |          456 |            3 |
|   24 | Best   |       0.129 |      6.4419 |       0.129 |       0.129 |     0.090858 |            5 |          478 |            3 |
|   25 | Accept |     0.15118 |      6.6757 |       0.129 |       0.129 |      0.15943 |            5 |          494 |            3 |
|   26 | Accept |     0.15343 |      2.2035 |       0.129 |       0.129 |     0.070349 |            5 |          489 |            1 |
|   27 | Best   |     0.12879 |      6.8017 |     0.12879 |     0.12879 |     0.091985 |            5 |          387 |            4 |
|   28 | Accept |     0.19093 |      5.9262 |     0.12879 |     0.12879 |     0.067405 |            5 |          331 |            4 |
|   29 | Accept |     0.16767 |      6.3779 |     0.12879 |     0.12879 |      0.31419 |            5 |          472 |            3 |
|   30 | Accept |     0.17636 |      11.026 |     0.12879 |     0.12879 |     0.054697 |            5 |          383 |            7 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 239.1035 seconds
Total objective function evaluation time: 216.5833

Best observed feasible point:
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Observed objective function value = 0.12879
Estimated objective function value = 0.12879
Function evaluation time = 6.8017

Best estimated feasible point (according to models):
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Estimated objective function value = 0.12879
Estimated function evaluation time = 6.7245
warning(orig_state)

Obtain the best point from results2.

zbest2 = bestPoint(results2)
zbest2=1×4 table
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Train Bivariate GAM with Optimal Parameters

Train an optimized GAM using the zbest1 and zbest2 values.

Mdl = fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'ClassNames',categorical({'<=50K','>50K'}), ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction, ...   
    'Interactions',zbest2.numInteractions) 
Mdl = 
  ClassificationGAM
           PredictorNames: {'age'  'workClass'  'education'  'education_num'  'marital_status'  'occupation'  'relationship'  'race'  'sex'  'capital_gain'  'capital_loss'  'hours_per_week'  'native_country'}
             ResponseName: 'salary'
    CategoricalPredictors: [2 3 5 6 7 8 9 13]
               ClassNames: [<=50K    >50K]
           ScoreTransform: 'logit'
                Intercept: -1.7755
             Interactions: [4×2 double]
          NumObservations: 500


  Properties, Methods

Alternatively, you can add interaction terms to the univariate GAM by using the addInteractions function.

Mdl2 = addInteractions(Mdl1,zbest2.numInteractions, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction); 

The second input argument specifies the maximum number of interaction terms, and the NumTreesPerInteraction name-value argument specifies the maximum number of trees per interaction term. The addInteractions function can include fewer interaction terms and stop before training the requested number of trees. You can use the Interactions and ReasonForTermination properties to check the actual number of interaction terms and number of trees in the trained model.

Display the interaction terms in Mdl.

Mdl.Interactions
ans = 4×2

     7    10
     4     7
     7     9
     5    10

Each row of Interactions represents one interaction term and contains the column indexes of the predictor variables for the interaction term. You can use the Interactions property to check the interaction terms in the model and the order in which fitcgam adds them to the model.

Display the interaction terms in Mdl using the predictor names.

Mdl.PredictorNames(Mdl.Interactions)
ans = 4×2 cell
    {'relationship'  }    {'capital_gain'}
    {'education_num' }    {'relationship'}
    {'relationship'  }    {'sex'         }
    {'marital_status'}    {'capital_gain'}

Display the reason for termination to determine whether the model contains the specified number of trees for each linear term and each interaction term.

Mdl.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: 'Terminated after training the requested number of trees.'

Assess Predictive Performance on New Observations

Assess the performance of the trained model by using the test sample adulttest and the object functions predict, loss, edge, and margin. You can use a full or compact model with these functions.

  • predict — Classify observations

  • loss — Compute classification loss (misclassification rate in decimal, by default)

  • margin — Compute classification margins

  • edge — Compute classification edge (average of classification margins)

If you want to assess the performance of the training data set, use the resubstitution object functions: resubPredict, resubLoss, resubMargin, and resubEdge. To use these functions, you must use the full model that contains the training data.

Create a compact model to reduce the size of the trained model.

CMdl = compact(Mdl);
whos('Mdl','CMdl')
  Name      Size              Bytes  Class                                                 Attributes

  CMdl      1x1             3272176  classreg.learning.classif.CompactClassificationGAM              
  Mdl       1x1             3389515  ClassificationGAM                                               

Predict labels and scores for the test data set (adulttest), and compute model statistics (loss, margin, and edge) using the test data set.

[labels,scores] = predict(CMdl,adulttest);
L = loss(CMdl,adulttest,'Weights',adulttest.fnlwgt);
M = margin(CMdl,adulttest);
E = edge(CMdl,adulttest,'Weights',adulttest.fnlwgt);

Predict labels and scores and compute the statistics without including interaction terms in the trained model.

[labels_nointeraction,scores_nointeraction] = predict(CMdl,adulttest,'IncludeInteractions',false);
L_nointeractions = loss(CMdl,adulttest,'Weights',adulttest.fnlwgt,'IncludeInteractions',false);
M_nointeractions = margin(CMdl,adulttest,'IncludeInteractions',false);
E_nointeractions = edge(CMdl,adulttest,'Weights',adulttest.fnlwgt,'IncludeInteractions',false);

Compare the results obtained by including both linear and interaction terms to the results obtained by including only linear terms.

Create a table containing the observed labels, predicted labels, and scores. Display the first eight rows of the table.

t = table(adulttest.salary,labels,scores,labels_nointeraction,scores_nointeraction, ...
    'VariableNames',{'True Labels','Predicted Labels','Scores' ...
    'Predicted Labels without interactions','Scores without interactions'});
head(t)
ans=8×5 table
    True Labels    Predicted Labels           Scores            Predicted Labels without interactions    Scores without interactions
    ___________    ________________    _____________________    _____________________________________    ___________________________

       <=50K            <=50K          0.97921      0.020787                    <=50K                       0.98005     0.019951    
       <=50K            <=50K                1     8.258e-17                    <=50K                        0.9713     0.028696    
       <=50K            <=50K                1    1.8297e-19                    <=50K                       0.99449    0.0055054    
       <=50K            <=50K          0.87422       0.12578                    <=50K                       0.87729      0.12271    
       <=50K            <=50K                1    3.5643e-07                    <=50K                       0.99882    0.0011769    
       <=50K            <=50K          0.60371       0.39629                    <=50K                       0.77861      0.22139    
       <=50K            >50K           0.49917       0.50083                    >50K                        0.46877      0.53123    
       >50K             >50K            0.3109        0.6891                    <=50K                       0.53571      0.46429    

Create a confusion chart from the true labels adulttest.salary and the predicted labels.

tiledlayout(1,2);
nexttile
confusionchart(adulttest.salary,labels)
title('Linear and Interaction Terms')
nexttile
confusionchart(adulttest.salary,labels_nointeraction)
title('Linear Terms Only')

Display the computed loss and edge values.

table([L; E], [L_nointeractions; E_nointeractions], ...
    'VariableNames',{'Linear and Interaction Terms','Only Linear Terms'}, ...
    'RowNames',{'Loss','Edge'})
ans=2×2 table
            Linear and Interaction Terms    Only Linear Terms
            ____________________________    _________________

    Loss              0.14868                    0.13852     
    Edge              0.63926                    0.58405     

The model achieves a smaller loss when only linear terms are included, but achieves a higher edge value when both linear and interaction terms are included.

Display the distributions of the margins using box plots.

figure
boxplot([M M_nointeractions],'Labels',{'Linear and Interaction Terms','Linear Terms Only'})
title('Box Plots of Test Sample Margins')

Interpret Prediction

Interpret the prediction for the first test observation by using the plotLocalEffects function. Also, create partial dependence plots for some important terms in the model by using the plotPartialDependence function.

Classify the first observation of the test data, and plot the local effects of the terms in CMdl on the prediction. To display an existing underscore in any predictor name, change the TickLabelInterpreter value of the axes to 'none'.

label = predict(CMdl,adulttest(1,:))
label = categorical
     <=50K 

f1 = figure;
plotLocalEffects(CMdl,adulttest(1,:))
f1.CurrentAxes.TickLabelInterpreter = 'none';

The predict function classifies the first observation adulttest(1,:) as '<=50K'. The plotLocalEffects function creates a horizontal bar graph that shows the local effects of the 10 most important terms on the prediction. Each local effect value shows the contribution of each term to the classification score for '<=50K', which is the logit of the posterior probability that the classification is '<=50K' for the observation.

Create a partial dependence plot for the term age. Specify both the training and test data sets to compute the partial dependence values using both sets.

figure
plotPartialDependence(CMdl,'age',label,[adultdata; adulttest])

The plotted line represents the averaged partial relationships between the predictor age and the score of the class <=50K in the trained model. The x-axis minor ticks represent the unique values in the predictor age.

Create partial dependence plots for the terms education_num and relationship.

f2 = figure;
plotPartialDependence(CMdl,["education_num","relationship"],label,[adultdata; adulttest])
f2.CurrentAxes.TickLabelInterpreter = 'none';

The plot shows the partial dependence on education_num, which has a different trend depending on the relationship value.

See Also

| | | | | |

Related Topics