mnist classification using batch method
Show older comments
Hi. I want to train a neural network with mnist database using batch method. I use below code but my accuracy is very low. but I think the code is correct. can any one help me please?
function [hiddenWeights, outputWeights, error] = train_network_batch(numberOfHiddenUnits, input, target, epochs, batchSize, learningRate,lambda)
% The number of training vectors.
trainingSetSize = size(input, 2);
% Input vector has 784 dimensions.
inputDimensions = size(input, 1);
% We have to distinguish 10 digits.
outputDimensions = size(target, 1);
% Initialize the weights for the hidden layer and the output layer.
% hiddenWeights = randn(NHiddenUnit, inputDimensions)*1/sqrt(size(input, 1));
% outputWeights = randn(outputDimensions, NHiddenUnit)*1/sqrt(size(input, 1));
hiddenWeights = rand(numberOfHiddenUnits, inputDimensions);
outputWeights = rand(outputDimensions, numberOfHiddenUnits);
hiddenWeights = hiddenWeights./size(hiddenWeights, 2);
outputWeights = outputWeights./size(outputWeights, 2);
hiddenWeights_store = hiddenWeights;
outputWeights_store = outputWeights;
n = zeros(batchSize,1);
validation_count=0;
validation_accuracy=0;
figure; hold on;
%batch method
for t = 1: epochs
for k = 1: batchSize
% Select which input vector to train on.
n(k) = floor(rand(1)*trainingSetSize + 1);
% n(k) =k;
% Propagate the input vector through the network.
inputVector = input(:, n(k));
hiddenActualInput = hiddenWeights*inputVector;
hiddenOutputVector = linear_func(hiddenActualInput);
outputActualInput = outputWeights*hiddenOutputVector;
outputVector = linear_func(outputActualInput);
targetVector = target(:, n(k));
% Backpropagate the errors.
outputDelta = dlinear_func(outputActualInput).*(outputVector - targetVector);
hiddenDelta = dlinear_func(hiddenActualInput).*(outputWeights'*outputDelta);
% outputWeights_store = outputWeights_store -(learningRate*lambda/batchSize).*outputWeights- learningRate.*outputDelta*hiddenOutputVector'; hiddenWeights_store = hiddenWeights_store -(learningRate*lambda/batchSize).*hiddenWeights-learningRate.*hiddenDelta*inputVector';
% outputWeights =(1-(learningRate*lambda/batchSize)).*outputWeights - learningRate.*outputDelta*hiddenOutputVector'; % hiddenWeights = (1-(learningRate*lambda/batchSize)).*hiddenWeights - learningRate.*hiddenDelta*inputVector';
end;
outputWeights=outputWeights+(outputWeights_store./batchSize);
hiddenWeights=hiddenWeights+(hiddenWeights_store./batchSize);
outputWeights_store=0;
hiddenWeights_store=0;
% %*********************************end of batch method*************** % Calculate the error for plotting. error = 0; for k = 1: batchSize inputVector = input(:, n(k)); targetVector = target(:, n(k));
error = error + norm(linear_func(outputWeights*linear_func(hiddenWeights*inputVector)) - targetVector, 2);
end;
error = error/batchSize;
plot(t, error,'*');
title(['MSE_ batch','NH= ',num2str(numberOfHiddenUnits),',',' alfa=',num2str(learningRate),' ,epoch=',num2str(epochs)]);
xlabel('epoch');
ylabel('cost');
inputValues=load('validation.mat');
inputValues=inputValues.v;
labels=load('label.mat');
labels=labels.l;
[correctlyClassified, classificationErrors]=validation_network(hiddenWeights,outputWeights,inputValues',labels);
correctlyClassified=correctlyClassified/10000;
if correctlyClassified<= validation_accuracy
validation_count=validation_count+1;
else
validation_count=0;
end
if validation_count>7
break;
end
validation_accuracy=correctlyClassified;
end;
end
Accepted Answer
More Answers (0)
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!