function [b, muBeta] = sample_beta(X, z, mvnrue, b0, sigma2, tau2, lambda2, delta2prod, omega2, XtX, weights, gprior, b, blocksample, blocksize, blockStart, blockEnd)
%SAMPLE_BETA samples the regression parameters beta.
%   [b, muBeta] = sample_beta(...) samples the regression parameters
%   beta from the conditional posterior distribution.
%
%
%   The input arguments are:
%       X           - [n x p] data matrix 
%       z           - [n x 1] target vector
%       mvnrue      - use Rue's algorithm? {true | false}
%       b0          - [1 x 1] intercept parameter
%       sigma2      - [1 x 1] noise variance
%       tau2        - [1 x 1] global variance hyperparameter
%       lambda2     - [p x 1] local variance hyperparameters
%       delta2prod  - [p x 1] combined group variance hyperparameters
%       omega2      - [n x 1] hyperparameters
%       XtX         - [p x p] pre-computed X'*X (if available)
%       gprior      - [1 x 1] true for gprior, otherwise false
%       b           - [p x 1] a sample from the posterior distribution
%       blocksample - [1 x 1] do we sample beta in blocks?
%       blocksize   - [k x 1] size of each beta block
%       blockstart  - [k x 1] start coordinate of each block
%       blockend    - [k x 1] end coordinate of each block
%
%   Return values:
%       b           - [p x 1] a sample from the posterior distribution
%       muBeta      - [p x 1] posterior mean
%
%   (c) Copyright Enes Makalic and Daniel F. Schmidt, 2016

sigma = sqrt(sigma2);
Lambda = sigma2 * tau2 * lambda2 .* delta2prod;

if(~blocksample)
    %% no block sampling
    alpha = (z - b0);
    
    % Use Rue's algorithm
    if(mvnrue)
        % If we don't have XtX precomputed
        if(weights)
            omega = sqrt(omega2);
            X = bsxfun(@rdivide, X, omega);          
            [b, muBeta] = fastmvg_rue(X ./ sigma, [], alpha ./ sigma ./ omega, Lambda, XtX, gprior);

        % XtX is precomputed (gaussian linear regression only)
        else
            [b, muBeta] = fastmvg_rue(X ./ sigma, XtX ./ sigma2, alpha ./ sigma, Lambda, XtX, gprior);
        end

    % Use Bhat. algorithm
    else
        omega = sqrt(omega2);
        X = bsxfun(@rdivide, X, omega);          
        [b, muBeta] = fastmvg_bhat(X./sigma, alpha ./ sigma ./ omega, Lambda);
    end
    
    
else
    %% Block sampling
    p = length(lambda2);
    muBeta = zeros(p,1);
    nBlocks = length(blocksize);
    
    alpha = (z - b0 - X*b);    
    
    % Use Rue's algorithm
    if(mvnrue)
        %% If we don't have XtX precomputed
        if(weights)
            omega = sqrt(omega2);
            Z = [];
            for k = 1 : nBlocks
                ix = ((1:p) >= blockStart(k)) & ((1:p) <= blockEnd(k)); % current block to sample
                if(gprior)
                    Z = XtX{k};
                end                
                
                Xscaled = bsxfun(@rdivide, X(:,ix), omega);
                alpha = alpha + X(:,ix)*b(ix);                          % faster update 
                [b(ix), muBeta(ix)] = fastmvg_rue(Xscaled ./ sigma, [], alpha ./ sigma ./ omega, Lambda(ix), Z, gprior);
                alpha = alpha - X(:,ix)*b(ix);
            end

        % XtX is precomputed (gaussian linear regression only)
        else
            for k = 1 : nBlocks
                ix = ((1:p) >= blockStart(k)) & ((1:p) <= blockEnd(k)); % current block to sample
                
                alpha = alpha + X(:,ix)*b(ix);                          % faster update 
                [b(ix), muBeta(ix)] = fastmvg_rue(X(:,ix) ./ sigma, XtX{k} ./ sigma2, alpha ./ sigma, Lambda(ix), XtX{k}, gprior);                
                alpha = alpha - X(:,ix)*b(ix);
                
            end
        end

    % Use Bhat. algorithm
    else
        omega = sqrt(omega2);
        for k = 1 : nBlocks
            ix = ((1:p) >= blockStart(k)) & ((1:p) <= blockEnd(k)); % current block to sample

            Xscaled = bsxfun(@rdivide, X(:,ix), omega);
            alpha = alpha + X(:,ix)*b(ix);                          % faster update 
            [b(ix), muBeta(ix)] = fastmvg_bhat(Xscaled./sigma, alpha ./ sigma ./ omega, Lambda(ix));
            alpha = alpha - X(:,ix)*b(ix);
        end
    end
end

end