(?)

KNearestNeighbors

Nearest Neighbor classifiers

Nearest-neighbor rules for classification are some of the most intuitive types of classifiers to consider. A nearest neighbor rule simply memorizes (stores) all of the training data, and when a new test point is given, compares the new point to all of the training examples, finds the nearest one, and predicts that the new point has the same class as that one.

A k-nearest neighbor rule is a simple extension -- instead of finding the single nearest training example, find the nearest k of them, and predict the majority class value, i.e., the class to which the most of those k examples belong.

The following figures show several classifiers as a function of k, the number of neighbors used.

Mini: PHP-GD image library not found. Exiting.Mini: PHP-GD image library not found. Exiting.Mini: PHP-GD image library not found. Exiting.
k=1k=3k=5

Decision boundary

An alternative representation to any classifier is its decision boundary, the places at which it changes from one decision to another.

Since the decision of a kNN classifier is defined by the nearest training examples, the decision boundary of a kNN classifier consists of those locations at which the set of the nearest training data change by (at least) one example. A location on the boundary is one balanced between those two sets -- in other words, it must be equidistant between two training data, and since it must change the balance of the set, they must be two points in different classes. For any two points, the set of equidistant locations is a line, perpendicular to the line joining the two points.

This means the decision boundary is locally linear (piecewise linear); this appearance can be easily seen in a small enough set of training data. By a similar argument, the transitions between linear segments must be equidistant from at least three points.

Complexity and K

The complexity of a k-nearest neighbor rule is a bit difficult to describe precisely. Nearest neighbor methods store all of the training examples, and don't have a simple notion of complexity like we saw in linear regression. However, consider how the training error of a k-nearest neighbor changes with k. At k=1, evaluating the prediction at any training data point finds a data point of distance zero (itself), and is guaranteed to predict the correct class. As k increases, this may no longer be true -- points of class 0 that are completely surrounded by points of class 1 are unlikely to predict their class correctly. As the following images suggest, as k is increased further the classifier is less able to memorize the data, reducing the complexity of the learner.

Mini: PHP-GD image library not found. Exiting.Mini: PHP-GD image library not found. Exiting.Mini: PHP-GD image library not found. Exiting.Mini: PHP-GD image library not found. Exiting.Mini: PHP-GD image library not found. Exiting.
k=1k=3k=5k=11k=21

Example code

Here is some example k nearest neighbor code in Matlab.

 classdef knnClassify < supLearn
 % Class implementing K-nearest-neighbors classifier 
properties (SetAccess=private, GetAccess=private) K=1; Xtrain=[]; % Training features, Ndata x Nfeatures
Ytrain=[]; % Training classes, Ndata x 1
end; methods % Constructor (takes zero arguments or 3)
function obj = knnClassify(K,Xtr,Ytr) if (nargin > 0) obj.K = K; obj.Xtrain = Xtr; obj.Ytrain = Ytr; end; end % set parameter K if desired
function obj=setK(obj, K) obj.K = K; end % Batch training: just memorize data
function obj=train(obj, Xtr,Ytr) obj.Xtrain = Xtr; obj.Ytrain = Ytr; end % Test function: predict on Xtest
function Yte = predict(obj,Xte) [Ntr,Mtr] = size(obj.Xtrain); % get size of training, test data
[Nte,Mte] = size(Xte); classes = unique(obj.Ytrain); % figure out how many classes & their labels
Yte = repmat(obj.Ytrain(1), [Nte,1]); % make Ytest the same data type as Ytrain
K = min(obj.K, Ntr); % can't have more than Ntrain neighbors
for i=1:Nte, % For each test example:
dist = sum( bsxfun( @minus, obj.Xtrain, Xte(i,:) ).^2 , 2); % compute sum of squared differences
[tmp,idx] = sort(dist); % find nearest neighbors over Xtrain (dimension 2)
cMax=1; NcMax=0; % then find the majority class within that set of neighbors
for c=1:length(classes), Nc = sum(obj.Ytrain(idx(1:K))==classes(c)); % count up how many instances of that class we have
if (Nc>NcMax), cMax=c; NcMax=Nc; end; % save the largest count and its class id
end; Yte(i)=classes(cMax); % save results
end; end end % end methods
end % end class
Last modified January 25, 2011, at 06:28 PM
Bren School of Information and Computer Science
University of California, Irvine