Error using trainNetwork (line 183) Invalid network. Caused by: Layer 'fc7': Input size mismatch. Size of input to this layer is different from the expected input size
Show older comments
clc
clear all; close all;
outputFolder=fullfile('recycle101');
rootFolder=fullfile(outputFolder,'recycle');
categories={'Aluminium Can','PET Bottles','Drink Carton Box'};
imds=imageDatastore(fullfile(rootFolder,categories),'LabelSource','foldernames');
tbl=countEachLabel(imds)
minSetCount=min(tbl{:,2});
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
countEachLabel(imds);
%randomly choose file for aluminium can, PET bottles, and drink carton box
AluminiumCan=find(imds.Labels=='Aluminium Can',1);
PETBottles=find(imds.Labels=='PET Bottles',1);
DrinkCartonBox=find(imds.Labels=='Drink Carton Box',1);
%plot image that was pick randomly
figure
subplot(2,2,1);
imshow(readimage(imds,AluminiumCan));
subplot(2,2,2);
imshow(readimage(imds,PETBottles));
subplot(2,2,3);
imshow(readimage(imds,DrinkCartonBox));
%Load pre-trained network
net = alexnet;
%analyzeNetwork(net)
%net =alexnet('Weights','none');
%lys = net.Layers;
%lys(end-3:end)
numClasses = numel(categories(imdsTrain.Labels));
lgraph = layerGraph(net.Layers);
%Replace the classification layers for new task
newFCLayer = fullyConnectedLayer(3,'Name','new_fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'fc6',newFCLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);
%Resize the image
imageSize=net.Layers(1).InputSize;
augmentedTrainingSet=augmentedImageDatastore(imageSize,...
imdsTrain,'ColorPreprocessing','gray2rgb');
augmentedTestSet=augmentedImageDatastore(imageSize,...
imdsValidation,'ColorPreprocessing','gray2rgb');
options = trainingOptions('sgdm', ...
'MiniBatchSize',4, ...
'MaxEpochs',8, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augmentedTestSet, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'ExecutionEnvironment','cpu', ...
'Plots','training-progress');
trainedNet = trainNetwork(augmentedTrainingSet,lgraph,options);
Accepted Answer
More Answers (0)
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!