image thumbnail

Hierarchical Kalman Filter for clinical time series prediction

by

 

It is an implementation of hierarchical (a.k.a. multi-scale) Kalman filter using belief propagation.

EMForMultiScaleKalmanFilter(msg, varNode, belief, ts_data, numLevels)
function ts_data = EMForMultiScaleKalmanFilter(msg, varNode, belief, ts_data, numLevels)
    % x{1}(1) ~ N(mu_0{1}, V_0{1})
    % x{1}(t1) = A{1}*x{1}(t1-1) + q{1}; N(0, V_Q{1})
    % y{1}(t1) = B{1}*x{1}(t1) + w{1}; N(0, V_W{1})
    % x{2}(t2) = A{2}*x{1}(parent) + q{2}; N(0, V_Q{2})
    % y{2}(t2) = B{2}*x{1}(t2) + w{2}; N(0, V_W{2})
    ts_data = updateLevel1(msg, varNode, belief, ts_data);    
    if(numLevels == 2)
        ts_data = updateLevel2(msg, varNode, belief, ts_data);
    end
end

function ts_data = updateLevel2(msg, varNode, belief, ts_data)    
    Lv1 = ts_data.Lv1;   
    Lv2 = ts_data.Lv2;    
    %% preCalculation
    beta = 0; % this is the sum of E[x{Lv2}_t, x{Lv2}^parent_t], where t = 1 to T{Lv2}
    gamma2 = 0; % this is the sum of E[x{Lv2}_t, x{Lv2}_t], where t = 1 to T{Lv2}
    gamma1 = 0; % this is the sum of E[x{Lv1}^parent_t, x{Lv2}^parent_t], where t = 1 to T{Lv2} sum all parents
    delta = 0; % this is the sum of y{Lv2}_t*E[x{Lv2}_t'], where t = 1 to T{Lv2};
    for t = 1:ts_data.T(Lv2)        
        % index of its parent
        t1 = ts_data.connectionMap{Lv2}.toParents(t);
        Vt_upper = msg{varNode.x{Lv2}{t}.upperNeighborMsgID}.toFactorNode.V;
        if(isempty(Vt_upper))
            Vt_upper = pinv(msg{varNode.x{Lv2}{t}.upperNeighborMsgID}.toFactorNode.iV);
        end
        beta = beta + Vt_upper/(ts_data.V_Q{Lv2}+ Vt_upper)*ts_data.A{Lv2}*belief.x_sm{Lv1}{t1}.V...
            + belief.x_sm{Lv2}{t}.mu * belief.x_sm{Lv1}{t1}.mu';        
        % at level 2
        gamma2 = gamma2 + belief.x_sm{Lv2}{t}.V + belief.x_sm{Lv2}{t}.mu * belief.x_sm{Lv2}{t}.mu';
        % at level 1            
        gamma1 = gamma1 + belief.x_sm{Lv1}{t1}.V + belief.x_sm{Lv1}{t1}.mu * belief.x_sm{Lv1}{t1}.mu';
        delta = delta + ts_data.y{Lv2}(:,t) * belief.x_sm{Lv2}{t}.mu';
    end    
    %% for A and V_Q at level 1
    ts_data.A{Lv2} = beta/gamma1;
    ts_data.V_Q{Lv2} = (gamma2 - ts_data.A{Lv2}*beta')/ts_data.T(Lv2);
    %% for B and V_W at level 2
    ts_data.B{Lv2} = delta/gamma2;
    ts_data.V_W{Lv2} = (ts_data.alpha{Lv2} - ts_data.B{Lv2}*delta')/ts_data.T(Lv2);
end

function ts_data = updateLevel1(msg, varNode, belief, ts_data)
    Lv1 = ts_data.Lv1;    
    %% for prior
    ts_data.mu_0{Lv1} = belief.x_sm{Lv1}{1}.mu;
    ts_data.V_0{Lv1} = belief.x_sm{Lv1}{1}.V;
    %% preCalculation
    beta = 0; % this is the sum of E[x_t, x_{t-1}], where t = 2 to T
    gamma = 0; % this is the sum of E[x_t, x_t], where t = 1 to T
    delta = 0; % this is the sum of y_t*E[x_t'], where t = 1 to N;
    for t = 1:ts_data.T(Lv1)
        if(t > 1)           
            Vt_back = msg{varNode.x{Lv1}{t}.backwardNeighborMsgID}.toFactorNode.V;
            if(isempty(Vt_back))
                Vt_back = pinv(msg{varNode.x{Lv1}{t}.backwardNeighborMsgID}.toFactorNode.iV);
            end
            beta = beta + Vt_back/(ts_data.V_Q{Lv1}+ Vt_back)*ts_data.A{Lv1}*belief.x_sm{Lv1}{t-1}.V...
                        + belief.x_sm{Lv1}{t}.mu * belief.x_sm{Lv1}{t-1}.mu';
        end
        gamma = gamma + belief.x_sm{Lv1}{t}.V + belief.x_sm{Lv1}{t}.mu * belief.x_sm{Lv1}{t}.mu';
        delta = delta + ts_data.y{Lv1}(:,t) * belief.x_sm{Lv1}{t}.mu';
    end
    % without t = T
    gamma1 = gamma - belief.x_sm{Lv1}{ts_data.T(Lv1)}.V - belief.x_sm{Lv1}{ts_data.T(Lv1)}.mu * belief.x_sm{Lv1}{ts_data.T(Lv1)}.mu';
    % without t = 1
    gamma2 = gamma - belief.x_sm{Lv1}{1}.V - belief.x_sm{Lv1}{1}.mu * belief.x_sm{Lv1}{1}.mu';
    %% for A and V_Q at level 1
    ts_data.A{Lv1} = beta/gamma1;
    ts_data.V_Q{Lv1} = (gamma2 - ts_data.A{Lv1}*beta')/(ts_data.T(Lv1) - 1);
    %% for B and V_W at level 1
    ts_data.B{Lv1} = delta/gamma;
    ts_data.V_W{Lv1} = (ts_data.alpha{Lv1} - ts_data.B{Lv1}*delta')/ts_data.T(Lv1);
end

Contact us