Clipped Surrogate Objective
PPO's core insight: reuse the same rollout data for multiple gradient updates (unlike A2C which uses it once). The danger is that repeated updates on stale data cause the policy to deviate too far from the behavior policy. PPO clips the importance sampling ratio to prevent this.
Importance Sampling Ratio
When we update the policy $\pi_\theta$ using data from old policy $\pi_{\theta_\text{old}}$, we correct for the distribution mismatch:
$$r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_\text{old}}(a_t|s_t)} = \exp(\log\pi_\theta - \log\pi_{\theta_\text{old}})$$
Clipping Mechanism
The clipped objective prevents $r_t(\theta)$ from straying far from 1:
$$L^{CLIP}(\theta) = \mathbb{E}_t\left[\min\left(r_t(\theta)\hat{A}_t, \;\mathrm{clip}(r_t(\theta), 1-\varepsilon, 1+\varepsilon)\hat{A}_t\right)\right]$$
import torch
import torch.nn.functional as F
def ppo_clipped_loss(new_log_probs, old_log_probs, advantages, clip_eps=0.2):
"""
PPO clipped surrogate loss.
If advantage > 0 (action was better than expected):
- ratio > 1+eps: cap the benefit (don't over-update)
- ratio < 1 : no clipping (update is conservative, fine)
If advantage < 0 (action was worse than expected):
- ratio < 1-eps: cap the penalty (don't over-punish)
- ratio > 1 : no clipping (conservative, fine)
"""
# Importance sampling ratio: exp(log(new/old))
ratios = torch.exp(new_log_probs - old_log_probs.detach())
# Unclipped and clipped objectives
surr1 = ratios * advantages
surr2 = ratios.clamp(1.0 - clip_eps, 1.0 + clip_eps) * advantages
# Take minimum: pessimistic (conservative) estimate
loss = -torch.min(surr1, surr2).mean()
return loss, ratios.mean().item()
# Visualize clipping
import torch
advantages = torch.linspace(-2, 2, 9) # Different advantage signs/magnitudes
ratios_to_test = [0.7, 0.9, 1.0, 1.1, 1.3]
print("ratio | adv | surr1 | surr2 | min (PPO loss contribution)")
for r_val in ratios_to_test:
for adv_val in [-1.0, 1.0]:
r = torch.tensor(r_val)
adv = torch.tensor(adv_val)
s1 = r * adv
s2 = r.clamp(0.8, 1.2) * adv
print(f"{r_val:.1f} | {adv_val:+.1f} | {s1.item():+.4f} | {s2.item():+.4f} | {torch.min(s1,s2).item():+.4f}")
print()
PPO Network Architecture
import torch
import torch.nn as nn
from torch.distributions import Categorical
class PPOActorCritic(nn.Module):
"""
PPO uses the same shared backbone as A2C.
Key difference: we store old log_probs for the ratio computation.
Orthogonal initialization is standard in PPO implementations.
"""
def __init__(self, state_dim, n_actions, hidden_dim=64):
super().__init__()
self.backbone = nn.Sequential(
nn.Linear(state_dim, hidden_dim), nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
)
self.actor = nn.Linear(hidden_dim, n_actions)
self.critic = nn.Linear(hidden_dim, 1)
# Orthogonal init: standard for PPO (cleanrl, stable-baselines3)
for layer in self.backbone:
if isinstance(layer, nn.Linear):
nn.init.orthogonal_(layer.weight, gain=2**0.5)
nn.init.constant_(layer.bias, 0.0)
nn.init.orthogonal_(self.actor.weight, gain=0.01)
nn.init.orthogonal_(self.critic.weight, gain=1.0)
def forward(self, x):
f = self.backbone(x)
dist = Categorical(logits=self.actor(f))
value = self.critic(f).squeeze(-1)
return dist, value
def get_action_and_value(self, x, action=None):
"""Used during rollout and during update phase."""
dist, value = self(x)
if action is None:
action = dist.sample()
return action, dist.log_prob(action), dist.entropy(), value
torch.manual_seed(42)
net = PPOActorCritic(state_dim=4, n_actions=2)
print(f"Parameters: {sum(p.numel() for p in net.parameters()):,}")
# Rollout step
s = torch.randn(4)
action, log_prob, entropy, value = net.get_action_and_value(s.unsqueeze(0))
print(f"action={action.item()}, log_prob={log_prob.item():.4f}, entropy={entropy.item():.4f}, V={value.item():.4f}")
# Update step: re-evaluate stored (state, action) pairs
states_batch = torch.randn(32, 4)
actions_batch = torch.randint(0, 2, (32,))
_, new_lps, new_ents, new_vals = net.get_action_and_value(states_batch, actions_batch)
print(f"Batch update shapes: log_probs={new_lps.shape}, values={new_vals.shape}")
Rollout Buffer
import torch
import torch.nn as nn
from torch.distributions import Categorical
class RolloutBuffer:
"""
Fixed-size buffer for PPO rollouts.
Stores (state, action, old_log_prob, reward, done, value) for N steps.
After collection, computes advantages and returns once.
"""
def __init__(self, n_steps, state_dim, gamma=0.99, lam=0.95):
self.n, self.state_dim = n_steps, state_dim
self.gamma, self.lam = gamma, lam
self.reset()
def reset(self):
self.states = torch.zeros(self.n, self.state_dim)
self.actions = torch.zeros(self.n, dtype=torch.long)
self.old_log_probs = torch.zeros(self.n)
self.rewards = torch.zeros(self.n)
self.dones = torch.zeros(self.n)
self.values = torch.zeros(self.n)
self.ptr = 0
def add(self, state, action, log_prob, reward, done, value):
self.states[self.ptr] = torch.as_tensor(state)
self.actions[self.ptr] = action
self.old_log_probs[self.ptr] = log_prob.detach()
self.rewards[self.ptr] = reward
self.dones[self.ptr] = float(done)
self.values[self.ptr] = value.detach()
self.ptr += 1
def compute_advantages(self, next_value):
"""GAE-Lambda advantage estimation."""
advs = torch.zeros(self.n)
gae = 0.0
all_v = self.values.tolist() + [next_value.item()]
for t in reversed(range(self.n)):
delta = self.rewards[t] + self.gamma * all_v[t+1] * (1-self.dones[t]) - all_v[t]
gae = delta + self.gamma * self.lam * (1-self.dones[t]) * gae
advs[t] = gae
returns = advs + self.values
# Normalize advantages
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
return advs, returns
def get_minibatches(self, advantages, returns, batch_size=64):
"""Generate shuffled mini-batches for K PPO epochs."""
perm = torch.randperm(self.n)
for start in range(0, self.n, batch_size):
idx = perm[start:start+batch_size]
yield (
self.states[idx], self.actions[idx],
self.old_log_probs[idx], advantages[idx], returns[idx],
)
# Demo: fill buffer and extract mini-batches
torch.manual_seed(42)
buf = RolloutBuffer(n_steps=256, state_dim=4)
for _ in range(256):
buf.add(
state=torch.randn(4),
action=torch.randint(0, 2, ()).item(),
log_prob=torch.tensor(-0.7),
reward=float(torch.randn(1).item()),
done=False,
value=torch.tensor(0.5),
)
advs, rets = buf.compute_advantages(next_value=torch.tensor(0.4))
print(f"Advantages: mean={advs.mean():.4f}, std={advs.std():.4f} (should be ≈0, ≈1 after normalize)")
print(f"Returns: mean={rets.mean():.4f}, std={rets.std():.4f}")
mb_count = sum(1 for _ in buf.get_minibatches(advs, rets, batch_size=64))
print(f"Mini-batches per epoch: {mb_count} (256 / 64 = 4)")
PPO Loss Computation
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
def ppo_loss(net, states, actions, old_log_probs, advantages, returns,
clip_eps=0.2, value_coeff=0.5, entropy_coeff=0.01):
"""
Full PPO loss for one mini-batch.
Called multiple times per rollout (K epochs).
"""
# Re-evaluate stored (state, action) under current policy
dist_new, values_new = net(states)
new_log_probs = dist_new.log_prob(actions)
entropy = dist_new.entropy()
# Importance sampling ratio
ratios = torch.exp(new_log_probs - old_log_probs.detach())
# Clipped surrogate loss
surr1 = ratios * advantages
surr2 = ratios.clamp(1 - clip_eps, 1 + clip_eps) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
# Value loss (clip value function too — optional but common in implementations)
value_loss = F.mse_loss(values_new, returns)
# Entropy bonus
entropy_loss = -entropy.mean()
total_loss = policy_loss + value_coeff * value_loss + entropy_coeff * entropy_loss
# Diagnostics
with torch.no_grad():
clip_frac = ((ratios - 1).abs() > clip_eps).float().mean()
approx_kl = ((old_log_probs - new_log_probs) ** 2).mean() * 0.5 # Second-order approx
return total_loss, {
'policy_loss': policy_loss.item(),
'value_loss': value_loss.item(),
'entropy': entropy.mean().item(),
'clip_fraction': clip_frac.item(),
'approx_kl': approx_kl.item(),
}
# Demo: PPO loss sanity check
torch.manual_seed(42)
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.l = nn.Linear(4, 64); self.l2 = nn.Linear(64, 2); self.v = nn.Linear(64, 1)
def forward(self, x):
f = torch.tanh(self.l(x))
return Categorical(logits=self.l2(f)), self.v(f).squeeze(-1)
net = SimpleNet()
B = 64
states = torch.randn(B, 4); actions = torch.randint(0, 2, (B,))
old_lps = torch.randn(B) - 0.7 # Simulate stored log probs
advs = torch.randn(B); rets = torch.rand(B)
loss, info = ppo_loss(net, states, actions, old_lps, advs, rets)
print("PPO loss breakdown:")
for k, v in info.items():
print(f" {k:20s}: {v:.4f}")
print(f"\nClip fraction {info['clip_fraction']:.4f} (should be near 0 on first update; near 0.1-0.2 indicates active clipping)")
Full PPO Training Loop
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
class PPOAgent:
"""Clean PPO implementation matching CleanRL/SB3 conventions."""
def __init__(self, state_dim, n_actions, lr=3e-4, gamma=0.99, lam=0.95,
n_steps=256, n_epochs=4, batch_size=64,
clip_eps=0.2, value_coeff=0.5, entropy_coeff=0.01, seed=42):
torch.manual_seed(seed)
self.gamma, self.lam = gamma, lam
self.n_steps, self.n_epochs = n_steps, n_epochs
self.batch_size = batch_size
self.clip_eps, self.vc, self.ec = clip_eps, value_coeff, entropy_coeff
# Network
self.net = nn.Sequential(nn.Linear(state_dim, 64), nn.Tanh(), nn.Linear(64, 64), nn.Tanh())
self.actor = nn.Linear(64, n_actions)
self.critic = nn.Linear(64, 1)
for m in [self.net[0], self.net[2]]: nn.init.orthogonal_(m.weight, 2**0.5)
nn.init.orthogonal_(self.actor.weight, 0.01)
nn.init.orthogonal_(self.critic.weight, 1.0)
self.optimizer = torch.optim.Adam(
list(self.net.parameters()) + list(self.actor.parameters()) + list(self.critic.parameters()),
lr=lr, eps=1e-5
)
def _forward(self, x):
f = self.net(x)
return Categorical(logits=self.actor(f)), self.critic(f).squeeze(-1)
def _act(self, state):
s = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
dist, v = self._forward(s)
a = dist.sample()
return a.item(), dist.log_prob(a).squeeze(), v.squeeze()
def _compute_gae(self, rewards, values, dones, next_v):
advs = torch.zeros(self.n_steps)
gae = 0.0; all_v = values.tolist() + [next_v.item()]
for t in reversed(range(self.n_steps)):
d = rewards[t] + self.gamma * all_v[t+1] * (1-dones[t]) - all_v[t]
gae = d + self.gamma * self.lam * (1-dones[t]) * gae
advs[t] = gae
return advs, advs + values
def collect_and_update(self, env):
# Rollout collection
buf_s, buf_a, buf_lp = [], [], []
buf_r, buf_d, buf_v = [], [], []
s = env.reset()
for _ in range(self.n_steps):
a, lp, v = self._act(s)
ns, r, done = env.step(a)
buf_s.append(torch.as_tensor(s, dtype=torch.float32))
buf_a.append(a); buf_lp.append(lp.detach())
buf_r.append(float(r)); buf_d.append(float(done)); buf_v.append(v.detach())
s = env.reset() if done else ns
with torch.no_grad():
_, next_v = self._forward(torch.as_tensor(s, dtype=torch.float32).unsqueeze(0))
states_t = torch.stack(buf_s)
actions_t = torch.tensor(buf_a, dtype=torch.long)
old_lps_t = torch.stack(buf_lp)
values_t = torch.stack(buf_v)
advs, rets = self._compute_gae(
torch.tensor(buf_r), values_t, torch.tensor(buf_d), next_v.squeeze()
)
advs_norm = (advs - advs.mean()) / (advs.std() + 1e-8)
# PPO epochs on collected data
stats = {'policy_loss': [], 'value_loss': [], 'entropy': [], 'clip_frac': []}
for _ in range(self.n_epochs):
perm = torch.randperm(self.n_steps)
for start in range(0, self.n_steps, self.batch_size):
idx = perm[start:start+self.batch_size]
dist, vals = self._forward(states_t[idx])
new_lps = dist.log_prob(actions_t[idx])
ents = dist.entropy()
ratios = torch.exp(new_lps - old_lps_t[idx])
adv_mb = advs_norm[idx]
s1 = ratios * adv_mb
s2 = ratios.clamp(1-self.clip_eps, 1+self.clip_eps) * adv_mb
p_loss = -torch.min(s1, s2).mean()
v_loss = F.mse_loss(vals, rets[idx])
total = p_loss + self.vc * v_loss - self.ec * ents.mean()
self.optimizer.zero_grad(); total.backward()
nn.utils.clip_grad_norm_(
list(self.net.parameters()) + list(self.actor.parameters()) + list(self.critic.parameters()), 0.5
)
self.optimizer.step()
with torch.no_grad():
clip_frac = ((ratios - 1).abs() > self.clip_eps).float().mean()
stats['policy_loss'].append(p_loss.item())
stats['value_loss'].append(v_loss.item())
stats['entropy'].append(ents.mean().item())
stats['clip_frac'].append(clip_frac.item())
return {k: sum(v)/len(v) for k, v in stats.items()}
# Training benchmark
class CartPoleEnv:
"""Simplified CartPole-like env."""
def reset(self): self.s = torch.randn(4).numpy().tolist(); return self.s
def step(self, a):
for i in range(4): self.s[i] += 0.05 * (1 if a==1 else -1) * (i+1)/4
done = any(abs(x) > 2.4 for x in self.s[:2])
return self.s, (0.0 if done else 1.0), done
torch.manual_seed(42)
env = CartPoleEnv()
agent = PPOAgent(4, 2, n_steps=128, n_epochs=4, batch_size=32, seed=42)
for update in range(40):
info = agent.collect_and_update(env)
if (update + 1) % 10 == 0:
print(f"Update {update+1:2d}: p_loss={info['policy_loss']:+.4f} "
f"v_loss={info['value_loss']:.4f} "
f"entropy={info['entropy']:.4f} "
f"clip_frac={info['clip_frac']:.3f}")
Diagnostics & Monitoring
import torch
def ppo_diagnostics(ratios, old_log_probs, new_log_probs, advantages, clip_eps=0.2):
"""
Key PPO training diagnostics.
These help detect instability early — monitor during every update.
"""
with torch.no_grad():
# Clip fraction: >20% suggests eps too large or lr too high
clip_frac = ((ratios - 1).abs() > clip_eps).float().mean()
# Approximate KL: >0.02 suggests excessive policy change → early stop
approx_kl = ((old_log_probs - new_log_probs) ** 2).mean() * 0.5
# Explained variance: how well critic explains return variance
# EV = 1 - Var(returns - values) / Var(returns)
# EV < 0.5: critic is not learning; EV > 0.9: critic is well-calibrated
returns_proxy = advantages + torch.zeros_like(advantages) # Placeholder
var_y = returns_proxy.var()
explained_var = 1 - (returns_proxy - advantages).var() / (var_y + 1e-8)
# Entropy: should stay positive; collapse to 0 = mode collapse
probs = torch.softmax(torch.randn(32, 4), dim=-1)
entropy = -(probs * probs.log()).sum(-1).mean()
return {
'clip_fraction': clip_frac.item(),
'approx_kl': approx_kl.item(),
'explained_variance': explained_var.item(),
'entropy': entropy.item(),
}
# Simulate healthy vs unhealthy training
torch.manual_seed(42)
print("=== Healthy training (ratios near 1, KL small) ===")
old_lps = torch.randn(256) - 0.7
new_lps_healthy = old_lps + torch.randn(256) * 0.05 # Small change
ratios_healthy = torch.exp(new_lps_healthy - old_lps)
advs = torch.randn(256)
diag = ppo_diagnostics(ratios_healthy, old_lps, new_lps_healthy, advs)
for k, v in diag.items():
print(f" {k}: {v:.4f}")
print("\n=== Problematic training (large policy change, high KL) ===")
new_lps_bad = old_lps + torch.randn(256) * 0.5 # Too large
ratios_bad = torch.exp(new_lps_bad - old_lps)
diag_bad = ppo_diagnostics(ratios_bad, old_lps, new_lps_bad, advs)
for k, v in diag_bad.items():
flag = " ⚠️" if (k == 'clip_fraction' and v > 0.2) or (k == 'approx_kl' and v > 0.02) else ""
print(f" {k}: {v:.4f}{flag}")
print("\nRule of thumb thresholds:")
print(" clip_fraction > 0.2 → reduce clip_eps or lr")
print(" approx_kl > 0.02 → early stop epoch, reduce lr")
print(" entropy → 0 → add entropy coefficient")