The GMM Model
A Gaussian Mixture Model models the data distribution as a weighted sum of $K$ multivariate Gaussian distributions:
$$p(x) = \sum_{k=1}^{K} \pi_k \, \mathcal{N}(x \mid \mu_k, \Sigma_k)$$
where $\pi_k \geq 0$ are mixing weights with $\sum_k \pi_k = 1$. Unlike K-Means, each point belongs to all clusters with some probability (soft assignment).
EM Algorithm
The EM algorithm maximizes the log-likelihood by iterating:
- E-step: Compute soft responsibilities $r_{ik} = P(z_i=k \mid x_i)$ — “how much does cluster k own point i?”
- M-step: Update $\pi_k$, $\mu_k$, $\Sigma_k$ using the weighted counts from the E-step
import torch
import math
def mvn_log_prob(X, mu, cov):
"""
Log probability of X under multivariate Gaussian N(mu, cov).
X: (n, d)
mu: (d,)
cov: (d, d)
Returns: (n,) log probabilities
"""
d = X.shape[1]
diff = X - mu.unsqueeze(0) # (n, d)
# Add small jitter for numerical stability
cov_stable = cov + 1e-6 * torch.eye(d)
cov_inv = torch.linalg.inv(cov_stable)
log_det = torch.logdet(cov_stable)
# Mahalanobis distance: diag(diff @ cov_inv @ diff.T)
mahal = (diff @ cov_inv * diff).sum(dim=1) # (n,)
return -0.5 * (d * math.log(2 * math.pi) + log_det + mahal)
# Demo: single Gaussian log-probability
torch.manual_seed(42)
X = torch.randn(100, 2)
mu = torch.zeros(2)
cov = torch.eye(2)
log_probs = mvn_log_prob(X, mu, cov)
print(f"Log-prob shape: {log_probs.shape}")
print(f"Mean log-prob under N(0,I): {log_probs.mean().item():.4f}")
# Standard 2D Gaussian: should be approx -0.5*d*log(2*pi) ≈ -1.84
E-Step: Computing Responsibilities
import torch
import math
def mvn_log_prob(X, mu, cov):
d = X.shape[1]
diff = X - mu.unsqueeze(0)
cov_s = cov + 1e-6 * torch.eye(d)
cov_inv = torch.linalg.inv(cov_s)
log_det = torch.logdet(cov_s)
mahal = (diff @ cov_inv * diff).sum(1)
return -0.5 * (d * math.log(2 * math.pi) + log_det + mahal)
def e_step(X, pi, means, covs):
"""
E-step: compute soft responsibilities r[i,k] = P(z_i=k | x_i).
Returns:
r: (n, k) responsibility matrix (rows sum to 1)
log_lik: scalar total log-likelihood
"""
n, k = len(X), len(pi)
# Log joint: log p(x_i, z_i=k) = log pi_k + log N(x_i | mu_k, cov_k)
log_joint = torch.zeros(n, k)
for c in range(k):
log_joint[:, c] = torch.log(pi[c]) + mvn_log_prob(X, means[c], covs[c])
# log-sum-exp for numerical stability
log_sum = torch.logsumexp(log_joint, dim=1, keepdim=True) # (n, 1)
r = torch.exp(log_joint - log_sum) # (n, k) — normalized responsibilities
log_lik = log_sum.sum().item()
return r, log_lik
# Demo
torch.manual_seed(42)
X = torch.randn(200, 2)
k = 3
pi = torch.ones(k) / k
means = torch.randn(k, 2)
covs = torch.stack([torch.eye(2) for _ in range(k)])
r, ll = e_step(X, pi, means, covs)
print(f"Responsibilities shape: {r.shape}")
print(f"Row sums (should be 1): {r.sum(1)[:3].tolist()}")
print(f"Log-likelihood: {ll:.3f}")
M-Step: Parameter Updates
import torch
def m_step(X, r):
"""
M-step: update GMM parameters from responsibilities.
r: (n, k) responsibility matrix
Returns: pi (k,), means (k, d), covs (k, d, d)
"""
n, d = X.shape
k = r.shape[1]
# Effective counts per cluster
N_k = r.sum(dim=0) # (k,)
# Update mixing weights
pi = N_k / n # (k,)
# Update means: weighted average
means = (r.T @ X) / N_k.unsqueeze(1) # (k, d)
# Update covariances: weighted outer products
covs = torch.zeros(k, d, d)
for c in range(k):
diff = X - means[c].unsqueeze(0) # (n, d)
# Weighted covariance: sum_i r[i,c] * (x_i - mu_c)(x_i - mu_c)^T
covs[c] = (r[:, c].unsqueeze(1) * diff).T @ diff / N_k[c]
return pi, means, covs
# Demo
torch.manual_seed(42)
k = 3
X = torch.cat([
torch.randn(100, 2) + torch.tensor([0., 0.]),
torch.randn(100, 2) + torch.tensor([4., 0.]),
torch.randn(100, 2) + torch.tensor([2., 3.5]),
])
# Initial uniform responsibilities
r_init = torch.softmax(torch.randn(len(X), k), dim=1)
pi, means, covs = m_step(X, r_init)
print(f"Initial pi: {pi.round(decimals=3).tolist()}")
print(f"Initial means shape: {means.shape}")
print(f"Initial covs shape: {covs.shape}")
Full GMM Implementation
import torch
import math
def mvn_log_prob(X, mu, cov):
d = X.shape[1]
diff = X - mu.unsqueeze(0)
cov_s = cov + 1e-6 * torch.eye(d)
cov_inv = torch.linalg.inv(cov_s)
log_det = torch.logdet(cov_s)
return -0.5 * (d * math.log(2 * math.pi) + log_det + (diff @ cov_inv * diff).sum(1))
class GMM:
"""Gaussian Mixture Model with EM algorithm."""
def __init__(self, k, max_iter=200, tol=1e-4, seed=42):
self.k = k
self.max_iter = max_iter
self.tol = tol
self.seed = seed
self.pi = None
self.means = None
self.covs = None
self.log_likelihood_ = None
def fit(self, X):
torch.manual_seed(self.seed)
X = X.float()
n, d = X.shape
# Initialize: random means, unit covariances, uniform weights
self.pi = torch.ones(self.k) / self.k
self.means = X[torch.randperm(n)[:self.k]].clone()
self.covs = torch.stack([torch.eye(d) for _ in range(self.k)])
prev_ll = float('-inf')
for iteration in range(self.max_iter):
# E-step
log_joint = torch.zeros(n, self.k)
for c in range(self.k):
log_joint[:, c] = torch.log(self.pi[c] + 1e-10) + mvn_log_prob(X, self.means[c], self.covs[c])
log_sum = torch.logsumexp(log_joint, dim=1, keepdim=True)
r = torch.exp(log_joint - log_sum)
ll = log_sum.sum().item()
# M-step
N_k = r.sum(0)
self.pi = N_k / n
self.means = (r.T @ X) / N_k.unsqueeze(1)
for c in range(self.k):
diff = X - self.means[c]
self.covs[c] = (r[:, c].unsqueeze(1) * diff).T @ diff / N_k[c]
if abs(ll - prev_ll) < self.tol:
print(f"Converged at iteration {iteration + 1}, LL={ll:.3f}")
break
prev_ll = ll
self.log_likelihood_ = ll
return self
def predict(self, X):
"""Hard assignment: argmax responsibility."""
X = X.float()
log_joint = torch.zeros(len(X), self.k)
for c in range(self.k):
log_joint[:, c] = torch.log(self.pi[c] + 1e-10) + mvn_log_prob(X, self.means[c], self.covs[c])
return log_joint.argmax(dim=1)
def predict_proba(self, X):
"""Soft assignments (responsibilities)."""
X = X.float()
log_joint = torch.zeros(len(X), self.k)
for c in range(self.k):
log_joint[:, c] = torch.log(self.pi[c] + 1e-10) + mvn_log_prob(X, self.means[c], self.covs[c])
log_sum = torch.logsumexp(log_joint, dim=1, keepdim=True)
return torch.exp(log_joint - log_sum)
# Test on 3-cluster data
torch.manual_seed(42)
X = torch.cat([
torch.randn(150, 2) + torch.tensor([0., 0.]),
torch.randn(150, 2) @ torch.tensor([[2., 0.5], [0.5, 0.5]]) + torch.tensor([5., 0.]),
torch.randn(150, 2) + torch.tensor([2.5, 4.]),
])
gmm = GMM(k=3)
gmm.fit(X)
print(f"Final log-likelihood: {gmm.log_likelihood_:.3f}")
print(f"Mixing weights: {gmm.pi.round(decimals=3).tolist()}")
print(f"Cluster sizes: {[(gmm.predict(X) == i).sum().item() for i in range(3)]}")
Model Selection with BIC/AIC
import torch
import math
def mvn_log_prob(X, mu, cov):
d = X.shape[1]
diff = X - mu.unsqueeze(0)
cov_s = cov + 1e-6 * torch.eye(d)
cov_inv = torch.linalg.inv(cov_s)
log_det = torch.logdet(cov_s)
return -0.5 * (d * math.log(2 * math.pi) + log_det + (diff @ cov_inv * diff).sum(1))
def fit_gmm(X, k, seed=42, max_iter=200):
"""Train GMM and return final log-likelihood."""
torch.manual_seed(seed)
X = X.float()
n, d = X.shape
pi = torch.ones(k) / k
means = X[torch.randperm(n)[:k]].clone()
covs = torch.stack([torch.eye(d) for _ in range(k)])
for _ in range(max_iter):
log_joint = torch.zeros(n, k)
for c in range(k):
log_joint[:, c] = torch.log(pi[c] + 1e-10) + mvn_log_prob(X, means[c], covs[c])
log_sum = torch.logsumexp(log_joint, dim=1, keepdim=True)
r = torch.exp(log_joint - log_sum)
N_k = r.sum(0)
pi = N_k / n
means = (r.T @ X) / N_k.unsqueeze(1)
for c in range(k):
diff = X - means[c]
covs[c] = (r[:, c].unsqueeze(1) * diff).T @ diff / N_k[c]
return log_sum.sum().item()
torch.manual_seed(42)
X = torch.cat([
torch.randn(100, 2) + torch.tensor([0., 0.]),
torch.randn(100, 2) + torch.tensor([4., 0.]),
torch.randn(100, 2) + torch.tensor([2., 3.5]),
])
n, d = X.shape
print("K | Log-Lik | AIC | BIC")
for k in range(1, 7):
ll = fit_gmm(X, k)
# Number of free parameters: k means (d each) + k covariances (d*d each) + k-1 mixing weights
n_params = k * d + k * d * d + (k - 1)
aic = 2 * n_params - 2 * ll
bic = n_params * math.log(n) - 2 * ll
print(f" {k} | {ll:8.1f} | {aic:10.1f} | {bic:10.1f}")
print("\n(Lower BIC/AIC is better. True K=3)")