mnist classification using batch method

6 views (last 30 days)
Hi. I want to train a neural network with mnist database using batch method. I use below code but my accuracy is very low. but I think the code is correct. can any one help me please?
function [hiddenWeights, outputWeights, error] = train_network_batch(numberOfHiddenUnits, input, target, epochs, batchSize, learningRate,lambda)
% The number of training vectors.
trainingSetSize = size(input, 2);
% Input vector has 784 dimensions.
inputDimensions = size(input, 1);
% We have to distinguish 10 digits.
outputDimensions = size(target, 1);
% Initialize the weights for the hidden layer and the output layer.
% hiddenWeights = randn(NHiddenUnit, inputDimensions)*1/sqrt(size(input, 1));
% outputWeights = randn(outputDimensions, NHiddenUnit)*1/sqrt(size(input, 1));
hiddenWeights = rand(numberOfHiddenUnits, inputDimensions);
outputWeights = rand(outputDimensions, numberOfHiddenUnits);
hiddenWeights = hiddenWeights./size(hiddenWeights, 2);
outputWeights = outputWeights./size(outputWeights, 2);
hiddenWeights_store = hiddenWeights;
outputWeights_store = outputWeights;
n = zeros(batchSize,1);
validation_count=0;
validation_accuracy=0;
figure; hold on;
%batch method
for t = 1: epochs
for k = 1: batchSize
% Select which input vector to train on.
n(k) = floor(rand(1)*trainingSetSize + 1);
% n(k) =k;
% Propagate the input vector through the network.
inputVector = input(:, n(k));
hiddenActualInput = hiddenWeights*inputVector;
hiddenOutputVector = linear_func(hiddenActualInput);
outputActualInput = outputWeights*hiddenOutputVector;
outputVector = linear_func(outputActualInput);
targetVector = target(:, n(k));
% Backpropagate the errors.
outputDelta = dlinear_func(outputActualInput).*(outputVector - targetVector);
hiddenDelta = dlinear_func(hiddenActualInput).*(outputWeights'*outputDelta);
% outputWeights_store = outputWeights_store -(learningRate*lambda/batchSize).*outputWeights- learningRate.*outputDelta*hiddenOutputVector'; hiddenWeights_store = hiddenWeights_store -(learningRate*lambda/batchSize).*hiddenWeights-learningRate.*hiddenDelta*inputVector';
% outputWeights =(1-(learningRate*lambda/batchSize)).*outputWeights - learningRate.*outputDelta*hiddenOutputVector'; % hiddenWeights = (1-(learningRate*lambda/batchSize)).*hiddenWeights - learningRate.*hiddenDelta*inputVector';
end;
outputWeights=outputWeights+(outputWeights_store./batchSize);
hiddenWeights=hiddenWeights+(hiddenWeights_store./batchSize);
outputWeights_store=0;
hiddenWeights_store=0;
% %*********************************end of batch method*************** % Calculate the error for plotting. error = 0; for k = 1: batchSize inputVector = input(:, n(k)); targetVector = target(:, n(k));
error = error + norm(linear_func(outputWeights*linear_func(hiddenWeights*inputVector)) - targetVector, 2);
end;
error = error/batchSize;
plot(t, error,'*');
title(['MSE_ batch','NH= ',num2str(numberOfHiddenUnits),',',' alfa=',num2str(learningRate),' ,epoch=',num2str(epochs)]);
xlabel('epoch');
ylabel('cost');
inputValues=load('validation.mat');
inputValues=inputValues.v;
labels=load('label.mat');
labels=labels.l;
[correctlyClassified, classificationErrors]=validation_network(hiddenWeights,outputWeights,inputValues',labels);
correctlyClassified=correctlyClassified/10000;
if correctlyClassified<= validation_accuracy
validation_count=validation_count+1;
else
validation_count=0;
end
if validation_count>7
break;
end
validation_accuracy=correctlyClassified;
end;
end

Accepted Answer

Greg Heath
Greg Heath on 5 Dec 2015
Edited: Greg Heath on 5 Dec 2015
1. I don't think that anyone wants to wade through all of that code when you can just use MATLAB classification functions
help PATTERNNET
doc PATTERNNET
2. If none of your hidden or output functions is nonlinear, then all you have is a complicated linear classifier which can be implemented with BACKSLASH.
Hope this helps.
Thank you for formally accepting my answer
Greg

More Answers (0)

Tags

No tags entered yet.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!