Assemble Multiple-Output Network for Prediction

Instead of using the model function for prediction, you can assemble the network into a DAGNetwork ready for prediction using the functionToLayerGraph and assembleNetwork functions. This lets you use the predict function.

Load Model Function and Parameters

Load the model parameters from the MAT file digitsMIMO.mat. The MAT file contains the model parameters in a struct named parameters, the model state in a struct named state, and the class names in classNames.

s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;

The model function model, listed at the end of the example, defines the model given the model parameters and state.

Assemble Network for Prediction

Define an anonymous function with a fixed set of model parameters, the model state, and set the doTraining option to false.

doTraining = false;
fun = @(dlX) model(dlX,parameters,doTraining,state);

Convert the model function to a layer graph using the functionToLayerGraph function. Create a variable dlX that contains a mini-batch of data with the expected format.

X = rand(28,28,1,128,'single');
dlX = dlarray(X,'SSCB');
lgraph = functionToLayerGraph(fun,dlX);

The layer graph output by the functionToLayerGraph function does not include input and output layers. Add an input layer, a classification layer, and a regression layer to the layer graph using the addLayers and connectLayers functions.

layers = imageInputLayer([28 28 1],'Name','input','Normalization','none');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'input','conv_1');
layers = classificationLayer('Classes',classNames,'Name','coutput');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'sm_1','coutput');
layers = regressionLayer('Name','routput');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'fc_1','routput');

View a plot of the network.

figure
plot(lgraph)

Assemble the network using the assembleNetwork function.

net = assembleNetwork(lgraph)
net = 
  DAGNetwork with properties:

         Layers: [17×1 nnet.cnn.layer.Layer]
    Connections: [17×2 table]
     InputNames: {'input'}
    OutputNames: {'coutput'  'routput'}

Make Predictions on New Data

Load the test data.

[XTest,Y1Test,Y2Test] = digitTest4DArrayData;

To make predictions using the assembled network, use the predict function. To return categorical labels for the classification output, set the 'ReturnCategorical' option to true.

[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);

Evaluate the classification accuracy.

accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9954

Evaluate the regression accuracy.

angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single
    5.6085

View some of the images with their predictions. Display the predicted angles in red and the correct labels in green.

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on
    
    sz = size(I,1);
    offset = sz/2;
    
    thetaPred = Y2Pred(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--')
    
    thetaValidation = Y2Test(idx(i));
    plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--')
    
    hold off
    label = string(Y1Pred(idx(i)));
    title("Label: " + label)
end

Model Function

The function model takes the input data dlX, the model parameters parameters, the flag doTraining which specifies whether to model should return outputs for training or prediction, and the network state state. The network outputs the predictions for the labels, the predictions for the angles, and the updated network state.

function [dlY1,dlY2,state] = model(dlX,parameters,doTraining,state)

% Convolution
W = parameters.conv1.Weights;
B = parameters.conv1.Bias;
dlY = dlconv(dlX,W,B,'Padding',2);

% Batch normalization, ReLU
Offset = parameters.batchnorm1.Offset;
Scale = parameters.batchnorm1.Scale;
trainedMean = state.batchnorm1.TrainedMean;
trainedVariance = state.batchnorm1.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm1.TrainedMean = trainedMean;
    state.batchnorm1.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
end
dlY = relu(dlY);

% Convolution (Skip connection)
W = parameters.convs.Weights;
B = parameters.convs.Bias;
YSkip = dlconv(dlY,W,B,'Stride',2);

% Convolution
W = parameters.conv2.Weights;
B = parameters.conv2.Bias;
dlY = dlconv(dlY,W,B,'Padding',1,'Stride',2);

% Batch normalization, ReLU
Offset = parameters.batchnorm2.Offset;
Scale = parameters.batchnorm2.Scale;
trainedMean = state.batchnorm2.TrainedMean;
trainedVariance = state.batchnorm2.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm2.TrainedMean = trainedMean;
    state.batchnorm2.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
end
dlY = relu(dlY);

% Convolution
W = parameters.conv3.Weights;
B = parameters.conv3.Bias;
dlY = dlconv(dlY,W,B,'Padding',1);

% Batch normalization, ReLU
Offset = parameters.batchnorm3.Offset;
Scale = parameters.batchnorm3.Scale;
trainedMean = state.batchnorm3.TrainedMean;
trainedVariance = state.batchnorm3.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm3.TrainedMean = trainedMean;
    state.batchnorm3.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
end
dlY = relu(dlY);

% Addition
dlY = YSkip + dlY;

% Fully connect (angles)
W = parameters.fc1.Weights;
B = parameters.fc1.Bias;
dlY2 = fullyconnect(dlY,W,B);

% Fully connect, softmax (labels)
W = parameters.fc2.Weights;
B = parameters.fc2.Bias;
dlY1 = fullyconnect(dlY,W,B);
dlY1 = softmax(dlY1);

end

See Also

| | | | | | | |

Related Topics