How to improve speed of calculating trace in a script?

10 views (last 30 days)
Hi all,
In my project I have to calculate the trace of some matrix products, I have the following script to demonstrate the purpose:
clear; clc;
% number of total tests.
nTest = 500;
% part 1. generate nTest*2 random matrices.
nd = 1000;
nt = 100;
dis = cell(nTest, 2);
dis = cellfun(@(v) rand(nd, nt), dis, 'un', 0);
% part 2. perform truncated-SVD on each matrix,
% only leave nRem singular vectors and values.
nRem = 50;
disSVD = cell(nTest, 2);
for isvd = 1:nTest
for jsvd = 1:2
[u, s, v] = svd(dis{isvd, jsvd}, 0);
disSVD{isvd, jsvd} = {u(:, 1:nRem), s(1:nRem, 1:nRem), v(:, 1:nRem)};
% part 3. for each SVD result, perform trace to obtain disTrans. disTrans is
% non-symmetric, thus jtr needs to start from 1.
disTrans = zeros(nTest);
for itr = 1:nTest
u1 = disSVD{itr, 1};
for jtr = 1:nTest
u2 = disSVD{jtr, 2};
disTrans(itr, jtr) = ...
trace(u1{3} * u1{2}' * u1{1}' * u2{1} * u2{2} * u2{3}');
I ran profile to find out which part is the slowest, it turns out it's calculating the trace in part 3 due to the large number. Unfortunately in my project the number of calculating trace is also very large. So any idea of how to improve the speed of calculating the trace? The profile is shown here:
Many thanks!
Xiaohan Du
Xiaohan Du on 4 Apr 2018
I got it, opening the entry of testuiTujSortImprove shows:
It seems that the operations inside trace, i.e.
u1{3} * u1{2}' * u1{1}' * u2{1} * u2{2} * u2{3}'
cost the majority of time.

Sign in to comment.

Accepted Answer

Christine Tobler
Christine Tobler on 3 Apr 2018
You can make the trace operator work faster as follows: Currently, the input is two truncated SVDs, A1 = U1 * S1 * V1' and A2 = U2 * S2 * V2', and you are computing
correct? After inserting A1 and A2, you can use the property of trace that trace(A*B) == trace(B*A) (note that trace(A*B*C) ~= trace(A*C*B), see wikipedia).
So this means that you can rearrange
trace(V1*S1'*U1'*U2*S2*V2') == trace( (V2'*V1) * S1 * (U1'*U2) * S2)
Make sure that the parentheses are set like this, and all other operations are on nRem-by-nRem matrices.
By the way, you can also rewrite trace(A'*B) as sum(sum(A.*B)), but I'm not sure if this will give you a speedup for this case.

Sign in to comment.

More Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!