Classification Trees and Regression Trees

What Are Classification Trees and Regression Trees?

Classification trees and regression trees predict responses to data. To predict a response, follow the decisions in the tree from the root (beginning) node down to a leaf node. The leaf node contains the response. Classification trees give responses that are nominal, such as 'true' or 'false'. Regression trees give numeric responses.

Statistics Toolbox™ trees are binary. Each step in a prediction involves checking the value of one predictor (variable). For example, here is a simple classification tree:

This tree predicts classifications based on two predictors, x1 and x2. To predict, start at the top node, represented by a triangle (Δ). The first decision is whether x1 is smaller than 0.5. If so, follow the left branch, and see that the tree classifies the data as type 0.

If, however, x1 exceeds 0.5, then follow the right branch to the lower-right triangle node. Here the tree asks if x2 is smaller than 0.5. If so, then follow the left branch to see that the tree classifies the data as type 0. If not, then follow the right branch to see that the that the tree classifies the data as type 1.

To learn how to prepare your data for classification or regression using decision trees, see Steps in Supervised Learning (Machine Learning).

Creating a Classification Tree

To create a classification tree for the ionosphere data:

load ionosphere % contains X and Y variables
ctree = fitctree(X,Y)
ctree = 

  ClassificationTree
           PredictorNames: {1x34 cell}
             ResponseName: 'Y'
               ClassNames: {'b'  'g'}
           ScoreTransform: 'none'
    CategoricalPredictors: []
            NumObservations: 351


  Properties, Methods

Creating a Regression Tree

To create a regression tree for the carsmall data based on the Horsepower and Weight vectors for data, and MPG vector for response:

load carsmall % contains Horsepower, Weight, MPG
X = [Horsepower Weight];
rtree = fitrtree(X,MPG)
rtree = 

  RegressionTree
           PredictorNames: {'x1'  'x2'}
             ResponseName: 'Y'
        ResponseTransform: 'none'
    CategoricalPredictors: []
            NumObservations: 94


  Properties, Methods

Viewing a Classification or Regression Tree

This example shows how to view a classification or a regression tree. There are two ways to view a tree: view(tree) returns a text description and view(tree,'mode','graph') returns a graphic description of the tree.

Create and view a classification tree.

load fisheriris % load the sample data
ctree = fitctree(meas,species); % create classification tree
view(ctree) % text description
Decision tree for classification
1  if x3<2.45 then node 2 elseif x3>=2.45 then node 3 else setosa
2  class = setosa
3  if x4<1.75 then node 4 elseif x4>=1.75 then node 5 else versicolor
4  if x3<4.95 then node 6 elseif x3>=4.95 then node 7 else versicolor
5  class = virginica
6  if x4<1.65 then node 8 elseif x4>=1.65 then node 9 else versicolor
7  class = virginica
8  class = versicolor
9  class = virginica

view(ctree,'mode','graph') % graphic description

Now, create and view a regression tree.

load carsmall % load the sample data, contains Horsepower, Weight, MPG
X = [Horsepower Weight];
rtree = fitrtree(X,MPG,'MinParent',30); % create classification tree
view(rtree) % text description
Decision tree for regression
1  if x2<3085.5 then node 2 elseif x2>=3085.5 then node 3 else 23.7181
2  if x1<89 then node 4 elseif x1>=89 then node 5 else 28.7931
3  if x1<115 then node 6 elseif x1>=115 then node 7 else 15.5417
4  if x2<2162 then node 8 elseif x2>=2162 then node 9 else 30.9375
5  fit = 24.0882
6  fit = 19.625
7  fit = 14.375
8  fit = 33.3056
9  fit = 29

view(rtree,'mode','graph') % graphic description

How the Fit Methods Create Trees

The fitctree and fitrtree methods perform the following steps to create decision trees:

  1. Start with all input data, and examine all possible binary splits on every predictor.

  2. Select a split with best optimization criterion.

    • If the split leads to a child node having too few observations (less than the MinLeaf parameter), select a split with the best optimization criterion subject to the MinLeaf constraint.

  3. Impose the split.

  4. Repeat recursively for the two child nodes.

The explanation requires two more items: description of the optimization criterion, and stopping rule.

Stopping rule: Stop splitting when any of the following hold:

  • The node is pure.

    • For classification, a node is pure if it contains only observations of one class.

    • For regression, a node is pure if the mean squared error (MSE) for the observed response in this node drops below the MSE for the observed response in the entire data multiplied by the tolerance on quadratic error per node (qetoler parameter).

  • There are fewer than MinParent observations in this node.

  • Any split imposed on this node would produce children with fewer than MinLeaf observations.

Optimization criterion:

  • Regression: mean-squared error (MSE). Choose a split to minimize the MSE of predictions compared to the training data.

  • Classification: One of three measures, depending on the setting of the SplitCriterion name-value pair:

    • 'gdi' (Gini's diversity index, the default)

    • 'twoing'

    • 'deviance'

    For details, see ClassificationTree Definitions.

For a continuous predictor, a tree can split halfway between any two adjacent unique values found for this predictor. For a categorical predictor with L levels, a classification tree needs to consider 2L–1–1 splits to find the optimal split. Alternatively, you can choose a heuristic algorithm to find a good split, as described in Splitting Categorical Predictors.

Predicting Responses With Classification and Regression Trees

After creating a tree, you can easily predict responses for new data. Suppose Xnew is new data that has the same number of columns as the original data X. To predict the classification or regression based on the tree and the new data, enter

Ynew = predict(tree,Xnew);

For each row of data in Xnew, predict runs through the decisions in tree and gives the resulting prediction in the corresponding element of Ynew. For more information for classification, see the classification predict reference page; for regression, see the regression predict reference page.

For example, to find the predicted classification of a point at the mean of the ionosphere data:

load ionosphere % contains X and Y variables
ctree = fitctree(X,Y);
Ynew = predict(ctree,mean(X))
Ynew = 

    'g'

To find the predicted MPG of a point at the mean of the carsmall data:

load carsmall % contains Horsepower, Weight, MPG
X = [Horsepower Weight];
rtree = fitrtree(X,MPG);
Ynew = predict(rtree,mean(X))
Ynew =

   28.7931

Improving Classification Trees and Regression Trees

You can tune trees by setting name-value pairs in fitctree and fitrtree. The remainder of this section describes how to determine the quality of a tree, how to decide which name-value pairs to set, and how to control the size of a tree:

Examining Resubstitution Error

Resubstitution error is the difference between the response training data and the predictions the tree makes of the response based on the input training data. If the resubstitution error is high, you cannot expect the predictions of the tree to be good. However, having low resubstitution error does not guarantee good predictions for new data. Resubstitution error is often an overly optimistic estimate of the predictive error on new data.

Example: Resubstitution Error of a Classification Tree.  Examine the resubstitution error of a default classification tree for the Fisher iris data:

load fisheriris
ctree = fitctree(meas,species);
resuberror = resubLoss(ctree)
resuberror =

    0.0200

The tree classifies nearly all the Fisher iris data correctly.

Cross Validation

To get a better sense of the predictive accuracy of your tree for new data, cross validate the tree. By default, cross validation splits the training data into 10 parts at random. It trains 10 new trees, each one on nine parts of the data. It then examines the predictive accuracy of each new tree on the data not included in training that tree. This method gives a good estimate of the predictive accuracy of the resulting tree, since it tests the new trees on new data.

Example: Cross Validating a Regression Tree.  Examine the resubstitution and cross-validation accuracy of a regression tree for predicting mileage based on the carsmall data:

load carsmall
X = [Acceleration Displacement Horsepower Weight];
rtree = fitrtree(X,MPG);
resuberror = resubLoss(rtree)
resuberror =
    4.7188

The resubstitution loss for a regression tree is the mean-squared error. The resulting value indicates that a typical predictive error for the tree is about the square root of 4.7, or a bit over 2.

Now calculate the error by cross validating the tree:

rng('default')
cvrtree = crossval(rtree);
cvloss = kfoldLoss(cvrtree)
cvloss =

   23.8065

The cross-validated loss is almost 25, meaning a typical predictive error for the tree on new data is about 5. This demonstrates that cross-validated loss is usually higher than simple resubstitution loss.

Control Depth or "Leafiness"

When you grow a decision tree, consider its simplicity and predictive power. A deep tree with many leaves is usually highly accurate on the training data. However, the tree is not guaranteed to show a comparable accuracy on an independent test set. A leafy tree tends to overtrain, and its test accuracy is often far less than its training (resubstitution) accuracy. In contrast, a shallow tree does not attain high training accuracy. But a shallow tree can be more robust — its training accuracy could be close to that of a representative test set. Also, a shallow tree is easy to interpret.

If you do not have enough data for training and test, estimate tree accuracy by cross validation.

For an alternative method of controlling the tree depth, see Pruning.

Select Appropriate Tree Depth.  

This example shows how to control the depth of a decision tree, and how to choose an appropriate depth.

Load the ionosphere data:

load ionosphere

Generate minimum leaf occupancies for classification trees from 10 to 100, spaced exponentially apart:

leafs = logspace(1,2,10);

Create cross validated classification trees for the ionosphere data with minimum leaf occupancies from leafs:

rng('default')
N = numel(leafs);
err = zeros(N,1);
for n=1:N
    t = fitctree(X,Y,'CrossVal','On',...
        'MinLeaf',leafs(n));
    err(n) = kfoldLoss(t);
end
plot(leafs,err);
xlabel('Min Leaf Size');
ylabel('cross-validated error');

The best leaf size is between about 20 and 50 observations per leaf.

Compare the near-optimal tree with at least 40 observations per leaf with the default tree, which uses 10 observations per parent node and 1 observation per leaf.

DefaultTree = fitctree(X,Y);
view(DefaultTree,'Mode','Graph')

OptimalTree = fitctree(X,Y,'minleaf',40);
view(OptimalTree,'mode','graph')

resubOpt = resubLoss(OptimalTree);
lossOpt = kfoldLoss(crossval(OptimalTree));
resubDefault = resubLoss(DefaultTree);
lossDefault = kfoldLoss(crossval(DefaultTree));
resubOpt,resubDefault,lossOpt,lossDefault
resubOpt =

    0.0883


resubDefault =

    0.0114


lossOpt =

    0.1054


lossDefault =

    0.1111

The near-optimal tree is much smaller and gives a much higher resubstitution error. Yet it gives similar accuracy for cross-validated data.

Pruning

Pruning optimizes tree depth (leafiness) is by merging leaves on the same tree branch. Control Depth or "Leafiness" describes one method for selecting the optimal depth for a tree. Unlike in that section, you do not need to grow a new tree for every node size. Instead, grow a deep tree, and prune it to the level you choose.

Prune a tree at the command line using the prune method (classification) or prune method (regression). Alternatively, prune a tree interactively with the tree viewer:

view(tree,'mode','graph')

To prune a tree, the tree must contain a pruning sequence. By default, both fitctree and fitrtree calculate a pruning sequence for a tree during construction. If you construct a tree with the 'Prune' name-value pair set to 'off', or if you prune a tree to a smaller level, the tree does not contain the full pruning sequence. Generate the full pruning sequence with the prune method (classification) or prune method (regression).

Prune a Classification Tree.  

This example creates a classification tree for the ionosphere data, and prunes it to a good level.

Load the ionosphere data:

load ionosphere

Construct a default classification tree for the data:

tree = fitctree(X,Y);

View the tree in the interactive viewer:

view(tree,'Mode','Graph')

Find the optimal pruning level by minimizing cross-validated loss:

[~,~,~,bestlevel] = cvLoss(tree,...
    'SubTrees','All','TreeSize','min')
bestlevel =

     6

Prune the tree to level 6:

view(tree,'Mode','Graph','Prune',6)

Alternatively, use the interactive window to prune the tree.

The pruned tree is the same as the near-optimal tree in the "Select Appropriate Tree Depth" example.

Set 'TreeSize' to 'SE' (default) to find the maximal pruning level for which the tree error does not exceed the error from the best level plus one standard deviation:

[~,~,~,bestlevel] = cvLoss(tree,'SubTrees','All')
bestlevel =

     6

In this case the level is the same for either setting of 'TreeSize'.

Prune the tree to use it for other purposes:

tree = prune(tree,'Level',6);
view(tree,'Mode','Graph')

Alternative: classregtree

The ClassificationTree and RegressionTree classes were released in MATLAB® R2011a. Previously, you represented both classification trees and regression trees with a classregtree object. The new classes provide all the functionality of the classregtree class, and are more convenient when used with Ensemble Methods.

Statistics Toolbox software maintains classregtree and its predecessors treefit, treedisp, treeval, treeprune, and treetest for backward compatibility. These functions will be removed in a future release.

Train Classification Trees Using classregtree

This example uses Fisher's iris data in fisheriris.mat to create a classification tree for predicting species using measurements of sepal length, sepal width, petal length, and petal width as predictors. Here, the predictors are continuous and the response is categorical.

Load the data and use the classregtree constructor of the classregtree class to create the classification tree.

load fisheriris

t = classregtree(meas,species,...
                 'Names',{'SL' 'SW' 'PL' 'PW'})
t = 

Decision tree for classification
1  if PL<2.45 then node 2 elseif PL>=2.45 then node 3 else setosa
2  class = setosa
3  if PW<1.75 then node 4 elseif PW>=1.75 then node 5 else versicolor
4  if PL<4.95 then node 6 elseif PL>=4.95 then node 7 else versicolor
5  class = virginica
6  if PW<1.65 then node 8 elseif PW>=1.65 then node 9 else versicolor
7  class = virginica
8  class = versicolor
9  class = virginica

t is a classregtree object and can be operated on with any class method.

Use the type method of the classregtree class to show the type of the tree.

treetype = type(t)
treetype =

classification

classregtree creates a classification tree because species is a cell array of strings, and the response is assumed to be categorical.

To view the tree, use the view method of the classregtree class.

view(t)

The tree predicts the response values at the circular leaf nodes based on a series of questions about the iris at the triangular branching nodes. A true answer to any question follows the branch to the left. A false follows the branch to the right.

The tree does not use sepal measurements for predicting species. These can go unmeasured in new data, and you can enter them as NaN values for predictions. For example, use the tree to predict the species of an iris with petal length 4.8 and petal width 1.6.

predicted = t([NaN NaN 4.8 1.6])
predicted = 

    'versicolor'

The object allows for functional evaluation, of the form t(X). This is a shorthand way of calling the eval method of the classregtree class. The predicted species is the left leaf node at the bottom of the tree in the previous view.

You can use a variety of methods of the classregtree class, such as cutvar and cuttype to get more information about the split at node 6 that makes the final distinction between versicolor and virginica.

var6 = cutvar(t,6) % What variable determines the split?
type6 = cuttype(t,6) % What type of split is it?
var6 = 

    'PW'


type6 = 

    'continuous'

Classification trees fit the original (training) data well, but can do a poor job of classifying new values. Lower branches, especially, can be strongly affected by outliers. A simpler tree often avoids overfitting. You can use the prune method of the classregtree class to find the next largest tree from an optimal pruning sequence.

pruned = prune(t,'Level',1)
view(pruned)
pruned = 

Decision tree for classification
1  if PL<2.45 then node 2 elseif PL>=2.45 then node 3 else setosa
2  class = setosa
3  if PW<1.75 then node 4 elseif PW>=1.75 then node 5 else versicolor
4  if PL<4.95 then node 6 elseif PL>=4.95 then node 7 else versicolor
5  class = virginica
6  class = versicolor
7  class = virginica

To find the best classification tree, employing the techniques of resubstitution and cross validation, use the test method of the classregtree class.

Train Regression Trees Using classregtree

This example uses the data on cars in carsmall.mat to create a regression tree for predicting mileage using measurements of weight and the number of cylinders as predictors. Here, one predictor (weight) is continuous and the other (cylinders) is categorical. The response (mileage) is continuous.

Load the data and use the classregtree constructor of the classregtree class to create the regression tree:

load carsmall

t = classregtree([Weight, Cylinders],MPG,...
                 'Categorical',2,'MinParent',20,...
                 'Names',{'W','C'})
t = 

Decision tree for regression
 1  if W<3085.5 then node 2 elseif W>=3085.5 then node 3 else 23.7181
 2  if W<2371 then node 4 elseif W>=2371 then node 5 else 28.7931
 3  if C=8 then node 6 elseif C in {4 6} then node 7 else 15.5417
 4  if W<2162 then node 8 elseif W>=2162 then node 9 else 32.0741
 5  if C=6 then node 10 elseif C=4 then node 11 else 25.9355
 6  if W<4381 then node 12 elseif W>=4381 then node 13 else 14.2963
 7  fit = 19.2778
 8  fit = 33.3056
 9  fit = 29.6111
10  fit = 23.25
11  if W<2827.5 then node 14 elseif W>=2827.5 then node 15 else 27.2143
12  if W<3533.5 then node 16 elseif W>=3533.5 then node 17 else 14.8696
13  fit = 11
14  fit = 27.6389
15  fit = 24.6667
16  fit = 16.6
17  fit = 14.3889

t is a classregtree object and can be operated on with any of the methods of the class.

Use the type method of the classregtree class to show the type of the tree:

treetype = type(t)
treetype =

regression

classregtree creates a regression tree because MPG is a numerical vector, and the response is assumed to be continuous.

To view the tree, use the view method of the classregtree class:

view(t)

The tree predicts the response values at the circular leaf nodes based on a series of questions about the car at the triangular branching nodes. A true answer to any question follows the branch to the left; a false follows the branch to the right.

Use the tree to predict the mileage for a 2000-pound car with either 4, 6, or 8 cylinders:

mileage2K = t([2000 4; 2000 6; 2000 8])
mileage2K =

   33.3056
   33.3056
   33.3056

The object allows for functional evaluation, of the form t(X). This is a shorthand way of calling the eval method of the classregtree class.

5.The predicted responses computed above are all the same. This is because they follow a series of splits in the tree that depend only on weight, terminating at the leftmost leaf node in the view above. A 4000-pound car, following the right branch from the top of the tree, leads to different predicted responses:

mileage4K = t([4000 4; 4000 6; 4000 8])
mileage4K =

   19.2778
   19.2778
   14.3889

You can use a variety of other methods of the classregtree class, such as cutvar, cuttype, and cutcategories, to get more information about the split at node 3 that distinguishes the 8-cylinder car:

var3 = cutvar(t,3)      % What variable determines the split?
type3 = cuttype(t,3)    % What type of split is it?
c = cutcategories(t,3); % Which classes are sent to the left
                        % child node, and which to the right?
leftChildNode = c{1}
rightChildNode = c{2}
var3 = 

    'C'


type3 = 

    'categorical'


leftChildNode =

     8


rightChildNode =

     4     6

Regression trees fit the original (training) data well, but may do a poor job of predicting new values. Lower branches, especially, may be strongly affected by outliers. A simpler tree often avoids overfitting. To find the best regression tree, employing the techniques of resubstitution and cross validation, use the test method of the classregtree class.

Was this topic helpful?