How to calculate sum(A .* (B * C), 'all') [Ed. actually sum(A .* log(B * C), 'all')] efficiently when A is sparse and B*C is full and large?
4 views (last 30 days)
Show older comments
I have three matrices, A of size [J,I], B of size [J,K], C of size [K,I]. A is a sparse matrix with more than 90% zeros, while B and C are positive double matrices. The typical values are J=1e5, I=1e4, K=50.
The problem is that B*C creates a full matrix of size [J,I], which leads to redundant memory usage because what I need is merely the elements (B * C)(find(A)). My current constraint is that I don't have enough memory for a full matrix of size [J,I]. I wonder if there's a smart way to avoid such unnecessary memory usage for calculating this specific expression?
I have tried coding B into a tall array, but error appears like "tall arrays are not allowed to contain sparse data" when .* is evaluted. I also tried coding A into a tall array using tall(full(A)), but that's not reasonable because I need to restore A in full matrix first, and A is in fact not "tall" at all. Another way I tried to reduce memory usage is to devide A and C into blocks and calculate this expression in part (using a for-loop). However, this is not efficient, and does not reach the goal of removing redundant memory usage.
Thanks in advance!
0 Comments
Accepted Answer
James Tursa
on 1 Nov 2021
Edited: James Tursa
on 9 Nov 2021
Here is the straightforward mex code (i.e., no parallel sections) if you want to try it out. It computes the result directly in a loop without the need for large temporary memory allocations and data copying. You will need a supported C compiler installed. To compile it use the following at the command line:
mex sABC.c -R2018a
If you have an earlier version of MATLAB you can omit the -R2018a option.
To run it simply call as noted:
A = whatever
B = whatever
C = whatever
sABC(A,B,C)
The C source code:
/* File sABC.c
* sABC(A,B,C) returns sum(A.*log(B*C),'all')
*
* A = sparse real double MxN
* B = full real double MxK
* C = full real double KxN
*
* Programmer: James Tursa
* Date: 10/31/2021
*/
#include "mex.h"
#include <math.h>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double dot, result = 0.0;
mwSize M, K, N;
mwSize j, k, nrow;
double *A, *B, *C, *b, *c;
mwIndex *Air, *Ajc;
/* Argument checks */
if( nrhs != 3 ) {
mexErrMsgTxt("Need exactly three inputs.");
}
if( nlhs > 1 ) {
mexErrMsgTxt("Too many outputs.");
}
if( !mxIsDouble(prhs[0]) || !mxIsSparse(prhs[0]) || mxIsComplex(prhs[0]) ) {
mexErrMsgTxt("A must be real sparse double.");
}
if( !mxIsDouble(prhs[1]) || !mxIsDouble(prhs[2]) ||
mxIsSparse(prhs[1]) || mxIsSparse(prhs[2]) ||
mxIsComplex(prhs[1]) || mxIsComplex(prhs[2]) ) {
mexErrMsgTxt("B and C must be real full double matrices.");
}
if( mxGetNumberOfDimensions(prhs[1]) != 2 || mxGetNumberOfDimensions(prhs[2]) != 2 ) {
mexErrMsgTxt("B and C must be 2D.");
}
M = mxGetM(prhs[0]);
N = mxGetN(prhs[0]);
K = mxGetN(prhs[1]);
if( M != mxGetM(prhs[1]) ||
N != mxGetN(prhs[2]) ||
K != mxGetM(prhs[2]) ) {
mexErrMsgTxt("Dimensions not compatible.");
}
/* Calculate result, simple loop no parallel code */
Air = mxGetIr(prhs[0]);
Ajc = mxGetJc(prhs[0]);
A = (double *) mxGetData(prhs[0]);
B = (double *) mxGetData(prhs[1]);
C = (double *) mxGetData(prhs[2]);
for( j=0; j<N; j++ ) {
nrow = Ajc[j+1] - Ajc[j]; /* Number of row elements for this column */
while( nrow-- ) {
b = B + *Air++; /* B row pointer */
c = C + j*K; /* C column pointer */
dot = 0.0;
for( k=0; k<K; k++ ) { /* dot product of B row and C column */
dot += (*b) * (*c);
b += M;
c++;
}
result += *A++ * log(dot); /* Accumulate in result */
}
}
plhs[0] = mxCreateDoubleScalar(result);
}
More Answers (4)
Matt J
on 30 Oct 2021
Edited: Matt J
on 30 Oct 2021
Yes I'm still calculating sum(A .* (log(B*C)),'all').
I would probably just break C down into a small number of chunks and loop, e.g.,
Cr=reshape(C,K,I/10,10);
Acell=mat2cell(A,J,ones(1,10)*I/10);
mysum=0;
for n=1:10
mysum=mysum+sum( Acell{n}.*log(B*Cr(:,:,n)) ,'all');
end
Another way I tried to reduce memory usage is to devide A and C into blocks and calculate this expression in part (using a for-loop). However, this is not efficient, and does not reach the goal of removing redundant memory usage.
I'm not sure why you conclude this is not efficient, but regardless, I don't think you're going to be able to avoid it (in the case where you have the log operation in there) unless there is some particular structure to the sparsity pattern in A that you haven't told us about.
It's important to remember that there is a lot of parallel computation happening in a matrix multiplication. When parallel computation is involved, the number of computations isn't necessarily the thing that dominates performance.
James Tursa
on 29 Oct 2021
Edited: James Tursa
on 29 Oct 2021
You could use this loop to avoid the memory usage, but it will run slowly because of the data copying going on in the background for the values, row, and column extractions from the variables. This extra data copying could be avoided in a mex routine if you really needed to recover that speed.
[row,col,v] = find(A);
mysum = 0;
for k=1:numel(v)
mysum = mysum + v(k)*(B(row(k),:)*C(:,col(k)));
end
4 Comments
James Tursa
on 31 Oct 2021
Do you have a supported C/C++ compiler installed? The code for this would be pretty straightforward.
See Also
Categories
Find more on Performance and Memory in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!