image thumbnail

Monte Carlo Markov Chain for inferring parameters for an Ordinary Differential Equation model

by

 

This function uses a Monte Carlo Markov Chain algorithm to infer parameters for an ODE model

analyze_mcmc(parallel)
function analyze_mcmc(parallel)

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Name - analyze_mcmc
% Creation Date - 17th Feb 2012
% Author - Soumya Banerjee
% Website - www.cs.unm.edu/~soumya
%
%
% Description - 
%   Function to analuze output from MCMC runs
%
%
% License - BSD 
%
% Change History - 
%                   17th Feb 2012 - Creation by Soumya Banerjee
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


tic;

if matlabpool('size') == 0 && parallel == 1
    matlabpool open
end

%% import output file from MCMC run

fid = importdata('results_evolution_poc_gibbs_v22_100000_13-Mar-2012.csv');
log10V0 = fid.data(:,1);
log10p = fid.data(:,2);
log10delta = fid.data(:,3);

%% ODE parameters
target = 2.3e5; k = 4; infected1 = 0; infected2 = 0;
gamma = 44.43;
log10beta = log10(1.388e-4);

%% Data
fid_fileptr1 = importdata('dummymouse.txt');
passed_time_end1 = fid_fileptr1.data(end,1);
v_data1 = fid_fileptr1.data(:,2)';
time_data1 = fid_fileptr1.data(:,1)';
N = size(log10V0,1);


%% other parameters
burnin_period = 10000;
correlation_checks_start = N - 500; % check correlation from 500 samples before end
skip = 4; % how many skip? sample every (skip + 1)th sample 
skip_count = 0;
iInnerCount = 1; % counter to count/index sampled arrays

for iCount = burnin_period + 1:size(log10V0,1)
    if skip_count == skip
        % save samples in arrays
        log10V0_sampled(iInnerCount)    =  log10V0(iCount,1);
        log10p_sampled(iInnerCount)     =  log10p(iCount,1);
        log10delta_sampled(iInnerCount) =  log10delta(iCount,1);
        
        % increment inner counter for sampled arrays
        iInnerCount = iInnerCount + 1;
        % reset counter
        skip_count = 0;
    else
        % advance counter
        skip_count = skip_count + 1;
    end
end

for iCount = 1:size(log10delta_sampled,2)
    if iCount == 1
        log10V0_sampled_cumul(iCount)    =  log10V0_sampled(1);
        log10p_sampled_cumul(iCount)     =  log10p_sampled(1);
        log10delta_sampled_cumul(iCount) =  log10delta_sampled(1);
    else
        log10V0_sampled_cumul(iCount)    =  mean(log10V0_sampled(1:iCount-1));
        log10p_sampled_cumul(iCount)     =  mean(log10p_sampled(1:iCount-1));
        log10delta_sampled_cumul(iCount) =  mean(log10delta_sampled(1:iCount-1));
    end
end

for iCount = burnin_period + 1:size(log10V0,1)
    if iCount == burnin_period + 1
        log10V0_sampled_cumul_other(iCount)     =  log10V0(burnin_period + 1);
        log10p_sampled_cumul_other(iCount)      =  log10p(burnin_period + 1);
        log10delta_sampled_cumul_other(iCount)  =  log10delta(burnin_period + 1);
    else
        log10V0_sampled_cumul_other(iCount)     =  mean(log10V0(1:iCount-1));
        log10p_sampled_cumul_other(iCount)      =  mean(log10p(1:iCount-1));
        log10delta_sampled_cumul_other(iCount)  =  mean(log10delta(1:iCount-1));
    end
end

for iCount = correlation_checks_start:size(log10V0,1)
    log10V0_correl(iCount)    =  log10V0(iCount,1);
    log10p_correl(iCount)     =  log10p(iCount,1);
    log10delta_correl(iCount) =  log10delta(iCount,1);
end

% parfor iCount = 1:size(log10V0,1)
%     % parfor iCount = 1:4 % for testing
%     %     log10V0(iCount)
%     
%     %     log10p(iCount)
%     %     log10delta(iCount)
%     ssr = odecall_eclipse_tcl_jv_local_fileptr_mcmc(target,infected1,...
%                     infected2,...
%                     [log10V0(iCount) log10beta log10p(iCount) log10delta(iCount)]',...
%                     gamma,k,...
%                     passed_time_end1,time_data1,fid_fileptr1,0,0,1,1,v_data1);
%     output_array(iCount,:) = [log10V0(iCount) log10beta ...
%                                 log10p(iCount) log10delta(iCount)  ssr];
% end

% %% Save ssr of parameters in file
% % Fields - log10(V0),log10(beta),log10(p)
% %            ,log10(delta),ssr all in log10 scale
% % show and write results to file
% fid_output = fopen(sprintf('ssr_mcmc_%d_%s.csv', N, date),'w');
% fprintf(fid_output,'log10V0,log10beta,log10p,log10delta,ssr\r\n');
% 
% % output result array to file
% for iCount=1:size(log10V0,1)
%     % for iCount = 1:4 % for testing loop
%     fprintf(fid_output,'%12.10f,%12.10f,%12.10f,%12.10f,%12.10f\r\n',...
%                         output_array(iCount,1),output_array(iCount,2),...
%                         output_array(iCount,3),output_array(iCount,4),...
%                         output_array(iCount,5)...
%                         );
% end
% fclose(fid_output);

%% Display posterior means
disp('Posterior mean of log10 V0')
% mean(log10V0_sampled)
mean(log10V0_sampled_cumul_other(burnin_period + 1:size(log10V0,1)))
disp('Posterior mean of log10 p')
% mean(log10p_sampled)
mean(log10p_sampled_cumul_other(burnin_period + 1:size(log10V0,1)))
disp('Posterior mean of log10 delta')
% mean(log10delta_sampled)
mean(log10delta_sampled_cumul_other(burnin_period + 1:size(log10V0,1)))

%% Plot cumulative posterior
% log10V0_sampled_cumul
% size(log10delta_sampled,2)

% figure
% plot(log10V0_sampled_cumul_other)
% figure
% plot(log10p_sampled_cumul_other)
% figure
% plot(log10delta_sampled_cumul_other)

figure
plot(log10V0_correl(correlation_checks_start:size(log10V0,1)))
figure
plot(log10p_correl(correlation_checks_start:size(log10V0,1)))
figure
plot(log10delta_correl(correlation_checks_start:size(log10V0,1)))

% figure
% plot(log10V0_sampled_cumul)
% figure
% plot(log10p_sampled_cumul)
% figure
% plot(log10delta_sampled_cumul)

% %% Plot parameters
% 

%% save other plots 
figID = figure;
plot(log10V0_sampled,'-k','linewidth',2)
ylabel('log_1_0 V_0');
xlabel('Number of MCMC Iterations');
print(figID, '-djpeg', sprintf('V0_iter_%d_%s.jpg', N, date));

figID = figure;
plot(log10p_sampled,'-r','linewidth',2)
ylabel('log_1_0 p');
xlabel('Number of MCMC Iterations');
print(figID, '-djpeg', sprintf('p_iter_%d_%s.jpg', N, date));

figID = figure;
plot(log10delta_sampled,'-k','linewidth',2)
ylabel('log_1_0 delta');
xlabel('Number of MCMC Iterations');
print(figID, '-djpeg', sprintf('delta_iter_%d_%s.jpg', N, date));

figID = figure;
hist(log10V0_sampled,1000)
xlabel('log_1_0 V0');
title('Histogram');
print(figID, '-djpeg', sprintf('V0_hist_%d_%s.jpg', N, date));

figID = figure;
hist(log10p_sampled,1000)
xlabel('log_1_0 p');
title('Histogram');
print(figID, '-djpeg', sprintf('p_hist_%d_%s.jpg', N, date));

figID = figure;
hist(log10delta_sampled,1000)
xlabel('log_1_0 delta');
title('Histogram');
print(figID, '-djpeg', sprintf('delta_hist_%d_%s.jpg', N, date));

% Yeah, clean up after yourself
if matlabpool('size') > 0 && parallel == 1
    matlabpool close
end


%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


function ssr = odecall_eclipse_tcl_jv_local_fileptr_mcmc(target,infected1,...
                            infected2,param_vector,gamma,k,...
                            time_phase,time_vector,fileptr,plotfig,...
                            figorssr,mac,interp,emp_virus_vector)
 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% function to call ODE solver and plot results and calculate SSR
%
% Name - odecall_eclipse_tcl_jv_local_fileptr_mcmc
% Creation Date - 9th Dec 2011
% Author - Soumya Banerjee
% Website - www.cs.unm.edu/~soumya
%
%
% Description - function to call ODE solver with eclipse phase, target cell
%                   limited model, and plot results and 
%                   calculate SSR, for MCMC
%
% Parameters - 
%               target - initial target cell density (/mL)
%               infected1 - initial latently infected cell density (/mL)
%               infected2 - initial productively infected cell density
%               (/mL)
%               param_vector(1,4) has
%                   virus - initial virus density (/mL) in log 10
%                   beta - infectivity log10
%                   p - replication rate (PFU/day) in log 10
%                   delta - infected cell death rate (/day) in log 10
%               gamma - immune system clearance rate (/day)
%               k - eclipse phase (rate of transition from I1 to I2) (/day)
%               time_phase - duration of simulation(days)
%               time_vector - vector of measured times (days)
%               fileptr - handle to data file
%               plotfig - 1 if data and model simulation needs to be
%                          plotted,
%                         0 if no plot needed  
%               figorssr - 1 if there is a need to access the data file
%                           (needed when need to calculate SSR or
%                            need to plot the data in the figure),
%                          0 if access to data file is not needed 
%               mac - 1 if this is my mac, 
%                     0 if my linux desktop
%               interp - 1 if interpolation of simulation needed,
%                        0 if not needed
%               emp_virus_vector - empirically measured virus data (1,n)
%                                   row vector    
%
%
%
%
% Assumptions - 1) all parameters passed in numerical values i.e. not 
%                  logged except virus, beta, p and delta 
%                       (which are logged to base 10)
%               2) Phase 1 model with correction term with one infected
%                   cell compartment
%               3) File has two columns - first column has time in days
%                   and second column has virus load in PFU/mL in log10
%                   (see attached file)    
%               4) Model parameters passed locally in intial conditons
%                   to ode solver  
%               5) Target cell limited model with eclipse phase
%               6) File pointer handle provided
%               7) Returns SSR on logged data and simulation
%
% Comments -    1) Make this function inline to speed up computation 
%               2) Consider tinkering with ode options (using odeset)
%                   to speed up for your specific case
%
% License - BSD
%
% Change History - 
%                   29th Aug 2011 - Creation by Soumya Banerjee
%                    7th Sep 2011 - Modification by Soumya Banerjee
%                                   plots now saved and parameters
%                                   printed in title
%                    7th Nov 2011 - Added testing lines to print 
%                                   parameters
%                    9th Nov 2011 - Added code to handle infinite, complex
%                                   numbers and NaN        
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 
high_virus_value = 20; % a very high value of virus (in log10) that is
% returned when there is a problem with integration

virus = param_vector(1,1); beta  = param_vector(2,1);
p = param_vector(3,1);     delta = param_vector(4,1);

options = odeset('RelTol', .00001, ...
                 'NonNegative', [1 2 3 4]);
             
[t1 z1] = ode15s(@odecore_jv_tcl_eclipse_inline,[0, time_phase], ...
                                            [target ...
                                            infected1 ... 
                                            infected2 ...
                                            double(10^virus) ...
                                            double(10^beta) ...
                                            double(10^p) ...
                                            gamma ...
                                            double(10^delta) ...
                                            k],...
                                            options);
% FOR TESTING 
% [virus beta p delta]

% V is the fourth row of the output array
% CAUTION - change this if not true, e.g. in an ODE with one infected cell
% compartment
V = log10(z1(:,4)');
 
% % If figure being plotted or ssr required, then import data
% if figorssr == 1 || plotfig == 1
%     % If run on my mac then data file initial pathname is different
%     if mac == 1
%         fileptr = importdata(strcat('',filename), '\t');
%     else
%         % if run on my linux desktop data file initial pathname is
%         % different
%         fileptr = importdata(strcat('~/matlab/linux_stage98/',filename), '\t');
%     end
% end
 
% Plot the results
if plotfig == 1
    
%     subplot (2,2,1); plot(t,log10(zT))
%     xlabel('Time (in days)');  ylabel('Count in log base 10'); title('Target cells progression');
% 
%     subplot (2,2,2); plot(t,zI)
%     xlabel('Time (in days)');  ylabel('Count in log base 10'); title('Infected cells progression');
% 
%     subplot (2,2,3); 
    figID = figure
    plot(t1,V,'-b','linewidth',2)
    xlabel('Time post infection (in days)','fontsize',25); ylabel('Virus concentration in blood (log_1_0 (PFU/mL))','fontsize',18);
    % axis([-0.1 7])% 0 13])
    hold on
    % a = [7.8 9.8 10.3 10.3 8.4 1.8 0.85];
    % subplot (2,2,3); 
    plot(fileptr.data(:,1),fileptr.data(:,2),'ro','MarkerEdgeColor','k','MarkerFaceColor','r','MarkerSize',10)
    % axis([-0.1 8.5 0 3.5])
    %     title(strcat(strcat(strcat(strcat(strcat(strcat(strcat('V0 = ',num2str(10^virus)),...
    %         strcat('beta = ',num2str(10^beta))), 'p = '),...
    %         num2str(10^p))),'delta = '),num2str(10^delta)),'fontsize',18);

    title(strcat(strcat('T0 =',num2str(target)),strcat(strcat(strcat(strcat(strcat('k ='...
        ,num2str(k)),strcat(strcat(strcat(strcat('V0 = ',num2str(10^virus)),...
       strcat('beta = ',num2str(10^beta))), 'p = '),...
       num2str(10^p)))),'delta = '),num2str(10^delta))),'fontsize',18);
   
    legend('simulated','actual data','Location','NorthEast')
    print(figID, '-djpeg', sprintf('tcl_jv_plot_T0%s_k%s.jpg', num2str(target),num2str(k)));
end
 
% If interpolation needed, then do
if interp == 1
    % Need some error handling here
    % e.g. for log10virus log10beta log10p log10delta 
    % 2.5000 -3.9612 1.1276 0.6248, the returned virus vector is
    % -1.0458   -2.7730   -4.4999   -6.2563       NaN
    % T0 = 2.3e4, k = 3
    
    sim_virus_vector = interp1(t1, V, time_vector); 
%     V
    
    if ~isempty(find((isnan(sim_virus_vector) == 1))) ...
            || ~isempty(find(isfinite(sim_virus_vector) == 0)) ...
            || ~isempty(find(isreal(sim_virus_vector) == 0))
        % there is a NaN in sim_virus_vector or complex number or infinite
        % due to integration problem return something that gives high SSR
        sim_virus_vector = high_virus_value*ones(1,size(time_vector,2));
    end
else
    sim_virus_vector = [];
end

ssr = sum((sim_virus_vector - emp_virus_vector).^2);

 
function dydt = odecore_jv_tcl_eclipse_inline(~,y)

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% function to solve ODEs
%
% Name - odecore_jv_tcl_eclipse_inline
% Creation Date - 27th Aug 2011
% Author - Soumya Banerjee
% Website - www.cs.unm.edu/~soumya
%
% Acknowledgements - Drew Levin
%
% Description - core function to solve ODEs, to be called from ode45, etc
%                   with eclipse phase for JV paper
%
% Assumptions - 1) All parameters passed in numerical values i.e. not 
%                  logged
%               2) Target cell limited model with correction term
%               3) Eclipse phase included
%               4) Inline version
%               5) Intermediate variables removed
%
% Comments - Make this function inline to speed up computation as this
%               is called repeatedly by ode45 or ode15s
%          - Input argument not used and replaced by ~ 
%
% License - BSD
%
% Change History - 
%                   27th Aug 2011 - Creation by Soumya Banerjee
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

    dydt = y;

%     % Initial conditions
%     T  = y(1); % Target cells
%     I1 = y(2); % Infected cells not producing virus (latently infected)
%     I2 = y(3); % Infected cells producing viru (productively infected)
%     V  = y(4); % Virus

%     % Values of model parameters passed in y
%     beta   = y(5);
%     p      = y(6);
%     gamma  = y(7);
%     delta  = y(8);
%     k      = y(9); % eclipse phase (rate of transition from I1 to I2) 
    
    % System of differential equations
	dydt(1) = -y(5)*y(1)*y(4);
	dydt(2) = y(5)*y(1)*y(4) - y(9)*y(2);
    dydt(3) = y(9)*y(2) - y(8)*y(3);
	dydt(4) = y(6)*y(3) - y(7)*y(4) - y(5)*y(1)*y(4);
	
    % These are the model parameters. Since they do not change with time,
    % their rate of change is 0
	dydt(5) = 0;
	dydt(6) = 0;
	dydt(7) = 0;
	dydt(8) = 0;
    dydt(9) = 0;
    
    

Contact us