image thumbnail
from SmartInv by Rouzaud Denis
Large sparse matrix inversion. Returns block diagonal, tridiagonal or pentadiagonal elements.

smartinv(varargin)
function Q = smartinv(varargin)
	% ---------------------------------------------------------------------
	%
	% Returns block mono/tri/penta diagonal elements of the inverse
	% of a square matrix.
	%
	% Useful for large sparse matrix which LU decomposition is easy to
	% compute for. Not the fastest way to compute inverse matrix, but avoid
	% memory problem of a full matrix storage.
	%
	% Optional progressive diagonal computation display. Useful to quickly
	% observe an important modification on  the diagonal.
	%
	% ---------------------------------------------------------------------
	%
	% Q = smartinv(N) returns the diagonal of N^-1. N is a square matrix n x n.
	%
	% Q = smartinv(N,blocksize) returns the block (blocksize x blocksize)
	%		diagonal elements of N^-1. 
	%
	% Q = smartinv(N,blocksize,type) returns the block type-diagonal 
	%		elements of N^-1 type can be mono, tri or penta. The type 
	%		determines the inclusion of diagonal block computation,
	%		i.e there is redundancy for tri (25%) and for penta (~45%).
	%
	% Q = smartinv(N,blocksize,position) returns the block diagonal elements 
	%		of N^-1 at position. Q is blocksize x blocksize if position
	%		is a single value or n x n if posistion is a vector. If
	%		position is not set, the block values will be computed in an
	%		order so the curve shape will appear progressively (if graph
	%		displaymode is used).
	%		
	% Q = smartinv(N,blocksize,...,'displaymode',displaymodevalue) specifies
	%		process graphical display: none, waitbar (default), graph.
	%		If position exists and is scalar, no display is set up at all.
	%		Graph mode will draw a curve for each diagonal element of block
	%		(i.e. a block represents an epoch, and each element of a block
	%		a variable to display).
	% 
	% Q = smartinv(N,blocksize,...,'linelegend',linelegendvalue) where
	%		linelegendvalue is a cell array of string specifying legend
	%		strings to display for each block element (used for graph
	%		display mode).
	%
	% Q = smartinv(N,blocksize,...,'linecolor',linecolorvalue) where 
	%		linecolorvalue is a cell array of color string or RGB value
	%		to display for each block element (used for graph display mode).
	%
	% Q = smartinv(N,blocksize,...,'linefunction',linefunctionname) where 
	%		linefunctionname is a function called to display value in
	%		graph displaymode. linefunctionname can be a string (the same
	%		function will be used for each element in the block) or a cell
	%		array of strings of blocksize length.
	%
	% ---------------------------------------------------------------------
	%
	% Based on LU decomposition
	%		P . A = L . U
	%		A^-1 = U^-1 . L^-1 . P
	%		Q = U \ ( L \ ( P . B ) )
	%		where B is zeros [n x (inclusion+1)*blocksize]
	%		and I [(inclusion+1)*blocksize]^2 at pind-th block row
	% type is used to define returned block inclusion
	%		if mono:  inclusion = 0
	%		if tri:   inclusion = 1
	%		if penta: inclusion = 2
	%
	% ---------------------------------------------------------------------
	%
	% Examples:
	%		Q = smartinv(N,9,'displaymode','graph','linelegend',linelegend,'linecolor',linecolor);
	%		will return mono block (9x9) diagonal of inverse matrix N^-1
	%		using graph display.
	%
	%		Q = smartinv(N,1,[1:900:size(N,1)],'displaymode','graph','linefunction','sqrt');
	%		will return 1 every 900 diagonal element of inverse matrix N^-1
	%		using graph display. The graph will display the square root of
	%		each element.
	%
	%		Q = smartinv(N,9,'penta');
	%		will return penta block (9x9) diagonal of inverse matrix N^-1
	%		displaying a waitbar.
	%
	%		Q = smartinv(N,1,1000);
	%		will return 1000th diagonal element of the inverse matrix N^-1.
	%
	% ---------------------------------------------------------------------
	%
	% Denis Rouzaud, TOPO, EPFL, 2009
	% Math. idea from Simone Deparis, IACS, EPFL
	%
	% ---------------------------------------------------------------------
	
	display = struct;
	
	% *******************************
	% CHECK ARGS
	% *******************************
	if nargin == 0
		error('No input specified.')
	end
	
	N = varargin{1};
	n = size(N,1);
	nleftarg = nargin-1;

	if nargin > 1
		blocksize = varargin{2};	
		if mod(n,blocksize)~=0
			error('N must be defined by an integer number of block of blocksize')
		end
		nleftarg = nleftarg - 1;
	else
		blocksize = 1;
	end

	displaymode = 'none';
	linelegend = cell(0,0);
	linecolor = cell(0,0);
	linefunction = cell(1,blocksize);
	
	for iar = 2:nargin 
		if strcmpi(varargin{iar},'displaymode')
			dpm = {'none','waitbar','graph'};
			idpm = find(strcmpi(varargin{iar+1},dpm), 1);
			if isempty(idpm)
				error('Displaymode should be none, waitbar or graph.')
			end
			displaymode = char(dpm(idpm));
			nleftarg = nleftarg - 2;
		elseif strcmpi(varargin{iar},'linecolor')
			linecolor = varargin{iar+1};
			if ~iscell(linecolor) || length(linecolor) ~= blocksize
				error('linecolor must be a cell array of strings of blocksize elements.');
			end
			nleftarg = nleftarg - 2;
		elseif strcmpi(varargin{iar},'linelegend')
			linelegend = varargin{iar+1};
			if ~iscell(linelegend) || length(linelegend) ~= blocksize
				error('linelegend must be a cell array of strings of blocksize elements.');
			end
			nleftarg = nleftarg - 2;
		elseif strcmpi(varargin{iar},'linefunction')
			linefunction = varargin{iar+1};
			if ischar(linefunction)
				linefunction = cell(1,blocksize);
				linefunction(:) = {varargin{iar+1}};
			elseif ~iscell(linefunction) || length(linefunction) ~= blocksize
				error('linefunction must be a cell array of strings of blocksize elements.');
			end
			for fi = 1:length(linefunction)
				if ~isempty(linefunction{fi})
					linefunction{fi} = str2func(linefunction{fi});
					if ~isa(linefunction{fi},'function_handle')
						error([linefunction{fi} ' is not a function handle.']);
					end
				end
			end
			nleftarg = nleftarg - 2;
		end
	end

	if nleftarg > 0 
		if ischar(varargin{3})

			type = {'mono','tri','penta'};
			inclusion = find(strcmpi(varargin{3},type))-1;
			if isempty(inclusion)
				error('Type must be mono, tri or penta.')
			end
			% blocks to be computed
			pind = setpind(1,(n/blocksize)-inclusion);

		else % if position is specified

			inclusion = 0;
			% blocks to be computed
			pind = varargin{3};
			if max(pind) > n/blocksize || isempty(pind)
				error('Specified position must be within N matrix dimension.')
			end	
			if mod(pind,1)~=0
				error('Position vector values must be integer')
			end
			if size(pind) > [1 1]
				error('Position must be a vector')
			end
			if length(pind) == 1			
				displaymode = 'none';
			end
		end
	else
		inclusion = 0;
		% blocks to be computed
		pind = setpind(1,(n/blocksize)-inclusion);
	end
	
	
	% *******************************
	% INITIALISE PROCESS
	% *******************************
	
	% LU decomposition
	[L,U,P,G] = lu(N);
	
	tic
	npind = length(pind);
	if strcmp(displaymode,'waitbar')
		display.h = waitbar(0,'Inverting matrix ...');
	elseif strcmp(displaymode,'graph')
		scrsz = get(0,'ScreenSize');
		display.h = figure('Visible','on','Position',[10 50 scrsz(3)-200 scrsz(4)-200],'MenuBar','none','Toolbar','figure',...
		   'Name','SmartInv','NumberTitle','off',...
		   'KeyPressFcn','','WindowButtonDownFcn','',...
		   'WindowButtonUpFcn','','WindowButtonMotionFcn','',...
		   'CloseRequestFcn',@close_window);%,'ResizeFcn',@update_window);
		display.ax = axes('Units','normalized','Position',[.02 .15 .96 .83 ],'XLim',[1 n/blocksize]);
		uipanel('Units','normalized','Position',[.06 .025 .4 .08]);
		display.step    = uicontrol('Style','text','String','Step: 1 / x'    ,'Units','normalized','Position',[.08 .03 .15 .03],'FontUnits','normalized','FontSize',.8,'HorizontalAlignment','left');
		display.process = uicontrol('Style','text','String','Process: 0 %'   ,'Units','normalized','Position',[.3 .03 .15 .03],'FontUnits','normalized','FontSize',.8,'HorizontalAlignment','left');
		display.rtime   = uicontrol('Style','text','String','Time remaining:','Units','normalized','Position',[.08 .07 .3 .03] ,'FontUnits','normalized','FontSize',.8,'HorizontalAlignment','left');
		display.pause  = uicontrol('Style','pushbutton','String','Pause','Units','normalized','Position',[.55 .025 .16 .08] ,'FontUnits','normalized','FontSize',.3,'Callback',@pause_Callback);
		display.return  = uicontrol('Style','pushbutton','String','Return','Units','normalized','Position',[.8 .025 .16 .08] ,'FontUnits','normalized','FontSize',.3,'Callback',@return_Callback);
		hold on
		ldat = struct;
		xdat = zeros(1,n/blocksize);
		for bi = 1:blocksize
			display.line(bi) = plot(1,1,'LineStyle','-','LineWidth',2);
			if ~isempty(linecolor)
				set(display.line(bi),'Color',linecolor{bi})
			end
			ldat(bi).y = zeros(1,n/blocksize);
		end
		if ~isempty(linelegend)
			legend(display.ax,linelegend)
		end
	end
	
	% *******************************
	% MAIN LOOP
	% *******************************
	% sparse indexing
	nel = npind * blocksize^2;
	li = zeros(nel,1);
	co = zeros(nel,1);
	va = zeros(nel,1);
	mi = 0;
	% init loop stop check
	pstop = 0;
	% elements to skip in inslock
	skip = inclusion*blocksize;
	% width of B and length of diagonal part of B
	w = (inclusion+1)*blocksize;
	
	for i=1:npind
		% length of zeros elements in first part of B
		z0 = (pind(i)-1)*blocksize;
		% length of zeros elements in third part of B
		z1 = n-w-z0;
		% B construction
		B=[										...
			sparse(z0,w);						...
			sparse(1:w, 1:w, ones(w,1), w, w);	...
			sparse(z1,w)						...
			];

		% computes block-width inverse
		Q0 = G * ( U \ ( L \ ( P * B ) ) );

		% position in Q
		if npind > 1,	q0 = z0;   else q0 = 0;   end
		% extracts specified block
		insblock = Q0(z0+1:z0+w,1:w);		
		
		for c = 1:w
			if i == 1 || inclusion == 0 || c > skip
				li(mi+1:mi+w) = q0+1:q0+w;
				co(mi+1:mi+w) = z0 + c;
				va(mi+1:mi+w) = insblock(:,c);
				mi = mi + w;
			else
				li(mi+1:mi+blocksize) = q0+skip+1:q0+w;
				co(mi+1:mi+blocksize) = z0 + c;
				va(mi+1:mi+blocksize) = insblock(skip+1:end,c);
				mi = mi + blocksize;
			end 
		end
		
		% Graph display update

		if ~strcmp(displaymode,'none') && ~ishandle(display.h)
			Q = [];
			return
		else
			rtimestr = remain_time(i,npind);
		end
		if strcmp(displaymode,'waitbar')
			msg = ['Inverting matrix ... ' sprintf('%3.1f',100*i/npind) '%. Remaining time: ' rtimestr];
			waitbar(i/npind,display.h,msg)
		elseif strcmp(displaymode,'graph')
			xdat(pind(i)) = pind(i);
			touse = find(xdat ~= 0);
			for bi = 1:blocksize
				ydat = insblock(bi,bi);
				if ~isempty(linefunction{bi})
					ydat = linefunction{bi}(ydat);
				end
				ldat(bi).y(pind(i)) = ydat;
				set(display.line(bi),'XData',xdat(touse),'YData',ldat(bi).y(touse))
			end
			set(display.rtime  ,'String',['Remaining time: ' rtimestr]);
			set(display.step   ,'String',['Step: ' num2str(i) ' / ' num2str(npind)]);
			set(display.process,'String',['Process: ' num2str(100*i/npind,'%.1f') ' %']);
			%need a pause to access figure tools!
			pause(10^-6)
		end
		if pstop, break, end
	end
	Q = sparse(li,co,va);
	
	if isfield(display,'pause') && ishandle(display.pause)
		set(display.pause,'String','Done','Enable','off')
		set(display.return,'String','Done','Enable','off')
	end
	
	if strcmp(displaymode,'waitbar')
		close_window()
	end
	
	function pind = setpind(v0,v1)
		dv = v1 - v0;
		pind = zeros(1,dv+1);
		pind(1) = v0;
		pind(2) = v1;
		depth = 1;
		inv = 3;
		while 1
			depth = depth*2;
			vi = (1:2:depth)/depth;
			v = v0+round(vi*dv);
			
			%number of new values
			nnv = length(v);
			pind(inv:inv+nnv-1) = v;
			inv = inv + nnv;	
			
			if inv > dv+1
				[dummy,ip] = unique(pind);
				pind = pind(sort(ip));
				break
			end
		end
		
	end
	
	function timestr = remain_time(step,nstep)
		% Returns remaing time string for an equivalent step process of
		% nstep at step.
		% tic must have been run once.
		% -------------------------------------------------
		% Denis Rouzaud, TOPO, EPFL, 2009
		% -------------------------------------------------
		ellapsed_time = toc;
		
		unit_proc_time = ellapsed_time / step;
		remain_proc = nstep-step;
		
		remain_time = unit_proc_time * remain_proc;
		
		remain_time_v = zeros(1,4); % d h m s
		
		sm   = 60;
		sh   = 60*sm;
		sday = 24*sh;
		
		remain_time_v(1) = floor( ( remain_time                            ) / sday );		% days
		remain_time_v(2) = floor( ( remain_time - (remain_time_v(1)* sday) ) / sh   );		% hours
		remain_time_v(3) = floor( ( remain_time - (remain_time_v(2)* sh  ) ) / sm   );		% min
		remain_time_v(4) = ceil ( ( remain_time - (remain_time_v(3)* sm  ) ) / 1    );		% sec
				
		%use only 2 values max! i.e. sec are not really intersting if more than 2 days! ;)
		ix = find(remain_time_v~=0, 1, 'first');
		istr = {' day(s)', ' h', ' min', ' sec'};
		
		timestr = [];
		if ix
			if ix < 4
				timestr = [sprintf('%3.0f', remain_time_v(ix)) char(istr(ix))];
			end
			ixe = min(4,ix+1);
			timestr = [timestr sprintf('%3.0f', remain_time_v(ixe)) char(istr(ixe))];
		end
	end

	function close_window(varargin)
		if ishandle(display.h)
			delete(display.h)
		end
	end

	function return_Callback(varargin)
		pstop = 1;
		set(display.return,'String','Stopped','Enable','off')
	end

	function pause_Callback(varargin)
		set(display.pause,'String','Press Key to continue','Enable','off')
		set(display.return,'Enable','off')
		pause
		if ishandle(display.pause)
			set(display.pause,'String','Pause','Enable','on')
			set(display.return,'Enable','on')
		end
	end
end

Contact us at files@mathworks.com