function [beta, beta0, retval] = tlasso(X, y, nu, varargin)
%TLASSO Robust Lasso Regression with Student-t Residuals.
%   [beta, beta0, retval] = tlasso(...) implements robust lasso regression
%      with t-distributed residuals. This is the main function the user
%      should be calling.
%
%      By default, the function standardises the predictor matrix X before performing
%      any lasso fits, and unstandardises the fitted coefficients before
%      returning them to the user. (use 'standardise' option to disable
%      this if required, though this is not recommended)
%
%      By default, the function determines a grid of 100 values of the regularisation 
%      parameter and produces a path along these values, along with likelihoods and information
%      criterion scores.  (use the 'ngrid' option to change the number of
%      grid points).
%
%      Cross-validation is optionally available. (use the 'CV' and 'MCReps'
%      options to enable and control CV).
%
%   The input arguments are:
%       X          - [n x p] data matrix (without the intercept)
%       y          - [n x 1] target vector
%       nu         - the degrees-of-freedom parameter for the t-distribution 
%                    (1 = cauchy, >1e4 is essentially Guassian)
%       varargin   - optional arguments described below.
%
%   The following optional arguments are supported in the format 'argument', value:
%       'tau2'      - find the lasso estimates for this single value of the 
%                     regularisation parameter (Default: generate full path)
%       'tau2_max'  - the maximum value of regularisation parameter for the
%                     regularisation path (use only if n <= p)
%       'CV'        - use k-fold cross-validation (Default: not used)
%       'MCReps'    - number of cross-validation repetitions. 
%                     The more repetitions the less variable the CV scores. (Default: 1)
%       'ngrid'     - Number of values of regularisation parameter to use
%                     when generating the regularisation path (Default: 100)
%       'maxiter'   - Maximum number of iterations for EM algorithm. Larger
%                     values can improve smoothness of paths. Setting to
%                     inf runs until a very strict convergence criterion. (Default: 50)
%
%   Return values:
%       beta        - values of coefficients along the lasso path
%       beta0       - values of the intercept along the lasso path
%       retval      - additional information
%
%   The 'retval' return value is a structure with (potentially) the following fields:
%       nu           - The degrees-of-freedom parameter used to run tlasso
%       sigma2       - Values of maximised scale parameter along the lasso path
%       negll        - Values of the negative-log-likelihood along the lasso path
%       tau2         - Values of the regularisation parameter used to define the lasso path
%       CVerr        - CV prediction error estimates along the lasso path
%       IndexMinCV   - Index along the path of model minimising CV error
%       dof          - Model degrees-of-freedom (number of fitted coefficients) along lasso path
%       BIC          - Bayesian information criterion score along the lasso path
%       IndexMinBIC  - Index along the path of model minimising BIC score
%       AICc         - Corrected Akaike information criterion score along the lasso path
%       IndexMinAICc - Index along the path of model minimising AICc score
%
%   Please see examples_tlasso1.m and examples_tlasso2.m for usage examples.
%
%   To cite this code:
%     Schmidt, D. F. & Makalic, E.
%     Robust Lasso Regression with Student-t Residuals
%     Lecture Notes in Artificial Intelligence, to appear, 2017.
%
%   (c) Copyright Daniel F. Schmidt and Enes Makalic, 2017
%
[n, p] = size(X);
ny     = length(y);

%% Parse options
inParser = inputParser;  

%% Default parameter values
defaultTau2        = NaN;
defaultCV          = NaN;
defaultMCReps      = 1;
defaultNgrid       = 100;
defaultTau2_max    = NaN;
defaultStandardise = true;
defaultMaxiter     = 50;

%% Define parameters
addParameter(inParser,'tau2',defaultTau2,@(x) isnumeric(x) && (x > 0) && isscalar(x));
addParameter(inParser,'tau2_max',defaultTau2_max, @(x) isnumeric(x) && isscalar(x));
addParameter(inParser,'CV',defaultCV,@(x) isnumeric(x) && isscalar(x) && x > 0 && x <= n && floor(x) == x);
addParameter(inParser,'MCReps',defaultMCReps,@(x) isnumeric(x) && isscalar(x) && x > 0 && floor(x) == x);
addParameter(inParser,'ngrid',defaultNgrid,@(x) isnumeric(x) && isscalar(x) && x > 0 && floor(x) == x);
addParameter(inParser,'standardise',defaultStandardise,@islogical);
addParameter(inParser,'maxiter',defaultMaxiter,@(x) isscalar(x) && x >= 50 && floor(x) == x);

%% Parse options
parse(inParser, varargin{:});  

tau2        = inParser.Results.tau2; 
CV          = inParser.Results.CV;
ngrid       = inParser.Results.ngrid;
mcreps      = inParser.Results.MCReps;
tau2_max    = inParser.Results.tau2_max;
dostand     = inParser.Results.standardise;
maxiter     = inParser.Results.maxiter;

retval.nu = nu;
retval.sigma2 = [];
retval.negll = [];
retval.tau2 = [];

%% Some error checking
if (ny ~= n)
    error('Target vector and predictor matrix have different number of rows.');
end
I = find(std(X) == 0);
if (~isempty(I))
    error('Column %d have a variance of zero.', I);
end    

%% Logic-check
if (~isnan(tau2) && ~isnan(CV))
    error('A value for regularisation parameter ''tau2'' cannot be used if CV is requested');
end

%% Standardise input data
if (dostand)
    [X, muX, normX] = standardise(X, y);
end

%% If a single tau2 was passed, use that only
if (~isnan(tau2))
    [beta, beta0, retval.sigma2, retval.negll] = tlasso_EM(X, y, nu, tau2, maxiter);
   
%% Otherwise
else
    %% Setup tau2 grid
    tau2_min = tlasso_FindMinTau2(X, y, nu);
    % If no tau2_max specified, and n > p, use heuristic based on ML
    if (isnan(tau2_max) && n > p)
        [tau2_max, B_ml, B0_ml, sigma2_ml, ~] = tlasso_FindMaxTau2(X, y, nu);
    
    % Otherwise
    else
        % If no value specified and n <= p, use a simpler heuristic
        if (isnan(tau2_max) && n <= p)
            warning('n <= p; using tau2_max = tau2_min * 1e6. You can use the ''tau2_max'' option to set this value manually as desired.');
            tau2_max = tau2_min * 1e6;
        end
        [B_ml, B0_ml, sigma2_ml, ~] = tlasso_EM(X, y, nu, tau2_max, maxiter);
    end
    
    % Error checking for tau2 grid
    if (tau2_max < tau2_min)
        error('Specified value of tau2_max(=%f) is less than computed value of tau2_min(=%f).', tau2_max, tau2_min);
    end
    
    % Finally, create the tau2 grid
    tau2 = [tau2_min/100, logspace(log10(tau2_min)/2, log10(tau2_max)/2, ngrid)].^2;

    %% Perform cross-validation if requested
    if (~isnan(CV))
        retval.CVErr = zeros(1, ngrid);
        for k = 1:mcreps
            cv = cvpartition(n, 'KFold', CV);
            for j = 1:CV
                [B, B0, sigma2, ~] = tlasso_path(X(cv.training(j),:), y(cv.training(j)), nu, tau2, [B0_ml;B_ml;sigma2_ml], maxiter);
                for i = 1:ngrid
                    retval.CVErr(i) = retval.CVErr(i) + tlasso_negll(X(cv.test(j),:), y(cv.test(j)), B0(i), B(:,i), sigma2(i), nu) / sum(cv.test(j));
                end
            end
        end
        retval.CVErr = retval.CVErr / (mcreps * CV);
    end
    
    %% Run the path
    [beta, beta0, retval.sigma2, retval.negll] = tlasso_path(X, y, nu, tau2, [B0_ml;B_ml;sigma2_ml], maxiter);
    
    % Nominate the point with the smallest CV error
    if (~isnan(CV))
        [~, retval.IndexMinCV] = min(retval.CVErr);
    end
end
retval.tau2 = tau2;

%% Information criteria
retval.dof = sum(beta~=0, 1);
retval.BIC = retval.negll + retval.dof/2*log(n);
[~,retval.IndexMinBIC] = min(retval.BIC);
retval.AICc = retval.negll + retval.dof*n./(n-retval.dof-1);
retval.AICc( (n-retval.dof-1) <= 1 ) = inf;
[~,retval.IndexMinAICc] = min(retval.AICc);

%% Re-scale coefficients
if (dostand)
    beta = bsxfun(@rdivide, beta, normX');
    beta0 = beta0 - muX*beta;
end

end

%% Standardise the covariates to have zero mean and x_i'x_i = 1
% function [X,meanX,stdX,y,meany]=standardise(X,y)
function [X,meanX,stdX,y,meany]=standardise(X,y)

%% params
n=size(X,1);
meanX=mean(X);
stdX=std(X,1)*sqrt(n);

%% Standardise Xs
X=bsxfun(@minus,X,meanX);
X=bsxfun(@rdivide,X,stdX);

%% Standardise ys (if neccessary)
if(nargin == 2)
    meany=mean(y);
    y=y-meany;
end;

%% done
end