Documentation Center

  • Trial Software
  • Product Updates

kfoldfun

Class: ClassificationPartitionedModel

Cross validate function

Syntax

vals = kfoldfun(obj,fun)

Description

vals = kfoldfun(obj,fun) cross validates the function fun by applying fun to the data stored in the cross-validated model obj. You must pass fun as a function handle.

Input Arguments

obj

Object of class ClassificationPartitionedModel or ClassificationPartitionedEnsemble.

fun

A function handle for a cross-validation function. fun has the syntax

testvals = fun(CMP,Xtrain,Ytrain,Wtrain,Xtest,Ytest,Wtest)
  • CMP is a compact model stored in one element of the obj.Trained property.

  • Xtrain is the training matrix of predictor values.

  • Ytrain is the training array of response values.

  • Wtrain are the training weights for observations.

  • Xtest and Ytest are the test data, with associated weights Wtest.

  • The returned value testvals must have the same size across all folds.

Output Arguments

vals

The arrays of testvals output, concatenated vertically over all folds. For example, if testvals from every fold is a numeric vector of length N, kfoldfun returns a KFold-by-N numeric matrix with one row per fold.

Examples

Cross validate a classification tree, and obtain the classification error (see kfoldLoss):

load fisheriris
t = fitctree(meas,species);
rng(0,'twister') % for reproducibility
cv = crossval(t);
L = kfoldLoss(cv)

L =
    0.0467

Examine the result when the error of misclassifying a flower as 'versicolor' is 10, and any other error is 1:

  1. Write a function file that gives a cost of 1 for misclassification, but 10 for misclassifying a flower as versicolor.

    function averageCost = noversicolor(CMP,Xtrain,Ytrain,Wtrain,Xtest,Ytest,Wtest)
    
    Ypredict = predict(CMP,Xtest);
    misclassified = not(strcmp(Ypredict,Ytest)); % different result
    classifiedAsVersicolor = strcmp(Ypredict,'versicolor'); % index of bad decisions
    cost = sum(misclassified) + ...
        9*sum(misclassified & classifiedAsVersicolor); % total differences
    averageCost = cost/numel(Ytest); % average error

  2. Save the file as noversicolor.m on your MATLAB® path.

  3. Compute the mean misclassification error with the noversicolor cost:

    mean(kfoldfun(cv,@noversicolor))
    
    ans =
        0.1667

See Also

| | | | |

Was this topic helpful?