Documentation

fitrtree

Fit binary regression decision tree

Syntax

  • tree = fitrtree(tbl,ResponseVarName)
  • tree = fitrtree(tbl,formula)
  • tree = fitrtree(tbl,y)
  • tree = fitrtree(___,Name,Value)
    example

Description

tree = fitrtree(tbl,ResponseVarName) returns a regression tree based on the input variables (also known as predictors, features, or attributes) in the table tbl and output (response) contained in tbl.ResponseVarName. tree is a binary tree where each branching node is split based on the values of a column of tbl.

tree = fitrtree(tbl,formula) returns a regression tree based on the input variables contained in the table tbl. formula is a formula string that identifies the response and predictor variables in tbl used for training.

tree = fitrtree(tbl,y) returns a regression tree based on the input variables contained in the table tbl and output contained in y.

example

tree = fitrtree(x,y) returns a regression tree based on the input variables x and output y. tree is a binary tree where each branching node is split based on the values of a column of x.

example

tree = fitrtree(___,Name,Value) fits a tree with additional options specified by one or more Name,Value pair arguments. For example, you can specify observation weights or train a cross-validated model.

If you use one of the following five options, tree is of class RegressionPartitionedModel: 'CrossVal', 'KFold', 'Holdout', 'Leaveout', or 'CVPartition'. Otherwise, tree is of class RegressionTree.

Examples

collapse all

Construct a Regression Tree

Load the sample data.

load carsmall;

Construct a regression tree using the sample data.

tree = fitrtree([Weight, Cylinders],MPG,...
                'categoricalpredictors',2,'MinParentSize',20,...
                'PredictorNames',{'W','C'})
tree = 

  RegressionTree
           PredictorNames: {'W'  'C'}
             ResponseName: 'Y'
    CategoricalPredictors: 2
        ResponseTransform: 'none'
          NumObservations: 94


Predict the mileage of 4,000-pound cars with 4, 6, and 8 cylinders.

mileage4K = predict(tree,[4000 4; 4000 6; 4000 8])
mileage4K =

   19.2778
   19.2778
   14.3889

Control Regression Tree Depth

You can control the depth of trees using the MaxNumSplits, MinLeafSize, or MinParentSize name-value pair parameters. fitrtree grows deep decision trees by default. You can grow shallower trees to reduce model complexity or computation time.

Load the carsmall data set. Consider Displacement, Horsepower, and Weight as predictors of the response MPG.

load carsmall
X = [Displacement Horsepower Weight];

The default values of the tree-depth controllers for growing regression trees are:

  • n - 1 for MaxNumSplits. n is the training sample size.

  • 1 for MinLeafSize.

  • 10 for MinParentSize.

These default values tend to grow deep trees for large training sample sizes.

Train a regression tree using the default values for tree-depth control. Cross validate the model using 10-fold cross validation.

rng(1); % For reproducibility
MdlDefault = fitrtree(X,MPG,'CrossVal','on');

Draw a histogram of the number of imposed on the trees. The number of imposed splits is one less than the number of leaves. Also, view one of the trees.

numBranches = @(x)sum(x.IsBranch);
mdlDefaultNumSplits = cellfun(numBranches, MdlDefault.Trained);

figure;
histogram(mdlDefaultNumSplits)

view(MdlDefault.Trained{1},'Mode','graph')

The average number of splits is between 14 and 15.

Suppose that you want a regression tree that is not as complex (deep) as the ones trained using the default number of splits. Train another regression tree, but set the maximum number of splits at 7, which is about half the mean number of splits from the default regression tree. Cross validate the model using 10-fold cross validation.

Mdl7 = fitrtree(X,MPG,'MaxNumSplits',7,'CrossVal','on');
view(Mdl7.Trained{1},'Mode','graph')

Compare the cross validation MSEs of the models.

mseDefault = kfoldLoss(MdlDefault)
mse7 = kfoldLoss(Mdl7)
mseDefault =

   27.7277


mse7 =

   28.3833

Mdl7 is much less complex and performs only slightly worse than MdlDefault.

Input Arguments

collapse all

tbl — Sample datatable

Sample data used to train the model, specified as a table. Each row of tbl corresponds to one observation, and each column corresponds to one predictor variable. Optionally, tbl can contain one additional column for the response variable. Multi-column variables and cell arrays other than cell arrays of strings are not allowed.

If tbl contains the response variable, and you want to use all remaining variables in tbl as predictors, then specify the response variable using ResponseVarName.

If tbl contains the response variable, and you want to use only a subset of the remaining variables in tbl as predictors, then specify a formula string using formula.

If tbl does not contain the response variable, then specify a response variable using y. The length of response variable and the number of rows of tbl must be equal.

Data Types: table

x — Predictor valuesmatrix of floating-point values

Predictor values, specified as matrix of floating-point values. Each column of x represents one variable, and each row represents one observation.

fitrtree considers NaN values in x as missing values. fitrtree does not use observations with all missing values for x the fit. fitrtree uses observations with some missing values for x to find splits on variables for which these observations have valid values.

Data Types: single | double

ResponseVarName — Response variable namename of a variable in tbl

Response variable name, specified as the name of a variable in tbl.

You must specify ResponseVarName as a string. For example, if the response variable y is stored as tbl.y, then specify it as 'response'. Otherwise, the software treats all columns of tbl, including y, as predictors when training the model.

The response variable must be a categorical or character array, logical or numeric vector, or cell array of strings. If y is a character array, then each element must correspond to one row of the array.

It is good practice to specify the order of the classes using the ClassNames name-value pair argument.

formula — Response and predictor variables to use in model trainingstring in the form of 'Y~X1+X2+X3'

Response and predictor variables to use in model training, specified as a string in the form of 'Y~X1+X2+X3'. In this form, Y represents the response variable, and X1, X2, and X3 represent the predictor variables.

To specify a subset of variables in tbl as predictors for training the model, use a formula string. If you specify a formula string, then any variables in tbl that do not appear in formula are not used to train the model.

y — Response datanumeric column vector

Response data, specified as a numeric column vector with the same number of rows as x. Each entry in y is the response to the data in the corresponding row of x.

fitrtree considers NaN values in y to be missing values. fitrtree does not use observations with missing values for y in the fit.

Data Types: single | double

Name-Value Pair Arguments

Specify optional comma-separated pairs of Name,Value arguments. Name is the argument name and Value is the corresponding value. Name must appear inside single quotes (' '). You can specify several name and value pair arguments in any order as Name1,Value1,...,NameN,ValueN.

Example: 'CrossVal','on','MinParentSize',30 specifies a cross-validated regression tree with a minimum of 30 observations per branch node.

'CategoricalPredictors' — Categorical predictors listnumeric or logical vector | cell array of strings | character matrix | 'all'

Categorical predictors list, specified as the comma-separated pair consisting of 'CategoricalPredictors' and one of the following.

  • A numeric vector with indices from 1 to p, where p is the number of columns of x or tbl.

  • A logical vector of length p, where a true entry means that the corresponding column of x or tbl is a categorical variable.

  • A cell array of strings, where each element in the array is the name of a predictor variable. The names must match entries in the PredictorNames property.

  • A character matrix, where each row of the matrix is a name of a predictor variable. Pad the names with extra blanks so each row of the character matrix has the same length.

  • 'all', meaning all predictors are categorical.

By default, if the predictor data is in a matrix (x), the software assumes that none of the predictors are categorical. If the predictor data is in a table (tbl), the software assumes that a variable is categorical if it contains, logical values, values of the unordered data type categorical, or a cell array of strings.

Data Types: single | double | logical | char | cell

'CrossVal' — Cross-validation flag'off' (default) | 'on'

Cross-validation flag, specified as the comma-separated pair consisting of 'CrossVal' and either 'on' or 'off'.

If 'on', fitrtree grows a cross-validated decision tree with 10 folds. You can override this cross-validation setting using one of the 'KFold', 'Holdout', 'Leaveout', or 'CVPartition' name-value pair arguments. You can only use one of these four options ('KFold', 'Holdout', 'Leaveout', or 'CVPartition') at a time when creating a cross-validated tree.

Alternatively, cross-validate tree later using the crossval method.

Example: 'CrossVal','on'

'CVPartition' — Partition for cross-validation treecvpartition object

Partition for cross-validated tree, specified as the comma-separated pair consisting of 'CVPartition' and an object created using cvpartition.

If you use 'CVPartition', you cannot use any of the 'KFold', 'Holdout', or 'Leaveout' name-value pair arguments.

'Holdout' — Fraction of data for holdout validation0 (default) | scalar value in the range [0,1]

Fraction of data used for holdout validation, specified as the comma-separated pair consisting of 'Holdout' and a scalar value in the range [0,1]. Holdout validation tests the specified fraction of the data, and uses the rest of the data for training.

If you use 'Holdout', you cannot use any of the 'CVPartition', 'KFold', or 'Leaveout' name-value pair arguments.

Example: 'Holdout',0.1

Data Types: single | double

'KFold' — Number of folds10 (default) | positive integer greater than 1

Number of folds to use in a cross-validated tree, specified as the comma-separated pair consisting of 'KFold' and a positive integer value greater than 1.

If you use 'KFold', you cannot use any of the 'CVPartition', 'Holdout', or 'Leaveout' name-value pair arguments.

Example: 'KFold',8

Data Types: single | double

'Leaveout' — Leave-one-out cross-validation flag'off' (default) | 'on'

Leave-one-out cross-validation flag, specified as the comma-separated pair consisting of 'Leaveout' and either 'on' or 'off. Use leave-one-out cross validation by setting to 'on'.

If you use 'Leaveout', you cannot use any of the 'CVPartition', 'Holdout', or 'KFold' name-value pair arguments.

Example: 'Leaveout','on'

'MergeLeaves' — Leaf merge flag'on' (default) | 'off'

Leaf merge flag, specified as the comma-separated pair consisting of 'MergeLeaves' and 'on' or 'off'.

If MergeLeaves is 'on', then fitrtree:

  • Merges leaves that originate from the same parent node, and that yields a sum of risk values greater or equal to the risk associated with the parent node

  • Estimates the optimal sequence of pruned subtrees, but does not prune the regression tree

Otherwise, fitrtree does not merge leaves.

Example: 'MergeLeaves','off'

'MinLeafSize' — Minimum number of leaf node observations1 (default) | positive integer value

Minimum number of leaf node observations, specified as the comma-separated pair consisting of 'MinLeafSize' and a positive integer value. Each leaf has at least MinLeafSize observations per tree leaf. If you supply both MinParentSize and MinLeafSize, fitrtree uses the setting that gives larger leaves: MinParentSize = max(MinParentSize,2*MinLeafSize).

Example: 'MinLeafSize',3

Data Types: single | double

'MinParentSize' — Minimum number of branch node observations10 (default) | positive integer value

Minimum number of branch node observations, specified as the comma-separated pair consisting of 'MinParentSize' and a positive integer value. Each branch node in the tree has at least MinParentSize observations. If you supply both MinParentSize and MinLeafSize, fitrtree uses the setting that gives larger leaves: MinParentSize = max(MinParentSize,2*MinLeafSize).

Example: 'MinParentSize',8

Data Types: single | double

'NumVariablesToSample' — Number of predictors for split'all' (default) | positive integer value

Number of predictors to select at random for each split, specified as the comma-separated pair consisting of 'NumVariablesToSample' and a positive integer value. You can also specify 'all' to use all available predictors.

Example: 'NumVariablesToSample',3

Data Types: single | double

'PredictorNames' — Predictor variable names{'x1','x2',...} (default) | cell array of strings

Predictor variable names, specified as the comma-separated pair consisting of 'PredictorNames' and a cell array of strings containing the names for the predictor variables, in the order in which they appear in x or tbl.

If you specify the predictors as a table (tbl), PredictorNames must be a subset of the variable names in tbl. In this case, the software uses only the variables in PredictorNames to fit the model. If you use formula to specify the model, then you cannot use the PredictorNames name-value pair.

Data Types: cell

'Prune' — Flag to estimate optimal sequence of pruned subtrees'on' (default) | 'off'

Flag to estimate the optimal sequence of pruned subtrees, specified as the comma-separated pair consisting of 'Prune' and 'on' or 'off'.

If Prune is 'on', then fitrtree grows the regression tree and estimates the optimal sequence of pruned subtrees, but does not prune the regression tree. Otherwise, fitrtree grows the regression tree without estimating the optimal sequence of pruned subtrees.

To prune a trained regression tree, pass the regression tree to prune.

Example: 'Prune','off'

'PruneCriterion' — Pruning criterion'error' (default)

Pruning criterion, specified as the comma-separated pair consisting of 'PruneCriterion' and 'error'.

Example: 'PruneCriterion','error'

'QuadraticErrorTolerance' — Quadratic error tolerance1e-6 (default) | positive scalar value

Quadratic error tolerance per node, specified as the comma-separated pair consisting of 'QuadraticErrorTolerance' and a positive scalar value. Splitting nodes stops when quadratic error per node drops below QuadraticErrorTolerance*QED, where QED is the quadratic error for the entire data computed before the decision tree is grown.

Example: 'QuadraticErrorTolerance',1e-4

'ResponseName' — Response variable name'Y' (default) | string

Response variable name, specified as the comma-separated pair consisting of 'ResponseName' and a string representing the name of the response variable.

This name-value pair is not valid when using the ResponseVarName or formula input arguments.

Example: 'ResponseName','Response'

Data Types: char

'ResponseTransform' — Response transform function'none' (default) | function handle

Response transform function for transforming the raw response values, specified as the comma-separated pair consisting of 'ResponseTransform' and either a function handle or 'none'. The function handle must accept a matrix of response values and return a matrix of the same size. The default string 'none' means @(x)x, or no transformation.

Add or change a ResponseTransform function using dot notation:

tree.ResponseTransform = @function

Data Types: function_handle

'SplitCriterion' — Split criterion'MSE' (default)

Split criterion, specified as the comma-separated pair consisting of 'SplitCriterion' and 'MSE', meaning mean squared error.

Example: 'SplitCriterion','MSE'

'Surrogate' — Surrogate decision splits flag'off' | 'on' | 'all' | positive integer value

Surrogate decision splits flag, specified as the comma-separated pair consisting of 'Surrogate' and 'on', 'off', 'all', or a positive integer value.

  • When 'on', fitrtree finds at most 10 surrogate splits at each branch node.

  • When set to a positive integer value, fitrtree finds at most the specified number of surrogate splits at each branch node.

  • When set to 'all', fitrtree finds all surrogate splits at each branch node. The 'all' setting can use much time and memory.

Use surrogate splits to improve the accuracy of predictions for data with missing values. The setting also enables you to compute measures of predictive association between predictors.

Example: 'Surrogate','on'

Data Types: single | double

'Weights' — Observation weightsones(size(X,1),1) (default) | vector of scalar values

Observation weights, specified as the comma-separated pair consisting of 'Weights' and a vector of scalar values. The software weights the observations in each row of x or tbl with the corresponding value in Weights. The size of Weights must equal the number of rows in x or tbl.

If you specify the input data as a table tbl, then Weights can be the name of a variable in tbl that contains a numeric vector. In this case, you must specify Weights as a variable name string. For example, if weights vector W is stored as TBL.W, then specify it as 'W'. Otherwise, the software treats all columns of tbl, including W, as predictors when training the model.

fitrtree normalizes the weights in each class to add up to 1.

Data Types: single | double

Output Arguments

collapse all

tree — Regression treeregression tree object

Regression tree, returned as a regression tree object. Using the 'Crossval', 'KFold', 'Holdout', 'Leaveout', or 'CVPartition' options results in a tree of class RegressionPartitionedModel. You cannot use a partitioned tree for prediction, so this kind of tree does not have a predict method.

Otherwise, tree is of class RegressionTree, and you can use the predict method to make predictions.

More About

collapse all

Predictive Measure of Association

The predictive measure of association is a value that indicates the similarity between decision rules that split observations. Among all possible decision splits that are compared to the optimal split (found by growing the tree), the best surrogate decision split yields the maximum predictive measure of association. The second-best surrogate split has the second-largest predictive measure of association.

Suppose xj and xk are predictor variables j and k, respectively, and jk. At node t, the predictive measure of association between the optimal split xj < u and a surrogate split xk < v is

λjk=min(PL,PR)(1PLjLkPRjRk)min(PL,PR).

  • PL is the proportion of observations in node t, such that xj < u. The subscript L stands for the left child of node t.

  • PR is the proportion of observations in node t, such that xju. The subscript R stands for the right child of node t.

  • PLjLk is the proportion of observations at node t, such that xj < u and xk < v.

  • PRjRk is the proportion of observations at node t, such that xju and xkv.

  • Observations with missing values for xj or xk do not contribute to the proportion calculations.

λjk is a value in (–∞,1]. If λjk > 0, then xk < v is a worthwhile surrogate split for xj < u.

Surrogate Decision Splits

A surrogate decision split is an alternative to the optimal decision split at a given node in a decision tree. The optimal split is found by growing the tree; the surrogate split uses a similar or correlated predictor variable and split criterion.

When the value of the optimal split predictor for an observation is missing, the observation is sent to the left or right child node using the best surrogate predictor. When the value of the best surrogate split predictor for the observation is also missing, the observation is sent to the left or right child node using the second-best surrogate predictor, and so on. Candidate splits are sorted in descending order by their predictive measure of association.

Tips

By default, Prune is 'on'. However, this specification does not prune the regression tree. To prune a trained regression tree, pass the regression tree to prune.

Algorithms

Node Splitting Rules

fitrtree follows these steps to determine how to split node t. For all predictors xi, i = 1,...,p:

  1. fitrtree computes the weighted, mean-square error (MSE) of the responses in node t using

    εt=jTwj(yjy¯t)2.

    wj is the weight of observation j, and T is the set of all observation indices in node t. If you do not specify Weights, then wj = 1/n, where n is the sample size.

  2. fitrtree estimates the probability that an observation is in node t using

    P(T)=jTwj.

  3. fitrtree sorts xi in ascending order. Each element of the sorted predictor is a splitting candidate or cut point. fitrtree records any indices corresponding to missing values in the set TU, which is the unsplit set.

  4. fitrtree determines the best way to split node t using xi by maximizing the reduction in MSE (ΔI) over all splitting candidates. That is, for all splitting candidates in xi:

    1. fitrtree splits the observations in node t into left and right child nodes (tL and tR, respectively).

    2. fitrtree computes ΔI. Suppose that for a particular splitting candidate, tL and tR contain observation indices in the sets TL and TR, respectively.

      • If xi does not contain any missing values, then the reduction in MSE for the current splitting candidate is

        ΔI=P(T)εtP(TL)εtLP(TR)εtR.

      • If xi contains missing values, then, assuming that the observations are missing at random, the reduction in MSE is

        ΔIU=P(TTU)εtP(TL)εtLP(TR)εtR.

        TTU is the set of all observation indices in node t that are not missing.

      • If you use surrogate decision splits, then:

        1. fitrtree computes the predictive measures of association between the decision split xj < u and all possible decision splits xk < v, jk.

        2. fitrtree sorts the possible alternative decision splits in descending order by their predictive measure of association with the optimal split. The surrogate split is the decision split yielding the largest measure.

        3. fitrtree decides the child node assignments for observations with a missing value for xi using the surrogate split. If the surrogate predictor also contains a missing value, then fitrtree uses the decision split with the second largest measure, and so on, until there are no other surrogates. It is possible for fitrtree to split two different observations at node t using two different surrogate splits. For example, suppose the predictors x1 and x2 are the best and second best surrogates, respectively, for the predictor xi, i ∉ {1,2}, at node t. If observation m of predictor xi is missing (i.e., xmi is missing), but xm1 is not missing, then x1 is the surrogate predictor for observation xmi. If observations x(m + 1),i and x(m + 1),1 are missing, but x(m + 1),2 is not missing, then x2 is the surrogate predictor for observation m + 1.

        4. fitrtree uses the appropriate MSE reduction formula. That is, if fitrtree fails to assign all missing observations in node t to children nodes using surrogate splits, then the MSE reduction is ΔIU. Otherwise, fitrtree uses ΔI for the MSE reduction.

    3. fitrtree chooses the candidate that yields the largest MSE reduction.

fitrtree splits the predictor variable at the cut point that maximizes the MSE reduction.

Tree Depth Control

  • If MergeLeaves is 'on' and PruneCriterion is 'error' (which are the default values for these name-value pair arguments), then the software applies pruning only to the leaves and by using classification error. This specification amounts to merging leaves that share the most popular class per leaf.

  • To accommodate MaxNumSplits, fitrtree splits all nodes in the current layer, and then counts the number of branch nodes. A layer is the set of nodes that are equidistant from the root node. If the number of branch nodes exceeds MaxNumSplits, fitrtree follows this procedure:

    1. Determine how many branch nodes in the current layer must be unsplit so that there are at most MaxNumSplits branch nodes.

    2. Sort the branch nodes by their impurity gains.

    3. Unsplit the number of least successful branches.

    4. Return the decision tree grown so far.

    This procedure produces maximally balanced trees.

  • The software splits branch nodes layer by layer until at least one of these events occurs:

    • There are MaxNumSplits branch nodes.

    • A proposed split causes the number of observations in at least one branch node to be fewer than MinParentSize.

    • A proposed split causes the number of observations in at least one leaf node to be fewer than MinLeafSize.

    • The algorithm cannot find a good split within a layer (i.e., the pruning criterion (see PruneCriterion), does not improve for all proposed splits in a layer). A special case is when all nodes are pure (i.e., all observations in the node have the same class).

    MaxNumSplits and MinLeafSize do not affect splitting at their default values. Therefore, if you set 'MaxNumSplits', splitting might stop due to the value of MinParentSize, before MaxNumSplits splits occur.

Parallelization

For dual-core systems and above, fitrtree parallelizes training decision trees using Intel® Threading Building Blocks (TBB). For details on Intel TBB, see https://software.intel.com/en-us/intel-tbb.

References

[1] Breiman, L., J. Friedman, R. Olshen, and C. Stone. Classification and Regression Trees. Boca Raton, FL: CRC Press, 1984.

Was this topic helpful?