0001 function [nnidx, dists] = annsearch(X0, X, k, varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075 if nargin < 3
0076 error('annerror:invalidarg', ...
0077 'The number of input arguments should not be less than 3 for annsearch');
0078 end
0079 if nargout == 0
0080 return;
0081 elseif nargout > 2
0082 error('annerror:invalidarg', ...
0083 'The number of output arguments should not be larger than 2 for annsearch');
0084 end
0085
0086
0087 if ~isempty(X)
0088 exclude_self = false;
0089 else
0090 X = X0;
0091 k = k + 1;
0092 exclude_self = true;
0093 end
0094
0095
0096 [d, n0] = size(X0);
0097 if size(X, 1) ~= d
0098 error('annerror:invalidarg', ...
0099 'The dimension of training and query points should be consistent');
0100 end
0101 if k >= n0
0102 error('annerror:invalidarg', ...
0103 'The value k (neighborhood size) should be less than n0, the size of the whole set');
0104 end
0105
0106
0107 opts.errbound = 0;
0108 opts.split = 'suggest';
0109 opts.search = 'normal';
0110 opts = slparseprops(opts, varargin{:});
0111
0112
0113
0114
0115
0116 if nargout == 1
0117
0118 nnidx = annsearch_wrapper( ...
0119 X0, ...
0120 X, ...
0121 k, ...
0122 opts.errbound, ...
0123 get_splitrule_id(opts.split), ...
0124 get_searchmethod_id(opts.search));
0125
0126 else
0127
0128 [nnidx, dists] = annsearch_wrapper( ...
0129 X0, ...
0130 X, ...
0131 k, ...
0132 opts.errbound, ...
0133 get_splitrule_id(opts.split), ...
0134 get_searchmethod_id(opts.search));
0135
0136 end
0137
0138
0139
0140 if exclude_self
0141 nnidx = nnidx(2:end, :);
0142
0143 if nargout >= 2
0144 dists = dists(2:end, :);
0145 end
0146 end
0147
0148 nnidx = nnidx + 1;
0149
0150 if nargout >= 2
0151 dists = sqrt(dists);
0152 end
0153
0154
0155
0156
0157 function id = get_splitrule_id(s)
0158
0159 switch s
0160 case 'suggest'
0161 id = 5;
0162 case 'std'
0163 id = 0;
0164 case 'midpt'
0165 id = 1;
0166 case 'fair'
0167 id = 2;
0168 case 'sl_midpt'
0169 id = 3;
0170 case 'sl_fair'
0171 id = 4;
0172 otherwise
0173 error('annerror:invalidarg', ...
0174 'Invalid split rule %s for annsearch', s);
0175 end
0176
0177
0178 function id = get_searchmethod_id(s)
0179
0180 switch s
0181 case 'normal'
0182 id = 0;
0183 case 'priority'
0184 id = 1;
0185 otherwise
0186 error('annerror:invalidarg', ...
0187 'Invalid search method %s for annsearch', s);
0188 end
0189
0190
0191
0192
0193
0194
0195