Back to PyTorch Mastery Series

Gaussian Mixture Models & EM Algorithm

May 29, 2026 Wasil Zafar 32 min read

Implement Gaussian Mixture Models from scratch in PyTorch: the full EM algorithm with E-step posterior computation and M-step parameter updates, log-likelihood convergence tracking, and principled K selection with BIC/AIC.

Table of Contents

  1. The GMM Model
  2. EM Algorithm
  3. Full Implementation
  4. Model Selection (BIC/AIC)
  5. Related Articles

The GMM Model

A Gaussian Mixture Model models the data distribution as a weighted sum of $K$ multivariate Gaussian distributions:

$$p(x) = \sum_{k=1}^{K} \pi_k \, \mathcal{N}(x \mid \mu_k, \Sigma_k)$$

where $\pi_k \geq 0$ are mixing weights with $\sum_k \pi_k = 1$. Unlike K-Means, each point belongs to all clusters with some probability (soft assignment).

GMM vs K-Means: K-Means is a special case of GMM where all covariances are constrained to be $\sigma^2 I$ (equal spherical Gaussians) and assignments are hard (0 or 1). GMM generalizes to elliptical clusters with different sizes and orientations.

EM Algorithm

The EM algorithm maximizes the log-likelihood by iterating:

  • E-step: Compute soft responsibilities $r_{ik} = P(z_i=k \mid x_i)$ — “how much does cluster k own point i?”
  • M-step: Update $\pi_k$, $\mu_k$, $\Sigma_k$ using the weighted counts from the E-step
import torch
import math


def mvn_log_prob(X, mu, cov):
    """
    Log probability of X under multivariate Gaussian N(mu, cov).
    X:   (n, d)
    mu:  (d,)
    cov: (d, d)
    Returns: (n,) log probabilities
    """
    d = X.shape[1]
    diff = X - mu.unsqueeze(0)  # (n, d)
    # Add small jitter for numerical stability
    cov_stable = cov + 1e-6 * torch.eye(d)
    cov_inv = torch.linalg.inv(cov_stable)
    log_det = torch.logdet(cov_stable)
    # Mahalanobis distance: diag(diff @ cov_inv @ diff.T)
    mahal = (diff @ cov_inv * diff).sum(dim=1)  # (n,)
    return -0.5 * (d * math.log(2 * math.pi) + log_det + mahal)


# Demo: single Gaussian log-probability
torch.manual_seed(42)
X = torch.randn(100, 2)
mu = torch.zeros(2)
cov = torch.eye(2)
log_probs = mvn_log_prob(X, mu, cov)
print(f"Log-prob shape: {log_probs.shape}")
print(f"Mean log-prob under N(0,I): {log_probs.mean().item():.4f}")
# Standard 2D Gaussian: should be approx -0.5*d*log(2*pi) ≈ -1.84

E-Step: Computing Responsibilities

import torch
import math


def mvn_log_prob(X, mu, cov):
    d = X.shape[1]
    diff = X - mu.unsqueeze(0)
    cov_s = cov + 1e-6 * torch.eye(d)
    cov_inv = torch.linalg.inv(cov_s)
    log_det = torch.logdet(cov_s)
    mahal = (diff @ cov_inv * diff).sum(1)
    return -0.5 * (d * math.log(2 * math.pi) + log_det + mahal)


def e_step(X, pi, means, covs):
    """
    E-step: compute soft responsibilities r[i,k] = P(z_i=k | x_i).
    Returns:
      r:       (n, k) responsibility matrix (rows sum to 1)
      log_lik: scalar total log-likelihood
    """
    n, k = len(X), len(pi)
    # Log joint: log p(x_i, z_i=k) = log pi_k + log N(x_i | mu_k, cov_k)
    log_joint = torch.zeros(n, k)
    for c in range(k):
        log_joint[:, c] = torch.log(pi[c]) + mvn_log_prob(X, means[c], covs[c])

    # log-sum-exp for numerical stability
    log_sum = torch.logsumexp(log_joint, dim=1, keepdim=True)  # (n, 1)
    r = torch.exp(log_joint - log_sum)  # (n, k) — normalized responsibilities
    log_lik = log_sum.sum().item()
    return r, log_lik


# Demo
torch.manual_seed(42)
X = torch.randn(200, 2)
k = 3
pi = torch.ones(k) / k
means = torch.randn(k, 2)
covs = torch.stack([torch.eye(2) for _ in range(k)])

r, ll = e_step(X, pi, means, covs)
print(f"Responsibilities shape: {r.shape}")
print(f"Row sums (should be 1): {r.sum(1)[:3].tolist()}")
print(f"Log-likelihood: {ll:.3f}")

M-Step: Parameter Updates

import torch


def m_step(X, r):
    """
    M-step: update GMM parameters from responsibilities.
    r: (n, k) responsibility matrix
    Returns: pi (k,), means (k, d), covs (k, d, d)
    """
    n, d = X.shape
    k = r.shape[1]

    # Effective counts per cluster
    N_k = r.sum(dim=0)  # (k,)

    # Update mixing weights
    pi = N_k / n  # (k,)

    # Update means: weighted average
    means = (r.T @ X) / N_k.unsqueeze(1)  # (k, d)

    # Update covariances: weighted outer products
    covs = torch.zeros(k, d, d)
    for c in range(k):
        diff = X - means[c].unsqueeze(0)   # (n, d)
        # Weighted covariance: sum_i r[i,c] * (x_i - mu_c)(x_i - mu_c)^T
        covs[c] = (r[:, c].unsqueeze(1) * diff).T @ diff / N_k[c]
    return pi, means, covs


# Demo
torch.manual_seed(42)
k = 3
X = torch.cat([
    torch.randn(100, 2) + torch.tensor([0., 0.]),
    torch.randn(100, 2) + torch.tensor([4., 0.]),
    torch.randn(100, 2) + torch.tensor([2., 3.5]),
])

# Initial uniform responsibilities
r_init = torch.softmax(torch.randn(len(X), k), dim=1)
pi, means, covs = m_step(X, r_init)
print(f"Initial pi:    {pi.round(decimals=3).tolist()}")
print(f"Initial means shape: {means.shape}")
print(f"Initial covs shape:  {covs.shape}")

Full GMM Implementation

import torch
import math


def mvn_log_prob(X, mu, cov):
    d = X.shape[1]
    diff = X - mu.unsqueeze(0)
    cov_s = cov + 1e-6 * torch.eye(d)
    cov_inv = torch.linalg.inv(cov_s)
    log_det = torch.logdet(cov_s)
    return -0.5 * (d * math.log(2 * math.pi) + log_det + (diff @ cov_inv * diff).sum(1))


class GMM:
    """Gaussian Mixture Model with EM algorithm."""

    def __init__(self, k, max_iter=200, tol=1e-4, seed=42):
        self.k = k
        self.max_iter = max_iter
        self.tol = tol
        self.seed = seed
        self.pi = None
        self.means = None
        self.covs = None
        self.log_likelihood_ = None

    def fit(self, X):
        torch.manual_seed(self.seed)
        X = X.float()
        n, d = X.shape

        # Initialize: random means, unit covariances, uniform weights
        self.pi = torch.ones(self.k) / self.k
        self.means = X[torch.randperm(n)[:self.k]].clone()
        self.covs = torch.stack([torch.eye(d) for _ in range(self.k)])

        prev_ll = float('-inf')
        for iteration in range(self.max_iter):
            # E-step
            log_joint = torch.zeros(n, self.k)
            for c in range(self.k):
                log_joint[:, c] = torch.log(self.pi[c] + 1e-10) + mvn_log_prob(X, self.means[c], self.covs[c])
            log_sum = torch.logsumexp(log_joint, dim=1, keepdim=True)
            r = torch.exp(log_joint - log_sum)
            ll = log_sum.sum().item()

            # M-step
            N_k = r.sum(0)
            self.pi = N_k / n
            self.means = (r.T @ X) / N_k.unsqueeze(1)
            for c in range(self.k):
                diff = X - self.means[c]
                self.covs[c] = (r[:, c].unsqueeze(1) * diff).T @ diff / N_k[c]

            if abs(ll - prev_ll) < self.tol:
                print(f"Converged at iteration {iteration + 1}, LL={ll:.3f}")
                break
            prev_ll = ll

        self.log_likelihood_ = ll
        return self

    def predict(self, X):
        """Hard assignment: argmax responsibility."""
        X = X.float()
        log_joint = torch.zeros(len(X), self.k)
        for c in range(self.k):
            log_joint[:, c] = torch.log(self.pi[c] + 1e-10) + mvn_log_prob(X, self.means[c], self.covs[c])
        return log_joint.argmax(dim=1)

    def predict_proba(self, X):
        """Soft assignments (responsibilities)."""
        X = X.float()
        log_joint = torch.zeros(len(X), self.k)
        for c in range(self.k):
            log_joint[:, c] = torch.log(self.pi[c] + 1e-10) + mvn_log_prob(X, self.means[c], self.covs[c])
        log_sum = torch.logsumexp(log_joint, dim=1, keepdim=True)
        return torch.exp(log_joint - log_sum)


# Test on 3-cluster data
torch.manual_seed(42)
X = torch.cat([
    torch.randn(150, 2) + torch.tensor([0., 0.]),
    torch.randn(150, 2) @ torch.tensor([[2., 0.5], [0.5, 0.5]]) + torch.tensor([5., 0.]),
    torch.randn(150, 2) + torch.tensor([2.5, 4.]),
])

gmm = GMM(k=3)
gmm.fit(X)
print(f"Final log-likelihood: {gmm.log_likelihood_:.3f}")
print(f"Mixing weights: {gmm.pi.round(decimals=3).tolist()}")
print(f"Cluster sizes: {[(gmm.predict(X) == i).sum().item() for i in range(3)]}")

Model Selection with BIC/AIC

import torch
import math


def mvn_log_prob(X, mu, cov):
    d = X.shape[1]
    diff = X - mu.unsqueeze(0)
    cov_s = cov + 1e-6 * torch.eye(d)
    cov_inv = torch.linalg.inv(cov_s)
    log_det = torch.logdet(cov_s)
    return -0.5 * (d * math.log(2 * math.pi) + log_det + (diff @ cov_inv * diff).sum(1))


def fit_gmm(X, k, seed=42, max_iter=200):
    """Train GMM and return final log-likelihood."""
    torch.manual_seed(seed)
    X = X.float()
    n, d = X.shape
    pi = torch.ones(k) / k
    means = X[torch.randperm(n)[:k]].clone()
    covs = torch.stack([torch.eye(d) for _ in range(k)])
    for _ in range(max_iter):
        log_joint = torch.zeros(n, k)
        for c in range(k):
            log_joint[:, c] = torch.log(pi[c] + 1e-10) + mvn_log_prob(X, means[c], covs[c])
        log_sum = torch.logsumexp(log_joint, dim=1, keepdim=True)
        r = torch.exp(log_joint - log_sum)
        N_k = r.sum(0)
        pi = N_k / n
        means = (r.T @ X) / N_k.unsqueeze(1)
        for c in range(k):
            diff = X - means[c]
            covs[c] = (r[:, c].unsqueeze(1) * diff).T @ diff / N_k[c]
    return log_sum.sum().item()


torch.manual_seed(42)
X = torch.cat([
    torch.randn(100, 2) + torch.tensor([0., 0.]),
    torch.randn(100, 2) + torch.tensor([4., 0.]),
    torch.randn(100, 2) + torch.tensor([2., 3.5]),
])
n, d = X.shape

print("K  | Log-Lik  |    AIC     |    BIC")
for k in range(1, 7):
    ll = fit_gmm(X, k)
    # Number of free parameters: k means (d each) + k covariances (d*d each) + k-1 mixing weights
    n_params = k * d + k * d * d + (k - 1)
    aic = 2 * n_params - 2 * ll
    bic = n_params * math.log(n) - 2 * ll
    print(f" {k} | {ll:8.1f} | {aic:10.1f} | {bic:10.1f}")

print("\n(Lower BIC/AIC is better. True K=3)")