MATLAB Answers

K-fold cross-validation neural networks

405 views (last 30 days)
Vincent on 25 Mar 2013
Answered: Greg Heath on 5 Dec 2014
Hi all,
I’m fairly new to ANN and I have a question regarding the use of k-fold cross-validation in the search of the optimal number of neurons.
I already read the very useful ANN FAQ (relevant extract: ) and looked in the archive of this site, but I would like to make sure the following procedure is right! My questions are at the bottom of the post…
Problem description: mapping problem (input->output, all doubles) using MLP and the LM backprop algorithm (default) with one hidden layer.
1. Load data: 2 input vectors (“input1” and “input2”) and 1 output vector (“output1”), all containing 600 values;
2. Divide the data set in training and “testing” set for the cross-validation:
k = 10;
cv = cvpartition(length(input1),'kfold',k);
for i=1:k
trainIdxs{i} = find(training(cv,i));
testIdxs{i} = find(test(cv,i));
trainMatrix{i} = [input1(trainIdxs{i}) input2(trainIdxs{i}) output(trainIdxs{i})];
validMatrix{i} = [input1(testIdxs{i}) input2(testIdxs{i}) output(testIdxs{i})];
3. I create a feedforwardnet with hidden nodes and change some training parameters to prevent early stopping (values are chosen which cannot be achieved in this case). There’s probably a more elegant way to do this, but I didn’t find it quickly in the documentation.
net = feedforwardnet(nr_hidden_nodes);
net.divideFcn = '';
net.trainParam.epochs = 30;
net.trainParam.max_fail = 500;
net.trainParam.min_grad = 0.000000000000001;
4. The network is trained, and the performance of the traning and testing sets is calculated (MSE) inside a loop over the number of folds:
for i=1:k
[net,tr] = train(net,trainMatrix{i}(:,1:2)',trainMatrix{i}(:,3)');
% Removed the simple code to calculate the MSE.
5. Operations 3 and 4 are repeated for several values of nr_hidden_nodes. The division in training and testing data remains exactly the same.
Now I have the following questions:
a. Should I re-initialize the network “net” after each training? And how can this be done in MATLAB? The command “init” sets it to random weights…
b. Is the above-described procedure correct in general? I intend to use the nr_hidden_nodes corresponding to the minimum of (MSE(training set) + MSE(testing set)).
c. Since the data is divided randomly in a training and testing set (cf. “cvpartition”), should I repeat the experiment multiple times?
Thanks in advance!


Angela Guerrero
Angela Guerrero on 4 Dec 2014
I don't know if it is useful in this moment, but I consider you should repeat the experiment k times, in this case 10 times, to have all the data cover.
Greg Heath
Greg Heath on 5 Dec 2014
That should be understood. Ideally, each example is a test example at least once and a validation example at least once.

Sign in to comment.

Accepted Answer

Greg Heath
Greg Heath on 5 Dec 2014
For NN training it would be more useful to have a 3 part stratified division into train, val and test sets.
To make life easy, one could choose Nval = Ntst = M = round(N/k) and Ntrn = N-2*M.
Then choose the indices so that each of k*M examples is in the test set at least once and the validation set at least once.
For N=94, k= 10, Nval = Ntst = 9, Ntrn = 76
The val and test indices are used k times with the scrambled index vector
S = randperm(N) in order to get an unbiased selection of indices for each of the k folds.
If N is sufficiently large, the fact that abs(N-k*M) examples will never be in one of the two nontraining subsets will not be significant. Otherwise, additional code might be desired.
Of course, one alternative to using a validation set is to use regularization to avoid overtraining an overfit net. In general, it is done one of two ways
trainbr with it's default form of msereg
fitnet or patternnet with the regularization option
Note that even though the default performance function for patternnet is crossentropy, the regularization option should still work.
That brings up the question of whether trainbr can be used with crossentropy.
Hope this helps.
Thank you for formally accepting my answer


Sign in to comment.

More Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!