function resultsOut = rnnGridHd
set(0,'DefaultFigureWindowStyle','docked')
sequenceLength = sequenceTime / dT;
enclosureBoundaries = 1.1;
t = 2 * pi * rand(nPlace,1);
r = (enclosureBoundaries * 1.1) * sqrt(rand(nPlace,1));
hdCenter = normalizeToBounds(rand([1 nHd]),[-pi pi],[0 1]);
resultsOut.underlyingParams.placeCenter = placeCenter;
resultsOut.underlyingParams.hdCenter = hdCenter;
resultsOut.underlyingParams.enclosureBoundaries = enclosureBoundaries;
placeCenter = resultsOut.underlyingParams.placeCenter;
hdCenter = resultsOut.underlyingParams.hdCenter;
enclosureBoundaries = resultsOut.underlyingParams.enclosureBoundaries;
[velocityInputs,positionOutputs,yPlaceCells,yHeadDirection] = deal(cell(1,numSamples));
angleV = nan(1,sequenceLength);
presentAngle = normalizeToBounds(rand,[-pi pi],[0 1]);
v = nan(1,sequenceLength);
actualXY = nan(2,sequenceLength);
r = enclosureBoundaries * sqrt(rand);
activityPlace = nan(nPlace,sequenceLength);
activityHd = nan(nHd,sequenceLength);
for jj = 1:sequenceLength
currentAngle = circ_vmrnd(presentAngle,5,1);
currentV = raylrnd(1) * dT;
if enclosureBoundaries - norm(presentPlace) < .03
presentBearing = atan2(presentPlace(2),presentPlace(1));
facingWall = abs(circ_dist(presentBearing,currentAngle)) < pi/2;
if sign(circ_dist(presentBearing,currentAngle)) == 1; currentAngle = unwrap(currentAngle - pi);
else; currentAngle = unwrap(currentAngle + pi);
actualXY(:,jj) = presentPlace + [cos(currentAngle); sin(currentAngle)] * currentV;
activityPlace(:,jj) = vecnorm(actualXY(:,jj) - placeCenter) .^ 2 / (2 * sigmaPlace ^2);
activityHd(:,jj) = cos(currentAngle - hdCenter);
angleV(jj) = currentAngle - presentAngle;
presentAngle = currentAngle;
presentPlace = actualXY(:,jj);
inputNoise = randn([inputSize sequenceLength]) * .01;
thisInput = [v; sin(angleV); cos(angleV)] + inputNoise;
velocityInputs{ii} = thisInput;
positionOutputs{ii} = actualXY;
activityPlace = exp(-activityPlace) ./ sum(exp(-activityPlace));
yPlaceCells{ii} = activityPlace;
activityHd = exp(kappaHd * activityHd) ./ sum(exp(kappaHd * activityHd));
yHeadDirection{ii} = activityHd;
resultsOut.groundTruth.yPlaceCells = yPlaceCells;
resultsOut.groundTruth.yHeadDirection = yHeadDirection;
resultsOut.groundTruth.velocityInputs = velocityInputs;
resultsOut.groundTruth.positionOutputs = positionOutputs;
yPlaceCells = resultsOut.groundTruth.yPlaceCells;
yHeadDirection = resultsOut.groundTruth.yHeadDirection;
velocityInputs = resultsOut.groundTruth.velocityInputs;
positionOutputs = resultsOut.groundTruth.positionOutputs;
nHd = size(resultsOut.underlyingParams.hdCenter,2);
nPlace = size(resultsOut.underlyingParams.placeCenter,2);
thisLayerGraph = layerGraph();
thisLayerGraph = addLayers(thisLayerGraph, [
sequenceInputLayer(inputSize, 'Name', 'input')
lstmLayer(hiddenSize, 'Name', 'lstm')
fullyConnectedLayer(linearSize, 'Name', 'linear')
dropoutLayer(0.25, 'Name', 'dropout')
headFC = fullyConnectedLayer(nHd, 'Name', 'fcHeadDirection');
headSoftmax = softmaxLayer('Name', 'softmaxHeadDirection');
placeFC = fullyConnectedLayer(nPlace, 'Name', 'fcPlaceCells');
placeSoftmax = softmaxLayer('Name', 'softmaxPlaceCells');
thisLayerGraph = addLayers(thisLayerGraph, [headFC, headSoftmax]);
thisLayerGraph = addLayers(thisLayerGraph, [placeFC, placeSoftmax]);
thisLayerGraph = connectLayers(thisLayerGraph, 'dropout', 'fcHeadDirection');
thisLayerGraph = connectLayers(thisLayerGraph, 'dropout', 'fcPlaceCells');
neuralNet = dlnetwork(thisLayerGraph);
learningRate = 1 * 10^-4;
gradientClipValue = 10^-4;
squaredGrad = []; squaredGradInit = [];
[lossHistory,hdLossHistory,placeLossHistory] = deal(zeros(numEpochs, 1));
W = createStateInitializationWeights(hiddenSize, nPlace, nHd);
numObservations = numel(velocityInputs);
indices = randperm(numObservations);
numMiniBatches = floor(numObservations / miniBatchSize);
for ii = 1:numMiniBatches
batchIndices = indices((ii-1)*miniBatchSize+1:ii*miniBatchSize);
batchX = velocityInputs(batchIndices);
batchYHeadDirection = yHeadDirection(batchIndices);
batchYPlaceCells = yPlaceCells(batchIndices);
dlX = dlarray(cat(3, batchX{:}), 'CTB');
dlYHeadDirection = dlarray(cat(3, batchYHeadDirection{:}), 'CTB');
dlYPlaceCells = dlarray(cat(3, batchYPlaceCells{:}), 'CTB');
[gradients, loss, headLoss, placeLoss,...
wGrad] = dlfeval(@modelGradients,neuralNet, dlX, dlYHeadDirection, dlYPlaceCells,W);
for jj = 1:height(gradients)
layerName = gradients.Layer{jj};
if contains(layerName, 'fcHeadDirection') || contains(layerName, 'fcPlaceCells')
gradValues = extractdata(gradients.Value{jj});
clippedGrads = min(max(gradValues, -gradientClipValue), gradientClipValue);
gradients.Value{jj} = dlarray(clippedGrads);
[neuralNet, squaredGrad,squaredGradInit,W] = customRMSProp(neuralNet, gradients, squaredGrad, squaredGradInit,...
(epoch-1)*numMiniBatches+ii, learningRate, rmsPropDecayRate, epsilon, weightDecayRate,W,wGrad);
lossHistory(epoch) = lossHistory(epoch) + double(gather(extractdata(loss))) / numMiniBatches;
hdLossHistory(epoch) = hdLossHistory(epoch) + double(gather(extractdata(headLoss))) / numMiniBatches;
placeLossHistory(epoch) = placeLossHistory(epoch) + double(gather(extractdata(placeLoss))) / numMiniBatches;
fprintf('Epoch %d/%d: Loss = %.4f. ', epoch, numEpochs, lossHistory(epoch));
fprintf('Loss = %.4f for HD, %.4f for place.\n', hdLossHistory(epoch),placeLossHistory(epoch));
plot(1:numEpochs, lossHistory)
resultsOut.neuralNet = neuralNet;
neuralNet = resultsOut.neuralNet;
activityPlace = yPlaceCells{ii};
actualXY = positionOutputs{ii};
[xFromPlaceGroundTruth, yFromPlaceGroundTruth] = generateTrajectoryFromPlaceCells(activityPlace, placeCenter');
dlX = dlarray(velocityInputs{ii}, 'CT');
placeCellPredictions = extractdata(predict(neuralNet, dlX, 'Outputs', 'softmaxPlaceCells'));
[xFromPlaceDecoded, yFromPlaceDecoded] = generateTrajectoryFromPlaceCells(placeCellPredictions, placeCenter');
plot(actualXY(1,:),actualXY(2,:),'k','LineWidth',2); hold on
circle(0,0,enclosureBoundaries,[-pi pi],true);
plot(xFromPlaceGroundTruth,yFromPlaceGroundTruth,'b:');
plot(xFromPlaceDecoded,yFromPlaceDecoded,'r:');
activityHd = yHeadDirection{ii};
actualHeadingAngles = atan2(velocityInputs{ii}(2,:),velocityInputs{ii}(3,:));
angleFromHdGroundTruth = decodeHeadDirection(activityHd, hdCenter');
dlX = dlarray(velocityInputs{ii}, 'CT');
hdPredictions = extractdata(predict(neuralNet, dlX, 'Outputs', 'softmaxHeadDirection'));
angleFromHdDecoded = decodeHeadDirection(hdPredictions, hdCenter');
plot(unwrap(angleFromHdGroundTruth), 'k', 'LineWidth', 2); hold on
plot(unwrap(actualHeadingAngles), 'b:', 'LineWidth', 1.5);
plot(unwrap(angleFromHdDecoded),'r:','LineWidth',1.5);
legend('Decoded Heading', 'Actual Heading');
ylabel('Heading Angle (rad)');
title('Decoded vs. Actual Heading Direction');
maxVal = enclosureBoundaries;
minVal = -enclosureBoundaries;
xEdges = linspace(minVal, maxVal, numBins+1);
yEdges = linspace(minVal, maxVal, numBins+1);
thEdges = -pi:pi/12:pi - pi / 12;
firingRatesMap = zeros(numBins, numBins, linearSize);
firingRatesHd = zeros(length(thEdges),linearSize);
parfor unit = 1:linearSize
firingRateMap = zeros(numBins, numBins);
occupancyMap = zeros(numBins, numBins);
firingRateHd = zeros(size(thEdges));
occupancyHd = zeros(size(thEdges));
for ii = 1:length(positionOutputs)
dlInput = dlarray(velocityInputs{ii}, 'CT');
layerActivations = predict(neuralNet, dlInput, 'Outputs', 'linear');
layerActivations = extractdata(layerActivations);
layerActivations = layerActivations';
assert(size(layerActivations,2) == linearSize);
agentPath = positionOutputs{ii};
agentHd = atan2(velocityInputs{ii}(2,:),velocityInputs{ii}(3,:));
if abs(x(t)) > enclosureBoundaries || abs(y(t)) > enclosureBoundaries; continue; end
xBin = find(x(t) >= xEdges, 1, 'last');
yBin = find(y(t) >= yEdges, 1, 'last');
firingRateMap(xBin, yBin) = firingRateMap(xBin, yBin) + layerActivations(t, unit);
occupancyMap(xBin, yBin) = occupancyMap(xBin, yBin) + 1;
thBin = find(agentHd(t) >= thEdges, 1, 'last');
firingRateHd(thBin) = firingRateHd(thBin) + layerActivations(t, unit);
occupancyHd(thBin) = occupancyHd(thBin) + 1;
firingRateMap = firingRateMap ./ occupancyMap;
firingRateMap(isnan(firingRateMap)) = 0;
firingRatesMap(:, :, unit) = firingRateMap;
firingRateHd = firingRateHd ./ occupancyHd;
firingRateHd(isnan(firingRateHd)) = 0;
firingRatesHd(:, unit) = firingRateHd;
resultsOut.firingRatesMap = firingRatesMap;
resultsOut.firingRatesHd = firingRatesHd;
firingRatesMap = resultsOut.firingRatesMap;
firingRatesHd = resultsOut.firingRatesHd;
for unit = 1:size(firingRatesMap,3)
A = squeeze(firingRatesMap(:, :, unit));
title(['Firing Rate Map for Hidden Unit ' num2str(unit)]);
B = xcorr2(firingRatesMap(:, :, unit));
title(['Autocorrelation of Firing Rate Map for Hidden Unit ' num2str(unit)]);
degStep = 0:stepSize:360 - stepSize;
corrResults = nan(size(degStep));
for ii = 1:length(corrResults)
Br = imrotate(B,degStep(ii),'crop');
corrResults(ii) = corr2(B,Br);
plot(degStep,corrResults)
dT = length(corrResults);
plot(f(1:N/2), abs(X(1:N/2))*2/N);
[~,xInd] = min(abs(f - 6));
A = firingRatesHd(:,unit);
polarplot(thEdges + pi / 24,A);
function [gradients, totalLoss, headDirectionLoss, placeCellLoss, wGrad] = modelGradients(net, X, YHeadDirection, YPlaceCells,W)
batchDim = X.finddim('B');
timeDim = X.finddim('T');
miniBatchSize = size(X,batchDim);
initialPlaceActivities = zeros(size(YPlaceCells,1), miniBatchSize);
initialHdActivities = zeros(size(YHeadDirection,1), miniBatchSize);
assert(isequal(batchDim,2));
assert(isequal(timeDim,3));
initialPlaceActivities(:,jj) = YPlaceCells(:,jj,1);
initialHdActivities(:,jj) = YHeadDirection(:,jj,1);
dlInitialPlace = dlarray(initialPlaceActivities, 'CB');
dlInitialHd = dlarray(initialHdActivities, 'CB');
[dlInitialHidden, dlInitialCell] = initializeLSTMStatesBatched(dlInitialPlace, dlInitialHd, W);
[YPredHeadDirection, YPredPlaceCells] = forwardWithInitialState(net, X,dlInitialHidden,dlInitialCell);
numTimeSteps = size(YHeadDirection, 3);
headDirectionLoss = crossentropy(YPredHeadDirection,YHeadDirection) / numTimeSteps;
placeCellLoss = crossentropy(YPredPlaceCells,YPlaceCells) / numTimeSteps;
totalLoss = headWeight * headDirectionLoss + placeWeight * placeCellLoss;
gradients = dlgradient(totalLoss, net.Learnables);
wGrad.W_cp_grad = dlgradient(totalLoss, W.dlW_cp);
wGrad.W_cd_grad = dlgradient(totalLoss, W.dlW_cd);
wGrad.W_hp_grad = dlgradient(totalLoss, W.dlW_hp);
wGrad.W_hd_grad = dlgradient(totalLoss, W.dlW_hd);
function [totalGradients, totalLoss, totalHeadLoss, totalPlaceLoss] = handleChunking(net, X, YHeadDirection, YPlaceCells, sequenceLength)
totalTimeSteps = size(X, 3);
numChunks = ceil(totalTimeSteps / sequenceLength);
startIdx = (chunk-1)*sequenceLength + 1;
endIdx = min(chunk*sequenceLength, totalTimeSteps);
chunkLength = endIdx - startIdx + 1;
totalProcessedSteps = totalProcessedSteps + chunkLength;
chunkX = X(:, :, startIdx:endIdx);
chunkYHead = YHeadDirection(:, :, startIdx:endIdx);
chunkYPlace = YPlaceCells(:, :, startIdx:endIdx);
[chunkGradients, chunkLoss, chunkHeadLoss, chunkPlaceLoss] = dlfeval(@modelGradients, net, chunkX, chunkYHead, chunkYPlace);
for ii = 1:height(chunkGradients)
chunkGradients.Value{ii} = chunkGradients.Value{ii} * (chunkLength / totalTimeSteps);
weightedLoss = chunkLoss * (chunkLength / totalTimeSteps);
weightedHeadLoss = chunkHeadLoss * (chunkLength / totalTimeSteps);
weightedPlaceLoss = chunkPlaceLoss * (chunkLength / totalTimeSteps);
totalLoss = totalLoss + double(gather(extractdata(weightedLoss)));
totalHeadLoss = totalHeadLoss + double(gather(extractdata(weightedHeadLoss)));
totalPlaceLoss = totalPlaceLoss + double(gather(extractdata(weightedPlaceLoss)));
if isempty(totalGradients)
totalGradients = chunkGradients;
for ii = 1:height(totalGradients)
totalGradients.Value{ii} = totalGradients.Value{ii} + chunkGradients.Value{ii};
function [YPredHeadDirection, YPredPlaceCells] = forwardWithInitialState(net, X, initialHidden, initialCell)
batchDim = X.finddim('B');
batchSize = size(X,batchDim);
hiddenInd = strcmpi(netState.Parameter,"HiddenState");
cellInd = strcmpi(netState.Parameter,"CellState");
netState.Value{hiddenInd}(:,ii) = initialHidden(:,ii);
netState.Value{cellInd}(:,ii) = initialCell(:,ii);
[YPredHeadDirection, YPredPlaceCells] = predict(net, X);
function [neuralNet, squaredGrad,squaredGradInit,W] = customRMSProp(neuralNet, gradients, squaredGrad, squaredGradInit, iteration, ...
learningRate, rmsPropDecayRate, epsilon, weightDecayRate, W, wGrad)
squaredGrad = cell(height(gradients), 1);
for ii = 1:height(gradients)
squaredGrad{ii} = zeros(size(extractdata(gradients.Value{ii})), 'like', extractdata(gradients.Value{ii}));
for ii = 1:height(gradients)
layerName = gradients.Layer{ii};
paramName = gradients.Parameter{ii};
paramValue = extractdata(neuralNet.Learnables.Value{ii});
gradValue = extractdata(gradients.Value{ii});
if contains(layerName, 'fcHeadDirection') || contains(layerName, 'fcPlaceCells')
if strcmp(paramName, 'Weights')
gradValue = gradValue + weightDecayRate * paramValue;
squaredGrad{ii} = rmsPropDecayRate * squaredGrad{ii} + (1 - rmsPropDecayRate) * (gradValue .^ 2);
update = -learningRate * gradValue ./ (sqrt(squaredGrad{ii}) + epsilon);
newValue = paramValue + update;
neuralNet.Learnables.Value{ii} = dlarray(newValue);
if isempty(squaredGradInit)
squaredGradInit = struct();
squaredGradInit.W_cp = zeros(size(extractdata(W.dlW_cp)), 'like', extractdata(W.dlW_cp));
squaredGradInit.W_cd = zeros(size(extractdata(W.dlW_cd)), 'like', extractdata(W.dlW_cd));
squaredGradInit.W_hp = zeros(size(extractdata(W.dlW_hp)), 'like', extractdata(W.dlW_hp));
squaredGradInit.W_hd = zeros(size(extractdata(W.dlW_hd)), 'like', extractdata(W.dlW_hd));
gradValue = extractdata(wGrad.W_cp_grad);
squaredGradInit.W_cp = rmsPropDecayRate * squaredGradInit.W_cp + (1 - rmsPropDecayRate) * (gradValue .^ 2);
update = -learningRate * gradValue ./ (sqrt(squaredGradInit.W_cp) + epsilon);
W.dlW_cp = W.dlW_cp + update;
gradValue = extractdata(wGrad.W_cd_grad);
squaredGradInit.W_cd = rmsPropDecayRate * squaredGradInit.W_cd + (1 - rmsPropDecayRate) * (gradValue .^ 2);
update = -learningRate * gradValue ./ (sqrt(squaredGradInit.W_cd) + epsilon);
W.dlW_cd = W.dlW_cd + update;
gradValue = extractdata(wGrad.W_hp_grad);
squaredGradInit.W_hp = rmsPropDecayRate * squaredGradInit.W_hp + (1 - rmsPropDecayRate) * (gradValue .^ 2);
update = -learningRate * gradValue ./ (sqrt(squaredGradInit.W_hp) + epsilon);
W.dlW_hp = W.dlW_hp + update;
gradValue = extractdata(wGrad.W_hd_grad);
squaredGradInit.W_hd = rmsPropDecayRate * squaredGradInit.W_hd + (1 - rmsPropDecayRate) * (gradValue .^ 2);
update = -learningRate * gradValue ./ (sqrt(squaredGradInit.W_hd) + epsilon);
W.dlW_hd = W.dlW_hd + update;
function [x, y] = generateTrajectoryFromPlaceCells(placeCellActivations, placeFieldCenters)
numTimeSteps = size(placeCellActivations, 2);
x = zeros(1, numTimeSteps);
y = zeros(1, numTimeSteps);
distanceFromCenter = vecnorm(placeFieldCenters, 2, 2);
distanceWeights = distanceFromCenter / max(distanceFromCenter);
activations = placeCellActivations(:, t);
x(t) = sum(activations .* placeFieldCenters(:, 1));
y(t) = sum(activations .* placeFieldCenters(:, 2));
function decodedHeadingAngle = decodeHeadDirection(hdActivations, hdPreferredAngles)
numTimeSteps = size(hdActivations, 2);
decodedHeadingAngle = zeros(1, numTimeSteps);
activations = hdActivations(:, t);
x_components = cos(hdPreferredAngles);
y_components = sin(hdPreferredAngles);
x_sum = sum(activations .* x_components);
y_sum = sum(activations .* y_components);
decodedHeadingAngle(t) = atan2(y_sum, x_sum);
function W = createStateInitializationWeights(hiddenSize, nPlace, nHd)
W_cp = initializeGlorot(hiddenSize, nPlace);
W_cd = initializeGlorot(hiddenSize, nHd);
W_hp = initializeGlorot(hiddenSize, nPlace);
W_hd = initializeGlorot(hiddenSize, nHd);
function W = initializeGlorot(outputSize, inputSize)
limit = sqrt(6 / (inputSize + outputSize));
W = 2 * limit * rand(outputSize, inputSize) - limit;
function [initialHiddenState, initialCellState] = initializeLSTMStatesBatched(placeCellActivity, headDirectionActivity, W)
placeCellActivity = extractdata(placeCellActivity);
headDirectionActivity = extractdata(headDirectionActivity);
initialCellState = W_cp * placeCellActivity + W_cd * headDirectionActivity;
initialHiddenState = W_hp * placeCellActivity + W_hd * headDirectionActivity;
initialCellState = dlarray(initialCellState,'CB');
initialHiddenState = dlarray(initialHiddenState,'CB');
function dataVals = normalizeToBounds(dataVals,endRange,beginRange)
if size(dataVals,2) > 1 && size(dataVals,1) == 1
for ii = 1:size(dataVals,2)
dataCurrent = dataVals(:,ii);
if nargin <= 2 || all(isnan(beginRange))
beginRange = [min(dataCurrent) max(dataCurrent)];
dataCurrent = max(beginRange(1),min(dataCurrent,beginRange(2),'includenan'),'includenan');
dataCurrent = (dataCurrent - beginRange(1)) / range(beginRange);
dataCurrent = dataCurrent - beginRange(1);
dataCurrent = dataCurrent * range(endRange) + endRange(1);
dataVals(:,ii) = dataCurrent;
if wasFlipped; dataVals = dataVals'; end
function [xunit,yunit] = circle(x,y,r,th,plotting)
if nargin <= 4; plotting = false; end
if nargin <= 3; th = [0 2 * pi]; end
th = linspace(th(1),th(2),100);
if ~plotting; return ;end