Decision trees, or Classification trees and regression trees,
predict responses to data. To predict a response, follow the decisions
in the tree from the root (beginning) node down to a leaf node. The
leaf node contains the response. Classification trees give responses
that are nominal, such as `'true'`

or `'false'`

.
Regression trees give numeric responses.

Statistics and Machine Learning Toolbox™ trees are binary. Each step in a prediction involves checking the value of one predictor (variable). For example, here is a simple classification tree:

This tree predicts classifications based on two predictors, `x1`

and `x2`

.
To predict, start at the top node, represented by a triangle (Δ).
The first decision is whether `x1`

is smaller than `0.5`

.
If so, follow the left branch, and see that the tree classifies the
data as type `0`

.

If, however, `x1`

exceeds `0.5`

,
then follow the right branch to the lower-right triangle node. Here
the tree asks if `x2`

is smaller than `0.5`

.
If so, then follow the left branch to see that the tree classifies
the data as type `0`

. If not, then follow the right
branch to see that the that the tree classifies the data as type `1`

.

To learn how to prepare your data for classification or regression using decision trees, see Steps in Supervised Learning.

This example shows how to train a classification tree.

Create a classification tree using the entire `ionosphere`

data set.

load ionosphere % Contains X and Y variables Mdl = fitctree(X,Y)

Mdl = ClassificationTree ResponseName: 'Y' CategoricalPredictors: [] ClassNames: {'b' 'g'} ScoreTransform: 'none' NumObservations: 351

This example shows how to train a regression tree.

Create a regression tree using all observation in the `carsmall`

data set. Consider the `Horsepower`

and `Weight`

vectors as predictor variables, and the `MPG`

vector as the response.

load carsmall % Contains Horsepower, Weight, MPG X = [Horsepower Weight]; Mdl = fitrtree(X,MPG)

Mdl = RegressionTree ResponseName: 'Y' CategoricalPredictors: [] ResponseTransform: 'none' NumObservations: 94

This example shows how to view a classification or a regression tree. There are two ways to view a tree: `view(tree)`

returns a text description and `view(tree,'mode','graph')`

returns a graphic description of the tree.

Create and view a classification tree.

load fisheriris % load the sample data ctree = fitctree(meas,species); % create classification tree view(ctree) % text description

Decision tree for classification 1 if x3<2.45 then node 2 elseif x3>=2.45 then node 3 else setosa 2 class = setosa 3 if x4<1.75 then node 4 elseif x4>=1.75 then node 5 else versicolor 4 if x3<4.95 then node 6 elseif x3>=4.95 then node 7 else versicolor 5 class = virginica 6 if x4<1.65 then node 8 elseif x4>=1.65 then node 9 else versicolor 7 class = virginica 8 class = versicolor 9 class = virginica

view(ctree,'mode','graph') % graphic description

Now, create and view a regression tree.

load carsmall % load the sample data, contains Horsepower, Weight, MPG X = [Horsepower Weight]; rtree = fitrtree(X,MPG,'MinParent',30); % create classification tree view(rtree) % text description

Decision tree for regression 1 if x2<3085.5 then node 2 elseif x2>=3085.5 then node 3 else 23.7181 2 if x1<89 then node 4 elseif x1>=89 then node 5 else 28.7931 3 if x1<115 then node 6 elseif x1>=115 then node 7 else 15.5417 4 if x2<2162 then node 8 elseif x2>=2162 then node 9 else 30.9375 5 fit = 24.0882 6 fit = 19.625 7 fit = 14.375 8 fit = 33.3056 9 fit = 29

view(rtree,'mode','graph') % graphic description

The `fitctree`

and `fitrtree`

methods perform the following
steps to create decision trees:

Start with all input data, and examine all possible binary splits on every predictor.

Select a split with best optimization criterion.

A split might lead to a child node having too few observations (less than the

`MinLeafSize`

parameter). To avoid this, the software chooses a split that yields the best optimization criterion subject to the`MinLeafSize`

constraint.

Impose the split.

Repeat recursively for the two child nodes.

The explanation requires two more items: description of the optimization criterion, and stopping rule.

**Stopping rule:** Stop splitting
when any of the following hold:

The node is

*pure*.For classification, a node is pure if it contains only observations of one class.

For regression, a node is pure if the mean squared error (MSE) for the observed response in this node drops below the MSE for the observed response in the entire data multiplied by the tolerance on quadratic error per node (

`QuadraticErrorTolerance`

parameter).

There are fewer than

`MinParentSize`

observations in this node.Any split imposed on this node produces children with fewer than

`MinLeafSize`

observations.The algorithm splits

`MaxNumSplits`

nodes.

**Optimization criterion:**

Regression: mean-squared error (MSE). Choose a split to minimize the MSE of predictions compared to the training data.

Classification: One of three measures, depending on the setting of the

`SplitCriterion`

name-value pair:`'gdi'`

(Gini's diversity index, the default)`'twoing'`

`'deviance'`

For details, see

`ClassificationTree`

Definitions.

For a continuous predictor, a tree can split halfway between
any two adjacent unique values found for this predictor. For a categorical
predictor with *L* levels, a classification tree
needs to consider 2^{L–1}–1
splits to find the optimal split. Alternatively, you can choose a
heuristic algorithm to find a good split, as described in Splitting Categorical Predictors.

For dual-core systems and above, `fitctree`

and `fitrtree`

parallelize training decision
trees using Intel^{®} Threading Building Blocks (TBB). For details
on Intel TBB, see https://software.intel.com/en-us/intel-tbb.

This example shows how to predict class labels or responses using trained classification and regression trees.

After creating a tree, you can easily predict responses for new data. Suppose `Xnew`

is new data that has the same number of columns as the original data `X`

. To predict the classification or regression based on the tree (`Mdl`

) and the new data, enter

`Ynew = predict(Mdl,Xnew)`

For each row of data in `Xnew`

, `predict`

runs through the decisions in `Mdl`

and gives the resulting prediction in the corresponding element of `Ynew`

. For more information on classification tree prediction, see the `CompactClassificationTree.predict`

. For regression, see `CompactRegressionTree.predict`

.

For example, find the predicted classification of a point at the mean of the `ionosphere`

data.

```
load ionosphere
CMdl = fitctree(X,Y);
Ynew = predict(CMdl,mean(X))
```

Ynew = 'g'

Find the predicted `MPG`

of a point at the mean of the `carsmall`

data.

```
load carsmall
X = [Horsepower Weight];
RMdl = fitrtree(X,MPG);
Ynew = predict(RMdl,mean(X))
```

Ynew = 28.7931

This example hows how to predict out-of-sample responses of regression trees, and then plot the results.

Load the `carsmall`

data set. Consider `Weight`

as a predictor of the response `MPG`

.

```
load carsmall
idxNaN = isnan(MPG + Weight);
X = Weight(~idxNaN);
Y = MPG(~idxNaN);
n = numel(X);
```

Partition the data into training (50%) and validation (50%) sets.

rng(1) % For reproducibility idxTrn = false(n,1); idxTrn(randsample(n,round(0.5*n))) = true; % Training set logical indices idxVal = idxTrn == false; % Validation set logical indices

Grow a regression tree using the training observations.

Mdl = fitrtree(X(idxTrn),Y(idxTrn)); view(Mdl,'Mode','graph')

Compute fitted values of the validation observations for each of several subtrees.

m = max(Mdl.PruneList); pruneLevels = 0:2:m; % Pruning levels to consider z = numel(pruneLevels); Yfit = predict(Mdl,X(idxVal),'SubTrees',pruneLevels);

`Yfit`

is an `n`

-by- `z`

matrix of fitted values in which the rows correspond to observations and the columns correspond to a subtree.

Plot `Yfit`

and `Y`

against `X`

.

figure; sortDat = sortrows([X(idxVal) Y(idxVal) Yfit],1); % Sort all data with respect to X plot(sortDat(:,1),sortDat(:,2),'*'); hold on; plot(repmat(sortDat(:,1),1,size(Yfit,2)),sortDat(:,3:end)); lev = cellstr(num2str((pruneLevels)','Level %d MPG')); legend(['Observed MPG'; lev]) title 'Out-of-Sample Predictions' xlabel 'Weight (lbs)'; ylabel 'MPG'; h = findobj(gcf); axis tight; set(h(4:end),'LineWidth',3) % Widen all lines

The values of `Yfit`

for lower pruning levels tend to follow the data more closely than higher levels. Higher pruning levels tend to be flat for large `X`

intervals.

You can tune trees by setting name-value pairs in `fitctree`

and `fitrtree`

.
The remainder of this section describes how to determine the quality
of a tree, how to decide which name-value pairs to set, and how to
control the size of a tree:

*Resubstitution error* is the difference
between the response training data and the predictions the tree makes
of the response based on the input training data. If the resubstitution
error is high, you cannot expect the predictions of the tree to be
good. However, having low resubstitution error does not guarantee
good predictions for new data. Resubstitution error is often an overly
optimistic estimate of the predictive error on new data.

**Classification Tree Resubstitution Error **

This example shows how to examine the resubstitution error of a classification tree.

Load Fisher's iris data.

```
load fisheriris
```

Train a default classification tree using the entire data set.

Mdl = fitctree(meas,species);

Examine the resubstitution error.

resuberror = resubLoss(Mdl)

resuberror = 0.0200

The tree classifies nearly all the Fisher iris data correctly.

To get a better sense of the predictive accuracy of your tree for new data, cross validate the tree. By default, cross validation splits the training data into 10 parts at random. It trains 10 new trees, each one on nine parts of the data. It then examines the predictive accuracy of each new tree on the data not included in training that tree. This method gives a good estimate of the predictive accuracy of the resulting tree, since it tests the new trees on new data.

**Cross Validate a Regression Tree **

This example shows how to examine the resubstitution and cross-validation accuracy of a regression tree for predicting mileage based on the `carsmall`

data.

Load the `carsmall`

data set. Consider acceleration, displacement, horsepower, and weight as predictors of MPG.

```
load carsmall
X = [Acceleration Displacement Horsepower Weight];
```

Grow a regression tree using all of the observations.

rtree = fitrtree(X,MPG);

Compute the in-sample error.

resuberror = resubLoss(rtree)

resuberror = 4.7188

The resubstitution loss for a regression tree is the mean-squared error. The resulting value indicates that a typical predictive error for the tree is about the square root of 4.7, or a bit over 2.

Estimate the cross-validation MSE.

```
rng 'default';
cvrtree = crossval(rtree);
cvloss = kfoldLoss(cvrtree)
```

cvloss = 23.8065

The cross-validated loss is almost 25, meaning a typical predictive error for the tree on new data is about 5. This demonstrates that cross-validated loss is usually higher than simple resubstitution loss.

When you grow a decision tree, consider its simplicity and predictive power. A deep tree with many leaves is usually highly accurate on the training data. However, the tree is not guaranteed to show a comparable accuracy on an independent test set. A leafy tree tends to overtrain (or overfit), and its test accuracy is often far less than its training (resubstitution) accuracy. In contrast, a shallow tree does not attain high training accuracy. But a shallow tree can be more robust — its training accuracy could be close to that of a representative test set. Also, a shallow tree is easy to interpret. If you do not have enough data for training and test, estimate tree accuracy by cross validation.

`fitctree`

and `fitrtree`

have three name-value pair arguments
that control the depth of resulting decision trees:

`MaxNumSplits`

— The maximal number of branch node splits is`MaxNumSplits`

per tree. Set a large value for`MaxNumSplits`

to get a deep tree. The default is`size(X,1) – 1`

.`MinLeafSize`

— Each leaf has at least`MinLeafSize`

observations. Set small values of`MinLeafSize`

to get deep trees. The default is`1`

.`MinParentSize`

— Each branch node in the tree has at least`MinParentSize`

observations. Set small values of`MinParentSize`

to get deep trees. The default is`10`

.

If you specify `MinParentSize`

and `MinLeafSize`

,
the learner uses the setting that yields trees with larger leaves
(i.e., shallower trees):

`MinParent = max(MinParentSize,2*MinLeafSize)`

If you supply `MaxNumSplits`

, the software
splits a tree until one of the three splitting criteria is satisfied.

For an alternative method of controlling the tree depth, see Pruning.

**Select Appropriate Tree Depth **

This example shows how to control the depth of a decision tree, and how to choose an appropriate depth.

Load the `ionosphere`

data.

```
load ionosphere
```

Generate an exponentially spaced set of values from `10`

through `100`

that represent the minimum number of observations per leaf node.

leafs = logspace(1,2,10);

Create cross-validated classification trees for the `ionosphere`

data. Specify to grow each tree using a minimum leaf size in `leafs`

.

rng('default') N = numel(leafs); err = zeros(N,1); for n=1:N t = fitctree(X,Y,'CrossVal','On',... 'MinLeafSize',leafs(n)); err(n) = kfoldLoss(t); end plot(leafs,err); xlabel('Min Leaf Size'); ylabel('cross-validated error');

The best leaf size is between about `20`

and `50`

observations per leaf.

Compare the near-optimal tree with at least `40`

observations per leaf with the default tree, which uses `10`

observations per parent node and `1`

observation per leaf.

DefaultTree = fitctree(X,Y); view(DefaultTree,'Mode','Graph') OptimalTree = fitctree(X,Y,'MinLeafSize',40); view(OptimalTree,'mode','graph')

resubOpt = resubLoss(OptimalTree); lossOpt = kfoldLoss(crossval(OptimalTree)); resubDefault = resubLoss(DefaultTree); lossDefault = kfoldLoss(crossval(DefaultTree)); resubOpt,resubDefault,lossOpt,lossDefault

resubOpt = 0.0883 resubDefault = 0.0114 lossOpt = 0.1054 lossDefault = 0.1111

The near-optimal tree is much smaller and gives a much higher resubstitution error. Yet, it gives similar accuracy for cross-validated data.

Pruning optimizes tree depth (leafiness) is by merging leaves on the same tree branch. Control Depth or "Leafiness" describes one method for selecting the optimal depth for a tree. Unlike in that section, you do not need to grow a new tree for every node size. Instead, grow a deep tree, and prune it to the level you choose.

Prune a tree at the command line using the `prune`

method
(classification) or `prune`

method (regression). Alternatively,
prune a tree interactively with the tree viewer:

view(tree,'mode','graph')

To prune a tree, the tree must contain a pruning sequence. By
default, both `fitctree`

and `fitrtree`

calculate a pruning sequence for
a tree during construction. If you construct a tree with the `'Prune'`

name-value
pair set to `'off'`

, or if you prune a tree to a
smaller level, the tree does not contain the full pruning sequence.
Generate the full pruning sequence with the `prune`

method
(classification) or `prune`

method (regression).

**Prune a Classification Tree **

This example creates a classification tree for the `ionosphere`

data, and prunes it to a good level.

Load the `ionosphere`

data:

```
load ionosphere
```

Construct a default classification tree for the data:

tree = fitctree(X,Y);

View the tree in the interactive viewer:

view(tree,'Mode','Graph')

Find the optimal pruning level by minimizing cross-validated loss:

[~,~,~,bestlevel] = cvLoss(tree,... 'SubTrees','All','TreeSize','min')

bestlevel = 6

Prune the tree to level `6`

:

view(tree,'Mode','Graph','Prune',6)

Alternatively, use the interactive window to prune the tree.

The pruned tree is the same as the near-optimal tree in the "Select Appropriate Tree Depth" example.

Set `'TreeSize'`

to `'SE'`

(default) to find the maximal pruning level for which the tree error does not exceed the error from the best level plus one standard deviation:

[~,~,~,bestlevel] = cvLoss(tree,'SubTrees','All')

bestlevel = 6

In this case the level is the same for either setting of `'TreeSize'`

.

Prune the tree to use it for other purposes:

tree = prune(tree,'Level',6); view(tree,'Mode','Graph')

The `ClassificationTree`

and `RegressionTree`

classes
were released in MATLAB^{®} R2011a. Previously, you represented both
classification trees and regression trees with a `classregtree`

object.
The new classes provide all the functionality of the `classregtree`

class,
and are more convenient when used with Ensemble Methods.

Statistics and Machine Learning Toolbox software maintains `classregtree`

and its
predecessors `treefit`

, `treedisp`

, `treeval`

, `treeprune`

, and `treetest`

for
backward compatibility. These functions will be removed in a future
release.

This example uses Fisher's iris data in `fisheriris.mat`

to create a classification tree for predicting species using measurements of sepal length, sepal width, petal length, and petal width as predictors. Here, the predictors are continuous and the response is categorical.

Load the data and use the `classregtree`

constructor of the `classregtree`

class to create the classification tree.

load fisheriris t = classregtree(meas,species,... 'Names',{'SL' 'SW' 'PL' 'PW'})

t = Decision tree for classification 1 if PL<2.45 then node 2 elseif PL>=2.45 then node 3 else setosa 2 class = setosa 3 if PW<1.75 then node 4 elseif PW>=1.75 then node 5 else versicolor 4 if PL<4.95 then node 6 elseif PL>=4.95 then node 7 else versicolor 5 class = virginica 6 if PW<1.65 then node 8 elseif PW>=1.65 then node 9 else versicolor 7 class = virginica 8 class = versicolor 9 class = virginica

`t`

is a `classregtree`

object and can be operated on with any class method.

Use the `type`

method of the `classregtree`

class to show the type of the tree.

treetype = type(t)

treetype = classification

`classregtree`

creates a classification tree because `species`

is a cell array of strings, and the response is assumed to be categorical.

To view the tree, use the `view`

method of the `classregtree`

class.

view(t)

The tree predicts the response values at the circular leaf nodes based on a series of questions about the iris at the triangular branching nodes. A `true`

answer to any question follows the branch to the left. A `false`

follows the branch to the right.

The tree does not use sepal measurements for predicting species. These can go unmeasured in new data, and you can enter them as `NaN`

values for predictions. For example, use the tree to predict the species of an iris with petal length `4.8`

and petal width `1.6`

.

predicted = t([NaN NaN 4.8 1.6])

predicted = 'versicolor'

The object allows for functional evaluation, of the form `t(X)`

. This is a shorthand way of calling the `eval`

method of the `classregtree`

class. The predicted species is the left leaf node at the bottom of the tree in the previous view.

You can use a variety of methods of the `classregtree`

class, such as `cutvar`

and `cuttype`

to get more information about the split at node `6`

that makes the final distinction between `versicolor`

and `virginica`

.

var6 = cutvar(t,6) % What variable determines the split? type6 = cuttype(t,6) % What type of split is it?

var6 = 'PW' type6 = 'continuous'

Classification trees fit the original (training) data well, but can do a poor job of classifying new values. Lower branches, especially, can be strongly affected by outliers. A simpler tree often avoids overfitting. You can use the `prune`

method of the `classregtree`

class to find the next largest tree from an optimal pruning sequence.

```
pruned = prune(t,'Level',1)
view(pruned)
```

pruned = Decision tree for classification 1 if PL<2.45 then node 2 elseif PL>=2.45 then node 3 else setosa 2 class = setosa 3 if PW<1.75 then node 4 elseif PW>=1.75 then node 5 else versicolor 4 if PL<4.95 then node 6 elseif PL>=4.95 then node 7 else versicolor 5 class = virginica 6 class = versicolor 7 class = virginica

To find the best classification tree, employing the techniques of resubstitution and cross validation, use the `test`

method of the `classregtree`

class.

This example uses the data on cars in `carsmall.mat`

to create a regression tree for predicting mileage using measurements of weight and the number of cylinders as predictors. Here, one predictor (weight) is continuous and the other (cylinders) is categorical. The response (mileage) is continuous.

Load the data and use the `classregtree`

constructor of the `classregtree`

class to create the regression tree:

load carsmall t = classregtree([Weight, Cylinders],MPG,... 'Categorical',2,'MinParent',20,... 'Names',{'W','C'})

t = Decision tree for regression 1 if W<3085.5 then node 2 elseif W>=3085.5 then node 3 else 23.7181 2 if W<2371 then node 4 elseif W>=2371 then node 5 else 28.7931 3 if C=8 then node 6 elseif C in {4 6} then node 7 else 15.5417 4 if W<2162 then node 8 elseif W>=2162 then node 9 else 32.0741 5 if C=6 then node 10 elseif C=4 then node 11 else 25.9355 6 if W<4381 then node 12 elseif W>=4381 then node 13 else 14.2963 7 fit = 19.2778 8 fit = 33.3056 9 fit = 29.6111 10 fit = 23.25 11 if W<2827.5 then node 14 elseif W>=2827.5 then node 15 else 27.2143 12 if W<3533.5 then node 16 elseif W>=3533.5 then node 17 else 14.8696 13 fit = 11 14 fit = 27.6389 15 fit = 24.6667 16 fit = 16.6 17 fit = 14.3889

`t`

is a `classregtree`

object and can be operated on with any of the methods of the class.

Use the `type`

method of the `classregtree`

class to show the type of the tree:

treetype = type(t)

treetype = regression

`classregtree`

creates a regression tree because MPG is a numerical vector, and the response is assumed to be continuous.

To view the tree, use the `view`

method of the `classregtree`

class:

view(t)

The tree predicts the response values at the circular leaf nodes based on a series of questions about the car at the triangular branching nodes. A `true`

answer to any question follows the branch to the left; a `false`

follows the branch to the right.

Use the tree to predict the mileage for a 2000-pound car with either 4, 6, or 8 cylinders:

mileage2K = t([2000 4; 2000 6; 2000 8])

mileage2K = 33.3056 33.3056 33.3056

The object allows for functional evaluation, of the form `t(X)`

. This is a shorthand way of calling the `eval`

method of the `classregtree`

class.

The predicted responses computed above are all the same. This is because they follow a series of splits in the tree that depend only on weight, terminating at the leftmost leaf node in the view above. A 4000-pound car, following the right branch from the top of the tree, leads to different predicted responses:

mileage4K = t([4000 4; 4000 6; 4000 8])

mileage4K = 19.2778 19.2778 14.3889

You can use a variety of other methods of the `classregtree`

class, such as `cutvar`

, `cuttype`

, and `cutcategories`

, to get more information about the split at node 3 that distinguishes the 8-cylinder car:

var3 = cutvar(t,3) % What variable determines the split? type3 = cuttype(t,3) % What type of split is it? c = cutcategories(t,3); % Which classes are sent to the left % child node, and which to the right? leftChildNode = c{1} rightChildNode = c{2}

var3 = 'C' type3 = 'categorical' leftChildNode = 8 rightChildNode = 4 6

Regression trees fit the original (training) data well, but may do a poor job of predicting new values. Lower branches, especially, may be strongly affected by outliers. A simpler tree often avoids overfitting. To find the best regression tree, employing the techniques of resubstitution and cross validation, use the `test`

method of the `classregtree`

class.

Was this topic helpful?