Back to PyTorch Mastery Series

Deep Q-Network (DQN) in PyTorch

May 29, 2026 Wasil Zafar 35 min read

Build DQN from scratch: neural network Q-function approximator, experience replay buffer, separate target network, and the training loop. Understand why each of the three key innovations (replay, target net, neural Q) was necessary to stabilize deep RL.

Table of Contents

  1. DQN Architecture
  2. Experience Replay Buffer
  3. Q-Network Implementation
  4. Training Loop
  5. Full DQN Agent
  6. Related Articles

DQN Architecture

DQN (Mnih et al., 2013) solved the instability of using neural networks as Q-function approximators with three innovations:

  1. Experience Replay — store and randomly sample transitions to break temporal correlations
  2. Target Network — a slowly-updated copy of the Q-network for stable TD targets
  3. Neural Q-Function — approximate $Q(s, \cdot)$ with one forward pass instead of per-action lookups

Experience Replay Buffer

import torch
from collections import deque
import random


class ReplayBuffer:
    """
    Circular buffer storing (s, a, r, s', done) transitions.
    Random sampling breaks temporal correlations between updates.
    """

    def __init__(self, capacity=10_000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        """Store transition. Oldest entries are discarded when full."""
        self.buffer.append((
            torch.as_tensor(state, dtype=torch.float32),
            int(action),
            float(reward),
            torch.as_tensor(next_state, dtype=torch.float32),
            float(done)
        ))

    def sample(self, batch_size):
        """Randomly sample a minibatch. Returns separate tensors for each field."""
        transitions = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*transitions)
        return (
            torch.stack(states),
            torch.tensor(actions, dtype=torch.long),
            torch.tensor(rewards, dtype=torch.float32),
            torch.stack(next_states),
            torch.tensor(dones, dtype=torch.float32),
        )

    def __len__(self):
        return len(self.buffer)


# Demo: fill buffer and sample
random.seed(42)
buf = ReplayBuffer(capacity=1000)

for _ in range(200):
    state = torch.randn(4).numpy()          # CartPole: 4-dim state
    action = random.randint(0, 1)            # 2 actions
    reward = random.uniform(-1, 1)
    next_state = torch.randn(4).numpy()
    done = random.random() < 0.05
    buf.push(state, action, reward, next_state, done)

print(f"Buffer size: {len(buf)}")
states, actions, rewards, next_states, dones = buf.sample(32)
print(f"Batch shapes — states: {states.shape}, actions: {actions.shape}, rewards: {rewards.shape}")
print(f"Reward range: [{rewards.min():.2f}, {rewards.max():.2f}]")

Q-Network Implementation

import torch
import torch.nn as nn


class QNetwork(nn.Module):
    """
    Neural network Q-function approximator.
    Input: state (batch, state_dim)
    Output: Q-values for ALL actions (batch, n_actions)
    Advantage: single forward pass gives Q-values for all actions simultaneously.
    """

    def __init__(self, state_dim, n_actions, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions),
        )

    def forward(self, x):
        return self.net(x)


# Demo: CartPole Q-network (4-dim state, 2 actions)
torch.manual_seed(42)
q_net = QNetwork(state_dim=4, n_actions=2, hidden_dim=128)
print(f"Q-Network parameters: {sum(p.numel() for p in q_net.parameters()):,}")

# Test forward pass
batch_states = torch.randn(32, 4)
q_values = q_net(batch_states)
print(f"Q-values shape: {q_values.shape}")  # (32, 2)
print(f"Greedy actions: {q_values.argmax(dim=1)[:8].tolist()}")

Training Loop

Target Network

import torch
import torch.nn as nn
import copy


class QNetwork(nn.Module):
    def __init__(self, state_dim, n_actions, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )
    def forward(self, x): return self.net(x)


def soft_update(online_net, target_net, tau=0.005):
    """
    Soft (Polyak) target network update:
    theta_target = tau * theta_online + (1 - tau) * theta_target
    tau=0.005 means target lags ~200 steps behind online network.
    """
    for tp, op in zip(target_net.parameters(), online_net.parameters()):
        tp.data.copy_(tau * op.data + (1 - tau) * tp.data)


def hard_update(online_net, target_net):
    """Hard copy: used every N steps (original DQN paper: every 10k steps)."""
    target_net.load_state_dict(online_net.state_dict())


torch.manual_seed(42)
online_net = QNetwork(4, 2)
target_net = copy.deepcopy(online_net)  # Exact copy initially

# Verify soft update moves target slowly toward online
dummy_state = torch.randn(1, 4)
q_online_before = online_net(dummy_state).detach().clone()

# Simulate a parameter update on online network
for p in online_net.parameters():
    p.data += 0.1 * torch.randn_like(p.data)  # Random shift

q_online_after = online_net(dummy_state).detach()
q_target_before = target_net(dummy_state).detach().clone()

soft_update(online_net, target_net, tau=0.005)
q_target_after = target_net(dummy_state).detach()

print(f"Online Q change:  {(q_online_after - q_online_before).abs().mean():.4f}")
print(f"Target Q change:  {(q_target_after - q_target_before).abs().mean():.6f}  (much smaller: lag)")

DQN Loss Computation

import torch
import torch.nn as nn
import torch.nn.functional as F


class QNetwork(nn.Module):
    def __init__(self, state_dim, n_actions, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )
    def forward(self, x): return self.net(x)


def compute_dqn_loss(online_net, target_net, batch, gamma=0.99):
    """
    DQN Huber loss.
    batch: (states, actions, rewards, next_states, dones)
    """
    states, actions, rewards, next_states, dones = batch

    # Q(s, a) — online net, only for taken actions
    q_current = online_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)

    # max Q(s', a') using TARGET net (stop gradient)
    with torch.no_grad():
        q_next = target_net(next_states).max(dim=1).values
        q_target = rewards + gamma * q_next * (1 - dones)

    # Huber loss is more robust to outlier TD errors than MSE
    loss = F.huber_loss(q_current, q_target)
    return loss


# Demo with synthetic batch
torch.manual_seed(42)
online_net = QNetwork(4, 2)
target_net = QNetwork(4, 2)
optimizer = torch.optim.Adam(online_net.parameters(), lr=1e-3)

# Synthetic batch (32 transitions)
batch = (
    torch.randn(32, 4),                              # states
    torch.randint(0, 2, (32,)),                      # actions
    torch.randn(32),                                  # rewards
    torch.randn(32, 4),                              # next states
    (torch.rand(32) < 0.05).float(),                 # dones (5% terminal)
)

for step in range(5):
    loss = compute_dqn_loss(online_net, target_net, batch)
    optimizer.zero_grad()
    loss.backward()
    # Gradient clipping — important for DQN stability
    nn.utils.clip_grad_norm_(online_net.parameters(), max_norm=10.0)
    optimizer.step()
    print(f"Step {step+1}: loss = {loss.item():.4f}")

Full DQN Agent

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import random
from collections import deque


class QNetwork(nn.Module):
    def __init__(self, s, a, h=128):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(s, h), nn.ReLU(), nn.Linear(h, h), nn.ReLU(), nn.Linear(h, a))
    def forward(self, x): return self.net(x)


class DQNAgent:
    """
    Full DQN agent with:
    - Neural Q-function (online + target networks)
    - Experience replay buffer
    - Soft target network updates
    - Epsilon-greedy exploration with decay
    """

    def __init__(self, state_dim, n_actions, lr=1e-3, gamma=0.99,
                 buffer_size=10_000, batch_size=64, tau=0.005, seed=42):
        random.seed(seed); torch.manual_seed(seed)
        self.n_actions = n_actions
        self.gamma = gamma
        self.batch_size = batch_size
        self.tau = tau

        self.online = QNetwork(state_dim, n_actions)
        self.target = copy.deepcopy(self.online)
        for p in self.target.parameters(): p.requires_grad = False

        self.optimizer = torch.optim.Adam(self.online.parameters(), lr=lr)
        self.buffer = deque(maxlen=buffer_size)

    def select_action(self, state, epsilon=0.1):
        if random.random() < epsilon:
            return random.randint(0, self.n_actions - 1)
        with torch.no_grad():
            s = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
            return self.online(s).argmax(1).item()

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((
            torch.as_tensor(state, dtype=torch.float32), int(action),
            float(reward), torch.as_tensor(next_state, dtype=torch.float32), float(done)
        ))

    def update(self):
        if len(self.buffer) < self.batch_size: return None
        transitions = random.sample(self.buffer, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*transitions)

        s = torch.stack(states)
        a = torch.tensor(actions, dtype=torch.long)
        r = torch.tensor(rewards, dtype=torch.float32)
        ns = torch.stack(next_states)
        d = torch.tensor(dones, dtype=torch.float32)

        q = self.online(s).gather(1, a.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            q_tgt = r + self.gamma * self.target(ns).max(1).values * (1 - d)

        loss = F.huber_loss(q, q_tgt)
        self.optimizer.zero_grad(); loss.backward()
        nn.utils.clip_grad_norm_(self.online.parameters(), 10.0)
        self.optimizer.step()

        # Soft target update
        for tp, op in zip(self.target.parameters(), self.online.parameters()):
            tp.data.copy_(self.tau * op.data + (1 - self.tau) * tp.data)
        return loss.item()


# Benchmark on synthetic GridWorld-style env
class SimpleEnv:
    """Simple 1D continuous env for testing DQN."""
    def reset(self): self.pos = 0.0; return [self.pos, 0.0, 0.0, 0.0]
    def step(self, a):
        self.pos += (1 if a == 1 else -1) * 0.1
        done = abs(self.pos) > 1.0
        r = 1.0 if self.pos > 1.0 else (-1.0 if self.pos < -1.0 else 0.01)
        return [self.pos, 0.0, 0.0, 0.0], r, done

torch.manual_seed(42)
env = SimpleEnv()
agent = DQNAgent(state_dim=4, n_actions=2)

losses = []; rewards_history = []
for ep in range(300):
    eps = max(0.05, 1.0 - ep / 200)
    s = env.reset(); total_r = 0.0
    for _ in range(50):
        a = agent.select_action(s, eps)
        ns, r, done = env.step(a)
        agent.push(s, a, r, ns, done)
        loss = agent.update()
        if loss is not None: losses.append(loss)
        total_r += r; s = ns
        if done: break
    rewards_history.append(total_r)

print(f"Final 100-ep avg reward: {sum(rewards_history[-100:])/100:.4f}")
print(f"Final 100-step avg loss: {sum(losses[-100:])/max(1,len(losses[-100:])):.4f}")