Policy Gradient Theorem
The policy gradient theorem gives us an unbiased gradient estimator for the expected return $J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[G_0]$:
$$\nabla_\theta J(\theta) = \mathbb{E}_\tau\left[\sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(A_t|S_t) \cdot G_t\right]$$
where $G_t = \sum_{k=t}^{T} \gamma^{k-t} R_{k+1}$ is the discounted return from step $t$. The key insight: we can backpropagate through $\log \pi_\theta$ even though the action was sampled stochastically.
Policy Network
Discrete Actions: Categorical Distribution
import torch
import torch.nn as nn
from torch.distributions import Categorical
class DiscretePolicy(nn.Module):
"""
Softmax policy for discrete action spaces.
Outputs action probabilities via Categorical distribution.
"""
def __init__(self, state_dim, n_actions, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(), # Tanh works better than ReLU for policy nets (bounded outputs)
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, n_actions),
)
def forward(self, x):
"""Returns Categorical distribution over actions."""
logits = self.net(x)
return Categorical(logits=logits)
def act(self, state):
"""Sample action and return (action, log_prob)."""
state_t = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
dist = self(state_t)
action = dist.sample()
log_prob = dist.log_prob(action)
return action.item(), log_prob
torch.manual_seed(42)
policy = DiscretePolicy(state_dim=4, n_actions=2)
# Demo: sample actions
state = torch.randn(4)
for _ in range(5):
action, log_prob = policy.act(state.numpy())
print(f"Action: {action}, log π(a|s): {log_prob.item():.4f}, π(a|s): {log_prob.exp().item():.4f}")
Continuous Actions: Gaussian Distribution
import torch
import torch.nn as nn
from torch.distributions import Normal
class GaussianPolicy(nn.Module):
"""
Diagonal Gaussian policy for continuous action spaces.
Parameterizes mean and log-std separately.
"""
def __init__(self, state_dim, action_dim, hidden_dim=128,
log_std_min=-20, log_std_max=2):
super().__init__()
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.backbone = nn.Sequential(
nn.Linear(state_dim, hidden_dim), nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
)
self.mean_head = nn.Linear(hidden_dim, action_dim)
self.log_std_head = nn.Linear(hidden_dim, action_dim)
def forward(self, x):
features = self.backbone(x)
mean = self.mean_head(features)
log_std = self.log_std_head(features).clamp(self.log_std_min, self.log_std_max)
return Normal(mean, log_std.exp())
def act(self, state):
s = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
dist = self(s)
action = dist.sample()
log_prob = dist.log_prob(action).sum(dim=-1) # Sum over action dims
return action.squeeze(0), log_prob
torch.manual_seed(42)
policy = GaussianPolicy(state_dim=8, action_dim=3) # e.g., LunarLanderContinuous
state = torch.randn(8).numpy()
action, log_prob = policy.act(state)
print(f"Action shape: {action.shape}")
print(f"Action: {action.tolist()}")
print(f"log π(a|s): {log_prob.item():.4f}")
REINFORCE Algorithm
Computing Discounted Returns
import torch
def compute_returns(rewards, gamma=0.99, normalize=True):
"""
Compute discounted returns G_t for each timestep t.
G_t = R_{t+1} + gamma * R_{t+2} + gamma^2 * R_{t+3} + ...
normalize=True: subtract mean, divide by std (reduces gradient variance)
"""
G = 0.0
returns = []
for r in reversed(rewards):
G = r + gamma * G
returns.insert(0, G)
returns = torch.tensor(returns, dtype=torch.float32)
if normalize and len(returns) > 1:
returns = (returns - returns.mean()) / (returns.std() + 1e-8)
return returns
# Demo: 10-step episode with increasing rewards
rewards = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
returns_raw = compute_returns(rewards, gamma=0.99, normalize=False)
returns_norm = compute_returns(rewards, gamma=0.99, normalize=True)
print("t | reward | G_t (raw) | G_t (normalized)")
for t, (r, g_raw, g_norm) in enumerate(zip(rewards, returns_raw, returns_norm)):
print(f"{t:2d} | {r:6.1f} | {g_raw.item():9.4f} | {g_norm.item():16.4f}")
REINFORCE Gradient Update
import torch
import torch.nn as nn
from torch.distributions import Categorical
class DiscretePolicy(nn.Module):
def __init__(self, s, a, h=128):
super().__init__()
self.net = nn.Sequential(nn.Linear(s, h), nn.Tanh(), nn.Linear(h, h), nn.Tanh(), nn.Linear(h, a))
def forward(self, x): return Categorical(logits=self.net(x))
def act(self, state):
d = self(torch.as_tensor(state, dtype=torch.float32).unsqueeze(0))
a = d.sample()
return a.item(), d.log_prob(a)
def reinforce_update(log_probs, returns, optimizer):
"""
REINFORCE loss = -sum_t [ log pi(a_t|s_t) * G_t ]
Negative because we ascend the gradient (maximize J),
but PyTorch minimizes by convention.
"""
# Ensure returns are normalized
returns_t = torch.stack(returns) if isinstance(returns[0], torch.Tensor) else torch.tensor(returns, dtype=torch.float32)
returns_t = (returns_t - returns_t.mean()) / (returns_t.std() + 1e-8)
log_probs_t = torch.stack(log_probs)
# Policy gradient loss: -E[log pi * G]
policy_loss = -(log_probs_t * returns_t).mean()
optimizer.zero_grad()
policy_loss.backward()
# Optional: clip gradients for stability
nn.utils.clip_grad_norm_(optimizer.param_groups[0]['params'], max_norm=0.5)
optimizer.step()
return policy_loss.item()
# Verify gradient direction: higher return should increase log_prob
torch.manual_seed(42)
policy = DiscretePolicy(4, 2)
optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4)
state = [0.1, -0.2, 0.3, 0.05]
_, log_prob = policy.act(state)
action_before = torch.tensor([0.1, -0.2, 0.3, 0.05])
probs_before = policy(torch.tensor([[0.1, -0.2, 0.3, 0.05]])).probs.detach().clone()
# Positive return: should reinforce the taken action
loss = reinforce_update([log_prob], [10.0], optimizer)
probs_after = policy(torch.tensor([[0.1, -0.2, 0.3, 0.05]])).probs.detach()
print(f"Loss: {loss:.4f}")
print(f"Probs before update: {probs_before.squeeze().tolist()}")
print(f"Probs after update: {probs_after.squeeze().tolist()}")
Variance Reduction with a Baseline
import torch
import torch.nn as nn
from torch.distributions import Categorical
class PolicyWithBaseline(nn.Module):
"""
Policy network with a shared backbone and a separate value baseline head.
The baseline V(s) estimates expected return, reducing gradient variance.
Loss = policy_loss + baseline_coeff * value_loss
"""
def __init__(self, state_dim, n_actions, hidden_dim=128):
super().__init__()
self.backbone = nn.Sequential(
nn.Linear(state_dim, hidden_dim), nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
)
self.policy_head = nn.Linear(hidden_dim, n_actions)
self.value_head = nn.Linear(hidden_dim, 1) # Baseline: V(s)
def forward(self, x):
features = self.backbone(x)
dist = Categorical(logits=self.policy_head(features))
value = self.value_head(features).squeeze(-1)
return dist, value
def act(self, state):
s = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
dist, value = self(s)
action = dist.sample()
return action.item(), dist.log_prob(action), value
def reinforce_with_baseline(log_probs, values, rewards, optimizer, gamma=0.99, beta=0.5):
"""
REINFORCE with baseline:
policy_loss = -sum_t [ log_prob * (G_t - V(s_t)) ]
value_loss = MSE(V(s_t), G_t)
(G_t - V(s_t)) is the 'advantage' — advantage > 0: action was better than expected
"""
G = 0.0; returns = []
for r in reversed(rewards):
G = r + gamma * G; returns.insert(0, G)
returns_t = torch.tensor(returns, dtype=torch.float32)
returns_t = (returns_t - returns_t.mean()) / (returns_t.std() + 1e-8)
log_probs_t = torch.stack(log_probs)
values_t = torch.stack(values)
advantages = returns_t - values_t.detach() # Stop gradient through baseline
policy_loss = -(log_probs_t * advantages).mean()
value_loss = nn.functional.mse_loss(values_t, returns_t)
total_loss = policy_loss + beta * value_loss
optimizer.zero_grad(); total_loss.backward()
nn.utils.clip_grad_norm_(optimizer.param_groups[0]['params'], 0.5)
optimizer.step()
return total_loss.item(), policy_loss.item(), value_loss.item()
torch.manual_seed(42)
net = PolicyWithBaseline(4, 2)
opt = torch.optim.Adam(net.parameters(), lr=3e-4)
# Synthetic episode
states = [torch.randn(4) for _ in range(10)]
log_probs, values, rewards = [], [], []
for s in states:
a, lp, v = net.act(s.numpy())
log_probs.append(lp.squeeze()); values.append(v.squeeze())
rewards.append(float(torch.rand(1).item())) # random rewards
total, p_loss, v_loss = reinforce_with_baseline(log_probs, values, rewards, opt)
print(f"Total loss: {total:.4f} (policy: {p_loss:.4f}, value: {v_loss:.4f})")
Full REINFORCE Agent
import torch
import torch.nn as nn
from torch.distributions import Categorical
class REINFORCEAgent:
"""Complete REINFORCE agent with optional baseline."""
def __init__(self, state_dim, n_actions, lr=3e-4, gamma=0.99, use_baseline=True, seed=42):
torch.manual_seed(seed)
self.gamma = gamma; self.use_baseline = use_baseline
self.backbone = nn.Sequential(nn.Linear(state_dim, 128), nn.Tanh(), nn.Linear(128, 128), nn.Tanh())
self.policy_head = nn.Linear(128, n_actions)
self.value_head = nn.Linear(128, 1) if use_baseline else None
params = list(self.backbone.parameters()) + list(self.policy_head.parameters())
if use_baseline: params += list(self.value_head.parameters())
self.optimizer = torch.optim.Adam(params, lr=lr)
self.reset_episode()
def reset_episode(self):
self.ep_log_probs, self.ep_values, self.ep_rewards = [], [], []
def act(self, state):
s = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
f = self.backbone(s)
dist = Categorical(logits=self.policy_head(f))
a = dist.sample()
self.ep_log_probs.append(dist.log_prob(a).squeeze())
if self.use_baseline:
self.ep_values.append(self.value_head(f).squeeze())
return a.item()
def store_reward(self, r): self.ep_rewards.append(float(r))
def finish_episode(self):
G = 0.0; returns = []
for r in reversed(self.ep_rewards):
G = r + self.gamma * G; returns.insert(0, G)
R = torch.tensor(returns, dtype=torch.float32)
if len(R) > 1: R = (R - R.mean()) / (R.std() + 1e-8)
lp = torch.stack(self.ep_log_probs)
if self.use_baseline:
V = torch.stack(self.ep_values)
adv = R - V.detach()
loss = -(lp * adv).mean() + 0.5 * nn.functional.mse_loss(V, R)
else:
loss = -(lp * R).mean()
self.optimizer.zero_grad(); loss.backward()
nn.utils.clip_grad_norm_(self.optimizer.param_groups[0]['params'], 0.5)
self.optimizer.step(); self.reset_episode()
return loss.item()
# Benchmark on simple env
class SimpleEnv:
def reset(self): self.x = 0.0; return [self.x, 0.0, 0.0, 0.0]
def step(self, a):
self.x += 0.1 * (1 if a==1 else -1)
done = abs(self.x) > 1.0
return [self.x,0,0,0], (1.0 if self.x>1.0 else -1.0 if self.x<-1.0 else 0.01), done
torch.manual_seed(42)
for use_bl in [False, True]:
agent = REINFORCEAgent(4, 2, use_baseline=use_bl, seed=42)
env = SimpleEnv(); ep_rewards = []
for ep in range(400):
s = env.reset(); total = 0.0
for _ in range(50):
a = agent.act(s); ns, r, done = env.step(a)
agent.store_reward(r); total += r; s = ns
if done: break
agent.finish_episode(); ep_rewards.append(total)
label = "with baseline" if use_bl else "no baseline "
print(f"REINFORCE {label} — last 100-ep avg: {sum(ep_rewards[-100:])/100:.4f}")