k-Nearest Neighbors: Who are close to you?

If you go to college, you probably have participated in at least a couple of student organizations. I’m just starting my 1st semester as a graduate student at Rochester Tech, and there are more than 350 organizations here. They are sorted into different categories based on the student’s interests. What defines these categories, and who says which org goes into what category? I’m sure if you asked the people running these organizations, they wouldn’t say that their org is just like someone else’s org, but in some way you know they are similar. Fraternities and sororities have the same interest in Greek Life. Intramural soccer and club tennis have the same interest in sports. The Latino group and the Asian American group have the same interest in cultural diversity. Perhaps if you measured the events and meetings run by these orgs, you could automatically figure out what category an organization belongs to. I’ll use student organizations to explain some of the concepts of k-Nearest Neighbors, arguably the simplest machine learning algorithm out there. Building the model consists only of storing the training dataset. To make a prediction for a new data point, the algorithm finds the closest data points in the training dataset—its “nearest neighbors.”

How It Works

In its simplest version, the k-NN algorithm only considers exactly one nearest neighbor, which is the closest training data point to the point we want to make a prediction for. The prediction is then simply the known output for this training point. Figure below illustrates this for the case of classification on the forge dataset:

Here, we added three new data points, shown as stars. For each of them, we marked the closest point in the training set. The prediction of the one-nearest-neighbor algorithm is the label of that point (shown by the color of the cross).

knn-classify-1.png

Instead of considering only the closest neighbor, we can also consider an arbitrary number, k, of neighbors. This is where the name of the k-nearest neighbors algorithm comes from. When considering more than one neighbor, we use voting to assign a label. This means that for each test point, we count how many neighbors belong to class 0 and how many neighbors belong to class 1. We then assign the class that is more frequent: in other words, the majority class among the k-nearest neighbors. The following example uses the five closest neighbors:

knn-classify-5.png

Again, the prediction is shown as the color of the cross. You can see that the prediction for the new data point at the top left is not the same as the prediction when we used only one neighbor.

While this illustration is for a binary classification problem, this method can be applied to datasets with any number of classes. For more classes, we count how many neighbors belong to each class and again predict the most common class.

Implementation From Scratch

Here’s the pseudocode for the kNN algorithm to classify one data point (let’s call it A):

For every point in our dataset:

  • calculate the distance between A and the current point

  • sort the distances in increasing order

  • take k items with lowest distances to A

  • find the majority class among these items

  • return the majority class as our prediction for the class of A

The Python code for the function is here:

def knnclassify(A, dataset, labels, k):
  datasetSize = dataset.shape[0]
  
  # Calculate the distance between A and the current point
  diffMat = tile(A, (datasetSize, 1)) - dataset
  sqDiffMat = diffMat ** 2
  sqDistances = sqDiffMat.sum(axis=1)
  distances = sqDistances ** 0.5
  
  # Sort the distances in increasing order
  sortedDistIndices = distances.argsort()
  
  # Voting with lowest k distances
  classCount = {}
  for i in range(k):
    voteIlabel = labels[sortedDistIndices[i]]
    classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    
  # Sort dictionary to find the majority class among these items
  sortedClassCount = sorted(classCount.iteritem(), key=operator.itemgetter(1), reverse=True)
  
  return sortedClassCount[0][0]

Let’s dig a bit deeper into the code:

  • The function knnclassify takes 4 inputs: the input vector to classify called A, a full matrix of training examples called dataSet, a vector of labels called labels, and k - the number of nearest neighbors to use in the voting. The labels vector should have as many elements in it as there are rows in the dataSet matrix.

  • We calculate the distances between A and the current point using the Euclidean distance.

  • Then we sort the distances in an increasing order.

  • Next, the lowest k distances are used to vote on the class of A.

  • After that, we take the classCount dictionary and decompose it into a list of tuples and then sort the tuples by the 2nd item in the tuple. The sort is done in reverse so we have the largest to smallest.

  • Lastly, we return the label of the item occurring the most frequently.

Implementation Via Scikit-Learn

Now let’s take a look at how we can implement the kNN algorithm using scikit-learn:

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

# Generate the iris dataset
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Split data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Choose the number of neighbors
clf = KNeighborsClassifier(n_neighbors=5)

# Fit the classifier on training set
clf.fit(X_train, y_train)

# Make the predictions on test set
predictions = clf.predict(X_test)
print("Test set predictions: {}".format(predictions))

# Evaluate the model
accuracy = clf.score(X_test, y_test)
print("Test set accuracy: {:.2f}".format(accuracy))

Let’s look into the code:

  • First, we generate the iris dataset.

  • Then, we split our data into a training and test set to evaluate generalization performance.

  • Next, we specify the number of neighbors (k) to 5.

  • Next, we fit the classifier using the training set.

  • To make predictions on the test data, we call the predict method. For each data point in the test set, the method computes its nearest neighbors in the training set and finds the most common class among them.

  • Lastly, we evaluate how well our model generalizes by calling the score method with test data and test labels.

Running the model should gives us a test set accuracy of 97%, meaning the model predicted the class correctly for 97% of the samples in the test dataset.

knn.png

Strengths and Weaknesses

In principle, there are two important parameters to the KNeighbors classifier: the number of neighbors and how you measure distance between data points.

  • In practice, using a small number of neighbors like three or five often works well, but you should certainly adjust this parameter.

  • Choosing the right distance measure is somewhat tricky. By default, Euclidean distance is used, which works well in many settings.

One of the strengths of k-NN is that the model is very easy to understand, and often gives reasonable performance without a lot of adjustments. Using this algorithm is a good baseline method to try before considering more advanced techniques. Building the nearest neighbors model is usually very fast, but when your training set is very large (either in number of features or in number of samples) prediction can be slow. When using the k-NN algorithm, it’s important to preprocess your data. This approach often does not perform well on datasets with many features (hundreds or more), and it does particularly badly with datasets where most features are 0 most of the time (so-called sparse datasets).

In Conclusion

The k-Nearest Neighbors algorithm is a simple and effective way to classify data. It is an example of instance-based learning, where you need to have instances of data close at hand to perform the machine learning algorithm. The algorithm has to carry around the full dataset; for large datasets, this implies a large amount of storage. In addition, you need to calculate the distance measurement for every piece of data in the database, and this can be cumbersome. An additional drawback is that kNN doesn’t give you any idea of the underlying structure of the data; you have no idea what an “average” or “exemplar” instance from each class looks like.

So, while the nearest k-neighbors algorithm is easy to understand, it is not often used in practice, due to prediction being slow and its inability to handle many features.

Reference Sources: