image thumbnail

Hierarchical Kalman Filter for clinical time series prediction

by

 

It is an implementation of hierarchical (a.k.a. multi-scale) Kalman filter using belief propagation.

initializeMessage(ts_data)
function [msg, varNode, factorNode, belief] = initializeMessage(ts_data) 
    eps = 1E-10;
    Lv1 = ts_data.Lv1; Lv2 = ts_data.Lv2;
    total_edges = 3*ts_data.T(Lv1) - 1 + 3*ts_data.T(Lv2);
    msg = cell(total_edges, 1);
    varNode.x{Lv1} = cell(ts_data.T(Lv1), 1);
    varNode.x{Lv2} = cell(ts_data.T(Lv2), 1);
    
    factorNode.y{Lv1} = cell(ts_data.T(Lv1), 1);  % for observation at level 1
    factorNode.y{Lv2} = cell(ts_data.T(Lv2), 1);  % for observation at level 2
    factorNode.x{Lv1} = cell(ts_data.T(Lv1), 1);  % for transitions at level 1
    factorNode.x{Lv2} = cell(ts_data.T(Lv2), 1);  % for transitions at level 2
    
    belief.x{Lv1} = cell(ts_data.T(Lv1), 1);      % belief for forward and downstream passing
    belief.x{Lv2} = cell(ts_data.T(Lv2), 1);      % belief for forward and downstream passing
    belief.x_sm{Lv1} = cell(ts_data.T(Lv1), 1);   % belief for backword and upstream passing
    belief.x_sm{Lv2} = cell(ts_data.T(Lv2), 1);   % belief for backword and upstream passing
    
    
    %% connect variable and factor nodes with message edge ID.
    count = 1;
    for i = 1:ts_data.T(Lv1)
        % for variable node
        varNode.x{Lv1}{i}.backwardNeighborMsgID = count; count = count + 1; % previous factor node
        if(i < ts_data.T(Lv1)) % no forward neighbor for the last variable
            varNode.x{Lv1}{i}.forwardNeighborMsgID = count; count = count + 1; % future factor node
        end
        varNode.x{Lv1}{i}.upperNeighborMsgID = count; count = count + 1; % observation y;
        % for factor node_x
        factorNode.x{Lv1}{i}.forwardNeighborMsgID = varNode.x{Lv1}{i}.backwardNeighborMsgID;
        if(i > 1) % no backward NeighborMsgID
            factorNode.x{Lv1}{i}.backwardNeighborMsgID = varNode.x{Lv1}{i-1}.forwardNeighborMsgID;
        end
        % for factor node_y
        factorNode.y{Lv1}{i}.lowerNeighborMsgID = varNode.x{Lv1}{i}.upperNeighborMsgID;
        
        for j = 1:numel(ts_data.connectionMap{Lv1}.toChildren{i})
            % connect level 2 with level 1
            tLv2 = ts_data.connectionMap{Lv1}.toChildren{i}(j);
            varNode.x{Lv1}{i}.lowerNeighborFactorNodeIDs(j) = tLv2; 
            varNode.x{Lv1}{i}.lowerNeighborMsgIDs(j) = count; count = count + 1;            
            factorNode.x{Lv2}{tLv2}.upperNeighborVariableNodeID = i;
            factorNode.x{Lv2}{tLv2}.upperNeighborMsgID = varNode.x{Lv1}{i}.lowerNeighborMsgIDs(j);
            % connect within level2.
            factorNode.x{Lv2}{tLv2}.lowerNeighborMsgID = count; count = count + 1;
            varNode.x{Lv2}{tLv2}.upperNeighborMsgID = factorNode.x{Lv2}{tLv2}.lowerNeighborMsgID;
            % connect level2 with observations
            varNode.x{Lv2}{tLv2}.lowerNeighborMsgID = count; count = count + 1;
            factorNode.y{Lv2}{tLv2}.upperNeighborMsgID = varNode.x{Lv2}{tLv2}.lowerNeighborMsgID;
        end        
    end  
   
    count = count - 1; % make sure count is correct.
    assert(count == total_edges, 'count mismatch');
    for i = 1:ts_data.T(Lv1)
        assert(varNode.x{Lv1}{i}.backwardNeighborMsgID == factorNode.x{Lv1}{i}.forwardNeighborMsgID);
        assert(varNode.x{Lv1}{i}.upperNeighborMsgID == factorNode.y{Lv1}{i}.lowerNeighborMsgID);
        if(i < ts_data.T(Lv1))
            assert(varNode.x{Lv1}{i}.forwardNeighborMsgID == factorNode.x{Lv1}{i+1}.backwardNeighborMsgID);
        end        
        assert(numel(varNode.x{Lv1}{i}.lowerNeighborFactorNodeIDs) == numel(varNode.x{Lv1}{i}.lowerNeighborMsgIDs));
        assert(numel(ts_data.connectionMap{Lv1}.toChildren{i}) == numel(varNode.x{Lv1}{i}.lowerNeighborMsgIDs));
    end
    for i = 1:ts_data.T(Lv2)    
        t_Lv1 = ts_data.connectionMap{Lv2}.toParents(i);
        assert(t_Lv1 == factorNode.x{Lv2}{i}.upperNeighborVariableNodeID)
        assert(ismember(factorNode.x{Lv2}{i}.upperNeighborMsgID, varNode.x{Lv1}{t_Lv1}.lowerNeighborMsgIDs)); 
        assert(factorNode.x{Lv2}{i}.lowerNeighborMsgID == varNode.x{Lv2}{i}.upperNeighborMsgID); 
        assert(varNode.x{Lv2}{i}.lowerNeighborMsgID == factorNode.y{Lv2}{i}.upperNeighborMsgID); 
    end
    %% initialize message
    % factor node x_1 for prior of x1
    msg{factorNode.x{Lv1}{1}.forwardNeighborMsgID}.toVarNode.mu = ts_data.mu_0{Lv1};
    msg{factorNode.x{Lv1}{1}.forwardNeighborMsgID}.toVarNode.V = ts_data.V_0{Lv1};
    msg{factorNode.x{Lv1}{1}.forwardNeighborMsgID}.toVarNode.iV = [];
    msg{factorNode.x{Lv1}{1}.forwardNeighborMsgID}.toFactorNode.mu = [];
    msg{factorNode.x{Lv1}{1}.forwardNeighborMsgID}.toFactorNode.V = [];
    msg{factorNode.x{Lv1}{1}.forwardNeighborMsgID}.toFactorNode.iV = [];
    % factor node x_2 to x_T with zero mean and infinite std.
    for i = 2:ts_data.T(Lv1)
        % forward message
        msg{factorNode.x{Lv1}{i}.forwardNeighborMsgID}.toVarNode.mu = zeros(ts_data.dim_x(Lv1), 1);
        msg{factorNode.x{Lv1}{i}.forwardNeighborMsgID}.toVarNode.V  = eyeInf(ts_data.dim_x(Lv1));
        msg{factorNode.x{Lv1}{i}.forwardNeighborMsgID}.toVarNode.iV  = [];
        msg{factorNode.x{Lv1}{i}.forwardNeighborMsgID}.toFactorNode.mu = zeros(ts_data.dim_x(Lv1), 1);
        msg{factorNode.x{Lv1}{i}.forwardNeighborMsgID}.toFactorNode.V  = eyeInf(ts_data.dim_x(Lv1));
        msg{factorNode.x{Lv1}{i}.forwardNeighborMsgID}.toFactorNode.iV  = [];
        % backward message
        msg{factorNode.x{Lv1}{i}.backwardNeighborMsgID}.toFactorNode.mu = zeros(ts_data.dim_x(Lv1), 1);
        msg{factorNode.x{Lv1}{i}.backwardNeighborMsgID}.toFactorNode.V  = eyeInf(ts_data.dim_x(Lv1));
        msg{factorNode.x{Lv1}{i}.backwardNeighborMsgID}.toFactorNode.iV  = [];
        msg{factorNode.x{Lv1}{i}.backwardNeighborMsgID}.toVarNode.mu = zeros(ts_data.dim_x(Lv1), 1);
        msg{factorNode.x{Lv1}{i}.backwardNeighborMsgID}.toVarNode.V  = eyeInf(ts_data.dim_x(Lv1));
        msg{factorNode.x{Lv1}{i}.backwardNeighborMsgID}.toVarNode.iV  = [];
    end
    % factor node y_1 to y_T.
    for i = 1:ts_data.T(Lv1)
        msg{factorNode.y{Lv1}{i}.lowerNeighborMsgID}.toVarNode.mu = pinv(ts_data.B{Lv1})*ts_data.y{Lv1}(:,i); % this is B' * inv(B*B') * y;
        assert(sum(sum(pinv(ts_data.B{Lv1}) - ts_data.B{Lv1}'/(ts_data.B{Lv1}*ts_data.B{Lv1}'))) < eps);
        msg{factorNode.y{Lv1}{i}.lowerNeighborMsgID}.toVarNode.V = [];  % the inverse matrix is singular;
        msg{factorNode.y{Lv1}{i}.lowerNeighborMsgID}.toVarNode.iV = (ts_data.B{Lv1}')/ts_data.V_W{Lv1}*ts_data.B{Lv1}; % we use precision to represent the distribution.
    end
    % factor node at level 2
    for i = 1:ts_data.T(Lv2)    
        msg{factorNode.y{Lv2}{i}.upperNeighborMsgID}.toVarNode.mu = pinv(ts_data.B{Lv2})*ts_data.y{Lv2}(:,i); % this is B' * inv(B*B') * y;
        assert(sum(sum(pinv(ts_data.B{Lv2}) - ts_data.B{Lv2}'/(ts_data.B{Lv2}*ts_data.B{Lv2}'))) < eps);
        msg{factorNode.y{Lv2}{i}.upperNeighborMsgID}.toVarNode.V = [];  % the inverse matrix is singular;
        msg{factorNode.y{Lv2}{i}.upperNeighborMsgID}.toVarNode.iV = (ts_data.B{Lv2}')/ts_data.V_W{Lv2}*ts_data.B{Lv2}; % we use precision to represent the distribution.
        
        msg{factorNode.x{Lv2}{i}.upperNeighborMsgID}.toVarNode.mu = pinv(ts_data.B{Lv2}*ts_data.A{Lv2})*ts_data.y{Lv2}(:,i);
        msg{factorNode.x{Lv2}{i}.upperNeighborMsgID}.toVarNode.V = pinv(ts_data.B{Lv2}*ts_data.A{Lv2}) * (ts_data.B{Lv2}*ts_data.V_Q{Lv2}*(ts_data.B{Lv2}') + ts_data.V_W{Lv2}) * (pinv(ts_data.B{Lv2}*ts_data.A{Lv2})'); 
        msg{factorNode.x{Lv2}{i}.upperNeighborMsgID}.toVarNode.iV = [];
        
        msg{factorNode.x{Lv2}{i}.upperNeighborMsgID}.toFactorNode.mu = zeros(ts_data.dim_x(Lv1), 1);
        msg{factorNode.x{Lv2}{i}.upperNeighborMsgID}.toFactorNode.V = eyeInf(ts_data.dim_x(Lv1)); 
        msg{factorNode.x{Lv2}{i}.upperNeighborMsgID}.toFactorNode.iV = [];
        
        
        msg{factorNode.x{Lv2}{i}.lowerNeighborMsgID}.toVarNode.mu = zeros(ts_data.dim_x(Lv2), 1);
        msg{factorNode.x{Lv2}{i}.lowerNeighborMsgID}.toVarNode.V = eyeInf(ts_data.dim_x(Lv2)); 
        msg{factorNode.x{Lv2}{i}.lowerNeighborMsgID}.toVarNode.iV = [];
        
        msg{factorNode.x{Lv2}{i}.lowerNeighborMsgID}.toFactorNode.mu = msg{factorNode.y{Lv2}{i}.upperNeighborMsgID}.toVarNode.mu;
        msg{factorNode.x{Lv2}{i}.lowerNeighborMsgID}.toFactorNode.V = msg{factorNode.y{Lv2}{i}.upperNeighborMsgID}.toVarNode.V; 
        msg{factorNode.x{Lv2}{i}.lowerNeighborMsgID}.toFactorNode.iV = msg{factorNode.y{Lv2}{i}.upperNeighborMsgID}.toVarNode.iV;        
        
    end
end

Contact us