KNN Intuition
K-Nearest Neighbors is one of the simplest yet most powerful classification algorithms. The core idea is beautifully intuitive: to classify a new point, look at the K closest points in the training set and let them vote. No training phase, no learned parameters — just raw geometry in feature space.
The algorithm in three steps:
- Store all training examples $(x_i, y_i)$
- Compute the distance from the query point $x_q$ to every training point
- Vote among the K nearest neighbors to determine the predicted class
The Lazy Learning Paradigm
Unlike neural networks that compress training data into weight matrices, KNN keeps everything. This has profound implications for PyTorch implementation: we can leverage GPU-accelerated tensor operations to make inference fast despite the brute-force nature of the algorithm.
flowchart LR
A[Query Point x_q] --> B[Compute Distances]
B --> C[Sort / Top-K]
C --> D[Majority Vote]
D --> E[Predicted Class]
F[Training Data] --> B
Distance Metrics
The choice of distance metric fundamentally shapes what “nearest” means. Different metrics are sensitive to different data structures, and PyTorch lets us implement all of them as vectorized tensor operations.
Euclidean Distance
The most common distance metric, measuring straight-line distance in feature space:
$$d(x, y) = \sqrt{\sum_{i=1}^{n} (x_i - y_i)^2} = \|x - y\|_2$$
In PyTorch, we compute this efficiently using the expanded form to avoid explicit loops:
import torch
# Create sample data: 1000 training points, 50 features
torch.manual_seed(42)
X_train = torch.randn(1000, 50)
X_query = torch.randn(5, 50) # 5 query points
# Euclidean distance using cdist (most efficient)
distances = torch.cdist(X_query, X_train, p=2)
print(f"Distance matrix shape: {distances.shape}") # (5, 1000)
print(f"Nearest distance for query 0: {distances[0].min():.4f}")
The torch.cdist function computes pairwise distances using an optimized expansion that avoids materializing the full difference tensor:
import torch
# Manual Euclidean distance via expansion trick
# ||x - y||^2 = ||x||^2 + ||y||^2 - 2*x.y
X_train = torch.randn(1000, 50)
X_query = torch.randn(5, 50)
# Expanded form — memory efficient for large datasets
query_sq = (X_query ** 2).sum(dim=1, keepdim=True) # (5, 1)
train_sq = (X_train ** 2).sum(dim=1, keepdim=True).T # (1, 1000)
cross_term = X_query @ X_train.T # (5, 1000)
distances_sq = query_sq + train_sq - 2 * cross_term
distances = torch.sqrt(torch.clamp(distances_sq, min=0)) # Numerical safety
print(f"Distance matrix shape: {distances.shape}") # (5, 1000)
print(f"Min distance: {distances.min():.4f}")
Manhattan & Minkowski Distances
Manhattan distance (L1) measures distance along axis-aligned paths, making it robust to outliers in individual features:
$$d_{L1}(x, y) = \sum_{i=1}^{n} |x_i - y_i|$$
The general Minkowski distance unifies both with parameter $p$:
import torch
X_train = torch.randn(1000, 50)
X_query = torch.randn(5, 50)
# Manhattan distance (p=1)
dist_l1 = torch.cdist(X_query, X_train, p=1)
print(f"L1 distances shape: {dist_l1.shape}")
# Minkowski with p=3
dist_l3 = torch.cdist(X_query, X_train, p=3)
print(f"L3 distances shape: {dist_l3.shape}")
# Compare: L1 is more robust to outliers
X_outlier = X_train.clone()
X_outlier[0, 0] = 100.0 # Inject outlier in one feature
dist_l2_clean = torch.cdist(X_query[:1], X_train[:1], p=2)
dist_l2_outlier = torch.cdist(X_query[:1], X_outlier[:1], p=2)
dist_l1_clean = torch.cdist(X_query[:1], X_train[:1], p=1)
dist_l1_outlier = torch.cdist(X_query[:1], X_outlier[:1], p=1)
print(f"L2 ratio (outlier/clean): {(dist_l2_outlier / dist_l2_clean).item():.2f}")
print(f"L1 ratio (outlier/clean): {(dist_l1_outlier / dist_l1_clean).item():.2f}")
Cosine Similarity
For high-dimensional sparse data (text, embeddings), cosine similarity measures the angle between vectors rather than magnitude:
$$\text{cosine\_sim}(x, y) = \frac{x \cdot y}{\|x\| \|y\|}$$
import torch
import torch.nn.functional as F
# Cosine distance for high-dimensional data
X_train = torch.randn(1000, 768) # Simulating embedding vectors
X_query = torch.randn(5, 768)
# Normalize then compute dot product
X_train_norm = F.normalize(X_train, dim=1)
X_query_norm = F.normalize(X_query, dim=1)
cosine_sim = X_query_norm @ X_train_norm.T # (5, 1000)
cosine_dist = 1 - cosine_sim # Convert similarity to distance
print(f"Cosine similarity range: [{cosine_sim.min():.3f}, {cosine_sim.max():.3f}]")
print(f"Most similar training point for query 0: index {cosine_sim[0].argmax()}")
Vectorized KNN in PyTorch
The naive KNN implementation uses Python loops and is painfully slow. By leveraging PyTorch’s tensor operations, we can build a fully vectorized KNN classifier that runs on GPU and processes thousands of queries simultaneously.
Complete KNN Implementation
import torch
import torch.nn.functional as F
class KNNClassifier:
"""K-Nearest Neighbors classifier using pure PyTorch tensors."""
def __init__(self, k=5, metric='euclidean'):
self.k = k
self.metric = metric
self.X_train = None
self.y_train = None
def fit(self, X, y):
"""Store training data (no computation during fit)."""
self.X_train = X.float()
self.y_train = y.long()
return self
def _compute_distances(self, X_query):
"""Compute pairwise distances between query and training points."""
if self.metric == 'euclidean':
return torch.cdist(X_query, self.X_train, p=2)
elif self.metric == 'manhattan':
return torch.cdist(X_query, self.X_train, p=1)
elif self.metric == 'cosine':
X_q_norm = F.normalize(X_query, dim=1)
X_t_norm = F.normalize(self.X_train, dim=1)
return 1 - (X_q_norm @ X_t_norm.T)
else:
raise ValueError(f"Unknown metric: {self.metric}")
def predict(self, X_query):
"""Predict class labels for query points."""
X_query = X_query.float()
distances = self._compute_distances(X_query)
# Get indices of K nearest neighbors
_, topk_indices = distances.topk(self.k, largest=False) # (N_query, K)
# Get labels of nearest neighbors
topk_labels = self.y_train[topk_indices] # (N_query, K)
# Majority vote
num_classes = self.y_train.max() + 1
votes = torch.zeros(X_query.shape[0], num_classes, device=X_query.device)
votes.scatter_add_(1, topk_labels, torch.ones_like(topk_labels, dtype=torch.float))
return votes.argmax(dim=1)
# Demo: 2D classification
torch.manual_seed(42)
# Generate two clusters
X_class0 = torch.randn(100, 2) + torch.tensor([2.0, 2.0])
X_class1 = torch.randn(100, 2) + torch.tensor([-2.0, -2.0])
X_train = torch.cat([X_class0, X_class1])
y_train = torch.cat([torch.zeros(100), torch.ones(100)]).long()
# Classify new points
knn = KNNClassifier(k=5, metric='euclidean')
knn.fit(X_train, y_train)
X_test = torch.tensor([[0.0, 0.0], [3.0, 3.0], [-3.0, -3.0]])
predictions = knn.predict(X_test)
print(f"Predictions: {predictions.tolist()}") # [0 or 1, 0, 1]
Binary & Multiclass Classification
Binary Classification with Confidence
Beyond hard predictions, KNN naturally provides probability estimates based on the proportion of neighbors belonging to each class:
import torch
def knn_predict_proba(X_train, y_train, X_query, k=5):
"""KNN with probability estimates for binary classification."""
distances = torch.cdist(X_query.float(), X_train.float(), p=2)
_, topk_indices = distances.topk(k, largest=False)
topk_labels = y_train[topk_indices].float() # (N_query, K)
# Probability = proportion of positive neighbors
prob_positive = topk_labels.mean(dim=1)
return prob_positive
# Create overlapping binary classification problem
torch.manual_seed(42)
X_class0 = torch.randn(200, 2) + torch.tensor([1.0, 0.0])
X_class1 = torch.randn(200, 2) + torch.tensor([-1.0, 0.0])
X_train = torch.cat([X_class0, X_class1])
y_train = torch.cat([torch.zeros(200), torch.ones(200)]).long()
# Query points at varying distances from boundary
X_query = torch.tensor([[0.0, 0.0], [2.0, 0.0], [-2.0, 0.0], [0.5, 0.0]])
probs = knn_predict_proba(X_train, y_train, X_query, k=11)
for i, (point, prob) in enumerate(zip(X_query, probs)):
print(f"Point {point.tolist()}: P(class=1) = {prob:.3f}")
Multiclass Voting
import torch
def knn_multiclass(X_train, y_train, X_query, k=7, num_classes=10):
"""Multiclass KNN with weighted voting using inverse distance."""
distances = torch.cdist(X_query.float(), X_train.float(), p=2)
topk_dists, topk_indices = distances.topk(k, largest=False)
topk_labels = y_train[topk_indices] # (N_query, K)
# Inverse distance weighting (closer neighbors get more weight)
weights = 1.0 / (topk_dists + 1e-8) # Avoid division by zero
# Weighted vote accumulation
votes = torch.zeros(X_query.shape[0], num_classes)
for c in range(num_classes):
mask = (topk_labels == c).float()
votes[:, c] = (weights * mask).sum(dim=1)
predictions = votes.argmax(dim=1)
confidences = votes / votes.sum(dim=1, keepdim=True)
return predictions, confidences
# Simulate 5-class problem
torch.manual_seed(42)
centers = torch.tensor([[0, 0], [3, 3], [-3, 3], [3, -3], [-3, -3]], dtype=torch.float)
X_train = torch.cat([torch.randn(50, 2) + c for c in centers])
y_train = torch.cat([torch.full((50,), i) for i in range(5)]).long()
X_query = torch.tensor([[0.0, 0.0], [3.0, 3.0], [0.0, 3.0]])
preds, confs = knn_multiclass(X_train, y_train, X_query, k=7, num_classes=5)
for i, (pred, conf) in enumerate(zip(preds, confs)):
print(f"Query {i}: predicted class {pred.item()}, confidence {conf[pred]:.3f}")
Effect of K & Decision Boundaries
The hyperparameter K controls model complexity. Small K yields complex, jagged decision boundaries (low bias, high variance), while large K smooths boundaries (high bias, low variance).
Visualizing Decision Boundaries
import torch
import matplotlib.pyplot as plt
def plot_knn_boundary(X_train, y_train, k, ax, title):
"""Plot KNN decision boundary on a 2D grid."""
# Create mesh grid
x_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1
y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1
xx, yy = torch.meshgrid(
torch.linspace(x_min, x_max, 100),
torch.linspace(y_min, y_max, 100),
indexing='ij'
)
grid = torch.stack([xx.ravel(), yy.ravel()], dim=1)
# Predict on grid
distances = torch.cdist(grid, X_train.float(), p=2)
_, topk_idx = distances.topk(k, largest=False)
topk_labels = y_train[topk_idx]
num_classes = y_train.max() + 1
votes = torch.zeros(grid.shape[0], num_classes)
votes.scatter_add_(1, topk_labels, torch.ones_like(topk_labels, dtype=torch.float))
preds = votes.argmax(dim=1).reshape(xx.shape)
ax.contourf(xx.numpy(), yy.numpy(), preds.numpy(), alpha=0.3, cmap='RdYlBu')
ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap='RdYlBu', edgecolors='k', s=20)
ax.set_title(title)
# Generate moon-like data
torch.manual_seed(42)
n = 150
t = torch.linspace(0, 3.14, n)
X_top = torch.stack([torch.cos(t), torch.sin(t)], dim=1) + 0.2 * torch.randn(n, 2)
X_bot = torch.stack([1 - torch.cos(t), 1 - torch.sin(t) - 0.5], dim=1) + 0.2 * torch.randn(n, 2)
X_train = torch.cat([X_top, X_bot])
y_train = torch.cat([torch.zeros(n), torch.ones(n)]).long()
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
for ax, k in zip(axes, [1, 7, 50]):
plot_knn_boundary(X_train, y_train, k, ax, f'K = {k}')
plt.tight_layout()
plt.savefig('knn_boundaries.png', dpi=100, bbox_inches='tight')
plt.show()
KNN on MNIST
Let’s put our tensor-based KNN to the test on a real dataset. MNIST handwritten digits is the classic benchmark — 784-dimensional pixel vectors classified into 10 digits.
GPU-Accelerated KNN
import torch
import torchvision
import torchvision.transforms as transforms
import time
# Load MNIST
transform = transforms.ToTensor()
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Flatten to vectors
X_train = train_dataset.data.float().reshape(-1, 784) / 255.0
y_train = train_dataset.targets
X_test = test_dataset.data.float().reshape(-1, 784) / 255.0
y_test = test_dataset.targets
print(f"Train: {X_train.shape}, Test: {X_test.shape}")
# Use subset for demo (full 60K x 10K is expensive on CPU)
X_train_sub = X_train[:5000]
y_train_sub = y_train[:5000]
X_test_sub = X_test[:1000]
y_test_sub = y_test[:1000]
# KNN prediction
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X_tr = X_train_sub.to(device)
y_tr = y_train_sub.to(device)
X_te = X_test_sub.to(device)
k = 5
start = time.time()
# Batched inference to manage memory
batch_size = 200
all_preds = []
for i in range(0, X_te.shape[0], batch_size):
batch = X_te[i:i+batch_size]
dists = torch.cdist(batch, X_tr, p=2)
_, topk_idx = dists.topk(k, largest=False)
topk_labels = y_tr[topk_idx]
votes = torch.zeros(batch.shape[0], 10, device=device)
votes.scatter_add_(1, topk_labels, torch.ones_like(topk_labels, dtype=torch.float))
all_preds.append(votes.argmax(dim=1))
predictions = torch.cat(all_preds)
elapsed = time.time() - start
accuracy = (predictions == y_test_sub.to(device)).float().mean()
print(f"KNN (k={k}) accuracy on 1000 test samples: {accuracy:.4f}")
print(f"Inference time: {elapsed:.3f}s on {device}")
Limitations & Failure Modes
Curse of Dimensionality
As dimensionality grows, all points become approximately equidistant. The ratio of nearest to farthest neighbor distance approaches 1, making KNN’s distance-based decisions meaningless:
import torch
torch.manual_seed(42)
# Show how distance ratios collapse in high dimensions
for dim in [2, 10, 50, 100, 500, 1000]:
points = torch.randn(1000, dim)
query = torch.randn(1, dim)
dists = torch.cdist(query, points).squeeze()
ratio = dists.max() / dists.min()
print(f"Dim={dim:4d}: max/min distance ratio = {ratio:.3f}")
XOR & Donut Problems
KNN handles non-linear boundaries naturally (unlike the perceptron), but struggles with specific geometric patterns where neighborhood structure is misleading:
import torch
# XOR problem — KNN handles it well (unlike perceptron)
X_xor = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float)
y_xor = torch.tensor([0, 1, 1, 0])
# KNN with k=1 solves XOR perfectly
dists = torch.cdist(X_xor, X_xor, p=2)
dists.fill_diagonal_(float('inf')) # Exclude self
nn_idx = dists.argmin(dim=1)
print(f"XOR nearest neighbors: {y_xor[nn_idx].tolist()}") # Each point's NN has same label
# Donut problem — concentric circles
torch.manual_seed(42)
n = 200
angles = torch.rand(n) * 2 * 3.14159
r_inner = 1.0 + 0.2 * torch.randn(n)
r_outer = 3.0 + 0.2 * torch.randn(n)
X_inner = torch.stack([r_inner * torch.cos(angles), r_inner * torch.sin(angles)], dim=1)
X_outer = torch.stack([r_outer * torch.cos(angles), r_outer * torch.sin(angles)], dim=1)
X_donut = torch.cat([X_inner, X_outer])
y_donut = torch.cat([torch.zeros(n), torch.ones(n)]).long()
# KNN handles this well with appropriate K
query = torch.tensor([[0.0, 0.0], [2.0, 0.0], [3.5, 0.0]])
dists = torch.cdist(query, X_donut, p=2)
_, topk = dists.topk(7, largest=False)
preds = y_donut[topk]
for i, (q, p) in enumerate(zip(query, preds)):
majority = p.float().mean()
print(f"Query {q.tolist()}: neighbors vote = {p.tolist()}, P(outer) = {majority:.2f}")
Bias-Variance Tradeoff
KNN offers a clear illustration of the bias-variance tradeoff controlled by a single knob: K.
K vs. Accuracy Curve
As K increases from 1 to N, the model transitions from perfect training fit (zero bias, high variance) to a constant predictor (high bias, zero variance). The optimal K balances these competing errors.
Finding Optimal K with Cross-Validation
import torch
torch.manual_seed(42)
# Generate noisy 2-class data
X_class0 = torch.randn(300, 2) + torch.tensor([1.5, 0.0])
X_class1 = torch.randn(300, 2) + torch.tensor([-1.5, 0.0])
X_all = torch.cat([X_class0, X_class1])
y_all = torch.cat([torch.zeros(300), torch.ones(300)]).long()
# Shuffle
perm = torch.randperm(600)
X_all, y_all = X_all[perm], y_all[perm]
# K-fold cross-validation
n_folds = 5
fold_size = len(X_all) // n_folds
k_values = [1, 3, 5, 7, 11, 15, 21, 31, 51]
results = {}
for k in k_values:
fold_accs = []
for fold in range(n_folds):
val_start = fold * fold_size
val_end = val_start + fold_size
X_val = X_all[val_start:val_end]
y_val = y_all[val_start:val_end]
X_tr = torch.cat([X_all[:val_start], X_all[val_end:]])
y_tr = torch.cat([y_all[:val_start], y_all[val_end:]])
# Predict
dists = torch.cdist(X_val, X_tr, p=2)
_, topk_idx = dists.topk(k, largest=False)
topk_labels = y_tr[topk_idx]
votes = torch.zeros(X_val.shape[0], 2)
votes.scatter_add_(1, topk_labels, torch.ones_like(topk_labels, dtype=torch.float))
preds = votes.argmax(dim=1)
acc = (preds == y_val).float().mean().item()
fold_accs.append(acc)
results[k] = sum(fold_accs) / len(fold_accs)
print(f"K={k:2d}: CV accuracy = {results[k]:.4f}")
best_k = max(results, key=results.get)
print(f"\nBest K = {best_k} with accuracy = {results[best_k]:.4f}")