Documentation

This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English verison of the page.

Note: This page has been translated by MathWorks. Please click here
To view all translated materals including this page, select Japan from the country navigator on the bottom of this page.

kfoldfun

Class: ClassificationPartitionedModel

Cross validate function

Syntax

vals = kfoldfun(CVMdl,fun)

Description

example

vals = kfoldfun(CVMdl,fun) cross validates the function fun by applying fun to the data stored in the cross-validated model CVMdl. You must pass fun as a function handle.

Input Arguments

expand all

Cross-validated model, specified as a ClassificationPartitionedECOC model, ClassificationPartitionedEnsemble model, or a ClassificationPartitionedModel model.

Cross-validated function, specified as a function handle. fun has the syntax

testvals = fun(CMP,Xtrain,Ytrain,Wtrain,Xtest,Ytest,Wtest)
  • CMP is a compact model stored in one element of the CVMdl.Trained property.

  • Xtrain is the training matrix of predictor values.

  • Ytrain is the training array of response values.

  • Wtrain are the training weights for observations.

  • Xtest and Ytest are the test data, with associated weights Wtest.

  • The returned value testvals needs the same size across all folds.

Data Types: function_handle

Output Arguments

expand all

Cross-validation results, returned as an numeric matrix. vals is the arrays of testvals output, concatenated vertically over all folds. For example, if testvals from every fold is a numeric vector of length N, kfoldfun returns a KFold-by-N numeric matrix with one row per fold.

Data Types: double

Examples

expand all

Train a classification tree classifier, and then cross validate it using a custom k-fold loss function.

Load Fisher's iris data set.

load fisheriris

Train a classification tree classifier.

Mdl = fitctree(meas,species);

Mdl is a ClassificationTree model.

Cross validate Mdl using the default 10-fold cross validation. Compute the classification error (proportion of misclassified observations) for the out-of-fold observations.

rng(1); % For reproducibility
CVMdl = crossval(Mdl);
L = kfoldLoss(CVMdl)
L =

    0.0467

Examine the result when the cost of misclassifying a flower as 'versicolor' is 10, and any other error is 1. Write a function called noversicolor.m that attributes a cost of 1 for misclassification, but 10 for misclassifying a flower as versicolor, and save it on your MATLAB® path.

function averageCost = noversicolor(CMP,Xtrain,Ytrain,Wtrain,Xtest,Ytest,Wtest)
%noversicolor Example custom cross-validation function
%   Attributes a cost of 10 for misclassifying versicolor irises, and 1 for
%   the other irises.  This example function requires the |fisheriris| data
%   set.
Ypredict = predict(CMP,Xtest);
misclassified = not(strcmp(Ypredict,Ytest)); % Different result
classifiedAsVersicolor = strcmp(Ypredict,'versicolor'); % Index of bad decisions
cost = sum(misclassified) + ...
    9*sum(misclassified & classifiedAsVersicolor); % Total differences
averageCost = cost/numel(Ytest); % Average error
end


Compute the mean misclassification error with the noversicolor cost.

mean(kfoldfun(CVMdl,@noversicolor))
ans =

    0.2267
Was this topic helpful?