Documentation

This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English verison of the page.

Note: This page has been translated by MathWorks. Please click here
To view all translated materals including this page, select Japan from the country navigator on the bottom of this page.

predict

Predict responses using a trained deep learning neural network

You can make predictions using a trained neural network for deep learning on either a CPU or GPU. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher. Specify the hardware requirements using the ExecutionEnvironment name-value pair argument.

Syntax

YPred = predict(net,X)
YPred = predict(net,C)
YPred = predict(net,X,Name,Value)

Description

YPred = predict(net,X) predicts responses for the image data in X using the trained network net.

example

YPred = predict(net,C) predicts responses for the sequence or time series data in C using the trained LSTM network net.

example

YPred = predict(net,X,Name,Value) predicts responses with additional options specified by one or more name-value pair arguments.

Examples

collapse all

Load the sample data.

[XTrain,TTrain] = digitTrain4DArrayData;

digitTrain4DArrayData loads the digit training set as 4-D array data. XTrain is a 28-by-28-by-1-by-4940 array, where 28 is the height and 28 is the width of the images. 1 is the number of channels and 4940 is the number of synthetic images of handwritten digits. TTrain is a categorical vector containing the labels for each observation.

Construct the convolutional neural network architecture.

layers = [imageInputLayer([28 28 1]);
          convolution2dLayer(5,20);
          reluLayer();
          maxPooling2dLayer(2,'Stride',2);
          fullyConnectedLayer(10);
          softmaxLayer();
          classificationLayer()];

Set the options to default settings for the stochastic gradient descent with momentum.

options = trainingOptions('sgdm');

Train the network.

rng(1)
net = trainNetwork(XTrain,TTrain,layers,options);
Training on single CPU.
Initializing image normalization.
|=========================================================================================|
|     Epoch    |   Iteration  | Time Elapsed |  Mini-batch  |  Mini-batch  | Base Learning|
|              |              |  (seconds)   |     Loss     |   Accuracy   |     Rate     |
|=========================================================================================|
|            1 |            1 |         0.38 |       2.3028 |       11.72% |       0.0100 |
|            2 |           50 |         9.96 |       2.2653 |       30.47% |       0.0100 |
|            3 |          100 |        18.54 |       1.5949 |       48.44% |       0.0100 |
|            4 |          150 |        26.66 |       1.2292 |       58.59% |       0.0100 |
|            6 |          200 |        35.55 |       1.0559 |       64.06% |       0.0100 |
|            7 |          250 |        44.03 |       1.0304 |       64.06% |       0.0100 |
|            8 |          300 |        52.91 |       0.7178 |       78.12% |       0.0100 |
|            9 |          350 |        61.73 |       0.6900 |       78.12% |       0.0100 |
|           11 |          400 |        70.14 |       0.5104 |       85.94% |       0.0100 |
|           12 |          450 |        78.91 |       0.4311 |       89.06% |       0.0100 |
|           13 |          500 |        86.82 |       0.2796 |       92.19% |       0.0100 |
|           15 |          550 |        94.75 |       0.2389 |       96.09% |       0.0100 |
|           16 |          600 |       101.75 |       0.2566 |       92.97% |       0.0100 |
|           17 |          650 |       109.46 |       0.1773 |       96.88% |       0.0100 |
|           18 |          700 |       117.29 |       0.1260 |       99.22% |       0.0100 |
|           20 |          750 |       124.93 |       0.1297 |      100.00% |       0.0100 |
|           21 |          800 |       132.63 |       0.1080 |       97.66% |       0.0100 |
|           22 |          850 |       140.13 |       0.1176 |       98.44% |       0.0100 |
|           24 |          900 |       147.61 |       0.0762 |      100.00% |       0.0100 |
|           25 |          950 |       155.19 |       0.0774 |      100.00% |       0.0100 |
|           26 |         1000 |       162.59 |       0.0877 |       99.22% |       0.0100 |
|           27 |         1050 |       170.86 |       0.0645 |       99.22% |       0.0100 |
|           29 |         1100 |       178.18 |       0.0624 |      100.00% |       0.0100 |
|           30 |         1150 |       185.41 |       0.0488 |      100.00% |       0.0100 |
|           30 |         1170 |       188.42 |       0.0816 |       99.22% |       0.0100 |
|=========================================================================================|

Run the trained network on a test set and predict the scores.

[XTest,TTest]= digitTest4DArrayData;
YTestPred = predict(net,XTest);

predict, by default, uses a CUDA-enabled GPU with compute capability 3.0, when available. You can also choose to run predict on a CPU using the 'ExecutionEnvironment','cpu' name-value pair argument.

Display the first 10 images in the test data and compare to the predictions from predict.

TTest(1:10,:)
ans = 

  10x1 categorical array

     0 
     0 
     0 
     0 
     0 
     0 
     0 
     0 
     0 
     0 

YTestPred(1:10,:)
ans =

  10x10 single matrix

  Columns 1 through 7

    0.9993    0.0000    0.0002    0.0003    0.0000    0.0000    0.0001
    0.8578    0.0000    0.0552    0.0003    0.0000    0.0002    0.0139
    0.9999    0.0000    0.0000    0.0000    0.0000    0.0000    0.0000
    0.9558    0.0000    0.0000    0.0000    0.0000    0.0000    0.0060
    0.9616    0.0000    0.0041    0.0001    0.0000    0.0000    0.0004
    0.9915    0.0000    0.0005    0.0000    0.0000    0.0000    0.0016
    0.9733    0.0000    0.0003    0.0000    0.0000    0.0000    0.0247
    1.0000    0.0000    0.0000    0.0000    0.0000    0.0000    0.0000
    0.9126    0.0000    0.0016    0.0002    0.0003    0.0007    0.0001
    0.9408    0.0000    0.0102    0.0020    0.0001    0.0001    0.0278

  Columns 8 through 10

    0.0000    0.0000    0.0002
    0.0001    0.0035    0.0690
    0.0000    0.0000    0.0001
    0.0000    0.0010    0.0372
    0.0002    0.0335    0.0002
    0.0000    0.0044    0.0020
    0.0000    0.0016    0.0001
    0.0000    0.0000    0.0000
    0.0000    0.0012    0.0834
    0.0000    0.0143    0.0047

TTest contains the digits corresponding to the images in XTest. The columns of YTestPred contain predict’s estimation of a probability that an image contains a particular digit. That is, the first column contains the probability estimate that the given image is digit 0, the second column contains the probability estimate that the image is digit 1, the third column contains the probability estimate that the image is digit 2, and so on. You can see that predict’s estimation of probabilities for the correct digits are almost 1 and the probability for any other digit is almost 0. predict correctly estimates the first 10 observations as digit 0.

Load pretrained network. JapaneseVowelsNet is a pretrained LSTM network trained on the Japanese Vowels dataset as described in [1] and [2]. It was trained on the sequences sorted by sequence length with a mini-batch size of 27.

load JapaneseVowelsNet

View the network architecture.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1', '2', and 7 other classes

Load the test data.

load JapaneseVowelsTest

Make predictions on the test data.

YPred = predict(net,XTest);

View the prediction scores for the first 10 sequences.

YPred(1:10,:)
ans = 10×9 single matrix

    0.9918    0.0000    0.0000    0.0000    0.0006    0.0010    0.0001    0.0006    0.0059
    0.9868    0.0000    0.0000    0.0000    0.0006    0.0010    0.0001    0.0010    0.0105
    0.9924    0.0000    0.0000    0.0000    0.0006    0.0010    0.0001    0.0006    0.0054
    0.9896    0.0000    0.0000    0.0000    0.0006    0.0009    0.0001    0.0007    0.0080
    0.9965    0.0000    0.0000    0.0000    0.0007    0.0009    0.0000    0.0003    0.0016
    0.9888    0.0000    0.0000    0.0000    0.0006    0.0010    0.0001    0.0008    0.0087
    0.9886    0.0000    0.0000    0.0000    0.0006    0.0010    0.0001    0.0008    0.0089
    0.9982    0.0000    0.0000    0.0000    0.0006    0.0007    0.0000    0.0001    0.0004
    0.9883    0.0000    0.0000    0.0000    0.0006    0.0010    0.0001    0.0008    0.0093
    0.9959    0.0000    0.0000    0.0000    0.0007    0.0011    0.0000    0.0004    0.0019

Compare these prediction scores to the labels of these sequences. The function assigns high prediction scores to the correct class.

YTest(1:10)
ans = 10×1 categorical array
     1 
     1 
     1 
     1 
     1 
     1 
     1 
     1 
     1 
     1 

Input Arguments

collapse all

Trained network, specified as a SeriesNetwork or a DAGNetwork object. You can get a trained network by importing a pretrained network (for example, by using the alexnet function) or by training your own network using trainNetwork.

Image data, specified as a 3-D array representing a single image, a 4-D array of images, an ImageDatastore, or a table.

  • If X is a 3-D array representing a single image, then X is a h-by-w-c array where h, w, and c correspond to the height, width, and the number of channels of the image respectively.

  • If X is a 4-D array images, then X is a h-by-w-c-by-N array where N is the number of images.

  • If X is a table, then the first column contains either paths to images, or 3-D arrays representing images.

For more information about image datastores, see ImageDatastore.

Sequence or time series data, specified as a matrix representing a single time series, or a cell array of matrices representing multiple time series.

  • If C is a matrix representing a single time series, then C is a D-by-S matrix, where D is the number of data points per time step, and S is the number of time steps.

  • If C is a cell array of time series, then C is an N-by-1 cell array, where N is the number of observations. Each entry of C is a time series represented by a matrix, with rows corresponding to data points and columns corresponding to time steps.

Name-Value Pair Arguments

Example: 'MiniBatchSize',256 specifies the mini-batch size as 256.

Specify optional comma-separated pair of Name,Value argument. Name is the argument name and Value is the corresponding value. Name must appear inside single quotes (' ').

collapse all

Size of mini-batches to use for prediction, specified as a positive integer. Larger mini-batch sizes require more memory, but can lead to faster predictions.

Example: 'MiniBatchSize',256

Hardware resource, specified as the comma-separated pair consisting of 'ExecutionEnvironment' and one of the following:

  • 'auto' — Use a GPU if one is available; otherwise, use the CPU.

  • 'gpu' — Use the GPU. Using a GPU requires Parallel Computing Toolbox and a CUDA enabled NVIDIA GPU with compute capability 3.0 or higher. If Parallel Computing Toolbox or a suitable GPU is not available, then the software returns an error.

  • 'cpu' — Use the CPU.

Example: 'ExecutionEnvironment','cpu'

Option to pad, truncate, or split input sequences, specified as one of the following:

  • 'longest' — Pad sequences in the each mini-batch to have the same length as the longest sequence.

  • 'shortest' — Truncate sequences in each mini-batch to have the same length as the shortest sequence.

  • Positive integer — Pad sequences in each mini-batch to have the same length as the longest sequence, then split into smaller sequences of the specified length. If splitting occurs, then the function creates extra mini-batches.

Example: 'SequenceLength','shortest'

Value by which to pad input sequences, specified as a scalar. The option is valid only when SequenceLength is 'longest' or a positive integer. Do not pad sequences with NaN, because doing so can propagate errors throughout the network.

Example: 'SequencePaddingValue',-1

Output Arguments

collapse all

Predicted scores, returned as one of the following:

  • For image and sequence to label classification networks, YPred is a N-by-K matrix, where N is the number of observations, and K is the number of classes.

  • For sequence-to-sequence classification networks, YPred is a N-by-1 cell array of matrices, where N is the number of observations. Each entry of YPred is a time series represented by a K-by-S matrix where K is the number of classes, and S is the total number of time steps in the corresponding entry in X.

  • For regression networks, YPred can be one of the following:

    • N-by-r matrix, where N is the number of observations and r is the number of responses

    • h-by-w-by-c-by-N numeric array, where N is the number of observations and h-by-w-by-c is the size of a single response.

Algorithms

If the image data contains NaNs, predict propagates them through the network. If the network has ReLU layers, these layers ignore NaNs. However, if the network does not have a ReLU layer, then predict returns NaNs as predictions.

Alternatives

You can compute the predicted scores and the predicted classes from a trained network using classify.

You can also compute the activations from a network layer using activations.

For sequence-to-label and sequence-to-sequence classification networks (LSTM networks), you can make predictions and update the network state using classifyAndUpdateState and predictAndUpdateState.

References

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Introduced in R2016a

Was this topic helpful?