Back to PyTorch Mastery Series

Q-Learning & Temporal Difference Methods

May 29, 2026 Wasil Zafar 28 min read

Implement Q-learning (off-policy) and SARSA (on-policy) from scratch, understand the TD error update rule, and compare how the two algorithms converge to different policies under the same epsilon-greedy behavior.

Table of Contents

  1. Temporal Difference Error
  2. Q-Learning (Off-Policy)
  3. SARSA (On-Policy)
  4. Q-Learning vs SARSA
  5. Related Articles

Temporal Difference Error

TD learning bootstraps: it updates estimates using other estimates, without waiting for episode completion. The TD error $\delta_t$ is the discrepancy between current and target estimates:

$$\delta_t = R_{t+1} + \gamma \max_a Q(S_{t+1}, a) - Q(S_t, A_t)$$

This is the Q-learning TD error. SARSA replaces $\max_a Q(S_{t+1}, a)$ with $Q(S_{t+1}, A_{t+1})$ — the actual next action taken.

Q-Learning (Off-Policy)

Q-Update Rule

import torch
import random


class GridWorld:
    """4x4 deterministic GridWorld: goal at state 15 (bottom-right)."""
    def __init__(self):
        self.n_states = 16; self.n_actions = 4
        self._d = [(-1,0),(0,1),(1,0),(0,-1)]
    def step(self, s, a):
        if s == 15: return 15, 0.0, True
        r, c = s // 4, s % 4
        nr, nc = max(0,min(3,r+self._d[a][0])), max(0,min(3,c+self._d[a][1]))
        ns = nr * 4 + nc
        return ns, (1.0 if ns==15 else -0.01), ns==15
    def reset(self): return 0


def q_learning(env, n_episodes=2000, alpha=0.1, gamma=0.9,
               eps_start=1.0, eps_end=0.01, seed=42):
    """
    Q-Learning: off-policy TD control.
    Update: Q(s,a) <- Q(s,a) + alpha * [R + gamma * max_a' Q(s',a') - Q(s,a)]
    """
    random.seed(seed); torch.manual_seed(seed)
    Q = torch.zeros(env.n_states, env.n_actions)
    episode_rewards = []

    for ep in range(n_episodes):
        eps = max(eps_end, eps_start - (eps_start - eps_end) * ep / n_episodes)
        s = env.reset()
        total_r = 0.0

        for _ in range(200):  # Max steps per episode
            # Epsilon-greedy action selection
            a = random.randint(0, 3) if random.random() < eps else Q[s].argmax().item()
            ns, r, done = env.step(s, a)

            # Q-Learning update (off-policy: bootstrap with MAX over next actions)
            td_target = r + gamma * Q[ns].max().item() * (1 - float(done))
            td_error = td_target - Q[s, a].item()
            Q[s, a] += alpha * td_error  # In-place scalar update

            total_r += r
            s = ns
            if done: break

        episode_rewards.append(total_r)

    return Q, episode_rewards


env = GridWorld()
Q, rewards = q_learning(env, n_episodes=3000)

print("Q-Learning converged policy (greedy):")
action_names = ['U', 'R', 'D', 'L']
for r in range(4):
    print([action_names[Q[r*4+c].argmax().item()] for c in range(4)])

# Evaluate greedy policy
def evaluate_policy(env, Q, n_eval=100):
    wins = 0
    for _ in range(n_eval):
        s = env.reset()
        for _ in range(50):
            a = Q[s].argmax().item()
            s, _, done = env.step(s, a)
            if done: wins += 1; break
    return wins / n_eval

print(f"\nSuccess rate: {evaluate_policy(env, Q)*100:.1f}%")
print(f"Avg reward (last 500 eps): {sum(rewards[-500:])/500:.4f}")

Convergence Across Learning Rates

import torch
import random


class GridWorld:
    def __init__(self):
        self.n_states = 16; self.n_actions = 4
        self._d = [(-1,0),(0,1),(1,0),(0,-1)]
    def step(self, s, a):
        if s == 15: return 15, 0.0, True
        r, c = s//4, s%4
        nr, nc = max(0,min(3,r+self._d[a][0])), max(0,min(3,c+self._d[a][1]))
        ns = nr*4+nc
        return ns, (1.0 if ns==15 else -0.01), ns==15
    def reset(self): return 0


env = GridWorld()
print("Learning rate comparison (3000 episodes, last 500 avg reward):")
print(f"{'Alpha':>6} | {'Avg Reward':>12} | {'Success Rate':>14}")

for alpha in [0.01, 0.05, 0.1, 0.3, 0.5, 0.9]:
    random.seed(42); torch.manual_seed(42)
    Q = torch.zeros(16, 4)
    total_rewards = []
    for ep in range(3000):
        eps = max(0.01, 1.0 - ep / 3000)
        s = 0; tot = 0.0
        for _ in range(200):
            a = random.randint(0,3) if random.random() < eps else Q[s].argmax().item()
            ns, r, done = env.step(s, a)
            td = r + 0.9 * Q[ns].max().item() * (1-float(done)) - Q[s,a].item()
            Q[s,a] += alpha * td
            tot += r; s = ns
            if done: break
        total_rewards.append(tot)

    avg_r = sum(total_rewards[-500:])/500
    wins = sum(1 for _ in range(200)
               for __ in range(lambda s=0: (lambda: next(
                   (1 for _ in range(50) for s2 in [Q[s].argmax().item()] for ns,r,done in [env.step(s,a)]
                   ), 0))())
               ) / 200  # Simplified: just use Q greedy
    # Quick success eval
    ok = 0
    for _ in range(200):
        s2 = 0
        for _ in range(50):
            a2 = Q[s2].argmax().item()
            s2, _, d2 = env.step(s2, a2)
            if d2: ok += 1; break
    print(f"{alpha:>6.2f} | {avg_r:>12.4f} | {ok/200*100:>13.1f}%")

SARSA (On-Policy)

import torch
import random


class GridWorld:
    def __init__(self):
        self.n_states = 16; self.n_actions = 4
        self._d = [(-1,0),(0,1),(1,0),(0,-1)]
    def step(self, s, a):
        if s == 15: return 15, 0.0, True
        r, c = s//4, s%4
        nr, nc = max(0,min(3,r+self._d[a][0])), max(0,min(3,c+self._d[a][1]))
        ns = nr*4+nc
        return ns, (1.0 if ns==15 else -0.01), ns==15
    def reset(self): return 0


def sarsa(env, n_episodes=3000, alpha=0.1, gamma=0.9,
          eps_start=1.0, eps_end=0.01, seed=42):
    """
    SARSA: on-policy TD control.
    Update: Q(s,a) <- Q(s,a) + alpha * [R + gamma * Q(s', a') - Q(s,a)]
    Key difference: a' is the ACTUAL next action (same epsilon-greedy policy)
    """
    random.seed(seed); torch.manual_seed(seed)
    Q = torch.zeros(env.n_states, env.n_actions)
    episode_rewards = []

    def eps_greedy(s, eps):
        if random.random() < eps: return random.randint(0, 3)
        return Q[s].argmax().item()

    for ep in range(n_episodes):
        eps = max(eps_end, eps_start - (eps_start - eps_end) * ep / n_episodes)
        s = env.reset()
        a = eps_greedy(s, eps)  # Select FIRST action
        total_r = 0.0

        for _ in range(200):
            ns, r, done = env.step(s, a)
            a_next = eps_greedy(ns, eps)  # Select NEXT action (same policy)

            # SARSA update: uses actual next action, not greedy next action
            td_target = r + gamma * Q[ns, a_next].item() * (1 - float(done))
            Q[s, a] += alpha * (td_target - Q[s, a].item())

            total_r += r
            s, a = ns, a_next
            if done: break

        episode_rewards.append(total_r)

    return Q, episode_rewards


env = GridWorld()
Q_sarsa, rewards_sarsa = sarsa(env)

print("SARSA converged policy (greedy):")
action_names = ['U', 'R', 'D', 'L']
for r in range(4):
    print([action_names[Q_sarsa[r*4+c].argmax().item()] for c in range(4)])
print(f"Avg reward (last 500 eps): {sum(rewards_sarsa[-500:])/500:.4f}")

Q-Learning vs SARSA Comparison

Key Difference On-Policy vs Off-Policy

On-Policy vs Off-Policy TD Control

Q-Learning is off-policy: it learns the value of the greedy policy regardless of the behavior policy used for exploration. The target uses $\max_a Q(s', a)$.

SARSA is on-policy: it learns the value of the current epsilon-greedy policy. The target uses $Q(s', a')$ where $a'$ is actually taken. This makes SARSA more conservative near cliffs or dangers.

Off-Policy On-Policy Convergence
import torch
import random


class GridWorld:
    def __init__(self):
        self.n_states = 16; self.n_actions = 4
        self._d = [(-1,0),(0,1),(1,0),(0,-1)]
    def step(self, s, a):
        if s == 15: return 15, 0.0, True
        r, c = s//4, s%4
        nr, nc = max(0,min(3,r+self._d[a][0])), max(0,min(3,c+self._d[a][1]))
        ns = nr*4+nc
        return ns, (1.0 if ns==15 else -0.01), ns==15
    def reset(self): return 0


def compare_convergence(env, n_episodes=3000):
    """Run Q-Learning and SARSA and compare convergence curves."""
    results = {}
    for method in ['q-learning', 'sarsa']:
        random.seed(42); torch.manual_seed(42)
        Q = torch.zeros(16, 4)
        window = []

        for ep in range(n_episodes):
            eps = max(0.01, 1.0 - ep / n_episodes)
            s = env.reset()
            tot = 0.0
            a = random.randint(0,3) if random.random() < eps else Q[s].argmax().item()

            for _ in range(200):
                ns, r, done = env.step(s, a)
                a_next = random.randint(0,3) if random.random() < eps else Q[ns].argmax().item()

                if method == 'q-learning':
                    target = r + 0.9 * Q[ns].max().item() * (1-float(done))
                else:
                    target = r + 0.9 * Q[ns, a_next].item() * (1-float(done))

                Q[s, a] += 0.1 * (target - Q[s, a].item())
                tot += r; s = ns; a = a_next
                if done: break
            window.append(tot)

        # Running average
        avg_by_100 = [sum(window[i:i+100])/100 for i in range(0, n_episodes, 100)]
        results[method] = avg_by_100

    print(f"{'Episode':>8} | {'Q-Learning':>12} | {'SARSA':>10}")
    for i, ep in enumerate(range(100, n_episodes+1, 100)):
        if i % 5 == 0:  # Print every 500 episodes
            print(f"{ep:>8} | {results['q-learning'][i]:>12.4f} | {results['sarsa'][i]:>10.4f}")


env = GridWorld()
compare_convergence(env)