Code covered by the BSD License

# CoSaMP and OMP for sparse recovery

### Stephen Becker (view profile)

01 Aug 2011 (Updated )

Orthogonal Matching Pursuit (OMP) and Compressive Sampling Matched Pursuit (CoSaMP).

CoSaMP( A, b, k, errFcn, opts )
```function [x,r,normR,residHist, errHist] = CoSaMP( A, b, k, errFcn, opts )
% x = CoSaMP( A, b, k )
%   uses the Compressive Sampling Matched Pursuit (CoSaMP) algorithm
%   (see Needell and Tropp's 2008 paper http://arxiv.org/abs/0803.2392 )
%   to estimate the solution to the equation
%       b = A*x     (or b = A*x + noise )
%   where there is prior information that x is sparse.
%
%   "A" may be a matrix, or it may be a cell array {Af,At}
%   where Af and At are function handles that compute the forward and transpose
%   multiplies, respectively. If the function handles are provided,
%   the the least-squares step is performed using LSQR (use could also use
%   CG on the normal equations, or other special variants).
%
% [x,r,normR,residHist,errHist] = CoSaMP( A, b, k, errFcn, opts )
%   is the full version.
% Outputs:
%   'x' is the k-sparse estimate of the unknown signal
%   'r' is the residual b - A*x
%   'normR' = norm(r)
%   'residHist'     is a vector with normR from every iteration
%   'errHist'       is a vector with the outout of errFcn from every iteration
%
% Inputs:
%   'A'     is the measurement matrix
%   'b'     is the vector of observations
%   'k'     is the estimate of the sparsity (you may wish to purposefully
%              over- or under-estimate the sparsity, depending on noise)
%              N.B. k < size(A,1) is necessary, otherwise we cannot
%                   solve the internal least-squares problem uniquely.
%   'errFcn'    (optional; set to [] to ignore) is a function handle
%              which will be used to calculate the error; the output
%              should be a scalar
%   'opts'  is a structure with more options, including:
%       .printEvery = is an integer which controls how often output is printed
%       .maxiter    = maximum number of iterations
%       .normTol    = desired size of norm(residual). This is also
%                       used to detect convergence when the residual
%                       has stopped decreasing in norm
%       .LSQR_tol   = when "A" is a set of function handles, this controls
%                       the tolerance in the iterative solver. For compatibility
%                       with older versions, the fieldname "cg_tol" is also OK.
%       .LSQR_maxit = maximum number of steps in the iterative solver. For compatibility
%                       with older versions, the fieldname "cg_maxit" is also OK.
%                       N.B. "CG" stands for conjugate gradient, but this code
%                            actually uses the LSQR solver.
%       .HSS        = if true, use the variant of CoSaMP that is similar to
%                       HHS Pursuit (see appendix A.2 of Needell/Tropp paper). Recommended.
%       .two_solves = if true, uses the variant of CoSaMP that re-solves
%                       on the support of size 'k' at every iteration
%                       (see appendix). This can be used with or without HSS variant.
%       .addK       = the number of new entries to add each time. By default
%                       this is 2*k (this was what is used in the paper).
%                       If you experience divergence, try reducing this.
%                       We recommend trying 1*k for most problems.
%                       correlations, the CoSaMP method specifies that you do
%                       not add the atoms if the correlation is exactly zero.
%                       In practice, it is better to not add the atoms of their
%                       correlation is nearly zero. "support_tol" controls
%                       what this 'nearly zero' number is, e.g. 1e-10.
%
%       Note that these field names are case sensitive!
%
%
% Stephen Becker, Aug 1 2011  srbecker@alumni.caltech.edu
%   Updated Dec 12 2012

if nargin < 5, opts = []; end
if ~isempty(opts) && ~isstruct(opts)
error('"opts" must be a structure');
end

function out = setOpts( field, default )
if ~isfield( opts, field )
opts.(field)    = default;
end
out = opts.(field);
end

printEvery  = setOpts( 'printEvery', 1000 );
maxiter     = setOpts( 'maxiter', 1000 );
normTol     = setOpts( 'normTol', 1e-10 );
cg_tol      = setOpts( 'cg_tol', 1e-6 );
cg_maxit    = setOpts( 'cg_maxit', 20 );
% Allow some synonyms
cg_tol      = setOpts( 'LSQR_tol', cg_tol );
cg_maxit    = setOpts( 'LSQR_maxit', cg_maxit );
HSS         = setOpts( 'HSS', false );
TWO_SOLVES  = setOpts( 'two_solves', false );
support_tol = setOpts( 'support_tol', 1e-10 );

if nargin < 5 || isempty(printEvery)
printEvery  = round(k,maxiter);
end

if nargin < 4
errFcn = [];
elseif ~isempty(errFcn) && ~isa(errFcn,'function_handle')
error('errFcn input must be a function handle (or leave the input empty)');
end

if iscell(A)
LARGESCALE  = true;
Af  = A{1};
At  = A{2};     % we don't really need this...
else
LARGESCALE  = false;
Af  = @(x) A*x;
At  = @(x) A'*x;
end

% -- Intitialize --
% start at x = 0, so r = b - A*x = b
r           = b;
Ar          = At(r);
N           = size(Ar,1);       % number of atoms
M           = size(r,1);        % size of atoms
if k > M/3
error('K cannot be larger than the dimension of the atoms');
end
x           = zeros(N,1);
ind_k       = [];

% indx_set    = zeros(k,1);  % created on-the-fly
% A_T         = zeros(M,k);
residHist   = zeros(k,1);
errHist     = zeros(k,1);

fprintf('Iter,   |T|,  Resid');
if ~isempty(errFcn)
fprintf(',   Error');
end
if LARGESCALE
fprintf(',   LSQR iterations and norm(residual)');
end
fprintf('\n');

if LARGESCALE

%     if exist( 'lsqr_wrapper','file')
%         LSQR_ALG    = @lsqr_wrapper;
%         % This is Stephen's wrapper (designed to imitate Matlab's
%         %   syntax ) to the version of LSQR that you can get
%         % at http://www.stanford.edu/group/SOL/software/lsqr/matlab/lsqr.m
%         % This version of LSQR is better than Matlab's implementation.
%     elseif exist( 'lsqr','file' )
%         LSQR_ALG    = @lsqr; % Matlab's version
%     else
%         disp('http://www.stanford.edu/group/SOL/software/lsqr/matlab/lsqr.m');
%         error('Need to have working copy of LSQR');
%     end

% Unfortunately, the Stanford group's version doesn't make it easy
%   to provide a starting value, so we will stick with Matlab's version.

LSQR_ALG    = @(RHS,Afcn,x0) lsqr(Afcn,RHS,cg_tol,cg_maxit,[],[],x0 );
end

for kk = 1:maxiter

% -- Step 1: find new index and atom to add
y_sort      = sort( abs(Ar),'descend');
cutoff      = max( cutoff, support_tol );
ind_new     = find( abs(Ar) >= cutoff );

% -- Merge:
T    = union( ind_new, ind_k );

% -- Step 2: update residual
if HSS
RHS     = r; % where r = b - A*x, so we'll need to add in "x" later
x_warmstart     = zeros(length(T),1);
else
RHS     = b;
x_warmstart     = x(T);
end

% -- solve for x on the suppor set "T"
if LARGESCALE
% Could do CG on the normal equations...
% Or, use LSQR:
%         x_T = LSQR_ALG(Afcn,RHS,cg_tol,cg_maxit);

% use an initial guess to warm-start the solver
Afcn = @(x,mode) partialA( N, T, Af, At, x, mode );
[x_T,flag,relres,CGiter] = LSQR_ALG(RHS,Afcn,x_warmstart);

else
x_T = A(:,T)\RHS;   % more efficient; equivalent to pinv when underdetermined.
%         x_T     = pinv( A(:,T) )*RHS;
end

if HSS
% HSS variation of CoSaMP
x_new       = zeros(N,1);
x_new(T)    = x_T;
x           = x + x_new;    % this is the key extra step in HSS
cutoff      = findCutoff( x, k );
x           = x.*( abs(x) >= cutoff );
ind_k       = find(x);

if TWO_SOLVES
if LARGESCALE
Afcn = @(x,mode) partialA( N, ind_k, Af, At, x, mode );
[x_T2,flag,relres,CGiter] = LSQR_ALG(b,Afcn,x(ind_k)); % not using "r", just using "b"
else
x_T2     = A(:,ind_k)\b;
end
x( ind_k )  = x_T2;
end

% update r
r_old   = r;
r       = b - Af(x);
else
% Standard CoSaMP
% Note: this is implemented *slightly* more efficiently
%   that the HSS variation

% Prune x to keep only "k" entries:
cutoff  = findCutoff(x_T, k);
Tk      = find( abs(x_T) >= cutoff );
% This is assuming there are no ties. If there are,
%    from a practical standpoint, it probably doesn't
%    matter much what you do. So out of laziness, we don't worry about it.
ind_k   = T(Tk);
x       = 0*x;
x( ind_k ) = x_T( Tk );

if TWO_SOLVES
if LARGESCALE
Afcn = @(x,mode) partialA( N, ind_k, Af, At, x, mode );
[x_T2,flag,relres,CGiter] = LSQR_ALG(b,Afcn,x(ind_k));
else
x_T2     = A(:,ind_k)\b;
end
x( ind_k )  = x_T2;
end

% Update x and r
r_old   = r;
if LARGESCALE
r   = b - Af( x );
else
% don't do a full matrix-vector multiply, just use necessary columns
r   = b - A(:,ind_k)*x_T( Tk );
end
end

% -- Print some info --
PRINT   = ( ~mod( kk, printEvery ) || kk == maxiter );
normR   = norm(r);
STOP    = false;
if normR < normTol || norm( r - r_old ) < normTol
STOP    = true;
PRINT   = true;
end

if ~isempty(errFcn)
er  = errFcn(x);
errHist(kk)     = er;
end
if PRINT
fprintf('%4d, %4d, %.2e', kk, length(T), normR);
if ~isempty(errFcn)
fprintf(', %.2e',er);
end
if LARGESCALE
fprintf(',  %d,  %.2e', CGiter,relres);
end
fprintf('\n');
end

residHist(kk)   = normR;

if STOP
disp('Reached stopping criteria');
break;
end

if kk < maxiter
Ar  = At(r); % prepare for next round
end

end

end % -- end of main function

function tau = findCutoff( x, k )
% finds the appropriate cutoff such that after hard-thresholding,
% "x" will be k-sparse
x   = sort( abs(x),'descend');
if k > length(x)
tau = x(end)*.999;
else
tau  = x(k);
end
end

% for LSQR, PCG, etc.
function z = partialA( N, T, Af, At, x, mode )
% Multiplies A_T(x) or A_T'(x) (the transpose)
%   where _T denotes restriction to the set "T"

switch lower(mode)
case 'notransp'
y   = zeros(N,1);
y(T)    = x;
z       = Af(y);

case 'transp'
y   = At(x);
% truncate:
z   = y(T);
end
end % -- end of subfuction

```