function [ A,D,Bt,X_hat,niter,sse,sse_diff,tot_sparseness,col_sparseness ] = ...
    nnmf_sca(X,k,dchoice,balance,achoice,asparse,schoice,maxiter)
% Nonnegative matrix factorization and/or Sparse component analysis.
%
% Usage:
% [ A,D,Bt,X_hat,niter,sse,sse_diff,sparseness,tot_sparseness ] = ...
%    nnmf_sca(X,k,dchoice,balance,achoice,asparse,schoice,maxiter)
%
% X is an m x n data matrix.
% k is the number of factors requested.
% Two choices for the D matrix:
% dchoice = diag|ident
% Two choices for balancing the columns of A and the elements of the D
% diagonal:
% balance = balance|nobalance
% Three  choices for the A matrix:
% achoice = nneg|sparse|both
% asparse is the level of sparseness for the A matrix (e.g. 0.2 = 20% of
% zero elements in A)
% Two choices for sparseness, random or by columns:
% schoice = random|bycols
% maxiter is the maximum number of iterations allowed.

% Defaults:
if nargin < 8 maxiter = 500; end
if nargin < 7 schoice = 'random'; end
if nargin < 6 asparse = 0.2; end
if nargin < 5 achoice = 'nneg'; end
if nargin < 4 balance = 'nobalance'; end
if nargin < 3 dchoice = 'ident'; end

[i,~] = size(X);

% Sparseness: set the number of zero elements per column.
nsparse = round(asparse*i);

% Initialize D matrix
D = eye(k);
                
% Initialize the A matrix and normalize the columns.
A = rand(i,k);
for ncols = 1:k
    A(:,ncols) = A(:,ncols)/norm(A(:,ncols));
end

sse_diff = 1;
sse = 1;
niter = 0;
small0 = eps^(1/3);
small = 0;

% Iteration
while sse_diff > small0
if niter > maxiter
    break
end    
niter = niter + 1;

% Calculate Bt and implement nonnegativity:
% Bt = ((A*D)'*A*D)\(A*D)'*X
% Same as:
Bt = (A*D)\X;
neg_ind = Bt<=0;
Bt(neg_ind) = small;

% Normalize the rows. 
for nrows = 1:k
    Bt(nrows,:) = Bt(nrows,:)/norm(Bt(nrows,:)); 
end

% Calculate D:
switch dchoice
    case 'diag'
    Ap = pinv(A);
    % Dt =(Bt*B)\(Bt*X'*Ap')
    % Dt = B\X'*Ap';
    % Same as:
    D = Ap*(X/Bt);
    D = diag(diag(D));

    case 'ident'
    % nothing to do here
end        

% Calculate A:
% At = ((B*Dt)'*B*D')\(B*Dt)'*X'
% At = (B*D)\X';
% Same as:
A = X/(D*Bt);

switch achoice
    case 'nneg'     % Implement nonnegativity:
    neg_ind = A<=0;
    A(neg_ind) = small;

    case 'sparse'   % Implement sparseness:
    
        switch schoice
            case 'random'
                [A_sorted] = sort(abs(A(:)),'ascend');
                cutoff = A_sorted(nsparse*k);
                sparse_ind = abs(A)<=cutoff;
                A(sparse_ind) = small;
            case 'bycols'
                for kcols = 1:k
                    [A_sorted] = sort(abs(A(:,kcols)),'ascend');
                    cutoff = A_sorted(nsparse);
                    sparse_ind = abs(A(:,kcols))<=cutoff;
                    A(sparse_ind,kcols) = small;
                end
        end

    case 'both'     % Implement nonnegativity and sparseness:
    neg_ind = A<=0;
    A(neg_ind) = small;
    
        switch schoice
            case 'random'
                [A_sorted] = sort(abs(A(:)),'ascend');
                cutoff = A_sorted(nsparse*k);
                sparse_ind = abs(A)<=cutoff;
                A(sparse_ind) = small;
            case 'bycols'
                for kcols = 1:k
                    [A_sorted] = sort(abs(A(:,kcols)),'ascend');
                    cutoff = A_sorted(nsparse);
                    sparse_ind = abs(A(:,kcols))<=cutoff;
                    A(sparse_ind,kcols) = small;
                end
        end
end

% Normalize A: it provides a more balanced D for the option 'diag' at the 
% expense of a slightly poorer residual. Do not use with the option
% 'ident'.
switch balance
    case 'balance'
    for ncols = 1:k
        A(:,ncols) = A(:,ncols)/norm(A(:,ncols));
    end
    case 'nobalance'
end

% Calculate residual:
X_hat = A*D*Bt;
R = X - X_hat;
sse_old = sse;
sse = R(:)'*R(:); 
sse_diff = abs(sse_old - sse);   
end

% Calculate sparseness: 
% 1. total:
tot_sparseness = sum(A(:)==small)/(i*k);
% 2. by columns:
col_sparseness = zeros(1,k);
for kcols = 1:k
    col_sparseness(kcols) = sum(A(:,kcols)==small)/i;
end

end

