Lloyd’s Algorithm
K-Means partitions $n$ points into $K$ clusters by minimizing the within-cluster sum of squared distances (inertia). Lloyd’s algorithm alternates between two steps until convergence:
$$J = \sum_{k=1}^{K} \sum_{x \in C_k} \|x - \mu_k\|^2$$
flowchart TD
A[Initialize K centroids] --> B[Assign each point\nto nearest centroid]
B --> C[Update centroids\nas cluster means]
C --> D{Centroids\nmoved?}
D -->|Yes| B
D -->|No| E[Converged — return labels]
import torch
class KMeans:
"""K-Means clustering using pure PyTorch operations (GPU-compatible)."""
def __init__(self, k, max_iter=300, tol=1e-4, seed=42):
self.k = k
self.max_iter = max_iter
self.tol = tol
self.seed = seed
self.centroids = None
self.labels = None
self.inertia = None
def _assign(self, X):
"""Assign each point to the nearest centroid. Returns label tensor."""
# X: (n, d), centroids: (k, d)
# Pairwise squared distances via broadcasting: (n, k)
diffs = X.unsqueeze(1) - self.centroids.unsqueeze(0) # (n, k, d)
dists = (diffs ** 2).sum(dim=2) # (n, k)
return dists.argmin(dim=1), dists.min(dim=1).values.sum().item()
def fit(self, X):
"""Fit K-Means. X: (n, d) float tensor."""
torch.manual_seed(self.seed)
X = X.float()
n, d = X.shape
# Random initialization
idx = torch.randperm(n)[:self.k]
self.centroids = X[idx].clone()
for iteration in range(self.max_iter):
old_centroids = self.centroids.clone()
self.labels, self.inertia = self._assign(X)
# Update centroids as cluster means
new_centroids = torch.zeros(self.k, d)
for c in range(self.k):
mask = self.labels == c
if mask.sum() > 0:
new_centroids[c] = X[mask].mean(dim=0)
else:
new_centroids[c] = old_centroids[c] # Handle empty cluster
self.centroids = new_centroids
shift = (self.centroids - old_centroids).norm().item()
if shift < self.tol:
print(f"Converged at iteration {iteration + 1}")
break
self.labels, self.inertia = self._assign(X)
return self
def predict(self, X):
"""Assign new points to nearest centroid."""
labels, _ = self._assign(X.float())
return labels
# Generate 3-cluster data
torch.manual_seed(42)
centers = torch.tensor([[0., 0.], [4., 0.], [2., 4.]])
X = torch.cat([c + torch.randn(100, 2) * 0.7 for c in centers])
km = KMeans(k=3)
km.fit(X)
print(f"Inertia: {km.inertia:.2f}")
print(f"Cluster sizes: {[(km.labels == i).sum().item() for i in range(3)]}")
Initialization Strategies
K-Means++ Initialization
Random initialization can land centroids in suboptimal positions. K-Means++ probabilistically selects each subsequent centroid proportional to its squared distance from the nearest already-chosen centroid, dramatically improving convergence.
import torch
def kmeans_plus_plus_init(X, k, seed=42):
"""
K-Means++ initialization.
Returns k centroid candidates from X.
"""
torch.manual_seed(seed)
n, d = X.shape
centroids = []
# First centroid: random point
centroids.append(X[torch.randint(n, (1,)).item()])
for _ in range(1, k):
# Distance from each point to its nearest centroid
dists = torch.stack([
((X - c) ** 2).sum(dim=1) for c in centroids
]).min(dim=0).values # (n,)
# Sample proportional to squared distance
probs = dists / dists.sum()
chosen = torch.multinomial(probs, 1).item()
centroids.append(X[chosen])
return torch.stack(centroids)
# Compare random vs K-Means++ initialization
torch.manual_seed(42)
centers_true = torch.tensor([[0., 0.], [6., 0.], [3., 6.]])
X = torch.cat([c + torch.randn(150, 2) * 0.8 for c in centers_true])
# Run 10 times with random init
random_inertias = []
for seed in range(10):
torch.manual_seed(seed)
n = len(X)
centroids = X[torch.randperm(n)[:3]].clone()
for _ in range(100):
diffs = X.unsqueeze(1) - centroids.unsqueeze(0)
dists = (diffs**2).sum(2)
labels = dists.argmin(1)
for c in range(3):
mask = labels == c
if mask.sum() > 0:
centroids[c] = X[mask].mean(0)
inertia = dists.min(1).values.sum().item()
random_inertias.append(inertia)
# Run 10 times with K-Means++ init
pp_inertias = []
for seed in range(10):
centroids = kmeans_plus_plus_init(X, 3, seed=seed)
for _ in range(100):
diffs = X.unsqueeze(1) - centroids.unsqueeze(0)
dists = (diffs**2).sum(2)
labels = dists.argmin(1)
for c in range(3):
mask = labels == c
if mask.sum() > 0:
centroids[c] = X[mask].mean(0)
inertia = dists.min(1).values.sum().item()
pp_inertias.append(inertia)
print(f"Random init — best: {min(random_inertias):.1f}, worst: {max(random_inertias):.1f}")
print(f"K-Means++ init — best: {min(pp_inertias):.1f}, worst: {max(pp_inertias):.1f}")
Choosing K
Elbow Method
import torch
class KMeans:
"""K-Means with K-Means++ initialization."""
def __init__(self, k, max_iter=200, seed=42):
self.k = k
self.max_iter = max_iter
self.seed = seed
self.centroids = None
self.inertia = None
def fit(self, X):
torch.manual_seed(self.seed)
X = X.float()
n, d = X.shape
# K-Means++ initialization
centroids = [X[torch.randint(n, (1,)).item()]]
for _ in range(1, self.k):
dists = torch.stack([((X - c)**2).sum(1) for c in centroids]).min(0).values
probs = dists / dists.sum()
centroids.append(X[torch.multinomial(probs, 1).item()])
self.centroids = torch.stack(centroids)
for _ in range(self.max_iter):
old = self.centroids.clone()
dists = ((X.unsqueeze(1) - self.centroids.unsqueeze(0))**2).sum(2)
labels = dists.argmin(1)
for c in range(self.k):
mask = labels == c
if mask.sum() > 0:
self.centroids[c] = X[mask].mean(0)
if (self.centroids - old).norm() < 1e-4:
break
dists = ((X.unsqueeze(1) - self.centroids.unsqueeze(0))**2).sum(2)
self.inertia = dists.min(1).values.sum().item()
return self
# Elbow method: plot inertia vs K
torch.manual_seed(42)
X = torch.cat([
torch.randn(80, 2) + torch.tensor([0., 0.]),
torch.randn(80, 2) + torch.tensor([5., 0.]),
torch.randn(80, 2) + torch.tensor([2.5, 4.]),
torch.randn(80, 2) + torch.tensor([8., 4.]),
])
inertias = []
for k in range(1, 9):
km = KMeans(k).fit(X)
inertias.append(km.inertia)
print(f"K={k}: inertia={km.inertia:.1f}")
# Detect elbow: largest drop in inertia reduction
drops = [inertias[i] - inertias[i+1] for i in range(len(inertias)-1)]
optimal_k = drops.index(max(drops)) + 2 # +2 because drops[0] = decrease from K=1 to K=2
print(f"\nSuggested K (elbow): {optimal_k}")
Silhouette Score
import torch
def silhouette_score_torch(X, labels):
"""
Compute mean silhouette score using PyTorch.
s(i) = (b(i) - a(i)) / max(a(i), b(i))
where a(i) = mean intra-cluster distance, b(i) = min mean inter-cluster distance.
"""
X = X.float()
n = len(X)
k = labels.max().item() + 1
scores = torch.zeros(n)
# Pairwise distances (expensive, O(n^2))
dists = torch.cdist(X, X) # (n, n)
for i in range(n):
c = labels[i].item()
# a(i): mean distance to points in same cluster (excluding self)
same_mask = (labels == c)
same_mask[i] = False # Exclude self
if same_mask.sum() == 0:
scores[i] = 0.0
continue
a = dists[i, same_mask].mean()
# b(i): min mean distance to points in each other cluster
b_vals = []
for other_c in range(k):
if other_c == c:
continue
other_mask = labels == other_c
if other_mask.sum() > 0:
b_vals.append(dists[i, other_mask].mean())
b = torch.stack(b_vals).min()
scores[i] = (b - a) / torch.max(a, b)
return scores.mean().item()
# Compare different K values using silhouette
torch.manual_seed(42)
X = torch.cat([
torch.randn(60, 2) + torch.tensor([0., 0.]),
torch.randn(60, 2) + torch.tensor([4., 0.]),
torch.randn(60, 2) + torch.tensor([2., 3.5]),
])
for k in range(2, 6):
# Simple K-Means (reusing logic inline for independence)
torch.manual_seed(0)
n = len(X)
centroids = X[torch.randperm(n)[:k]].clone()
for _ in range(200):
dists_mat = ((X.unsqueeze(1) - centroids.unsqueeze(0))**2).sum(2)
labels = dists_mat.argmin(1)
for c in range(k):
mask = labels == c
if mask.sum() > 0:
centroids[c] = X[mask].mean(0)
score = silhouette_score_torch(X, labels)
print(f"K={k}: silhouette={score:.4f}")
Limitations & Alternatives
K-Means Limitations
- Assumes spherical clusters: Struggles with elongated or irregular shapes — use DBSCAN or Gaussian Mixture Models
- Requires K in advance: Must try multiple K values — DBSCAN infers K automatically from density
- Sensitive to outliers: Outliers shift centroids significantly — use K-Medoids (PAM) which uses actual data points as centers
- Scale-dependent: Always standardize features before K-Means — large-scale features dominate distance calculations