Back to PyTorch Mastery Series

K-Means Clustering in PyTorch

May 29, 2026 Wasil Zafar 28 min read

Build K-Means from the ground up using only PyTorch tensor operations: Lloyd’s algorithm, K-Means++ initialization, inertia tracking, and principled cluster count selection.

Table of Contents

  1. Lloyd’s Algorithm
  2. Initialization
  3. Choosing K
  4. Limitations & Alternatives
  5. Related Articles

Lloyd’s Algorithm

K-Means partitions $n$ points into $K$ clusters by minimizing the within-cluster sum of squared distances (inertia). Lloyd’s algorithm alternates between two steps until convergence:

$$J = \sum_{k=1}^{K} \sum_{x \in C_k} \|x - \mu_k\|^2$$

Lloyd’s Algorithm
flowchart TD
    A[Initialize K centroids] --> B[Assign each point\nto nearest centroid]
    B --> C[Update centroids\nas cluster means]
    C --> D{Centroids\nmoved?}
    D -->|Yes| B
    D -->|No| E[Converged — return labels]
                            
import torch


class KMeans:
    """K-Means clustering using pure PyTorch operations (GPU-compatible)."""

    def __init__(self, k, max_iter=300, tol=1e-4, seed=42):
        self.k = k
        self.max_iter = max_iter
        self.tol = tol
        self.seed = seed
        self.centroids = None
        self.labels = None
        self.inertia = None

    def _assign(self, X):
        """Assign each point to the nearest centroid. Returns label tensor."""
        # X: (n, d), centroids: (k, d)
        # Pairwise squared distances via broadcasting: (n, k)
        diffs = X.unsqueeze(1) - self.centroids.unsqueeze(0)  # (n, k, d)
        dists = (diffs ** 2).sum(dim=2)  # (n, k)
        return dists.argmin(dim=1), dists.min(dim=1).values.sum().item()

    def fit(self, X):
        """Fit K-Means. X: (n, d) float tensor."""
        torch.manual_seed(self.seed)
        X = X.float()
        n, d = X.shape

        # Random initialization
        idx = torch.randperm(n)[:self.k]
        self.centroids = X[idx].clone()

        for iteration in range(self.max_iter):
            old_centroids = self.centroids.clone()
            self.labels, self.inertia = self._assign(X)

            # Update centroids as cluster means
            new_centroids = torch.zeros(self.k, d)
            for c in range(self.k):
                mask = self.labels == c
                if mask.sum() > 0:
                    new_centroids[c] = X[mask].mean(dim=0)
                else:
                    new_centroids[c] = old_centroids[c]  # Handle empty cluster
            self.centroids = new_centroids

            shift = (self.centroids - old_centroids).norm().item()
            if shift < self.tol:
                print(f"Converged at iteration {iteration + 1}")
                break

        self.labels, self.inertia = self._assign(X)
        return self

    def predict(self, X):
        """Assign new points to nearest centroid."""
        labels, _ = self._assign(X.float())
        return labels


# Generate 3-cluster data
torch.manual_seed(42)
centers = torch.tensor([[0., 0.], [4., 0.], [2., 4.]])
X = torch.cat([c + torch.randn(100, 2) * 0.7 for c in centers])

km = KMeans(k=3)
km.fit(X)
print(f"Inertia: {km.inertia:.2f}")
print(f"Cluster sizes: {[(km.labels == i).sum().item() for i in range(3)]}")

Initialization Strategies

K-Means++ Initialization

Random initialization can land centroids in suboptimal positions. K-Means++ probabilistically selects each subsequent centroid proportional to its squared distance from the nearest already-chosen centroid, dramatically improving convergence.

import torch


def kmeans_plus_plus_init(X, k, seed=42):
    """
    K-Means++ initialization.
    Returns k centroid candidates from X.
    """
    torch.manual_seed(seed)
    n, d = X.shape

    centroids = []
    # First centroid: random point
    centroids.append(X[torch.randint(n, (1,)).item()])

    for _ in range(1, k):
        # Distance from each point to its nearest centroid
        dists = torch.stack([
            ((X - c) ** 2).sum(dim=1) for c in centroids
        ]).min(dim=0).values  # (n,)

        # Sample proportional to squared distance
        probs = dists / dists.sum()
        chosen = torch.multinomial(probs, 1).item()
        centroids.append(X[chosen])

    return torch.stack(centroids)


# Compare random vs K-Means++ initialization
torch.manual_seed(42)
centers_true = torch.tensor([[0., 0.], [6., 0.], [3., 6.]])
X = torch.cat([c + torch.randn(150, 2) * 0.8 for c in centers_true])

# Run 10 times with random init
random_inertias = []
for seed in range(10):
    torch.manual_seed(seed)
    n = len(X)
    centroids = X[torch.randperm(n)[:3]].clone()
    for _ in range(100):
        diffs = X.unsqueeze(1) - centroids.unsqueeze(0)
        dists = (diffs**2).sum(2)
        labels = dists.argmin(1)
        for c in range(3):
            mask = labels == c
            if mask.sum() > 0:
                centroids[c] = X[mask].mean(0)
    inertia = dists.min(1).values.sum().item()
    random_inertias.append(inertia)

# Run 10 times with K-Means++ init
pp_inertias = []
for seed in range(10):
    centroids = kmeans_plus_plus_init(X, 3, seed=seed)
    for _ in range(100):
        diffs = X.unsqueeze(1) - centroids.unsqueeze(0)
        dists = (diffs**2).sum(2)
        labels = dists.argmin(1)
        for c in range(3):
            mask = labels == c
            if mask.sum() > 0:
                centroids[c] = X[mask].mean(0)
    inertia = dists.min(1).values.sum().item()
    pp_inertias.append(inertia)

print(f"Random init  — best: {min(random_inertias):.1f}, worst: {max(random_inertias):.1f}")
print(f"K-Means++ init — best: {min(pp_inertias):.1f}, worst: {max(pp_inertias):.1f}")

Choosing K

Elbow Method

import torch


class KMeans:
    """K-Means with K-Means++ initialization."""

    def __init__(self, k, max_iter=200, seed=42):
        self.k = k
        self.max_iter = max_iter
        self.seed = seed
        self.centroids = None
        self.inertia = None

    def fit(self, X):
        torch.manual_seed(self.seed)
        X = X.float()
        n, d = X.shape

        # K-Means++ initialization
        centroids = [X[torch.randint(n, (1,)).item()]]
        for _ in range(1, self.k):
            dists = torch.stack([((X - c)**2).sum(1) for c in centroids]).min(0).values
            probs = dists / dists.sum()
            centroids.append(X[torch.multinomial(probs, 1).item()])
        self.centroids = torch.stack(centroids)

        for _ in range(self.max_iter):
            old = self.centroids.clone()
            dists = ((X.unsqueeze(1) - self.centroids.unsqueeze(0))**2).sum(2)
            labels = dists.argmin(1)
            for c in range(self.k):
                mask = labels == c
                if mask.sum() > 0:
                    self.centroids[c] = X[mask].mean(0)
            if (self.centroids - old).norm() < 1e-4:
                break

        dists = ((X.unsqueeze(1) - self.centroids.unsqueeze(0))**2).sum(2)
        self.inertia = dists.min(1).values.sum().item()
        return self


# Elbow method: plot inertia vs K
torch.manual_seed(42)
X = torch.cat([
    torch.randn(80, 2) + torch.tensor([0., 0.]),
    torch.randn(80, 2) + torch.tensor([5., 0.]),
    torch.randn(80, 2) + torch.tensor([2.5, 4.]),
    torch.randn(80, 2) + torch.tensor([8., 4.]),
])

inertias = []
for k in range(1, 9):
    km = KMeans(k).fit(X)
    inertias.append(km.inertia)
    print(f"K={k}: inertia={km.inertia:.1f}")

# Detect elbow: largest drop in inertia reduction
drops = [inertias[i] - inertias[i+1] for i in range(len(inertias)-1)]
optimal_k = drops.index(max(drops)) + 2  # +2 because drops[0] = decrease from K=1 to K=2
print(f"\nSuggested K (elbow): {optimal_k}")

Silhouette Score

import torch


def silhouette_score_torch(X, labels):
    """
    Compute mean silhouette score using PyTorch.
    s(i) = (b(i) - a(i)) / max(a(i), b(i))
    where a(i) = mean intra-cluster distance, b(i) = min mean inter-cluster distance.
    """
    X = X.float()
    n = len(X)
    k = labels.max().item() + 1
    scores = torch.zeros(n)

    # Pairwise distances (expensive, O(n^2))
    dists = torch.cdist(X, X)  # (n, n)

    for i in range(n):
        c = labels[i].item()
        # a(i): mean distance to points in same cluster (excluding self)
        same_mask = (labels == c)
        same_mask[i] = False  # Exclude self
        if same_mask.sum() == 0:
            scores[i] = 0.0
            continue
        a = dists[i, same_mask].mean()

        # b(i): min mean distance to points in each other cluster
        b_vals = []
        for other_c in range(k):
            if other_c == c:
                continue
            other_mask = labels == other_c
            if other_mask.sum() > 0:
                b_vals.append(dists[i, other_mask].mean())
        b = torch.stack(b_vals).min()
        scores[i] = (b - a) / torch.max(a, b)

    return scores.mean().item()


# Compare different K values using silhouette
torch.manual_seed(42)
X = torch.cat([
    torch.randn(60, 2) + torch.tensor([0., 0.]),
    torch.randn(60, 2) + torch.tensor([4., 0.]),
    torch.randn(60, 2) + torch.tensor([2., 3.5]),
])

for k in range(2, 6):
    # Simple K-Means (reusing logic inline for independence)
    torch.manual_seed(0)
    n = len(X)
    centroids = X[torch.randperm(n)[:k]].clone()
    for _ in range(200):
        dists_mat = ((X.unsqueeze(1) - centroids.unsqueeze(0))**2).sum(2)
        labels = dists_mat.argmin(1)
        for c in range(k):
            mask = labels == c
            if mask.sum() > 0:
                centroids[c] = X[mask].mean(0)

    score = silhouette_score_torch(X, labels)
    print(f"K={k}: silhouette={score:.4f}")

Limitations & Alternatives

Comparison When to Use Alternatives

K-Means Limitations

  • Assumes spherical clusters: Struggles with elongated or irregular shapes — use DBSCAN or Gaussian Mixture Models
  • Requires K in advance: Must try multiple K values — DBSCAN infers K automatically from density
  • Sensitive to outliers: Outliers shift centroids significantly — use K-Medoids (PAM) which uses actual data points as centers
  • Scale-dependent: Always standardize features before K-Means — large-scale features dominate distance calculations
DBSCAN GMM K-Medoids