Vectorize nested loops for performance
Show older comments
I have to solve a dynamic programming problem with a finite horizon and I am trying to vectorize as much as possible for speed. I attach here a MWE so that everybody can run it
I have run the code with the Matlab profiler and indetified two bottlenecks, marked with the comment line % THIS IS SLOW ACCORDING TO PROFILER
The first bottleneck is the call to the function ReturnFn that builds the n_a*n_a matrix RetMat. I have done this in a vectorized way (inside the function RetMat, I do not use loops)
The second bottleneck is the maximization of RetMat along the first dimension.
I'd be very grateful for any comment/suggestion!
P.S. I have an if condition inside the loops to check for bugs, I removed it now that I am confident about the code but the speed improvement is marginal.
clear,clc,close all
% STATE VARIABLES
% V(a,h,z,j)
% a: asset holdings
% h: human capital (female)
% z: labor productivity shocks (eta_m, eta_f are shocks, theta is permanent)
% j: age (from 1 to N_j)
% CHOICE VARIABLES
% d: Female labor supply (only extensive margin: either 0 or 1)
% a': Next-period assets
% h' is implied by (d,a) and consumption is implied by the budget
% constraint
% DYNAMIC PROGRAMMING PROBLEM
% V(a,h,z,j) = max_{d,a'} F(d,a',a,h,z,j)+beta*s_j*E[V(a',h',z',j+1)|z]
% subject to
% h'=G(d,h), law of motion for human capital
verbose = 1;
%% Define grids and grid sizes
N_j = 80;
n_a = 51;
n_h = 11;
n_z = 50;
n_d = 2;
a_grid = linspace(0,450,n_a)';
h_grid = linspace(0,0.72,n_h)';
z_grid = linspace(0.9,1.1,n_z)';
d_grid = [0,1]';
z_grid = repmat(z_grid,[1,3]);
pi_z = rand(n_z,n_z);
pi_z = pi_z./sum(pi_z,2);
aprime_val = a_grid; %(a',1)
a_val = a_grid'; %(1,a)
%% Set parameters that do not depend on age
beta = 0.98;
r = 0.04;
w_m = 1;
w_f = 0.75;
crra = 2;
nu = 0.12;
Jr = 45;
xi_1 = 0.05312;
xi_2 = -0.00188;
del_h = 0.074;
h_l = 0;
p.eff_j = ones(N_j,1);
p.pchild_j = ones(N_j,1);
p.pen_j = ones(N_j,1);
p.nchild_j = ones(N_j,1);
p.s_j = ones(N_j,1);
p.age_j = (1:1:N_j)';
% Initialize output arrays
V = zeros(n_a,n_h,n_z,N_j);
Policy = zeros(2,n_a,n_h,n_z,N_j);
tic
%% Solve problem in the last period
% Set age-dependent parameters
eff_j = p.eff_j(N_j);
pchild_j = p.pchild_j(N_j);
pen_j = p.pen_j(N_j);
nchild_j = p.nchild_j(N_j);
age_j = p.age_j(N_j);
s_j = p.s_j(N_j);
% V(a,h,z,N_j) = max_{d,a'}
V_d = zeros(n_a,n_h,n_z,n_d);
Pol_aprime_d = zeros(n_a,n_h,n_z,n_d);
for d_c = 1:n_d
d_val = d_grid(d_c);
for z_c = 1:n_z
eta_m_val = z_grid(z_c,1);
eta_f_val = z_grid(z_c,2);
theta_val = z_grid(z_c,3);
for h_c = 1:n_h
h_val = h_grid(h_c);
% RetMat is (a',a)
RetMat = ReturnFn(d_val,aprime_val,a_val,h_val,eta_m_val,eta_f_val,theta_val,...
w_m,eff_j,w_f,pchild_j,pen_j,r,nchild_j,crra,nu,age_j,Jr);
[max_val,max_ind] = max(RetMat,[],1);
V_d(:,h_c,z_c,d_c) = max_val;
Pol_aprime_d(:,h_c,z_c,d_c) = max_ind;
end %end h
end %end z
end %end d
[V(:,:,:,N_j),d_max] = max(V_d,[],4);
Policy(1,:,:,:,N_j) = d_max; % Optimal d
for z_c=1:n_z
for h_c=1:n_h
for a_c = 1:n_a
d_star = d_max(a_c,h_c,z_c);
Policy(2,a_c,h_c,z_c,N_j) = Pol_aprime_d(a_c,h_c,z_c,d_star); % Optimal a'
end
end
end
%% Backward iteration over age
for j = N_j-1:-1:1
if verbose==1; fprintf('Age %d out of %d \n',j,N_j); end
V_next = V(:,:,:,j+1); %V(a',h',z')
% Set age-dependent parameters
eff_j = p.eff_j(j);
pchild_j = p.pchild_j(j);
pen_j = p.pen_j(j);
nchild_j = p.nchild_j(j);
age_j = p.age_j(j);
s_j = p.s_j(j);
for z_c = 1:n_z
eta_m_val = z_grid(z_c,1);
eta_f_val = z_grid(z_c,2);
theta_val = z_grid(z_c,3);
% Compute EV(a',h'), given z
% EV = zeros(n_a,n_h);
z_prob = pi_z(z_c,:)';
% for zprime_c = 1:n_z
% EV = EV+V_next(:,:,zprime_c)*z_prob(zprime_c);
% end %end z'
EV = V_next.*shiftdim(z_prob,-2); %V(a',h',z')*Prob(1,1,z')
EV = sum(EV,3); %EV(a',h')
for d_c=1:n_d
d_val = d_grid(d_c);
for h_c = 1:n_h
h_val = h_grid(h_c);
hprime_val = f_HC_accum(d_val,h_val,age_j,xi_1,xi_2,del_h,h_l);
% Ret_mat is (a',a)
% THIS IS SLOW ACCORDING TO PROFILER
Ret_mat = ReturnFn(d_val,aprime_val,a_val,h_val,eta_m_val,eta_f_val,theta_val,...
w_m,eff_j,w_f,pchild_j,pen_j,r,nchild_j,crra,nu,age_j,Jr);
%[ind_l,weight_l] = interp_toolkit(hprime_val,h_grid);
[ind_l,weight_l] = find_loc(h_grid,hprime_val);
if ind_l>length(h_grid)-1
error('ind_l out of bounds')
end
EV_interp = EV(:,ind_l)*weight_l+EV(:,ind_l+1)*(1-weight_l);
RHS_mat = Ret_mat+beta*s_j*EV_interp;
% THIS IS SLOW ACCORDING TO PROFILER
[max_val,max_ind] = max(RHS_mat,[],1);
% max_val and max_ind are (1,a)
V_d(:,h_c,z_c,d_c) = max_val; %best V given d
Pol_aprime_d(:,h_c,z_c,d_c) = max_ind; %best a' given d
end %end h
end %end d
end %end z
[V(:,:,:,j),d_max] = max(V_d,[],4);
Policy(1,:,:,:,j) = d_max; % Optimal d
for z_c=1:n_z
for h_c=1:n_h
for a_c = 1:n_a
d_star = d_max(a_c,h_c,z_c);
Policy(2,a_c,h_c,z_c,j) = Pol_aprime_d(a_c,h_c,z_c,d_star); % Optimal a'
end
end
end
end %end j
toc
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function F = ReturnFn(l_f,aprime,a,h_f,eta_m,eta_f,theta,w_m,eff_j,w_f,...
pchild_j,pen_j,r,nchild_j,crra,nu,agej,Jr)
% Calculate earnings (incl. child care costs) of men and women
y_m = w_m*eff_j*theta*eta_m;
y_f = w_f*l_f*(exp(h_f)*theta*eta_f - pchild_j);
% l_f can be either 0 or 1
% calculate available resources
cash = (1+r)*a + pen_j*(agej>=Jr) + (y_m + y_f)*(agej<Jr);
cons = cash-aprime;
%pos = cons>0;
F = (cons/(sqrt(2+nchild_j))).^(1-crra)/(1-crra) - nu*l_f;
F(cons<=0) = -inf;
end %end function "f_ReturnFn"
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [h_f_prime] = f_HC_accum(l_f,h_f,age_j,xi_1,xi_2,del_h,h_l)
% l_f: d variable that affects h'
% h_f: current-period value of h
h_f_prime = h_f + (xi_1 + xi_2*age_j)*l_f - del_h*(1-l_f);
h_f_prime = max(h_f_prime, h_l);
%h_f_prime = h_f;
%h_f_prime = max(h_f_prime, h_l);
end %end function f_HC_accum
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [jl,omega] = find_loc(x_grid,xi)
%-------------------------------------------------------------------------%
% DESCRIPTION
% Find jl s.t. x_grid(jl)<=xi<x_grid(jl+1)
% for jl=1,..,N-1
% omega is the weight on x_grid(jl) so that
% omega*x_grid(jl)+(1-omega)*x_grid(jl+1)=xi
% INPUTS
% x_grid must be a strictly increasing column vector (nx,1)
% xi must be a scalar
% OUTPUTS
% jl: Left point (scalar)
% omega: weight on the left point (scalar)
% NOTES
% See find_loc_vec.m for a vectorized version.
%-------------------------------------------------------------------------%
nx = size(x_grid,1);
jl = max(min(locate(x_grid,xi),nx-1),1);
%Weight on x_grid(j)
omega = (x_grid(jl+1)-xi)/(x_grid(jl+1)-x_grid(jl));
omega = max(min(omega,1),0);
end %end function "find_loc"
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function jl = locate(xx,x)
%function jl = locate(xx,x)
%
% x is between xx(jl) and xx(jl+1)
%
% jl = 0 and jl = n means x is out of range
%
% xx is assumed to be monotone increasing
n = length(xx);
if x<xx(1)
jl = 0;
elseif x>xx(n)
jl = n;
else
jl = 1;
ju = n;
while (ju-jl>1)
jm = floor((ju+jl)/2);
if x>=xx(jm)
jl = jm;
else
ju=jm;
end
end
end
end %end function locate
5 Comments
dpb
on 17 Jun 2024
In
Ret_mat = ReturnFn(d_val,aprime_val,a_val,h_val,eta_m_val,eta_f_val,theta_val,...
w_m,eff_j,w_f,pchild_j,pen_j,r,nchild_j,crra,nu,age_j,Jr);
there may be no internal loops but you're calling the function (N_j-1)*n_d*n_h times with single values for
eta_m_val = z_grid(z_c,1);
eta_f_val = z_grid(z_c,2);
theta_val = z_grid(z_c,3);
...
d_val = d_grid(d_c);
...
h_val = h_grid(h_c);
It's complex enough I've not tried to read the function code in detail, but where you could possibly save some time is vectorizing over those loops. It may (or may not) be possible, but there's where to look.
Or, precompute and save in an ND array, maybe...
Alessandro Maria Marco
on 17 Jun 2024
Edited: Alessandro Maria Marco
on 17 Jun 2024
dpb
on 17 Jun 2024
It may not be possible (or may become a so much more complex structure in order to hold the result that the memory and other overhead are more costly than "dead ahead" solution) as noted earlier.
It would take far more time to try to understand the bowels of the computation than I have time to devote to be able to do more than general comments, sorry, but as the final earlier comment alluded to, perhaps what is now a 2D array can be turned into a 3D array by plane for one of the looping variables; theoretically it could be a 5D array over the other two, but then the issue of memory likely becomes the problem instead.
Or, as said, unless the values are variable every time the code is run, the alternative approach could be to precompute externally and then read the precomputed result...the time to do the precomputation would be the same but if something else is what changes from one execution to another, you can make up for it there. But, unless some of it is invariant from one run to another, then likely it's going to be "you simply have to compute what you have to compute".
Alessandro Maria Marco
on 17 Jun 2024
Edited: Alessandro Maria Marco
on 17 Jun 2024
dpb
on 18 Jun 2024
I don't know that I'm terribly surprised; for loop code optimization has advanced by leaps and bounds since early days of MATLAB so while vectorized code can often be quicker if can get to a linear addressing mode in internal code, if the vecorized code is complex there may be little optimization that can be done and the combination of memory and code complexity can be counterproductive to the "deadahead" solution (as you've just demonstrated).
You do need to run the process monitor and ensure you really are not page faulting with the larger sizes; it may be you have a lot of installed memory but MATLAB isn't using it all...
Accepted Answer
More Answers (0)
Categories
Find more on Scripts in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!