Back to PyTorch Mastery Series

Policy Gradient Methods & REINFORCE

May 29, 2026 Wasil Zafar 32 min read

Policy gradients directly optimize the policy without a Q-table or value function. Implement REINFORCE from the policy gradient theorem, understand how log-probabilities flow gradients through stochastic actions, and reduce variance with a learned baseline.

Table of Contents

  1. Policy Gradient Theorem
  2. Policy Network
  3. REINFORCE Algorithm
  4. Variance Reduction: Baseline
  5. Full REINFORCE Agent
  6. Related Articles

Policy Gradient Theorem

The policy gradient theorem gives us an unbiased gradient estimator for the expected return $J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[G_0]$:

$$\nabla_\theta J(\theta) = \mathbb{E}_\tau\left[\sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(A_t|S_t) \cdot G_t\right]$$

where $G_t = \sum_{k=t}^{T} \gamma^{k-t} R_{k+1}$ is the discounted return from step $t$. The key insight: we can backpropagate through $\log \pi_\theta$ even though the action was sampled stochastically.

Policy Network

Discrete Actions: Categorical Distribution

import torch
import torch.nn as nn
from torch.distributions import Categorical


class DiscretePolicy(nn.Module):
    """
    Softmax policy for discrete action spaces.
    Outputs action probabilities via Categorical distribution.
    """

    def __init__(self, state_dim, n_actions, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),  # Tanh works better than ReLU for policy nets (bounded outputs)
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, n_actions),
        )

    def forward(self, x):
        """Returns Categorical distribution over actions."""
        logits = self.net(x)
        return Categorical(logits=logits)

    def act(self, state):
        """Sample action and return (action, log_prob)."""
        state_t = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
        dist = self(state_t)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob


torch.manual_seed(42)
policy = DiscretePolicy(state_dim=4, n_actions=2)

# Demo: sample actions
state = torch.randn(4)
for _ in range(5):
    action, log_prob = policy.act(state.numpy())
    print(f"Action: {action}, log π(a|s): {log_prob.item():.4f}, π(a|s): {log_prob.exp().item():.4f}")

Continuous Actions: Gaussian Distribution

import torch
import torch.nn as nn
from torch.distributions import Normal


class GaussianPolicy(nn.Module):
    """
    Diagonal Gaussian policy for continuous action spaces.
    Parameterizes mean and log-std separately.
    """

    def __init__(self, state_dim, action_dim, hidden_dim=128,
                 log_std_min=-20, log_std_max=2):
        super().__init__()
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

        self.backbone = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
        )
        self.mean_head = nn.Linear(hidden_dim, action_dim)
        self.log_std_head = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        features = self.backbone(x)
        mean = self.mean_head(features)
        log_std = self.log_std_head(features).clamp(self.log_std_min, self.log_std_max)
        return Normal(mean, log_std.exp())

    def act(self, state):
        s = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
        dist = self(s)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)  # Sum over action dims
        return action.squeeze(0), log_prob


torch.manual_seed(42)
policy = GaussianPolicy(state_dim=8, action_dim=3)  # e.g., LunarLanderContinuous

state = torch.randn(8).numpy()
action, log_prob = policy.act(state)
print(f"Action shape: {action.shape}")
print(f"Action: {action.tolist()}")
print(f"log π(a|s): {log_prob.item():.4f}")

REINFORCE Algorithm

Computing Discounted Returns

import torch


def compute_returns(rewards, gamma=0.99, normalize=True):
    """
    Compute discounted returns G_t for each timestep t.
    G_t = R_{t+1} + gamma * R_{t+2} + gamma^2 * R_{t+3} + ...

    normalize=True: subtract mean, divide by std (reduces gradient variance)
    """
    G = 0.0
    returns = []
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)

    returns = torch.tensor(returns, dtype=torch.float32)

    if normalize and len(returns) > 1:
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)

    return returns


# Demo: 10-step episode with increasing rewards
rewards = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]

returns_raw = compute_returns(rewards, gamma=0.99, normalize=False)
returns_norm = compute_returns(rewards, gamma=0.99, normalize=True)

print("t  | reward | G_t (raw) | G_t (normalized)")
for t, (r, g_raw, g_norm) in enumerate(zip(rewards, returns_raw, returns_norm)):
    print(f"{t:2d} | {r:6.1f} | {g_raw.item():9.4f} | {g_norm.item():16.4f}")

REINFORCE Gradient Update

import torch
import torch.nn as nn
from torch.distributions import Categorical


class DiscretePolicy(nn.Module):
    def __init__(self, s, a, h=128):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(s, h), nn.Tanh(), nn.Linear(h, h), nn.Tanh(), nn.Linear(h, a))
    def forward(self, x): return Categorical(logits=self.net(x))
    def act(self, state):
        d = self(torch.as_tensor(state, dtype=torch.float32).unsqueeze(0))
        a = d.sample()
        return a.item(), d.log_prob(a)


def reinforce_update(log_probs, returns, optimizer):
    """
    REINFORCE loss = -sum_t [ log pi(a_t|s_t) * G_t ]
    Negative because we ascend the gradient (maximize J),
    but PyTorch minimizes by convention.
    """
    # Ensure returns are normalized
    returns_t = torch.stack(returns) if isinstance(returns[0], torch.Tensor) else torch.tensor(returns, dtype=torch.float32)
    returns_t = (returns_t - returns_t.mean()) / (returns_t.std() + 1e-8)

    log_probs_t = torch.stack(log_probs)

    # Policy gradient loss: -E[log pi * G]
    policy_loss = -(log_probs_t * returns_t).mean()

    optimizer.zero_grad()
    policy_loss.backward()
    # Optional: clip gradients for stability
    nn.utils.clip_grad_norm_(optimizer.param_groups[0]['params'], max_norm=0.5)
    optimizer.step()

    return policy_loss.item()


# Verify gradient direction: higher return should increase log_prob
torch.manual_seed(42)
policy = DiscretePolicy(4, 2)
optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4)

state = [0.1, -0.2, 0.3, 0.05]
_, log_prob = policy.act(state)
action_before = torch.tensor([0.1, -0.2, 0.3, 0.05])
probs_before = policy(torch.tensor([[0.1, -0.2, 0.3, 0.05]])).probs.detach().clone()

# Positive return: should reinforce the taken action
loss = reinforce_update([log_prob], [10.0], optimizer)
probs_after = policy(torch.tensor([[0.1, -0.2, 0.3, 0.05]])).probs.detach()

print(f"Loss: {loss:.4f}")
print(f"Probs before update: {probs_before.squeeze().tolist()}")
print(f"Probs after  update: {probs_after.squeeze().tolist()}")

Variance Reduction with a Baseline

import torch
import torch.nn as nn
from torch.distributions import Categorical


class PolicyWithBaseline(nn.Module):
    """
    Policy network with a shared backbone and a separate value baseline head.
    The baseline V(s) estimates expected return, reducing gradient variance.
    Loss = policy_loss + baseline_coeff * value_loss
    """

    def __init__(self, state_dim, n_actions, hidden_dim=128):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
        )
        self.policy_head = nn.Linear(hidden_dim, n_actions)
        self.value_head = nn.Linear(hidden_dim, 1)  # Baseline: V(s)

    def forward(self, x):
        features = self.backbone(x)
        dist = Categorical(logits=self.policy_head(features))
        value = self.value_head(features).squeeze(-1)
        return dist, value

    def act(self, state):
        s = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
        dist, value = self(s)
        action = dist.sample()
        return action.item(), dist.log_prob(action), value


def reinforce_with_baseline(log_probs, values, rewards, optimizer, gamma=0.99, beta=0.5):
    """
    REINFORCE with baseline:
    policy_loss = -sum_t [ log_prob * (G_t - V(s_t)) ]
    value_loss  = MSE(V(s_t), G_t)
    (G_t - V(s_t)) is the 'advantage' — advantage > 0: action was better than expected
    """
    G = 0.0; returns = []
    for r in reversed(rewards):
        G = r + gamma * G; returns.insert(0, G)
    returns_t = torch.tensor(returns, dtype=torch.float32)
    returns_t = (returns_t - returns_t.mean()) / (returns_t.std() + 1e-8)

    log_probs_t = torch.stack(log_probs)
    values_t = torch.stack(values)

    advantages = returns_t - values_t.detach()  # Stop gradient through baseline

    policy_loss = -(log_probs_t * advantages).mean()
    value_loss = nn.functional.mse_loss(values_t, returns_t)
    total_loss = policy_loss + beta * value_loss

    optimizer.zero_grad(); total_loss.backward()
    nn.utils.clip_grad_norm_(optimizer.param_groups[0]['params'], 0.5)
    optimizer.step()
    return total_loss.item(), policy_loss.item(), value_loss.item()


torch.manual_seed(42)
net = PolicyWithBaseline(4, 2)
opt = torch.optim.Adam(net.parameters(), lr=3e-4)

# Synthetic episode
states = [torch.randn(4) for _ in range(10)]
log_probs, values, rewards = [], [], []
for s in states:
    a, lp, v = net.act(s.numpy())
    log_probs.append(lp.squeeze()); values.append(v.squeeze())
    rewards.append(float(torch.rand(1).item()))  # random rewards

total, p_loss, v_loss = reinforce_with_baseline(log_probs, values, rewards, opt)
print(f"Total loss: {total:.4f}  (policy: {p_loss:.4f}, value: {v_loss:.4f})")

Full REINFORCE Agent

import torch
import torch.nn as nn
from torch.distributions import Categorical


class REINFORCEAgent:
    """Complete REINFORCE agent with optional baseline."""

    def __init__(self, state_dim, n_actions, lr=3e-4, gamma=0.99, use_baseline=True, seed=42):
        torch.manual_seed(seed)
        self.gamma = gamma; self.use_baseline = use_baseline
        self.backbone = nn.Sequential(nn.Linear(state_dim, 128), nn.Tanh(), nn.Linear(128, 128), nn.Tanh())
        self.policy_head = nn.Linear(128, n_actions)
        self.value_head = nn.Linear(128, 1) if use_baseline else None
        params = list(self.backbone.parameters()) + list(self.policy_head.parameters())
        if use_baseline: params += list(self.value_head.parameters())
        self.optimizer = torch.optim.Adam(params, lr=lr)
        self.reset_episode()

    def reset_episode(self):
        self.ep_log_probs, self.ep_values, self.ep_rewards = [], [], []

    def act(self, state):
        s = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
        f = self.backbone(s)
        dist = Categorical(logits=self.policy_head(f))
        a = dist.sample()
        self.ep_log_probs.append(dist.log_prob(a).squeeze())
        if self.use_baseline:
            self.ep_values.append(self.value_head(f).squeeze())
        return a.item()

    def store_reward(self, r): self.ep_rewards.append(float(r))

    def finish_episode(self):
        G = 0.0; returns = []
        for r in reversed(self.ep_rewards):
            G = r + self.gamma * G; returns.insert(0, G)
        R = torch.tensor(returns, dtype=torch.float32)
        if len(R) > 1: R = (R - R.mean()) / (R.std() + 1e-8)

        lp = torch.stack(self.ep_log_probs)
        if self.use_baseline:
            V = torch.stack(self.ep_values)
            adv = R - V.detach()
            loss = -(lp * adv).mean() + 0.5 * nn.functional.mse_loss(V, R)
        else:
            loss = -(lp * R).mean()

        self.optimizer.zero_grad(); loss.backward()
        nn.utils.clip_grad_norm_(self.optimizer.param_groups[0]['params'], 0.5)
        self.optimizer.step(); self.reset_episode()
        return loss.item()


# Benchmark on simple env
class SimpleEnv:
    def reset(self): self.x = 0.0; return [self.x, 0.0, 0.0, 0.0]
    def step(self, a):
        self.x += 0.1 * (1 if a==1 else -1)
        done = abs(self.x) > 1.0
        return [self.x,0,0,0], (1.0 if self.x>1.0 else -1.0 if self.x<-1.0 else 0.01), done

torch.manual_seed(42)
for use_bl in [False, True]:
    agent = REINFORCEAgent(4, 2, use_baseline=use_bl, seed=42)
    env = SimpleEnv(); ep_rewards = []
    for ep in range(400):
        s = env.reset(); total = 0.0
        for _ in range(50):
            a = agent.act(s); ns, r, done = env.step(a)
            agent.store_reward(r); total += r; s = ns
            if done: break
        agent.finish_episode(); ep_rewards.append(total)
    label = "with baseline" if use_bl else "no  baseline "
    print(f"REINFORCE {label} — last 100-ep avg: {sum(ep_rewards[-100:])/100:.4f}")