Code covered by the BSD License  

Highlights from
Statistical Learning Toolbox

from Statistical Learning Toolbox by Dahua Lin
Functions for statistical learning, pattern recognition and computer vision, covering many topics.

Description of slfindnn
Home > sltoolbox > graph > slfindnn.m

slfindnn

PURPOSE ^

SLFINDNN Finds the nearest neighbors using specified strategy

SYNOPSIS ^

function [nnidx, dists] = slfindnn(X0, X, method, varargin)

DESCRIPTION ^

SLFINDNN Finds the nearest neighbors using specified strategy

 $ Syntax $
   - [nnidx, dists] = slfindnn(X0, X, method, ...)
 
 $ Arguments $
   - X0:           The referenced samples in which the neighbors are found
   - X:            The query samples
   - method:       The method to find nearest neighbors
   - nnidx:        The indices of the nearest neighbors
   - dists:        The distances between the samples and corresponding
                   neighbors

 $ Description $
   - [nnidx, dists] = slfindnn(X0, X, method, ...) finds the nearest
     neighbors for all samples using the specified method. You can specify
     the X0 and X in three different configurations:
       - X0, []:   finds the nearest neighbors for the samples in X0, 
                   each sample itself is not considered as a neighbor
       - X0, X0:   finds the nearest neighbors for the samples in X0,
                   each sample itself is also taken as a neighbor
       - X0, X:    the query samples and the reference samples are not
                   in the same set.
     If there are n query samples, then nnidx is a cell array of size
     1 x n, and each cell contains a column vector of all indices of the
     neighbors of the corresponding sample. dists will be in the same form
     except that the values are distances instead of indices.
     \*
     \t      Table. The methods for nearest neighbor finding
     \h        name    &           description
              'knn'    & Strict KNN using exhaustive search, having the
                         following properties:
                           - 'K':  The number of neighbors to find for
                                   each query sample (default = 3)
                           - 'maxblk': The maximum number of distances
                                       that can be computed in one batch
                                       (default = 1e7)
                           - 'metric': The metric type used to compute
                                       distances. It can be string of
                                       the metric name, or a cell array
                                       of parameters for slmetric_pw.
                                       or a function handle in the form:
                                          D = f(X1, X2)
                                       (default = 'eucdist')
              'ann'    & Approximate KNN using KD-tree, having the 
                         following properties:
                           - 'K':  The number of neighbors to find for
                                   each query sample (default = 3)
              'eps'    & Find all neighbors with distance below a
                         threshold, having the following properties:
                           - 'e':  The threshold of the distance
                                   (default = 1)
                           - 'maxblk': The maximum number of distances
                                       that can be computed in one batch
                                       (default = 1e7)
                           - 'metric': The metric type used to compute
                                       distances. It can be string of
                                       the metric name, or a cell array
                                       of parameters for slmetric_pw.
                                       or a function handle in the form:
                                          D = f(X1, X2)
                                       (default = 'eucdist')
     \*

 $ Remarks $
   - In current version, the distances metric should have the attribute
     that it decreases when the samples become nearer. Don't use 
     similarity metrics. The metric customization only applies to 
     'knn' and 'eps', for 'ann', it can only use Euclidean distances.

 $ History $
   - Created by Dahua Lin, on Sep 8th, 2006
   - Modified by Dahua Lin, on Sep 18, 2006
       - add the functionality to support various distance metric types
         and user-supplied distances.

CROSS-REFERENCE INFORMATION ^

This function calls:
  • annsearch ANNSEARCH Approximate Nearest Neighbor Search
  • slmetric_pw SLMETRIC_PW Compute the metric between column vectors pairwisely
  • raise_lackinput RAISE_LACKINPUT Raises an error indicating lack of input argument
  • slparseprops SLPARSEPROPS Parses input parameters
  • slpartition SLPARTITION Partition a range into blocks in a specified manner
This function is called by:
  • slnngraph SLNNGRAPH Constructs a nearest neighborhood based graph

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [nnidx, dists] = slfindnn(X0, X, method, varargin)
0002 %SLFINDNN Finds the nearest neighbors using specified strategy
0003 %
0004 % $ Syntax $
0005 %   - [nnidx, dists] = slfindnn(X0, X, method, ...)
0006 %
0007 % $ Arguments $
0008 %   - X0:           The referenced samples in which the neighbors are found
0009 %   - X:            The query samples
0010 %   - method:       The method to find nearest neighbors
0011 %   - nnidx:        The indices of the nearest neighbors
0012 %   - dists:        The distances between the samples and corresponding
0013 %                   neighbors
0014 %
0015 % $ Description $
0016 %   - [nnidx, dists] = slfindnn(X0, X, method, ...) finds the nearest
0017 %     neighbors for all samples using the specified method. You can specify
0018 %     the X0 and X in three different configurations:
0019 %       - X0, []:   finds the nearest neighbors for the samples in X0,
0020 %                   each sample itself is not considered as a neighbor
0021 %       - X0, X0:   finds the nearest neighbors for the samples in X0,
0022 %                   each sample itself is also taken as a neighbor
0023 %       - X0, X:    the query samples and the reference samples are not
0024 %                   in the same set.
0025 %     If there are n query samples, then nnidx is a cell array of size
0026 %     1 x n, and each cell contains a column vector of all indices of the
0027 %     neighbors of the corresponding sample. dists will be in the same form
0028 %     except that the values are distances instead of indices.
0029 %     \*
0030 %     \t      Table. The methods for nearest neighbor finding
0031 %     \h        name    &           description
0032 %              'knn'    & Strict KNN using exhaustive search, having the
0033 %                         following properties:
0034 %                           - 'K':  The number of neighbors to find for
0035 %                                   each query sample (default = 3)
0036 %                           - 'maxblk': The maximum number of distances
0037 %                                       that can be computed in one batch
0038 %                                       (default = 1e7)
0039 %                           - 'metric': The metric type used to compute
0040 %                                       distances. It can be string of
0041 %                                       the metric name, or a cell array
0042 %                                       of parameters for slmetric_pw.
0043 %                                       or a function handle in the form:
0044 %                                          D = f(X1, X2)
0045 %                                       (default = 'eucdist')
0046 %              'ann'    & Approximate KNN using KD-tree, having the
0047 %                         following properties:
0048 %                           - 'K':  The number of neighbors to find for
0049 %                                   each query sample (default = 3)
0050 %              'eps'    & Find all neighbors with distance below a
0051 %                         threshold, having the following properties:
0052 %                           - 'e':  The threshold of the distance
0053 %                                   (default = 1)
0054 %                           - 'maxblk': The maximum number of distances
0055 %                                       that can be computed in one batch
0056 %                                       (default = 1e7)
0057 %                           - 'metric': The metric type used to compute
0058 %                                       distances. It can be string of
0059 %                                       the metric name, or a cell array
0060 %                                       of parameters for slmetric_pw.
0061 %                                       or a function handle in the form:
0062 %                                          D = f(X1, X2)
0063 %                                       (default = 'eucdist')
0064 %     \*
0065 %
0066 % $ Remarks $
0067 %   - In current version, the distances metric should have the attribute
0068 %     that it decreases when the samples become nearer. Don't use
0069 %     similarity metrics. The metric customization only applies to
0070 %     'knn' and 'eps', for 'ann', it can only use Euclidean distances.
0071 %
0072 % $ History $
0073 %   - Created by Dahua Lin, on Sep 8th, 2006
0074 %   - Modified by Dahua Lin, on Sep 18, 2006
0075 %       - add the functionality to support various distance metric types
0076 %         and user-supplied distances.
0077 %
0078 
0079 %% parse and verify input
0080 
0081 if nargin < 3
0082     raise_lackinput('slfindnn', 3);
0083 end
0084 
0085 if ~ismember(method, {'knn', 'ann', 'eps'})
0086     error('sltoolbox:invalidarg', ...
0087         'Invalid method for nearest neighbor finding: %s', method);
0088 end
0089 
0090 if isempty(X)
0091     X = X0;
0092     excludediag = true;
0093 else
0094     excludediag = false;
0095 end
0096 
0097 
0098 %% Main skeleton
0099 
0100 switch method
0101     case 'knn'
0102         if nargout < 2
0103             nnidx = find_knn(X0, X, excludediag, varargin{:});
0104         else
0105             [nnidx, dists] = find_knn(X0, X, excludediag, varargin{:});
0106         end
0107     case 'ann'
0108         if nargout < 2
0109             nnidx = find_ann(X0, X, excludediag, varargin{:});
0110         else
0111             [nnidx, dists] = find_ann(X0, X, excludediag, varargin{:});
0112         end
0113     case 'eps'
0114         if nargout < 2
0115             nnidx = find_eps(X0, X, excludediag, varargin{:});
0116         else
0117             [nnidx, dists] = find_eps(X0, X, excludediag, varargin{:});
0118         end
0119 end
0120 
0121 %% Core functions
0122 
0123 function [nnidx, dists] = find_knn(X0, X, excludediag, varargin)
0124 
0125 % parse input
0126 opts.K = 3;
0127 opts.maxblk = 1e7;
0128 opts.metric = 'eucdist';
0129 opts = slparseprops(opts, varargin{:});
0130 fhmetric = get_metricfunc(opts.metric);
0131 
0132 n = size(X, 2);
0133 K = getK(opts, X0);
0134 [secs, nsecs] = getparsecs(opts, X0, X);
0135 
0136 to_output_dist = (nargout >= 2);
0137 
0138 % prepare storage
0139 nnidx = zeros(K, n);
0140 if to_output_dist
0141     dists = zeros(K, n);
0142 end
0143 
0144 % compute and select
0145 for k = 1 : nsecs
0146     
0147     % compute distances
0148     sp = secs.sinds(k); ep = secs.einds(k);
0149     curdists = compute_pwdists(X0, X, fhmetric, sp, ep, excludediag);
0150     
0151     % sort distances
0152     [curdists, curnnidx] = sort(curdists, 1);
0153     
0154     % selecte and record
0155     curnnidx = curnnidx(1:K, :);
0156     nnidx(:, sp:ep) = curnnidx;    
0157     if to_output_dist
0158         curdists = curdists(1:K, :);
0159         dists(:, sp:ep) = curdists;
0160     end
0161     
0162     clear curnnidx curdists;        
0163     
0164 end
0165 
0166 % organize output
0167 nnidx = cols_to_cells(nnidx);
0168 if nargout >= 2
0169     dists = cols_to_cells(dists);
0170 end
0171 
0172 
0173 function [nnidx, dists] = find_ann(X0, X, excludediag, varargin)
0174 
0175 % parse input
0176 opts.K = 3;
0177 opts = slparseprops(opts, varargin{:});
0178 K = getK(opts, X0);
0179 to_output_dist = (nargout >= 2);
0180 
0181 if excludediag
0182     X = [];
0183 end
0184 
0185 % perform search
0186 if ~to_output_dist
0187     nnidx = annsearch(X0, X, K);
0188 else
0189     [nnidx, dists] = annsearch(X0, X, K);
0190 end
0191 
0192 % organize output
0193 nnidx = cols_to_cells(nnidx);
0194 if to_output_dist
0195     dists = cols_to_cells(dists);
0196 end
0197     
0198 
0199 function [nnidx, dists] = find_eps(X0, X, excludediag, varargin)
0200 
0201 % parse input
0202 opts.e = 1;
0203 opts.maxblk = 1e7;
0204 opts.metric = 'eucdist';
0205 opts = slparseprops(opts, varargin{:});
0206 fhmetric = get_metricfunc(opts.metric);
0207 [secs, nsecs] = getparsecs(opts, X0, X);
0208 to_output_dist = (nargout >= 2);
0209 
0210 % prepare storage
0211 n = size(X, 2);
0212 nnidx = cell(1, n);
0213 if to_output_dist
0214     dists = cell(1, n);
0215 end
0216 
0217 
0218 % compute and select
0219 for k = 1 : nsecs
0220     
0221     % compute distances
0222     sp = secs.sinds(k); ep = secs.einds(k);
0223     curdists = compute_pwdists(X0, X, fhmetric, sp, ep, excludediag);
0224     
0225     % filter
0226     is_selected = (curdists < opts.e);
0227     
0228     % store
0229     nnidx(sp:ep) = select_output_indices(is_selected);
0230     if to_output_dist
0231         dists(sp:ep) = select_output_values(curdists, is_selected);
0232     end
0233     
0234 end
0235 
0236 
0237 
0238 %% Auxiliary function
0239 
0240 function dists = compute_pwdists(X0, X, fhmetric, sp, ep, excludediag)
0241 
0242 n0 = size(X0, 2);
0243 n = size(X, 2);
0244 
0245 if sp == 1 && ep == n
0246     curX = X;
0247 else
0248     curX = X(:, sp:ep);
0249 end
0250 
0251 % dists = slmetric_pw(X0, curX, 'eucdist');
0252 dists = fhmetric(X0, curX);
0253 
0254 if excludediag
0255     curn = ep - sp + 1;
0256     inds_diag = sub2ind([n0, curn], sp:ep, 1:curn);
0257     dists(inds_diag) = inf;
0258 end
0259 
0260 
0261 function fh = get_metricfunc(m)
0262 
0263 if ischar(m)
0264     fh = @(X, Y) slmetric_pw(X, Y, m);
0265 elseif iscell(m)
0266     fh = @(X, Y) slmetric_pw(X, Y, m{:});
0267 elseif isa(m, 'function_handle')
0268     fh = m;
0269 else
0270     error('sltoolbox:invalidarg', 'The metric is specified incorrectly');
0271 end
0272 
0273 
0274     
0275 function C = cols_to_cells(M)   
0276 
0277 [m, n] = size(M);
0278 C = mat2cell(M, m, ones(1, n));
0279 
0280 
0281 function nnidx = select_output_indices(is_selected)
0282 
0283 n = size(is_selected, 2);
0284 nnidx = cell(1, n);
0285 for i = 1 : n
0286     nnidx{i} = find(is_selected(:,i));
0287 end
0288 
0289 function vals = select_output_values(vals0, is_selected)
0290 
0291 n = size(is_selected, 2);
0292 vals = cell(1, n);
0293 for i = 1 : n
0294     vals{i} = vals0(is_selected(:,i), i);
0295 end
0296 
0297 
0298 function K = getK(opts, X0)
0299 
0300 K = opts.K;
0301 n0 = size(X0, 2);
0302 if K >= n0
0303     error('sltoolbox:invalidarg', ...
0304         'The specified K should be less than the number of referenced samples');
0305 end
0306 
0307 function [secs, nsecs] = getparsecs(opts, X0, X)
0308 
0309 n0 = size(X0, 2);
0310 ss = max(floor(opts.maxblk / n0), 1);
0311 n = size(X, 2);
0312 secs = slpartition(n, 'maxblksize', ss);
0313 nsecs = length(secs.sinds);
0314 
0315 
0316 
0317 
0318 
0319

Generated on Wed 20-Sep-2006 12:43:11 by m2html © 2003

Contact us at files@mathworks.com