Temporal Difference Error
TD learning bootstraps: it updates estimates using other estimates, without waiting for episode completion. The TD error $\delta_t$ is the discrepancy between current and target estimates:
$$\delta_t = R_{t+1} + \gamma \max_a Q(S_{t+1}, a) - Q(S_t, A_t)$$
This is the Q-learning TD error. SARSA replaces $\max_a Q(S_{t+1}, a)$ with $Q(S_{t+1}, A_{t+1})$ — the actual next action taken.
Q-Learning (Off-Policy)
Q-Update Rule
import torch
import random
class GridWorld:
"""4x4 deterministic GridWorld: goal at state 15 (bottom-right)."""
def __init__(self):
self.n_states = 16; self.n_actions = 4
self._d = [(-1,0),(0,1),(1,0),(0,-1)]
def step(self, s, a):
if s == 15: return 15, 0.0, True
r, c = s // 4, s % 4
nr, nc = max(0,min(3,r+self._d[a][0])), max(0,min(3,c+self._d[a][1]))
ns = nr * 4 + nc
return ns, (1.0 if ns==15 else -0.01), ns==15
def reset(self): return 0
def q_learning(env, n_episodes=2000, alpha=0.1, gamma=0.9,
eps_start=1.0, eps_end=0.01, seed=42):
"""
Q-Learning: off-policy TD control.
Update: Q(s,a) <- Q(s,a) + alpha * [R + gamma * max_a' Q(s',a') - Q(s,a)]
"""
random.seed(seed); torch.manual_seed(seed)
Q = torch.zeros(env.n_states, env.n_actions)
episode_rewards = []
for ep in range(n_episodes):
eps = max(eps_end, eps_start - (eps_start - eps_end) * ep / n_episodes)
s = env.reset()
total_r = 0.0
for _ in range(200): # Max steps per episode
# Epsilon-greedy action selection
a = random.randint(0, 3) if random.random() < eps else Q[s].argmax().item()
ns, r, done = env.step(s, a)
# Q-Learning update (off-policy: bootstrap with MAX over next actions)
td_target = r + gamma * Q[ns].max().item() * (1 - float(done))
td_error = td_target - Q[s, a].item()
Q[s, a] += alpha * td_error # In-place scalar update
total_r += r
s = ns
if done: break
episode_rewards.append(total_r)
return Q, episode_rewards
env = GridWorld()
Q, rewards = q_learning(env, n_episodes=3000)
print("Q-Learning converged policy (greedy):")
action_names = ['U', 'R', 'D', 'L']
for r in range(4):
print([action_names[Q[r*4+c].argmax().item()] for c in range(4)])
# Evaluate greedy policy
def evaluate_policy(env, Q, n_eval=100):
wins = 0
for _ in range(n_eval):
s = env.reset()
for _ in range(50):
a = Q[s].argmax().item()
s, _, done = env.step(s, a)
if done: wins += 1; break
return wins / n_eval
print(f"\nSuccess rate: {evaluate_policy(env, Q)*100:.1f}%")
print(f"Avg reward (last 500 eps): {sum(rewards[-500:])/500:.4f}")
Convergence Across Learning Rates
import torch
import random
class GridWorld:
def __init__(self):
self.n_states = 16; self.n_actions = 4
self._d = [(-1,0),(0,1),(1,0),(0,-1)]
def step(self, s, a):
if s == 15: return 15, 0.0, True
r, c = s//4, s%4
nr, nc = max(0,min(3,r+self._d[a][0])), max(0,min(3,c+self._d[a][1]))
ns = nr*4+nc
return ns, (1.0 if ns==15 else -0.01), ns==15
def reset(self): return 0
env = GridWorld()
print("Learning rate comparison (3000 episodes, last 500 avg reward):")
print(f"{'Alpha':>6} | {'Avg Reward':>12} | {'Success Rate':>14}")
for alpha in [0.01, 0.05, 0.1, 0.3, 0.5, 0.9]:
random.seed(42); torch.manual_seed(42)
Q = torch.zeros(16, 4)
total_rewards = []
for ep in range(3000):
eps = max(0.01, 1.0 - ep / 3000)
s = 0; tot = 0.0
for _ in range(200):
a = random.randint(0,3) if random.random() < eps else Q[s].argmax().item()
ns, r, done = env.step(s, a)
td = r + 0.9 * Q[ns].max().item() * (1-float(done)) - Q[s,a].item()
Q[s,a] += alpha * td
tot += r; s = ns
if done: break
total_rewards.append(tot)
avg_r = sum(total_rewards[-500:])/500
wins = sum(1 for _ in range(200)
for __ in range(lambda s=0: (lambda: next(
(1 for _ in range(50) for s2 in [Q[s].argmax().item()] for ns,r,done in [env.step(s,a)]
), 0))())
) / 200 # Simplified: just use Q greedy
# Quick success eval
ok = 0
for _ in range(200):
s2 = 0
for _ in range(50):
a2 = Q[s2].argmax().item()
s2, _, d2 = env.step(s2, a2)
if d2: ok += 1; break
print(f"{alpha:>6.2f} | {avg_r:>12.4f} | {ok/200*100:>13.1f}%")
SARSA (On-Policy)
import torch
import random
class GridWorld:
def __init__(self):
self.n_states = 16; self.n_actions = 4
self._d = [(-1,0),(0,1),(1,0),(0,-1)]
def step(self, s, a):
if s == 15: return 15, 0.0, True
r, c = s//4, s%4
nr, nc = max(0,min(3,r+self._d[a][0])), max(0,min(3,c+self._d[a][1]))
ns = nr*4+nc
return ns, (1.0 if ns==15 else -0.01), ns==15
def reset(self): return 0
def sarsa(env, n_episodes=3000, alpha=0.1, gamma=0.9,
eps_start=1.0, eps_end=0.01, seed=42):
"""
SARSA: on-policy TD control.
Update: Q(s,a) <- Q(s,a) + alpha * [R + gamma * Q(s', a') - Q(s,a)]
Key difference: a' is the ACTUAL next action (same epsilon-greedy policy)
"""
random.seed(seed); torch.manual_seed(seed)
Q = torch.zeros(env.n_states, env.n_actions)
episode_rewards = []
def eps_greedy(s, eps):
if random.random() < eps: return random.randint(0, 3)
return Q[s].argmax().item()
for ep in range(n_episodes):
eps = max(eps_end, eps_start - (eps_start - eps_end) * ep / n_episodes)
s = env.reset()
a = eps_greedy(s, eps) # Select FIRST action
total_r = 0.0
for _ in range(200):
ns, r, done = env.step(s, a)
a_next = eps_greedy(ns, eps) # Select NEXT action (same policy)
# SARSA update: uses actual next action, not greedy next action
td_target = r + gamma * Q[ns, a_next].item() * (1 - float(done))
Q[s, a] += alpha * (td_target - Q[s, a].item())
total_r += r
s, a = ns, a_next
if done: break
episode_rewards.append(total_r)
return Q, episode_rewards
env = GridWorld()
Q_sarsa, rewards_sarsa = sarsa(env)
print("SARSA converged policy (greedy):")
action_names = ['U', 'R', 'D', 'L']
for r in range(4):
print([action_names[Q_sarsa[r*4+c].argmax().item()] for c in range(4)])
print(f"Avg reward (last 500 eps): {sum(rewards_sarsa[-500:])/500:.4f}")
Q-Learning vs SARSA Comparison
On-Policy vs Off-Policy TD Control
Q-Learning is off-policy: it learns the value of the greedy policy regardless of the behavior policy used for exploration. The target uses $\max_a Q(s', a)$.
SARSA is on-policy: it learns the value of the current epsilon-greedy policy. The target uses $Q(s', a')$ where $a'$ is actually taken. This makes SARSA more conservative near cliffs or dangers.
import torch
import random
class GridWorld:
def __init__(self):
self.n_states = 16; self.n_actions = 4
self._d = [(-1,0),(0,1),(1,0),(0,-1)]
def step(self, s, a):
if s == 15: return 15, 0.0, True
r, c = s//4, s%4
nr, nc = max(0,min(3,r+self._d[a][0])), max(0,min(3,c+self._d[a][1]))
ns = nr*4+nc
return ns, (1.0 if ns==15 else -0.01), ns==15
def reset(self): return 0
def compare_convergence(env, n_episodes=3000):
"""Run Q-Learning and SARSA and compare convergence curves."""
results = {}
for method in ['q-learning', 'sarsa']:
random.seed(42); torch.manual_seed(42)
Q = torch.zeros(16, 4)
window = []
for ep in range(n_episodes):
eps = max(0.01, 1.0 - ep / n_episodes)
s = env.reset()
tot = 0.0
a = random.randint(0,3) if random.random() < eps else Q[s].argmax().item()
for _ in range(200):
ns, r, done = env.step(s, a)
a_next = random.randint(0,3) if random.random() < eps else Q[ns].argmax().item()
if method == 'q-learning':
target = r + 0.9 * Q[ns].max().item() * (1-float(done))
else:
target = r + 0.9 * Q[ns, a_next].item() * (1-float(done))
Q[s, a] += 0.1 * (target - Q[s, a].item())
tot += r; s = ns; a = a_next
if done: break
window.append(tot)
# Running average
avg_by_100 = [sum(window[i:i+100])/100 for i in range(0, n_episodes, 100)]
results[method] = avg_by_100
print(f"{'Episode':>8} | {'Q-Learning':>12} | {'SARSA':>10}")
for i, ep in enumerate(range(100, n_episodes+1, 100)):
if i % 5 == 0: # Print every 500 episodes
print(f"{ep:>8} | {results['q-learning'][i]:>12.4f} | {results['sarsa'][i]:>10.4f}")
env = GridWorld()
compare_convergence(env)