Binary decision tree for regression

fits
a tree with additional options specified by one or more name-value
pair arguments. For example, you can grow a cross-validated tree,
hold out a fraction of data for validation, or specify observation
weights.`tree`

= fitrtree(`x`

,`y`

,`Name,Value`

)

Load the sample data.

```
load carsmall;
```

Construct a regression tree using the sample data.

tree = fitrtree([Weight, Cylinders],MPG,... 'categoricalpredictors',2,'MinParentSize',20,... 'PredictorNames',{'W','C'})

tree = RegressionTree PredictorNames: {'W' 'C'} ResponseName: 'Y' ResponseTransform: 'none' CategoricalPredictors: 2 NumObservations: 94

Predict the mileage of 4,000-pound cars with 4, 6, and 8 cylinders.

mileage4K = predict(tree,[4000 4; 4000 6; 4000 8])

mileage4K = 19.2778 19.2778 14.3889

You can control the depth of trees using the `MaxNumSplits`

, `MinLeafSize`

, or `MinParentSize`

name-value pair parameters. `fitrtree`

grows deep decision trees by default. You can grow shallower trees to reduce model complexity or computation time.

Load the `carsmall`

data set. Consider `Displacement`

, `Horsepower`

, and `Weight`

as predictors of the response `MPG`

.

```
load carsmall
X = [Displacement Horsepower Weight];
```

The default values of the tree-depth controllers for growing regression trees are:

`n - 1`

for`MaxNumSplits`

.`n`

is the training sample size.`1`

for`MinLeafSize`

.`10`

for`MinParentSize`

.

These default values tend to grow deep trees for large training sample sizes.

Train a regression tree using the default values for tree-depth control. Cross validate the model using 10-fold cross validation.

rng(1); % For reproducibility MdlDefault = fitrtree(X,MPG,'CrossVal','on');

Draw a histogram of the number of imposed on the trees. The number of imposed splits is one less than the number of leaves. Also, view one of the trees.

numBranches = @(x)sum(x.IsBranch); mdlDefaultNumSplits = cellfun(numBranches, MdlDefault.Trained); figure; histogram(mdlDefaultNumSplits) view(MdlDefault.Trained{1},'Mode','graph')

The average number of splits is between 14 and 15.

Suppose that you want a regression tree that is not as complex (deep) as the ones trained using the default number of splits. Train another regression tree, but set the maximum number of splits at 7, which is about half the mean number of splits from the default regression tree. Cross validate the model using 10-fold cross validation.

Mdl7 = fitrtree(X,MPG,'MaxNumSplits',7,'CrossVal','on'); view(Mdl7.Trained{1},'Mode','graph')

Compare the cross validation MSEs of the models.

mseDefault = kfoldLoss(MdlDefault) mse7 = kfoldLoss(Mdl7)

mseDefault = 27.7277 mse7 = 28.3833

`Mdl7`

is much less complex and performs only slightly worse than `MdlDefault`

.

`x`

— Predictor valuesmatrix of scalar valuesPredictor values, specified as a matrix of scalar values. Each
column of `x`

represents one variable, and each
row represents one observation.

`fitrtree`

considers `NaN`

values
in `x`

as missing values. `fitrtree`

does
not use observations with all missing values for `x`

in
the fit. `fitrtree`

uses observations with some
missing values for `x`

to find splits on variables
for which these observations have valid values.

**Data Types: **`single`

| `double`

`y`

— Response valuesvector of scalar valuesResponse values, specified as a vector of scalar values with
the same number of rows as `x`

. Each entry in `y`

is
the response to the data in the corresponding row of `x`

.

`fitrtree`

considers `NaN`

values
in `y`

to be missing values. `fitrtree`

does
not use observations with missing values for `y`

in
the fit.

**Data Types: **`single`

| `double`

Specify optional comma-separated pairs of `Name,Value`

arguments.
`Name`

is the argument
name and `Value`

is the corresponding
value. `Name`

must appear
inside single quotes (`' '`

).
You can specify several name and value pair
arguments in any order as `Name1,Value1,...,NameN,ValueN`

.

`'CrossVal','on','MinParentSize',30`

specifies
a cross-validated regression tree with a minimum of 30 observations
per branch node.`'CategoricalPredictors'`

— Categorical predictors listnumeric or logical vector | cell array of strings | character matrix | `'all'`

Categorical predictors list, specified as the comma-separated
pair consisting of `'CategoricalPredictors'`

and
one of the following:

A numeric vector with indices from

`1`

through`p`

, where`p`

is the number of columns of`x`

.A logical vector of length

`p`

, where a`true`

entry means that the corresponding column of`x`

is a categorical variable.A cell array of strings, where each element in the array is the name of a predictor variable. The names must match entries in the

`PredictorNames`

property.A character matrix, where each row of the matrix is a name of a predictor variable. Pad the names with extra blanks so each row of the character matrix has the same length.

`'all'`

, meaning all predictors are categorical.

**Data Types: **`single`

| `double`

| `logical`

| `char`

| `cell`

`'CrossVal'`

— Cross-validation flag`'off'`

(default) | `'on'`

Cross-validation flag, specified as the comma-separated pair
consisting of `'CrossVal'`

and either `'on'`

or `'off'`

.

If `'on'`

, `fitrtree`

grows
a cross-validated decision tree with 10 folds. You can override this
cross-validation setting using one of the `'KFold'`

, `'Holdout'`

, `'Leaveout'`

,
or `'CVPartition'`

name-value pair arguments. You
can only use one of these arguments at a time when creating a cross-validated
tree.

Alternatively, cross validate `tree`

later
using the `crossval`

method.

**Example: **`'CrossVal','on'`

`'CVPartition'`

— Partition for cross-validated tree`cvpartition`

objectPartition for cross-validated tree, specified as the comma-separated
pair consisting of `'CVPartition'`

and an object
created using `cvpartition`

.

If you use `'CVPartition'`

, you cannot use
any of the `'KFold'`

, `'Holdout'`

,
or `'Leaveout'`

name-value pair arguments.

`'Holdout'`

— Fraction of data for holdout validation`0`

(default) | scalar value in the range `[0,1]`

Fraction of data used for holdout validation, specified as the
comma-separated pair consisting of `'Holdout'`

and
a scalar value in the range `[0,1]`

. Holdout validation
tests the specified fraction of the data, and uses the rest of the
data for training.

If you use `'Holdout'`

, you cannot use any
of the `'CVPartition'`

, `'KFold'`

,
or `'Leaveout'`

name-value pair arguments.

**Example: **`'Holdout',0.1`

**Data Types: **`single`

| `double`

`'KFold'`

— Number of folds`10`

(default) | positive integer valueNumber of folds to use in a cross-validated tree, specified
as the comma-separated pair consisting of `'KFold'`

and
a positive integer value.

If you use `'KFold'`

, you cannot use any of
the `'CVPartition'`

, `'Holdout'`

,
or `'Leaveout'`

name-value pair arguments.

**Example: **`'KFold',8`

**Data Types: **`single`

| `double`

`'Leaveout'`

— Leave-one-out cross-validation flag`'off'`

(default) | `'on'`

Leave-one-out cross-validation flag, specified as the comma-separated
pair consisting of `'Leaveout'`

and either `'on'`

or `'off`

.
Specify `'on'`

to use leave-one-out cross validation.

If you use `'Leaveout'`

, you cannot use any
of the `'CVPartition'`

, `'Holdout'`

,
or `'KFold'`

name-value pair arguments.

**Example: **`'Leaveout','on'`

`'MaxNumSplits'`

— Maximal number of decision splits`size(X,1) - 1`

(default) | positive integerMaximal number of decision splits (or branch nodes), specified
as the comma-separated pair consisting of `'MaxNumSplits'`

and
a positive integer. `fitrtree`

splits `MaxNumSplits`

or
fewer branch nodes. For more details on splitting behavior, see Algorithms.

**Example: **`'MaxNumSplits',5`

**Data Types: **`single`

| `double`

`'MergeLeaves'`

— Leaf merge flag`'on'`

(default) | `'off'`

Leaf merge flag, specified as the comma-separated pair consisting
of `'MergeLeaves'`

and either `'on'`

or `'off'`

.

If `MergeLeaves`

is `'on'`

,
then `fitrtree`

merges leaves that originate
from the same parent node, and that give a sum of risk values greater
or equal to the risk associated with the parent node. Otherwise, `fitrtree`

does
not merge leaves.

**Example: **`'MergeLeaves','off'`

`'MinLeafSize'`

— Minimum number of leaf node observations`1`

(default) | positive integer valueMinimum number of leaf node observations, specified as the comma-separated
pair consisting of `'MinLeafSize'`

and a positive
integer value. Each leaf has at least `MinLeafSize`

observations
per tree leaf. If you supply both `MinParentSize`

and `MinLeafSize`

, `fitrtree`

uses the setting that gives larger
leaves: `MinParentSize = max(MinParentSize,2*MinLeafSize)`

.

**Example: **`'MinLeafSize',3`

**Data Types: **`single`

| `double`

`'MinParentSize'`

— Minimum number of branch node observations`10`

(default) | positive integer valueMinimum number of branch node observations, specified as the
comma-separated pair consisting of `'MinParentSize'`

and
a positive integer value. Each branch node in the tree has at least `MinParentSize`

observations.
If you supply both `MinParentSize`

and `MinLeafSize`

, `fitrtree`

uses the setting that gives larger
leaves: `MinParentSize = max(MinParentSize,2*MinLeafSize)`

.

**Example: **`'MinParentSize',8`

**Data Types: **`single`

| `double`

`'NumVariablesToSample'`

— Number of predictors to select at random for each split`'all'`

(default) | positive integer valueNumber of predictors to select at random for each split, specified
as the comma-separated pair consisting of `'NumVariablesToSample'`

and
a positive integer value. You can also specify `'all'`

to
use all available predictors.

**Example: **`'NumVariablesToSample',3`

**Data Types: **`single`

| `double`

`'PredictorNames'`

— Predictor variable names`{'x1','x2',...}`

(default) | cell array of stringsPredictor variable names, specified as the comma-separated pair
consisting of `'PredictorNames'`

and a cell array
of strings containing the names for the predictor variables, in the
order in which they appear in `x`

.

**Data Types: **`cell`

`'Prune'`

— Flag to estimate the optimal sequence of pruned subtrees`'on'`

(default) | `'off'`

Flag to estimate the optimal sequence of pruned subtrees, specified
as the comma-separated pair consisting of `'Prune'`

and
either `'on'`

or `'off'`

.

If `Prune`

is `'on'`

, then `fitrtree`

grows
the regression tree and estimates the optimal sequence of pruned subtrees,
but does not prune the regression tree. Otherwise, `fitrtree`

grows
the regression tree without estimating the optimal sequence of pruned
subtrees.

To prune a trained `RegressionTree`

model,
pass it to `prune`

.

**Example: **`'Prune','off'`

`'PruneCriterion'`

— Pruning criterion`'mse'`

Pruning criterion, specified as the comma-separated pair consisting
of `'PruneCriterion'`

and `'mse'`

.

**Example: **`'PruneCriterion','mse'`

`'QuadraticErrorTolerance'`

— Quadratic error tolerance`1e-6`

(default) | positive scalar valueQuadratic error tolerance per node, specified as the comma-separated
pair consisting of `'QuadraticErrorTolerance'`

and
a positive scalar value. Splitting nodes stops when the quadratic
error per node drops below `QuadraticErrorTolerance*QED`

,
where `QED`

is the quadratic error for all data computed
before the decision tree is grown.

**Example: **`'QuadraticErrorTolerance',1e-4`

`'ResponseName'`

— Response variable name`'Y'`

(default) | stringResponse variable name, specified as the comma-separated pair
consisting of `'ResponseName'`

and a string containing
the name of the response variable in `y`

.

**Example: **`'ResponseName','Response'`

**Data Types: **`char`

`'ResponseTransform'`

— Response transform function`'none'`

(default) | function handleResponse transform function for transforming the raw response
values, specified as the comma-separated pair consisting of `'ResponseTransform'`

and
either a function handle or `'none'`

. The function
handle should accept a matrix of response values and return a matrix
of the same size. The default string `'none'`

means `@(x)x`

,
or no transformation.

Add or change a `ResponseTransform`

function
using dot notation:

tree.ResponseTransform = @function

**Data Types: **`function_handle`

`'SplitCriterion'`

— Split criterion`'MSE'`

Split criterion, specified as the comma-separated pair consisting
of `'SplitCriterion'`

and `'MSE'`

,
meaning mean squared error.

**Example: **`'SplitCriterion','MSE'`

`'Surrogate'`

— Surrogate decision splits flag`'off'`

| `'on'`

| `'all'`

| positive integer valueSurrogate decision splits flag, specified as the comma-separated
pair consisting of `'Surrogate'`

and one of `'on'`

, `'off'`

, `'all'`

,
or a positive integer value.

When

`'on'`

,`fitrtree`

finds at most 10 surrogate splits at each branch node.When set to a positive integer value,

`fitrtree`

finds at most the specified number of surrogate splits at each branch node.When set to

`'all'`

,`fitrtree`

finds all surrogate splits at each branch node. The`'all'`

setting can use considerable time and memory.

Use surrogate splits to improve the accuracy of predictions for data with missing values. The setting also lets you compute measures of predictive association between predictors.

**Example: **`'Surrogate','on'`

**Data Types: **`single`

| `double`

`'Weights'`

— Observation weights`ones(size(X,1),1)`

(default) | vector of scalar valuesObservation weights, specified as the comma-separated pair consisting
of `'Weights'`

and a vector of scalar values. The
length of `Weights`

is the number of rows in `x`

.

**Data Types: **`single`

| `double`

`tree`

— Regression treeregression tree objectRegression tree, returned as a regression tree object. Note
that using the `'Crossval'`

, `'KFold'`

, `'Holdout'`

, `'Leaveout'`

,
or `'CVPartition'`

options results in a tree of class `RegressionPartitionedModel`

.
You cannot use a partitioned tree for prediction, so this kind of
tree does not have a `predict`

method.

Otherwise, `tree`

is of class `RegressionTree`

, and
you can use the `predict`

method to make predictions.

If

`MergeLeaves`

is`'on'`

and`PruneCriterion`

is`'error'`

(which are the default values for these name-value pair arguments), then the software applies pruning only to the leaves and by using classification error. This specification amounts to merging leaves that share the most popular class per leaf.To accommodate

`MaxNumSplits`

,`fitrtree`

splits all nodes in the current*layer*, and then counts the number of branch nodes. A layer is the set of nodes that are equidistant from the root node. If the number of branch nodes exceeds`MaxNumSplits`

,`fitrtree`

follows this procedure:Determine how many branch nodes in the current layer must be unsplit so that there are at most

`MaxNumSplits`

branch nodes.Sort the branch nodes by their impurity gains.

Unsplit the number of least successful branches.

Return the decision tree grown so far.

This procedure produces maximally balanced trees.

The software splits branch nodes layer by layer until at least one of these events occurs:

There are

`MaxNumSplits`

branch nodes.A proposed split causes the number of observations in at least one branch node to be fewer than

`MinParentSize`

.A proposed split causes the number of observations in at least one leaf node to be fewer than

`MinLeafSize`

.The algorithm cannot find a good split within a layer (i.e., the pruning criterion (see

`PruneCriterion`

), does not improve for all proposed splits in a layer). A special case is when all nodes are pure (i.e., all observations in the node have the same class).

`MaxNumSplits`

and`MinLeafSize`

do not affect splitting at their default values. Therefore, if you set`'MaxNumSplits'`

, splitting might stop due to the value of`MinParentSize`

, before`MaxNumSplits`

splits occur.For dual-core systems and above,

`fitrtree`

parallelizes training decision trees using Intel^{®}Threading Building Blocks (TBB). For details on Intel TBB, see https://software.intel.com/en-us/intel-tbb.

[1] Breiman, L., J. Friedman, R. Olshen, and
C. Stone. *Classification and Regression Trees*.
Boca Raton, FL: CRC Press, 1984.

`predict`

| `prune`

| `RegressionPartitionedModel`

| `RegressionTree`

Was this topic helpful?