Back to PyTorch Mastery Series

Actor-Critic Methods (A2C)

May 29, 2026 Wasil Zafar 38 min read

Advantage Actor-Critic (A2C) combines a learned policy (actor) with a learned value function (critic) to reduce the variance of policy gradient updates. Implement shared backbone networks, Generalized Advantage Estimation (GAE), entropy bonuses, and synchronized multi-step rollouts.

Table of Contents

  1. A2C Overview
  2. Actor-Critic Network
  3. Generalized Advantage Estimation
  4. A2C Loss Function
  5. N-Step Rollout Collection
  6. Full A2C Agent
  7. Related Articles

A2C Overview

A2C improves over REINFORCE in three ways:

  1. Critic (value baseline) — V(s) reduces variance without introducing bias
  2. N-step returns — update every N steps instead of waiting for episode end (handles continuous envs)
  3. GAE — interpolates between high-bias 1-step TD and high-variance Monte Carlo returns

The A2C objective:

$$\mathcal{L} = \underbrace{-\sum_t \log \pi_\theta(a_t|s_t) \hat{A}_t}_{\text{policy loss}} + \underbrace{c_1 \sum_t (V_\phi(s_t) - G_t)^2}_{\text{value loss}} - \underbrace{c_2 H[\pi_\theta(\cdot|s_t)]}_{\text{entropy bonus}}$$

Actor-Critic Network

import torch
import torch.nn as nn
from torch.distributions import Categorical


class ActorCriticNet(nn.Module):
    """
    Shared-backbone actor-critic network.
    Shared features reduce memory and often improve learning by
    forcing the representations to be useful for both tasks.
    """

    def __init__(self, state_dim, n_actions, hidden_dim=256):
        super().__init__()

        # Shared feature extractor
        self.backbone = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
        )

        # Actor head: outputs action logits
        self.actor_head = nn.Linear(hidden_dim, n_actions)

        # Critic head: outputs scalar state-value V(s)
        self.critic_head = nn.Linear(hidden_dim, 1)

        # Initialize output layers with small weights
        nn.init.orthogonal_(self.actor_head.weight, gain=0.01)
        nn.init.orthogonal_(self.critic_head.weight, gain=1.0)

    def forward(self, x):
        """Returns (Categorical distribution, value tensor)."""
        features = self.backbone(x)
        dist = Categorical(logits=self.actor_head(features))
        value = self.critic_head(features).squeeze(-1)
        return dist, value

    def act(self, state):
        """Sample action and return (action, log_prob, value)."""
        s = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
        dist, value = self(s)
        action = dist.sample()
        return action.item(), dist.log_prob(action).squeeze(), value.squeeze()

    def evaluate(self, states):
        """Batch evaluation for gradient update."""
        return self(states)


torch.manual_seed(42)
net = ActorCriticNet(state_dim=8, n_actions=4)  # e.g., LunarLander
total_params = sum(p.numel() for p in net.parameters())
print(f"Total parameters: {total_params:,}")

# Test single step
s = torch.randn(8).numpy()
action, log_prob, value = net.act(s)
print(f"Action: {action}, log π(a|s): {log_prob.item():.4f}, V(s): {value.item():.4f}")

# Test batch forward
batch_states = torch.randn(32, 8)
dist_batch, values_batch = net(batch_states)
print(f"Batch: dist entropy mean = {dist_batch.entropy().mean():.4f}, values shape = {values_batch.shape}")

Generalized Advantage Estimation (GAE)

import torch


def compute_gae(rewards, values, next_value, dones, gamma=0.99, lam=0.95):
    """
    Generalized Advantage Estimation (Schulman 2015).
    Advantage_t = sum_{l=0}^{T-t} (gamma * lam)^l * delta_{t+l}
    where delta_t = r_t + gamma * V(s_{t+1}) * (1-done) - V(s_t)

    lam=1.0  -> Monte Carlo (high variance, low bias)
    lam=0.0  -> 1-step TD (low variance, high bias)
    lam=0.95 -> good empirical tradeoff (PPO paper)
    """
    T = len(rewards)
    advantages = torch.zeros(T)
    gae = 0.0

    # Bootstrap from next_value (0 if terminal)
    all_values = values.tolist() + [next_value.item()]

    for t in reversed(range(T)):
        delta = rewards[t] + gamma * all_values[t+1] * (1 - dones[t]) - all_values[t]
        gae = delta + gamma * lam * (1 - dones[t]) * gae
        advantages[t] = gae

    returns = advantages + values  # V(s) + A(s,a) = Q(s,a) estimate
    return advantages, returns


# Demo: compare different lambda values
torch.manual_seed(42)
T = 20
rewards = torch.rand(T) * 2 - 1       # rewards in [-1, 1]
values = torch.rand(T) * 2             # estimated values
dones = (torch.rand(T) < 0.1).float() # 10% terminal
next_value = torch.tensor(0.5)

print(f"{'t':>3} | {'reward':>8} | {'A_lam=0.0':>10} | {'A_lam=0.95':>11} | {'A_lam=1.0':>10}")
adv_td, _ = compute_gae(rewards, values, next_value, dones, gamma=0.99, lam=0.0)
adv_gae, _ = compute_gae(rewards, values, next_value, dones, gamma=0.99, lam=0.95)
adv_mc, _ = compute_gae(rewards, values, next_value, dones, gamma=0.99, lam=1.0)
for t in range(min(8, T)):
    print(f"{t:>3} | {rewards[t].item():>8.4f} | {adv_td[t].item():>10.4f} | {adv_gae[t].item():>11.4f} | {adv_mc[t].item():>10.4f}")

print(f"\nStd (lower = more stable gradient):")
print(f"  lam=0.0  (TD):  {adv_td.std().item():.4f}")
print(f"  lam=0.95 (GAE): {adv_gae.std().item():.4f}")
print(f"  lam=1.0  (MC):  {adv_mc.std().item():.4f}")

A2C Loss Function

Entropy Bonus

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical


def a2c_loss(log_probs, entropies, advantages, values, returns,
             value_coeff=0.5, entropy_coeff=0.01):
    """
    A2C combined loss.

    policy_loss: maximize E[log_prob * advantage] (gradient ascent direction)
    value_loss:  minimize MSE(V(s), G_t) (Huber for robustness)
    entropy_bonus: maximize H[pi] to encourage exploration (prevents premature convergence)

    entropy_coeff=0.01 is the standard starting point from the A3C paper.
    """
    # Normalize advantages (per-batch)
    adv = advantages.detach()
    adv = (adv - adv.mean()) / (adv.std() + 1e-8)

    policy_loss = -(log_probs * adv).mean()
    value_loss = F.huber_loss(values, returns.detach())
    entropy_loss = -entropies.mean()  # Negative: we want to maximize entropy

    total = policy_loss + value_coeff * value_loss + entropy_coeff * entropy_loss
    return total, {
        'policy': policy_loss.item(),
        'value': value_loss.item(),
        'entropy': entropies.mean().item(),
    }


# Demo: show entropy effect
torch.manual_seed(42)
T = 32

# High-entropy policy (uniform-ish)
logits_uniform = torch.zeros(T, 4)
dist_uniform = Categorical(logits=logits_uniform)
entropy_uniform = dist_uniform.entropy().mean().item()

# Low-entropy policy (peaked)
logits_peaked = torch.zeros(T, 4); logits_peaked[:, 0] = 5.0
dist_peaked = Categorical(logits=logits_peaked)
entropy_peaked = dist_peaked.entropy().mean().item()

print(f"Uniform policy entropy: {entropy_uniform:.4f}  (max = {torch.log(torch.tensor(4.0)).item():.4f})")
print(f"Peaked policy entropy:  {entropy_peaked:.4f}  (min = 0.0)")
print(f"Entropy bonus encourages staying closer to uniform distribution")

# Synthetic A2C loss demo
advantages = torch.randn(T)
returns = torch.rand(T)
values = torch.rand(T)
log_probs = dist_uniform.log_prob(dist_uniform.sample())
entropies = dist_uniform.entropy()

total, breakdown = a2c_loss(
    log_probs=dist_uniform.log_prob(torch.zeros(T, dtype=torch.long)),
    entropies=entropies,
    advantages=advantages,
    values=values,
    returns=returns,
)
print(f"\nA2C loss breakdown: {breakdown}")

N-Step Rollout Collection

import torch
import torch.nn as nn
from torch.distributions import Categorical


class ActorCriticNet(nn.Module):
    def __init__(self, s, a, h=256):
        super().__init__()
        self.backbone = nn.Sequential(nn.Linear(s, h), nn.Tanh(), nn.Linear(h, h), nn.Tanh())
        self.actor_head = nn.Linear(h, a)
        self.critic_head = nn.Linear(h, 1)
    def forward(self, x):
        f = self.backbone(x)
        return Categorical(logits=self.actor_head(f)), self.critic_head(f).squeeze(-1)
    def act(self, s):
        d, v = self(torch.as_tensor(s, dtype=torch.float32).unsqueeze(0))
        a = d.sample()
        return a.item(), d.log_prob(a).squeeze(), d.entropy().squeeze(), v.squeeze()


def collect_rollout(env, net, n_steps=128, gamma=0.99, lam=0.95):
    """
    Collect n_steps of transitions and compute GAE advantages.
    Returns tensors ready for gradient update.
    """
    states, log_probs, entropies = [], [], []
    values, rewards, dones = [], [], []

    s = env.reset()
    for _ in range(n_steps):
        a, lp, ent, v = net.act(s)
        ns, r, done = env.step(a)

        states.append(torch.as_tensor(s, dtype=torch.float32))
        log_probs.append(lp); entropies.append(ent)
        values.append(v); rewards.append(float(r)); dones.append(float(done))

        s = env.reset() if done else ns

    # Bootstrap value from final state
    with torch.no_grad():
        _, next_v = net(torch.as_tensor(s, dtype=torch.float32).unsqueeze(0))
    next_v = next_v.squeeze()

    # GAE
    advantages = torch.zeros(n_steps)
    gae = 0.0
    all_v = torch.stack(values).tolist() + [next_v.item()]
    for t in reversed(range(n_steps)):
        delta = rewards[t] + gamma * all_v[t+1] * (1 - dones[t]) - all_v[t]
        gae = delta + gamma * lam * (1 - dones[t]) * gae
        advantages[t] = gae

    values_t = torch.stack(values)
    returns = advantages + values_t

    return (
        torch.stack(states),
        torch.stack(log_probs),
        torch.stack(entropies),
        values_t,
        advantages,
        returns,
    )


# Demo: collect rollout from simple env
class SimpleEnv:
    def reset(self): self.x = 0.0; return [self.x, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    def step(self, a):
        self.x += 0.1 * (1 if a % 2 == 1 else -1)
        done = abs(self.x) > 1.0
        return [self.x]+[0]*7, (1.0 if self.x>1 else -1.0 if self.x<-1 else 0.01), done

torch.manual_seed(42)
env = SimpleEnv(); net = ActorCriticNet(8, 4)
states, lps, ents, vals, advs, rets = collect_rollout(env, net, n_steps=64)
print(f"Rollout shapes: states={states.shape}, advantages={advs.shape}, returns={rets.shape}")
print(f"Advantage stats: mean={advs.mean():.4f}, std={advs.std():.4f}")
print(f"Value stats:     mean={vals.mean():.4f}, std={vals.std():.4f}")

Full A2C Agent

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical


class ActorCriticNet(nn.Module):
    def __init__(self, s, a, h=128):
        super().__init__()
        self.backbone = nn.Sequential(nn.Linear(s,h), nn.Tanh(), nn.Linear(h,h), nn.Tanh())
        self.actor = nn.Linear(h, a)
        self.critic = nn.Linear(h, 1)
        nn.init.orthogonal_(self.actor.weight, 0.01)
        nn.init.orthogonal_(self.critic.weight, 1.0)
    def forward(self, x):
        f = self.backbone(x)
        return Categorical(logits=self.actor(f)), self.critic(f).squeeze(-1)


class A2CAgent:
    """Full Advantage Actor-Critic agent."""

    def __init__(self, state_dim, n_actions, lr=7e-4, gamma=0.99, lam=0.95,
                 n_steps=128, value_coeff=0.5, entropy_coeff=0.01, seed=42):
        torch.manual_seed(seed)
        self.gamma, self.lam = gamma, lam
        self.n_steps = n_steps
        self.vc, self.ec = value_coeff, entropy_coeff
        self.net = ActorCriticNet(state_dim, n_actions)
        self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=lr, alpha=0.99, eps=1e-5)

    def collect_and_update(self, env):
        states, log_probs, entropies, values, rewards, dones = [], [], [], [], [], []
        s = env.reset()
        for _ in range(self.n_steps):
            s_t = torch.as_tensor(s, dtype=torch.float32).unsqueeze(0)
            dist, v = self.net(s_t)
            a = dist.sample()
            ns, r, done = env.step(a.item())
            states.append(s_t.squeeze()); log_probs.append(dist.log_prob(a).squeeze())
            entropies.append(dist.entropy().squeeze()); values.append(v.squeeze())
            rewards.append(float(r)); dones.append(float(done))
            s = env.reset() if done else ns

        with torch.no_grad():
            _, next_v = self.net(torch.as_tensor(s, dtype=torch.float32).unsqueeze(0))

        # GAE
        advs = torch.zeros(self.n_steps); gae = 0.0
        all_v = [v.item() for v in values] + [next_v.item()]
        for t in reversed(range(self.n_steps)):
            d = rewards[t] + self.gamma * all_v[t+1] * (1-dones[t]) - all_v[t]
            gae = d + self.gamma * self.lam * (1-dones[t]) * gae
            advs[t] = gae

        vals_t = torch.stack(values); rets = advs + vals_t
        lps_t = torch.stack(log_probs); ents_t = torch.stack(entropies)
        norm_adv = (advs - advs.mean()) / (advs.std() + 1e-8)

        p_loss = -(lps_t * norm_adv.detach()).mean()
        v_loss = F.huber_loss(vals_t, rets.detach())
        e_loss = -ents_t.mean()
        total = p_loss + self.vc * v_loss + self.ec * e_loss

        self.optimizer.zero_grad(); total.backward()
        nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
        self.optimizer.step()
        return total.item(), ents_t.mean().item()


# Training benchmark on simple env
class SimpleEnv:
    def reset(self): self.x=0.0; return [self.x]+[0]*3
    def step(self, a):
        self.x += 0.1*(1 if a==1 else -1)
        done = abs(self.x) > 1.0
        return [self.x]+[0]*3, (1.0 if self.x>1 else -1.0 if self.x<-1 else 0.01), done

torch.manual_seed(42)
env = SimpleEnv(); agent = A2CAgent(4, 2, n_steps=64, seed=42)
losses, entropies = [], []
for update in range(100):
    loss, entropy = agent.collect_and_update(env)
    losses.append(loss); entropies.append(entropy)
    if (update + 1) % 20 == 0:
        print(f"Update {update+1:3d}: loss={loss:.4f}, entropy={entropy:.4f}")

print(f"\nFinal 20-update avg loss:    {sum(losses[-20:])/20:.4f}")
print(f"Final entropy (exploration): {entropies[-1]:.4f}")