Fitnet: Why do some training functions (trainbfg,traincgf) with default settings fail on the simplefit dataset, while others (trainlm,trainbr) work perfectly well?

1 view (last 30 days)
I have written a script that compares various training functions with their default parameters, using the data returned by simplefit_dataset. I train the networks on half of the points and evaluate the performance on all points. trainlm works well, trainbr works very well, but trainbfg, traincgf and trainrp do not work at all. What is the reason for this? I would have thought that the default parameters were good enough for this almost trivial problem.
Here is the code:
%% Function fitting with a neural network
close all
rng('default') % choose default random number generator
rng(0) % set seed of RNG
% try several architectures and training/optimization functions
% see fitnet documentation
hidsize=8; % number of hidden neurons
mynets{1}=fitnet(hidsize,'trainlm'); % Levenberg-Marquardt (default)
mynets{2}=fitnet(hidsize,'trainbr'); % Bayesian regularization
mynets{3}=fitnet(hidsize,'trainbfg'); % BFGS quasi-Newton
mynets{4}=fitnet(hidsize,'traincgf'); % Fletcher-Powell conjugate gradients
mynets{5}=fitnet(hidsize,'trainrp'); % Resilient backpropagation
% load data and targets
[x,t] = simplefit_dataset;
xx=x(1:2:end); % training inputs
tt=t(1:2:end); % training targets
for inet=1:numnets
makeplots=false; % set true to see figures
if makeplots
ha=plot(x,t,'.-g','displayname','all values'); % test values
hold on
ylim([0 12])
ht=plot(xx,tt,'.-k','displayname','training values');
hold on
hy=plot(x,y,'o-r','displayname','network output');
titstr=['fitnet with ',num2str(hidsize),...
' hidden neurons and training function ',net.trainfcn];
hx=text(1,11,['$\mu=',num2str(mu,2),',\ \sigma=',num2str(sig,2),'$']);
disp(['Mean of errors: ',num2str(Mu)])
Mean of errors: 0.0075002 5.0619e-06 -0.089574 -0.048282 0.040614
disp(['Stdev of errors: ',num2str(Sig)])
Stdev of errors: 0.022042 2.8171e-05 0.36157 0.57972 0.3814

Answers (1)

Aditya on 24 Jan 2023
Edited: Aditya on 3 Feb 2023
I understand that you want to know why some network training functions are performing poorly as compared to others.
The evaluation strategy might not give you the correct picture of how the network performs. As you are training the network on half the points, you cannot use those points in evaluation because network has already "seen" those points. In some network, suppose the network is overfitting on the training dataset, you will get very good results when evaluating. There should be no overlap in the training and testing dataset.
Here is the same example with different training and testing splits.
[x,y] = simplefit_dataset;
train_x = x(1:80);
test_x = x(81:end);
train_y = y(1:80);
test_y = y(81:end);
nets = {'trainlm','trainbr','trainbfg','traincgf','trainrp'};
hidsize = 8;
for i =1:5
net = fitnet(hidsize, nets{i});
net=train(net,train_x, train_y);
Y = sim(net,test_x);
perf = perform(net,Y,test_y);
disp(nets{i} + " Mean Squared Error = " + perf);
Your observation regarding the performance of the network with 8 hidden layers is correct. However, if you change the number of hidden layers to 15, for example, you will notice different results.
With deep learning, the goal is to have generalization. So, for comparing training functions, you should perform the comparison on different network sizes to come at a conclusion.
Aditya on 3 Feb 2023
Edited: Aditya on 3 Feb 2023
It might be okay to check if training data loss is going down or not to evaluate whether network is learning.
However, using training dataset to compare the performance of networks might not give a complete picture. The simple reason is overfitting networks would perform better. Generalization matters.
Yes, you are right in the observation that different weight initialization lead to faster convergence. The initial point can determine whether the algorithm converges at all, with some initial points being so unstable that the algorithm encounters numerical difficulties and fails altogether.

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!