image thumbnail

Hierarchical Kalman Filter for clinical time series prediction

by

 

It is an implementation of hierarchical (a.k.a. multi-scale) Kalman filter using belief propagation.

kalmanFilter_MultiScale(msg, varNode, factorNode, belief, ts_data, numLevels)
function [msg, belief, ts_data] = kalmanFilter_MultiScale(msg, varNode, factorNode, belief, ts_data, numLevels)
    Lv1 = ts_data.Lv1;
    Lv2 = ts_data.Lv2;
    %% forward message for kalman filter
    % initialize the first message for v1 to f2 based on the given f1.
    msg{varNode.x{Lv1}{1}.forwardNeighborMsgID}.toFactorNode ...
        = GaussianMultiply(msg{varNode.x{Lv1}{1}.backwardNeighborMsgID}.toVarNode,...
        msg{varNode.x{Lv1}{1}.upperNeighborMsgID}.toVarNode);
    %% include level 2 information
    if(numLevels == 2)    
        msg = updateMSGToLevel2Forward(msg, varNode, 1, Lv1);
        msg = updateMSGInLevel2(msg, varNode, factorNode, 1, Lv1, Lv2, ts_data);
    end
    for i = 2:ts_data.T(Lv1)
        %% factor node update
        msg{factorNode.x{Lv1}{i}.forwardNeighborMsgID}.toVarNode.mu ...
            = ts_data.A{Lv1} * msg{factorNode.x{Lv1}{i}.backwardNeighborMsgID}.toFactorNode.mu;
        msg{factorNode.x{Lv1}{i}.forwardNeighborMsgID}.toVarNode.V ...
            = ts_data.A{Lv1} * msg{factorNode.x{Lv1}{i}.backwardNeighborMsgID}.toFactorNode.V * (ts_data.A{Lv1}') + ts_data.V_Q{Lv1};
        msg{factorNode.x{Lv1}{i}.forwardNeighborMsgID}.toVarNode.iV = [];
        %% variable node update
        if (i < ts_data.T(Lv1))
            msg{varNode.x{Lv1}{i}.forwardNeighborMsgID}.toFactorNode ...
                = GaussianMultiply(msg{varNode.x{Lv1}{i}.backwardNeighborMsgID}.toVarNode,...
                msg{varNode.x{Lv1}{i}.upperNeighborMsgID}.toVarNode);
            %% include level 2 information
            if(numLevels == 2)
                msg = updateMSGToLevel2Forward(msg, varNode, i, Lv1);
                msg = updateMSGInLevel2(msg, varNode, factorNode, i, Lv1, Lv2, ts_data);
            end
        else
            msg{varNode.x{Lv1}{ts_data.T(Lv1)}.backwardNeighborMsgID}.toFactorNode = msg{varNode.x{Lv1}{ts_data.T(Lv1)}.upperNeighborMsgID}.toVarNode;
            %% include level 2 information
            if(numLevels == 2)
                msg = updateMSGToLevel2Backward(msg, varNode, i, Lv1);
                msg = updateMSGInLevel2(msg, varNode, factorNode, i, Lv1, Lv2, ts_data);
            end
        end
    end
    
    %% update belief for kalman filter in level 1
    for i = 1:ts_data.T(Lv1)
        if(i < ts_data.T(Lv1))
            belief.x{Lv1}{i} = GaussianMultiply(msg{varNode.x{Lv1}{i}.forwardNeighborMsgID}.toFactorNode,...
                msg{varNode.x{Lv1}{i}.forwardNeighborMsgID}.toVarNode);
        else
            belief.x{Lv1}{i} = GaussianMultiply(msg{varNode.x{Lv1}{i}.backwardNeighborMsgID}.toFactorNode,...
                msg{varNode.x{Lv1}{i}.backwardNeighborMsgID}.toVarNode);
        end
        ts_data.mu_kf{Lv1}(:,i) = belief.x{Lv1}{i}.mu;
    end
    %% update belief for kalman filter in level 2
    for i = 1:ts_data.T(Lv2)
        belief.x{Lv2}{i} = GaussianMultiply(msg{varNode.x{Lv2}{i}.upperNeighborMsgID}.toFactorNode,...
                msg{varNode.x{Lv2}{i}.upperNeighborMsgID}.toVarNode);
        ts_data.mu_kf{Lv2}(:,i) = belief.x{Lv2}{i}.mu;
    end
end

function msg = updateMSGToLevel2Forward(msg, varNode, i, Lv1)
        % send message forward in level 1.
        for j = 1:numel(varNode.x{Lv1}{i}.lowerNeighborMsgIDs)
            msg{varNode.x{Lv1}{i}.forwardNeighborMsgID}.toFactorNode...
                = GaussianMultiply(msg{varNode.x{Lv1}{i}.forwardNeighborMsgID}.toFactorNode,...
                msg{varNode.x{Lv1}{i}.lowerNeighborMsgIDs(j)}.toVarNode);
        end
        % update variable node in level 1 which sends message downward to level 2
        for j = 1:numel(varNode.x{Lv1}{i}.lowerNeighborMsgIDs)
            msg{varNode.x{Lv1}{i}.lowerNeighborMsgIDs(j)}.toFactorNode...
                = GaussianDivision(msg{varNode.x{Lv1}{i}.forwardNeighborMsgID}.toFactorNode,...
                msg{varNode.x{Lv1}{i}.lowerNeighborMsgIDs(j)}.toVarNode);
        end
end

Contact us