Classify data using a trained recurrent neural network and update the network state
You can make predictions using a trained deep learning network on either a CPU
or GPU. Using a GPU requires
Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Support by Release (Parallel Computing Toolbox). Specify the hardware requirements using the 'ExecutionEnvironment'
name-value pair argument.
[
classifies the data in updatedNet
,YPred
] = classifyAndUpdateState(recNet
,sequences
)sequences
using the trained recurrent
neural network recNet
and updates the network state.
This function supports recurrent neural networks only. The input
recNet
must have at least one recurrent layer.
[
uses any of the arguments in the previous syntaxes and additional options specified
by one or more updatedNet
,YPred
] = classifyAndUpdateState(___,Name,Value
)Name,Value
pair arguments. For example,
'MiniBatchSize',27
classifies data using mini-batches of size
27
[
uses any of the arguments in the previous syntaxes, returns a matrix of
classification scores, and updates the network state.updatedNet
,YPred
,scores
] = classifyAndUpdateState(___)
Tip
When making predictions with sequences of different lengths, the mini-batch size can impact the amount of padding added to the input data which can result in different predicted values. Try using different values to see which works best with your network. To specify mini-batch size and padding options, use the 'MiniBatchSize'
and 'SequenceLength'
options, respectively.
When you train a network using the trainNetwork
function, or when you use prediction or validation functions
with DAGNetwork
and
SeriesNetwork
objects, the software performs these computations using single-precision, floating-point
arithmetic. Functions for training, prediction, and validation include trainNetwork
, predict
,
classify
, and
activations
.
The software uses single-precision arithmetic when you train networks using both CPUs and
GPUs.
[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.
[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
bilstmLayer
| classify
| gruLayer
| lstmLayer
| predict
| predictAndUpdateState
| resetState
| sequenceInputLayer