%TLASSO_EM Compute the t-lasso estimates for a specified regularisation parameter.
% Users should refer to the 'tlasso' function to use this software.
function [b, b0, sigma2, L, sigma2_ml, iter] = tlasso_EM(X, y, nu, tau2, theta_start, maxiter)

%% Setup
seps = sqrt(eps);
convcrit = 1e-6;
[n, p] = size(X);
X = [ones(n,1) X];

options = optimoptions('fminunc','Display', 'off', 'GradObj', 'off', 'Algorithm', 'quasi-newton');

%% Initial parameter estimates
if (~exist('theta_start','var'))
    b0 = 0;
    b = normrnd(0,1,p,1);
    sigma2=1;
else
    b0=theta_start(1);
    b=theta_start(2:end-1);
    sigma2=theta_start(end);
end
beta_old = randn(p + 1, 1);

%% Adjust X,y to allow EM for Student-t lasso
Xa = eye(p+1);
Xa(1,:) = [];
ya = zeros(p,1);
Xadj = [X; Xa];
yadj = [y; ya];

%% EM algorithm
Elambda2inv = zeros(p,1);
done = false;
iter = 0;
while(~done)
    
    %% E-step    
    Elambda2inv = sqrt((2*sigma2)./(b.^2)./tau2); % = 1./tau2 * sqrt((2*tau2*sigma2)./(b.^2));
    Elambda2inv = min(Elambda2inv, 1e10);
   
    mu = b0 + X(:,2:end)*b;
    e2 = (y - mu).^2 / sigma2;
    Eu2 = (nu + 1) ./ (nu + e2);
    Eu2 = max(Eu2, 1e-5);
    
    %% M-step
    beta = lscov(Xadj, yadj, [Eu2; Elambda2inv]);
    b0 = beta(1);
    b = beta(2:end);
    sigma2 = (sum(Eu2 .* (y - mu).^2) + sum(Elambda2inv .* b.^2)) / (n+p);
             
    %% Check convergence
    if(~any(abs(beta - beta_old) > convcrit * max(seps, abs(beta_old))))
        done = true;
    end    
    
    %% Update beta and iteration counter
    beta_old = beta;    
    iter = iter + 1;
    if (iter > maxiter)
        done = true;
    end
end
b(Elambda2inv > 1e6) = 0;

%% Find ML estimate of sigma2
%L = tlasso_negll(X(:,2:end),y,b0,b,sigma2,nu);
[sigma2_ml, L] = fminunc(@(Z) tlasso_negll(X(:,2:end), y, b0, b, exp(Z), nu), log(sigma2), options);
sigma2_ml = exp(sigma2_ml);

end