Creating an actorLossFunction for ContinuousDeterministicActor
Show older comments
Hi in the example the actor loss function is the following for a rlDiscreteCategoricalActor
function loss = actorLossFunction(policy, lossData)
policy = policy{1};
% Create the action indication matrix.
batchSize = lossData.batchSize;
Z = repmat(lossData.actInfo.Elements',1,batchSize);
actionIndicationMatrix = lossData.actionBatch(:,:) == Z;
% Resize the discounted return to the size of policy.
G = actionIndicationMatrix .* lossData.discountedReturn;
G = reshape(G,size(policy));
% Round any policy values less than eps to eps.
policy(policy < eps) = eps;
% Compute the loss.
loss = -sum(G .* log(policy),'all');
end
Here is my
actInfo =
rlNumericSpec with properties:
LowerLimit: [2×1 double]
UpperLimit: [2×1 double]
Name: "CartPole Action"
Description: [0×0 string]
Dimension: [2 1]
DataType: "double"
obsInfo =
rlNumericSpec with properties:
LowerLimit: -Inf
UpperLimit: Inf
Name: "CartPole States"
Description: "pendulum_force, cart position, cart velocity"
Dimension: [4 1501]
DataType: "double"
Here is how I set my actor
actor = rlContinuousDeterministicActor(actorNet,obsInfo,actInfo);
actor = accelerate(actor,true);
actorOpts = rlOptimizerOptions('LearnRate',1e-3);
actorOptimizer = rlOptimizer(actorOpts);
To create my loss function can I do the following?
function loss = actorLossFunction(policy, lossData)
policy = policy{1};
% Create the action indication matrix.
batchSize = lossData.batchSize;
Z = repmat(lossData.actInfo.Dimension(1)',1,batchSize);
actionIndicationMatrix = lossData.actionBatch(:,:) == Z;
% Resize the discounted return to the size of policy.
G = actionIndicationMatrix .* lossData.discountedReturn;
G = reshape(G,size(policy));
% Round any policy values less than eps to eps.
policy(policy < eps) = eps;
% Compute the loss.
loss = -sum(G .* log(policy),'all');
end
Accepted Answer
More Answers (0)
Categories
Find more on Policies and Value Functions in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!