Main Content

Train Generalized Additive Model for Regression

This example shows how to train a generalized additive model (GAM) with optimal parameters and how to assess the predictive performance of the trained model. The example first finds the optimal parameter values for a univariate GAM (parameters for linear terms) and then finds the values for a bivariate GAM (parameters for interaction terms). Also, The example explains how to interpret the trained model by examining local effects of terms on a specific prediction and by computing the partial dependence of the predictions on predictors.

Load Sample Data

Load the sample data set NYCHousing2015.

load NYCHousing2015

The data set includes 10 variables with information on the sales of properties in New York City in 2015. This example uses these variables to analyze the sale prices (SALEPRICE).

Preprocess the data set. Assume that a SALEPRICE less than or equal to $1000 indicates ownership transfer without a cash consideration. Remove the samples that have this SALEPRICE. Also, remove the outliers identified by the isoutlier function. Then, convert the datetime array (SALEDATE) to the month numbers and move the response variable (SALEPRICE) to the last column. Change zeros in LANDSQUAREFEET, GROSSSQUAREFEET, and YEARBUILT to NaNs.

idx1 = NYCHousing2015.SALEPRICE <= 1000;
idx2 = isoutlier(NYCHousing2015.SALEPRICE);
NYCHousing2015(idx1|idx2,:) = [];
NYCHousing2015.SALEDATE = month(NYCHousing2015.SALEDATE);
NYCHousing2015 = movevars(NYCHousing2015,'SALEPRICE','After','SALEDATE');
NYCHousing2015.LANDSQUAREFEET(NYCHousing2015.LANDSQUAREFEET == 0) = NaN; 
NYCHousing2015.GROSSSQUAREFEET(NYCHousing2015.GROSSSQUAREFEET == 0) = NaN; 
NYCHousing2015.YEARBUILT(NYCHousing2015.YEARBUILT == 0) = NaN; 

Display the first three rows of the table.

head(NYCHousing2015,3)
ans=3×10 table
    BOROUGH    NEIGHBORHOOD       BUILDINGCLASSCATEGORY        RESIDENTIALUNITS    COMMERCIALUNITS    LANDSQUAREFEET    GROSSSQUAREFEET    YEARBUILT    SALEDATE    SALEPRICE
    _______    ____________    ____________________________    ________________    _______________    ______________    _______________    _________    ________    _________

       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   0                1103              1290            1910          2           3e+05 
       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   1                2500              2452            1910          7           4e+05 
       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   2                1911              4080            1931          1         5.1e+05 

Randomly select 1000 samples by using the datasample function, and partition observations into a training set and a test set by using the cvpartition function. Specify a 10% holdout sample for testing.

rng('default') % For reproducibility
NumSamples = 1e3;
NYCHousing2015 = datasample(NYCHousing2015,NumSamples,'Replace',false);
cv = cvpartition(size(NYCHousing2015,1),'HoldOut',0.10);

Extract the training and test indices, and create tables for training and test data sets.

tbl_training = NYCHousing2015(training(cv),:);
tbl_test = NYCHousing2015(test(cv),:);

Find Optimal Parameters for Univariate GAM

Optimize the parameters for a univariate GAM with respect to cross-validation by using the bayesopt function.

Prepare optimizableVariable objects for the name-value arguments of a univariate GAM: MaxNumSplitsPerPredictor, NumTreesPerPredictor, and InitialLearnRateForPredictors.

maxNumSplitsPerPredictor = optimizableVariable('maxNumSplitsPerPredictor',[1,10],'Type','integer');
numTreesPerPredictor = optimizableVariable('numTreesPerPredictor',[1,500],'Type','integer');
initialLearnRateForPredictors = optimizableVariable('initialLearnRateForPredictors',[1e-3,1],'Type','real');

Create an objective function that takes an input z = [maxNumSplitsPerPredictor,numTreesPerPredictor,initialLearnRateForPredictors] and returns the cross-validated loss value at the parameters in z.

minfun1 = @(z)kfoldLoss(fitrgam(tbl_training,'SALEPRICE', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',z.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',z.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',z.numTreesPerPredictor));

If you specify the cross-validation option 'CrossVal','on', then the fitrgam function returns a cross-validated model object RegressionPartitionedGAM. The kfoldLoss function returns the regression loss (mean squared error) obtained by the cross-validated model. Therefore, the function handle minfun1 computes the cross-validation loss at the parameters in z.

Search for the best parameters using bayesopt. For reproducibility, choose the 'expected-improvement-plus' acquisition function. The default acquisition function depends on run time and, therefore, can give varying results.

results1 = bayesopt(minfun1, ...
    [initialLearnRateForPredictors,maxNumSplitsPerPredictor,numTreesPerPredictor], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|    1 | Best   |  8.4558e+10 |      1.5106 |  8.4558e+10 |  8.4558e+10 |      0.36695 |            2 |           30 |
|    2 | Accept |  8.6891e+10 |       12.01 |  8.4558e+10 |  8.4558e+10 |     0.008213 |            5 |          271 |
|    3 | Accept |  9.6521e+10 |      1.9121 |  8.4558e+10 |  8.4558e+10 |      0.22984 |            9 |           37 |
|    4 | Accept |  1.3402e+11 |      14.388 |  8.4558e+10 |  8.4558e+10 |      0.99932 |            3 |          344 |
|    5 | Accept |  8.7852e+10 |      13.595 |  8.4558e+10 |  8.4558e+10 |      0.16575 |            1 |          456 |
|    6 | Accept |  9.3041e+10 |      11.002 |  8.4558e+10 |  8.4558e+10 |      0.49477 |            1 |          360 |
|    7 | Accept |  1.0558e+11 |      7.7647 |  8.4558e+10 |  8.4558e+10 |      0.24562 |            4 |          175 |
|    8 | Accept |  8.8841e+10 |      1.5763 |  8.4558e+10 |  8.4558e+10 |      0.39298 |            2 |           41 |
|    9 | Accept |  9.9227e+10 |      14.377 |  8.4558e+10 |  8.4558e+10 |     0.091879 |            3 |          358 |
|   10 | Accept |  9.8611e+10 |     0.14914 |  8.4558e+10 |  8.4558e+10 |      0.22487 |            2 |            2 |
|   11 | Accept |  1.2998e+11 |      23.962 |  8.4558e+10 |  8.4558e+10 |      0.25341 |            5 |          500 |
|   12 | Accept |  8.8968e+10 |      5.0028 |  8.4558e+10 |  8.4558e+10 |      0.33109 |            1 |          175 |
|   13 | Accept |  1.2018e+11 |      1.8004 |  8.4558e+10 |  8.4558e+10 |    0.0030413 |            6 |           40 |
|   14 | Accept |  8.7503e+10 |     0.79283 |  8.4558e+10 |  8.4558e+10 |      0.33877 |            1 |           25 |
|   15 | Accept |  9.3798e+10 |      2.9578 |  8.4558e+10 |  8.4558e+10 |      0.32926 |            2 |           80 |
|   16 | Accept |  9.5165e+10 |      8.0635 |  8.4558e+10 |  8.4558e+10 |      0.33878 |            1 |          282 |
|   17 | Best   |  8.3549e+10 |     0.24446 |  8.3549e+10 |  8.3549e+10 |       0.3552 |            2 |            5 |
|   18 | Best   |  8.3104e+10 |      1.4534 |  8.3104e+10 |  8.3104e+10 |       0.2526 |            1 |           49 |
|   19 | Accept |  8.6938e+10 |      3.3234 |  8.3104e+10 |  8.3104e+10 |      0.18293 |            1 |          110 |
|   20 | Accept |  8.7531e+10 |      2.8096 |  8.3104e+10 |  8.3104e+10 |       0.2781 |            1 |           93 |
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|   21 | Accept |  9.1613e+10 |      13.347 |  8.3104e+10 |  8.3104e+10 |      0.31722 |            1 |          464 |
|   22 | Accept |   8.678e+10 |      10.358 |  8.3104e+10 |  8.3104e+10 |      0.11269 |            1 |          358 |
|   23 | Accept |  8.3614e+10 |     0.47001 |  8.3104e+10 |  8.3104e+10 |      0.22278 |            1 |           14 |
|   24 | Accept |  1.3203e+11 |       1.069 |  8.3104e+10 |  8.3104e+10 |    0.0021552 |            5 |           23 |
|   25 | Accept |    8.66e+10 |       7.233 |  8.3104e+10 |  8.3104e+10 |      0.11469 |            1 |          236 |
|   26 | Accept |  8.4535e+10 |      8.7657 |  8.3104e+10 |  8.3104e+10 |    0.0090628 |            1 |          292 |
|   27 | Accept |  1.0315e+11 |      12.297 |  8.3104e+10 |  8.3104e+10 |    0.0014094 |            1 |          413 |
|   28 | Accept |  9.6736e+10 |      5.8323 |  8.3104e+10 |  8.3104e+10 |    0.0040429 |            1 |          202 |
|   29 | Accept |  8.3651e+10 |      8.4999 |  8.3104e+10 |  8.3104e+10 |      0.09375 |            1 |          295 |
|   30 | Accept |  8.7977e+10 |      13.521 |  8.3104e+10 |  8.3104e+10 |     0.016448 |            6 |          292 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 245.1541 seconds
Total objective function evaluation time: 210.0881

Best observed feasible point:
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Observed objective function value = 83103839919.908
Estimated objective function value = 83103840296.3186
Function evaluation time = 1.4534

Best estimated feasible point (according to models):
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Estimated objective function value = 83103840296.3186
Estimated function evaluation time = 1.803

Obtain the best point from results1.

zbest1 = bestPoint(results1)
zbest1=1×3 table
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Train Univariate GAM with Optimal Parameters

Train an optimized GAM using the zbest1 values.

Mdl1 = fitrgam(tbl_training,'SALEPRICE', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor) 
Mdl1 = 
  RegressionGAM
           PredictorNames: {'BOROUGH'  'NEIGHBORHOOD'  'BUILDINGCLASSCATEGORY'  'RESIDENTIALUNITS'  'COMMERCIALUNITS'  'LANDSQUAREFEET'  'GROSSSQUAREFEET'  'YEARBUILT'  'SALEDATE'}
             ResponseName: 'SALEPRICE'
    CategoricalPredictors: [2 3]
        ResponseTransform: 'none'
                Intercept: 4.9806e+05
          NumObservations: 900


  Properties, Methods

Mdl1 is a RegressionGAM model object. The model display shows a partial list of the model properties. To view the full list of properties, double-click the variable name Mdl1 in the Workspace. The Variables editor opens for Mdl1. Alternatively, you can display the properties in the Command Window by using dot notation. For example, display the ReasonForTermination property.

Mdl1.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: ''

The PredictorTrees field of the property value indicates that Mdl1 includes the specified number of trees. NumTreesPerPredictor of fitrgam specifies the maximum number of trees per predictor, and the function can stop before training the requested number of trees. You can use the ReasonForTermination property to determine whether the trained model contains the specified number of trees.

If you specify to include interaction terms so that fitrgam trains trees for them. the InteractionTrees field contains a nonempty value.

Find Optimal Parameters for Bivariate GAM

Find the parameters for interaction terms of a bivariate GAM by using the bayesopt function.

Prepare optimizableVariable for the name-value arguments for the interaction terms: InitialLearnRateForInteractions, MaxNumSplitsPerInteraction, NumTreesPerInteraction, and InitialLearnRateForInteractions.

initialLearnRateForInteractions = optimizableVariable('initialLearnRateForInteractions',[1e-3,1],'Type','real');
maxNumSplitsPerInteraction = optimizableVariable('maxNumSplitsPerInteraction',[1,10],'Type','integer');
numTreesPerInteraction = optimizableVariable('numTreesPerInteraction',[1,500],'Type','integer');
numInteractions = optimizableVariable('numInteractions',[1,28],'Type','integer');

Create an objective function for the optimization. Use the optimal parameter values in zbest1 so that the software finds optimal parameter values for interaction terms based on the zbest1 values.

minfun2 = @(z)kfoldLoss(fitrgam(tbl_training,'SALEPRICE', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',z.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',z.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',z.numTreesPerInteraction, ...
    'Interactions',z.numInteractions));

Search for the best parameters using bayesopt. The optimization process trains multiple models and displays warning messages if the models include no interaction terms. Disable all warnings before calling bayesopt and restore the warning state after running bayesopt. You can leave the warning state unchanged to view the warning messages.

orig_state = warning('query'); 
warning('off')
results2 = bayesopt(minfun2, ...
    [initialLearnRateForInteractions,maxNumSplitsPerInteraction,numTreesPerInteraction,numInteractions], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|    1 | Best   |  8.4721e+10 |      1.6996 |  8.4721e+10 |  8.4721e+10 |      0.41774 |            1 |          346 |           28 |
|    2 | Accept |  9.1765e+10 |      8.3313 |  8.4721e+10 |  8.4721e+10 |       0.9565 |            3 |          231 |           14 |
|    3 | Accept |  9.2116e+10 |      2.8341 |  8.4721e+10 |  8.4721e+10 |      0.33578 |            9 |           45 |            5 |
|    4 | Accept |   1.784e+11 |      76.237 |  8.4721e+10 |  8.4721e+10 |      0.91186 |           10 |          479 |           27 |
|    5 | Accept |  8.4906e+10 |      1.8275 |  8.4721e+10 |  8.4721e+10 |        0.296 |            4 |            1 |           27 |
|    6 | Best   |  8.4172e+10 |        1.73 |  8.4172e+10 |  8.4172e+10 |      0.68133 |            1 |           86 |            1 |
|    7 | Best   |   8.234e+10 |      1.7164 |   8.234e+10 |   8.234e+10 |      0.13943 |            1 |          228 |           26 |
|    8 | Accept |  8.3488e+10 |      1.6382 |   8.234e+10 |   8.234e+10 |      0.46764 |            1 |            1 |            5 |
|    9 | Accept |  8.7977e+10 |      1.5655 |   8.234e+10 |   8.234e+10 |       0.8385 |           10 |            1 |            5 |
|   10 | Accept |  8.4431e+10 |      1.5744 |   8.234e+10 |   8.234e+10 |      0.95535 |            1 |          261 |            4 |
|   11 | Accept |  8.5784e+10 |      1.7478 |   8.234e+10 |   8.234e+10 |     0.023058 |            7 |            1 |           14 |
|   12 | Accept |  8.6068e+10 |      1.7304 |   8.234e+10 |   8.234e+10 |      0.77118 |            1 |            5 |           28 |
|   13 | Accept |  8.7004e+10 |      1.5903 |   8.234e+10 |   8.234e+10 |     0.016991 |            1 |          263 |            2 |
|   14 | Accept |  8.3325e+10 |      1.5895 |   8.234e+10 |   8.234e+10 |       0.9468 |            4 |            7 |            1 |
|   15 | Accept |  8.4097e+10 |      1.6357 |   8.234e+10 |   8.234e+10 |      0.97988 |            1 |          250 |           28 |
|   16 | Accept |  8.3106e+10 |      1.6081 |   8.234e+10 |   8.234e+10 |     0.024052 |            1 |          121 |           28 |
|   17 | Accept |   8.469e+10 |      1.6235 |   8.234e+10 |   8.234e+10 |     0.047902 |            3 |            3 |           12 |
|   18 | Best   |  8.1641e+10 |      1.5833 |  8.1641e+10 |  8.1641e+10 |      0.99848 |            6 |            1 |            3 |
|   19 | Accept |  8.5957e+10 |      1.6305 |  8.1641e+10 |  8.1641e+10 |      0.99826 |            6 |            1 |           13 |
|   20 | Accept |  8.2486e+10 |      1.6515 |  8.1641e+10 |  8.1641e+10 |      0.36059 |            7 |            2 |            1 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|   21 | Accept |  8.6534e+10 |       1.647 |  8.1641e+10 |  8.1641e+10 |    0.0089186 |            1 |          192 |           18 |
|   22 | Accept |  8.5425e+10 |      1.5316 |  8.1641e+10 |  8.1641e+10 |      0.99842 |            1 |          497 |            1 |
|   23 | Accept |   8.515e+10 |      1.5728 |  8.1641e+10 |  8.1641e+10 |      0.99934 |            1 |            3 |            2 |
|   24 | Accept |   8.593e+10 |      1.6086 |  8.1641e+10 |  8.1641e+10 |    0.0099052 |            1 |            2 |           28 |
|   25 | Accept |  8.7394e+10 |       1.577 |  8.1641e+10 |  8.1641e+10 |      0.96502 |            7 |            5 |            2 |
|   26 | Accept |   8.618e+10 |      1.5714 |  8.1641e+10 |  8.1641e+10 |     0.097871 |            5 |            3 |            1 |
|   27 | Accept |  8.5704e+10 |       1.665 |  8.1641e+10 |  8.1641e+10 |     0.056356 |           10 |            6 |            3 |
|   28 | Accept |  9.5451e+10 |      2.8821 |  8.1641e+10 |  8.1641e+10 |      0.91844 |            3 |           12 |           28 |
|   29 | Accept |  8.4013e+10 |      1.5633 |  8.1641e+10 |  8.1641e+10 |      0.68016 |            6 |            1 |            1 |
|   30 | Accept |  8.3928e+10 |      1.7715 |  8.1641e+10 |  8.1641e+10 |      0.07259 |            5 |            5 |           14 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 155.1459 seconds
Total objective function evaluation time: 132.9347

Best observed feasible point:
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Observed objective function value = 81640836929.8637
Estimated objective function value = 81640841484.6238
Function evaluation time = 1.5833

Best estimated feasible point (according to models):
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Estimated objective function value = 81640841484.6238
Estimated function evaluation time = 1.5784
warning(orig_state)

Obtain the best point from results2.

zbest2 = bestPoint(results2)
zbest2=1×4 table
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Train Bivariate GAM with Optimal Parameters

Train an optimized GAM using the zbest1 and zbest2 values.

Mdl = fitrgam(tbl_training,'SALEPRICE', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction, ...   
    'Interactions',zbest2.numInteractions) 
Mdl = 
  RegressionGAM
           PredictorNames: {'BOROUGH'  'NEIGHBORHOOD'  'BUILDINGCLASSCATEGORY'  'RESIDENTIALUNITS'  'COMMERCIALUNITS'  'LANDSQUAREFEET'  'GROSSSQUAREFEET'  'YEARBUILT'  'SALEDATE'}
             ResponseName: 'SALEPRICE'
    CategoricalPredictors: [2 3]
        ResponseTransform: 'none'
                Intercept: 4.9741e+05
             Interactions: [3×2 double]
          NumObservations: 900


  Properties, Methods

Alternatively, you can add interaction terms to the univariate GAM by using the addInteractions function.

Mdl2 = addInteractions(Mdl1,zbest2.numInteractions, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction); 

The second input argument specifies the maximum number of interaction terms, and the NumTreesPerInteraction name-value argument specifies the maximum number of trees per interaction term. The addInteractions function can include fewer interaction terms and stop before training the requested number of trees. You can use the Interactions and ReasonForTermination properties to check the actual number of interaction terms and number of trees in the trained model.

Display the interaction terms in Mdl.

Mdl.Interactions
ans = 3×2

     3     6
     4     6
     6     8

Each row of Interactions represents one interaction term and contains the column indexes of the predictor variables for the interaction term. You can use the Interactions property to check the interaction terms in the model and the order in which fitrgam adds them to the model.

Display the interaction terms in Mdl using the predictor names.

Mdl.PredictorNames(Mdl.Interactions)
ans = 3×2 cell
    {'BUILDINGCLASSCATEGORY'}    {'LANDSQUAREFEET'}
    {'RESIDENTIALUNITS'     }    {'LANDSQUAREFEET'}
    {'LANDSQUAREFEET'       }    {'YEARBUILT'     }

Display the reason for termination to determine whether the model contains the specified number of trees for each linear term and each interaction term.

Mdl.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: 'Terminated after training the requested number of trees.'

Assess Predictive Performance on New Observations

Assess the performance of the trained model by using the test sample tbl_test and the object functions predict and loss. You can use a full or compact model with these functions.

  • predict — Predict responses

  • loss — Compute regression loss (mean squared error, by default)

If you want to assess the performance of the training data set, use the resubstitution object functions: resubPredict and resubLoss. To use these functions, you must use the full model that contains the training data.

Create a compact model to reduce the size of the trained model.

CMdl = compact(Mdl);
whos('Mdl','CMdl')
  Name      Size             Bytes  Class                                          Attributes

  CMdl      1x1             370211  classreg.learning.regr.CompactRegressionGAM              
  Mdl       1x1             528102  RegressionGAM                                            

Predict responses and compute mean squared errors for the test data set tbl_test.

yFit = predict(CMdl,tbl_test);
L = loss(CMdl,tbl_test)
L = 1.2855e+11

Find predicted responses and errors without including interaction terms in the trained model.

yFit_nointeraction = predict(CMdl,tbl_test,'IncludeInteractions',false);
L_nointeractions = loss(CMdl,tbl_test,'IncludeInteractions',false)
L_nointeractions = 1.3031e+11

The model achieves a smaller error for the test data set when both linear and interaction terms are included.

Compare the results obtained by including both linear to interaction terms and the results obtained by including only linear terms. Create a table containing the observed responses and predicted responses. Display the first eight rows of the table.

t = table(tbl_test.SALEPRICE,yFit,yFit_nointeraction, ...
    'VariableNames',{'Observed Value','Predicted Response','Predicted Response Without Interactions'});
head(t)
ans=8×3 table
    Observed Value    Predicted Response    Predicted Response Without Interactions
    ______________    __________________    _______________________________________

         3.6e+05          4.9812e+05                      5.2712e+05               
         1.8e+05          2.7349e+05                      2.7415e+05               
         1.9e+05          3.3682e+05                      3.3748e+05               
        4.26e+05            6.15e+05                      5.6542e+05               
        3.91e+05          3.1262e+05                      3.1328e+05               
         2.3e+05          1.0606e+05                      1.0672e+05               
      4.7333e+05          1.0773e+06                      1.1399e+06               
           2e+05          2.9506e+05                       3.305e+05               

Interpret Prediction

Interpret the prediction for the first test observation by using the plotLocalEffects function. Also, create partial dependence plots for some important terms in the model by using the plotPartialDependence function.

Predict a response value for the first observation of the test data, and plot the local effects of the terms in CMdl on the prediction. Specify 'IncludeIntercept',true to include the intercept term in the plot.

yFit = predict(CMdl,tbl_test(1,:))
yFit = 4.9812e+05
plotLocalEffects(CMdl,tbl_test(1,:),'IncludeIntercept',true)

The predict function returns the sale price for the first observation tbl_test(1,:). The plotLocalEffects function creates a horizontal bar graph that shows the local effects of the terms in CMdl on the prediction. Each local effect value shows the contribution of each term to the predicted sale price for tbl_test(1,:).

Compute the partial dependence values for BUILDINGCLASSCATEGORY and plot the sorted values. Specify both the training and test data sets to compute the partial dependence values using both sets.

[pd,x,y] = partialDependence(CMdl,'BUILDINGCLASSCATEGORY',[tbl_training; tbl_test]);
[pd_sorted,I] = sort(pd);
x_sorted = x(I);
x_sorted = reordercats(x_sorted,I);
figure
plot(x_sorted,pd_sorted,'o:')
xlabel('BUILDINGCLASSCATEGORY')
ylabel('SALEPRICE')
title('Patial Dependence Plot')

The plotted line represents the averaged partial relationships between the predictor BUILDINGCLASSCATEGORY and the response SALEPRICE in the trained model.

Create a partial dependence plot for the terms RESIDENTIALUNITS and LANDSQUAREFEET.

figure
plotPartialDependence(CMdl,["RESIDENTIALUNITS","LANDSQUAREFEET"],[tbl_training; tbl_test])

The minor ticks in the x-axis (RESIDENTIALUNITS) and y-axis (LANDSQUAREFEET) represent the unique values of the predictors in the specified data. The predictor values include a few outliers, and most of the RESIDENTIALUNITS and LANDSQUAREFEET values are less than 10 and 50,000, respectively. The plot shows that the SALEPRICE values do not vary significantly when the RESIDENTIALUNITS and LANDSQUAREFEET values are greater than 10 and 50,000.

See Also

| | | | | |

Related Topics