Why does my custom SAC agent behave differently from built-in SAC agent

15 views (last 30 days)
I implemented one custom SAC agent, which I have to, with MATLAB deep learning automatic differentiation. However, when compared to MATLAB built-in SAC agent on a certain task with exactly the same hyperparameters, the custom SAC agent failed to complete the task while the built-in agent succeeded.
Here is the training process of the built-in agent:
This is the training progress of the custom SAC agent(alongwith loss):
Here are the codes for the custom SAC agent and training:
1.Implementation of custom SAC agent
classdef MySACAgent < rl.agent.CustomAgent
properties
%networks
actor
critic1
critic2
critic_target1
critic_target2
log_alpha%entropy weight(log transformed)
%training options
options%Agent options
%optimizers
actorOptimizer
criticOptimizer_1
criticOptimizer_2
entWgtOptimizer
%experience buffers
obsBuffer
actionBuffer
rewardBuffer
nextObsBuffer
isDoneBuffer
rlExpBuffer
bufferIdx
bufferLen
%loss to record
cLoss
aLoss
eLoss
end
properties(Access = private)
Ts
counter
numObs
numAct
end
methods
%constructor
function obj = MySACAgent(numObs,numAct,obsInfo,actInfo,hid_dim,Ts,options)
% options' field:MaxBufferLen WarmUpSteps MiniBatchSize
% LearningFrequency EntropyLossWeight DiscountFactor
% OptimizerOptions(cell) PolicyUpdateFrequency TargetEntropy
% TargetUpdateFrequency TargetSmoothFactor
% base_seed NumGradientStepsPerUpdate
%OptimizerOptions(for actor&critic)
% (required) Call the abstract class constructor.
rng(options.base_seed);%set random seed
obj = obj@rl.agent.CustomAgent();
obj.ObservationInfo = obsInfo;
obj.ActionInfo = actInfo;
% obj.SampleTime = Ts;%explicitly assigned for simulink
obj.Ts = Ts;
%create networks
if isempty(hid_dim)
hid_dim = 256;
end
obj.actor = CreateActor(obj,numObs,numAct,hid_dim,obsInfo,actInfo);
[obj.critic1,obj.critic2,obj.critic_target1,obj.critic_target2] = CreateCritic(obj,numObs,numAct,hid_dim,obsInfo,actInfo);
obj.options = options;
assert(options.WarmUpSteps>options.MiniBatchSize,...
'options.WarmUpSteps must not be less than options.MiniBatchSize');
%set optimizers
obj.actorOptimizer = rlOptimizer(options.OptimizerOptions{1});
obj.criticOptimizer_1 = rlOptimizer(options.OptimizerOptions{2});
obj.criticOptimizer_2 = rlOptimizer(options.OptimizerOptions{3});
obj.entWgtOptimizer = rlOptimizer(options.OptimizerOptions{4});
obj.cLoss=0;
obj.aLoss=0;
obj.eLoss=0;
% (optional) Cache the number of observations and actions.
obj.numObs = numObs;
obj.numAct = numAct;
% (optional) Initialize buffer and counter.
resetImpl(obj);
% obj.rlExpBuffer = rlReplayMemory(obsInfo,actInfo,options.MaxBufferLen);
end
function resetImpl(obj)
% (Optional) Define how the agent is reset before training/
resetBuffer(obj);
obj.counter = 0;
obj.bufferLen=0;
obj.bufferIdx = 0;%base 0
obj.log_alpha = dlarray(log(obj.options.EntropyLossWeight));
end
function resetBuffer(obj)
% Reinitialize observation buffer. Allocate as dlarray to
% support automatic differentiation with dlfeval and
% dlgradient.
%format:CBT
obj.obsBuffer = dlarray(...
zeros(obj.numObs,obj.options.MaxBufferLen),'CB');
% Reinitialize action buffer with valid actions.
obj.actionBuffer = dlarray(...
zeros(obj.numAct,obj.options.MaxBufferLen),'CB');
% Reinitialize reward buffer.
obj.rewardBuffer = dlarray(zeros(1,obj.options.MaxBufferLen),'CB');
% Reinitialize nextState buffer.
obj.nextObsBuffer = dlarray(...
zeros(obj.numObs,obj.options.MaxBufferLen),'CB');
% Reinitialize mask buffer.
obj.isDoneBuffer = dlarray(zeros(1,obj.options.MaxBufferLen),'CB');
end
%Create networks
%Actor
function actor = CreateActor(obj,numObs,numAct,hid_dim,obsInfo,actInfo)
% Create the actor network layers.
commonPath = [
featureInputLayer(numObs,Name="obsInLyr")
fullyConnectedLayer(hid_dim)
layerNormalizationLayer
reluLayer
fullyConnectedLayer(hid_dim)
layerNormalizationLayer
reluLayer(Name="comPathOutLyr")
];
meanPath = [
fullyConnectedLayer(numAct,Name="meanOutLyr")
];
stdPath = [
fullyConnectedLayer(numAct,Name="stdInLyr")
softplusLayer(Name="stdOutLyr")
];
% Connect the layers.
actorNetwork = layerGraph(commonPath);
actorNetwork = addLayers(actorNetwork,meanPath);
actorNetwork = addLayers(actorNetwork,stdPath);
actorNetwork = connectLayers(actorNetwork,"comPathOutLyr","meanOutLyr/in");
actorNetwork = connectLayers(actorNetwork,"comPathOutLyr","stdInLyr/in");
actordlnet = dlnetwork(actorNetwork);
actor = initialize(actordlnet);
end
%Critic
function [critic1,critic2,critic_target1,critic_target2] = CreateCritic(obj,numObs,numAct,hid_dim,obsInfo,actInfo)
% Define the network layers.
criticNet = [
featureInputLayer(numObs+numAct,Name="obsInLyr")%input:[obs act]
fullyConnectedLayer(hid_dim)
layerNormalizationLayer
reluLayer
fullyConnectedLayer(hid_dim)
layerNormalizationLayer
reluLayer
fullyConnectedLayer(1,Name="QValueOutLyr")
];
% Connect the layers.
criticNet = layerGraph(criticNet);
criticDLnet = dlnetwork(criticNet,'Initialize',false);
critic1 = initialize(criticDLnet);
critic2 = initialize(criticDLnet);%c1 and c2 different initilization
critic_target1 = initialize(criticDLnet);
critic_target1.Learnables = critic1.Learnables;
critic_target1.State = critic1.State;
critic_target2 = initialize(criticDLnet);
critic_target2.Learnables = critic2.Learnables;
critic_target2.State = critic2.State;
end
function logP = logProbBoundedAction(obj,boundedAction,mu,sigma)
%used to calculate log probability for tanh(gaussian)
%validated, nothing wrong with this function
eps=1e-10;
logP = sum(log(1/sqrt(2*pi)./sigma.*exp(-0.5*(0.5*...
log((1+boundedAction+eps)./(1-boundedAction+eps))-mu).^2./sigma.^2).*1./(1-boundedAction.^2+eps)),1);
end
%loss functions
function [vLoss_1, vLoss_2, criticGrad_1, criticGrad_2] = criticLoss(obj,batchExperiences,c1,c2)
batchObs = batchExperiences{1};
batchAction = batchExperiences{2};
batchReward = batchExperiences{3};
batchNextObs = batchExperiences{4};
batchIsDone = batchExperiences{5};
batchSize = size(batchObs,2);
gamma = obj.options.DiscountFactor;
y = dlarray(zeros(1,batchSize));%CB(C=1)
y = y + batchReward;
actionNext = getActionWithExploration_dlarray(obj,batchNextObs);%CB
actionNext = actionNext{1};
Qt1=predict(obj.critic_target1,cat(1,batchNextObs,actionNext));%CB(C=1)
Qt2=predict(obj.critic_target2,cat(1,batchNextObs,actionNext));%CB(C=1)
[mu,sigma] = predict(obj.actor,batchNextObs);%CB:numAct*batch
next_action = tanh(mu + sigma.*randn(size(sigma)));
logP = logProbBoundedAction(obj,next_action,mu,sigma);
y = y + (1 - batchIsDone).*(gamma*(min(cat(1,Qt1,Qt2),[],1) - exp(obj.log_alpha)*logP));
critic_input = cat(1,batchObs,batchAction);
Q1 = forward(c1,critic_input);
Q2 = forward(c2,critic_input);
vLoss_1 = 1/2*mean((y - Q1).^2,'all');
vLoss_2 = 1/2*mean((y - Q2).^2,'all');
criticGrad_1 = dlgradient(vLoss_1,c1.Learnables);
criticGrad_2 = dlgradient(vLoss_2,c2.Learnables);
end
function [aLoss,actorGrad] = actorLoss(obj,batchExperiences,actor)
batchObs = batchExperiences{1};
batchSize = size(batchObs,2);
[mu,sigma] = forward(actor,batchObs);%CB:numAct*batch
curr_action = tanh(mu + sigma.*randn(size(sigma)));%reparameterization
critic_input = cat(1,batchObs,curr_action);
Q1=forward(obj.critic1,critic_input);%CB(C=1)
Q2=forward(obj.critic2,critic_input);%CB(C=1)
logP = logProbBoundedAction(obj,curr_action,mu,sigma);
aLoss = mean(-min(cat(1,Q1,Q2),[],1) + exp(obj.log_alpha) * logP,'all');
actorGrad= dlgradient(aLoss,actor.Learnables);
end
function [eLoss,entGrad] = entropyLoss(obj,batchExperiences,logAlpha)
batchObs = batchExperiences{1};
[mu,sigma] = predict(obj.actor,batchObs);%CB:numAct*batch
curr_action = tanh(mu + sigma.*randn(size(sigma)));
ent = mean(-logProbBoundedAction(obj,curr_action,mu,sigma));
eLoss = exp(logAlpha) * (ent - obj.options.TargetEntropy);
entGrad = dlgradient(eLoss,logAlpha);
end
end
methods(Access=protected)
%return SampleTime
function ts = getSampleTime_(obj)
ts = obj.Ts;
end
%get action without exploration
function action = getActionImpl(obj,obs)
%obs:dlarray CB
if ~isa(obs,'dlarray')
if isa(obs,'cell')
obs = dlarray(obs{1},'CB');
else
obs = dlarray(obs,'CB');
end
end
[mu,~] = predict(obj.actor,obs);
mu = extractdata(mu);
action = {tanh(mu)};
end
%get action with exploration
function action = getActionWithExplorationImpl(obj,obs)
%obs:dlarray CT
if ~isa(obs,'dlarray') || size(obs,1)~=obj.numObs
obs = dlarray(randn(obj.numObs,1),'CB');
end
[mu,sigma] = predict(obj.actor,obs);
mu = extractdata(mu);
sigma = extractdata(sigma);
action = {tanh(mu + sigma .* randn(size(sigma)))};
end
function action = getActionWithExploration_dlarray(obj,obs)
[mu,sigma] = predict(obj.actor,obs);
action = {tanh(mu + sigma .* randn(size(sigma)))};
end
%learning
function action = learnImpl(obj,Experience)
% Extract data from experience.
obs = Experience{1};
action = Experience{2};
reward = Experience{3};
nextObs = Experience{4};
isDone = logical(Experience{5});
obj.obsBuffer(:,obj.bufferIdx+1,:) = obs{1};
obj.actionBuffer(:,obj.bufferIdx+1,:) = action{1};
obj.rewardBuffer(:,obj.bufferIdx+1) = reward;
obj.nextObsBuffer(:,obj.bufferIdx+1,:) = nextObs{1};
obj.isDoneBuffer(:,obj.bufferIdx+1) = isDone;
obj.bufferLen = max(obj.bufferLen,obj.bufferIdx+1);
obj.bufferIdx = mod(obj.bufferIdx+1,obj.options.MaxBufferLen);
if obj.bufferLen>=max(obj.options.WarmUpSteps,obj.options.MiniBatchSize)
obj.counter = obj.counter + 1;
if (obj.options.LearningFrequency==-1 && isDone) || ...
(obj.options.LearningFrequency>0 && mod(obj.counter,obj.options.LearningFrequency)==0)
for gstep = 1:obj.options.NumGradientStepsPerUpdate
%sample batch
batchSize = obj.options.MiniBatchSize;
batchInd = randperm(obj.bufferLen,batchSize);
batchExperience = {
obj.obsBuffer(:,batchInd,:),...
obj.actionBuffer(:,batchInd,:),...
obj.rewardBuffer(:,batchInd),...
obj.nextObsBuffer(:,batchInd,:),...
obj.isDoneBuffer(:,batchInd)
};
%update the parameters of each critic
[cLoss1,cLoss2,criticGrad_1,criticGrad_2] = dlfeval(@(x,c1,c2)obj.criticLoss(x,c1,c2),batchExperience,obj.critic1,obj.critic2);
obj.cLoss = min(extractdata(cLoss1),extractdata(cLoss2));
[obj.critic1.Learnables.Value,obj.criticOptimizer_1] = update(obj.criticOptimizer_1,obj.critic1.Learnables.Value,criticGrad_1.Value);
[obj.critic2.Learnables.Value,obj.criticOptimizer_2] = update(obj.criticOptimizer_2,obj.critic2.Learnables.Value,criticGrad_2.Value);
if (mod(obj.counter,obj.options.PolicyUpdateFrequency)==0 && obj.options.LearningFrequency==-1) ||...
(mod(obj.counter,obj.options.LearningFrequency * obj.options.PolicyUpdateFrequency)==0 ...
&& obj.options.LearningFrequency>0)
%update the parameters of actor
[aloss,actorGrad] = dlfeval(...
@(x,actor)obj.actorLoss(x,actor),...
batchExperience,obj.actor);
obj.aLoss = extractdata(aloss);
[obj.actor.Learnables.Value,obj.actorOptimizer] = update(obj.actorOptimizer,obj.actor.Learnables.Value,actorGrad.Value);
%update the entropy weight
[eloss,entGrad] = dlfeval(@(x,alpha)obj.entropyLoss(x,alpha),batchExperience,obj.log_alpha);
obj.eLoss = extractdata(eloss);
% disp(obj.alpha)
[obj.log_alpha,obj.entWgtOptimizer] = update(obj.entWgtOptimizer,{obj.log_alpha},{entGrad});
obj.log_alpha = obj.log_alpha{1};
end
%update critic targets
%1
critic1_params = obj.critic1.Learnables.Value;%cell array network params
critic_target1_params = obj.critic_target1.Learnables.Value;
for i=1:size(critic1_params,1)
obj.critic_target1.Learnables.Value{i} = obj.options.TargetSmoothFactor * critic1_params{i}...
+ (1 - obj.options.TargetSmoothFactor) * critic_target1_params{i};
end
%2
critic2_params = obj.critic2.Learnables.Value;%cell array network params
critic_target2_params = obj.critic_target2.Learnables.Value;
for i=1:size(critic2_params,1)
obj.critic_target2.Learnables.Value{i} = obj.options.TargetSmoothFactor * critic2_params{i}...
+ (1 - obj.options.TargetSmoothFactor) * critic_target2_params{i};
end
% end
end
end
end
action = getActionWithExplorationImpl(obj,nextObs{1});
end
end
end
2.Configuration of 'options' property(same as those used for the built-in SAC agent)
options.MaxBufferLen = 1e4;
options.WarmUpSteps = 1000;
options.MiniBatchSize = 256;
options.LearningFrequency = -1;%when -1: train after each episode
options.EntropyLossWeight = 1;
options.DiscountFactor = 0.99;
options.PolicyUpdateFrequency = 1;
options.TargetEntropy = -2;
options.TargetUpdateFrequency = 1;
options.TargetSmoothFactor = 1e-3;
options.NumGradientStepsPerUpdate = 10;
%optimizerOptions: actor critic1 critic2 entWgt(alpha)
%encoder decoder
options.OptimizerOptions = {
rlOptimizerOptions("Algorithm","adam","GradientThreshold",1,'LearnRate',1e-3),...
rlOptimizerOptions("Algorithm","adam","GradientThreshold",1,'LearnRate',1e-3),...
rlOptimizerOptions("Algorithm","adam","GradientThreshold",1,'LearnRate',1e-3),...
rlOptimizerOptions("Algorithm","adam",'LearnRate',3e-4),...
rlOptimizerOptions("Algorithm","adam","GradientThreshold",1,'LearnRate',1e-3),...
rlOptimizerOptions("Algorithm","adam","GradientThreshold",1,'LearnRate',1e-3)};
options.base_seed=940;
3.training
clc;
clear;
close all;
run('init_car_params.m');
%create RL env
numObs = 4; % vx vy r beta_user
numAct = 2; % st_angle_ref rw_omega_ref
obsInfo = rlNumericSpec([numObs 1]);
actInfo = rlNumericSpec([numAct 1]);
actInfo.LowerLimit = -1;
actInfo.UpperLimit = 1;
mdl = "prius_sm_model";
blk = mdl + "/RL Agent";
env = rlSimulinkEnv(mdl,blk,obsInfo,actInfo);
params=struct('rw_radius',rw_radius,'a',a,'b',b,'init_vx',init_vx,'init_yaw_rate',init_yaw_rate);
env.ResetFcn = @(in) PriusResetFcn(in,params,mdl);
Ts = 1/10;
Tf = 5;
%create actor
rnd_seed=940;
algorithm = 'MySAC';
switch algorithm
case 'SAC'
agent = createNetworks(rnd_seed,numObs,numAct,obsInfo,actInfo,Ts);
case 'MySAC'
hid_dim = 256;
options=getDWMLAgentOptions();
agent = MySACAgent(numObs,numAct,obsInfo,actInfo,hid_dim,Ts,options);
end
%%
%train agent
close all
maxEpisodes = 6000;
maxSteps = floor(Tf/Ts);
useParallel = false;
run_idx=9;
saveAgentDir = ['savedAgents/',algorithm,'/',num2str(run_idx)];
switch algorithm
case 'SAC'
trainOpts = rlTrainingOptions(...
MaxEpisodes=maxEpisodes, ...
MaxStepsPerEpisode=maxSteps, ...
ScoreAveragingWindowLength=100, ...
Plots="training-progress", ...
StopTrainingCriteria="AverageReward", ...
UseParallel=useParallel,...
SaveAgentCriteria='EpisodeReward',...
SaveAgentValue=35,...
SaveAgentDirectory=saveAgentDir);
% SaveAgentCriteria='EpisodeFrequency',...
% SaveAgentValue=1,...
case 'MySAC'
trainOpts = rlTrainingOptions(...
MaxEpisodes=maxEpisodes, ...
MaxStepsPerEpisode=maxSteps, ...
ScoreAveragingWindowLength=100, ...
Plots="training-progress", ...
StopTrainingCriteria="AverageReward", ...
UseParallel=useParallel,...
SaveAgentCriteria='EpisodeReward',...
SaveAgentValue=35,...
SaveAgentDirectory=saveAgentDir);
end
set_param(mdl,"FastRestart","off");%for random initialization
if trainOpts.UseParallel
% Disable visualization in Simscape Mechanics Explorer
set_param(mdl, SimMechanicsOpenEditorOnUpdate="off");
save_system(mdl);
else
% Enable visualization in Simscape Mechanics Explorer
set_param(mdl, SimMechanicsOpenEditorOnUpdate="on");
end
%load training data
monitor = trainingProgressMonitor();
logger = rlDataLogger(monitor);
logger.EpisodeFinishedFcn = @myEpisodeLoggingFcn;
doTraining = true;
if doTraining
trainResult = train(agent,env,trainOpts,Logger=logger);
end
% %logger callback used for MySACAgent
function dataToLog = myEpisodeLoggingFcn(data)
dataToLog.criticLoss = data.Agent.cLoss;
dataToLog.actorLoss = data.Agent.aLoss;
dataToLog.entLoss = data.Agent.eLoss;
% dataToLog.denoiseLoss = data.Agent.dnLoss;
end
In the simulink environment used, action output by the Agent block(in [-1,1]) is denormalized and fed into the environment.
I think possible causes of the problem include:
1.Wrong implementation of critic loss. As shown in the training progress, critic loss seemed to diverge. It's hardly caused by hyperparameters(batch size or learning rate or target update frequency) because they worked well for the built-in agent. So it is more likely the critic loss is wrong.
2.Wrong implementation of replay buffer. I implemented the replay buffer as a circular queue, where I sampled uniformly to get batch training data. From the comparison of the training progress shown above, the custom SAC agent did explore states with high reward(around 30) but failed to exploit them, So I guess there is still problem with my replay buffer.
3.Gradient flow was broken.The learning is done with the help of MATLAB deep learning automatic differentiation. Perhaps some of my implementation violates the computational rule of automatic differentiation, which broke the gradient flow during forward computation or backpropagation and led to wrong result.
4.Gradient step(update frequency). In current implementation, NumGradientStepsPerUpdate gradient steps are executed after each episode. During each gradient step, cirtic(s) and actor, alongwith entropy weight, is updated once. I am not sure whether the current implementation of gradient step has got the update frequency right.
5.Also could be normalization problem, but I am not so sure.
I plan to debug 3 first.
Please read the code and help find potential causes of the gap between the custom SAC agent and the built-in one.
Finally, I am actually trying to extend SAC algorithm to a more complex framework. I didn't choose to inherit the built-in SAC agent(rlSACAgent), would it be recommended to do my development by doing so?

Accepted Answer

Kaustab Pal
Kaustab Pal on 29 Aug 2024
Upon reviewing your critic loss implementation, I'd like to offer some insights.
1. While the overall structure appears sound, there might be subtle dimension mismatches that could affect performance. It's crucial to ensure all operations are elementwise and that dimensions align perfectly across your tensors.
  • You can pay particular attention to the “batchIsDone” variable - verify it's being broadcasted correctly to match other tensors' dimensions.
  • Also consider using MATLAB's “bsxfun” or broadcasting syntax to guarantee proper dimension handling. These small details can significantly impact the stability and effectiveness of your learning process.
  • To further diagnose potential issues, I recommend using MATLAB's debugging tools to visualize tensor shapes at each step of the loss calculation. This can help pinpoint any unexpected dimension conflicts.
  • Additionally, consider adding “assert” statements to check tensor dimensions explicitly, which can catch issues early in the training process.
2. Your circular queue implementation seems reasonable. However, ensure that you're not overwriting experiences too quickly.
  • Consider increasing the buffer size or adjusting the sampling strategy. You might want to implement prioritized experience replay to focus on more important transitions.
3. Regarding Gradient Flow, ensure that all operations in the forward pass are differentiable functions.
4. In your current implementation, you are updating the networks after each episode. This might lead to instability. Consider updating after every N steps. Also ensure that the target networks are being updated correctly and at the right frequency.
5. Lastly, often normalizing the rewards leads to a more stable training. You can try that out as well.
Regarding your final question about inheriting from “rlSACAgent”: If you're planning to extend the SAC algorithm significantly, creating a custom implementation as you've done gives you more flexibility. However, if your extensions are minor, inheriting from “rlSACAgent” could save you time and reduce the chances of implementation errors.
Please refer to the following official documentations to learn about the functions in more detail:
  1. rlSACAgent: https://www.mathworks.com/help/reinforcement-learning/ref/rl.agent.rlsacagent.html
  2. bsxfun: https://www.mathworks.com/help/matlab/ref/bsxfun.html
  3. assert: https://www.mathworks.com/help/sltest/ref/assert.html
I hope these suggestions resolves your query.
  1 Comment
一凡
一凡 on 2 Sep 2024
Hi, Kaustab, thank you for your detailed suggestions. and sorry for my delayed reply.
I tried all your suggestions andmanaged to get the custom agent right.
I shall give how I debug it in a seperated comment.
Thanks again for your help!!!

Sign in to comment.

More Answers (1)

一凡
一凡 on 2 Sep 2024
I managed to work it out. These are main points to focus:
1.getActionWithExplorationImpl() function. This function is called by built-in simulation during warm-up stage for SAC algorithm and that's why I created random gaussian variable to explore. However, this function is also called after warm-up stage, where exploration action should be determined by feeding 'obs' into actor, which I failed to do in the code shown in the original question.
2.Hyperparameters. Because I mentioned in the original question that my update frequency might dismatch that of the built-in agent, hyperparameters should be different from those of built-in agent, especially learning rate & LearningFrequency & TargetSmoothFactor. I am yet to find which of these factors determines the performence. I suggest one tune these factors to make the custom agent work.
3.@Kaustab Pal 's anwser is quite helpful, with which I found the influence of LearningFrequency. One could refer to the anwser to debug their own custom SAC agent.
  1 Comment
一凡
一凡 on 2 Sep 2024
Btw, to accelerate the agent training, it might be helpful to accelerate the calculation of gradient function, which could be done by refering to MATLAB official webpage. I am working on this.

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!