LSTM - Set special loss function

11 views (last 30 days)
Oliver Köhn
Oliver Köhn on 26 Apr 2018
Edited: Stuart Whipp on 12 Dec 2018
Based on this great MatLab-example I would like to build a neural network classifying each timestep of a timeseries (x_i,y_i) (i=1:N) as 1 or 2. For training purpose I created 500 different timeseries and the corresponding target-vectors. In reality, about 85 % of a timeseries is in state 1 and the rest (15 %) in state 2. The training-process lookes like this:
It stagnates at about 85%, because having the Mean Squared Error as a loss-function, a policy classifying every timestep as 1 results in a "good" accuracy of 85 %. So nearly every timestep is classified as 1. I am quite sure this can be avoided by using another loss function, but unfortunately I do not know how I can create an arbitrary loss-function.
I would like to adapt the loss function in a way, that if it falsely classifies a true state 2 as 1, then the loss is weighted by a factor f.
How can this be done?
This is how the training is set up:
featureDimension = 2;
numHiddenUnits = 100;
numClasses = 2;
layers = [ ...
sequenceInputLayer(featureDimension)
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
options = trainingOptions('adam', ...
'GradientThreshold',1, ...
'InitialLearnRate',0.01, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',20, ...
'Verbose',0, ...
'Plots','training-progress');
net = trainNetwork(XTrain,YTrain,layers,options);

Answers (2)

sii
sii on 18 May 2018
If it is still relevant, check following documentation. As far as I know, it is only available in the R2018a release.

Stuart Whipp
Stuart Whipp on 10 Dec 2018
I think what is needed is a weighted classification output so you can account for the imbalance in your classes. A custom layer tutorial exists for this, but it only works on image classification problems seemingly.
  1 Comment
Stuart Whipp
Stuart Whipp on 12 Dec 2018
Edited: Stuart Whipp on 12 Dec 2018
Conor Daly (staff) kindly answered my question this morning with a custom output layer that I can confirm has worked for my Use Case. Please take a look at this link as I believe it also answers your question.
Regards Stuart

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!