how to define custom loss function like tf.train.A​​damOptimi​z​er().min​im​ize(los​s) ?

6 views (last 30 days)
Yiyan Huang
Yiyan Huang on 23 Oct 2019
Answered: Sai Bhargav Avula on 18 Feb 2020
I put X=[4,100] into LSTM and received the last hidden state[4,1]. I use this state vector to estimate Q. I want to use Adam Optimizer to minimize a custom loss function. But options offered by official tool cannot do it. So I want to define a loss function and use Adam Optimizer to min the loss. How can I do it like python:train = tf.train.AdamOptimizer().minimize(loss)?
z=[0.01,0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,0.99];
z_normal=norminv(z);
numFeatures = 4;
A=4;
X = reshape(X_test.',4,100,[]);
Y = reshape(Y_test.',1,[]);
dlX = dlarray(X,'CBT');
dlY = dlarray(Y,'BT');
numHiddenUnits = 16;
outputdim=4;
H0 = zeros(numHiddenUnits,1);
C0 = zeros(numHiddenUnits,1);
weights = dlarray(randn(4*numHiddenUnits,numFeatures),'CU');
recurrentWeights = dlarray(randn(4*numHiddenUnits,numHiddenUnits),'CU');
bias = dlarray(randn(4*numHiddenUnits,1),'C');
[outputs,hiddenState,cellState] = lstm(dlX,H0,C0,weights,recurrentWeights,bias);
state_h=hiddenState(:,end);
w_out=normrnd(0,1,[outputdim,numHiddenUnits]);
b_out=normrnd(0,1,[outputdim,1]);
out=w_out*state_h+b_out;
params=out+[0;1;1;1];
mu=params(1,:);
sig=params(2,:);
utail=params(3,:);
vtail=params(4,:);
factor1=exp(z_normal.'*utail)/A+1;
factor2=exp(z_normal.'*utail)/A+1;
factor=factor1.*factor2;
Q=factor.*z_normal.'.*sig+mu;
error=dlY-Q;
error1=z.'.*error;
error2=(z.'-1).*error;
loss=mean(max(error1,error2));

Answers (1)

Sai Bhargav Avula
Sai Bhargav Avula on 18 Feb 2020
Hi,
You can create custom loss function by creating a function of the form loss = myLoss(Y,T), where Y is the network predictions, T are the targets. The loss can be used to update the gradients in the modelGradient function.
This link explains how to create custom layers and loss functions in matlab
Hope this helps!

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!