This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English version of the page.

Note: This page has been translated by MathWorks. Click here to see
To view all translated materials including this page, select Country from the country navigator on the bottom of this page.

trainSoftmaxLayer

Train a softmax layer for classification

Syntax

net = trainSoftmaxLayer(X,T)
net = trainSoftmaxLayer(X,T,Name,Value)

Description

example

net = trainSoftmaxLayer(X,T) trains a softmax layer, net, on the input data X and the targets T.

net = trainSoftmaxLayer(X,T,Name,Value) trains a softmax layer, net, with additional options specified by one or more of the Name,Value pair arguments.

For example, you can specify the loss function.

Examples

collapse all

Load the sample data.

[X,T] = iris_dataset;

X is a 4x150 matrix of four attributes of iris flowers: Sepal length, sepal width, petal length, petal width.

T is a 3x150 matrix of associated class vectors defining which of the three classes each input is assigned to. Each row corresponds to a dummy variable representing one of the iris species (classes). In each column, a 1 in one of the three rows represents the class that particular sample (observation or example) belongs to. There is a zero in the rows for the other classes that the observation does not belong to.

Train a softmax layer using the sample data.

net = trainSoftmaxLayer(X,T);

Classify the observations into one of the three classes using the trained softmax layer.

Y = net(X);

Plot the confusion matrix using the targets and the classifications obtained from the softmax layer.

plotconfusion(T,Y);

Input Arguments

collapse all

Training data, specified as an m-by-n matrix, where m is the number of variables in training data, and n is the number of observations (examples). Hence, each column of X represents a sample.

Data Types: single | double

Target data, specified as a k-by-n matrix, where k is the number of classes, and n is the number of observations. Each row is a dummy variable representing a particular class. In other words, each column represents a sample, and all entries of a column are zero except for a single one in a row. This single entry indicates the class for that sample.

Data Types: single | double

Name-Value Pair Arguments

Specify optional comma-separated pairs of Name,Value arguments. Name is the argument name and Value is the corresponding value. Name must appear inside quotes. You can specify several name and value pair arguments in any order as Name1,Value1,...,NameN,ValueN.

Example: 'MaxEpochs',400,'ShowProgressWindow',false specifies the maximum number of iterations as 400 and hides the training window.

Maximum number of training iterations, specified as the comma-separated pair consisting of 'MaxEpochs' and a positive integer value.

Example: 'MaxEpochs',500

Data Types: single | double

Loss function for the softmax layer, specified as the comma-separated pair consisting of 'LossFunction' and either 'crossentropy' or 'mse'.

mse stands for mean squared error function, which is given by:

E=1nj=1ni=1k(tijyij)2,

where n is the number of training examples, and k is the number of classes. tij is the ijth entry of the target matrix, T, and yij is the ith output from the autoencoder when the input vector is xj.

The cross entropy function is given by:

E=1nj=1ni=1ktijlnyij+(1tij)ln(1yij).

Example: 'LossFunction','mse'

Indicator to display the training window during training, specified as the comma-separated pair consisting of 'ShowProgressWindow' and either true or false.

Example: 'ShowProgressWindow',false

Data Types: logical

Training algorithm used to train the softmax layer, specified as the comma-separated pair consisting of 'TrainingAlgorithm' and 'trainscg', which stands for scaled conjugate gradient.

Example: 'TrainingAlgorithm','trainscg'

Output Arguments

collapse all

Softmax layer for classification, returned as a network object. The softmax layer, net, is the same size as the target T.

Introduced in R2015b