Back to PyTorch Mastery Series

K-Nearest Neighbors in PyTorch

May 29, 2026 Wasil Zafar 30 min read

Implement K-Nearest Neighbors from scratch using pure PyTorch tensor operations — vectorized distance computation, batched inference, decision boundary visualization, and MNIST classification without scikit-learn.

Table of Contents

  1. KNN Intuition
  2. Distance Metrics
  3. Vectorized KNN in PyTorch
  4. Binary & Multiclass Classification
  5. Effect of K & Decision Boundaries
  6. KNN on MNIST
  7. Limitations & Failure Modes
  8. Bias-Variance Tradeoff
  9. Related Articles

KNN Intuition

K-Nearest Neighbors is one of the simplest yet most powerful classification algorithms. The core idea is beautifully intuitive: to classify a new point, look at the K closest points in the training set and let them vote. No training phase, no learned parameters — just raw geometry in feature space.

Key Insight: KNN is a non-parametric, instance-based learner. It stores the entire training set and defers all computation to prediction time. This makes it a “lazy learner” — it does zero work during training but pays the cost at inference.

The algorithm in three steps:

  1. Store all training examples $(x_i, y_i)$
  2. Compute the distance from the query point $x_q$ to every training point
  3. Vote among the K nearest neighbors to determine the predicted class

The Lazy Learning Paradigm

Unlike neural networks that compress training data into weight matrices, KNN keeps everything. This has profound implications for PyTorch implementation: we can leverage GPU-accelerated tensor operations to make inference fast despite the brute-force nature of the algorithm.

KNN Classification Pipeline
flowchart LR
    A[Query Point x_q] --> B[Compute Distances]
    B --> C[Sort / Top-K]
    C --> D[Majority Vote]
    D --> E[Predicted Class]
    F[Training Data] --> B
                            

Distance Metrics

The choice of distance metric fundamentally shapes what “nearest” means. Different metrics are sensitive to different data structures, and PyTorch lets us implement all of them as vectorized tensor operations.

Euclidean Distance

The most common distance metric, measuring straight-line distance in feature space:

$$d(x, y) = \sqrt{\sum_{i=1}^{n} (x_i - y_i)^2} = \|x - y\|_2$$

In PyTorch, we compute this efficiently using the expanded form to avoid explicit loops:

import torch

# Create sample data: 1000 training points, 50 features
torch.manual_seed(42)
X_train = torch.randn(1000, 50)
X_query = torch.randn(5, 50)  # 5 query points

# Euclidean distance using cdist (most efficient)
distances = torch.cdist(X_query, X_train, p=2)
print(f"Distance matrix shape: {distances.shape}")  # (5, 1000)
print(f"Nearest distance for query 0: {distances[0].min():.4f}")

The torch.cdist function computes pairwise distances using an optimized expansion that avoids materializing the full difference tensor:

import torch

# Manual Euclidean distance via expansion trick
# ||x - y||^2 = ||x||^2 + ||y||^2 - 2*x.y
X_train = torch.randn(1000, 50)
X_query = torch.randn(5, 50)

# Expanded form — memory efficient for large datasets
query_sq = (X_query ** 2).sum(dim=1, keepdim=True)   # (5, 1)
train_sq = (X_train ** 2).sum(dim=1, keepdim=True).T  # (1, 1000)
cross_term = X_query @ X_train.T                       # (5, 1000)

distances_sq = query_sq + train_sq - 2 * cross_term
distances = torch.sqrt(torch.clamp(distances_sq, min=0))  # Numerical safety

print(f"Distance matrix shape: {distances.shape}")  # (5, 1000)
print(f"Min distance: {distances.min():.4f}")

Manhattan & Minkowski Distances

Manhattan distance (L1) measures distance along axis-aligned paths, making it robust to outliers in individual features:

$$d_{L1}(x, y) = \sum_{i=1}^{n} |x_i - y_i|$$

The general Minkowski distance unifies both with parameter $p$:

import torch

X_train = torch.randn(1000, 50)
X_query = torch.randn(5, 50)

# Manhattan distance (p=1)
dist_l1 = torch.cdist(X_query, X_train, p=1)
print(f"L1 distances shape: {dist_l1.shape}")

# Minkowski with p=3
dist_l3 = torch.cdist(X_query, X_train, p=3)
print(f"L3 distances shape: {dist_l3.shape}")

# Compare: L1 is more robust to outliers
X_outlier = X_train.clone()
X_outlier[0, 0] = 100.0  # Inject outlier in one feature

dist_l2_clean = torch.cdist(X_query[:1], X_train[:1], p=2)
dist_l2_outlier = torch.cdist(X_query[:1], X_outlier[:1], p=2)
dist_l1_clean = torch.cdist(X_query[:1], X_train[:1], p=1)
dist_l1_outlier = torch.cdist(X_query[:1], X_outlier[:1], p=1)

print(f"L2 ratio (outlier/clean): {(dist_l2_outlier / dist_l2_clean).item():.2f}")
print(f"L1 ratio (outlier/clean): {(dist_l1_outlier / dist_l1_clean).item():.2f}")

Cosine Similarity

For high-dimensional sparse data (text, embeddings), cosine similarity measures the angle between vectors rather than magnitude:

$$\text{cosine\_sim}(x, y) = \frac{x \cdot y}{\|x\| \|y\|}$$

import torch
import torch.nn.functional as F

# Cosine distance for high-dimensional data
X_train = torch.randn(1000, 768)  # Simulating embedding vectors
X_query = torch.randn(5, 768)

# Normalize then compute dot product
X_train_norm = F.normalize(X_train, dim=1)
X_query_norm = F.normalize(X_query, dim=1)

cosine_sim = X_query_norm @ X_train_norm.T  # (5, 1000)
cosine_dist = 1 - cosine_sim  # Convert similarity to distance

print(f"Cosine similarity range: [{cosine_sim.min():.3f}, {cosine_sim.max():.3f}]")
print(f"Most similar training point for query 0: index {cosine_sim[0].argmax()}")

Vectorized KNN in PyTorch

The naive KNN implementation uses Python loops and is painfully slow. By leveraging PyTorch’s tensor operations, we can build a fully vectorized KNN classifier that runs on GPU and processes thousands of queries simultaneously.

Complete KNN Implementation

import torch
import torch.nn.functional as F


class KNNClassifier:
    """K-Nearest Neighbors classifier using pure PyTorch tensors."""

    def __init__(self, k=5, metric='euclidean'):
        self.k = k
        self.metric = metric
        self.X_train = None
        self.y_train = None

    def fit(self, X, y):
        """Store training data (no computation during fit)."""
        self.X_train = X.float()
        self.y_train = y.long()
        return self

    def _compute_distances(self, X_query):
        """Compute pairwise distances between query and training points."""
        if self.metric == 'euclidean':
            return torch.cdist(X_query, self.X_train, p=2)
        elif self.metric == 'manhattan':
            return torch.cdist(X_query, self.X_train, p=1)
        elif self.metric == 'cosine':
            X_q_norm = F.normalize(X_query, dim=1)
            X_t_norm = F.normalize(self.X_train, dim=1)
            return 1 - (X_q_norm @ X_t_norm.T)
        else:
            raise ValueError(f"Unknown metric: {self.metric}")

    def predict(self, X_query):
        """Predict class labels for query points."""
        X_query = X_query.float()
        distances = self._compute_distances(X_query)

        # Get indices of K nearest neighbors
        _, topk_indices = distances.topk(self.k, largest=False)  # (N_query, K)

        # Get labels of nearest neighbors
        topk_labels = self.y_train[topk_indices]  # (N_query, K)

        # Majority vote
        num_classes = self.y_train.max() + 1
        votes = torch.zeros(X_query.shape[0], num_classes, device=X_query.device)
        votes.scatter_add_(1, topk_labels, torch.ones_like(topk_labels, dtype=torch.float))

        return votes.argmax(dim=1)


# Demo: 2D classification
torch.manual_seed(42)
# Generate two clusters
X_class0 = torch.randn(100, 2) + torch.tensor([2.0, 2.0])
X_class1 = torch.randn(100, 2) + torch.tensor([-2.0, -2.0])
X_train = torch.cat([X_class0, X_class1])
y_train = torch.cat([torch.zeros(100), torch.ones(100)]).long()

# Classify new points
knn = KNNClassifier(k=5, metric='euclidean')
knn.fit(X_train, y_train)

X_test = torch.tensor([[0.0, 0.0], [3.0, 3.0], [-3.0, -3.0]])
predictions = knn.predict(X_test)
print(f"Predictions: {predictions.tolist()}")  # [0 or 1, 0, 1]

Binary & Multiclass Classification

Binary Classification with Confidence

Beyond hard predictions, KNN naturally provides probability estimates based on the proportion of neighbors belonging to each class:

import torch


def knn_predict_proba(X_train, y_train, X_query, k=5):
    """KNN with probability estimates for binary classification."""
    distances = torch.cdist(X_query.float(), X_train.float(), p=2)
    _, topk_indices = distances.topk(k, largest=False)
    topk_labels = y_train[topk_indices].float()  # (N_query, K)

    # Probability = proportion of positive neighbors
    prob_positive = topk_labels.mean(dim=1)
    return prob_positive


# Create overlapping binary classification problem
torch.manual_seed(42)
X_class0 = torch.randn(200, 2) + torch.tensor([1.0, 0.0])
X_class1 = torch.randn(200, 2) + torch.tensor([-1.0, 0.0])
X_train = torch.cat([X_class0, X_class1])
y_train = torch.cat([torch.zeros(200), torch.ones(200)]).long()

# Query points at varying distances from boundary
X_query = torch.tensor([[0.0, 0.0], [2.0, 0.0], [-2.0, 0.0], [0.5, 0.0]])
probs = knn_predict_proba(X_train, y_train, X_query, k=11)

for i, (point, prob) in enumerate(zip(X_query, probs)):
    print(f"Point {point.tolist()}: P(class=1) = {prob:.3f}")

Multiclass Voting

import torch


def knn_multiclass(X_train, y_train, X_query, k=7, num_classes=10):
    """Multiclass KNN with weighted voting using inverse distance."""
    distances = torch.cdist(X_query.float(), X_train.float(), p=2)
    topk_dists, topk_indices = distances.topk(k, largest=False)

    topk_labels = y_train[topk_indices]  # (N_query, K)

    # Inverse distance weighting (closer neighbors get more weight)
    weights = 1.0 / (topk_dists + 1e-8)  # Avoid division by zero

    # Weighted vote accumulation
    votes = torch.zeros(X_query.shape[0], num_classes)
    for c in range(num_classes):
        mask = (topk_labels == c).float()
        votes[:, c] = (weights * mask).sum(dim=1)

    predictions = votes.argmax(dim=1)
    confidences = votes / votes.sum(dim=1, keepdim=True)
    return predictions, confidences


# Simulate 5-class problem
torch.manual_seed(42)
centers = torch.tensor([[0, 0], [3, 3], [-3, 3], [3, -3], [-3, -3]], dtype=torch.float)
X_train = torch.cat([torch.randn(50, 2) + c for c in centers])
y_train = torch.cat([torch.full((50,), i) for i in range(5)]).long()

X_query = torch.tensor([[0.0, 0.0], [3.0, 3.0], [0.0, 3.0]])
preds, confs = knn_multiclass(X_train, y_train, X_query, k=7, num_classes=5)

for i, (pred, conf) in enumerate(zip(preds, confs)):
    print(f"Query {i}: predicted class {pred.item()}, confidence {conf[pred]:.3f}")

Effect of K & Decision Boundaries

The hyperparameter K controls model complexity. Small K yields complex, jagged decision boundaries (low bias, high variance), while large K smooths boundaries (high bias, low variance).

Common Pitfall: Using K=1 memorizes the training data perfectly but is extremely sensitive to noise. A single mislabeled point creates a “pocket” in the decision boundary that misclassifies nearby test points.

Visualizing Decision Boundaries

import torch
import matplotlib.pyplot as plt


def plot_knn_boundary(X_train, y_train, k, ax, title):
    """Plot KNN decision boundary on a 2D grid."""
    # Create mesh grid
    x_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1
    y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1
    xx, yy = torch.meshgrid(
        torch.linspace(x_min, x_max, 100),
        torch.linspace(y_min, y_max, 100),
        indexing='ij'
    )
    grid = torch.stack([xx.ravel(), yy.ravel()], dim=1)

    # Predict on grid
    distances = torch.cdist(grid, X_train.float(), p=2)
    _, topk_idx = distances.topk(k, largest=False)
    topk_labels = y_train[topk_idx]
    num_classes = y_train.max() + 1
    votes = torch.zeros(grid.shape[0], num_classes)
    votes.scatter_add_(1, topk_labels, torch.ones_like(topk_labels, dtype=torch.float))
    preds = votes.argmax(dim=1).reshape(xx.shape)

    ax.contourf(xx.numpy(), yy.numpy(), preds.numpy(), alpha=0.3, cmap='RdYlBu')
    ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap='RdYlBu', edgecolors='k', s=20)
    ax.set_title(title)


# Generate moon-like data
torch.manual_seed(42)
n = 150
t = torch.linspace(0, 3.14, n)
X_top = torch.stack([torch.cos(t), torch.sin(t)], dim=1) + 0.2 * torch.randn(n, 2)
X_bot = torch.stack([1 - torch.cos(t), 1 - torch.sin(t) - 0.5], dim=1) + 0.2 * torch.randn(n, 2)
X_train = torch.cat([X_top, X_bot])
y_train = torch.cat([torch.zeros(n), torch.ones(n)]).long()

fig, axes = plt.subplots(1, 3, figsize=(14, 4))
for ax, k in zip(axes, [1, 7, 50]):
    plot_knn_boundary(X_train, y_train, k, ax, f'K = {k}')
plt.tight_layout()
plt.savefig('knn_boundaries.png', dpi=100, bbox_inches='tight')
plt.show()

KNN on MNIST

Let’s put our tensor-based KNN to the test on a real dataset. MNIST handwritten digits is the classic benchmark — 784-dimensional pixel vectors classified into 10 digits.

GPU-Accelerated KNN

import torch
import torchvision
import torchvision.transforms as transforms
import time

# Load MNIST
transform = transforms.ToTensor()
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Flatten to vectors
X_train = train_dataset.data.float().reshape(-1, 784) / 255.0
y_train = train_dataset.targets
X_test = test_dataset.data.float().reshape(-1, 784) / 255.0
y_test = test_dataset.targets

print(f"Train: {X_train.shape}, Test: {X_test.shape}")

# Use subset for demo (full 60K x 10K is expensive on CPU)
X_train_sub = X_train[:5000]
y_train_sub = y_train[:5000]
X_test_sub = X_test[:1000]
y_test_sub = y_test[:1000]

# KNN prediction
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X_tr = X_train_sub.to(device)
y_tr = y_train_sub.to(device)
X_te = X_test_sub.to(device)

k = 5
start = time.time()

# Batched inference to manage memory
batch_size = 200
all_preds = []
for i in range(0, X_te.shape[0], batch_size):
    batch = X_te[i:i+batch_size]
    dists = torch.cdist(batch, X_tr, p=2)
    _, topk_idx = dists.topk(k, largest=False)
    topk_labels = y_tr[topk_idx]
    votes = torch.zeros(batch.shape[0], 10, device=device)
    votes.scatter_add_(1, topk_labels, torch.ones_like(topk_labels, dtype=torch.float))
    all_preds.append(votes.argmax(dim=1))

predictions = torch.cat(all_preds)
elapsed = time.time() - start

accuracy = (predictions == y_test_sub.to(device)).float().mean()
print(f"KNN (k={k}) accuracy on 1000 test samples: {accuracy:.4f}")
print(f"Inference time: {elapsed:.3f}s on {device}")
Performance Note: KNN achieves ~97% on MNIST with k=3-5 using raw pixels — surprisingly competitive with simple neural networks. On GPU, the 5000×1000 distance matrix computes in milliseconds. For full 60K training set, use batched processing to avoid OOM.

Limitations & Failure Modes

Curse of Dimensionality

As dimensionality grows, all points become approximately equidistant. The ratio of nearest to farthest neighbor distance approaches 1, making KNN’s distance-based decisions meaningless:

import torch

torch.manual_seed(42)

# Show how distance ratios collapse in high dimensions
for dim in [2, 10, 50, 100, 500, 1000]:
    points = torch.randn(1000, dim)
    query = torch.randn(1, dim)
    dists = torch.cdist(query, points).squeeze()

    ratio = dists.max() / dists.min()
    print(f"Dim={dim:4d}: max/min distance ratio = {ratio:.3f}")

XOR & Donut Problems

KNN handles non-linear boundaries naturally (unlike the perceptron), but struggles with specific geometric patterns where neighborhood structure is misleading:

import torch

# XOR problem — KNN handles it well (unlike perceptron)
X_xor = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float)
y_xor = torch.tensor([0, 1, 1, 0])

# KNN with k=1 solves XOR perfectly
dists = torch.cdist(X_xor, X_xor, p=2)
dists.fill_diagonal_(float('inf'))  # Exclude self
nn_idx = dists.argmin(dim=1)
print(f"XOR nearest neighbors: {y_xor[nn_idx].tolist()}")  # Each point's NN has same label

# Donut problem — concentric circles
torch.manual_seed(42)
n = 200
angles = torch.rand(n) * 2 * 3.14159
r_inner = 1.0 + 0.2 * torch.randn(n)
r_outer = 3.0 + 0.2 * torch.randn(n)
X_inner = torch.stack([r_inner * torch.cos(angles), r_inner * torch.sin(angles)], dim=1)
X_outer = torch.stack([r_outer * torch.cos(angles), r_outer * torch.sin(angles)], dim=1)

X_donut = torch.cat([X_inner, X_outer])
y_donut = torch.cat([torch.zeros(n), torch.ones(n)]).long()

# KNN handles this well with appropriate K
query = torch.tensor([[0.0, 0.0], [2.0, 0.0], [3.5, 0.0]])
dists = torch.cdist(query, X_donut, p=2)
_, topk = dists.topk(7, largest=False)
preds = y_donut[topk]
for i, (q, p) in enumerate(zip(query, preds)):
    majority = p.float().mean()
    print(f"Query {q.tolist()}: neighbors vote = {p.tolist()}, P(outer) = {majority:.2f}")

Bias-Variance Tradeoff

KNN offers a clear illustration of the bias-variance tradeoff controlled by a single knob: K.

Experiment Bias-Variance

K vs. Accuracy Curve

As K increases from 1 to N, the model transitions from perfect training fit (zero bias, high variance) to a constant predictor (high bias, zero variance). The optimal K balances these competing errors.

Overfitting Underfitting Model Selection

Finding Optimal K with Cross-Validation

import torch

torch.manual_seed(42)

# Generate noisy 2-class data
X_class0 = torch.randn(300, 2) + torch.tensor([1.5, 0.0])
X_class1 = torch.randn(300, 2) + torch.tensor([-1.5, 0.0])
X_all = torch.cat([X_class0, X_class1])
y_all = torch.cat([torch.zeros(300), torch.ones(300)]).long()

# Shuffle
perm = torch.randperm(600)
X_all, y_all = X_all[perm], y_all[perm]

# K-fold cross-validation
n_folds = 5
fold_size = len(X_all) // n_folds
k_values = [1, 3, 5, 7, 11, 15, 21, 31, 51]

results = {}
for k in k_values:
    fold_accs = []
    for fold in range(n_folds):
        val_start = fold * fold_size
        val_end = val_start + fold_size

        X_val = X_all[val_start:val_end]
        y_val = y_all[val_start:val_end]
        X_tr = torch.cat([X_all[:val_start], X_all[val_end:]])
        y_tr = torch.cat([y_all[:val_start], y_all[val_end:]])

        # Predict
        dists = torch.cdist(X_val, X_tr, p=2)
        _, topk_idx = dists.topk(k, largest=False)
        topk_labels = y_tr[topk_idx]
        votes = torch.zeros(X_val.shape[0], 2)
        votes.scatter_add_(1, topk_labels, torch.ones_like(topk_labels, dtype=torch.float))
        preds = votes.argmax(dim=1)

        acc = (preds == y_val).float().mean().item()
        fold_accs.append(acc)

    results[k] = sum(fold_accs) / len(fold_accs)
    print(f"K={k:2d}: CV accuracy = {results[k]:.4f}")

best_k = max(results, key=results.get)
print(f"\nBest K = {best_k} with accuracy = {results[best_k]:.4f}")