Documentation

This is machine translation

Translated by Microsoft
Mouse over text to see original. Click the button below to return to the English verison of the page.

Decision Trees

What Are Decision Trees?

Decision trees, or Classification trees and regression trees, predict responses to data. To predict a response, follow the decisions in the tree from the root (beginning) node down to a leaf node. The leaf node contains the response. Classification trees give responses that are nominal, such as 'true' or 'false'. Regression trees give numeric responses.

Statistics and Machine Learning Toolbox™ trees are binary. Each step in a prediction involves checking the value of one predictor (variable). For example, here is a simple classification tree:

This tree predicts classifications based on two predictors, x1 and x2. To predict, start at the top node, represented by a triangle (Δ). The first decision is whether x1 is smaller than 0.5. If so, follow the left branch, and see that the tree classifies the data as type 0.

If, however, x1 exceeds 0.5, then follow the right branch to the lower-right triangle node. Here the tree asks if x2 is smaller than 0.5. If so, then follow the left branch to see that the tree classifies the data as type 0. If not, then follow the right branch to see that the that the tree classifies the data as type 1.

To learn how to prepare your data for classification or regression using decision trees, see Steps in Supervised Learning.

Train Classification Tree

This example shows how to train a classification tree.

Create a classification tree using the entire ionosphere data set.

load ionosphere % Contains X and Y variables
Mdl = fitctree(X,Y)
Mdl = 

  ClassificationTree
             ResponseName: 'Y'
    CategoricalPredictors: []
               ClassNames: {'b'  'g'}
           ScoreTransform: 'none'
          NumObservations: 351


Train Regression Tree

This example shows how to train a regression tree.

Create a regression tree using all observation in the carsmall data set. Consider the Horsepower and Weight vectors as predictor variables, and the MPG vector as the response.

load carsmall % Contains Horsepower, Weight, MPG
X = [Horsepower Weight];

Mdl = fitrtree(X,MPG)
Mdl = 

  RegressionTree
             ResponseName: 'Y'
    CategoricalPredictors: []
        ResponseTransform: 'none'
          NumObservations: 94


Viewing a Classification or Regression Tree

This example shows how to view a classification or a regression tree. There are two ways to view a tree: view(tree) returns a text description and view(tree,'mode','graph') returns a graphic description of the tree.

Create and view a classification tree.

load fisheriris % load the sample data
ctree = fitctree(meas,species); % create classification tree
view(ctree) % text description
Decision tree for classification
1  if x3<2.45 then node 2 elseif x3>=2.45 then node 3 else setosa
2  class = setosa
3  if x4<1.75 then node 4 elseif x4>=1.75 then node 5 else versicolor
4  if x3<4.95 then node 6 elseif x3>=4.95 then node 7 else versicolor
5  class = virginica
6  if x4<1.65 then node 8 elseif x4>=1.65 then node 9 else versicolor
7  class = virginica
8  class = versicolor
9  class = virginica

view(ctree,'mode','graph') % graphic description

Now, create and view a regression tree.

load carsmall % load the sample data, contains Horsepower, Weight, MPG
X = [Horsepower Weight];
rtree = fitrtree(X,MPG,'MinParent',30); % create classification tree
view(rtree) % text description
Decision tree for regression
1  if x2<3085.5 then node 2 elseif x2>=3085.5 then node 3 else 23.7181
2  if x1<89 then node 4 elseif x1>=89 then node 5 else 28.7931
3  if x1<115 then node 6 elseif x1>=115 then node 7 else 15.5417
4  if x2<2162 then node 8 elseif x2>=2162 then node 9 else 30.9375
5  fit = 24.0882
6  fit = 19.625
7  fit = 14.375
8  fit = 33.3056
9  fit = 29

view(rtree,'mode','graph') % graphic description

How the Fit Methods Create Trees

By default, fitctree and fitrtree use the standard CART algorithm [1] to create decision trees. That is, they perform the following steps:

  1. Start with all input data, and examine all possible binary splits on every predictor.

  2. Select a split with best optimization criterion.

    • A split might lead to a child node having too few observations (less than the MinLeafSize parameter). To avoid this, the software chooses a split that yields the best optimization criterion subject to the MinLeafSize constraint.

  3. Impose the split.

  4. Repeat recursively for the two child nodes.

The explanation requires two more items: description of the optimization criterion, and stopping rule.

Stopping rule: Stop splitting when any of the following hold:

  • The node is pure.

    • For classification, a node is pure if it contains only observations of one class.

    • For regression, a node is pure if the mean squared error (MSE) for the observed response in this node drops below the MSE for the observed response in the entire data multiplied by the tolerance on quadratic error per node (QuadraticErrorTolerance parameter).

  • There are fewer than MinParentSize observations in this node.

  • Any split imposed on this node produces children with fewer than MinLeafSize observations.

  • The algorithm splits MaxNumSplits nodes.

Optimization criterion:

  • Regression: mean-squared error (MSE). Choose a split to minimize the MSE of predictions compared to the training data.

  • Classification: One of three measures, depending on the setting of the SplitCriterion name-value pair:

    • 'gdi' (Gini's diversity index, the default)

    • 'twoing'

    • 'deviance'

    For details, see ClassificationTree Definitions.

For a continuous predictor, a tree can split halfway between any two adjacent unique values found for this predictor. For a categorical predictor with L levels, a classification tree needs to consider 2L–1–1 splits to find the optimal split. Alternatively, you can choose a heuristic algorithm to find a good split, as described in Splitting Categorical Predictors.

For dual-core systems and above, fitctree and fitrtree parallelize training decision trees using Intel® Threading Building Blocks (TBB). For details on Intel TBB, see https://software.intel.com/en-us/intel-tbb.

Prediction Using Classification and Regression Trees

This example shows how to predict class labels or responses using trained classification and regression trees.

After creating a tree, you can easily predict responses for new data. Suppose Xnew is new data that has the same number of columns as the original data X. To predict the classification or regression based on the tree (Mdl) and the new data, enter

Ynew = predict(Mdl,Xnew)

For each row of data in Xnew, predict runs through the decisions in Mdl and gives the resulting prediction in the corresponding element of Ynew. For more information on classification tree prediction, see the CompactClassificationTree.predict. For regression, see CompactRegressionTree.predict.

For example, find the predicted classification of a point at the mean of the ionosphere data.

load ionosphere
CMdl = fitctree(X,Y);
Ynew = predict(CMdl,mean(X))
Ynew =

  cell

    'g'

Find the predicted MPG of a point at the mean of the carsmall data.

load carsmall
X = [Horsepower Weight];
RMdl = fitrtree(X,MPG);
Ynew = predict(RMdl,mean(X))
Ynew =

   28.7931

Predict Out-of-Sample Responses of Subtrees

This example hows how to predict out-of-sample responses of regression trees, and then plot the results.

Load the carsmall data set. Consider Weight as a predictor of the response MPG.

load carsmall
idxNaN = isnan(MPG + Weight);
X = Weight(~idxNaN);
Y = MPG(~idxNaN);
n = numel(X);

Partition the data into training (50%) and validation (50%) sets.

rng(1) % For reproducibility
idxTrn = false(n,1);
idxTrn(randsample(n,round(0.5*n))) = true; % Training set logical indices
idxVal = idxTrn == false;                  % Validation set logical indices

Grow a regression tree using the training observations.

Mdl = fitrtree(X(idxTrn),Y(idxTrn));
view(Mdl,'Mode','graph')

Compute fitted values of the validation observations for each of several subtrees.

m = max(Mdl.PruneList);
pruneLevels = 0:2:m; % Pruning levels to consider
z = numel(pruneLevels);
Yfit = predict(Mdl,X(idxVal),'SubTrees',pruneLevels);

Yfit is an n-by- z matrix of fitted values in which the rows correspond to observations and the columns correspond to a subtree.

Plot Yfit and Y against X.

figure;
sortDat = sortrows([X(idxVal) Y(idxVal) Yfit],1); % Sort all data with respect to X
plot(sortDat(:,1),sortDat(:,2),'*');
hold on;
plot(repmat(sortDat(:,1),1,size(Yfit,2)),sortDat(:,3:end));
lev = cellstr(num2str((pruneLevels)','Level %d MPG'));
legend(['Observed MPG'; lev])
title 'Out-of-Sample Predictions'
xlabel 'Weight (lbs)';
ylabel 'MPG';
h = findobj(gcf);
axis tight;
set(h(4:end),'LineWidth',3) % Widen all lines

The values of Yfit for lower pruning levels tend to follow the data more closely than higher levels. Higher pruning levels tend to be flat for large X intervals.

Improving Classification Trees and Regression Trees

You can tune trees by setting name-value pairs in fitctree and fitrtree. The remainder of this section describes how to determine the quality of a tree, how to decide which name-value pairs to set, and how to control the size of a tree:

Examining Resubstitution Error

Resubstitution error is the difference between the response training data and the predictions the tree makes of the response based on the input training data. If the resubstitution error is high, you cannot expect the predictions of the tree to be good. However, having low resubstitution error does not guarantee good predictions for new data. Resubstitution error is often an overly optimistic estimate of the predictive error on new data.

Classification Tree Resubstitution Error  

This example shows how to examine the resubstitution error of a classification tree.

Load Fisher's iris data.

load fisheriris

Train a default classification tree using the entire data set.

Mdl = fitctree(meas,species);

Examine the resubstitution error.

resuberror = resubLoss(Mdl)
resuberror =

    0.0200

The tree classifies nearly all the Fisher iris data correctly.

Cross Validation

To get a better sense of the predictive accuracy of your tree for new data, cross validate the tree. By default, cross validation splits the training data into 10 parts at random. It trains 10 new trees, each one on nine parts of the data. It then examines the predictive accuracy of each new tree on the data not included in training that tree. This method gives a good estimate of the predictive accuracy of the resulting tree, since it tests the new trees on new data.

Cross Validate a Regression Tree  

This example shows how to examine the resubstitution and cross-validation accuracy of a regression tree for predicting mileage based on the carsmall data.

Load the carsmall data set. Consider acceleration, displacement, horsepower, and weight as predictors of MPG.

load carsmall
X = [Acceleration Displacement Horsepower Weight];

Grow a regression tree using all of the observations.

rtree = fitrtree(X,MPG);

Compute the in-sample error.

resuberror = resubLoss(rtree)
resuberror =

    4.7188

The resubstitution loss for a regression tree is the mean-squared error. The resulting value indicates that a typical predictive error for the tree is about the square root of 4.7, or a bit over 2.

Estimate the cross-validation MSE.

rng 'default';
cvrtree = crossval(rtree);
cvloss = kfoldLoss(cvrtree)
cvloss =

   23.8065

The cross-validated loss is almost 25, meaning a typical predictive error for the tree on new data is about 5. This demonstrates that cross-validated loss is usually higher than simple resubstitution loss.

Choose Split Predictor Selection Technique

The standard CART algorithm tends to select continuous predictors that have many levels. Sometimes, such a selection can be spurious and can also mask more important predictors that have fewer levels, such as categorical predictors. That is, the predictor-selection process at each node is biased. Also, standard CART tends to miss the important interactions between pairs of predictors and the response.

To mitigate selection bias and increase detection of important interactions, you can specify usage of the curvature or interaction tests using the 'PredictorSelection' name-value pair argument. Using the curvature or interaction test has the added advantage of producing better predictor importance estimates than standard CART.

This table summarizes the supported predictor-selection techniques.

Technique'PredictorSelection' ValueDescriptionTraining speedWhen to specify
Standard CART [1]Default

Selects the split predictor that maximizes the split-criterion gain over all possible splits of all predictors.

Baseline for comparisonSpecify if any of these conditions are true:
  • All predictors are continuous

  • Predictor importance is not the analysis goal

  • For boosting decision trees

Curvature test [35][34]'curvature'Selects the split predictor that minimizes the p-value of chi-square tests of independence between each predictor and the response.Comparable to standard CARTSpecify if any of these conditions are true:
  • The predictor variables are heterogeneous

  • Predictor importance is an analysis goal

  • Enhance tree interpretation

Interaction test [2]'interaction-curvature'Chooses the split predictor that minimizes the p-value of chi-square tests of independence between each predictor and the response (that is, conducts curvature tests), and that minimizes the p-value of a chi-square test of independence between each pair of predictors and response.Slower than standard CART, particularly when data set contains many predictor variables.Specify if any of these conditions are true:
  • The predictor variables are heterogeneous

  • You suspect associations between pairs of predictors and the response

  • Predictor importance is an analysis goal

  • Enhance tee interpretation

For more details on predictor selection techniques:

Control Depth or "Leafiness"

When you grow a decision tree, consider its simplicity and predictive power. A deep tree with many leaves is usually highly accurate on the training data. However, the tree is not guaranteed to show a comparable accuracy on an independent test set. A leafy tree tends to overtrain (or overfit), and its test accuracy is often far less than its training (resubstitution) accuracy. In contrast, a shallow tree does not attain high training accuracy. But a shallow tree can be more robust — its training accuracy could be close to that of a representative test set. Also, a shallow tree is easy to interpret. If you do not have enough data for training and test, estimate tree accuracy by cross validation.

fitctree and fitrtree have three name-value pair arguments that control the depth of resulting decision trees:

  • MaxNumSplits — The maximal number of branch node splits is MaxNumSplits per tree. Set a large value for MaxNumSplits to get a deep tree. The default is size(X,1) – 1.

  • MinLeafSize — Each leaf has at least MinLeafSize observations. Set small values of MinLeafSize to get deep trees. The default is 1.

  • MinParentSize — Each branch node in the tree has at least MinParentSize observations. Set small values of MinParentSize to get deep trees. The default is 10.

If you specify MinParentSize and MinLeafSize, the learner uses the setting that yields trees with larger leaves (i.e., shallower trees):

MinParent = max(MinParentSize,2*MinLeafSize)

If you supply MaxNumSplits, the software splits a tree until one of the three splitting criteria is satisfied.

For an alternative method of controlling the tree depth, see Pruning.

Select Appropriate Tree Depth  

This example shows how to control the depth of a decision tree, and how to choose an appropriate depth.

Load the ionosphere data.

load ionosphere

Generate an exponentially spaced set of values from 10 through 100 that represent the minimum number of observations per leaf node.

leafs = logspace(1,2,10);

Create cross-validated classification trees for the ionosphere data. Specify to grow each tree using a minimum leaf size in leafs.

rng('default')
N = numel(leafs);
err = zeros(N,1);
for n=1:N
    t = fitctree(X,Y,'CrossVal','On',...
        'MinLeafSize',leafs(n));
    err(n) = kfoldLoss(t);
end
plot(leafs,err);
xlabel('Min Leaf Size');
ylabel('cross-validated error');

The best leaf size is between about 20 and 50 observations per leaf.

Compare the near-optimal tree with at least 40 observations per leaf with the default tree, which uses 10 observations per parent node and 1 observation per leaf.

DefaultTree = fitctree(X,Y);
view(DefaultTree,'Mode','Graph')

OptimalTree = fitctree(X,Y,'MinLeafSize',40);
view(OptimalTree,'mode','graph')

resubOpt = resubLoss(OptimalTree);
lossOpt = kfoldLoss(crossval(OptimalTree));
resubDefault = resubLoss(DefaultTree);
lossDefault = kfoldLoss(crossval(DefaultTree));
resubOpt,resubDefault,lossOpt,lossDefault
resubOpt =

    0.0883


resubDefault =

    0.0114


lossOpt =

    0.1054


lossDefault =

    0.1111

The near-optimal tree is much smaller and gives a much higher resubstitution error. Yet, it gives similar accuracy for cross-validated data.

Pruning

Pruning optimizes tree depth (leafiness) is by merging leaves on the same tree branch. Control Depth or "Leafiness" describes one method for selecting the optimal depth for a tree. Unlike in that section, you do not need to grow a new tree for every node size. Instead, grow a deep tree, and prune it to the level you choose.

Prune a tree at the command line using the prune method (classification) or prune method (regression). Alternatively, prune a tree interactively with the tree viewer:

view(tree,'mode','graph')

To prune a tree, the tree must contain a pruning sequence. By default, both fitctree and fitrtree calculate a pruning sequence for a tree during construction. If you construct a tree with the 'Prune' name-value pair set to 'off', or if you prune a tree to a smaller level, the tree does not contain the full pruning sequence. Generate the full pruning sequence with the prune method (classification) or prune method (regression).

Prune a Classification Tree  

This example creates a classification tree for the ionosphere data, and prunes it to a good level.

Load the ionosphere data:

load ionosphere

Construct a default classification tree for the data:

tree = fitctree(X,Y);

View the tree in the interactive viewer:

view(tree,'Mode','Graph')

Find the optimal pruning level by minimizing cross-validated loss:

[~,~,~,bestlevel] = cvLoss(tree,...
    'SubTrees','All','TreeSize','min')
bestlevel =

     6

Prune the tree to level 6:

view(tree,'Mode','Graph','Prune',6)

Alternatively, use the interactive window to prune the tree.

The pruned tree is the same as the near-optimal tree in the "Select Appropriate Tree Depth" example.

Set 'TreeSize' to 'SE' (default) to find the maximal pruning level for which the tree error does not exceed the error from the best level plus one standard deviation:

[~,~,~,bestlevel] = cvLoss(tree,'SubTrees','All')
bestlevel =

     6

In this case the level is the same for either setting of 'TreeSize'.

Prune the tree to use it for other purposes:

tree = prune(tree,'Level',6);
view(tree,'Mode','Graph')

Alternative: classregtree

The ClassificationTree and RegressionTree classes were released in MATLAB® R2011a. Previously, you represented both classification trees and regression trees with a classregtree object. The new classes provide all the functionality of the classregtree class, and are more convenient when used with Ensemble Methods.

Statistics and Machine Learning Toolbox software maintains classregtree and its predecessors treefit, treedisp, treeval, treeprune, and treetest for backward compatibility. These functions will be removed in a future release.

Train Classification Trees Using classregtree

This example uses Fisher's iris data in fisheriris.mat to create a classification tree for predicting species using measurements of sepal length, sepal width, petal length, and petal width as predictors. Here, the predictors are continuous and the response is categorical.

Load the data and use the classregtree constructor of the classregtree class to create the classification tree.

load fisheriris

t = classregtree(meas,species,...
                 'Names',{'SL' 'SW' 'PL' 'PW'})
t = 

Decision tree for classification
1  if PL<2.45 then node 2 elseif PL>=2.45 then node 3 else setosa
2  class = setosa
3  if PW<1.75 then node 4 elseif PW>=1.75 then node 5 else versicolor
4  if PL<4.95 then node 6 elseif PL>=4.95 then node 7 else versicolor
5  class = virginica
6  if PW<1.65 then node 8 elseif PW>=1.65 then node 9 else versicolor
7  class = virginica
8  class = versicolor
9  class = virginica

t is a classregtree object and can be operated on with any class method.

Use the type method of the classregtree class to show the type of the tree.

treetype = type(t)
treetype =

classification

classregtree creates a classification tree because species is a cell array of character vectors, and the response is assumed to be categorical.

To view the tree, use the view method of the classregtree class.

view(t)

The tree predicts the response values at the circular leaf nodes based on a series of questions about the iris at the triangular branching nodes. A true answer to any question follows the branch to the left. A false follows the branch to the right.

The tree does not use sepal measurements for predicting species. These can go unmeasured in new data, and you can enter them as NaN values for predictions. For example, use the tree to predict the species of an iris with petal length 4.8 and petal width 1.6.

predicted = t([NaN NaN 4.8 1.6])
predicted =

  cell

    'versicolor'

The object allows for functional evaluation, of the form t(X). This is a shorthand way of calling the eval method of the classregtree class. The predicted species is the left leaf node at the bottom of the tree in the previous view.

You can use a variety of methods of the classregtree class, such as cutvar and cuttype to get more information about the split at node 6 that makes the final distinction between versicolor and virginica.

var6 = cutvar(t,6) % What variable determines the split?
type6 = cuttype(t,6) % What type of split is it?
var6 =

  cell

    'PW'


type6 =

  cell

    'continuous'

Classification trees fit the original (training) data well, but can do a poor job of classifying new values. Lower branches, especially, can be strongly affected by outliers. A simpler tree often avoids overfitting. You can use the prune method of the classregtree class to find the next largest tree from an optimal pruning sequence.

pruned = prune(t,'Level',1)
view(pruned)
pruned = 

Decision tree for classification
1  if PL<2.45 then node 2 elseif PL>=2.45 then node 3 else setosa
2  class = setosa
3  if PW<1.75 then node 4 elseif PW>=1.75 then node 5 else versicolor
4  if PL<4.95 then node 6 elseif PL>=4.95 then node 7 else versicolor
5  class = virginica
6  class = versicolor
7  class = virginica

To find the best classification tree, employing the techniques of resubstitution and cross validation, use the test method of the classregtree class.

Train Regression Trees Using classregtree

This example uses the data on cars in carsmall.mat to create a regression tree for predicting mileage using measurements of weight and the number of cylinders as predictors. Here, one predictor (weight) is continuous and the other (cylinders) is categorical. The response (mileage) is continuous.

Load the data and use the classregtree constructor of the classregtree class to create the regression tree:

load carsmall

t = classregtree([Weight, Cylinders],MPG,...
                 'Categorical',2,'MinParent',20,...
                 'Names',{'W','C'})
t = 

Decision tree for regression
 1  if W<3085.5 then node 2 elseif W>=3085.5 then node 3 else 23.7181
 2  if W<2371 then node 4 elseif W>=2371 then node 5 else 28.7931
 3  if C=8 then node 6 elseif C in {4 6} then node 7 else 15.5417
 4  if W<2162 then node 8 elseif W>=2162 then node 9 else 32.0741
 5  if C=6 then node 10 elseif C=4 then node 11 else 25.9355
 6  if W<4381 then node 12 elseif W>=4381 then node 13 else 14.2963
 7  fit = 19.2778
 8  fit = 33.3056
 9  fit = 29.6111
10  fit = 23.25
11  if W<2827.5 then node 14 elseif W>=2827.5 then node 15 else 27.2143
12  if W<3533.5 then node 16 elseif W>=3533.5 then node 17 else 14.8696
13  fit = 11
14  fit = 27.6389
15  fit = 24.6667
16  fit = 16.6
17  fit = 14.3889

t is a classregtree object and can be operated on with any of the methods of the class.

Use the type method of the classregtree class to show the type of the tree:

treetype = type(t)
treetype =

regression

classregtree creates a regression tree because MPG is a numerical vector, and the response is assumed to be continuous.

To view the tree, use the view method of the classregtree class:

view(t)

The tree predicts the response values at the circular leaf nodes based on a series of questions about the car at the triangular branching nodes. A true answer to any question follows the branch to the left; a false follows the branch to the right.

Use the tree to predict the mileage for a 2000-pound car with either 4, 6, or 8 cylinders:

mileage2K = t([2000 4; 2000 6; 2000 8])
mileage2K =

   33.3056
   33.3056
   33.3056

The object allows for functional evaluation, of the form t(X). This is a shorthand way of calling the eval method of the classregtree class.

The predicted responses computed above are all the same. This is because they follow a series of splits in the tree that depend only on weight, terminating at the leftmost leaf node in the view above. A 4000-pound car, following the right branch from the top of the tree, leads to different predicted responses:

mileage4K = t([4000 4; 4000 6; 4000 8])
mileage4K =

   19.2778
   19.2778
   14.3889

You can use a variety of other methods of the classregtree class, such as cutvar, cuttype, and cutcategories, to get more information about the split at node 3 that distinguishes the 8-cylinder car:

var3 = cutvar(t,3)      % What variable determines the split?
type3 = cuttype(t,3)    % What type of split is it?
c = cutcategories(t,3); % Which classes are sent to the left
                        % child node, and which to the right?
leftChildNode = c{1}
rightChildNode = c{2}
var3 =

  cell

    'C'


type3 =

  cell

    'categorical'


leftChildNode =

     8


rightChildNode =

     4     6

Regression trees fit the original (training) data well, but may do a poor job of predicting new values. Lower branches, especially, may be strongly affected by outliers. A simpler tree often avoids overfitting. To find the best regression tree, employing the techniques of resubstitution and cross validation, use the test method of the classregtree class.

Was this topic helpful?