LSTM loss decreasing but task fails

9 views (last 30 days)
Joshua Diamond
Joshua Diamond on 19 Mar 2025
Answered: Prathamesh on 29 May 2025
I am trying to reproduce the results in this paper. This is an LSTM which performs path integration. The loss decreases over time, but the task itself, path integration, still fails upon testing. Would be grateful if one of you could take a look.
Thank you,
Josh
function resultsOut = rnnGridHd
dbstop if error
set(0,'DefaultFigureWindowStyle','docked')
% Clear workspace
% Generate synthetic data for path integration
% Assume velocity inputs (vx, vy) and corresponding position outputs (x, y)
sequenceTime = 2; % Seconds
dT = .02;
sequenceLength = sequenceTime / dT;
numSamples = 1000; % Number of samples
% sequenceLength = 100; % Length of each sequence
inputSize = 3; % v_t, sin(phi_t),cos(phi_t)
hiddenSize = 128;
linearSize = 512;
forceNew = true;
% load('resultsOut.mat','resultsOut');
if forceNew
nPlace = 256; %
nHd = 12;
sigmaPlace = .1;
kappaHd = 20;
enclosureBoundaries = 1.1; % Meters
% Define place cell centers
% Generate random angles and radii
t = 2 * pi * rand(nPlace,1);
r = (enclosureBoundaries * 1.1) * sqrt(rand(nPlace,1)); % Multiply by 1.1 to give us some outside of the perimeter, to prevent pulling inward
% Calculate x and y coordinates
x = r.*cos(t);
y = r.*sin(t);
placeCenter = [x'; y'];
hdCenter = normalizeToBounds(rand([1 nHd]),[-pi pi],[0 1]);
resultsOut.underlyingParams.placeCenter = placeCenter;
resultsOut.underlyingParams.hdCenter = hdCenter;
resultsOut.underlyingParams.enclosureBoundaries = enclosureBoundaries;
else
placeCenter = resultsOut.underlyingParams.placeCenter;
hdCenter = resultsOut.underlyingParams.hdCenter;
enclosureBoundaries = resultsOut.underlyingParams.enclosureBoundaries;
end
%% Create trajectories
if forceNew
[velocityInputs,positionOutputs,yPlaceCells,yHeadDirection] = deal(cell(1,numSamples));
for ii = 1:numSamples
angleV = nan(1,sequenceLength);
presentAngle = normalizeToBounds(rand,[-pi pi],[0 1]);
v = nan(1,sequenceLength);
actualXY = nan(2,sequenceLength);
t = 2 * pi * rand;
r = enclosureBoundaries * sqrt(rand);
x = r.*cos(t);
y = r.*sin(t);
presentPlace = [x; y];
% Place cell output
activityPlace = nan(nPlace,sequenceLength);
activityHd = nan(nHd,sequenceLength);
for jj = 1:sequenceLength
% Angle
currentAngle = circ_vmrnd(presentAngle,5,1);
currentV = raylrnd(1) * dT;
% Are we close to the wall?
if enclosureBoundaries - norm(presentPlace) < .03
% Are we facing the wall?
presentBearing = atan2(presentPlace(2),presentPlace(1));
facingWall = abs(circ_dist(presentBearing,currentAngle)) < pi/2;
if facingWall
% Turn
if sign(circ_dist(presentBearing,currentAngle)) == 1; currentAngle = unwrap(currentAngle - pi);
else; currentAngle = unwrap(currentAngle + pi);
end
% Slow down
currentV = currentV / 4;
end
end
% Position
actualXY(:,jj) = presentPlace + [cos(currentAngle); sin(currentAngle)] * currentV;
% Define gaussian activity for place cell centers
activityPlace(:,jj) = vecnorm(actualXY(:,jj) - placeCenter) .^ 2 / (2 * sigmaPlace ^2);
% Head direction cell activity
activityHd(:,jj) = cos(currentAngle - hdCenter);
% Load
angleV(jj) = currentAngle - presentAngle; % Take the difference
v(jj) = currentV;
% Update
presentAngle = currentAngle;
presentPlace = actualXY(:,jj);
end
inputNoise = randn([inputSize sequenceLength]) * .01;
thisInput = [v; sin(angleV); cos(angleV)] + inputNoise;
velocityInputs{ii} = thisInput;
positionOutputs{ii} = actualXY;
%% Place cell output
% Take particular over summed place center activity
activityPlace = exp(-activityPlace) ./ sum(exp(-activityPlace));
yPlaceCells{ii} = activityPlace;
%% HD output
activityHd = exp(kappaHd * activityHd) ./ sum(exp(kappaHd * activityHd));
yHeadDirection{ii} = activityHd;
end
resultsOut.groundTruth.yPlaceCells = yPlaceCells;
resultsOut.groundTruth.yHeadDirection = yHeadDirection;
resultsOut.groundTruth.velocityInputs = velocityInputs;
resultsOut.groundTruth.positionOutputs = positionOutputs;
else
yPlaceCells = resultsOut.groundTruth.yPlaceCells;
yHeadDirection = resultsOut.groundTruth.yHeadDirection;
velocityInputs = resultsOut.groundTruth.velocityInputs;
positionOutputs = resultsOut.groundTruth.positionOutputs;
end
%% Set up forked training procedure
forceNew = true;
if forceNew
nHd = size(resultsOut.underlyingParams.hdCenter,2);
nPlace = size(resultsOut.underlyingParams.placeCenter,2);
% Create an empty layer graph
thisLayerGraph = layerGraph();
% Add the common layers to the graph
thisLayerGraph = addLayers(thisLayerGraph, [
sequenceInputLayer(inputSize, 'Name', 'input')
lstmLayer(hiddenSize, 'Name', 'lstm')
fullyConnectedLayer(linearSize, 'Name', 'linear')
dropoutLayer(0.25, 'Name', 'dropout') % Changed to .25 from .5
]);
% Define TWO SEPARATE fully connected layers for outputs
% One for head direction with nHd outputs
headFC = fullyConnectedLayer(nHd, 'Name', 'fcHeadDirection');
headSoftmax = softmaxLayer('Name', 'softmaxHeadDirection');
% One for place cells with 50 outputs
placeFC = fullyConnectedLayer(nPlace, 'Name', 'fcPlaceCells');
placeSoftmax = softmaxLayer('Name', 'softmaxPlaceCells');
% Add these layers to the graph
thisLayerGraph = addLayers(thisLayerGraph, [headFC, headSoftmax]);
thisLayerGraph = addLayers(thisLayerGraph, [placeFC, placeSoftmax]);
% Connect the branches to the dropout layer
thisLayerGraph = connectLayers(thisLayerGraph, 'dropout', 'fcHeadDirection');
thisLayerGraph = connectLayers(thisLayerGraph, 'dropout', 'fcPlaceCells');
% Create the network
neuralNet = dlnetwork(thisLayerGraph);
%% Define training parameters
numEpochs = 3000;
miniBatchSize = 10;
learningRate = 1 * 10^-4; % Adjust based on the paper if provided
rmsPropDecayRate = 0.9; % Typical value for RMSProp
weightDecayRate = 10^-4; % L2 regularization strength
epsilon = 1e-8; % For numerical stability
gradientClipValue = 10^-4;
% Initialize RMSProp state
squaredGrad = []; squaredGradInit = [];
% Initialize array to store loss values
[lossHistory,hdLossHistory,placeLossHistory] = deal(zeros(numEpochs, 1));
% Initialize
W = createStateInitializationWeights(hiddenSize, nPlace, nHd);
%% Train
% Training loop
for epoch = 1:numEpochs
% Shuffle data indices for this epoch
numObservations = numel(velocityInputs);
indices = randperm(numObservations);
% Initialize epoch loss
numMiniBatches = floor(numObservations / miniBatchSize);
% Mini-batch loop
for ii = 1:numMiniBatches
% Get indices for current mini-batch
batchIndices = indices((ii-1)*miniBatchSize+1:ii*miniBatchSize);
% Create mini-batch
batchX = velocityInputs(batchIndices);
batchYHeadDirection = yHeadDirection(batchIndices);
batchYPlaceCells = yPlaceCells(batchIndices);
% Process into dlarrays
% Assuming each sequence has 100 time steps as mentioned in the paper
dlX = dlarray(cat(3, batchX{:}), 'CTB'); % Class, Time, Batch
dlYHeadDirection = dlarray(cat(3, batchYHeadDirection{:}), 'CTB'); % Class, Time, Batch
dlYPlaceCells = dlarray(cat(3, batchYPlaceCells{:}), 'CTB'); % Class, Time, Batch
% Evaluate model gradients and loss with BPTT
% [gradients, loss, headLoss, placeLoss] = handleChunking(neuralNet, dlX, dlYHeadDirection, dlYPlaceCells);
[gradients, loss, headLoss, placeLoss,...
wGrad] = dlfeval(@modelGradients,neuralNet, dlX, dlYHeadDirection, dlYPlaceCells,W);
% Apply gradient clipping to output projection layers
for jj = 1:height(gradients)
layerName = gradients.Layer{jj};
% Check if this parameter connects dropout to outputs
if contains(layerName, 'fcHeadDirection') || contains(layerName, 'fcPlaceCells')
% Apply gradient clipping to just these parameters
gradValues = extractdata(gradients.Value{jj});
clippedGrads = min(max(gradValues, -gradientClipValue), gradientClipValue);
gradients.Value{jj} = dlarray(clippedGrads);
end
end
% Update network parameters using RMSProp
[neuralNet, squaredGrad,squaredGradInit,W] = customRMSProp(neuralNet, gradients, squaredGrad, squaredGradInit,...
(epoch-1)*numMiniBatches+ii, learningRate, rmsPropDecayRate, epsilon, weightDecayRate,W,wGrad);
% Accumulate loss for epoch
lossHistory(epoch) = lossHistory(epoch) + double(gather(extractdata(loss))) / numMiniBatches;
% epochLoss = epochLoss + loss; % It's already a double here
hdLossHistory(epoch) = hdLossHistory(epoch) + double(gather(extractdata(headLoss))) / numMiniBatches;
placeLossHistory(epoch) = placeLossHistory(epoch) + double(gather(extractdata(placeLoss))) / numMiniBatches;
end
% Display progress
fprintf('Epoch %d/%d: Loss = %.4f. ', epoch, numEpochs, lossHistory(epoch));
fprintf('Loss = %.4f for HD, %.4f for place.\n', hdLossHistory(epoch),placeLossHistory(epoch));
end
% Plot training progress
figure
plot(1:numEpochs, lossHistory)
xlabel('Epoch')
ylabel('Loss')
title('Training Loss')
resultsOut.neuralNet = neuralNet;
resultsOut.W = W;
else
neuralNet = resultsOut.neuralNet;
end
%% Test a prediction round
ii = 1;
activityPlace = yPlaceCells{ii};
actualXY = positionOutputs{ii};
% From place cell ground truth
[xFromPlaceGroundTruth, yFromPlaceGroundTruth] = generateTrajectoryFromPlaceCells(activityPlace, placeCenter');
% From place cell activations
dlX = dlarray(velocityInputs{ii}, 'CT');
placeCellPredictions = extractdata(predict(neuralNet, dlX, 'Outputs', 'softmaxPlaceCells'));
[xFromPlaceDecoded, yFromPlaceDecoded] = generateTrajectoryFromPlaceCells(placeCellPredictions, placeCenter');
figure
subplot(2,1,1);
% Actual
plot(actualXY(1,:),actualXY(2,:),'k','LineWidth',2); hold on
circle(0,0,enclosureBoundaries,[-pi pi],true);
axis equal
plot(xFromPlaceGroundTruth,yFromPlaceGroundTruth,'b:');
% From place cell activations
plot(xFromPlaceDecoded,yFromPlaceDecoded,'r:');
%% Test head direction prediction
activityHd = yHeadDirection{ii};
actualHeadingAngles = atan2(velocityInputs{ii}(2,:),velocityInputs{ii}(3,:));
angleFromHdGroundTruth = decodeHeadDirection(activityHd, hdCenter');
% Decoded
dlX = dlarray(velocityInputs{ii}, 'CT');
hdPredictions = extractdata(predict(neuralNet, dlX, 'Outputs', 'softmaxHeadDirection'));
angleFromHdDecoded = decodeHeadDirection(hdPredictions, hdCenter');
% Visualization
subplot(2,1,2);
% Plot decoded vs. actual angles
plot(unwrap(angleFromHdGroundTruth), 'k', 'LineWidth', 2); hold on
plot(unwrap(actualHeadingAngles), 'b:', 'LineWidth', 1.5);
plot(unwrap(angleFromHdDecoded),'r:','LineWidth',1.5);
legend('Decoded Heading', 'Actual Heading');
xlabel('Time Step');
ylabel('Heading Angle (rad)');
title('Decoded vs. Actual Heading Direction');
%%
forceNew = true;
if forceNew
% Define spatial bins for the firing rate map
numBins = 50; % Number of bins for the 2D map
maxVal = enclosureBoundaries;
% max(cat(1,positionOutputs{:}),[],'all');
minVal = -enclosureBoundaries;
xEdges = linspace(minVal, maxVal, numBins+1);
yEdges = linspace(minVal, maxVal, numBins+1);
thEdges = -pi:pi/12:pi - pi / 12;
% Initialize firing rate maps for all hidden units
firingRatesMap = zeros(numBins, numBins, linearSize);
firingRatesHd = zeros(length(thEdges),linearSize);
% Compute firing rate maps
parfor unit = 1:linearSize
firingRateMap = zeros(numBins, numBins);
occupancyMap = zeros(numBins, numBins);
firingRateHd = zeros(size(thEdges));
occupancyHd = zeros(size(thEdges));
for ii = 1:length(positionOutputs)
dlInput = dlarray(velocityInputs{ii}, 'CT');
layerActivations = predict(neuralNet, dlInput, 'Outputs', 'linear');
% Reshape hidden activations for analysis
layerActivations = extractdata(layerActivations);
layerActivations = layerActivations'; % Transpose to [sequenceLength x numHiddenUnits]
assert(size(layerActivations,2) == linearSize);
% Get the agent's path (x, y positions)
agentPath = positionOutputs{ii}; % Use the first sample's path for simplicity
x = agentPath(1, :);
y = agentPath(2, :);
agentHd = atan2(velocityInputs{ii}(2,:),velocityInputs{ii}(3,:));
for t = 1:sequenceLength
% Find the bin for the current position
if abs(x(t)) > enclosureBoundaries || abs(y(t)) > enclosureBoundaries; continue; end
xBin = find(x(t) >= xEdges, 1, 'last');
yBin = find(y(t) >= yEdges, 1, 'last');
% Accumulate activations and occupancy
firingRateMap(xBin, yBin) = firingRateMap(xBin, yBin) + layerActivations(t, unit);
occupancyMap(xBin, yBin) = occupancyMap(xBin, yBin) + 1;
thBin = find(agentHd(t) >= thEdges, 1, 'last');
firingRateHd(thBin) = firingRateHd(thBin) + layerActivations(t, unit);
occupancyHd(thBin) = occupancyHd(thBin) + 1;
end
end
% Normalize by occupancy to get the firing rate
firingRateMap = firingRateMap ./ occupancyMap;
firingRateMap(isnan(firingRateMap)) = 0; % Handle NaNs (empty bins)
firingRatesMap(:, :, unit) = firingRateMap;
firingRateHd = firingRateHd ./ occupancyHd;
firingRateHd(isnan(firingRateHd)) = 0;
firingRatesHd(:, unit) = firingRateHd;
end
resultsOut.firingRatesMap = firingRatesMap;
resultsOut.firingRatesHd = firingRatesHd;
else
firingRatesMap = resultsOut.firingRatesMap;
firingRatesHd = resultsOut.firingRatesHd;
end
%% Plot firing rate maps and their autocorrelations
figure
for unit = 1:size(firingRatesMap,3)
clf
% Firing rate map
subplot(2,3,1)
A = squeeze(firingRatesMap(:, :, unit));
imagesc(A);
colorbar;
title(['Firing Rate Map for Hidden Unit ' num2str(unit)]);
xlabel('X Position');
ylabel('Y Position');
% Spatial autocorrelation
subplot(2,3,4)
B = xcorr2(firingRatesMap(:, :, unit)); % Compute 2D autocorrelation
imagesc(B);
colorbar;
title(['Autocorrelation of Firing Rate Map for Hidden Unit ' num2str(unit)]);
xlabel('X');
ylabel('Y');
% pause
%% Hex symmetry
% Moser Nature 2006
stepSize = 6; % Degrees
% Sargolini, Moser Science 2006
degStep = 0:stepSize:360 - stepSize;
corrResults = nan(size(degStep));
for ii = 1:length(corrResults)
Br = imrotate(B,degStep(ii),'crop');
corrResults(ii) = corr2(B,Br);
end
%
subplot(2,3,2);
plot(degStep,corrResults)
xticks(0:60:360)
xline(xticks);
subplot(2,3,5);
% Assume 360 degrees is 1 second.
% How many samples per second? Rather, samples per cycle.
dT = length(corrResults);
X = fft(corrResults);
N = length(corrResults);
f = (0:N-1)*(dT/N); % Frequency vector
plot(f(1:N/2), abs(X(1:N/2))*2/N); % Plot single-sided spectrum
xline(6)
[~,xInd] = min(abs(f - 6));
% if abs(X(xInd)) > 2; keyboard; end
subplot(2,3,3);
A = firingRatesHd(:,unit);
polarplot(thEdges + pi / 24,A);
end
end
% Define the model gradients function
function [gradients, totalLoss, headDirectionLoss, placeCellLoss, wGrad] = modelGradients(net, X, YHeadDirection, YPlaceCells,W)
%% Initial states
batchDim = X.finddim('B');
timeDim = X.finddim('T');
miniBatchSize = size(X,batchDim); % Batch
% Initial states
initialPlaceActivities = zeros(size(YPlaceCells,1), miniBatchSize);
initialHdActivities = zeros(size(YHeadDirection,1), miniBatchSize);
assert(isequal(batchDim,2)); % Accessing dim 2 below, better be batch
assert(isequal(timeDim,3));
% Extract initial conditions for each sequence in the batch
for jj = 1:miniBatchSize
initialPlaceActivities(:,jj) = YPlaceCells(:,jj,1);
initialHdActivities(:,jj) = YHeadDirection(:,jj,1);
end
dlInitialPlace = dlarray(initialPlaceActivities, 'CB'); % Class, Batch
dlInitialHd = dlarray(initialHdActivities, 'CB'); % Class, Batch
[dlInitialHidden, dlInitialCell] = initializeLSTMStatesBatched(dlInitialPlace, dlInitialHd, W);
%%
% Forward pass
[YPredHeadDirection, YPredPlaceCells] = forwardWithInitialState(net, X,dlInitialHidden,dlInitialCell);
numTimeSteps = size(YHeadDirection, 3);
headDirectionLoss = crossentropy(YPredHeadDirection,YHeadDirection) / numTimeSteps;
placeCellLoss = crossentropy(YPredPlaceCells,YPlaceCells) / numTimeSteps;
% Combine losses (you can adjust weights if needed)
headWeight = 0.05;
placeWeight = 0.95;
totalLoss = headWeight * headDirectionLoss + placeWeight * placeCellLoss;
% Calculate gradients of loss with respect to all learnable parameters
gradients = dlgradient(totalLoss, net.Learnables);
wGrad.W_cp_grad = dlgradient(totalLoss, W.dlW_cp);
wGrad.W_cd_grad = dlgradient(totalLoss, W.dlW_cd);
wGrad.W_hp_grad = dlgradient(totalLoss, W.dlW_hp);
wGrad.W_hd_grad = dlgradient(totalLoss, W.dlW_hd);
end
function [totalGradients, totalLoss, totalHeadLoss, totalPlaceLoss] = handleChunking(net, X, YHeadDirection, YPlaceCells, sequenceLength)
sequenceLength = 100;
% Get total sequence length
totalTimeSteps = size(X, 3);
numChunks = ceil(totalTimeSteps / sequenceLength);
% Initialize variables
totalGradients = [];
totalLoss = 0;
totalHeadLoss = 0;
totalPlaceLoss = 0;
totalProcessedSteps = 0;
% Process each chunk
for chunk = 1:numChunks
% Get the time steps for this chunk
startIdx = (chunk-1)*sequenceLength + 1;
endIdx = min(chunk*sequenceLength, totalTimeSteps);
chunkLength = endIdx - startIdx + 1;
% Track total steps processed
totalProcessedSteps = totalProcessedSteps + chunkLength;
% Extract data for this chunk
chunkX = X(:, :, startIdx:endIdx);
chunkYHead = YHeadDirection(:, :, startIdx:endIdx);
chunkYPlace = YPlaceCells(:, :, startIdx:endIdx);
% Call modelGradients
[chunkGradients, chunkLoss, chunkHeadLoss, chunkPlaceLoss] = dlfeval(@modelGradients, net, chunkX, chunkYHead, chunkYPlace);
% Weight the gradients by the chunk length
if numChunks > 1
for ii = 1:height(chunkGradients)
chunkGradients.Value{ii} = chunkGradients.Value{ii} * (chunkLength / totalTimeSteps);
end
end
% Weight the losses by the chunk length
weightedLoss = chunkLoss * (chunkLength / totalTimeSteps);
weightedHeadLoss = chunkHeadLoss * (chunkLength / totalTimeSteps);
weightedPlaceLoss = chunkPlaceLoss * (chunkLength / totalTimeSteps);
% Accumulate weighted values
totalLoss = totalLoss + double(gather(extractdata(weightedLoss)));
totalHeadLoss = totalHeadLoss + double(gather(extractdata(weightedHeadLoss)));
totalPlaceLoss = totalPlaceLoss + double(gather(extractdata(weightedPlaceLoss)));
% Initialize or accumulate gradients
if isempty(totalGradients)
totalGradients = chunkGradients;
else
% Add the weighted gradients
for ii = 1:height(totalGradients)
totalGradients.Value{ii} = totalGradients.Value{ii} + chunkGradients.Value{ii};
end
end
end
end
% The forward function needs to be modified to use the initial states
function [YPredHeadDirection, YPredPlaceCells] = forwardWithInitialState(net, X, initialHidden, initialCell)
% This implementation depends on how your specific LSTM is structured
% Here's a general approach that would need to be adapted to your network
batchDim = X.finddim('B');
batchSize = size(X,batchDim);
netState = net.State;
hiddenInd = strcmpi(netState.Parameter,"HiddenState");
cellInd = strcmpi(netState.Parameter,"CellState");
% Update the copy
for ii = 1:batchSize
netState.Value{hiddenInd}(:,ii) = initialHidden(:,ii);
netState.Value{cellInd}(:,ii) = initialCell(:,ii);
end
% Update the network with the modified state
net.State = netState;
% Run forward pass with the modified state
[YPredHeadDirection, YPredPlaceCells] = predict(net, X);
end
function [neuralNet, squaredGrad,squaredGradInit,W] = customRMSProp(neuralNet, gradients, squaredGrad, squaredGradInit, iteration, ...
learningRate, rmsPropDecayRate, epsilon, weightDecayRate, W, wGrad)
% neuralNet: The neural network to update
% gradients: Table of gradients from dlgradient
% squaredGrad: Moving average of squared gradients (initialized as empty [] on first call)
% iteration: Current training iteration
% learningRate: Learning rate for parameter updates
% decayRate: Decay rate for moving average of squared gradients (typically 0.9-0.99)
% epsilon: Small constant for numerical stability
% weightDecayRate: Weight decay rate for L2 regularization (only applied to specific layers)
% if mod(iteration, 1000) == 0 && iteration > 0
% learningRate = learningRate * 0.5;
% end
% Initialize squared gradient accumulation if first iteration
if isempty(squaredGrad)
squaredGrad = cell(height(gradients), 1);
for ii = 1:height(gradients)
squaredGrad{ii} = zeros(size(extractdata(gradients.Value{ii})), 'like', extractdata(gradients.Value{ii}));
end
end
% Update each parameter
for ii = 1:height(gradients)
% Extract parameter name and value
layerName = gradients.Layer{ii};
paramName = gradients.Parameter{ii};
paramValue = extractdata(neuralNet.Learnables.Value{ii});
gradValue = extractdata(gradients.Value{ii});
% Apply weight decay only to weights projecting from dropout to output layers
if contains(layerName, 'fcHeadDirection') || contains(layerName, 'fcPlaceCells')
if strcmp(paramName, 'Weights') % Only apply to weights, not biases
% Apply weight decay (L2 regularization) by adding to gradient
gradValue = gradValue + weightDecayRate * paramValue;
end
end
% Update moving average of squared gradient
squaredGrad{ii} = rmsPropDecayRate * squaredGrad{ii} + (1 - rmsPropDecayRate) * (gradValue .^ 2);
% Compute update
update = -learningRate * gradValue ./ (sqrt(squaredGrad{ii}) + epsilon);
% Apply update
newValue = paramValue + update;
% Update parameter
neuralNet.Learnables.Value{ii} = dlarray(newValue);
end
%% Initialization
% Now handle the initialization parameters
if isempty(squaredGradInit)
squaredGradInit = struct();
squaredGradInit.W_cp = zeros(size(extractdata(W.dlW_cp)), 'like', extractdata(W.dlW_cp));
squaredGradInit.W_cd = zeros(size(extractdata(W.dlW_cd)), 'like', extractdata(W.dlW_cd));
squaredGradInit.W_hp = zeros(size(extractdata(W.dlW_hp)), 'like', extractdata(W.dlW_hp));
squaredGradInit.W_hd = zeros(size(extractdata(W.dlW_hd)), 'like', extractdata(W.dlW_hd));
end
gradValue = extractdata(wGrad.W_cp_grad);
squaredGradInit.W_cp = rmsPropDecayRate * squaredGradInit.W_cp + (1 - rmsPropDecayRate) * (gradValue .^ 2);
update = -learningRate * gradValue ./ (sqrt(squaredGradInit.W_cp) + epsilon);
W.dlW_cp = W.dlW_cp + update;
% Update W.dlW_cd
gradValue = extractdata(wGrad.W_cd_grad);
squaredGradInit.W_cd = rmsPropDecayRate * squaredGradInit.W_cd + (1 - rmsPropDecayRate) * (gradValue .^ 2);
update = -learningRate * gradValue ./ (sqrt(squaredGradInit.W_cd) + epsilon);
W.dlW_cd = W.dlW_cd + update;
% Update W.dlW_hp
gradValue = extractdata(wGrad.W_hp_grad);
squaredGradInit.W_hp = rmsPropDecayRate * squaredGradInit.W_hp + (1 - rmsPropDecayRate) * (gradValue .^ 2);
update = -learningRate * gradValue ./ (sqrt(squaredGradInit.W_hp) + epsilon);
W.dlW_hp = W.dlW_hp + update;
% Update W.dlW_hd
gradValue = extractdata(wGrad.W_hd_grad);
squaredGradInit.W_hd = rmsPropDecayRate * squaredGradInit.W_hd + (1 - rmsPropDecayRate) * (gradValue .^ 2);
update = -learningRate * gradValue ./ (sqrt(squaredGradInit.W_hd) + epsilon);
W.dlW_hd = W.dlW_hd + update;
end
function [x, y] = generateTrajectoryFromPlaceCells(placeCellActivations, placeFieldCenters)
% This function generates a 2D trajectory from place cell activations
%
% Inputs:
% placeCellActivations - Matrix of place cell activations (numCells × numTimeSteps)
% placeFieldCenters - 2D coordinates of place field centers (numCells × 2)
%
% Outputs:
% x, y - Estimated position coordinates for each time step
numTimeSteps = size(placeCellActivations, 2);
x = zeros(1, numTimeSteps);
y = zeros(1, numTimeSteps);
distanceFromCenter = vecnorm(placeFieldCenters, 2, 2);
distanceWeights = distanceFromCenter / max(distanceFromCenter);
% For each time step, compute the weighted average of place field centers
for t = 1:numTimeSteps
% Get activations at this time step
activations = placeCellActivations(:, t);
% Normalize activations to sum to 1 (if not already)
% Then when decoding
% activations = activations .* (1 + distanceWeights);
% activations = activations / sum(activations);
% Compute weighted average of place field centers
x(t) = sum(activations .* placeFieldCenters(:, 1));
y(t) = sum(activations .* placeFieldCenters(:, 2));
end
return
end
function decodedHeadingAngle = decodeHeadDirection(hdActivations, hdPreferredAngles)
% hdActivations: Matrix of head direction cell activations (numHDCells × numTimeSteps)
% hdPreferredAngles: Vector of preferred angles for each HD cell (numHDCells × 1)
numTimeSteps = size(hdActivations, 2);
decodedHeadingAngle = zeros(1, numTimeSteps);
for t = 1:numTimeSteps
activations = hdActivations(:, t);
% Normalize activations if not already done
% activations = activations / sum(activations);
% Convert preferred angles to unit vectors
x_components = cos(hdPreferredAngles);
y_components = sin(hdPreferredAngles);
% Compute weighted sum of unit vectors
x_sum = sum(activations .* x_components);
y_sum = sum(activations .* y_components);
% Convert back to angle
decodedHeadingAngle(t) = atan2(y_sum, x_sum);
end
end
%% Initialization
% function [W_cp, W_cd, W_hp, W_hd] = createStateInitializationWeights(hiddenSize, nPlace, nHd)
function W = createStateInitializationWeights(hiddenSize, nPlace, nHd)
% Creates the learnable weight matrices for LSTM state initialization
% Initialize using Glorot (Xavier) initialization
% This is a good default for neural network weights
% Cell state weights
W_cp = initializeGlorot(hiddenSize, nPlace);
W_cd = initializeGlorot(hiddenSize, nHd);
% Hidden state weights
W_hp = initializeGlorot(hiddenSize, nPlace);
W_hd = initializeGlorot(hiddenSize, nHd);
dlW_cp = dlarray(W_cp);
dlW_cd = dlarray(W_cd);
dlW_hp = dlarray(W_hp);
dlW_hd = dlarray(W_hd);
W.dlW_cp = dlW_cp;
W.dlW_cd = dlW_cd;
W.dlW_hp = dlW_hp;
W.dlW_hd = dlW_hd;
end
function W = initializeGlorot(outputSize, inputSize)
% Glorot/Xavier initialization
% Good for weights connecting to tanh activation functions (used in LSTM)
limit = sqrt(6 / (inputSize + outputSize));
W = 2 * limit * rand(outputSize, inputSize) - limit;
end
function [initialHiddenState, initialCellState] = initializeLSTMStatesBatched(placeCellActivity, headDirectionActivity, W)
% placeCellActivity: [nPlace x batchSize]
% headDirectionActivity: [nHd x batchSize]
% W_cp: [hiddenSize x nPlace]
% W_cd: [hiddenSize x nHd]
% W_hp: [hiddenSize x nPlace]
% W_hd: [hiddenSize x nHd]
W_cp = W.dlW_cp;
W_cd = W.dlW_cd;
W_hp = W.dlW_hp;
W_hd = W.dlW_hd;
placeCellActivity = extractdata(placeCellActivity);
headDirectionActivity = extractdata(headDirectionActivity);
% Calculate initial states for entire batch at once
initialCellState = W_cp * placeCellActivity + W_cd * headDirectionActivity;
initialHiddenState = W_hp * placeCellActivity + W_hd * headDirectionActivity;
initialCellState = dlarray(initialCellState,'CB');
initialHiddenState = dlarray(initialHiddenState,'CB');
% Results dimensions: [hiddenSize x batchSize]
end
function dataVals = normalizeToBounds(dataVals,endRange,beginRange)
wasFlipped = false;
if size(dataVals,2) > 1 && size(dataVals,1) == 1 % Row vector
dataVals = dataVals';
wasFlipped = true;
end
for ii = 1:size(dataVals,2)
dataCurrent = dataVals(:,ii);
if nargin <= 2 || all(isnan(beginRange))
beginRange = [min(dataCurrent) max(dataCurrent)];
end
if nargin <= 1
endRange = [0 1];
end
%% Put between 1 and 0
dataCurrent = max(beginRange(1),min(dataCurrent,beginRange(2),'includenan'),'includenan');
% Normalize to bounds
if range(beginRange) > 0
dataCurrent = (dataCurrent - beginRange(1)) / range(beginRange);
else
dataCurrent = dataCurrent - beginRange(1);
end
%% Put into end range
dataCurrent = dataCurrent * range(endRange) + endRange(1);
dataVals(:,ii) = dataCurrent;
end
if wasFlipped; dataVals = dataVals'; end
end
function [xunit,yunit] = circle(x,y,r,th,plotting)
hold on
if nargin <= 4; plotting = false; end
if nargin <= 3; th = [0 2 * pi]; end
th = linspace(th(1),th(2),100);
% th = 0 : pi / 50 : 2 * pi;
xunit = r * cos(th) + x;
yunit = r * sin(th) + y;
% fill(xunit,yunit,[1 0 0]);
if ~plotting; return ;end
plot(xunit,yunit,'k');
end

Answers (1)

Prathamesh
Prathamesh on 29 May 2025
I understand that you have a LTSM which performs path integration. The loss decrease over time, but the task itself, path integration, still fails upon testing.
Below are some steps to debug the issue:
1.Adjust Hyperparameters (First Priority):
  • Increase gradientClipValue: Set it to 1 or 5.
  • Increase learningRate: Try 5e-4 or 1e-3.
  • Set Equal Loss Weights: headWeight = 0.5; placeWeight = 0.5;
  • Run the training and observe the loss curves and the test predictions
2. Intermediate Predictions during Trainings:
  • Modify your training loop to periodically (e.g., every 100 epochs) run a predict call on a single sequence
dlX = dlarray(velocityInputs{1}, 'CT');
and then extract “softmaxPlaceCells” and "softmaxHeadDirection”outputs.
3. Examine generated data:
  • Before training, plot a few of your “positionOutputs” and “yPlaceCells” (as heatmaps or scatter plots of centers with activation strength). Ensure the synthetic data generation looks consistent.
Hope this helps

Products


Release

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!