Back to Technology

Transformers & Attention Mechanism

January 27, 2026 Wasil Zafar 45 min read

Part 8 of 16: Understand the architecture that revolutionized modern NLP—self-attention, multi-head attention, positional encoding, and encoder-decoder structure.

Table of Contents

  1. Introduction to Transformers
  2. The Attention Mechanism
  3. Self-Attention
  4. Multi-Head Attention
  5. Positional Encoding
  6. Encoder Architecture
  7. Decoder Architecture
  8. Training Transformers
  9. Transformer Variants
  10. Conclusion & Next Steps

Introduction to Transformers

The Transformer architecture, introduced in "Attention Is All You Need" (2017), revolutionized NLP by replacing recurrence with self-attention, enabling parallel processing and better long-range dependency modeling.

Key Insight

Transformers use self-attention to weigh the importance of different parts of the input when encoding each position, allowing direct modeling of relationships regardless of distance.

The Attention Mechanism

The attention mechanism is the core innovation that powers Transformers. Unlike RNNs that process sequences step-by-step, attention allows the model to look at all positions simultaneously and decide which parts of the input are most relevant for each output position. This "soft selection" of relevant information is learned during training and adapts to each specific input.

The fundamental idea comes from human cognition: when reading a sentence, you don't give equal weight to every word. Instead, you focus on the most relevant parts based on context. Attention mechanisms formalize this intuition mathematically, enabling neural networks to dynamically focus on different parts of the input when producing each output element.

Why Attention Matters

Attention solves the information bottleneck problem that plagues sequence-to-sequence models. Instead of compressing an entire input sequence into a single fixed-size vector, attention allows the decoder to access all encoder hidden states directly, weighted by their relevance to the current decoding step.

Scaled Dot-Product Attention

Scaled dot-product attention is the specific attention variant used in Transformers. Given a query vector Q, a set of key vectors K, and value vectors V, attention computes a weighted sum of the values, where the weights are determined by the compatibility between the query and each key. The "scaled" part refers to dividing by the square root of the dimension to prevent softmax from producing extremely peaked distributions in high dimensions.

The mathematical formula is: Attention(Q, K, V) = softmax(QK^T / vd_k) × V, where d_k is the dimension of the keys. The dot product QK^T measures similarity between queries and keys, softmax normalizes these scores into a probability distribution, and the result weights the values. This simple yet powerful formulation is highly parallelizable and efficient.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        query: Tensor of shape (batch, num_heads, seq_len, d_k)
        key: Tensor of shape (batch, num_heads, seq_len, d_k)
        value: Tensor of shape (batch, num_heads, seq_len, d_v)
        mask: Optional attention mask
    
    Returns:
        output: Weighted sum of values
        attention_weights: Attention probability distribution
    """
    d_k = query.size(-1)
    
    # Compute attention scores: (batch, heads, seq_len, seq_len)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Apply mask if provided (for padding or causal attention)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Softmax to get attention weights (probabilities)
    attention_weights = F.softmax(scores, dim=-1)
    
    # Weighted sum of values
    output = torch.matmul(attention_weights, value)
    
    return output, attention_weights

# Example usage
batch_size, num_heads, seq_len, d_k = 2, 4, 10, 64
query = torch.randn(batch_size, num_heads, seq_len, d_k)
key = torch.randn(batch_size, num_heads, seq_len, d_k)
value = torch.randn(batch_size, num_heads, seq_len, d_k)

output, weights = scaled_dot_product_attention(query, key, value)
print(f"Output shape: {output.shape}")  # (2, 4, 10, 64)
print(f"Attention weights shape: {weights.shape}")  # (2, 4, 10, 10)
print(f"Weights sum per query position: {weights.sum(dim=-1)[0, 0]}")  # Should be ~1.0

Queries, Keys & Values

The Query-Key-Value (QKV) framework provides an intuitive way to understand attention. Think of it like a retrieval system: the query represents what you're looking for, keys represent labels or indices of stored information, and values contain the actual content to retrieve. The attention mechanism compares the query against all keys to determine which values are most relevant.

In practice, Q, K, and V are computed by projecting the input through three separate learned weight matrices. This allows the model to learn different representations for matching (Q and K) versus content (V). For self-attention, all three come from the same input sequence; for cross-attention (encoder-decoder), Q comes from the decoder while K and V come from the encoder output.

import torch
import torch.nn as nn

class QKVProjection(nn.Module):
    """Project input into Query, Key, and Value representations."""
    
    def __init__(self, d_model, d_k, d_v):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_k)  # Query projection
        self.W_k = nn.Linear(d_model, d_k)  # Key projection
        self.W_v = nn.Linear(d_model, d_v)  # Value projection
    
    def forward(self, x_query, x_key, x_value):
        """
        For self-attention: x_query = x_key = x_value = input
        For cross-attention: x_query = decoder, x_key = x_value = encoder
        """
        Q = self.W_q(x_query)  # What am I looking for?
        K = self.W_k(x_key)    # What information is available?
        V = self.W_v(x_value)  # What content should I retrieve?
        return Q, K, V

# Example: Self-attention QKV
d_model, d_k, d_v = 512, 64, 64
seq_len, batch_size = 20, 4

qkv_proj = QKVProjection(d_model, d_k, d_v)
input_seq = torch.randn(batch_size, seq_len, d_model)

# Self-attention: same input for all three
Q, K, V = qkv_proj(input_seq, input_seq, input_seq)
print(f"Query shape: {Q.shape}")  # (4, 20, 64)
print(f"Key shape: {K.shape}")    # (4, 20, 64)
print(f"Value shape: {V.shape}")  # (4, 20, 64)

# Cross-attention: different sources
encoder_output = torch.randn(batch_size, 30, d_model)  # Source sequence
decoder_input = torch.randn(batch_size, 10, d_model)   # Target sequence

Q_cross, K_cross, V_cross = qkv_proj(decoder_input, encoder_output, encoder_output)
print(f"\nCross-attention Query: {Q_cross.shape}")  # (4, 10, 64) - from decoder
print(f"Cross-attention Key: {K_cross.shape}")    # (4, 30, 64) - from encoder
print(f"Cross-attention Value: {V_cross.shape}")  # (4, 30, 64) - from encoder

Attention Score Interpretation

Visualization Understanding

Attention weights tell us which input positions the model "focuses on" when processing each position:

  • High weight: Strong relevance (e.g., pronoun attending to its antecedent)
  • Uniform weights: Global context aggregation (common in early layers)
  • Sparse patterns: Syntactic relationships (later layers often show grammatical structure)
  • Diagonal patterns: Position-sensitive features (local context)

Self-Attention

Self-attention is the specific application of attention where a sequence attends to itself. Each position in the input sequence can directly interact with every other position, creating rich contextual representations. This is the key mechanism that allows Transformers to capture long-range dependencies without the sequential bottleneck of RNNs—a token at position 1 can directly attend to position 1000 in a single operation.

In self-attention, the same input sequence is used to generate all three components: queries, keys, and values. Each token asks "what other tokens in this sequence are relevant to me?" and receives information weighted by the answers. This bidirectional context (in encoders) or causal context (in decoders) enables sophisticated understanding of language structure, semantics, and relationships.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    """Single-head self-attention layer."""
    
    def __init__(self, d_model, d_k=None, d_v=None):
        super().__init__()
        self.d_k = d_k or d_model
        self.d_v = d_v or d_model
        
        # Projection matrices
        self.W_q = nn.Linear(d_model, self.d_k)
        self.W_k = nn.Linear(d_model, self.d_k)
        self.W_v = nn.Linear(d_model, self.d_v)
        
    def forward(self, x, mask=None):
        """
        Self-attention: input attends to itself.
        
        Args:
            x: Input tensor (batch, seq_len, d_model)
            mask: Optional attention mask
        Returns:
            output: Contextualized representations
            weights: Attention weight matrix
        """
        # Project input to Q, K, V
        Q = self.W_q(x)  # (batch, seq_len, d_k)
        K = self.W_k(x)  # (batch, seq_len, d_k)
        V = self.W_v(x)  # (batch, seq_len, d_v)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

# Example: Self-attention on a sentence
d_model = 256
self_attn = SelfAttention(d_model)

# Simulated embedded sentence: "The cat sat on the mat"
batch_size, seq_len = 1, 6
sentence_embeddings = torch.randn(batch_size, seq_len, d_model)

contextualized, weights = self_attn(sentence_embeddings)
print(f"Input shape: {sentence_embeddings.shape}")
print(f"Output shape: {contextualized.shape}")  # Same as input
print(f"Attention matrix shape: {weights.shape}")  # (1, 6, 6)
print(f"\nAttention from position 0 to all positions:")
print(weights[0, 0, :])  # How much token 0 attends to each token

Causal (Masked) Self-Attention

In language modeling and decoder architectures, we need causal self-attention where each position can only attend to previous positions (including itself). This prevents "cheating" by looking at future tokens during training. The causal mask is a lower-triangular matrix that sets attention weights to negative infinity for future positions, ensuring zero probability after softmax.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def create_causal_mask(seq_len):
    """Create causal attention mask (lower triangular)."""
    # 1s for positions we can attend to, 0s for positions to mask
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

class CausalSelfAttention(nn.Module):
    """Self-attention with causal masking for autoregressive models."""
    
    def __init__(self, d_model, max_seq_len=512):
        super().__init__()
        self.d_model = d_model
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Register causal mask as buffer (not a parameter)
        mask = create_causal_mask(max_seq_len)
        self.register_buffer('causal_mask', mask)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_model)
        
        # Apply causal mask
        mask = self.causal_mask[:seq_len, :seq_len]
        scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

# Example: Causal attention for language modeling
causal_attn = CausalSelfAttention(d_model=256)

# Sequence: positions can only see themselves and earlier positions
sequence = torch.randn(2, 8, 256)  # batch=2, seq_len=8
output, weights = causal_attn(sequence)

print("Causal attention weights (batch 0):")
print(weights[0].round(decimals=2))
# Lower triangular pattern - position i only attends to positions 0..i

Self-Attention Complexity

Time complexity: O(n² × d) where n is sequence length and d is dimension. This quadratic scaling with sequence length is the main limitation of Transformers, motivating efficient variants like Linformer, Performer, and sparse attention patterns for very long sequences.

Multi-Head Attention

Multi-head attention is one of the most important innovations in the Transformer architecture. Instead of performing a single attention function, we run multiple attention operations ("heads") in parallel, each with different learned projections. This allows the model to jointly attend to information from different representation subspaces at different positions—capturing various types of relationships simultaneously.

Each head operates on a reduced dimension (d_model / num_heads), so the total computation is similar to single-head attention at full dimension. The outputs from all heads are concatenated and projected through a final linear layer. Different heads often learn to focus on different linguistic phenomena: some capture syntactic dependencies, others semantic relationships, positional patterns, or coreference chains.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention as described in "Attention Is All You Need".
    """
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head
        
        # Linear projections for Q, K, V (combined for efficiency)
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, d_k)."""
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)  # (batch, heads, seq_len, d_k)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Linear projections
        Q = self.W_q(query)  # (batch, seq_len, d_model)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # Split into multiple heads
        Q = self.split_heads(Q, batch_size)  # (batch, heads, seq_len, d_k)
        K = self.split_heads(K, batch_size)
        V = self.split_heads(V, batch_size)
        
        # Scaled dot-product attention for each head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            # Expand mask for heads dimension
            if mask.dim() == 2:
                mask = mask.unsqueeze(0).unsqueeze(0)
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention to values
        context = torch.matmul(attention_weights, V)  # (batch, heads, seq_len, d_k)
        
        # Concatenate heads
        context = context.transpose(1, 2).contiguous()  # (batch, seq_len, heads, d_k)
        context = context.view(batch_size, -1, self.d_model)  # (batch, seq_len, d_model)
        
        # Final linear projection
        output = self.W_o(context)
        
        return output, attention_weights

# Example: Multi-head attention
d_model, num_heads = 512, 8
mha = MultiHeadAttention(d_model, num_heads)

batch_size, seq_len = 4, 20
x = torch.randn(batch_size, seq_len, d_model)

# Self-attention (query=key=value)
output, weights = mha(x, x, x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")  # Same as input
print(f"Attention weights shape: {weights.shape}")  # (batch, heads, seq_len, seq_len)
print(f"Each head's d_k: {d_model // num_heads}")  # 64

Visualizing Attention Heads

Different attention heads specialize in capturing different types of patterns. Researchers have found that specific heads often correspond to interpretable linguistic relationships, such as subject-verb agreement, coreference resolution, or positional proximity. Analyzing these patterns helps us understand what Transformers learn.

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention_heads(attention_weights, tokens, num_heads_to_show=4):
    """
    Visualize attention patterns for multiple heads.
    
    Args:
        attention_weights: Tensor (num_heads, seq_len, seq_len)
        tokens: List of token strings
        num_heads_to_show: Number of heads to display
    """
    fig, axes = plt.subplots(1, num_heads_to_show, figsize=(4*num_heads_to_show, 4))
    
    for i in range(num_heads_to_show):
        ax = axes[i] if num_heads_to_show > 1 else axes
        weights = attention_weights[i].detach().cpu().numpy()
        
        sns.heatmap(weights, ax=ax, cmap='Blues', 
                   xticklabels=tokens, yticklabels=tokens,
                   cbar=i == num_heads_to_show - 1)
        ax.set_title(f'Head {i+1}')
        ax.set_xlabel('Key Position')
        ax.set_ylabel('Query Position')
    
    plt.tight_layout()
    plt.savefig('attention_heads.png', dpi=150, bbox_inches='tight')
    plt.show()

# Example: Analyze attention patterns
tokens = ['The', 'cat', 'sat', 'on', 'the', 'mat', '.']
seq_len = len(tokens)
num_heads = 8
d_model = 256

mha = MultiHeadAttention(d_model, num_heads)
x = torch.randn(1, seq_len, d_model)
_, weights = mha(x, x, x)

# weights shape: (1, 8, 7, 7) -> squeeze batch dimension
print(f"Attention weights per head: {weights.shape}")
visualize_attention_heads(weights[0], tokens, num_heads_to_show=4)

What Different Heads Learn

Research Finding BERT Analysis

Studies on BERT's attention heads have revealed specialized patterns:

  • Positional heads: Attend to adjacent positions (local context)
  • Separator heads: Focus on [SEP] and [CLS] tokens
  • Syntactic heads: Track dependency parse relationships
  • Coreference heads: Link pronouns to their referents
  • Rare word heads: Redistribute attention when encountering OOV tokens

Reference: Clark et al. (2019) "What Does BERT Look At?"

Positional Encoding

Unlike RNNs that process tokens sequentially (inherently capturing position), Transformers process all positions in parallel. This parallelism is a strength for training efficiency but means the model has no inherent notion of word order. Without positional information, "dog bites man" and "man bites dog" would be indistinguishable! Positional encoding solves this by injecting position information into the input embeddings.

The original Transformer uses sinusoidal positional encoding: fixed patterns of sines and cosines at different frequencies. This approach has elegant mathematical properties—relative positions can be computed via linear transformations, and the encoding generalizes to sequence lengths longer than seen during training. Modern variants like BERT use learned positional embeddings, while others use relative positional encodings (T5) or rotary embeddings (RoPE in LLaMA).

import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt

class SinusoidalPositionalEncoding(nn.Module):
    """
    Sinusoidal Positional Encoding from "Attention Is All You Need".
    
    PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
    """
    
    def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        
        # Compute the div_term: 10000^(2i/d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                            (-math.log(10000.0) / d_model))
        
        # Apply sin to even indices, cos to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension and register as buffer (not a parameter)
        pe = pe.unsqueeze(0)  # (1, max_seq_len, d_model)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        """
        Add positional encoding to input embeddings.
        
        Args:
            x: Input embeddings (batch, seq_len, d_model)
        Returns:
            x + positional encoding
        """
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)

# Create and visualize positional encoding
d_model = 512
max_len = 100
pos_encoding = SinusoidalPositionalEncoding(d_model, max_len, dropout=0.0)

# Get the encoding matrix
pe_matrix = pos_encoding.pe[0, :50, :].numpy()  # First 50 positions

plt.figure(figsize=(12, 6))
plt.imshow(pe_matrix.T, cmap='RdBu', aspect='auto')
plt.colorbar(label='Encoding Value')
plt.xlabel('Position in Sequence')
plt.ylabel('Embedding Dimension')
plt.title('Sinusoidal Positional Encoding (first 50 positions)')
plt.savefig('positional_encoding.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Positional encoding shape: {pos_encoding.pe.shape}")

Learned vs. Fixed Positional Encodings

While the original Transformer used fixed sinusoidal encodings, many modern models (BERT, GPT-2) use learned positional embeddings—simply another embedding table indexed by position. Both approaches achieve similar performance, but learned embeddings are conceptually simpler and may better adapt to specific tasks. The tradeoff is that learned embeddings don't extrapolate to unseen sequence lengths.

import torch
import torch.nn as nn

class LearnedPositionalEmbedding(nn.Module):
    """Learned positional embeddings (used in BERT, GPT-2)."""
    
    def __init__(self, d_model, max_seq_len=512, dropout=0.1):
        super().__init__()
        self.position_embeddings = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Initialize with small random values
        nn.init.normal_(self.position_embeddings.weight, std=0.02)
        
    def forward(self, x):
        """
        Args:
            x: Input embeddings (batch, seq_len, d_model)
        """
        batch_size, seq_len, _ = x.shape
        
        # Create position indices: [0, 1, 2, ..., seq_len-1]
        position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0)
        position_ids = position_ids.expand(batch_size, -1)  # (batch, seq_len)
        
        # Look up positional embeddings
        position_embeddings = self.position_embeddings(position_ids)
        
        # Add to input embeddings
        x = x + position_embeddings
        return self.dropout(x)

# Compare both approaches
d_model = 256
sinusoidal_pe = SinusoidalPositionalEncoding(d_model)
learned_pe = LearnedPositionalEmbedding(d_model)

x = torch.randn(2, 20, d_model)  # batch=2, seq_len=20

out_sin = sinusoidal_pe(x)
out_learn = learned_pe(x)

print(f"Input shape: {x.shape}")
print(f"After sinusoidal PE: {out_sin.shape}")
print(f"After learned PE: {out_learn.shape}")
print(f"\nLearned PE is trainable: {learned_pe.position_embeddings.weight.requires_grad}")

Why Sinusoidal Encodings Work

The sinusoidal encoding has a key property: relative positions can be computed via linear transformation. For any fixed offset k, PE(pos+k) can be expressed as a linear function of PE(pos). This makes it easier for the model to learn to attend to relative positions. The different frequencies (wavelengths from 2p to 10000×2p) encode position at different scales, similar to binary representation but continuous.

Encoder Architecture

The Transformer encoder is a stack of identical layers, each containing two sub-layers: multi-head self-attention and a position-wise feed-forward network. Each sub-layer uses a residual connection followed by layer normalization: output = LayerNorm(x + Sublayer(x)). This architecture enables deep stacking without vanishing gradients and allows information to flow directly through residual paths.

The encoder processes the entire input sequence in parallel, producing contextualized representations for each position. These representations capture bidirectional context—each token's representation is influenced by all other tokens in the sequence. The original Transformer uses 6 encoder layers; BERT-base uses 12, and BERT-large uses 24. More layers generally improve capacity but increase computational cost and risk of overfitting.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionwiseFeedForward(nn.Module):
    """
    Position-wise Feed-Forward Network.
    FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
    
    Typically d_ff = 4 * d_model (e.g., 2048 for d_model=512)
    """
    
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # Expand -> ReLU -> Contract
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class EncoderLayer(nn.Module):
    """Single Transformer encoder layer."""
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self-attention with residual connection and layer norm
        attn_output, attn_weights = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))
        
        # Feed-forward with residual connection and layer norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output))
        
        return x, attn_weights

# Example: Single encoder layer
d_model, num_heads, d_ff = 512, 8, 2048
encoder_layer = EncoderLayer(d_model, num_heads, d_ff)

batch_size, seq_len = 4, 20
x = torch.randn(batch_size, seq_len, d_model)

output, weights = encoder_layer(x)
print(f"Encoder layer input: {x.shape}")
print(f"Encoder layer output: {output.shape}")  # Same shape
print(f"Attention weights: {weights.shape}")

Complete Encoder Stack

The full encoder stacks multiple encoder layers sequentially. Input tokens are first embedded and combined with positional encoding, then processed through each layer. The output is a sequence of contextualized vectors, one for each input position, that serve as the input's "encoded" representation for downstream tasks or the decoder.

import torch
import torch.nn as nn

class TransformerEncoder(nn.Module):
    """Complete Transformer encoder stack."""
    
    def __init__(self, vocab_size, d_model, num_heads, d_ff, 
                 num_layers, max_seq_len, dropout=0.1):
        super().__init__()
        
        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.scale = math.sqrt(d_model)  # Scaling factor from paper
        
        # Positional encoding
        self.positional_encoding = SinusoidalPositionalEncoding(
            d_model, max_seq_len, dropout
        )
        
        # Stack of encoder layers
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        """
        Args:
            x: Input token IDs (batch, seq_len)
            mask: Padding mask (batch, seq_len)
        Returns:
            Encoded representations (batch, seq_len, d_model)
        """
        # Embed tokens and scale
        x = self.token_embedding(x) * self.scale
        
        # Add positional encoding
        x = self.positional_encoding(x)
        
        # Process through encoder layers
        all_attention_weights = []
        for layer in self.layers:
            x, attn_weights = layer(x, mask)
            all_attention_weights.append(attn_weights)
        
        return x, all_attention_weights

# Create encoder
vocab_size = 30000
d_model = 512
num_heads = 8
d_ff = 2048
num_layers = 6
max_seq_len = 512

encoder = TransformerEncoder(
    vocab_size, d_model, num_heads, d_ff, num_layers, max_seq_len
)

# Process a batch of sequences
input_ids = torch.randint(0, vocab_size, (4, 50))  # batch=4, seq_len=50
encoded, all_weights = encoder(input_ids)

print(f"Input token IDs: {input_ids.shape}")
print(f"Encoded output: {encoded.shape}")
print(f"Number of layer attention weights: {len(all_weights)}")
print(f"\nTotal parameters: {sum(p.numel() for p in encoder.parameters()):,}")

Encoder Layer Composition

Architecture Components

Each encoder layer contains these components in order:

  1. Multi-Head Self-Attention: Contextualizes each position with all others
  2. Add & Norm: Residual connection + Layer Normalization
  3. Position-wise FFN: Two linear layers with ReLU (expands then contracts)
  4. Add & Norm: Another residual connection + Layer Normalization

The FFN can be thought of as processing each position independently (like a 1x1 convolution), adding non-linear capacity between attention layers.

Decoder Architecture

The Transformer decoder generates output sequences autoregressively—one token at a time, conditioning on both the encoder output and previously generated tokens. Each decoder layer has three sub-layers: masked self-attention (to prevent looking at future tokens), cross-attention (to attend to encoder output), and a position-wise feed-forward network. This structure allows the decoder to integrate source information while maintaining autoregressive generation.

The key difference from the encoder is the causal mask in self-attention, ensuring position i can only attend to positions = i. During training, we can process entire target sequences in parallel using this mask ("teacher forcing"). During inference, we generate one token at a time, feeding each generated token back as input for the next step. This asymmetry is why modern decoder-only models (GPT) focus on efficient autoregressive generation.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class DecoderLayer(nn.Module):
    """Single Transformer decoder layer with masked self-attention + cross-attention."""
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # Masked self-attention (causal)
        self.masked_self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Cross-attention (decoder attends to encoder)
        self.cross_attention = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Feed-forward network
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        
        # Layer normalizations
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        # Dropout layers
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: Decoder input (batch, tgt_seq_len, d_model)
            encoder_output: From encoder (batch, src_seq_len, d_model)
            src_mask: Padding mask for source
            tgt_mask: Causal mask for target
        """
        # 1. Masked self-attention
        self_attn_out, self_attn_weights = self.masked_self_attention(
            x, x, x, mask=tgt_mask
        )
        x = self.norm1(x + self.dropout1(self_attn_out))
        
        # 2. Cross-attention (query from decoder, key/value from encoder)
        cross_attn_out, cross_attn_weights = self.cross_attention(
            x, encoder_output, encoder_output, mask=src_mask
        )
        x = self.norm2(x + self.dropout2(cross_attn_out))
        
        # 3. Feed-forward
        ff_out = self.feed_forward(x)
        x = self.norm3(x + self.dropout3(ff_out))
        
        return x, self_attn_weights, cross_attn_weights

# Example: Single decoder layer
d_model, num_heads, d_ff = 512, 8, 2048
decoder_layer = DecoderLayer(d_model, num_heads, d_ff)

batch_size = 4
src_seq_len = 20  # Source sequence length
tgt_seq_len = 15  # Target sequence length

# Simulated inputs
encoder_output = torch.randn(batch_size, src_seq_len, d_model)
decoder_input = torch.randn(batch_size, tgt_seq_len, d_model)

# Create causal mask for decoder
causal_mask = torch.tril(torch.ones(tgt_seq_len, tgt_seq_len))

output, self_attn, cross_attn = decoder_layer(
    decoder_input, encoder_output, tgt_mask=causal_mask
)

print(f"Decoder input: {decoder_input.shape}")
print(f"Encoder output: {encoder_output.shape}")
print(f"Decoder output: {output.shape}")
print(f"Self-attention weights: {self_attn.shape}")  # (batch, heads, tgt, tgt)
print(f"Cross-attention weights: {cross_attn.shape}")  # (batch, heads, tgt, src)

Complete Decoder Stack

The full decoder stacks multiple decoder layers, each receiving the encoder output for cross-attention. The final decoder output is projected through a linear layer to vocabulary size, followed by softmax to produce next-token probabilities. During generation, we use sampling strategies (greedy, beam search, nucleus sampling) to select the next token.

import torch
import torch.nn as nn
import math

class TransformerDecoder(nn.Module):
    """Complete Transformer decoder stack."""
    
    def __init__(self, vocab_size, d_model, num_heads, d_ff,
                 num_layers, max_seq_len, dropout=0.1):
        super().__init__()
        
        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.scale = math.sqrt(d_model)
        
        # Positional encoding
        self.positional_encoding = SinusoidalPositionalEncoding(
            d_model, max_seq_len, dropout
        )
        
        # Stack of decoder layers
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        
        # Output projection to vocabulary
        self.output_projection = nn.Linear(d_model, vocab_size)
        
    def generate_causal_mask(self, seq_len, device):
        """Generate causal attention mask."""
        mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
        return mask
    
    def forward(self, tgt_ids, encoder_output, src_mask=None):
        """
        Args:
            tgt_ids: Target token IDs (batch, tgt_seq_len)
            encoder_output: From encoder (batch, src_seq_len, d_model)
            src_mask: Source padding mask
        Returns:
            logits: Vocabulary logits (batch, tgt_seq_len, vocab_size)
        """
        tgt_seq_len = tgt_ids.size(1)
        
        # Embed and scale
        x = self.token_embedding(tgt_ids) * self.scale
        x = self.positional_encoding(x)
        
        # Generate causal mask
        tgt_mask = self.generate_causal_mask(tgt_seq_len, x.device)
        
        # Process through decoder layers
        for layer in self.layers:
            x, _, _ = layer(x, encoder_output, src_mask, tgt_mask)
        
        # Project to vocabulary
        logits = self.output_projection(x)
        
        return logits

# Create decoder
vocab_size = 30000
decoder = TransformerDecoder(
    vocab_size, d_model=512, num_heads=8, d_ff=2048,
    num_layers=6, max_seq_len=512
)

# Process target sequence
tgt_ids = torch.randint(0, vocab_size, (4, 15))  # batch=4, tgt_len=15
encoder_out = torch.randn(4, 20, 512)  # From encoder

logits = decoder(tgt_ids, encoder_out)
print(f"Target token IDs: {tgt_ids.shape}")
print(f"Output logits: {logits.shape}")  # (4, 15, 30000)
print(f"Next token probabilities: {F.softmax(logits[:, -1, :], dim=-1).shape}")

Cross-Attention: The Bridge

Cross-attention connects encoder and decoder. The decoder's queries attend to the encoder's keys and values, allowing each decoder position to selectively access source information. This is where the decoder "reads" the input—for translation, cross-attention might link a German noun to its English translation; for summarization, it identifies salient source sentences.

Training Transformers

Training Transformers involves several key techniques that differ from traditional neural networks. The original paper used Adam optimizer with a custom learning rate schedule (warmup followed by decay), label smoothing for regularization, and dropout applied to attention weights, residual connections, and embeddings. These choices remain influential in modern large language model training.

The learning rate schedule is particularly important: starting with a very small learning rate, warming up linearly for a number of steps, then decaying proportionally to the inverse square root of the step number. This prevents early training instability while allowing efficient convergence. Modern practice often uses cosine annealing or other schedules, but warmup remains essential for stable training of deep Transformers.

import torch
import torch.nn as nn
import torch.optim as optim
import math

class TransformerLRScheduler:
    """
    Learning rate scheduler from "Attention Is All You Need".
    
    lrate = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))
    """
    
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.current_step = 0
        
    def step(self):
        self.current_step += 1
        lr = self.compute_lr(self.current_step)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr
    
    def compute_lr(self, step):
        # Formula from the paper
        arg1 = step ** (-0.5)
        arg2 = step * (self.warmup_steps ** (-1.5))
        return (self.d_model ** (-0.5)) * min(arg1, arg2)

# Visualize learning rate schedule
import matplotlib.pyplot as plt

d_model = 512
warmup_steps = 4000
steps = list(range(1, 100000))
lrs = [TransformerLRScheduler.compute_lr(None, s) for s in steps]

plt.figure(figsize=(10, 5))
plt.plot(steps, [TransformerLRScheduler(None, d_model, warmup_steps).compute_lr(s) 
                 for s in steps])
plt.xlabel('Training Step')
plt.ylabel('Learning Rate')
plt.title('Transformer Learning Rate Schedule (warmup + inverse sqrt decay)')
plt.axvline(x=warmup_steps, color='r', linestyle='--', label=f'Warmup ends ({warmup_steps})')
plt.legend()
plt.savefig('lr_schedule.png', dpi=150, bbox_inches='tight')
plt.show()

Full Training Loop Example

Here's a complete training loop for a Transformer, including teacher forcing (feeding ground truth tokens during training), cross-entropy loss computation, and gradient clipping to prevent exploding gradients—a common issue in deep networks.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

def train_transformer(model, train_loader, num_epochs, d_model, 
                      warmup_steps=4000, max_grad_norm=1.0,
                      label_smoothing=0.1):
    """
    Complete training loop for encoder-decoder Transformer.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Loss with label smoothing
    criterion = nn.CrossEntropyLoss(
        ignore_index=0,  # Padding token
        label_smoothing=label_smoothing
    )
    
    # Optimizer
    optimizer = optim.Adam(
        model.parameters(),
        lr=0,  # Will be set by scheduler
        betas=(0.9, 0.98),
        eps=1e-9
    )
    
    # Learning rate scheduler
    scheduler = TransformerLRScheduler(optimizer, d_model, warmup_steps)
    
    # Training loop
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        
        for batch_idx, (src, tgt) in enumerate(train_loader):
            src = src.to(device)
            tgt = tgt.to(device)
            
            # Target input (shifted right) and output
            tgt_input = tgt[:, :-1]  # Remove last token
            tgt_output = tgt[:, 1:]  # Remove first token (usually )
            
            # Forward pass
            optimizer.zero_grad()
            logits = model(src, tgt_input)  # (batch, seq_len, vocab_size)
            
            # Compute loss
            loss = criterion(
                logits.reshape(-1, logits.size(-1)),  # (batch*seq, vocab)
                tgt_output.reshape(-1)                 # (batch*seq,)
            )
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            clip_grad_norm_(model.parameters(), max_grad_norm)
            
            # Update weights and learning rate
            optimizer.step()
            lr = scheduler.step()
            
            total_loss += loss.item()
            
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}, "
                      f"Loss: {loss.item():.4f}, LR: {lr:.6f}")
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} complete. Average loss: {avg_loss:.4f}")

# Note: This is a training template - you'd need actual data loaders
# and a complete model to run this

Training Hyperparameters from the Paper

Original Config Vaswani et al.
OptimizerAdam with ß1=0.9, ß2=0.98, e=10??
Warmup Steps4,000
Dropout0.1 (attention, residual, embeddings)
Label Smoothinge_ls = 0.1
Batch Size~25,000 source + target tokens
Training Steps100,000 (base), 300,000 (big)

Transformer Variants

Since the original Transformer, researchers have developed numerous variants optimizing for different use cases. Encoder-only models like BERT excel at understanding tasks (classification, NER, QA). Decoder-only models like GPT dominate generation tasks (text completion, dialogue). Encoder-decoder models like T5 and BART handle sequence-to-sequence tasks (translation, summarization). Each architecture makes different tradeoffs between bidirectional understanding and autoregressive generation.

Beyond architectural choices, variants address the O(n²) attention complexity that limits sequence length. Sparse attention (Longformer, BigBird) uses local + global patterns. Linear attention (Performer, Linear Transformer) approximates attention in O(n) time. Hierarchical approaches (Hierarchical Transformers) process documents at multiple scales. These innovations enable processing of very long documents that would be prohibitive with standard attention.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class EncoderOnlyTransformer(nn.Module):
    """
    BERT-style encoder-only Transformer for classification/understanding.
    Uses [CLS] token representation for sentence-level tasks.
    """
    
    def __init__(self, vocab_size, d_model, num_heads, d_ff,
                 num_layers, num_classes, max_seq_len=512, dropout=0.1):
        super().__init__()
        
        self.encoder = TransformerEncoder(
            vocab_size, d_model, num_heads, d_ff, num_layers, max_seq_len, dropout
        )
        
        # Classification head on [CLS] token (position 0)
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(d_model, num_classes)
        )
        
    def forward(self, input_ids, mask=None):
        # Encode entire sequence
        encoded, _ = self.encoder(input_ids, mask)
        
        # Use [CLS] token representation (position 0)
        cls_representation = encoded[:, 0, :]  # (batch, d_model)
        
        # Classify
        logits = self.classifier(cls_representation)  # (batch, num_classes)
        return logits

# Example: Sentiment classification
bert_style = EncoderOnlyTransformer(
    vocab_size=30000, d_model=768, num_heads=12, d_ff=3072,
    num_layers=12, num_classes=3  # negative, neutral, positive
)

input_ids = torch.randint(0, 30000, (4, 128))  # batch=4, seq_len=128
logits = bert_style(input_ids)
print(f"Classification logits: {logits.shape}")  # (4, 3)

Decoder-Only Transformer (GPT-style)

Decoder-only models use causal self-attention throughout, trained with language modeling objective (predict next token). This simple architecture scales remarkably well and has become the dominant paradigm for large language models. GPT-3, GPT-4, LLaMA, and other modern LLMs all use this architecture.

import torch
import torch.nn as nn
import math

class DecoderOnlyTransformer(nn.Module):
    """
    GPT-style decoder-only Transformer for text generation.
    Uses causal masking throughout - no encoder needed.
    """
    
    def __init__(self, vocab_size, d_model, num_heads, d_ff,
                 num_layers, max_seq_len=1024, dropout=0.1):
        super().__init__()
        
        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        self.scale = math.sqrt(d_model)
        
        # Causal self-attention layers (no cross-attention)
        self.layers = nn.ModuleList([
            CausalDecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Output head (often tied with input embedding weights)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
        # Weight tying (optional but common)
        self.lm_head.weight = self.token_embedding.weight
        
    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape
        
        # Embeddings
        positions = torch.arange(seq_len, device=input_ids.device)
        x = self.token_embedding(input_ids) * self.scale
        x = x + self.position_embedding(positions)
        x = self.dropout(x)
        
        # Process through layers with causal masking
        for layer in self.layers:
            x = layer(x)
        
        x = self.norm(x)
        
        # Project to vocabulary
        logits = self.lm_head(x)  # (batch, seq_len, vocab_size)
        
        return logits
    
    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens, temperature=1.0):
        """Autoregressive generation."""
        for _ in range(max_new_tokens):
            # Get logits for last position
            logits = self(input_ids)[:, -1, :]  # (batch, vocab)
            
            # Apply temperature
            logits = logits / temperature
            
            # Sample next token
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to sequence
            input_ids = torch.cat([input_ids, next_token], dim=1)
        
        return input_ids

class CausalDecoderLayer(nn.Module):
    """Decoder layer with only causal self-attention (no cross-attention)."""
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        seq_len = x.size(1)
        # Causal mask
        mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device))
        
        # Self-attention with causal masking
        attn_out, _ = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_out))
        
        # Feed-forward
        ff_out = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_out))
        
        return x

# Example: GPT-style generation
gpt_style = DecoderOnlyTransformer(
    vocab_size=50000, d_model=768, num_heads=12, d_ff=3072, num_layers=12
)

prompt = torch.randint(0, 50000, (1, 10))  # Starting prompt
generated = gpt_style.generate(prompt, max_new_tokens=20, temperature=0.8)
print(f"Prompt length: 10, Generated length: {generated.shape[1]}")

Comparison of Transformer Variants

Architecture Use Cases
TypeExamplesBest ForAttention
Encoder-only BERT, RoBERTa, ALBERT Classification, NER, QA Bidirectional
Decoder-only GPT, LLaMA, PaLM Generation, Completion Causal (left-to-right)
Encoder-Decoder T5, BART, mT5 Translation, Summarization Bidirectional + Causal

Complete Encoder-Decoder Transformer

Finally, here's a complete implementation combining encoder and decoder into the full "Attention Is All You Need" architecture, suitable for machine translation and other sequence-to-sequence tasks.

import torch
import torch.nn as nn
import math

class Transformer(nn.Module):
    """
    Complete Transformer (Encoder-Decoder) from "Attention Is All You Need".
    Suitable for machine translation and other seq2seq tasks.
    """
    
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512,
                 num_heads=8, d_ff=2048, num_encoder_layers=6,
                 num_decoder_layers=6, max_seq_len=512, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        
        # Source embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.src_pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
        
        # Target embeddings
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.tgt_pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
        
        # Encoder
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])
        
        # Decoder
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])
        
        # Output projection
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
        
        # Initialize parameters
        self._init_parameters()
        
    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def encode(self, src, src_mask=None):
        """Encode source sequence."""
        x = self.src_embedding(src) * math.sqrt(self.d_model)
        x = self.src_pos_encoding(x)
        
        for layer in self.encoder_layers:
            x, _ = layer(x, src_mask)
        
        return x
    
    def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
        """Decode target sequence given encoder output."""
        x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        x = self.tgt_pos_encoding(x)
        
        for layer in self.decoder_layers:
            x, _, _ = layer(x, encoder_output, src_mask, tgt_mask)
        
        return x
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Full forward pass for training.
        
        Args:
            src: Source token IDs (batch, src_len)
            tgt: Target token IDs (batch, tgt_len)
        Returns:
            logits: (batch, tgt_len, tgt_vocab_size)
        """
        # Generate causal mask for target
        tgt_seq_len = tgt.size(1)
        if tgt_mask is None:
            tgt_mask = torch.tril(torch.ones(tgt_seq_len, tgt_seq_len, device=tgt.device))
        
        # Encode source
        encoder_output = self.encode(src, src_mask)
        
        # Decode target
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        
        # Project to vocabulary
        logits = self.output_projection(decoder_output)
        
        return logits

# Example: Machine Translation Transformer
transformer = Transformer(
    src_vocab_size=32000,  # German vocabulary
    tgt_vocab_size=32000,  # English vocabulary
    d_model=512,
    num_heads=8,
    d_ff=2048,
    num_encoder_layers=6,
    num_decoder_layers=6
)

# Simulated translation batch
src = torch.randint(0, 32000, (4, 30))  # German sentences
tgt = torch.randint(0, 32000, (4, 25))  # English sentences (shifted)

logits = transformer(src, tgt)
print(f"Source: {src.shape}")
print(f"Target: {tgt.shape}")
print(f"Output logits: {logits.shape}")

# Count parameters (similar to original paper's base model)
total_params = sum(p.numel() for p in transformer.parameters())
print(f"\nTotal parameters: {total_params:,}")
print(f"~{total_params / 1e6:.1f}M parameters")

Scaling Laws

Research has shown that Transformer performance scales predictably with compute, data, and parameters (Kaplan et al., 2020). Larger models are more sample-efficient—they achieve the same loss with less training data. This insight drove the development of GPT-3 (175B), PaLM (540B), and other massive models. The key is balancing model size, dataset size, and compute budget according to these scaling laws.

Conclusion & Next Steps

The Transformer architecture fundamentally changed natural language processing by replacing recurrence with self-attention. This simple but powerful idea—that sequences can be processed in parallel by allowing every position to attend to every other position—has proven remarkably effective. From the original encoder-decoder model for translation to BERT's bidirectional encoding and GPT's autoregressive generation, Transformers have become the foundation of modern NLP.

Key concepts to remember: scaled dot-product attention computes compatibility between queries and keys to weight values; multi-head attention enables learning multiple relationship types simultaneously; positional encoding injects sequence order information; the encoder builds bidirectional representations while the decoder generates outputs autoregressively with causal masking. These components combine into a highly effective architecture that scales well with data and compute.

Understanding Transformers deeply is essential for working with modern NLP. The architecture underlies virtually every state-of-the-art model, and design choices in attention, normalization, and training have cascading effects on model behavior. As you continue through this series, we'll explore how pretraining and fine-tuning leverage this architecture in BERT, GPT, and beyond.

Key Takeaways

Summary Part 8
  • Attention replaces recurrence: O(1) path length between any positions vs. O(n) for RNNs
  • Self-attention: Each position attends to all positions in the same sequence
  • Multi-head attention: Multiple parallel attention heads capture different relationships
  • Positional encoding: Sinusoidal or learned embeddings inject position information
  • Encoder: Stack of self-attention + FFN layers for bidirectional encoding
  • Decoder: Masked self-attention + cross-attention + FFN for generation
  • Training tricks: Warmup learning rate, label smoothing, dropout throughout
  • Variants: Encoder-only (BERT), decoder-only (GPT), encoder-decoder (T5)

What's Next

In Part 9: Pretrained Language Models & Transfer Learning, we'll explore how pretraining on massive text corpora and fine-tuning on specific tasks has revolutionized NLP. You'll learn about BERT's masked language modeling, pretraining objectives, fine-tuning strategies, and how transfer learning enables powerful NLP with limited labeled data.

Practice Exercises

To solidify your understanding of Transformers:

  1. Implement attention visualization: Train a small Transformer and visualize what different heads attend to
  2. Compare architectures: Implement encoder-only, decoder-only, and encoder-decoder variants; compare on appropriate tasks
  3. Experiment with positional encoding: Try learned vs. sinusoidal encodings; what happens without any positional encoding?
  4. Build a mini-GPT: Train a small character-level language model using the decoder-only architecture
  5. Study scaling: How does performance change as you vary d_model, num_heads, num_layers?

Further Reading

Technology