Back to PyTorch Mastery Series

Proximal Policy Optimization (PPO)

May 29, 2026 Wasil Zafar 42 min read

PPO is the most widely used deep RL algorithm today, used in ChatGPT's RLHF fine-tuning. The clipped surrogate objective constrains how much the policy can change in one update, providing stable training without the complexity of trust-region methods like TRPO.

Table of Contents

  1. Clipped Surrogate Objective
  2. PPO Network Architecture
  3. Rollout Buffer
  4. PPO Loss Computation
  5. Full PPO Training Loop
  6. Diagnostics & Monitoring
  7. Related Articles

Clipped Surrogate Objective

PPO's core insight: reuse the same rollout data for multiple gradient updates (unlike A2C which uses it once). The danger is that repeated updates on stale data cause the policy to deviate too far from the behavior policy. PPO clips the importance sampling ratio to prevent this.

Importance Sampling Ratio

When we update the policy $\pi_\theta$ using data from old policy $\pi_{\theta_\text{old}}$, we correct for the distribution mismatch:

$$r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_\text{old}}(a_t|s_t)} = \exp(\log\pi_\theta - \log\pi_{\theta_\text{old}})$$

Clipping Mechanism

The clipped objective prevents $r_t(\theta)$ from straying far from 1:

$$L^{CLIP}(\theta) = \mathbb{E}_t\left[\min\left(r_t(\theta)\hat{A}_t, \;\mathrm{clip}(r_t(\theta), 1-\varepsilon, 1+\varepsilon)\hat{A}_t\right)\right]$$

import torch
import torch.nn.functional as F


def ppo_clipped_loss(new_log_probs, old_log_probs, advantages, clip_eps=0.2):
    """
    PPO clipped surrogate loss.

    If advantage > 0 (action was better than expected):
      - ratio > 1+eps: cap the benefit (don't over-update)
      - ratio < 1    : no clipping (update is conservative, fine)

    If advantage < 0 (action was worse than expected):
      - ratio < 1-eps: cap the penalty (don't over-punish)
      - ratio > 1    : no clipping (conservative, fine)
    """
    # Importance sampling ratio: exp(log(new/old))
    ratios = torch.exp(new_log_probs - old_log_probs.detach())

    # Unclipped and clipped objectives
    surr1 = ratios * advantages
    surr2 = ratios.clamp(1.0 - clip_eps, 1.0 + clip_eps) * advantages

    # Take minimum: pessimistic (conservative) estimate
    loss = -torch.min(surr1, surr2).mean()
    return loss, ratios.mean().item()


# Visualize clipping
import torch
advantages = torch.linspace(-2, 2, 9)  # Different advantage signs/magnitudes
ratios_to_test = [0.7, 0.9, 1.0, 1.1, 1.3]

print("ratio | adv   | surr1  | surr2  | min (PPO loss contribution)")
for r_val in ratios_to_test:
    for adv_val in [-1.0, 1.0]:
        r = torch.tensor(r_val)
        adv = torch.tensor(adv_val)
        s1 = r * adv
        s2 = r.clamp(0.8, 1.2) * adv
        print(f"{r_val:.1f}  | {adv_val:+.1f}  | {s1.item():+.4f} | {s2.item():+.4f} | {torch.min(s1,s2).item():+.4f}")
    print()

PPO Network Architecture

import torch
import torch.nn as nn
from torch.distributions import Categorical


class PPOActorCritic(nn.Module):
    """
    PPO uses the same shared backbone as A2C.
    Key difference: we store old log_probs for the ratio computation.
    Orthogonal initialization is standard in PPO implementations.
    """

    def __init__(self, state_dim, n_actions, hidden_dim=64):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
        )
        self.actor = nn.Linear(hidden_dim, n_actions)
        self.critic = nn.Linear(hidden_dim, 1)

        # Orthogonal init: standard for PPO (cleanrl, stable-baselines3)
        for layer in self.backbone:
            if isinstance(layer, nn.Linear):
                nn.init.orthogonal_(layer.weight, gain=2**0.5)
                nn.init.constant_(layer.bias, 0.0)
        nn.init.orthogonal_(self.actor.weight, gain=0.01)
        nn.init.orthogonal_(self.critic.weight, gain=1.0)

    def forward(self, x):
        f = self.backbone(x)
        dist = Categorical(logits=self.actor(f))
        value = self.critic(f).squeeze(-1)
        return dist, value

    def get_action_and_value(self, x, action=None):
        """Used during rollout and during update phase."""
        dist, value = self(x)
        if action is None:
            action = dist.sample()
        return action, dist.log_prob(action), dist.entropy(), value


torch.manual_seed(42)
net = PPOActorCritic(state_dim=4, n_actions=2)
print(f"Parameters: {sum(p.numel() for p in net.parameters()):,}")

# Rollout step
s = torch.randn(4)
action, log_prob, entropy, value = net.get_action_and_value(s.unsqueeze(0))
print(f"action={action.item()}, log_prob={log_prob.item():.4f}, entropy={entropy.item():.4f}, V={value.item():.4f}")

# Update step: re-evaluate stored (state, action) pairs
states_batch = torch.randn(32, 4)
actions_batch = torch.randint(0, 2, (32,))
_, new_lps, new_ents, new_vals = net.get_action_and_value(states_batch, actions_batch)
print(f"Batch update shapes: log_probs={new_lps.shape}, values={new_vals.shape}")

Rollout Buffer

import torch
import torch.nn as nn
from torch.distributions import Categorical


class RolloutBuffer:
    """
    Fixed-size buffer for PPO rollouts.
    Stores (state, action, old_log_prob, reward, done, value) for N steps.
    After collection, computes advantages and returns once.
    """

    def __init__(self, n_steps, state_dim, gamma=0.99, lam=0.95):
        self.n, self.state_dim = n_steps, state_dim
        self.gamma, self.lam = gamma, lam
        self.reset()

    def reset(self):
        self.states      = torch.zeros(self.n, self.state_dim)
        self.actions     = torch.zeros(self.n, dtype=torch.long)
        self.old_log_probs = torch.zeros(self.n)
        self.rewards     = torch.zeros(self.n)
        self.dones       = torch.zeros(self.n)
        self.values      = torch.zeros(self.n)
        self.ptr         = 0

    def add(self, state, action, log_prob, reward, done, value):
        self.states[self.ptr]       = torch.as_tensor(state)
        self.actions[self.ptr]      = action
        self.old_log_probs[self.ptr] = log_prob.detach()
        self.rewards[self.ptr]       = reward
        self.dones[self.ptr]         = float(done)
        self.values[self.ptr]        = value.detach()
        self.ptr += 1

    def compute_advantages(self, next_value):
        """GAE-Lambda advantage estimation."""
        advs = torch.zeros(self.n)
        gae = 0.0
        all_v = self.values.tolist() + [next_value.item()]
        for t in reversed(range(self.n)):
            delta = self.rewards[t] + self.gamma * all_v[t+1] * (1-self.dones[t]) - all_v[t]
            gae = delta + self.gamma * self.lam * (1-self.dones[t]) * gae
            advs[t] = gae
        returns = advs + self.values
        # Normalize advantages
        advs = (advs - advs.mean()) / (advs.std() + 1e-8)
        return advs, returns

    def get_minibatches(self, advantages, returns, batch_size=64):
        """Generate shuffled mini-batches for K PPO epochs."""
        perm = torch.randperm(self.n)
        for start in range(0, self.n, batch_size):
            idx = perm[start:start+batch_size]
            yield (
                self.states[idx], self.actions[idx],
                self.old_log_probs[idx], advantages[idx], returns[idx],
            )


# Demo: fill buffer and extract mini-batches
torch.manual_seed(42)
buf = RolloutBuffer(n_steps=256, state_dim=4)
for _ in range(256):
    buf.add(
        state=torch.randn(4),
        action=torch.randint(0, 2, ()).item(),
        log_prob=torch.tensor(-0.7),
        reward=float(torch.randn(1).item()),
        done=False,
        value=torch.tensor(0.5),
    )
advs, rets = buf.compute_advantages(next_value=torch.tensor(0.4))
print(f"Advantages: mean={advs.mean():.4f}, std={advs.std():.4f} (should be ≈0, ≈1 after normalize)")
print(f"Returns:    mean={rets.mean():.4f}, std={rets.std():.4f}")

mb_count = sum(1 for _ in buf.get_minibatches(advs, rets, batch_size=64))
print(f"Mini-batches per epoch: {mb_count} (256 / 64 = 4)")

PPO Loss Computation

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical


def ppo_loss(net, states, actions, old_log_probs, advantages, returns,
             clip_eps=0.2, value_coeff=0.5, entropy_coeff=0.01):
    """
    Full PPO loss for one mini-batch.
    Called multiple times per rollout (K epochs).
    """
    # Re-evaluate stored (state, action) under current policy
    dist_new, values_new = net(states)
    new_log_probs = dist_new.log_prob(actions)
    entropy = dist_new.entropy()

    # Importance sampling ratio
    ratios = torch.exp(new_log_probs - old_log_probs.detach())

    # Clipped surrogate loss
    surr1 = ratios * advantages
    surr2 = ratios.clamp(1 - clip_eps, 1 + clip_eps) * advantages
    policy_loss = -torch.min(surr1, surr2).mean()

    # Value loss (clip value function too — optional but common in implementations)
    value_loss = F.mse_loss(values_new, returns)

    # Entropy bonus
    entropy_loss = -entropy.mean()

    total_loss = policy_loss + value_coeff * value_loss + entropy_coeff * entropy_loss

    # Diagnostics
    with torch.no_grad():
        clip_frac = ((ratios - 1).abs() > clip_eps).float().mean()
        approx_kl = ((old_log_probs - new_log_probs) ** 2).mean() * 0.5  # Second-order approx

    return total_loss, {
        'policy_loss': policy_loss.item(),
        'value_loss': value_loss.item(),
        'entropy': entropy.mean().item(),
        'clip_fraction': clip_frac.item(),
        'approx_kl': approx_kl.item(),
    }


# Demo: PPO loss sanity check
torch.manual_seed(42)

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l = nn.Linear(4, 64); self.l2 = nn.Linear(64, 2); self.v = nn.Linear(64, 1)
    def forward(self, x):
        f = torch.tanh(self.l(x))
        return Categorical(logits=self.l2(f)), self.v(f).squeeze(-1)

net = SimpleNet()
B = 64
states = torch.randn(B, 4); actions = torch.randint(0, 2, (B,))
old_lps = torch.randn(B) - 0.7  # Simulate stored log probs
advs = torch.randn(B); rets = torch.rand(B)

loss, info = ppo_loss(net, states, actions, old_lps, advs, rets)
print("PPO loss breakdown:")
for k, v in info.items():
    print(f"  {k:20s}: {v:.4f}")
print(f"\nClip fraction {info['clip_fraction']:.4f} (should be near 0 on first update; near 0.1-0.2 indicates active clipping)")

Full PPO Training Loop

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical


class PPOAgent:
    """Clean PPO implementation matching CleanRL/SB3 conventions."""

    def __init__(self, state_dim, n_actions, lr=3e-4, gamma=0.99, lam=0.95,
                 n_steps=256, n_epochs=4, batch_size=64,
                 clip_eps=0.2, value_coeff=0.5, entropy_coeff=0.01, seed=42):
        torch.manual_seed(seed)
        self.gamma, self.lam = gamma, lam
        self.n_steps, self.n_epochs = n_steps, n_epochs
        self.batch_size = batch_size
        self.clip_eps, self.vc, self.ec = clip_eps, value_coeff, entropy_coeff

        # Network
        self.net = nn.Sequential(nn.Linear(state_dim, 64), nn.Tanh(), nn.Linear(64, 64), nn.Tanh())
        self.actor = nn.Linear(64, n_actions)
        self.critic = nn.Linear(64, 1)
        for m in [self.net[0], self.net[2]]: nn.init.orthogonal_(m.weight, 2**0.5)
        nn.init.orthogonal_(self.actor.weight, 0.01)
        nn.init.orthogonal_(self.critic.weight, 1.0)

        self.optimizer = torch.optim.Adam(
            list(self.net.parameters()) + list(self.actor.parameters()) + list(self.critic.parameters()),
            lr=lr, eps=1e-5
        )

    def _forward(self, x):
        f = self.net(x)
        return Categorical(logits=self.actor(f)), self.critic(f).squeeze(-1)

    def _act(self, state):
        s = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
        dist, v = self._forward(s)
        a = dist.sample()
        return a.item(), dist.log_prob(a).squeeze(), v.squeeze()

    def _compute_gae(self, rewards, values, dones, next_v):
        advs = torch.zeros(self.n_steps)
        gae = 0.0; all_v = values.tolist() + [next_v.item()]
        for t in reversed(range(self.n_steps)):
            d = rewards[t] + self.gamma * all_v[t+1] * (1-dones[t]) - all_v[t]
            gae = d + self.gamma * self.lam * (1-dones[t]) * gae
            advs[t] = gae
        return advs, advs + values

    def collect_and_update(self, env):
        # Rollout collection
        buf_s, buf_a, buf_lp = [], [], []
        buf_r, buf_d, buf_v  = [], [], []
        s = env.reset()
        for _ in range(self.n_steps):
            a, lp, v = self._act(s)
            ns, r, done = env.step(a)
            buf_s.append(torch.as_tensor(s, dtype=torch.float32))
            buf_a.append(a); buf_lp.append(lp.detach())
            buf_r.append(float(r)); buf_d.append(float(done)); buf_v.append(v.detach())
            s = env.reset() if done else ns

        with torch.no_grad():
            _, next_v = self._forward(torch.as_tensor(s, dtype=torch.float32).unsqueeze(0))

        states_t = torch.stack(buf_s)
        actions_t = torch.tensor(buf_a, dtype=torch.long)
        old_lps_t = torch.stack(buf_lp)
        values_t  = torch.stack(buf_v)
        advs, rets = self._compute_gae(
            torch.tensor(buf_r), values_t, torch.tensor(buf_d), next_v.squeeze()
        )
        advs_norm = (advs - advs.mean()) / (advs.std() + 1e-8)

        # PPO epochs on collected data
        stats = {'policy_loss': [], 'value_loss': [], 'entropy': [], 'clip_frac': []}
        for _ in range(self.n_epochs):
            perm = torch.randperm(self.n_steps)
            for start in range(0, self.n_steps, self.batch_size):
                idx = perm[start:start+self.batch_size]
                dist, vals = self._forward(states_t[idx])
                new_lps = dist.log_prob(actions_t[idx])
                ents = dist.entropy()
                ratios = torch.exp(new_lps - old_lps_t[idx])
                adv_mb = advs_norm[idx]
                s1 = ratios * adv_mb
                s2 = ratios.clamp(1-self.clip_eps, 1+self.clip_eps) * adv_mb
                p_loss = -torch.min(s1, s2).mean()
                v_loss = F.mse_loss(vals, rets[idx])
                total = p_loss + self.vc * v_loss - self.ec * ents.mean()
                self.optimizer.zero_grad(); total.backward()
                nn.utils.clip_grad_norm_(
                    list(self.net.parameters()) + list(self.actor.parameters()) + list(self.critic.parameters()), 0.5
                )
                self.optimizer.step()
                with torch.no_grad():
                    clip_frac = ((ratios - 1).abs() > self.clip_eps).float().mean()
                stats['policy_loss'].append(p_loss.item())
                stats['value_loss'].append(v_loss.item())
                stats['entropy'].append(ents.mean().item())
                stats['clip_frac'].append(clip_frac.item())

        return {k: sum(v)/len(v) for k, v in stats.items()}


# Training benchmark
class CartPoleEnv:
    """Simplified CartPole-like env."""
    def reset(self): self.s = torch.randn(4).numpy().tolist(); return self.s
    def step(self, a):
        for i in range(4): self.s[i] += 0.05 * (1 if a==1 else -1) * (i+1)/4
        done = any(abs(x) > 2.4 for x in self.s[:2])
        return self.s, (0.0 if done else 1.0), done

torch.manual_seed(42)
env = CartPoleEnv()
agent = PPOAgent(4, 2, n_steps=128, n_epochs=4, batch_size=32, seed=42)

for update in range(40):
    info = agent.collect_and_update(env)
    if (update + 1) % 10 == 0:
        print(f"Update {update+1:2d}: p_loss={info['policy_loss']:+.4f}  "
              f"v_loss={info['value_loss']:.4f}  "
              f"entropy={info['entropy']:.4f}  "
              f"clip_frac={info['clip_frac']:.3f}")

Diagnostics & Monitoring

import torch


def ppo_diagnostics(ratios, old_log_probs, new_log_probs, advantages, clip_eps=0.2):
    """
    Key PPO training diagnostics.
    These help detect instability early — monitor during every update.
    """
    with torch.no_grad():
        # Clip fraction: >20% suggests eps too large or lr too high
        clip_frac = ((ratios - 1).abs() > clip_eps).float().mean()

        # Approximate KL: >0.02 suggests excessive policy change → early stop
        approx_kl = ((old_log_probs - new_log_probs) ** 2).mean() * 0.5

        # Explained variance: how well critic explains return variance
        # EV = 1 - Var(returns - values) / Var(returns)
        # EV < 0.5: critic is not learning; EV > 0.9: critic is well-calibrated
        returns_proxy = advantages + torch.zeros_like(advantages)  # Placeholder
        var_y = returns_proxy.var()
        explained_var = 1 - (returns_proxy - advantages).var() / (var_y + 1e-8)

        # Entropy: should stay positive; collapse to 0 = mode collapse
        probs = torch.softmax(torch.randn(32, 4), dim=-1)
        entropy = -(probs * probs.log()).sum(-1).mean()

    return {
        'clip_fraction': clip_frac.item(),
        'approx_kl': approx_kl.item(),
        'explained_variance': explained_var.item(),
        'entropy': entropy.item(),
    }


# Simulate healthy vs unhealthy training
torch.manual_seed(42)

print("=== Healthy training (ratios near 1, KL small) ===")
old_lps = torch.randn(256) - 0.7
new_lps_healthy = old_lps + torch.randn(256) * 0.05  # Small change
ratios_healthy = torch.exp(new_lps_healthy - old_lps)
advs = torch.randn(256)
diag = ppo_diagnostics(ratios_healthy, old_lps, new_lps_healthy, advs)
for k, v in diag.items():
    print(f"  {k}: {v:.4f}")

print("\n=== Problematic training (large policy change, high KL) ===")
new_lps_bad = old_lps + torch.randn(256) * 0.5  # Too large
ratios_bad = torch.exp(new_lps_bad - old_lps)
diag_bad = ppo_diagnostics(ratios_bad, old_lps, new_lps_bad, advs)
for k, v in diag_bad.items():
    flag = " ⚠️" if (k == 'clip_fraction' and v > 0.2) or (k == 'approx_kl' and v > 0.02) else ""
    print(f"  {k}: {v:.4f}{flag}")

print("\nRule of thumb thresholds:")
print("  clip_fraction > 0.2 → reduce clip_eps or lr")
print("  approx_kl > 0.02   → early stop epoch, reduce lr")
print("  entropy → 0        → add entropy coefficient")