Decision trees, or Classification trees and regression trees,
predict responses to data. To predict a response, follow the decisions
in the tree from the root (beginning) node down to a leaf node. The
leaf node contains the response. Classification trees give responses
that are nominal, such as 'true'
or 'false'
.
Regression trees give numeric responses.
Statistics and Machine Learning Toolbox™ trees are binary. Each step in a prediction involves checking the value of one predictor (variable). For example, here is a simple classification tree:
This tree predicts classifications based on two predictors, x1
and x2
.
To predict, start at the top node, represented by a triangle (Δ).
The first decision is whether x1
is smaller than 0.5
.
If so, follow the left branch, and see that the tree classifies the
data as type 0
.
If, however, x1
exceeds 0.5
,
then follow the right branch to the lowerright triangle node. Here
the tree asks if x2
is smaller than 0.5
.
If so, then follow the left branch to see that the tree classifies
the data as type 0
. If not, then follow the right
branch to see that the that the tree classifies the data as type 1
.
To learn how to prepare your data for classification or regression using decision trees, see Steps in Supervised Learning.
This example shows how to train a classification tree.
Create a classification tree using the entire ionosphere
data set.
load ionosphere % Contains X and Y variables Mdl = fitctree(X,Y)
Mdl = ClassificationTree ResponseName: 'Y' CategoricalPredictors: [] ClassNames: {'b' 'g'} ScoreTransform: 'none' NumObservations: 351
This example shows how to train a regression tree.
Create a regression tree using all observation in the carsmall
data set. Consider the Horsepower
and Weight
vectors as predictor variables, and the MPG
vector as the response.
load carsmall % Contains Horsepower, Weight, MPG X = [Horsepower Weight]; Mdl = fitrtree(X,MPG)
Mdl = RegressionTree ResponseName: 'Y' CategoricalPredictors: [] ResponseTransform: 'none' NumObservations: 94
This example shows how to view a classification or a regression tree. There are two ways to view a tree: view(tree)
returns a text description and view(tree,'mode','graph')
returns a graphic description of the tree.
Create and view a classification tree.
load fisheriris % load the sample data ctree = fitctree(meas,species); % create classification tree view(ctree) % text description
Decision tree for classification 1 if x3<2.45 then node 2 elseif x3>=2.45 then node 3 else setosa 2 class = setosa 3 if x4<1.75 then node 4 elseif x4>=1.75 then node 5 else versicolor 4 if x3<4.95 then node 6 elseif x3>=4.95 then node 7 else versicolor 5 class = virginica 6 if x4<1.65 then node 8 elseif x4>=1.65 then node 9 else versicolor 7 class = virginica 8 class = versicolor 9 class = virginica
view(ctree,'mode','graph') % graphic description
Now, create and view a regression tree.
load carsmall % load the sample data, contains Horsepower, Weight, MPG X = [Horsepower Weight]; rtree = fitrtree(X,MPG,'MinParent',30); % create classification tree view(rtree) % text description
Decision tree for regression 1 if x2<3085.5 then node 2 elseif x2>=3085.5 then node 3 else 23.7181 2 if x1<89 then node 4 elseif x1>=89 then node 5 else 28.7931 3 if x1<115 then node 6 elseif x1>=115 then node 7 else 15.5417 4 if x2<2162 then node 8 elseif x2>=2162 then node 9 else 30.9375 5 fit = 24.0882 6 fit = 19.625 7 fit = 14.375 8 fit = 33.3056 9 fit = 29
view(rtree,'mode','graph') % graphic description
By default, fitctree
and fitrtree
use the standard CART algorithm [1] to create decision trees. That is, they perform the following
steps:
Start with all input data, and examine all possible binary splits on every predictor.
Select a split with best optimization criterion.
A split might lead to a child node having too few
observations (less than the MinLeafSize
parameter).
To avoid this, the software chooses a split that yields the best optimization
criterion subject to the MinLeafSize
constraint.
Impose the split.
Repeat recursively for the two child nodes.
The explanation requires two more items: description of the optimization criterion, and stopping rule.
Stopping rule: Stop splitting when any of the following hold:
The node is pure.
For classification, a node is pure if it contains only observations of one class.
For regression, a node is pure if the mean squared
error (MSE) for the observed response in this node drops below the
MSE for the observed response in the entire data multiplied by the
tolerance on quadratic error per node (QuadraticErrorTolerance
parameter).
There are fewer than MinParentSize
observations
in this node.
Any split imposed on this node produces children with
fewer than MinLeafSize
observations.
The algorithm splits MaxNumSplits
nodes.
Optimization criterion:
Regression: meansquared error (MSE). Choose a split to minimize the MSE of predictions compared to the training data.
Classification: One of three measures, depending on
the setting of the SplitCriterion
namevalue pair:
'gdi'
(Gini's diversity index,
the default)
'twoing'
'deviance'
For details, see ClassificationTree
Definitions.
For a continuous predictor, a tree can split halfway between any two adjacent unique values found for this predictor. For a categorical predictor with L levels, a classification tree needs to consider 2^{L–1}–1 splits to find the optimal split. Alternatively, you can choose a heuristic algorithm to find a good split, as described in Splitting Categorical Predictors.
For dualcore systems and above, fitctree
and fitrtree
parallelize training decision
trees using Intel^{®} Threading Building Blocks (TBB). For details
on Intel TBB, see https://software.intel.com/enus/inteltbb.
This example shows how to predict class labels or responses using trained classification and regression trees.
After creating a tree, you can easily predict responses for new data. Suppose Xnew
is new data that has the same number of columns as the original data X
. To predict the classification or regression based on the tree (Mdl
) and the new data, enter
Ynew = predict(Mdl,Xnew)
For each row of data in Xnew
, predict
runs through the decisions in Mdl
and gives the resulting prediction in the corresponding element of Ynew
. For more information on classification tree prediction, see the CompactClassificationTree.predict
. For regression, see CompactRegressionTree.predict
.
For example, find the predicted classification of a point at the mean of the ionosphere
data.
load ionosphere
CMdl = fitctree(X,Y);
Ynew = predict(CMdl,mean(X))
Ynew = cell 'g'
Find the predicted MPG
of a point at the mean of the carsmall
data.
load carsmall
X = [Horsepower Weight];
RMdl = fitrtree(X,MPG);
Ynew = predict(RMdl,mean(X))
Ynew = 28.7931
This example hows how to predict outofsample responses of regression trees, and then plot the results.
Load the carsmall
data set. Consider Weight
as a predictor of the response MPG
.
load carsmall
idxNaN = isnan(MPG + Weight);
X = Weight(~idxNaN);
Y = MPG(~idxNaN);
n = numel(X);
Partition the data into training (50%) and validation (50%) sets.
rng(1) % For reproducibility idxTrn = false(n,1); idxTrn(randsample(n,round(0.5*n))) = true; % Training set logical indices idxVal = idxTrn == false; % Validation set logical indices
Grow a regression tree using the training observations.
Mdl = fitrtree(X(idxTrn),Y(idxTrn)); view(Mdl,'Mode','graph')
Compute fitted values of the validation observations for each of several subtrees.
m = max(Mdl.PruneList); pruneLevels = 0:2:m; % Pruning levels to consider z = numel(pruneLevels); Yfit = predict(Mdl,X(idxVal),'SubTrees',pruneLevels);
Yfit
is an n
by z
matrix of fitted values in which the rows correspond to observations and the columns correspond to a subtree.
Plot Yfit
and Y
against X
.
figure; sortDat = sortrows([X(idxVal) Y(idxVal) Yfit],1); % Sort all data with respect to X plot(sortDat(:,1),sortDat(:,2),'*'); hold on; plot(repmat(sortDat(:,1),1,size(Yfit,2)),sortDat(:,3:end)); lev = cellstr(num2str((pruneLevels)','Level %d MPG')); legend(['Observed MPG'; lev]) title 'OutofSample Predictions' xlabel 'Weight (lbs)'; ylabel 'MPG'; h = findobj(gcf); axis tight; set(h(4:end),'LineWidth',3) % Widen all lines
The values of Yfit
for lower pruning levels tend to follow the data more closely than higher levels. Higher pruning levels tend to be flat for large X
intervals.
You can tune trees by setting namevalue pairs in fitctree
and fitrtree
.
The remainder of this section describes how to determine the quality
of a tree, how to decide which namevalue pairs to set, and how to
control the size of a tree:
Resubstitution error is the difference between the response training data and the predictions the tree makes of the response based on the input training data. If the resubstitution error is high, you cannot expect the predictions of the tree to be good. However, having low resubstitution error does not guarantee good predictions for new data. Resubstitution error is often an overly optimistic estimate of the predictive error on new data.
Classification Tree Resubstitution Error
This example shows how to examine the resubstitution error of a classification tree.
Load Fisher's iris data.
load fisheriris
Train a default classification tree using the entire data set.
Mdl = fitctree(meas,species);
Examine the resubstitution error.
resuberror = resubLoss(Mdl)
resuberror = 0.0200
The tree classifies nearly all the Fisher iris data correctly.
To get a better sense of the predictive accuracy of your tree for new data, cross validate the tree. By default, cross validation splits the training data into 10 parts at random. It trains 10 new trees, each one on nine parts of the data. It then examines the predictive accuracy of each new tree on the data not included in training that tree. This method gives a good estimate of the predictive accuracy of the resulting tree, since it tests the new trees on new data.
Cross Validate a Regression Tree
This example shows how to examine the resubstitution and crossvalidation accuracy of a regression tree for predicting mileage based on the carsmall
data.
Load the carsmall
data set. Consider acceleration, displacement, horsepower, and weight as predictors of MPG.
load carsmall
X = [Acceleration Displacement Horsepower Weight];
Grow a regression tree using all of the observations.
rtree = fitrtree(X,MPG);
Compute the insample error.
resuberror = resubLoss(rtree)
resuberror = 4.7188
The resubstitution loss for a regression tree is the meansquared error. The resulting value indicates that a typical predictive error for the tree is about the square root of 4.7, or a bit over 2.
Estimate the crossvalidation MSE.
rng 'default';
cvrtree = crossval(rtree);
cvloss = kfoldLoss(cvrtree)
cvloss = 23.8065
The crossvalidated loss is almost 25, meaning a typical predictive error for the tree on new data is about 5. This demonstrates that crossvalidated loss is usually higher than simple resubstitution loss.
The standard CART algorithm tends to select continuous predictors that have many levels. Sometimes, such a selection can be spurious and can also mask more important predictors that have fewer levels, such as categorical predictors. That is, the predictorselection process at each node is biased. Also, standard CART tends to miss the important interactions between pairs of predictors and the response.
To mitigate selection bias and increase detection of important
interactions, you can specify usage of the curvature or interaction
tests using the 'PredictorSelection'
namevalue
pair argument. Using the curvature or interaction test has the added
advantage of producing better predictor importance estimates than
standard CART.
This table summarizes the supported predictorselection techniques.
Technique  'PredictorSelection' Value  Description  Training speed  When to specify 

Standard CART [1]  Default  Selects the split predictor that maximizes the splitcriterion gain over all possible splits of all predictors.  Baseline for comparison  Specify if any of these conditions are true:

Curvature test [35][34]  'curvature'  Selects the split predictor that minimizes the pvalue of chisquare tests of independence between each predictor and the response.  Comparable to standard CART  Specify if any of these conditions are true:

Interaction test [2]  'interactioncurvature'  Chooses the split predictor that minimizes the pvalue of chisquare tests of independence between each predictor and the response (that is, conducts curvature tests), and that minimizes the pvalue of a chisquare test of independence between each pair of predictors and response.  Slower than standard CART, particularly when data set contains many predictor variables.  Specify if any of these conditions are true:

For more details on predictor selection techniques:
For classification trees, see PredictorSelection
and Node Splitting Rules.
For regression trees, see PredictorSelection
and Node Splitting Rules.
When you grow a decision tree, consider its simplicity and predictive power. A deep tree with many leaves is usually highly accurate on the training data. However, the tree is not guaranteed to show a comparable accuracy on an independent test set. A leafy tree tends to overtrain (or overfit), and its test accuracy is often far less than its training (resubstitution) accuracy. In contrast, a shallow tree does not attain high training accuracy. But a shallow tree can be more robust — its training accuracy could be close to that of a representative test set. Also, a shallow tree is easy to interpret. If you do not have enough data for training and test, estimate tree accuracy by cross validation.
fitctree
and fitrtree
have three namevalue pair arguments
that control the depth of resulting decision trees:
MaxNumSplits
— The maximal
number of branch node splits is MaxNumSplits
per
tree. Set a large value for MaxNumSplits
to get
a deep tree. The default is size(X,1) – 1
.
MinLeafSize
— Each leaf
has at least MinLeafSize
observations. Set small
values of MinLeafSize
to get deep trees. The default
is 1
.
MinParentSize
— Each branch
node in the tree has at least MinParentSize
observations.
Set small values of MinParentSize
to get deep trees.
The default is 10
.
If you specify MinParentSize
and MinLeafSize
,
the learner uses the setting that yields trees with larger leaves
(i.e., shallower trees):
MinParent = max(MinParentSize,2*MinLeafSize)
If you supply MaxNumSplits
, the software
splits a tree until one of the three splitting criteria is satisfied.
For an alternative method of controlling the tree depth, see Pruning.
Select Appropriate Tree Depth
This example shows how to control the depth of a decision tree, and how to choose an appropriate depth.
Load the ionosphere
data.
load ionosphere
Generate an exponentially spaced set of values from 10
through 100
that represent the minimum number of observations per leaf node.
leafs = logspace(1,2,10);
Create crossvalidated classification trees for the ionosphere
data. Specify to grow each tree using a minimum leaf size in leafs
.
rng('default') N = numel(leafs); err = zeros(N,1); for n=1:N t = fitctree(X,Y,'CrossVal','On',... 'MinLeafSize',leafs(n)); err(n) = kfoldLoss(t); end plot(leafs,err); xlabel('Min Leaf Size'); ylabel('crossvalidated error');
The best leaf size is between about 20
and 50
observations per leaf.
Compare the nearoptimal tree with at least 40
observations per leaf with the default tree, which uses 10
observations per parent node and 1
observation per leaf.
DefaultTree = fitctree(X,Y); view(DefaultTree,'Mode','Graph') OptimalTree = fitctree(X,Y,'MinLeafSize',40); view(OptimalTree,'mode','graph')
resubOpt = resubLoss(OptimalTree); lossOpt = kfoldLoss(crossval(OptimalTree)); resubDefault = resubLoss(DefaultTree); lossDefault = kfoldLoss(crossval(DefaultTree)); resubOpt,resubDefault,lossOpt,lossDefault
resubOpt = 0.0883 resubDefault = 0.0114 lossOpt = 0.1054 lossDefault = 0.1111
The nearoptimal tree is much smaller and gives a much higher resubstitution error. Yet, it gives similar accuracy for crossvalidated data.
Pruning optimizes tree depth (leafiness) is by merging leaves on the same tree branch. Control Depth or "Leafiness" describes one method for selecting the optimal depth for a tree. Unlike in that section, you do not need to grow a new tree for every node size. Instead, grow a deep tree, and prune it to the level you choose.
Prune a tree at the command line using the prune
method
(classification) or prune
method (regression). Alternatively,
prune a tree interactively with the tree viewer:
view(tree,'mode','graph')
To prune a tree, the tree must contain a pruning sequence. By
default, both fitctree
and fitrtree
calculate a pruning sequence for
a tree during construction. If you construct a tree with the 'Prune'
namevalue
pair set to 'off'
, or if you prune a tree to a
smaller level, the tree does not contain the full pruning sequence.
Generate the full pruning sequence with the prune
method
(classification) or prune
method (regression).
Prune a Classification Tree
This example creates a classification tree for the ionosphere
data, and prunes it to a good level.
Load the ionosphere
data:
load ionosphere
Construct a default classification tree for the data:
tree = fitctree(X,Y);
View the tree in the interactive viewer:
view(tree,'Mode','Graph')
Find the optimal pruning level by minimizing crossvalidated loss:
[~,~,~,bestlevel] = cvLoss(tree,... 'SubTrees','All','TreeSize','min')
bestlevel = 6
Prune the tree to level 6
:
view(tree,'Mode','Graph','Prune',6)
Alternatively, use the interactive window to prune the tree.
The pruned tree is the same as the nearoptimal tree in the "Select Appropriate Tree Depth" example.
Set 'TreeSize'
to 'SE'
(default) to find the maximal pruning level for which the tree error does not exceed the error from the best level plus one standard deviation:
[~,~,~,bestlevel] = cvLoss(tree,'SubTrees','All')
bestlevel = 6
In this case the level is the same for either setting of 'TreeSize'
.
Prune the tree to use it for other purposes:
tree = prune(tree,'Level',6); view(tree,'Mode','Graph')
The ClassificationTree
and RegressionTree
classes
were released in MATLAB^{®} R2011a. Previously, you represented both
classification trees and regression trees with a classregtree
object.
The new classes provide all the functionality of the classregtree
class,
and are more convenient when used with Ensemble Methods.
Statistics and Machine Learning Toolbox software maintains classregtree
and its
predecessors treefit
, treedisp
, treeval
, treeprune
, and treetest
for
backward compatibility. These functions will be removed in a future
release.
This example uses Fisher's iris data in fisheriris.mat
to create a classification tree for predicting species using measurements of sepal length, sepal width, petal length, and petal width as predictors. Here, the predictors are continuous and the response is categorical.
Load the data and use the classregtree
constructor of the classregtree
class to create the classification tree.
load fisheriris t = classregtree(meas,species,... 'Names',{'SL' 'SW' 'PL' 'PW'})
t = Decision tree for classification 1 if PL<2.45 then node 2 elseif PL>=2.45 then node 3 else setosa 2 class = setosa 3 if PW<1.75 then node 4 elseif PW>=1.75 then node 5 else versicolor 4 if PL<4.95 then node 6 elseif PL>=4.95 then node 7 else versicolor 5 class = virginica 6 if PW<1.65 then node 8 elseif PW>=1.65 then node 9 else versicolor 7 class = virginica 8 class = versicolor 9 class = virginica
t
is a classregtree
object and can be operated on with any class method.
Use the type
method of the classregtree
class to show the type of the tree.
treetype = type(t)
treetype = classification
classregtree
creates a classification tree because species
is a cell array of character vectors, and the response is assumed to be categorical.
To view the tree, use the view
method of the classregtree
class.
view(t)
The tree predicts the response values at the circular leaf nodes based on a series of questions about the iris at the triangular branching nodes. A true
answer to any question follows the branch to the left. A false
follows the branch to the right.
The tree does not use sepal measurements for predicting species. These can go unmeasured in new data, and you can enter them as NaN
values for predictions. For example, use the tree to predict the species of an iris with petal length 4.8
and petal width 1.6
.
predicted = t([NaN NaN 4.8 1.6])
predicted = cell 'versicolor'
The object allows for functional evaluation, of the form t(X)
. This is a shorthand way of calling the eval
method of the classregtree
class. The predicted species is the left leaf node at the bottom of the tree in the previous view.
You can use a variety of methods of the classregtree
class, such as cutvar
and cuttype
to get more information about the split at node 6
that makes the final distinction between versicolor
and virginica
.
var6 = cutvar(t,6) % What variable determines the split? type6 = cuttype(t,6) % What type of split is it?
var6 = cell 'PW' type6 = cell 'continuous'
Classification trees fit the original (training) data well, but can do a poor job of classifying new values. Lower branches, especially, can be strongly affected by outliers. A simpler tree often avoids overfitting. You can use the prune
method of the classregtree
class to find the next largest tree from an optimal pruning sequence.
pruned = prune(t,'Level',1)
view(pruned)
pruned = Decision tree for classification 1 if PL<2.45 then node 2 elseif PL>=2.45 then node 3 else setosa 2 class = setosa 3 if PW<1.75 then node 4 elseif PW>=1.75 then node 5 else versicolor 4 if PL<4.95 then node 6 elseif PL>=4.95 then node 7 else versicolor 5 class = virginica 6 class = versicolor 7 class = virginica
To find the best classification tree, employing the techniques of resubstitution and cross validation, use the test
method of the classregtree
class.
This example uses the data on cars in carsmall.mat
to create a regression tree for predicting mileage using measurements of weight and the number of cylinders as predictors. Here, one predictor (weight) is continuous and the other (cylinders) is categorical. The response (mileage) is continuous.
Load the data and use the classregtree
constructor of the classregtree
class to create the regression tree:
load carsmall t = classregtree([Weight, Cylinders],MPG,... 'Categorical',2,'MinParent',20,... 'Names',{'W','C'})
t = Decision tree for regression 1 if W<3085.5 then node 2 elseif W>=3085.5 then node 3 else 23.7181 2 if W<2371 then node 4 elseif W>=2371 then node 5 else 28.7931 3 if C=8 then node 6 elseif C in {4 6} then node 7 else 15.5417 4 if W<2162 then node 8 elseif W>=2162 then node 9 else 32.0741 5 if C=6 then node 10 elseif C=4 then node 11 else 25.9355 6 if W<4381 then node 12 elseif W>=4381 then node 13 else 14.2963 7 fit = 19.2778 8 fit = 33.3056 9 fit = 29.6111 10 fit = 23.25 11 if W<2827.5 then node 14 elseif W>=2827.5 then node 15 else 27.2143 12 if W<3533.5 then node 16 elseif W>=3533.5 then node 17 else 14.8696 13 fit = 11 14 fit = 27.6389 15 fit = 24.6667 16 fit = 16.6 17 fit = 14.3889
t
is a classregtree
object and can be operated on with any of the methods of the class.
Use the type
method of the classregtree
class to show the type of the tree:
treetype = type(t)
treetype = regression
classregtree
creates a regression tree because MPG is a numerical vector, and the response is assumed to be continuous.
To view the tree, use the view
method of the classregtree
class:
view(t)
The tree predicts the response values at the circular leaf nodes based on a series of questions about the car at the triangular branching nodes. A true
answer to any question follows the branch to the left; a false
follows the branch to the right.
Use the tree to predict the mileage for a 2000pound car with either 4, 6, or 8 cylinders:
mileage2K = t([2000 4; 2000 6; 2000 8])
mileage2K = 33.3056 33.3056 33.3056
The object allows for functional evaluation, of the form t(X)
. This is a shorthand way of calling the eval
method of the classregtree
class.
The predicted responses computed above are all the same. This is because they follow a series of splits in the tree that depend only on weight, terminating at the leftmost leaf node in the view above. A 4000pound car, following the right branch from the top of the tree, leads to different predicted responses:
mileage4K = t([4000 4; 4000 6; 4000 8])
mileage4K = 19.2778 19.2778 14.3889
You can use a variety of other methods of the classregtree
class, such as cutvar
, cuttype
, and cutcategories
, to get more information about the split at node 3 that distinguishes the 8cylinder car:
var3 = cutvar(t,3) % What variable determines the split? type3 = cuttype(t,3) % What type of split is it? c = cutcategories(t,3); % Which classes are sent to the left % child node, and which to the right? leftChildNode = c{1} rightChildNode = c{2}
var3 = cell 'C' type3 = cell 'categorical' leftChildNode = 8 rightChildNode = 4 6
Regression trees fit the original (training) data well, but may do a poor job of predicting new values. Lower branches, especially, may be strongly affected by outliers. A simpler tree often avoids overfitting. To find the best regression tree, employing the techniques of resubstitution and cross validation, use the test
method of the classregtree
class.