Back to PyTorch Mastery Series

Part 7: Transformers & Attention Mechanisms

May 3, 2026 Wasil Zafar 35 min read

Understand the architecture that powers modern AI — build attention mechanisms, positional encodings, Transformer encoder & decoder blocks, Vision Transformers, and efficient attention from scratch in PyTorch.

Table of Contents

  1. The Attention Revolution
  2. Self-Attention: Building Intuition
  3. Scaled Dot-Product Attention
  4. Multi-Head Attention
  5. Self-Attention vs Cross-Attention
  6. Positional Encoding
  7. Transformer Encoder Block
  8. Transformer Decoder Block
  9. Full Transformer Model
  10. Attention Masking
  11. Vision Transformers (ViT)
  12. Sparse & Efficient Attention
  13. Conclusion & Next Steps

The Attention Revolution

In Part 6, we built RNNs and LSTMs that process sequences one token at a time, carrying a hidden state forward like a person reading a book word by word while trying to remember everything. This works, but it has a critical bottleneck: the entire context of a 500-word paragraph must be squeezed into a single fixed-size hidden vector. By the time the RNN reaches the end of a long sequence, the information from the beginning has been diluted through hundreds of sequential updates. This is the information bottleneck problem.

In 2017, a team at Google published "Attention Is All You Need" — a paper that introduced the Transformer architecture. The key insight was radical: throw away recurrence entirely. Instead of processing tokens sequentially, let every token directly attend to every other token in the sequence simultaneously. This parallel processing not only solved the information bottleneck but also made training dramatically faster because GPUs excel at parallel matrix operations, not sequential loops.

Key Insight: Transformers replaced the sequential processing of RNNs with parallel attention — every token can directly look at every other token in a single step. This eliminated the information bottleneck and unlocked massive parallelism on GPUs. Every modern large language model (GPT, BERT, LLaMA) and many vision models (ViT, DINO) are built on Transformers.

The impact was seismic. Within two years, Transformers replaced RNNs in virtually every NLP task. Today, every large language model — GPT-4, Claude, LLaMA, Gemini — is a Transformer. The architecture has also conquered computer vision (Vision Transformers), audio (Whisper), protein structure prediction (AlphaFold), and even reinforcement learning (Decision Transformers). Understanding Transformers is now the single most important skill in deep learning.

Let's build one from scratch, starting with the core mechanism that makes it all work: attention.

Self-Attention: Building Intuition

Before diving into the full scaled dot-product formula, let's build intuition by implementing the simplest possible self-attention — no trainable weights, no scaling, no multi-head splitting. This is the core idea stripped to its essence: each token computes a relevance score with every other token, normalizes those scores with softmax, and then takes a weighted average of all token representations.

Attention Without Trainable Weights

The simplest self-attention computes attention scores as the raw dot product between token embeddings. If two tokens have similar embeddings (pointing in the same direction), their dot product is large, and they attend strongly to each other. This is conceptually equivalent to measuring "how similar are these two tokens?"

$$\omega_{ij} = x_i \cdot x_j = \sum_{k=1}^{d} x_{ik} \, x_{jk}$$ $$\alpha_{ij} = \frac{\exp(\omega_{ij})}{\sum_{k=1}^{n} \exp(\omega_{ik})} \quad \text{(softmax normalization)}$$ $$z_i = \sum_{j=1}^{n} \alpha_{ij} \, x_j \quad \text{(context vector)}$$

Where $x_i$ is the embedding of token $i$, $\omega_{ij}$ is the raw attention score, $\alpha_{ij}$ is the normalized attention weight, and $z_i$ is the output — a weighted combination of all tokens, biased toward the most "relevant" ones.

Worked ExampleStep-by-Step Calculation
Tracing the Equations with Real Numbers

Let's work through all three equations manually using a tiny 3-token sequence with 2-dimensional embeddings so every number stays readable:

Given token embeddings ($d = 2$):

$$x_0 = [1.0,\ 0.0] \quad x_1 = [0.8,\ 0.6] \quad x_2 = [0.0,\ 1.0]$$

Step 1 — Compute raw scores $\omega_{ij} = x_i \cdot x_j$

Pick token $i = 1$ ("cat") and compute its score against every token:

$$\omega_{10} = x_1 \cdot x_0 = (0.8)(1.0) + (0.6)(0.0) = 0.80$$ $$\omega_{11} = x_1 \cdot x_1 = (0.8)(0.8) + (0.6)(0.6) = 0.64 + 0.36 = 1.00$$ $$\omega_{12} = x_1 \cdot x_2 = (0.8)(0.0) + (0.6)(1.0) = 0.60$$

So token 1 attends most to itself (score 1.00), then to token 0 (0.80), then token 2 (0.60) — which makes sense because $x_1$ points closest to $x_0$ and furthest from $x_2$.

Step 2 — Softmax normalization $\alpha_{ij}$

Exponentiate each raw score and divide by the row sum:

$$e^{0.80} \approx 2.226 \qquad e^{1.00} \approx 2.718 \qquad e^{0.60} \approx 1.822$$ $$\text{sum} = 2.226 + 2.718 + 1.822 = 6.766$$ $$\alpha_{10} = \frac{2.226}{6.766} \approx 0.329 \qquad \alpha_{11} = \frac{2.718}{6.766} \approx 0.402 \qquad \alpha_{12} = \frac{1.822}{6.766} \approx 0.269$$

All three weights are positive and sum to exactly 1.0 — a valid probability distribution over the sequence.

Step 3 — Context vector $z_1 = \sum_j \alpha_{1j}\, x_j$

$$z_1 = 0.329 \cdot [1.0,\ 0.0]\ +\ 0.402 \cdot [0.8,\ 0.6]\ +\ 0.269 \cdot [0.0,\ 1.0]$$ $$z_1 = \underbrace{[0.329,\ 0.000]}_{\text{from "The"}}\ +\ \underbrace{[0.322,\ 0.241]}_{\text{from "cat" (self)}}\ +\ \underbrace{[0.000,\ 0.269]}_{\text{from "sat"}}$$ $$z_1 = [0.651,\ 0.510]$$

Token 1's output $z_1 = [0.651,\ 0.510]$ is no longer just its own embedding $[0.8,\ 0.6]$. Let's read what each neighbour actually contributed:

  • "The" contributes $[0.329,\ 0.000]$ — attends with weight 0.329, the second highest. "The" and "cat" are moderately similar (score 0.80), so "cat" borrows a meaningful share of "The"'s direction, adding length to the first dimension and nothing to the second.
  • "cat" contributes $[0.322,\ 0.241]$ — self-attention weight 0.402 (the largest), which is expected: a token always has a perfect dot product with itself. This is why $z_1$ stays close to $x_1$.
  • "sat" contributes $[0.000,\ 0.269]$ — attends with weight 0.269 (the lowest). $x_2 = [0, 1]$ is perpendicular to $x_0 = [1, 0]$ and only slightly overlaps with $x_1$. Its entire contribution lands in the second dimension — "sat" pushes "cat" upward along the second axis.

Net effect — both dimensions decrease from the original embedding:

  • Dim 1: 0.8 → 0.651 — "cat" self-contributes $0.402 \times 0.8 = 0.322$, "The" adds $0.329 \times 1.0 = 0.329$, "sat" adds $0.269 \times 0.0 = 0.000$. Total = 0.651. Although "The" has a strong dim-1 value (1.0), the blend pulls the result below 0.8 because "sat" contributes nothing to dim 1.
  • Dim 2: 0.6 → 0.510 — "cat" self-contributes $0.402 \times 0.6 = 0.241$, "The" adds $0.329 \times 0.0 = 0.000$, "sat" adds $0.269 \times 1.0 = 0.269$. Total = 0.510. "Sat" has the maximum value in dim 2 (1.0), but its attention weight is the lowest (0.269), so its boost ($+0.269$) is less than the gap lost from "The" contributing nothing ($0.329 \times 0 = 0$). The result ends up below 0.6.

In short: no dimension rises above the original — context-averaging always moves the output toward the weighted average of the sequence, not above any single component of it. What changes is the direction: $z_1 = [0.651, 0.510]$ points somewhere between $x_0$, $x_1$, and $x_2$, blending all three tokens' geometry into one enriched vector.

Effect of similarity on attention weights: Compare $x_0 = [1, 0]$ vs $x_2 = [0, 1]$. These two vectors are perpendicular (dot product = 0), so if token 0 attends to token 2: $\omega_{02} = (1)(0) + (0)(1) = 0.00$ — near-zero raw score, small weight after softmax. Tokens pointing in different directions barely attend to each other.

import torch
import torch.nn.functional as F

# Simple sentence: 4 tokens, each represented by a 3-dimensional embedding
# (In practice, embeddings are 512-1024 dims; we use 3 for clarity)
embeddings = torch.tensor([
    [0.43, 0.15, 0.89],  # Token 0: "The"
    [0.55, 0.87, 0.66],  # Token 1: "cat"
    [0.57, 0.85, 0.64],  # Token 2: "sat"
    [0.22, 0.58, 0.33],  # Token 3: "down"
])

# Step 1: Compute all pairwise dot products (similarity scores)
# Each entry [i, j] = how much token i attends to token j
scores = embeddings @ embeddings.T  # [4, 4]
print("Raw attention scores (dot products):")
print(scores.round(decimals=3))

# Step 2: Normalize each row with softmax (weights sum to 1)
attention_weights = F.softmax(scores, dim=-1)
print("\nAttention weights (after softmax):")
print(attention_weights.round(decimals=3))
print("Row sums:", attention_weights.sum(dim=-1))

# Step 3: Compute context vectors (weighted sum of all embeddings)
context_vectors = attention_weights @ embeddings  # [4, 3]
print("\nContext vectors (output):")
print(context_vectors.round(decimals=3))

# Interpretation: context_vectors[1] is "cat" enriched with info from similar tokens
# "cat" and "sat" have similar embeddings, so they attend strongly to each other
print(f"\n'cat' attends to 'sat' with weight: {attention_weights[1, 2]:.3f}")
print(f"'cat' attends to 'down' with weight: {attention_weights[1, 3]:.3f}")
Key Insight: Even without any learnable parameters, this simple mechanism already captures something meaningful: tokens with similar embeddings attend to each other. The word "cat" attends more strongly to "sat" (similar vector direction) than to "down". The output for each token becomes a context-enriched version of itself, blending information from the most relevant neighbors.

Why We Need Q, K, V Projections

The simple version above has a major limitation: each token uses the same vector for three different roles — as the thing doing the searching, the thing being searched for, and the thing being retrieved. Look back at the worked example: "cat" was represented by the single vector $x_1 = [0.8,\ 0.6]$, and that same vector played all three roles simultaneously:

  • Searcher (query): $x_1 = [0.8,\ 0.6]$ was dot-producted against every other token to produce the raw scores $\omega_{10}, \omega_{11}, \omega_{12}$.
  • Searched (key): $x_1 = [0.8,\ 0.6]$ was what other tokens compared against when deciding how much to attend to "cat" — i.e., $\omega_{01} = x_0 \cdot x_1$, $\omega_{11} = x_1 \cdot x_1$, etc.
  • Retrieved (value): $x_1 = [0.8,\ 0.6]$ was the vector that got weighted and summed into the context vectors — the $\alpha \cdot x_1$ term inside every $z_i$.

In the code block, this shows up as embeddings @ embeddings.T (using embeddings as both query and key) and then attention_weights @ embeddings (using the same embeddings as values). One matrix, three jobs. In practice, these roles should be different. Consider the word "it" in "The cat sat on the mat because it was tired" — what "it" is looking for (an antecedent noun) is different from what "it" advertises (a pronoun) and what "it" contains (its semantic meaning).

This is why we introduce three separate learnable linear projections:

$$Q = XW_Q, \quad K = XW_K, \quad V = XW_V$$

The Query ($Q$) represents "what am I looking for?", the Key ($K$) represents "what do I contain that others might want?", and the Value ($V$) represents "what information do I actually carry?". By learning separate projections, the model can attend based on one criterion (Q·K similarity) while retrieving different information (V). This separation is what gives Transformers their remarkable expressive power.

Worked ExampleStep-by-Step Q K V Calculation
Why the Same Vector Can't Play Three Roles — Shown with Numbers

We'll use 2 tokens with 2-dimensional embeddings and fixed (not yet trained) $2\times2$ weight matrices so every multiplication is visible. The sentence fragment is "bank loan" — "bank" is intentionally ambiguous (river bank? financial bank?), which makes it a perfect case to see how separate projections let the model route information differently.

Input embeddings ($d_{in} = 2$):

$$x_0 = [1.0,\ 0.5] \quad \text{("bank")} \qquad x_1 = [0.2,\ 0.9] \quad \text{("loan")}$$

Weight matrices ($d_{in} \times d_k = 2 \times 2$):

$$W_Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix} \quad W_K = \begin{bmatrix} 0 & 1 \\ 1 & 0 \end{bmatrix} \quad W_V = \begin{bmatrix} 2 & 0 \\ 0 & 2 \end{bmatrix}$$

Note: $W_Q$ is the identity (queries = raw embeddings), $W_K$ swaps dimensions (keys emphasise the opposite direction), and $W_V$ scales up by 2 (values carry amplified content). These three matrices do completely different things — impossible to achieve with one shared matrix.

Step 1 — Compute Queries ($Q = X W_Q$)

$W_Q$ is the identity matrix, so each query equals the raw embedding:

$$q_0 = x_0 \cdot W_Q = [1.0,\ 0.5] \cdot \begin{bmatrix}1&0\\0&1\end{bmatrix} = [(1.0)(1)+(0.5)(0),\ (1.0)(0)+(0.5)(1)] = [1.0,\ 0.5]$$ $$q_1 = x_1 \cdot W_Q = [0.2,\ 0.9] \cdot \begin{bmatrix}1&0\\0&1\end{bmatrix} = [(0.2)(1)+(0.9)(0),\ (0.2)(0)+(0.9)(1)] = [0.2,\ 0.9]$$

Step 2 — Compute Keys ($K = X W_K$)

$W_K$ swaps dimensions — the two columns are exchanged, so the second input feature becomes the first output feature and vice versa:

$$k_0 = x_0 \cdot W_K = [1.0,\ 0.5] \cdot \begin{bmatrix}0&1\\1&0\end{bmatrix} = [(1.0)(0)+(0.5)(1),\ (1.0)(1)+(0.5)(0)] = [0.5,\ 1.0]$$ $$k_1 = x_1 \cdot W_K = [0.2,\ 0.9] \cdot \begin{bmatrix}0&1\\1&0\end{bmatrix} = [(0.2)(0)+(0.9)(1),\ (0.2)(1)+(0.9)(0)] = [0.9,\ 0.2]$$

Step 3 — Compute Values ($V = X W_V$)

$W_V$ is a diagonal matrix that scales every dimension by 2, amplifying the content each token carries:

$$v_0 = x_0 \cdot W_V = [1.0,\ 0.5] \cdot \begin{bmatrix}2&0\\0&2\end{bmatrix} = [(1.0)(2)+(0.5)(0),\ (1.0)(0)+(0.5)(2)] = [2.0,\ 1.0]$$ $$v_1 = x_1 \cdot W_V = [0.2,\ 0.9] \cdot \begin{bmatrix}2&0\\0&2\end{bmatrix} = [(0.2)(2)+(0.9)(0),\ (0.2)(0)+(0.9)(2)] = [0.4,\ 1.8]$$

Step 4 — Scaled dot-product scores for token 0 ("bank")

$q_0 = [1.0, 0.5]$ is compared against every key. Scaling factor: $\sqrt{d_k} = \sqrt{2} \approx 1.414$.

$$s_{00} = \frac{q_0 \cdot k_0}{\sqrt{2}} = \frac{(1.0)(0.5)+(0.5)(1.0)}{1.414} = \frac{0.50 + 0.50}{1.414} = \frac{1.00}{1.414} \approx 0.707$$ $$s_{01} = \frac{q_0 \cdot k_1}{\sqrt{2}} = \frac{(1.0)(0.9)+(0.5)(0.2)}{1.414} = \frac{0.90 + 0.10}{1.414} = \frac{1.00}{1.414} \approx 0.707$$

Step 5 — Softmax to get attention weights

Exponentiate each score, then divide by the row sum:

$$e^{0.707} \approx 2.028 \qquad \text{row sum} = 2.028 + 2.028 = 4.056$$ $$\alpha_{00} = \frac{2.028}{4.056} = 0.50 \qquad \alpha_{01} = \frac{2.028}{4.056} = 0.50$$

"bank" pays equal attention to itself and to "loan" because both scores are identical. After training, $W_Q$ and $W_K$ would shift these scores to resolve the ambiguity.

Step 6 — Context vector $z_0$

Combine the value vectors using the attention weights:

$$z_0 = \alpha_{00} \cdot v_0 + \alpha_{01} \cdot v_1 = 0.50 \cdot [2.0,\ 1.0] + 0.50 \cdot [0.4,\ 1.8]$$ $$z_0 = [1.00,\ 0.50] + [0.20,\ 0.90] = [1.20,\ 1.40]$$

Compare the raw embedding $x_0 = [1.0, 0.5]$ with the context vector $z_0 = [1.2, 1.4]$. The output has been enriched with "loan"'s value vector — the model is now blending financial context into "bank". After training, if "loan" should dominate, $W_K$ would shift so $q_0 \cdot k_1 \gg q_0 \cdot k_0$, pulling the weight toward 1.0.

Why one shared matrix can't do this: In the no-weights version, the score $x_0 \cdot x_0 = 1.0^2 + 0.5^2 = 1.25$ is always larger than $x_0 \cdot x_1 = (1.0)(0.2)+(0.5)(0.9) = 0.65$ — so "bank" always attends more to itself. With separate $W_Q$ and $W_K$, the scores became equal here and can be steered arbitrarily by training. The Value projection further decouples what you retrieve from how you searched.

import torch
import torch.nn as nn
import torch.nn.functional as F

# Now with LEARNABLE projections (Q, K, V)
d_in = 3    # input embedding dimension
d_out = 2   # output/attention dimension (can differ from input)

# Learnable weight matrices (these are trained via backpropagation)
W_Q = nn.Parameter(torch.randn(d_in, d_out))
W_K = nn.Parameter(torch.randn(d_in, d_out))
W_V = nn.Parameter(torch.randn(d_in, d_out))

# Same embeddings as before
embeddings = torch.tensor([
    [0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
])

# Project into Q, K, V spaces
Q = embeddings @ W_Q  # [4, 2] — what each token is looking for
K = embeddings @ W_K  # [4, 2] — what each token advertises
V = embeddings @ W_V  # [4, 2] — what each token carries

print(f"Q shape: {Q.shape}")  # [4, 2]
print(f"K shape: {K.shape}")  # [4, 2]
print(f"V shape: {V.shape}")  # [4, 2]

# Attention with scaling (the "Scaled Dot-Product" we'll formalize next)
d_k = Q.size(-1)
scores = (Q @ K.T) / (d_k ** 0.5)
weights = F.softmax(scores, dim=-1)
output = weights @ V

print(f"\nAttention output shape: {output.shape}")  # [4, 2]
print(f"Attention weights:\n{weights.round(decimals=3)}")

This is exactly the Scaled Dot-Product Attention formula — we've just arrived at it by building up from the simplest possible version. Now let's formalize it and handle batching.

Scaled Dot-Product Attention

Attention answers a simple question: "Given what I'm looking for, which parts of the input are most relevant?" Think of searching a library. You arrive with a query (your question: "books about medieval history"), you compare your query against every book's key (its title and description), and when a key matches well, you retrieve that book's value (its actual content). Attention works exactly the same way — with learned linear projections replacing the librarian.

The Attention Formula

The mathematical formulation is elegant. Given matrices Q (queries), K (keys), and V (values), attention is computed as:

Attention(Q, K, V) = softmax(Q · K^T / √d_k) · V

The division by √d_k (the square root of the key dimension) is the "scaling" part. Without it, when dimensions are large, the dot products grow very large, pushing softmax into regions where its gradients are tiny — effectively killing learning. The scaling keeps the variance of the dot products at approximately 1, keeping softmax in its healthy gradient zone.

Implementing Scaled Dot-Product Attention

Let's implement the attention formula from scratch. We create query, key, and value tensors, compute their dot product, scale it, apply softmax to get attention weights, and finally multiply by values to get the output. This is the fundamental building block of every Transformer:

What is a "linear projection"?

In general matrix algebra, a projection is any operation that maps a vector from one space into another. The word "linear" means the mapping respects addition and scaling — i.e., it is entirely described by multiplying by a matrix. So a linear projection is simply a matrix multiplication:

$$q_i = x_i \cdot W_Q$$

where $x_i \in \mathbb{R}^{d_{in}}$ is the original token embedding and $W_Q \in \mathbb{R}^{d_{in} \times d_k}$ is a learnable weight matrix. The result $q_i \in \mathbb{R}^{d_k}$ is the token "projected" into a new space of dimension $d_k$.

Why "projection"? Think of shining a light on a 3D object and looking at its 2D shadow — the shadow is a projection of the object. The weight matrix $W_Q$ acts like the angle of the light: it decides which aspects of the embedding to emphasise and which to compress away. Three different matrices $W_Q, W_K, W_V$ create three different "shadows" of the same embedding, each optimised for a different role:

  • $W_Q$ projects into a space where "what am I searching for?" is most readable.
  • $W_K$ projects into a space where "what do I offer to others?" is most readable.
  • $W_V$ projects into a space that carries "what information should I pass on?"

In the code below, torch.randn(seq_len, d_k) simulates what you would get after that multiplication — in a real Transformer, you would instead write Q = x @ W_Q where W_Q = nn.Linear(d_in, d_k, bias=False).weight.T. The projection is learnable: during training, gradients flow back through the matrix multiplication and update $W_Q, W_K, W_V$ so the three projected spaces become increasingly useful.

import torch
import torch.nn.functional as F

# Simulating a sequence of 4 tokens, each with 8 dimensions
seq_len = 4
d_k = 8  # dimension of keys/queries

# Random Q, K, V matrices (in practice, these come from linear projections)
Q = torch.randn(seq_len, d_k)  # [4, 8] — what each token is looking for
K = torch.randn(seq_len, d_k)  # [4, 8] — what each token advertises
V = torch.randn(seq_len, d_k)  # [4, 8] — what each token actually contains

# Step 1: Compute raw attention scores (dot product of Q and K^T)
scores = torch.matmul(Q, K.transpose(-2, -1))  # [4, 4]
print("Raw scores shape:", scores.shape)
print("Raw scores:\n", scores.round(decimals=2))

# Step 2: Scale by sqrt(d_k) to prevent vanishing gradients
scale = d_k ** 0.5
scaled_scores = scores / scale
print("\nScaled scores:\n", scaled_scores.round(decimals=2))

# Step 3: Apply softmax to get attention weights (rows sum to 1)
attention_weights = F.softmax(scaled_scores, dim=-1)
print("\nAttention weights (rows sum to 1):\n", attention_weights.round(decimals=2))
print("Row sums:", attention_weights.sum(dim=-1).round(decimals=2))

# Step 4: Multiply weights by values to get the output
output = torch.matmul(attention_weights, V)  # [4, 8]
print("\nOutput shape:", output.shape)
print("Output:\n", output.round(decimals=2))

Each row in the attention weights matrix tells us how much each token attends to every other token. For example, if row 0 has weights [0.1, 0.6, 0.2, 0.1], token 0 pays 60% of its attention to token 1 and distributes the remaining 40% across the others. The output for token 0 is then a weighted average of all value vectors, with those weights determining the mix.

Experiment
Why Scaling Matters

Try removing the / scale step and increasing d_k to 512. Without scaling, the raw dot products become very large (around ±30), and softmax saturates — producing weights like [0.0, 0.0, 1.0, 0.0]. This "winner-take-all" behavior destroys gradient flow. With scaling, the scores stay near zero, keeping softmax in its useful range where it can express partial attention across multiple tokens.

gradient flow numerical stability softmax

Packaging Attention as a Reusable Function

Let's wrap the attention logic into a clean function that also supports an optional mask. The mask is crucial for preventing the model from attending to certain positions (like future tokens in language generation or padding tokens in batched inputs):

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.

    Args:
        Q: Queries [batch, heads, seq_len, d_k]
        K: Keys    [batch, heads, seq_len, d_k]
        V: Values  [batch, heads, seq_len, d_v]
        mask: Optional mask [batch, 1, 1, seq_len] or [1, 1, seq_len, seq_len]

    Returns:
        output: Weighted sum of values [batch, heads, seq_len, d_v]
        weights: Attention weights [batch, heads, seq_len, seq_len]
    """
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)

    if mask is not None:
        # Replace masked positions with -inf so softmax gives them zero weight
        scores = scores.masked_fill(mask == 0, float('-inf'))

    weights = F.softmax(scores, dim=-1)
    output = torch.matmul(weights, V)
    return output, weights

# Demo: batch of 2 sequences, 1 head, 5 tokens, 16 dims
batch, heads, seq_len, d_k = 2, 1, 5, 16
Q = torch.randn(batch, heads, seq_len, d_k)
K = torch.randn(batch, heads, seq_len, d_k)
V = torch.randn(batch, heads, seq_len, d_k)

output, weights = scaled_dot_product_attention(Q, K, V)
print("Output shape:", output.shape)    # [2, 1, 5, 16]
print("Weights shape:", weights.shape)  # [2, 1, 5, 5]
print("Weights row sum:", weights.sum(dim=-1)[0, 0].round(decimals=2))  # all 1.0

This function handles batched inputs with multiple attention heads. The mask parameter will become essential when we implement causal masking for decoders and padding masking for variable-length sequences later in this article.

Multi-Head Attention

The Core Idea (Plain English)

A single attention head can only focus on one type of relationship at a time. But language is multi-dimensional — when processing "bank", you simultaneously need to resolve: is it finance or a river? (semantic head), what's the subject? (syntactic head), where does it appear? (positional head). Multi-head attention runs several "attention spotlights" in parallel, each learning to focus on different things.

The Best Analogy: A Committee of Experts

Imagine 8 experts reading the same sentence simultaneously:

  • Expert 1 focuses on subject-verb relationships
  • Expert 2 focuses on adjective-noun pairs
  • Expert 3 focuses on coreference ("it" → "the cat")
  • Expert 4 focuses on negation scope
  • ...each expert independently decides what to attend to

After all experts finish, their findings are concatenated and merged into one rich representation per token. That's multi-head attention.

Ultra-compressed version:

# Multi-head = split → attend separately → concatenate → merge
heads = split(Q, K, V, into=8)            # 64-dim → eight 8-dim slices
outputs = [attention(h.q, h.k, h.v) for h in heads]  # 8 parallel attentions
merged = linear(concat(outputs))           # combine back to 64-dim

A single attention head can only capture one type of relationship at a time. But language is rich — when processing the word "bank" in "I sat on the river bank", one head might attend to "river" to disambiguate meaning, another might attend to "sat" for grammatical structure, and a third might attend to "I" for subject-verb agreement. Multi-Head Attention (MHA) runs multiple attention heads in parallel, each learning to focus on different types of relationships.

The mechanism is straightforward: split Q, K, V into h heads along the feature dimension, run independent attention on each head, concatenate the results, and project back to the original dimension. Each head gets d_k = d_model / h dimensions, so the total computation is the same as a single large attention — we just partition the representation space.

Implementing Multi-Head Attention From Scratch

This implementation mirrors the original paper. We create three linear projections (for Q, K, V), split them into heads, compute attention in parallel, concatenate, and project back. Pay attention to the reshape and transpose operations — they are the key to making multi-head attention work:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # dimension per head

        # Linear projections for Q, K, V and output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Step 1: Linear projections
        Q = self.W_q(query)  # [batch, seq_len, d_model]
        K = self.W_k(key)
        V = self.W_v(value)

        # Step 2: Split into heads — reshape to [batch, num_heads, seq_len, d_k]
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Step 3: Scaled dot-product attention on each head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)  # [batch, heads, seq_len, d_k]

        # Step 4: Concatenate heads — reshape back to [batch, seq_len, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.d_model)

        # Step 5: Final linear projection
        return self.W_o(attn_output)

# Demo
mha = MultiHeadAttention(d_model=64, num_heads=8)
x = torch.randn(2, 10, 64)  # batch=2, seq_len=10, d_model=64
output = mha(x, x, x)  # self-attention: Q=K=V=x
print("Input shape: ", x.shape)     # [2, 10, 64]
print("Output shape:", output.shape) # [2, 10, 64]
print("Params:", sum(p.numel() for p in mha.parameters()))  # 4 * 64 * 64 + 4 * 64 = 16640

Notice that the input and output shapes are identical — [batch, seq_len, d_model]. This is by design: it means we can stack attention layers on top of each other without any dimension mismatches. Each of the 8 heads operates on 8-dimensional slices (64/8 = 8), learning to attend to different types of relationships within those slices.

Using PyTorch's Built-In nn.MultiheadAttention

PyTorch provides nn.MultiheadAttention as a highly optimized built-in module. It expects inputs in the shape [seq_len, batch, d_model] by default (note the transposed order compared to our implementation). You can set batch_first=True to use [batch, seq_len, d_model] instead. Here's how to use it:

import torch
import torch.nn as nn

# PyTorch's built-in multi-head attention
mha = nn.MultiheadAttention(embed_dim=64, num_heads=8, batch_first=True)

# Input: batch=2, seq_len=10, embed_dim=64
x = torch.randn(2, 10, 64)

# Self-attention: query=key=value=x
# Returns: (attention_output, attention_weights)
attn_output, attn_weights = mha(query=x, key=x, value=x)
print("Output shape:", attn_output.shape)    # [2, 10, 64]
print("Weights shape:", attn_weights.shape)  # [2, 10, 10]

# Attention weights show how each token attends to others
print("Token 0 attends to all 10 tokens:", attn_weights[0, 0].round(decimals=3))
print("Weights sum per row:", attn_weights[0].sum(dim=-1).round(decimals=2))

The built-in version also supports key_padding_mask and attn_mask parameters for masking, which we will use when building the full Transformer. For production code, prefer nn.MultiheadAttention — it uses optimized CUDA kernels that are significantly faster than a manual implementation.

Self-Attention vs Cross-Attention

Self-attention is when the query, key, and value all come from the same source. When we wrote mha(x, x, x) above, that was self-attention — each token in the sequence asks "which other tokens in my own sequence should I pay attention to?" This is used in both the encoder (every source token attends to every other source token) and the decoder's first attention layer (every target token attends to previously generated target tokens).

Cross-attention is when the query comes from one source and the key/value come from a different source. In a translation Transformer, the decoder uses cross-attention: the query is the decoder's current state ("what am I looking for in the source?"), while the key and value come from the encoder's output ("here's the source sentence's representation"). This is how the decoder "reads" the input while generating the output.

import torch
import torch.nn as nn

mha = nn.MultiheadAttention(embed_dim=64, num_heads=8, batch_first=True)

# ---- Self-Attention ----
# All from the same source (e.g., encoder input)
encoder_input = torch.randn(2, 10, 64)  # 10-token source
self_attn_out, _ = mha(
    query=encoder_input,
    key=encoder_input,
    value=encoder_input
)
print("Self-attention output:", self_attn_out.shape)  # [2, 10, 64]

# ---- Cross-Attention ----
# Query from decoder, Key/Value from encoder
decoder_state = torch.randn(2, 8, 64)   # 8-token target (being generated)
encoder_output = torch.randn(2, 10, 64) # 10-token source (already encoded)
cross_attn_out, cross_weights = mha(
    query=decoder_state,     # decoder asks the questions
    key=encoder_output,      # encoder provides the keys
    value=encoder_output     # encoder provides the values
)
print("Cross-attention output:", cross_attn_out.shape)   # [2, 8, 64]
print("Cross-attention weights:", cross_weights.shape)    # [2, 8, 10]
# Each of 8 decoder tokens attends to all 10 encoder tokens

The cross-attention weight matrix has shape [batch, target_len, source_len] — you can visualize it as a heatmap showing which source words each target word focuses on. In machine translation, this often reveals interpretable alignment patterns (e.g., the French word "chat" strongly attends to the English word "cat").

Remember: Self-attention asks "what in my own sequence is relevant?" Cross-attention asks "what in the other sequence is relevant?" Encoder blocks use self-attention. Decoder blocks use both — masked self-attention first, then cross-attention to the encoder output.

Positional Encoding

The Core Idea (Plain English)

Attention is order-blind — it computes the same result regardless of token order. But "dog bites man" and "man bites dog" have very different meanings! Positional encoding solves this by stamping each token with its position number before attention sees it.

The Best Analogy: Seat Numbers in a Theater

Imagine a group conversation where everyone speaks simultaneously (like attention):

  • Without positional encoding — you hear all voices at once but can't tell who spoke first, second, third (word salad)
  • With positional encoding — each speaker holds up a seat number card, so you know the order even though they all speak at once

The encoding is added to each token's embedding: input = word_meaning + position_signal. The model learns to use both signals.

Ultra-compressed version:

# Without position: "cat sat mat" == "mat cat sat" (same attention output!)
# With position:   "cat(pos=0) sat(pos=1) mat(pos=2)" ≠ "mat(pos=0) cat(pos=1) sat(pos=2)"

# The fix: add a unique position vector to each token embedding
token_with_position = word_embedding + position_encoding[position_index]

RNNs naturally know word order because they process tokens one at a time — position 0 is processed before position 1, which is processed before position 2, and so on. Transformers have no such built-in ordering. Because attention computes all pairwise interactions in parallel, attention("cat sat mat") gives the exact same result as attention("mat cat sat") unless we explicitly inject position information. Without positional encoding, a Transformer is a bag-of-words model — it cannot distinguish "dog bites man" from "man bites dog".

Sinusoidal Positional Encoding

The original Transformer paper uses sinusoidal functions to encode position. Each position gets a unique vector of sine and cosine values at different frequencies. This clever scheme has two advantages: (1) it generalizes to sequence lengths longer than those seen during training, and (2) the model can learn to attend to relative positions because the dot product of encodings at positions i and j is a function of i - j. Here's the implementation:

import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create a matrix of shape [max_len, d_model]
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # [max_len, 1]
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)  # even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices
        pe = pe.unsqueeze(0)  # [1, max_len, d_model] for broadcasting

        # Register as buffer (not a parameter — no gradient updates)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: [batch, seq_len, d_model]
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# Demo
d_model = 64
pe_module = PositionalEncoding(d_model=d_model, max_len=100)

# Simulate an embedded sequence (batch=1, seq_len=10, d_model=64)
token_embeddings = torch.randn(1, 10, d_model)
encoded = pe_module(token_embeddings)
print("Input shape: ", token_embeddings.shape)  # [1, 10, 64]
print("Output shape:", encoded.shape)           # [1, 10, 64]

# The positional encoding for position 0 vs position 9 are very different
print("\nPosition 0 encoding (first 8 dims):", pe_module.pe[0, 0, :8].round(decimals=3))
print("Position 9 encoding (first 8 dims):", pe_module.pe[0, 9, :8].round(decimals=3))

The positional encoding is added to the token embeddings, not concatenated. This means each dimension of the embedding now carries both semantic meaning (from the learned embedding) and positional information (from the sinusoidal encoding). The register_buffer call ensures the encoding tensor moves to GPU with the model but is not updated by the optimizer — positions are deterministic, not learned.

Experiment
Learned vs Fixed Positional Embeddings

Modern models like BERT and GPT use learned positional embeddings (nn.Embedding(max_len, d_model)) instead of sinusoidal encodings. Learned embeddings can adapt to the data but cannot extrapolate to longer sequences than seen during training. Sinusoidal encodings generalize to any length because they're based on continuous functions. For most practical purposes, both approaches perform similarly on in-distribution sequence lengths.

learned embeddings generalization sequence length

Transformer Encoder Block

The Core Idea (Plain English)

An encoder block is a two-step refinement loop: first, every token talks to every other token (attention); then, every token processes its own information independently (feed-forward). Both steps have "safety rails" (residual connections + normalization) that prevent things from going wrong in deep networks.

The Best Analogy: A Workshop with Two Stations

Think of each encoder block as a two-station workshop that each token passes through:

  • Station 1 (Attention) — "Group discussion" — every token consults all other tokens and updates itself with relevant info
  • Station 2 (FFN) — "Private study" — each token independently refines its own representation through a small neural network
  • Safety rails (Residual + Norm) — "Always keep your original notes" — each station adds new info rather than replacing old info

Stack 6–12 of these blocks, and tokens get progressively richer representations. That's a Transformer encoder.

Ultra-compressed version:

# One encoder block in pseudocode:
x = x + attention(x)    # Step 1: talk to neighbors, add result
x = x + ffn(x)          # Step 2: think alone, add result
# Repeat 6-12 times (stacked blocks)

Now we have all the ingredients to build a Transformer encoder block. Each block consists of two sub-layers, each wrapped with a residual connection and layer normalization. The first sub-layer is multi-head self-attention (tokens attend to each other). The second sub-layer is a position-wise feed-forward network (each token is independently transformed through two linear layers with a ReLU activation). The residual connections ensure that gradients can flow directly through the network even when it is very deep.

Transformer Encoder Block Architecture
flowchart TD
    Input["Input Embeddings\n+ Positional Encoding"] --> MHA["Multi-Head\nSelf-Attention"]
    MHA --> Add1["Add & Norm\n(Residual + LayerNorm)"]
    Input --> Add1
    Add1 --> FFN["Feed-Forward Network\nLinear → ReLU → Linear"]
    FFN --> Add2["Add & Norm\n(Residual + LayerNorm)"]
    Add1 --> Add2
    Add2 --> Output["Encoder Output"]
                            

Why Residual (Shortcut) Connections Work

Every sub-layer in a Transformer block is wrapped with a residual connection: the output is x + SubLayer(x) rather than just SubLayer(x). This seemingly simple addition is critical for training deep networks (Transformers can be 96+ layers). The key insight: instead of learning the full desired transformation $H(x)$, the sub-layer only needs to learn the residual $F(x) = H(x) - x$ — the difference from the identity. If the optimal transformation is close to identity (often the case), learning a near-zero residual is much easier than learning a full identity mapping from scratch.

$$\text{output} = x + \text{SubLayer}(\text{LayerNorm}(x))$$

During backpropagation, the gradient flows through two paths: through the sub-layer AND directly through the addition (gradient = 1). This "gradient highway" prevents vanishing gradients even in 100+ layer models. Without residual connections, Transformers deeper than 6 layers become essentially untrainable.

Pre-Norm vs Post-Norm: The original "Attention Is All You Need" paper placed LayerNorm after the residual addition (Post-LN): LayerNorm(x + SubLayer(x)). Modern architectures like GPT-2, LLaMA, and most current models use Pre-LN: x + SubLayer(LayerNorm(x)). Pre-LN produces more stable training and avoids the need for careful learning rate warmup that Post-LN requires.

Building an Encoder Block From Scratch

Here is a complete, self-contained encoder block. We use pre-layer normalization (applying LayerNorm before each sub-layer), which is the approach used by GPT-2 and most modern Transformers because it stabilizes training. The original paper used post-layer normalization (after each sub-layer), but pre-norm has been shown to train more reliably:

import torch
import torch.nn as nn

class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # Multi-head self-attention
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=num_heads,
            dropout=dropout, batch_first=True
        )

        # Feed-forward network (two linear layers with ReLU)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )

        # Layer normalization (one per sub-layer)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout for residual connections
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask=None, src_key_padding_mask=None):
        # Sub-layer 1: Multi-Head Self-Attention + Residual + Norm
        x_norm = self.norm1(x)
        attn_output, _ = self.self_attn(
            x_norm, x_norm, x_norm,
            attn_mask=src_mask,
            key_padding_mask=src_key_padding_mask
        )
        x = x + self.dropout(attn_output)  # residual connection

        # Sub-layer 2: Feed-Forward Network + Residual + Norm
        x_norm = self.norm2(x)
        ffn_output = self.ffn(x_norm)
        x = x + ffn_output  # residual connection

        return x

# Demo: process a batch of sequences through one encoder block
encoder_block = TransformerEncoderBlock(d_model=64, num_heads=8, d_ff=256, dropout=0.1)
x = torch.randn(2, 10, 64)  # batch=2, seq_len=10, d_model=64
output = encoder_block(x)
print("Input shape: ", x.shape)      # [2, 10, 64]
print("Output shape:", output.shape)  # [2, 10, 64]
print("Parameters:  ", sum(p.numel() for p in encoder_block.parameters()))

The feed-forward network (FFN) is often called the "expansion layer" because d_ff is typically 4× larger than d_model. For example, BERT-base uses d_model=768 and d_ff=3072. This expansion gives each token more capacity to transform its representation before projecting back down. Notice that the FFN is applied independently to each token position — there is no information exchange between tokens in this sub-layer.

Transformer Decoder Block

The Core Idea (Plain English)

The decoder is like the encoder but with a blindfold: when generating token 5, it's not allowed to peek at tokens 6, 7, 8... (because those haven't been generated yet). It also has an extra step where it "reads" the encoder's output to understand the source.

The Best Analogy: Writing a Translation with a Reference Book

Imagine translating English to French, one word at a time:

  • Masked self-attention — "Look at what I've written so far" (can't see future words you haven't written yet)
  • Cross-attention — "Look at the English source text" (read the original to decide what to write next)
  • FFN — "Think about what word comes next" (private processing)

The causal mask is the "blindfold" that enforces left-to-right generation. Without it, the decoder would cheat by looking at future tokens during training.

The decoder block is similar to the encoder block but with one crucial addition: a cross-attention layer that attends to the encoder's output. The decoder also uses masked self-attention in its first sub-layer, where a causal mask prevents each position from attending to future positions. This is essential for autoregressive generation — when predicting token 5, the model should only see tokens 0–4, not tokens 6–9.

Building a Decoder Block

The decoder block has three sub-layers instead of two: (1) masked self-attention, (2) cross-attention to encoder output, and (3) feed-forward network. Each sub-layer has its own residual connection and layer normalization:

import torch
import torch.nn as nn

class TransformerDecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # Masked self-attention (causal — no future peeking)
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=num_heads,
            dropout=dropout, batch_first=True
        )

        # Cross-attention (decoder queries attend to encoder keys/values)
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=num_heads,
            dropout=dropout, batch_first=True
        )

        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )

        # Layer norms (one per sub-layer)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # Sub-layer 1: Masked Self-Attention
        x_norm = self.norm1(x)
        self_attn_out, _ = self.self_attn(
            x_norm, x_norm, x_norm, attn_mask=tgt_mask,
            key_padding_mask=tgt_key_padding_mask
        )
        x = x + self.dropout(self_attn_out)

        # Sub-layer 2: Cross-Attention to encoder output
        x_norm = self.norm2(x)
        cross_attn_out, _ = self.cross_attn(
            query=x_norm,                # decoder asks the questions
            key=encoder_output,          # encoder provides context
            value=encoder_output,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask
        )
        x = x + self.dropout(cross_attn_out)

        # Sub-layer 3: Feed-Forward Network
        x_norm = self.norm3(x)
        ffn_out = self.ffn(x_norm)
        x = x + ffn_out

        return x

# Demo
decoder_block = TransformerDecoderBlock(d_model=64, num_heads=8, d_ff=256)
tgt = torch.randn(2, 8, 64)           # target sequence (8 tokens)
encoder_out = torch.randn(2, 10, 64)  # encoder output (10 tokens)

# Create causal mask: lower triangular (True = allowed, False = blocked)
seq_len = 8
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
# nn.MultiheadAttention uses True to BLOCK, so we use upper triangular

output = decoder_block(tgt, encoder_out, tgt_mask=causal_mask)
print("Decoder input shape: ", tgt.shape)          # [2, 8, 64]
print("Encoder output shape:", encoder_out.shape)   # [2, 10, 64]
print("Decoder output shape:", output.shape)        # [2, 8, 64]

The causal mask is an upper triangular matrix of True values — positions that should be blocked. In PyTorch's nn.MultiheadAttention, True in the attention mask means "do not attend here." So the upper triangle (future positions) is True (blocked), and the lower triangle plus diagonal (past and current positions) is False (allowed). This ensures position i can only attend to positions 0 through i.

Full Transformer Model

A complete Transformer stacks multiple encoder and decoder blocks on top of each other. The original paper used 6 of each. The encoder processes the entire source sequence into a rich representation, and the decoder uses that representation (via cross-attention) to generate the output sequence one token at a time. Let's build the full architecture from our blocks:

Full Transformer Architecture (Encoder-Decoder)
flowchart TD
    subgraph Encoder["Encoder (Nx)"]
        SrcEmb["Source\nEmbedding"] --> PE1["+ Positional\nEncoding"]
        PE1 --> EB1["Encoder Block 1"]
        EB1 --> EB2["Encoder Block 2"]
        EB2 --> EBN["... Encoder Block N"]
    end

    subgraph Decoder["Decoder (Nx)"]
        TgtEmb["Target\nEmbedding"] --> PE2["+ Positional\nEncoding"]
        PE2 --> DB1["Decoder Block 1"]
        DB1 --> DB2["Decoder Block 2"]
        DB2 --> DBN["... Decoder Block N"]
    end

    EBN -->|"encoder output\n(cross-attention)"| DB1
    EBN -->|"encoder output"| DB2
    EBN -->|"encoder output"| DBN
    DBN --> Linear["Linear\nProjection"]
    Linear --> Softmax["Softmax\nOutput Probabilities"]
                            

Building a Full Transformer From Custom Blocks

We combine our encoder and decoder blocks, token embeddings, positional encoding, and a final output projection into a complete sequence-to-sequence Transformer. This model can be used for machine translation, text summarization, or any encoder-decoder task:

import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return self.dropout(x + self.pe[:, :x.size(1)])

class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=128, num_heads=8,
                 num_layers=3, d_ff=512, dropout=0.1, max_len=200):
        super().__init__()

        # Token embeddings
        self.src_embed = nn.Embedding(src_vocab, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len, dropout)
        self.scale = math.sqrt(d_model)

        # Encoder and Decoder stacks (using PyTorch's built-in layers)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=num_heads, dim_feedforward=d_ff,
            dropout=dropout, batch_first=True, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=num_heads, dim_feedforward=d_ff,
            dropout=dropout, batch_first=True, norm_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Output projection
        self.output_proj = nn.Linear(d_model, tgt_vocab)

    def forward(self, src, tgt, tgt_mask=None, src_key_padding_mask=None,
                tgt_key_padding_mask=None):
        # Embed and encode source
        src_emb = self.pos_enc(self.src_embed(src) * self.scale)
        memory = self.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)

        # Embed and decode target
        tgt_emb = self.pos_enc(self.tgt_embed(tgt) * self.scale)
        decoder_out = self.decoder(
            tgt_emb, memory, tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask
        )

        return self.output_proj(decoder_out)

# Demo: a small Transformer
model = Transformer(src_vocab=1000, tgt_vocab=1000, d_model=128,
                    num_heads=8, num_layers=3, d_ff=512)

src = torch.randint(0, 1000, (2, 15))  # source: batch=2, seq_len=15
tgt = torch.randint(0, 1000, (2, 12))  # target: batch=2, seq_len=12

# Create causal mask for decoder
tgt_mask = nn.Transformer.generate_square_subsequent_mask(12)
logits = model(src, tgt, tgt_mask=tgt_mask)
print("Source shape:", src.shape)     # [2, 15]
print("Target shape:", tgt.shape)     # [2, 12]
print("Logits shape:", logits.shape)  # [2, 12, 1000] — one vocab distribution per target position
print("Total params:", sum(p.numel() for p in model.parameters()))

The model outputs logits of shape [batch, target_len, vocab_size] — a probability distribution over the vocabulary for each position in the target sequence. During training, we compare these logits against the ground truth using cross-entropy loss. During inference, we generate tokens one at a time, feeding each predicted token back as input for the next step (autoregressive generation).

Using PyTorch's nn.Transformer Directly

If you don't need custom encoder/decoder blocks, PyTorch's nn.Transformer wraps everything into a single module. It is the simplest way to get a working Transformer in just a few lines:

import torch
import torch.nn as nn

# Create a full Transformer with one line
transformer = nn.Transformer(
    d_model=64, nhead=8, num_encoder_layers=3, num_decoder_layers=3,
    dim_feedforward=256, dropout=0.1, batch_first=True
)

# Simulated embeddings (in practice, add token embedding + positional encoding)
src = torch.randn(2, 15, 64)   # source embeddings
tgt = torch.randn(2, 12, 64)   # target embeddings

# Generate causal mask for the target
tgt_mask = nn.Transformer.generate_square_subsequent_mask(12)

# Forward pass
output = transformer(src, tgt, tgt_mask=tgt_mask)
print("Output shape:", output.shape)  # [2, 12, 64]
print("Total params:", sum(p.numel() for p in transformer.parameters()))

This is useful for quick experiments, but for production models you typically want to add your own embedding layers, positional encoding, and output projection on top — exactly as we did in the full model above.

Attention Masking

Masking is one of the most important — and most confusing — parts of Transformers. There are two types of masks, and getting them wrong will silently produce incorrect results.

Causal Masks (No Future Peeking)

A causal mask (also called a "look-ahead mask" or "subsequent mask") prevents each position from attending to future positions. This is mandatory for autoregressive models — when generating token 5, the model must not see tokens 6, 7, 8, etc. Without a causal mask, the model would "cheat" during training by looking at the answer it is supposed to predict. Here's how to create one:

import torch
import torch.nn as nn

def create_causal_mask(seq_len):
    """Create a causal mask where True means 'block this position'."""
    # Upper triangular matrix: True above the diagonal (future positions)
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask

# Visualize the mask for a 5-token sequence
mask = create_causal_mask(5)
print("Causal mask (True = blocked):")
print(mask.int())
# tensor([[0, 1, 1, 1, 1],    ← token 0 can see only itself
#         [0, 0, 1, 1, 1],    ← token 1 can see tokens 0-1
#         [0, 0, 0, 1, 1],    ← token 2 can see tokens 0-2
#         [0, 0, 0, 0, 1],    ← token 3 can see tokens 0-3
#         [0, 0, 0, 0, 0]])   ← token 4 can see all tokens

# PyTorch's built-in helper generates the same thing (as float -inf)
pytorch_mask = nn.Transformer.generate_square_subsequent_mask(5)
print("\nPyTorch causal mask (float):")
print(pytorch_mask.round(decimals=1))

In PyTorch's convention, True (or -inf) means "block this position" and False (or 0) means "allow attention." The generate_square_subsequent_mask helper returns a float mask filled with -inf for blocked positions and 0 for allowed positions, which gets added directly to the attention scores before softmax.

Padding Masks (Ignore PAD Tokens)

When batching sequences of different lengths, shorter sequences are padded with special PAD tokens to match the length of the longest sequence. We don't want the model to attend to these meaningless padding tokens. A padding mask marks which positions are padding so attention can ignore them:

import torch
import torch.nn as nn

def create_padding_mask(sequences, pad_idx=0):
    """Create a padding mask. True means 'this position is padding — ignore it'."""
    return (sequences == pad_idx)  # [batch, seq_len]

# Example: batch of 3 sequences padded to length 6
# Sequence 1: 4 real tokens + 2 padding
# Sequence 2: 6 real tokens (no padding)
# Sequence 3: 2 real tokens + 4 padding
sequences = torch.tensor([
    [5, 12, 8, 3, 0, 0],   # 0 = PAD token
    [7, 1, 9, 4, 6, 2],
    [11, 3, 0, 0, 0, 0],
])

padding_mask = create_padding_mask(sequences, pad_idx=0)
print("Sequences:\n", sequences)
print("\nPadding mask (True = padding, ignore):")
print(padding_mask.int())
# tensor([[0, 0, 0, 0, 1, 1],
#         [0, 0, 0, 0, 0, 0],
#         [0, 0, 1, 1, 1, 1]])

# Use with nn.MultiheadAttention as key_padding_mask
mha = nn.MultiheadAttention(embed_dim=32, num_heads=4, batch_first=True)
x = torch.randn(3, 6, 32)  # batch=3, seq_len=6, d_model=32
output, weights = mha(x, x, x, key_padding_mask=padding_mask)
print("\nOutput shape:", output.shape)  # [3, 6, 32]
# Padding positions will have near-zero attention weights

The padding mask is a boolean tensor of shape [batch, seq_len] where True marks padding positions. When passed as key_padding_mask to nn.MultiheadAttention, it ensures that no token attends to padding positions. This is critical for correct behavior — without it, the model would treat random padding values as meaningful content.

Combining Both Masks

In a decoder, you often need both masks simultaneously: a causal mask (prevent future peeking) and a padding mask (ignore PAD tokens). Here is how to combine them correctly:

import torch
import torch.nn as nn

def create_combined_mask(tgt_seq, pad_idx=0):
    """Create both causal and padding masks for decoder self-attention."""
    seq_len = tgt_seq.size(1)
    batch_size = tgt_seq.size(0)

    # Causal mask: [seq_len, seq_len] — same for all batches
    causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len)

    # Padding mask: [batch, seq_len] — different per sequence
    padding_mask = (tgt_seq == pad_idx)

    return causal_mask, padding_mask

# Target sequences (with padding)
tgt = torch.tensor([
    [1, 5, 8, 3, 0, 0],  # 4 real tokens
    [2, 7, 4, 6, 9, 1],  # 6 real tokens
])

causal_mask, padding_mask = create_combined_mask(tgt, pad_idx=0)
print("Causal mask shape:", causal_mask.shape)    # [6, 6]
print("Padding mask shape:", padding_mask.shape)  # [2, 6]

# Use both masks together in decoder
decoder_layer = nn.TransformerDecoderLayer(
    d_model=32, nhead=4, dim_feedforward=128, batch_first=True
)
tgt_emb = torch.randn(2, 6, 32)
memory = torch.randn(2, 10, 32)  # encoder output

output = decoder_layer(
    tgt_emb, memory,
    tgt_mask=causal_mask,              # prevents future peeking
    tgt_key_padding_mask=padding_mask  # ignores PAD tokens
)
print("Decoder output shape:", output.shape)  # [2, 6, 32]

The causal mask and padding mask work together: the causal mask prevents attention to future positions, while the padding mask prevents attention to padding tokens. Both are essential for correct decoder behavior during training.

Vision Transformers (ViT)

The Core Idea (Plain English)

What if we could use Transformers for images instead of just text? The trick is simple: chop the image into patches and treat each patch as a "word". A 224×224 image split into 16×16 patches gives 196 "tokens" — feed those into a standard Transformer encoder, and you get image classification without any convolutions at all.

The Best Analogy: A Jigsaw Puzzle Reader

Imagine cutting a photo into a 14×14 grid of puzzle pieces:

  • Each puzzle piece = one image patch (treated as a "token")
  • Flatten + project = convert each patch's pixels into an embedding vector
  • Attention between patches = "which other patches are relevant to understanding this patch?"
  • [CLS] token = a special "summary" token that aggregates information from all patches for classification

No convolutions, no pooling — just patches and attention. That's a Vision Transformer.

Ultra-compressed version:

# Vision Transformer in pseudocode:
patches = split_image_into_grid(image, patch_size=16)  # 196 patches
tokens = linear_project(patches)                        # each patch → embedding vector
tokens = prepend([CLS], tokens)                         # add summary token
tokens = tokens + position_embeddings                   # stamp with position
output = transformer_encoder(tokens)                    # standard Transformer!
class_label = classifier(output[CLS])                   # classify from [CLS] output

In 2020, the paper "An Image Is Worth 16x16 Words" showed that you can apply Transformers directly to images by treating image patches as tokens. The idea is beautifully simple: split an image into a grid of non-overlapping patches (e.g., 16×16 pixels each), flatten each patch into a vector, project it to the model dimension, prepend a learnable [CLS] token, add positional embeddings, and feed the whole thing into a standard Transformer encoder. The [CLS] token's output representation is used for classification.

Building a Vision Transformer

Let's implement a minimal ViT from scratch. The key innovation is the patch embedding layer, which uses a convolution with kernel_size = stride = patch_size to efficiently split the image into patches and project them in one step:

import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    """Split image into patches and embed them."""
    def __init__(self, img_size=32, patch_size=8, in_channels=3, embed_dim=128):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2  # e.g., (32/8)^2 = 16
        # Conv2d with kernel=stride=patch_size acts as a patch extractor
        self.proj = nn.Conv2d(in_channels, embed_dim,
                              kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: [batch, channels, height, width]
        x = self.proj(x)             # [batch, embed_dim, H/patch, W/patch]
        x = x.flatten(2)             # [batch, embed_dim, num_patches]
        x = x.transpose(1, 2)        # [batch, num_patches, embed_dim]
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=8, in_channels=3,
                 num_classes=10, embed_dim=128, num_heads=8,
                 num_layers=4, d_ff=512, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        # Learnable [CLS] token and positional embeddings
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=d_ff,
            dropout=dropout, batch_first=True, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(embed_dim)

        # Classification head
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        batch_size = x.size(0)

        # Step 1: Patch embedding
        patches = self.patch_embed(x)  # [batch, num_patches, embed_dim]

        # Step 2: Prepend [CLS] token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        patches = torch.cat([cls_tokens, patches], dim=1)  # [batch, num_patches+1, embed_dim]

        # Step 3: Add positional embeddings
        patches = self.dropout(patches + self.pos_embed)

        # Step 4: Transformer encoder
        encoded = self.encoder(patches)
        encoded = self.norm(encoded)

        # Step 5: Use [CLS] token output for classification
        cls_output = encoded[:, 0]  # [batch, embed_dim]
        return self.classifier(cls_output)

# Demo: classify 32x32 images (like CIFAR-10)
vit = VisionTransformer(
    img_size=32, patch_size=8, in_channels=3, num_classes=10,
    embed_dim=128, num_heads=8, num_layers=4
)
images = torch.randn(4, 3, 32, 32)  # batch of 4 RGB images
logits = vit(images)
print("Input shape: ", images.shape)    # [4, 3, 32, 32]
print("Output shape:", logits.shape)    # [4, 10]
print("Prediction:  ", logits.argmax(dim=-1))
print("Parameters:  ", sum(p.numel() for p in vit.parameters()))

This ViT splits each 32×32 image into 16 patches of 8×8 pixels. Each patch becomes a 128-dimensional token, plus one [CLS] token — giving us 17 tokens total. The Transformer processes these tokens exactly like it would process word tokens in NLP. The final classification uses only the [CLS] token's output, which has learned to aggregate information from all patches through self-attention.

Experiment
When ViT Beats CNNs

Vision Transformers need significantly more training data than CNNs to perform well. On small datasets like CIFAR-10 (50K images), a well-tuned ResNet typically beats a ViT. But on large datasets (ImageNet-21K with 14M images, JFT-300M with 300M images), ViTs surpass CNNs because self-attention can capture global relationships that CNNs' local receptive fields miss. The rule of thumb: use CNNs for small datasets, ViTs for large ones — or use a pretrained ViT and fine-tune it (covered in Part 8).

data efficiency scaling laws CNN vs ViT

Sparse & Efficient Attention

Standard self-attention has a fundamental scaling problem: it computes attention between every pair of tokens, giving it O(n²) time and memory complexity where n is the sequence length. For a 512-token sequence, this means 262,144 attention scores — manageable. But for a 4,096-token sequence, it jumps to 16.7 million. For a 100K-token document, it would require 10 billion attention scores — far beyond any GPU's memory. This quadratic scaling is the single biggest limitation of standard Transformers.

Efficient Attention Strategies

Researchers have developed several strategies to reduce the O(n²) cost. Here are the most important ones, with a practical comparison:

Performance Reality Check: Standard attention with sequence length 8192 requires ~268M attention scores per head. At float16, that is ~512MB of memory just for attention weights — per layer, per head. With 32 heads and 24 layers, you need ~384GB just for attention. This is why efficient attention methods are not optional for long-context models — they are a hard requirement.

Sliding Window (Local) Attention restricts each token to attend only to a fixed window of nearby tokens (e.g., 256 tokens on each side). This reduces complexity from O(n²) to O(n × w) where w is the window size. Models like Longformer and Mistral use this approach, often combined with a few global tokens that can attend to everything:

import torch
import torch.nn.functional as F

def sliding_window_attention(Q, K, V, window_size=3):
    """
    Simple sliding window attention — each token attends only
    to tokens within a local window of size 2*window_size+1.
    """
    seq_len = Q.size(0)
    d_k = Q.size(-1)

    # Create a mask where True = allowed to attend
    mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
    for i in range(seq_len):
        start = max(0, i - window_size)
        end = min(seq_len, i + window_size + 1)
        mask[i, start:end] = True

    # Compute attention with the window mask
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
    scores = scores.masked_fill(~mask, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    output = torch.matmul(weights, V)

    return output, weights, mask

# Demo: 10 tokens, window size 2 (each token sees 5 neighbors)
seq_len, d_k = 10, 16
Q = torch.randn(seq_len, d_k)
K = torch.randn(seq_len, d_k)
V = torch.randn(seq_len, d_k)

output, weights, mask = sliding_window_attention(Q, K, V, window_size=2)
print("Window mask (1 = can attend):")
print(mask.int())
print("\nAttention weights (notice local pattern):")
print(weights.round(decimals=2))
print("\nNon-zero weights per row:", (weights > 0.001).sum(dim=-1))

Notice how the mask creates a diagonal band pattern — each token only attends to its local neighborhood. This is extremely efficient for long sequences because the number of attention computations scales linearly with sequence length rather than quadratically. The trade-off is that distant tokens cannot directly communicate in a single layer, but stacking multiple layers with sliding windows allows information to propagate across the full sequence (like pixels in a CNN).

Other efficient attention approaches include:

  • Linear Attention: Replace softmax(QK^T)V with φ(Q)(φ(K)^T V), computing K^T V first to avoid the n×n matrix. Reduces complexity to O(n).
  • Flash Attention: Not a mathematical approximation — it computes exact standard attention but uses tiling and kernel fusion to minimize GPU memory reads/writes. Used in PyTorch 2.0+ via torch.nn.functional.scaled_dot_product_attention.
  • Sparse Attention: Combine local windows with a few global tokens and random long-range connections (as in BigBird and Longformer).

Using PyTorch 2.0's Flash Attention

PyTorch 2.0+ includes Flash Attention as a backend for scaled_dot_product_attention. It computes mathematically exact attention but is 2-4× faster and uses significantly less memory through smart GPU kernel optimization. You get it for free — no code changes needed:

import torch
import torch.nn.functional as F

# PyTorch 2.0+ scaled_dot_product_attention with Flash Attention backend
batch, heads, seq_len, d_k = 2, 8, 64, 32
Q = torch.randn(batch, heads, seq_len, d_k)
K = torch.randn(batch, heads, seq_len, d_k)
V = torch.randn(batch, heads, seq_len, d_k)

# This automatically uses Flash Attention when available on CUDA
output = F.scaled_dot_product_attention(Q, K, V)
print("Output shape:", output.shape)  # [2, 8, 64, 32]

# With causal mask (for decoder / autoregressive models)
output_causal = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
print("Causal output shape:", output_causal.shape)  # [2, 8, 64, 32]

# Check which backend is being used (requires CUDA)
if torch.cuda.is_available():
    Q_cuda = Q.cuda().half()
    K_cuda = K.cuda().half()
    V_cuda = V.cuda().half()
    with torch.backends.cuda.sdp_kernel(
        enable_flash=True, enable_math=False, enable_mem_efficient=False
    ):
        try:
            out = F.scaled_dot_product_attention(Q_cuda, K_cuda, V_cuda)
            print("Flash Attention is available!")
        except RuntimeError:
            print("Flash Attention not supported on this GPU")
else:
    print("Running on CPU — Flash Attention requires CUDA")
    print("On CPU, scaled_dot_product_attention uses the math backend")

Flash Attention achieves its speed by avoiding the materialization of the full n×n attention matrix in GPU memory. Instead, it tiles the computation into blocks that fit in the GPU's fast SRAM cache, computing the softmax incrementally. The result is mathematically identical to standard attention, but it uses O(n) memory instead of O(n²) and runs 2-4× faster on modern GPUs.

Conclusion & Next Steps

In this article, we built the Transformer architecture from the ground up. We started with the fundamental scaled dot-product attention mechanism (Q, K, V with temperature scaling), extended it to multi-head attention (parallel heads capturing different relationships), added positional encoding (injecting order information), and assembled complete encoder and decoder blocks with residual connections and layer normalization. We then combined everything into a full encoder-decoder Transformer, implemented causal and padding masks, built a Vision Transformer for image classification, and explored efficient attention strategies for long sequences.

Key Takeaways:
  • Attention replaces recurrence with parallel pairwise token interactions
  • Multi-head attention lets the model capture multiple types of relationships simultaneously
  • Positional encoding is essential — without it, Transformers are bag-of-words models
  • Encoder blocks use self-attention; decoder blocks add masked self-attention and cross-attention
  • Causal masks prevent future-token leakage; padding masks ignore meaningless PAD tokens
  • Vision Transformers treat image patches as tokens — same architecture, different modality
  • Flash Attention in PyTorch 2.0+ gives you exact attention with 2-4× speedup for free

Next in the Series

In Part 8: Transfer Learning & Fine-Tuning, we'll learn how to take pretrained Transformer models (like BERT, GPT-2, and ViT) and adapt them to your specific tasks — achieving state-of-the-art results with minimal data and compute.