test

Class: classregtree

Syntax

cost = test(t,'resubstitution')
cost = test(t,'test',X,y)
cost = test(t,'crossvalidate',X,y)
[cost,secost,ntnodes,bestlevel] = test(...)
[...] = test(...,param1,val1,param2,val2,...)

Description

cost = test(t,'resubstitution') computes the cost of the tree t using a resubstitution method. t is a decision tree as created by classregtree. The cost of the tree is the sum over all terminal nodes of the estimated probability of a node times the cost of a node. If t is a classification tree, the cost of a node is the sum of the misclassification costs of the observations in that node. If t is a regression tree, the cost of a node is the average squared error over the observations in that node. cost is a vector of cost values for each subtree in the optimal pruning sequence for t. The resubstitution cost is based on the same sample that was used to create the original tree, so it under estimates the likely cost of applying the tree to new data.

cost = test(t,'test',X,y) uses the matrix of predictors X and the response vector y as a test sample, applies the decision tree t to that sample, and returns a vector cost of cost values computed for the test sample. X and y should not be the same as the learning sample, that is, the sample that was used to fit the tree t.

cost = test(t,'crossvalidate',X,y) uses 10-fold cross-validation to compute the cost vector. X and y should be the learning sample, that is, the sample that was used to fit the tree t. The function partitions the sample into 10 subsamples, chosen randomly but with roughly equal size. For classification trees, the subsamples also have roughly the same class proportions. For each subsample, test fits a tree to the remaining data and uses it to predict the subsample. It pools the information from all subsamples to compute the cost for the whole sample.

[cost,secost,ntnodes,bestlevel] = test(...) also returns the vector secost containing the standard error of each cost value, the vector ntnodes containing the number of terminal nodes for each subtree, and the scalar bestlevel containing the estimated best level of pruning. A bestlevel of 0 means no pruning. The best level is the one that produces the smallest tree that is within one standard error of the minimum-cost subtree.

[...] = test(...,param1,val1,param2,val2,...) specifies optional parameter name/value pairs for methods other than 'resubstitution', chosen from the following:

  • 'weights' — Observation weights.

  • 'nsamples' — The number of cross-validation samples (default is 10).

  • 'treesize' — Either 'se' (default) to choose the smallest tree whose cost is within one standard error of the minimum cost, or 'min' to choose the minimal cost tree.

Examples

expand all

Compute the Cost of a Decision Tree

Find the best tree for Fisher's iris data using cross-validation.

Grow a large tree:

load fisheriris;
t = classregtree(meas,species,...
                 'names',{'SL' 'SW' 'PL' 'PW'},...
                 'minparent',5)
view(t)
t = 

Decision tree for classification
 1  if PL<2.45 then node 2 elseif PL>=2.45 then node 3 else setosa
 2  class = setosa
 3  if PW<1.75 then node 4 elseif PW>=1.75 then node 5 else versicolor
 4  if PL<4.95 then node 6 elseif PL>=4.95 then node 7 else versicolor
 5  class = virginica
 6  if PW<1.65 then node 8 elseif PW>=1.65 then node 9 else versicolor
 7  if PW<1.55 then node 10 elseif PW>=1.55 then node 11 else virginica
 8  class = versicolor
 9  class = virginica
10  class = virginica
11  class = versicolor

Find the minimum-cost tree:

rng(1); % For reproducibility
[c,s,n,best] = test(t,'crossvalidate',meas,species);
tmin = prune(t,'level',best)
view(tmin)
tmin = 

Decision tree for classification
1  if PL<2.45 then node 2 elseif PL>=2.45 then node 3 else setosa
2  class = setosa
3  if PW<1.75 then node 4 elseif PW>=1.75 then node 5 else versicolor
4  class = versicolor
5  class = virginica

Plot the smallest tree within one standard error of the minimum cost tree:

[mincost,minloc] = min(c);
plot(n,c,'b-o',...
     n(best+1),c(best+1),'bs',...
     n,(mincost+s(minloc))*ones(size(n)),'k--')
xlabel('Tree size (number of terminal nodes)')
ylabel('Cost')

The solid line shows the estimated cost for each tree size, the dashed line marks one standard error above the minimum, and the square marks the smallest tree under the dashed line.

References

[1] Breiman, L., J. Friedman, R. Olshen, and C. Stone. Classification and Regression Trees. Boca Raton, FL: CRC Press, 1984.

See Also

| | |

Was this topic helpful?