DQN Architecture
DQN (Mnih et al., 2013) solved the instability of using neural networks as Q-function approximators with three innovations:
- Experience Replay — store and randomly sample transitions to break temporal correlations
- Target Network — a slowly-updated copy of the Q-network for stable TD targets
- Neural Q-Function — approximate $Q(s, \cdot)$ with one forward pass instead of per-action lookups
Experience Replay Buffer
import torch
from collections import deque
import random
class ReplayBuffer:
"""
Circular buffer storing (s, a, r, s', done) transitions.
Random sampling breaks temporal correlations between updates.
"""
def __init__(self, capacity=10_000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
"""Store transition. Oldest entries are discarded when full."""
self.buffer.append((
torch.as_tensor(state, dtype=torch.float32),
int(action),
float(reward),
torch.as_tensor(next_state, dtype=torch.float32),
float(done)
))
def sample(self, batch_size):
"""Randomly sample a minibatch. Returns separate tensors for each field."""
transitions = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*transitions)
return (
torch.stack(states),
torch.tensor(actions, dtype=torch.long),
torch.tensor(rewards, dtype=torch.float32),
torch.stack(next_states),
torch.tensor(dones, dtype=torch.float32),
)
def __len__(self):
return len(self.buffer)
# Demo: fill buffer and sample
random.seed(42)
buf = ReplayBuffer(capacity=1000)
for _ in range(200):
state = torch.randn(4).numpy() # CartPole: 4-dim state
action = random.randint(0, 1) # 2 actions
reward = random.uniform(-1, 1)
next_state = torch.randn(4).numpy()
done = random.random() < 0.05
buf.push(state, action, reward, next_state, done)
print(f"Buffer size: {len(buf)}")
states, actions, rewards, next_states, dones = buf.sample(32)
print(f"Batch shapes — states: {states.shape}, actions: {actions.shape}, rewards: {rewards.shape}")
print(f"Reward range: [{rewards.min():.2f}, {rewards.max():.2f}]")
Q-Network Implementation
import torch
import torch.nn as nn
class QNetwork(nn.Module):
"""
Neural network Q-function approximator.
Input: state (batch, state_dim)
Output: Q-values for ALL actions (batch, n_actions)
Advantage: single forward pass gives Q-values for all actions simultaneously.
"""
def __init__(self, state_dim, n_actions, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions),
)
def forward(self, x):
return self.net(x)
# Demo: CartPole Q-network (4-dim state, 2 actions)
torch.manual_seed(42)
q_net = QNetwork(state_dim=4, n_actions=2, hidden_dim=128)
print(f"Q-Network parameters: {sum(p.numel() for p in q_net.parameters()):,}")
# Test forward pass
batch_states = torch.randn(32, 4)
q_values = q_net(batch_states)
print(f"Q-values shape: {q_values.shape}") # (32, 2)
print(f"Greedy actions: {q_values.argmax(dim=1)[:8].tolist()}")
Training Loop
Target Network
import torch
import torch.nn as nn
import copy
class QNetwork(nn.Module):
def __init__(self, state_dim, n_actions, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, n_actions)
)
def forward(self, x): return self.net(x)
def soft_update(online_net, target_net, tau=0.005):
"""
Soft (Polyak) target network update:
theta_target = tau * theta_online + (1 - tau) * theta_target
tau=0.005 means target lags ~200 steps behind online network.
"""
for tp, op in zip(target_net.parameters(), online_net.parameters()):
tp.data.copy_(tau * op.data + (1 - tau) * tp.data)
def hard_update(online_net, target_net):
"""Hard copy: used every N steps (original DQN paper: every 10k steps)."""
target_net.load_state_dict(online_net.state_dict())
torch.manual_seed(42)
online_net = QNetwork(4, 2)
target_net = copy.deepcopy(online_net) # Exact copy initially
# Verify soft update moves target slowly toward online
dummy_state = torch.randn(1, 4)
q_online_before = online_net(dummy_state).detach().clone()
# Simulate a parameter update on online network
for p in online_net.parameters():
p.data += 0.1 * torch.randn_like(p.data) # Random shift
q_online_after = online_net(dummy_state).detach()
q_target_before = target_net(dummy_state).detach().clone()
soft_update(online_net, target_net, tau=0.005)
q_target_after = target_net(dummy_state).detach()
print(f"Online Q change: {(q_online_after - q_online_before).abs().mean():.4f}")
print(f"Target Q change: {(q_target_after - q_target_before).abs().mean():.6f} (much smaller: lag)")
DQN Loss Computation
import torch
import torch.nn as nn
import torch.nn.functional as F
class QNetwork(nn.Module):
def __init__(self, state_dim, n_actions, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, n_actions)
)
def forward(self, x): return self.net(x)
def compute_dqn_loss(online_net, target_net, batch, gamma=0.99):
"""
DQN Huber loss.
batch: (states, actions, rewards, next_states, dones)
"""
states, actions, rewards, next_states, dones = batch
# Q(s, a) — online net, only for taken actions
q_current = online_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
# max Q(s', a') using TARGET net (stop gradient)
with torch.no_grad():
q_next = target_net(next_states).max(dim=1).values
q_target = rewards + gamma * q_next * (1 - dones)
# Huber loss is more robust to outlier TD errors than MSE
loss = F.huber_loss(q_current, q_target)
return loss
# Demo with synthetic batch
torch.manual_seed(42)
online_net = QNetwork(4, 2)
target_net = QNetwork(4, 2)
optimizer = torch.optim.Adam(online_net.parameters(), lr=1e-3)
# Synthetic batch (32 transitions)
batch = (
torch.randn(32, 4), # states
torch.randint(0, 2, (32,)), # actions
torch.randn(32), # rewards
torch.randn(32, 4), # next states
(torch.rand(32) < 0.05).float(), # dones (5% terminal)
)
for step in range(5):
loss = compute_dqn_loss(online_net, target_net, batch)
optimizer.zero_grad()
loss.backward()
# Gradient clipping — important for DQN stability
nn.utils.clip_grad_norm_(online_net.parameters(), max_norm=10.0)
optimizer.step()
print(f"Step {step+1}: loss = {loss.item():.4f}")
Full DQN Agent
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import random
from collections import deque
class QNetwork(nn.Module):
def __init__(self, s, a, h=128):
super().__init__()
self.net = nn.Sequential(nn.Linear(s, h), nn.ReLU(), nn.Linear(h, h), nn.ReLU(), nn.Linear(h, a))
def forward(self, x): return self.net(x)
class DQNAgent:
"""
Full DQN agent with:
- Neural Q-function (online + target networks)
- Experience replay buffer
- Soft target network updates
- Epsilon-greedy exploration with decay
"""
def __init__(self, state_dim, n_actions, lr=1e-3, gamma=0.99,
buffer_size=10_000, batch_size=64, tau=0.005, seed=42):
random.seed(seed); torch.manual_seed(seed)
self.n_actions = n_actions
self.gamma = gamma
self.batch_size = batch_size
self.tau = tau
self.online = QNetwork(state_dim, n_actions)
self.target = copy.deepcopy(self.online)
for p in self.target.parameters(): p.requires_grad = False
self.optimizer = torch.optim.Adam(self.online.parameters(), lr=lr)
self.buffer = deque(maxlen=buffer_size)
def select_action(self, state, epsilon=0.1):
if random.random() < epsilon:
return random.randint(0, self.n_actions - 1)
with torch.no_grad():
s = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
return self.online(s).argmax(1).item()
def push(self, state, action, reward, next_state, done):
self.buffer.append((
torch.as_tensor(state, dtype=torch.float32), int(action),
float(reward), torch.as_tensor(next_state, dtype=torch.float32), float(done)
))
def update(self):
if len(self.buffer) < self.batch_size: return None
transitions = random.sample(self.buffer, self.batch_size)
states, actions, rewards, next_states, dones = zip(*transitions)
s = torch.stack(states)
a = torch.tensor(actions, dtype=torch.long)
r = torch.tensor(rewards, dtype=torch.float32)
ns = torch.stack(next_states)
d = torch.tensor(dones, dtype=torch.float32)
q = self.online(s).gather(1, a.unsqueeze(1)).squeeze(1)
with torch.no_grad():
q_tgt = r + self.gamma * self.target(ns).max(1).values * (1 - d)
loss = F.huber_loss(q, q_tgt)
self.optimizer.zero_grad(); loss.backward()
nn.utils.clip_grad_norm_(self.online.parameters(), 10.0)
self.optimizer.step()
# Soft target update
for tp, op in zip(self.target.parameters(), self.online.parameters()):
tp.data.copy_(self.tau * op.data + (1 - self.tau) * tp.data)
return loss.item()
# Benchmark on synthetic GridWorld-style env
class SimpleEnv:
"""Simple 1D continuous env for testing DQN."""
def reset(self): self.pos = 0.0; return [self.pos, 0.0, 0.0, 0.0]
def step(self, a):
self.pos += (1 if a == 1 else -1) * 0.1
done = abs(self.pos) > 1.0
r = 1.0 if self.pos > 1.0 else (-1.0 if self.pos < -1.0 else 0.01)
return [self.pos, 0.0, 0.0, 0.0], r, done
torch.manual_seed(42)
env = SimpleEnv()
agent = DQNAgent(state_dim=4, n_actions=2)
losses = []; rewards_history = []
for ep in range(300):
eps = max(0.05, 1.0 - ep / 200)
s = env.reset(); total_r = 0.0
for _ in range(50):
a = agent.select_action(s, eps)
ns, r, done = env.step(a)
agent.push(s, a, r, ns, done)
loss = agent.update()
if loss is not None: losses.append(loss)
total_r += r; s = ns
if done: break
rewards_history.append(total_r)
print(f"Final 100-ep avg reward: {sum(rewards_history[-100:])/100:.4f}")
print(f"Final 100-step avg loss: {sum(losses[-100:])/max(1,len(losses[-100:])):.4f}")