Assemble Multiple-Output Network for Prediction

This example shows how to assemble a multiple output network for prediction.

Instead of using the dlnetwork object for prediction, you can assemble the network into a DAGNetwork ready for prediction using the assembleNetwork function. This lets you use the predict function with other data types such as datastores.

Load Model Function and Parameters

Load the model parameters from the MAT file dlnetDigits.mat. The MAT file contains a dlnetwork object that predicts both the scores for categorical labels and numeric angles of rotation of images of digits, and the corresponding class names.

s = load("dlnetDigits.mat");
dlnet = s.dlnet;
classNames = s.classNames;

Assemble Network for Prediction

Extract the layer graph from the dlnetwork object using the layerGraph function.

lgraph = layerGraph(dlnet);

The layer graph does not include output layers. Add a classification layer and a regression layer to the layer graph using the addLayers and connectLayers functions.

layers = classificationLayer('Classes',classNames,'Name','coutput');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'softmax','coutput');

layers = regressionLayer('Name','routput');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'fc2','routput');

View a plot of the network.


Assemble the network using the assembleNetwork function.

net = assembleNetwork(lgraph)
net = 
  DAGNetwork with properties:

         Layers: [19x1 nnet.cnn.layer.Layer]
    Connections: [19x2 table]
     InputNames: {'in'}
    OutputNames: {'coutput'  'routput'}

Make Predictions on New Data

Load the test data.

[XTest,Y1Test,Y2Test] = digitTest4DArrayData;

To make predictions using the assembled network, use the predict function. To return categorical labels for the classification output, set the 'ReturnCategorical' option to true.

[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);

Evaluate the classification accuracy.

accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9870

Evaluate the regression accuracy.

angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single

View some of the images with their predictions. Display the predicted angles in red and the correct labels in green.

idx = randperm(size(XTest,4),9);
for i = 1:9
    I = XTest(:,:,:,idx(i));
    hold on
    sz = size(I,1);
    offset = sz/2;
    thetaPred = Y2Pred(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--')
    thetaValidation = Y2Test(idx(i));
    plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--')
    hold off
    label = string(Y1Pred(idx(i)));
    title("Label: " + label)

See Also

| | | | | |

Related Topics