A2C Overview
A2C improves over REINFORCE in three ways:
- Critic (value baseline) — V(s) reduces variance without introducing bias
- N-step returns — update every N steps instead of waiting for episode end (handles continuous envs)
- GAE — interpolates between high-bias 1-step TD and high-variance Monte Carlo returns
The A2C objective:
$$\mathcal{L} = \underbrace{-\sum_t \log \pi_\theta(a_t|s_t) \hat{A}_t}_{\text{policy loss}} + \underbrace{c_1 \sum_t (V_\phi(s_t) - G_t)^2}_{\text{value loss}} - \underbrace{c_2 H[\pi_\theta(\cdot|s_t)]}_{\text{entropy bonus}}$$
Actor-Critic Network
import torch
import torch.nn as nn
from torch.distributions import Categorical
class ActorCriticNet(nn.Module):
"""
Shared-backbone actor-critic network.
Shared features reduce memory and often improve learning by
forcing the representations to be useful for both tasks.
"""
def __init__(self, state_dim, n_actions, hidden_dim=256):
super().__init__()
# Shared feature extractor
self.backbone = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
)
# Actor head: outputs action logits
self.actor_head = nn.Linear(hidden_dim, n_actions)
# Critic head: outputs scalar state-value V(s)
self.critic_head = nn.Linear(hidden_dim, 1)
# Initialize output layers with small weights
nn.init.orthogonal_(self.actor_head.weight, gain=0.01)
nn.init.orthogonal_(self.critic_head.weight, gain=1.0)
def forward(self, x):
"""Returns (Categorical distribution, value tensor)."""
features = self.backbone(x)
dist = Categorical(logits=self.actor_head(features))
value = self.critic_head(features).squeeze(-1)
return dist, value
def act(self, state):
"""Sample action and return (action, log_prob, value)."""
s = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
dist, value = self(s)
action = dist.sample()
return action.item(), dist.log_prob(action).squeeze(), value.squeeze()
def evaluate(self, states):
"""Batch evaluation for gradient update."""
return self(states)
torch.manual_seed(42)
net = ActorCriticNet(state_dim=8, n_actions=4) # e.g., LunarLander
total_params = sum(p.numel() for p in net.parameters())
print(f"Total parameters: {total_params:,}")
# Test single step
s = torch.randn(8).numpy()
action, log_prob, value = net.act(s)
print(f"Action: {action}, log π(a|s): {log_prob.item():.4f}, V(s): {value.item():.4f}")
# Test batch forward
batch_states = torch.randn(32, 8)
dist_batch, values_batch = net(batch_states)
print(f"Batch: dist entropy mean = {dist_batch.entropy().mean():.4f}, values shape = {values_batch.shape}")
Generalized Advantage Estimation (GAE)
import torch
def compute_gae(rewards, values, next_value, dones, gamma=0.99, lam=0.95):
"""
Generalized Advantage Estimation (Schulman 2015).
Advantage_t = sum_{l=0}^{T-t} (gamma * lam)^l * delta_{t+l}
where delta_t = r_t + gamma * V(s_{t+1}) * (1-done) - V(s_t)
lam=1.0 -> Monte Carlo (high variance, low bias)
lam=0.0 -> 1-step TD (low variance, high bias)
lam=0.95 -> good empirical tradeoff (PPO paper)
"""
T = len(rewards)
advantages = torch.zeros(T)
gae = 0.0
# Bootstrap from next_value (0 if terminal)
all_values = values.tolist() + [next_value.item()]
for t in reversed(range(T)):
delta = rewards[t] + gamma * all_values[t+1] * (1 - dones[t]) - all_values[t]
gae = delta + gamma * lam * (1 - dones[t]) * gae
advantages[t] = gae
returns = advantages + values # V(s) + A(s,a) = Q(s,a) estimate
return advantages, returns
# Demo: compare different lambda values
torch.manual_seed(42)
T = 20
rewards = torch.rand(T) * 2 - 1 # rewards in [-1, 1]
values = torch.rand(T) * 2 # estimated values
dones = (torch.rand(T) < 0.1).float() # 10% terminal
next_value = torch.tensor(0.5)
print(f"{'t':>3} | {'reward':>8} | {'A_lam=0.0':>10} | {'A_lam=0.95':>11} | {'A_lam=1.0':>10}")
adv_td, _ = compute_gae(rewards, values, next_value, dones, gamma=0.99, lam=0.0)
adv_gae, _ = compute_gae(rewards, values, next_value, dones, gamma=0.99, lam=0.95)
adv_mc, _ = compute_gae(rewards, values, next_value, dones, gamma=0.99, lam=1.0)
for t in range(min(8, T)):
print(f"{t:>3} | {rewards[t].item():>8.4f} | {adv_td[t].item():>10.4f} | {adv_gae[t].item():>11.4f} | {adv_mc[t].item():>10.4f}")
print(f"\nStd (lower = more stable gradient):")
print(f" lam=0.0 (TD): {adv_td.std().item():.4f}")
print(f" lam=0.95 (GAE): {adv_gae.std().item():.4f}")
print(f" lam=1.0 (MC): {adv_mc.std().item():.4f}")
A2C Loss Function
Entropy Bonus
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
def a2c_loss(log_probs, entropies, advantages, values, returns,
value_coeff=0.5, entropy_coeff=0.01):
"""
A2C combined loss.
policy_loss: maximize E[log_prob * advantage] (gradient ascent direction)
value_loss: minimize MSE(V(s), G_t) (Huber for robustness)
entropy_bonus: maximize H[pi] to encourage exploration (prevents premature convergence)
entropy_coeff=0.01 is the standard starting point from the A3C paper.
"""
# Normalize advantages (per-batch)
adv = advantages.detach()
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
policy_loss = -(log_probs * adv).mean()
value_loss = F.huber_loss(values, returns.detach())
entropy_loss = -entropies.mean() # Negative: we want to maximize entropy
total = policy_loss + value_coeff * value_loss + entropy_coeff * entropy_loss
return total, {
'policy': policy_loss.item(),
'value': value_loss.item(),
'entropy': entropies.mean().item(),
}
# Demo: show entropy effect
torch.manual_seed(42)
T = 32
# High-entropy policy (uniform-ish)
logits_uniform = torch.zeros(T, 4)
dist_uniform = Categorical(logits=logits_uniform)
entropy_uniform = dist_uniform.entropy().mean().item()
# Low-entropy policy (peaked)
logits_peaked = torch.zeros(T, 4); logits_peaked[:, 0] = 5.0
dist_peaked = Categorical(logits=logits_peaked)
entropy_peaked = dist_peaked.entropy().mean().item()
print(f"Uniform policy entropy: {entropy_uniform:.4f} (max = {torch.log(torch.tensor(4.0)).item():.4f})")
print(f"Peaked policy entropy: {entropy_peaked:.4f} (min = 0.0)")
print(f"Entropy bonus encourages staying closer to uniform distribution")
# Synthetic A2C loss demo
advantages = torch.randn(T)
returns = torch.rand(T)
values = torch.rand(T)
log_probs = dist_uniform.log_prob(dist_uniform.sample())
entropies = dist_uniform.entropy()
total, breakdown = a2c_loss(
log_probs=dist_uniform.log_prob(torch.zeros(T, dtype=torch.long)),
entropies=entropies,
advantages=advantages,
values=values,
returns=returns,
)
print(f"\nA2C loss breakdown: {breakdown}")
N-Step Rollout Collection
import torch
import torch.nn as nn
from torch.distributions import Categorical
class ActorCriticNet(nn.Module):
def __init__(self, s, a, h=256):
super().__init__()
self.backbone = nn.Sequential(nn.Linear(s, h), nn.Tanh(), nn.Linear(h, h), nn.Tanh())
self.actor_head = nn.Linear(h, a)
self.critic_head = nn.Linear(h, 1)
def forward(self, x):
f = self.backbone(x)
return Categorical(logits=self.actor_head(f)), self.critic_head(f).squeeze(-1)
def act(self, s):
d, v = self(torch.as_tensor(s, dtype=torch.float32).unsqueeze(0))
a = d.sample()
return a.item(), d.log_prob(a).squeeze(), d.entropy().squeeze(), v.squeeze()
def collect_rollout(env, net, n_steps=128, gamma=0.99, lam=0.95):
"""
Collect n_steps of transitions and compute GAE advantages.
Returns tensors ready for gradient update.
"""
states, log_probs, entropies = [], [], []
values, rewards, dones = [], [], []
s = env.reset()
for _ in range(n_steps):
a, lp, ent, v = net.act(s)
ns, r, done = env.step(a)
states.append(torch.as_tensor(s, dtype=torch.float32))
log_probs.append(lp); entropies.append(ent)
values.append(v); rewards.append(float(r)); dones.append(float(done))
s = env.reset() if done else ns
# Bootstrap value from final state
with torch.no_grad():
_, next_v = net(torch.as_tensor(s, dtype=torch.float32).unsqueeze(0))
next_v = next_v.squeeze()
# GAE
advantages = torch.zeros(n_steps)
gae = 0.0
all_v = torch.stack(values).tolist() + [next_v.item()]
for t in reversed(range(n_steps)):
delta = rewards[t] + gamma * all_v[t+1] * (1 - dones[t]) - all_v[t]
gae = delta + gamma * lam * (1 - dones[t]) * gae
advantages[t] = gae
values_t = torch.stack(values)
returns = advantages + values_t
return (
torch.stack(states),
torch.stack(log_probs),
torch.stack(entropies),
values_t,
advantages,
returns,
)
# Demo: collect rollout from simple env
class SimpleEnv:
def reset(self): self.x = 0.0; return [self.x, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
def step(self, a):
self.x += 0.1 * (1 if a % 2 == 1 else -1)
done = abs(self.x) > 1.0
return [self.x]+[0]*7, (1.0 if self.x>1 else -1.0 if self.x<-1 else 0.01), done
torch.manual_seed(42)
env = SimpleEnv(); net = ActorCriticNet(8, 4)
states, lps, ents, vals, advs, rets = collect_rollout(env, net, n_steps=64)
print(f"Rollout shapes: states={states.shape}, advantages={advs.shape}, returns={rets.shape}")
print(f"Advantage stats: mean={advs.mean():.4f}, std={advs.std():.4f}")
print(f"Value stats: mean={vals.mean():.4f}, std={vals.std():.4f}")
Full A2C Agent
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
class ActorCriticNet(nn.Module):
def __init__(self, s, a, h=128):
super().__init__()
self.backbone = nn.Sequential(nn.Linear(s,h), nn.Tanh(), nn.Linear(h,h), nn.Tanh())
self.actor = nn.Linear(h, a)
self.critic = nn.Linear(h, 1)
nn.init.orthogonal_(self.actor.weight, 0.01)
nn.init.orthogonal_(self.critic.weight, 1.0)
def forward(self, x):
f = self.backbone(x)
return Categorical(logits=self.actor(f)), self.critic(f).squeeze(-1)
class A2CAgent:
"""Full Advantage Actor-Critic agent."""
def __init__(self, state_dim, n_actions, lr=7e-4, gamma=0.99, lam=0.95,
n_steps=128, value_coeff=0.5, entropy_coeff=0.01, seed=42):
torch.manual_seed(seed)
self.gamma, self.lam = gamma, lam
self.n_steps = n_steps
self.vc, self.ec = value_coeff, entropy_coeff
self.net = ActorCriticNet(state_dim, n_actions)
self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=lr, alpha=0.99, eps=1e-5)
def collect_and_update(self, env):
states, log_probs, entropies, values, rewards, dones = [], [], [], [], [], []
s = env.reset()
for _ in range(self.n_steps):
s_t = torch.as_tensor(s, dtype=torch.float32).unsqueeze(0)
dist, v = self.net(s_t)
a = dist.sample()
ns, r, done = env.step(a.item())
states.append(s_t.squeeze()); log_probs.append(dist.log_prob(a).squeeze())
entropies.append(dist.entropy().squeeze()); values.append(v.squeeze())
rewards.append(float(r)); dones.append(float(done))
s = env.reset() if done else ns
with torch.no_grad():
_, next_v = self.net(torch.as_tensor(s, dtype=torch.float32).unsqueeze(0))
# GAE
advs = torch.zeros(self.n_steps); gae = 0.0
all_v = [v.item() for v in values] + [next_v.item()]
for t in reversed(range(self.n_steps)):
d = rewards[t] + self.gamma * all_v[t+1] * (1-dones[t]) - all_v[t]
gae = d + self.gamma * self.lam * (1-dones[t]) * gae
advs[t] = gae
vals_t = torch.stack(values); rets = advs + vals_t
lps_t = torch.stack(log_probs); ents_t = torch.stack(entropies)
norm_adv = (advs - advs.mean()) / (advs.std() + 1e-8)
p_loss = -(lps_t * norm_adv.detach()).mean()
v_loss = F.huber_loss(vals_t, rets.detach())
e_loss = -ents_t.mean()
total = p_loss + self.vc * v_loss + self.ec * e_loss
self.optimizer.zero_grad(); total.backward()
nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
self.optimizer.step()
return total.item(), ents_t.mean().item()
# Training benchmark on simple env
class SimpleEnv:
def reset(self): self.x=0.0; return [self.x]+[0]*3
def step(self, a):
self.x += 0.1*(1 if a==1 else -1)
done = abs(self.x) > 1.0
return [self.x]+[0]*3, (1.0 if self.x>1 else -1.0 if self.x<-1 else 0.01), done
torch.manual_seed(42)
env = SimpleEnv(); agent = A2CAgent(4, 2, n_steps=64, seed=42)
losses, entropies = [], []
for update in range(100):
loss, entropy = agent.collect_and_update(env)
losses.append(loss); entropies.append(entropy)
if (update + 1) % 20 == 0:
print(f"Update {update+1:3d}: loss={loss:.4f}, entropy={entropy:.4f}")
print(f"\nFinal 20-update avg loss: {sum(losses[-20:])/20:.4f}")
print(f"Final entropy (exploration): {entropies[-1]:.4f}")