Code covered by the BSD License  

Highlights from
Neural Network add-in for PSORT

image thumbnail

Neural Network add-in for PSORT

by

 

28 Nov 2010 (Updated )

This add-in allows a Neural Network to be trained by Particle Swarm Optimization technique.

trainpso(net, tr, trainV, valV, testV, varargin)
%   Created/Modified: 2010/09/18
%   By: Tricia Rambharose
%   info@tricia-rambharose.com

%   This is a mofication of existing neural network training algorithms
%   provided in Matlab's Neural Network toolbox, using ideas from Brian
%   Birge's PSO toolbox @
%   http://www.mathworks.com/matlabcentral/fileexchange/7506-particle-swarm-optimization-toolbox

% This PSO training function aims to be the interface between the Matlab NN
% toolbox and the PSO Research Toolbox. For consistency, the formatting and
% structure of other NN training algorithms are followed. 
% Training here is based on the idea that
% all weight and bias values are determined using a swarm optimization
% approach.

function [net, tr] = trainpso(net, tr, trainV, valV, testV, varargin)

%set global variables to be used by several functions
global first_pso

first_pso = true ;%Must not be changed! Flag to indicate to PSO toolbox the first time the PSO algorithm is run to prevent redundant params validation and display. 

%% Info-Check if network information requested
if strcmp(net,'info')
  % Define function information and default values
  info.function = mfilename;
  info.title = 'Particle Swarm Optimization NN training';
  info.type = 'Training';
  info.version = 6;
  info.training_mode = 'Supervised';
  info.gradient_mode = 'No Gradient';%This method does not use the gradient of the performance function to adjust the weights and biases, but instead a PSO determination of weights and biases based on the performace function as the objective function
  info.uses_validation = false; %should NN training be validated?
  info.param_defaults.show = 1000;
  info.param_defaults.epochs = 10;%max number of epoch to train the NN
  info.param_defaults.time = inf;%max time to train the NN
  info.param_defaults.goal = 0.3; %maximum performance value that is accepted
  info.param_defaults.max_fail = 6; %max failures of the NN to reach the goal
  info.param_defaults.max_perf_inc = 1.04;
  info.param_defaults.showCommandLine = true;
  info.param_defaults.showWindow = false;
  info.param_defaults.plotPSO = true;%option for PSO scatter plot
  net = info;
  return
end

%% NNET 5.1 Backward Compatibility
if ischar(net)
  switch (net)
    case 'name', info = feval(mfilename,'info'); net = info.title;
    case 'pnames', info = feval(mfilename,'info'); net = fieldnames(info.param_defaults);
    case 'pdefaults', info = feval(mfilename,'info'); net = info.param_defaults;
    case 'gdefaults', if (tr==0), net='calcgrad'; else net='calcgbtt'; end
    otherwise, error('NNET:Arguments','Unrecognized code.')
  end
  return
end


%% CALCULATION
%Create and initialize variables for network parameters set in info section above
epochs = net.trainParam.epochs; %number of times to train network using given set of inputs
goal = net.trainParam.goal;
max_fail = net.trainParam.max_fail; 
max_perf_inc = net.trainParam.max_perf_inc;
show = net.trainParam.show;
time = net.trainParam.time;

%Parameter Checking - must correspond to Info and Parameters defined above
if (~isa(epochs,'double')) || (~isreal(epochs)) || (any(size(epochs)) ~= 1) || ...
  (epochs < 1) || (round(epochs) ~= epochs)
  error('NNET:Arguments','Epochs is not a positive integer.')
end
if (~isa(goal,'double')) || (~isreal(goal)) || (any(size(goal)) ~= 1) || ...
  (goal < 0)
  error('NNET:Arguments','Goal is not zero or a positive real value.')
end
if (~isa(max_fail,'double')) || (~isreal(max_fail)) || (any(size(max_fail)) ~= 1) || ...
  (max_fail < 1) || (round(max_fail) ~= max_fail)
  error('NNET:Arguments','Max_fail is not a positive integer.')
end
if (~isa(max_perf_inc,'double')) || (~isreal(max_perf_inc)) || (any(size(max_perf_inc)) ~= 1) || ...
  (max_perf_inc < 1)
  error('NNET:Arguments','Max_perf_inc is not a positive real value greater or equal to 1.0.')
end
if (~isa(show,'double')) || (~isreal(show)) || (any(size(show)) ~= 1) || ...
  (isfinite(show) && ((show < 1) || (round(show) ~= show)))
  error('NNET:Arguments','Show is not ''NaN'' or a positive integer.')
end
if (~isa(time,'double')) || (~isreal(time)) || (any(size(time)) ~= 1) || ...
  (time < 0)
  error('NNET:Arguments','Time is not zero or a positive real value.')
end

%Initialize training variables
Q = trainV.Q; %Training data parameter
TS = trainV.TS; %Training data parameter
startTime = clock;
NN_wb = getx(net); %NN_wb is all network weight and bias values as a single vector
original_net = net;

%Initialize training record parameters
tr.states = {'epoch','time','perf'}; %Name of training state variables to keep track of as the network is trained

if net.trainParam.plotPSO % If plot requested, create a figure.
    fig = figure;
end

%Adjust network weights and biases-aim to determine new weights such that perf is minimized
Control_Panel; % Call to PSO Research Toolbox
  
NN_wb = g(1, :); %g hold the position vector of the PSO result. Each call to RegPSO_main clears the values in g
net = setx(net, NN_wb);%Set network with new weights and bias
[perf,El,trainV.Y,Ac,N,Zb,Zi,Zl] = calcperf2(net,NN_wb,trainV.Pd,trainV.Tl,trainV.Ai,Q,TS);  %Calculate network performance using new weights and biases and training data
  
tr = tr_clip(tr); % remove unnecessary values at end of tr state values arrays
end

Contact us