| Description of sllogistreg |
sllogistreg
PURPOSE 
SLLOGISTREG Performs Multivariate Logistic Regression
SYNOPSIS 
function [A, b, props, info] = sllogistreg(X, nums, varargin)
DESCRIPTION 
CROSS-REFERENCE INFORMATION 
This function calls:
- sladdvec SLADDVEC adds a vector to columns or rows of a matrix
- slmulvec SLMULVEC multiplies a vector to columns or rows of a matrix
- slposteriori SLPOSTERIORI Computes the posterioris
- slposterioritrue SLPOSTERIORITRUE Computes the posteriori that samples belong to true class
- raise_lackinput RAISE_LACKINPUT Raises an error indicating lack of input argument
- slexpand SLEXPAND Expand a set to multiple instance
- slnums2bounds SLNUMS2BOUNDS Compute the index-boundaries from section sizes
- slparseprops SLPARSEPROPS Parses input parameters
This function is called by:
SUBFUNCTIONS 
- function [f, g] = logistic_objfun(v, Xa, nums, d, n, C, pri, w)
- function L = compute_logit(Aa, Xa)
- function M = make_indicatormap(nums, C, n)
SOURCE CODE 
0001 function [A, b, props, info] = sllogistreg(X, nums, varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076 if nargin < 2
0077 raise_lackinput('sllogistreg', 2);
0078 end
0079
0080 [d, n] = size(X);
0081 if ~isvector(nums) || size(nums, 1) ~= 1
0082 error('sltoolbox:invalidarg', ...
0083 'nums should be a 1 x C row vector');
0084 end
0085 if sum(nums) ~= n
0086 error('sltoolbox:sizmismatch', ...
0087 'The numbers in nums are not consistent with the sample number');
0088 end
0089
0090 C = length(nums);
0091
0092 if C < 2
0093 error('sltoolbox:invalidarg', ...
0094 'There should be at least 2 classes.');
0095 end
0096
0097 opts.weights = [];
0098 opts.priors = [];
0099 opts.maxiter = 300;
0100 opts.tolF = 1e-6;
0101 opts.tolX = 1e-6;
0102 opts.display = 'off';
0103 opts.init = {};
0104 opts = slparseprops(opts, varargin{:});
0105
0106 w = opts.weights;
0107 if ~isempty(w)
0108 if ~isequal(size(w), [1, n])
0109 error('sltoolbox:sizmismatch', ...
0110 'w should be a 1 x n row vector');
0111 end
0112 end
0113
0114 pri = opts.priors;
0115 if ~isempty(pri)
0116 if ~isequal(size(pri), [1, C])
0117 error('sltoolbox:sizmismatch', ...
0118 'pri should be a 1 x C row vector');
0119 end
0120 end
0121
0122 if isempty(opts.init)
0123 is_inited = false;
0124 else
0125 A0 = opts.init{1};
0126 b0 = opts.init{2};
0127 if ~isequal(size(A0), [d, C])
0128 error('sltoolbox:sizmismatch', 'A0 should be a d x C matrix');
0129 end
0130 if ~isequal(size(b0), [1, C])
0131 error('sltoolbox:sizmismatch', 'b0 should be a 1 x C row vector');
0132 end
0133 is_inited = true;
0134 end
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144 Xa = [X; ones(1, n)];
0145
0146
0147 optimfunc = @(v) logistic_objfun(v, Xa, nums, d, n, C, pri, w);
0148 optimopts = optimset(optimset('fminunc'), ...
0149 'LargeScale', 'on', ...
0150 'GradObj', 'on', ...
0151 'MaxIter', opts.maxiter, ...
0152 'Display', opts.display, ...
0153 'TolFun', opts.tolF, ...
0154 'TolX', opts.tolX);
0155
0156
0157 if ~is_inited
0158 v0 = rand((d+1)*C, 1);
0159 else
0160 v0 = reshape([A0; b0], (d+1)*C, 1);
0161 end
0162
0163
0164 [v, fval, exitflag, optimoutput] = fminunc(optimfunc, v0, optimopts);
0165 clear v0;
0166 clear optimfunc;
0167 clear Xa;
0168
0169
0170 v = reshape(v, d+1, C);
0171 A = v(1:d, :);
0172 b = v(d+1, :);
0173 clear v;
0174
0175 if nargout >= 3
0176 L = compute_logit(A, X) ;
0177 L = sladdvec(L, b', 1);
0178 props = slposterioritrue(L, nums, pri, 'log');
0179 end
0180
0181 if nargout >= 4
0182 info.exitflag = exitflag;
0183 info.numiters = optimoutput.iterations;
0184 info.fval = -fval;
0185 end
0186
0187
0188
0189
0190 function [f, g] = logistic_objfun(v, Xa, nums, d, n, C, pri, w)
0191
0192
0193
0194 Aa = reshape(v, d+1, C);
0195 L = compute_logit(Aa, Xa);
0196 clear Aa;
0197
0198
0199 P = slposteriori(L, pri, 'log');
0200 [sps, eps] = slnums2bounds(nums);
0201 pps = zeros(1, n);
0202 for k = 1 : C
0203 sk = sps(k); ek = eps(k);
0204 pps(sk:ek) = P(k, sk:ek);
0205 end
0206
0207 if isempty(w)
0208 f = -sum(log(pps));
0209 else
0210 f = -sum(log(pps) .* w);
0211 end
0212
0213
0214 M = make_indicatormap(nums, C, n);
0215 M = M - P;
0216 clear P;
0217 if ~isempty(w)
0218 M = slmulvec(M, w, 2);
0219 end
0220 g = Xa * M';
0221 clear M;
0222 g = -g(:);
0223
0224
0225 function L = compute_logit(Aa, Xa)
0226
0227 L = Aa' * Xa;
0228
0229
0230
0231
0232 function M = make_indicatormap(nums, C, n)
0233
0234
0235
0236 M = zeros(C, n);
0237 I = slexpand(nums);
0238 J = 1:n;
0239 inds = sub2ind([C, n], I, J);
0240 clear I J;
0241 M(inds) = 1;
0242
Generated on Wed 20-Sep-2006 12:43:11 by m2html © 2003
|
|