Code covered by the BSD License  

Highlights from
Statistical Learning Toolbox

from Statistical Learning Toolbox by Dahua Lin
Functions for statistical learning, pattern recognition and computer vision, covering many topics.

slfindnn(X0, X, method, varargin)
function [nnidx, dists] = slfindnn(X0, X, method, varargin)
%SLFINDNN Finds the nearest neighbors using specified strategy
%
% $ Syntax $
%   - [nnidx, dists] = slfindnn(X0, X, method, ...)
% 
% $ Arguments $
%   - X0:           The referenced samples in which the neighbors are found
%   - X:            The query samples
%   - method:       The method to find nearest neighbors
%   - nnidx:        The indices of the nearest neighbors
%   - dists:        The distances between the samples and corresponding
%                   neighbors
%
% $ Description $
%   - [nnidx, dists] = slfindnn(X0, X, method, ...) finds the nearest
%     neighbors for all samples using the specified method. You can specify
%     the X0 and X in three different configurations:
%       - X0, []:   finds the nearest neighbors for the samples in X0, 
%                   each sample itself is not considered as a neighbor
%       - X0, X0:   finds the nearest neighbors for the samples in X0,
%                   each sample itself is also taken as a neighbor
%       - X0, X:    the query samples and the reference samples are not
%                   in the same set.
%     If there are n query samples, then nnidx is a cell array of size
%     1 x n, and each cell contains a column vector of all indices of the
%     neighbors of the corresponding sample. dists will be in the same form
%     except that the values are distances instead of indices.
%     \*
%     \t      Table. The methods for nearest neighbor finding
%     \h        name    &           description
%              'knn'    & Strict KNN using exhaustive search, having the
%                         following properties:
%                           - 'K':  The number of neighbors to find for
%                                   each query sample (default = 3)
%                           - 'maxblk': The maximum number of distances
%                                       that can be computed in one batch
%                                       (default = 1e7)
%                           - 'metric': The metric type used to compute
%                                       distances. It can be string of
%                                       the metric name, or a cell array
%                                       of parameters for slmetric_pw.
%                                       or a function handle in the form:
%                                          D = f(X1, X2)
%                                       (default = 'eucdist')
%              'ann'    & Approximate KNN using KD-tree, having the 
%                         following properties:
%                           - 'K':  The number of neighbors to find for
%                                   each query sample (default = 3)
%              'eps'    & Find all neighbors with distance below a
%                         threshold, having the following properties:
%                           - 'e':  The threshold of the distance
%                                   (default = 1)
%                           - 'maxblk': The maximum number of distances
%                                       that can be computed in one batch
%                                       (default = 1e7)
%                           - 'metric': The metric type used to compute
%                                       distances. It can be string of
%                                       the metric name, or a cell array
%                                       of parameters for slmetric_pw.
%                                       or a function handle in the form:
%                                          D = f(X1, X2)
%                                       (default = 'eucdist')
%     \*
%
% $ Remarks $
%   - In current version, the distances metric should have the attribute
%     that it decreases when the samples become nearer. Don't use 
%     similarity metrics. The metric customization only applies to 
%     'knn' and 'eps', for 'ann', it can only use Euclidean distances.
%
% $ History $
%   - Created by Dahua Lin, on Sep 8th, 2006
%   - Modified by Dahua Lin, on Sep 18, 2006
%       - add the functionality to support various distance metric types
%         and user-supplied distances.
%

%% parse and verify input

if nargin < 3
    raise_lackinput('slfindnn', 3);
end

if ~ismember(method, {'knn', 'ann', 'eps'})
    error('sltoolbox:invalidarg', ...
        'Invalid method for nearest neighbor finding: %s', method);
end

if isempty(X)
    X = X0;
    excludediag = true;
else
    excludediag = false;
end


%% Main skeleton

switch method
    case 'knn'
        if nargout < 2
            nnidx = find_knn(X0, X, excludediag, varargin{:});
        else
            [nnidx, dists] = find_knn(X0, X, excludediag, varargin{:});
        end
    case 'ann'
        if nargout < 2
            nnidx = find_ann(X0, X, excludediag, varargin{:});
        else
            [nnidx, dists] = find_ann(X0, X, excludediag, varargin{:});
        end
    case 'eps'
        if nargout < 2
            nnidx = find_eps(X0, X, excludediag, varargin{:});
        else
            [nnidx, dists] = find_eps(X0, X, excludediag, varargin{:});
        end
end

%% Core functions

function [nnidx, dists] = find_knn(X0, X, excludediag, varargin)

% parse input
opts.K = 3;
opts.maxblk = 1e7;
opts.metric = 'eucdist';
opts = slparseprops(opts, varargin{:});
fhmetric = get_metricfunc(opts.metric);

n = size(X, 2);
K = getK(opts, X0);
[secs, nsecs] = getparsecs(opts, X0, X);

to_output_dist = (nargout >= 2);

% prepare storage
nnidx = zeros(K, n);
if to_output_dist
    dists = zeros(K, n);
end

% compute and select
for k = 1 : nsecs
    
    % compute distances
    sp = secs.sinds(k); ep = secs.einds(k);
    curdists = compute_pwdists(X0, X, fhmetric, sp, ep, excludediag);
    
    % sort distances
    [curdists, curnnidx] = sort(curdists, 1);
    
    % selecte and record
    curnnidx = curnnidx(1:K, :);
    nnidx(:, sp:ep) = curnnidx;    
    if to_output_dist
        curdists = curdists(1:K, :);
        dists(:, sp:ep) = curdists;
    end
    
    clear curnnidx curdists;        
    
end

% organize output
nnidx = cols_to_cells(nnidx);
if nargout >= 2
    dists = cols_to_cells(dists);
end


function [nnidx, dists] = find_ann(X0, X, excludediag, varargin)

% parse input
opts.K = 3;
opts = slparseprops(opts, varargin{:});
K = getK(opts, X0);
to_output_dist = (nargout >= 2);

if excludediag
    X = [];
end

% perform search
if ~to_output_dist
    nnidx = annsearch(X0, X, K);
else
    [nnidx, dists] = annsearch(X0, X, K);
end

% organize output
nnidx = cols_to_cells(nnidx);
if to_output_dist
    dists = cols_to_cells(dists);
end
    

function [nnidx, dists] = find_eps(X0, X, excludediag, varargin)

% parse input
opts.e = 1;
opts.maxblk = 1e7;
opts.metric = 'eucdist';
opts = slparseprops(opts, varargin{:});
fhmetric = get_metricfunc(opts.metric);
[secs, nsecs] = getparsecs(opts, X0, X);
to_output_dist = (nargout >= 2);

% prepare storage
n = size(X, 2);
nnidx = cell(1, n);
if to_output_dist
    dists = cell(1, n);
end


% compute and select
for k = 1 : nsecs
    
    % compute distances
    sp = secs.sinds(k); ep = secs.einds(k);
    curdists = compute_pwdists(X0, X, fhmetric, sp, ep, excludediag);
    
    % filter
    is_selected = (curdists < opts.e);
    
    % store
    nnidx(sp:ep) = select_output_indices(is_selected);
    if to_output_dist
        dists(sp:ep) = select_output_values(curdists, is_selected);
    end
    
end



%% Auxiliary function

function dists = compute_pwdists(X0, X, fhmetric, sp, ep, excludediag)

n0 = size(X0, 2);
n = size(X, 2);

if sp == 1 && ep == n
    curX = X;
else
    curX = X(:, sp:ep);
end

% dists = slmetric_pw(X0, curX, 'eucdist');
dists = fhmetric(X0, curX);

if excludediag
    curn = ep - sp + 1;
    inds_diag = sub2ind([n0, curn], sp:ep, 1:curn);
    dists(inds_diag) = inf;
end


function fh = get_metricfunc(m)

if ischar(m)
    fh = @(X, Y) slmetric_pw(X, Y, m);
elseif iscell(m)
    fh = @(X, Y) slmetric_pw(X, Y, m{:});
elseif isa(m, 'function_handle')
    fh = m;
else
    error('sltoolbox:invalidarg', 'The metric is specified incorrectly');
end


    
function C = cols_to_cells(M)   

[m, n] = size(M);
C = mat2cell(M, m, ones(1, n));


function nnidx = select_output_indices(is_selected)

n = size(is_selected, 2);
nnidx = cell(1, n);
for i = 1 : n
    nnidx{i} = find(is_selected(:,i));
end

function vals = select_output_values(vals0, is_selected)

n = size(is_selected, 2);
vals = cell(1, n);
for i = 1 : n
    vals{i} = vals0(is_selected(:,i), i);
end


function K = getK(opts, X0)

K = opts.K;
n0 = size(X0, 2);
if K >= n0
    error('sltoolbox:invalidarg', ...
        'The specified K should be less than the number of referenced samples');
end

function [secs, nsecs] = getparsecs(opts, X0, X)

n0 = size(X0, 2);
ss = max(floor(opts.maxblk / n0), 1);
n = size(X, 2);
secs = slpartition(n, 'maxblksize', ss);
nsecs = length(secs.sinds);





 

Contact us at files@mathworks.com