Code covered by the BSD License  

Highlights from
General Least Squares Regression

image thumbnail

General Least Squares Regression

by

 

02 Feb 2009 (Updated )

Multi Dimensional Multivariable Least Squares Regression

[fm,C9]=mreg(x,y,m)
function [fm,C9]=mreg(x,y,m)
% MREG Multi dimensional multivariable least squares regression
%
% This code manually derives the expression fm(x1,x2,x3,...,xk) for
% the equally spaced points in (k+1) dimensions, where m is the order of
% the approximation f.
%
% Input:
%
% x: k by n array, where k is the variable number, and n is the the number of
% points for each variable.
%
% y: 1 by n vector, consisting of the values for the 1 to k points in each of
% the variable x1 to xk.
%
% m: order of approximating function.
%
% Output:
%
% fm: explicit formula for approximation.
%
% C9: derived equations that were solved in order to obtain the coefficients.
%
% Example 1:
%
% x1=[0.54849       2.0229       1.8291       3.1395       4.6643      0.96014,...
% 3.4813        2.627       4.3057       1.9673      3.7063       1.7386       2.9305,...
% 0.22227];
%
% x2=[ 0.31796       2.2419       3.8175       3.8599       4.863    0.69437 ,...
% 0.4691       2.6517       2.4243       3.3572       2.6003     0.74999       1.3107 ,...
% 3.7747];
%
% y=[       4.7835       3.8262       2.9801       4.5111       4.5333,...
% 4.7531       4.4505       6.6162       1.4625        4.965       1.6536,...
% 0.83577       4.2511        3.151];
%
% [y,b]=mreg([x1;x2],y,3)
%
% y =  
% 8.531340+16.75951*x2-5.988564*x2^2+.3241113*x2^3-18.96994*x1-.5539259*x1*x2+1.189775*x1*x2^2+6.907194*x1^2-1.018538*x1^2*x2-.5316869*x1^3
%  
% b = 
%     [1x335 char]
%     [1x384 char]
%     [1x399 char]
%     [1x399 char]
%     [1x384 char]
%     [1x433 char]
%     [1x448 char]
%     [1x399 char]
%     [1x448 char]
%     [1x399 char]
% 
% (also plots the answer in 3D)
%
%
% Example 2:
% 
% x1=[6 7 8 5 3 2 1 2 5 8];
% 
% x2=[9 7 0 8 5 6 4 3 2 10];
% 
% y=[9 7 6 4 3 2 5 6 8 1];
% 
% mreg([x1;x2],y,1)
%  
% ans =
% 5.542073-.2829251*x2+.2310048*x1
%  
%  
% mreg([x1;x2],y,2)
%  
% ans =
% 3.797949-.6754073*x2-.2393447e-1*x2^2+2.256645*x1+.8248780e-1*x1*x2-.2478672*x1^2
%  
%  
% mreg([x1;x2],y,3)
%  
% ans =
% -24.62814+33.15517*x2-9.837215*x2^2+.8316257*x2^3-17.31119*x1+7.450905*x1*x2-.6438069*x1*x2^2+.3614982*x1^2-.4148167*x1^2*x2+.2851207*x1^3
% 
%
% Example 3:
% 
% x1=[6 7 8 5 3 2 1 2 5 8 5 4];
% 
% x2=[9 7 0 8 5 6 4 3 2 10 5 7];
% 
% x3=[8 9 7 6 5 4 3  5 6 7 5 3]; 
% 
% y=[9 7 6 4 3 2 5 6 8 1 8 0];
% 
% mreg([x1;x2;x3],y,2)
%  
% ans = 
% 32.48905-7.049319*x3+.1384845*x3^2-6.246954*x2+1.259510*x2*x3-.4498889e-1*x2^2+5.668323*x1-.4038119*x1*x3-.3257156*x1*x2-.1031725*x1^2
%
% numandina@gmail.com

global n
if size(x,2)~=length(y)
	error('MSG:ID3','lengths of martices x and y need to be the same')
end
if size(y,1)>1
	error('MSG:ID2','y matrix needs to be a vector')
end
	function [vct,p,x]=eiy(var,m)
		% credits to Roger Stafford for this line
		vct = diff([zeros(nchoosek(m+var,var),1),nchoosek(1:var+m,var)],1,2)-1;
		x=sym(zeros(1,var));
		for k=1:var
			x(k)=sym(['x',num2str(k)]);
		end
		x2=vct;
		G=[];
		for p=1:size(vct,1)
			G=[G;x.^x2(p,:)];
		end
		p=1;
		for h5=1:var
			p=p.*G(:,h5);
		end
	end
var=size(x,1);
[a,b,x2p]=eiy(var,m);
n=length(y);
c=sym(zeros(1,length(b)));
for h=1:length(b)
	if h<11
		c(h)=sym(['a00',num2str(h-1)]);
	elseif h<101		
		c(h)=sym(['a0',num2str(h-1)]);
	else
		c(h)=sym(['a',num2str(h-1)]);
	end
end
s=-sym('y');
for kj=1:length(b)
	s=s+b(kj).*c(kj);
end
s=s^2;
C9=cell(length(b),1);
for ks=1:length(c)
	C9{ks}=diff(s,c(ks))*sym('SSS');
end
for g=1:length(c)
	C9{g}=sym2str(expand(C9{g}));
	C9{g}=strb2(C9{g});
	C9{g}=strrep(C9{g},'*','.*');
	C9{g}=strrep(C9{g},'^','.^');
	C9{g}=strrep(C9{g},'SSS()','SSS(n)');
	C9{g}=strrep(C9{g},'..','.');
	for jh=1:length(x2p)
		C9{g}=strrep(C9{g},['x',num2str(jh)],['x(',num2str(jh),',:)']);
	end
	for y3=1:length(b)
		if y3<11
			C9{g}=strrep(C9{g},['a00',num2str(y3-1)],['sym(''a00',num2str(y3-1),''')']);
		elseif y3<101
			C9{g}=strrep(C9{g},['a0',num2str(y3-1)],['sym(''a0',num2str(y3-1),''')']);
		else
			C9{g}=strrep(C9{g},['a',num2str(y3-1)],['sym(''a',num2str(y3-1),''')']);
		end
	end
	ce(g)=eval(C9{g});
end
stt='ce(1)';
for kl=2:length(b)
	stt=[stt,',ce(',num2str(kl),')'];
end
ce2=eval(['solve(',stt,')']);
ce5=zeros(length(b),1);
try
	for lll=1:length(b)
		if lll<11
			ce5(lll)=eval(['ce2.','a00',num2str(lll-1)]);
		elseif lll<101
			ce5(lll)=eval(['ce2.','a0',num2str(lll-1)]);
		else
			ce5(lll)=eval(['ce2.','a',num2str(lll-1)]);
		end
	end
catch
	error('MSG:ID6','Increase number of points please')
end
fm=0;
for kkksdgsdg=1:length(ce5)
	fm=fm+b(kkksdgsdg).*ce5(kkksdgsdg);
end
fm=vpa(fm,7);
	function c=SSS(a)
		if ~nargin
			c=n;
			return
		end
		if numel(a)==1
			c=a*n;
			return
		end
		c=sum(a);
	end
	function ll=strb2(ll)		
		c=1;
		ci=1;		
		while c<length(ll)
			if strcmp(ll(c),'S') && strcmp(ll(c+1),'S') && strcmp(ll(c+2),'S') && ~strcmp(ll(c+3),'(')
				tend=ci;
				while tend<length(ll) && ~strcmp(ll(tend+1),'+') && ~strcmp(ll(tend+1),'-')
					tend=tend+1;
				end
				bite=ll(ci:tend);
				sa=1;
				tout1='';
				tout2='';
				while sa<length(bite)
					if strcmp(bite(sa),'a')
						tout1=bite(sa:sa+3);
					end
					sa=sa+1;
				end
				sa=1;
				while sa<length(bite)
					if strcmp(bite(sa),'S') && strcmp(bite(sa+1),'S') && strcmp(bite(sa+2),'S') && ~strcmp(bite(sa+3),'(')
						tout2=bite(sa:sa+2);
					end
					sa=sa+1;
				end				
				bite=strrep(bite,tout1,'');
				bite=strrep(bite,tout2,'');
				bite=['(',bite,')'];
				bite=strrep(bite,'**','*');
				bite=strrep(bite,'*)',')');
				bite=strrep(bite,'(*','(');
				if numel(tout1)
					bite=['SSS',bite,'*',tout1];
				else
					bite=['SSS',bite];
				end
				ll(ci:tend);
				ll=strrep(ll,ll(ci:tend),bite);												
				ci=tend+2;
			end		
			c=c+1;
		end
		ll=strrep(ll,'SSS','+SSS');
	end
if var==2
	plot3(x(1,:),x(2,:),y,'*')
	set(findobj(gca,'marker','*'),'marker','.','color','k','markersize',25)
	h=axis;
	grid on
	hold on
	colormap([.5 .5 .5])
	ezsurf(fm,[min(x(1,:)) max(x(1,:)) min(x(2,:)) max(x(2,:))])
	set(findobj('type','surface'),'facealpha',0.5)
	% zlim([min(y)-41 max(y)+41])
	axis(h)
	title('')
end
if var==1
		plot(x,y,'*')
		set(findobj(gca,'marker','*'),'marker','.','color','k','markersize',25)
		h=axis;
		hold on
		ezplot(fm,[min(x),max(x)])
		set(findobj('color','b'),'color',[.5 .5 .5],'linewidth',2)
		axis(h)
		title('')
end
function s=sym2str(a)
    cr=length(a);
	s=cell(1,cr);
	for k5=1:cr
		s{k5}=char(a(k5));
	end
        s=char(s);
end
end

Contact us