Back to TensorFlow Mastery Series

Part 8: Transformers & Attention

May 3, 2026 Wasil Zafar 30 min read

Beyond recurrence — master the architecture that revolutionized deep learning. Implement scaled dot-product attention, multi-head attention, positional encoding, and full Transformer encoder-decoder models from scratch. Apply Transformers to text classification, use pretrained models, and build a Vision Transformer.

Table of Contents

  1. Why Attention?
  2. Scaled Dot-Product Attention
  3. Multi-Head Attention
  4. Positional Encoding
  5. Transformer Encoder Block
  6. Transformer Decoder Block
  7. Full Transformer Model
  8. Text Classification
  9. Pretrained Transformers
  10. Vision Transformers (ViT)

Why Attention?

Recurrent neural networks process sequences one token at a time, maintaining a hidden state that compresses everything seen so far into a fixed-size vector. For short sequences this works well, but for long sequences (hundreds or thousands of tokens), critical information from early timesteps gets washed out — the infamous information bottleneck. Even LSTM and GRU architectures, designed to mitigate vanishing gradients, struggle to reliably connect tokens separated by 50+ positions.

The attention mechanism solves this by allowing every position in a sequence to directly attend to every other position, regardless of distance. Instead of relying on a single hidden state to carry information forward, attention computes a weighted combination of all positions — giving the model a direct "highway" to any relevant context. This is conceptually similar to a soft lookup table: given a query, find the most relevant keys and retrieve their associated values.

Key Insight: Attention replaces sequential processing with parallel content-based addressing. Each token asks "what parts of the input are most relevant to me?" and retrieves a weighted mixture of the answers — all in a single matrix multiplication. This enables $O(1)$ path length between any two positions (vs $O(n)$ for RNNs).

Self-Attention vs Cross-Attention

Self-attention (intra-attention) relates positions within the same sequence. Each token in a sentence attends to all other tokens in that same sentence to build contextual representations. Cross-attention connects two different sequences — for example, a decoder token attending to all encoder outputs (like in machine translation). The Transformer uses both: self-attention within encoder/decoder blocks, and cross-attention where the decoder queries the encoder.

import tensorflow as tf
import numpy as np

# Demonstrating why RNNs struggle with long-range dependencies
# vs how attention provides direct access

# Simulate RNN information flow: signal decays over distance
np.random.seed(42)
seq_length = 100
decay_factor = 0.95  # Per-step retention

# Information retained from position 0 at each subsequent position
retention = np.array([decay_factor**t for t in range(seq_length)])
print("RNN information retention from position 0:")
print(f"  After 10 steps: {retention[10]:.4f} ({retention[10]*100:.1f}%)")
print(f"  After 50 steps: {retention[50]:.4f} ({retention[50]*100:.1f}%)")
print(f"  After 100 steps: {retention[99]:.6f} ({retention[99]*100:.2f}%)")

# Attention: direct access regardless of distance
# Every position can attend to every other with equal ease
print("\nAttention mechanism: direct access")
print(f"  Access from position 99 to position 0: O(1) - full strength")
print(f"  Access from position 99 to position 50: O(1) - full strength")
print(f"  No information decay — attention weights learned per-instance")

# Self-attention vs Cross-attention
# Self: Q, K, V all come from the SAME sequence
encoder_input = tf.random.normal([1, 6, 64])  # (batch, seq_len, d_model)
print(f"\nSelf-attention: Q, K, V from same input {encoder_input.shape}")
print("  Each of the 6 tokens attends to all 6 tokens (including itself)")

# Cross: Q from one sequence, K/V from another
decoder_input = tf.random.normal([1, 4, 64])  # Decoder queries
print(f"\nCross-attention: Q from decoder {decoder_input.shape}")
print(f"  K, V from encoder {encoder_input.shape}")
print("  Each of 4 decoder tokens attends to all 6 encoder outputs")

The key advantage of attention is parallelizability. While RNNs must process tokens sequentially (each step depends on the previous hidden state), attention computes relationships between all pairs of positions simultaneously via matrix operations. This makes Transformers dramatically faster to train on modern GPUs.

Scaled Dot-Product Attention

The core computation of the Transformer is scaled dot-product attention. Given three matrices — Queries ($Q$), Keys ($K$), and Values ($V$) — we compute attention as:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Here's the intuition: each query vector "asks a question," each key vector "advertises what it contains," and the dot product $QK^T$ measures compatibility. We scale by $\sqrt{d_k}$ to prevent the dot products from becoming too large (which would push softmax into regions with vanishing gradients). After softmax normalization, we get attention weights that sum to 1, which we use to compute a weighted average of the value vectors.

Why Scale? For large $d_k$, the dot products $q \cdot k$ grow in magnitude proportionally to $\sqrt{d_k}$ (assuming unit variance inputs). Without scaling, softmax would produce extremely peaked distributions (near one-hot), leading to vanishing gradients. Dividing by $\sqrt{d_k}$ keeps the variance of dot products at ~1 regardless of dimensionality.

Implementing from Scratch

Here is the implementation for Implementing from Scratch. Each code example below is self-contained and can be run independently:

import tensorflow as tf
import numpy as np

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Compute scaled dot-product attention.

    Args:
        query: shape (..., seq_len_q, d_k)
        key: shape (..., seq_len_k, d_k)
        value: shape (..., seq_len_k, d_v)
        mask: optional mask shape (..., seq_len_q, seq_len_k)

    Returns:
        output: weighted values, shape (..., seq_len_q, d_v)
        attention_weights: shape (..., seq_len_q, seq_len_k)
    """
    # Step 1: Compute dot products between queries and keys
    # matmul(Q, K^T) → shape (..., seq_len_q, seq_len_k)
    matmul_qk = tf.matmul(query, key, transpose_b=True)

    # Step 2: Scale by sqrt(d_k)
    d_k = tf.cast(tf.shape(key)[-1], tf.float32)
    scaled_scores = matmul_qk / tf.math.sqrt(d_k)

    # Step 3: Apply mask (set masked positions to -infinity before softmax)
    if mask is not None:
        scaled_scores += (mask * -1e9)

    # Step 4: Softmax normalization (along key dimension)
    attention_weights = tf.nn.softmax(scaled_scores, axis=-1)

    # Step 5: Weighted sum of values
    output = tf.matmul(attention_weights, value)

    return output, attention_weights


# Example: 3 tokens, d_k=4
np.random.seed(42)
seq_len, d_k = 3, 4

# Create Q, K, V (in practice, these come from linear projections)
Q = tf.constant(np.random.randn(1, seq_len, d_k), dtype=tf.float32)
K = tf.constant(np.random.randn(1, seq_len, d_k), dtype=tf.float32)
V = tf.constant(np.random.randn(1, seq_len, d_k), dtype=tf.float32)

output, weights = scaled_dot_product_attention(Q, K, V)

print("Query shape:", Q.shape)
print("Key shape:", K.shape)
print("Value shape:", V.shape)
print("\nAttention weights (each row sums to 1):")
print(weights.numpy()[0].round(3))
print("\nOutput shape:", output.shape)
print("Output (weighted sum of values):")
print(output.numpy()[0].round(3))

# Verify weights sum to 1
print(f"\nRow sums: {tf.reduce_sum(weights, axis=-1).numpy()[0].round(6)}")

Notice that each row of the attention weights sums to 1 — each query distributes its "attention budget" across all keys. The output for each query position is a convex combination of the value vectors, weighted by how relevant each key is to that query.

Now let's add masking — essential for preventing the decoder from attending to future tokens during training (causal/look-ahead mask) and for ignoring padding tokens:

import tensorflow as tf
import numpy as np

def scaled_dot_product_attention(query, key, value, mask=None):
    """Scaled dot-product attention with optional mask."""
    matmul_qk = tf.matmul(query, key, transpose_b=True)
    d_k = tf.cast(tf.shape(key)[-1], tf.float32)
    scaled_scores = matmul_qk / tf.math.sqrt(d_k)
    if mask is not None:
        scaled_scores += (mask * -1e9)
    attention_weights = tf.nn.softmax(scaled_scores, axis=-1)
    output = tf.matmul(attention_weights, value)
    return output, attention_weights

def create_causal_mask(size):
    """Create upper-triangular mask (1s above diagonal → masked)."""
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask  # shape: (size, size)

def create_padding_mask(seq):
    """Create mask for padding tokens (0s in input)."""
    mask = tf.cast(tf.math.equal(seq, 0), tf.float32)
    return mask[:, tf.newaxis, tf.newaxis, :]  # (batch, 1, 1, seq_len)

# Causal mask: prevents attending to future positions
causal = create_causal_mask(5)
print("Causal mask (1 = masked, 0 = visible):")
print(causal.numpy().astype(int))

# Padding mask: token IDs with 0 = padding
token_ids = tf.constant([[7, 12, 3, 0, 0]])  # Last 2 are padding
pad_mask = create_padding_mask(token_ids)
print(f"\nToken IDs: {token_ids.numpy()[0]}")
print(f"Padding mask: {pad_mask.numpy()[0, 0, 0].astype(int)}")

# Attention with causal mask — decoder can only see past tokens
np.random.seed(42)
seq_len, d_k = 5, 8
Q = tf.random.normal([1, seq_len, d_k])
K = tf.random.normal([1, seq_len, d_k])
V = tf.random.normal([1, seq_len, d_k])

output, weights = scaled_dot_product_attention(Q, K, V, mask=causal)
print("\nCausal attention weights (lower triangular):")
print(weights.numpy()[0].round(3))

The causal mask ensures position $i$ can only attend to positions $\leq i$. This is critical for autoregressive generation — the model cannot "cheat" by looking at the answer while predicting it.

Multi-Head Attention

A single attention function captures one type of relationship between positions. Multi-head attention runs $h$ parallel attention operations (heads), each with different learned linear projections. This allows the model to jointly attend to information from different representation subspaces — one head might focus on syntactic relationships, another on semantic similarity, another on positional proximity.

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

where $\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$

Multi-Head Attention Architecture
flowchart TD
    subgraph Input["Input Projections"]
        Q["Q"] --> PQ1["W_Q^1"]
        Q --> PQ2["W_Q^2"]
        Q --> PQh["W_Q^h"]
        K["K"] --> PK1["W_K^1"]
        K --> PK2["W_K^2"]
        K --> PKh["W_K^h"]
        V["V"] --> PV1["W_V^1"]
        V --> PV2["W_V^2"]
        V --> PVh["W_V^h"]
    end

    subgraph Heads["Parallel Attention Heads"]
        PQ1 --> H1["Head 1
Attention"] PK1 --> H1 PV1 --> H1 PQ2 --> H2["Head 2
Attention"] PK2 --> H2 PV2 --> H2 PQh --> Hh["Head h
Attention"] PKh --> Hh PVh --> Hh end H1 --> CAT["Concat"] H2 --> CAT Hh --> CAT CAT --> WO["W_O (Output Projection)"] WO --> OUT["Multi-Head Output"]

Manual Implementation & Keras API

Here is the implementation for Manual Implementation & Keras API. Each code example below is self-contained and can be run independently:

import tensorflow as tf
import numpy as np

class MultiHeadAttention(tf.keras.layers.Layer):
    """Multi-head attention implemented from scratch."""

    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.num_heads = num_heads
        self.d_model = d_model
        self.depth = d_model // num_heads  # d_k per head

        # Linear projections for Q, K, V and output
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)
        self.wo = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        """Split last dim into (num_heads, depth) and transpose."""
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])  # (batch, heads, seq, depth)

    def call(self, query, key, value, mask=None):
        batch_size = tf.shape(query)[0]

        # Linear projections
        q = self.wq(query)  # (batch, seq_q, d_model)
        k = self.wk(key)
        v = self.wv(value)

        # Split into heads
        q = self.split_heads(q, batch_size)  # (batch, heads, seq_q, depth)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        # Scaled dot-product attention per head
        matmul_qk = tf.matmul(q, k, transpose_b=True)
        d_k = tf.cast(self.depth, tf.float32)
        scaled = matmul_qk / tf.math.sqrt(d_k)

        if mask is not None:
            scaled += (mask * -1e9)

        weights = tf.nn.softmax(scaled, axis=-1)
        attn_output = tf.matmul(weights, v)  # (batch, heads, seq_q, depth)

        # Concatenate heads
        attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3])
        concat = tf.reshape(attn_output, (batch_size, -1, self.d_model))

        # Final linear projection
        output = self.wo(concat)
        return output, weights


# Test our implementation
d_model, num_heads = 64, 8
mha = MultiHeadAttention(d_model, num_heads)

# Input: batch=2, seq_len=10, d_model=64
x = tf.random.normal([2, 10, d_model])
output, attn_weights = mha(x, x, x)  # Self-attention

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"  (batch={2}, heads={num_heads}, seq_q=10, seq_k=10)")

# Compare with Keras built-in MultiHeadAttention
keras_mha = tf.keras.layers.MultiHeadAttention(
    num_heads=8,
    key_dim=8,  # depth per head = d_model / num_heads
)
keras_output = keras_mha(query=x, key=x, value=x)
print(f"\nKeras MHA output shape: {keras_output.shape}")
print(f"Parameters: {keras_mha.count_params()}")

Each head operates on a subspace of dimension $d_k = d_{model} / h$. With 8 heads and $d_{model} = 64$, each head works in an 8-dimensional space. The concatenation of all heads produces a 64-dimensional vector, which the output projection maps back to $d_{model}$. The total computation cost is similar to single-head attention with the full dimensionality, but the multi-head version captures richer patterns.

Positional Encoding

Attention is permutation-invariant — if you shuffle the input tokens, the attention output (ignoring position) would be the same (each token still attends to the same set of tokens). But word order matters! "The cat sat on the mat" ≠ "mat the the on sat cat." We need to inject positional information into the model.

The original Transformer uses sinusoidal positional encoding:

$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right)$$

$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right)$$

where $pos$ is the position and $i$ is the dimension index. Each dimension uses a sinusoid of a different frequency, creating a unique "fingerprint" for each position. The key property: relative positions can be represented as linear functions of the encodings, allowing the model to learn to attend by relative position.

Both Approaches: Sinusoidal & Learned

Here is the implementation for Both Approaches: Sinusoidal & Learned. Each code example below is self-contained and can be run independently:

import tensorflow as tf
import numpy as np

def sinusoidal_positional_encoding(max_len, d_model):
    """
    Generate sinusoidal positional encoding matrix.

    Returns: shape (max_len, d_model)
    """
    positions = np.arange(max_len)[:, np.newaxis]     # (max_len, 1)
    dims = np.arange(d_model)[np.newaxis, :]          # (1, d_model)

    # Compute angles: pos / 10000^(2i/d_model)
    angles = positions / np.power(10000, (2 * (dims // 2)) / d_model)

    # Apply sin to even indices, cos to odd indices
    pe = np.zeros((max_len, d_model))
    pe[:, 0::2] = np.sin(angles[:, 0::2])  # Even dimensions
    pe[:, 1::2] = np.cos(angles[:, 1::2])  # Odd dimensions

    return tf.cast(pe[np.newaxis, :, :], dtype=tf.float32)  # (1, max_len, d_model)


class PositionalEncoding(tf.keras.layers.Layer):
    """Adds sinusoidal positional encoding to embeddings."""

    def __init__(self, max_len, d_model):
        super().__init__()
        self.pe = sinusoidal_positional_encoding(max_len, d_model)

    def call(self, x):
        seq_len = tf.shape(x)[1]
        return x + self.pe[:, :seq_len, :]


class LearnedPositionalEmbedding(tf.keras.layers.Layer):
    """Learned positional embeddings (like BERT)."""

    def __init__(self, max_len, d_model):
        super().__init__()
        self.pos_embedding = tf.keras.layers.Embedding(max_len, d_model)

    def call(self, x):
        seq_len = tf.shape(x)[1]
        positions = tf.range(seq_len)
        pos_emb = self.pos_embedding(positions)  # (seq_len, d_model)
        return x + pos_emb


# Compare both approaches
max_len, d_model = 100, 64

# Sinusoidal (fixed, no learnable parameters)
sin_pe = PositionalEncoding(max_len, d_model)
x = tf.random.normal([2, 20, d_model])
out_sin = sin_pe(x)
print(f"Sinusoidal PE - Input: {x.shape} → Output: {out_sin.shape}")
print(f"  Parameters: 0 (fixed encoding)")

# Learned (trainable)
learned_pe = LearnedPositionalEmbedding(max_len, d_model)
out_learned = learned_pe(x)
print(f"Learned PE - Input: {x.shape} → Output: {out_learned.shape}")
print(f"  Parameters: {learned_pe.count_params()} (max_len × d_model)")

# Visualize encoding pattern for first few positions
pe_matrix = sinusoidal_positional_encoding(max_len, d_model)
print(f"\nPositional encoding matrix shape: {pe_matrix.shape}")
print(f"PE[0, :8] (position 0): {pe_matrix[0, 0, :8].numpy().round(3)}")
print(f"PE[1, :8] (position 1): {pe_matrix[0, 1, :8].numpy().round(3)}")
print(f"PE[50, :8] (position 50): {pe_matrix[0, 50, :8].numpy().round(3)}")

# Key property: dot product between positions decays with distance
dot_0_1 = tf.reduce_sum(pe_matrix[0, 0] * pe_matrix[0, 1]).numpy()
dot_0_10 = tf.reduce_sum(pe_matrix[0, 0] * pe_matrix[0, 10]).numpy()
dot_0_50 = tf.reduce_sum(pe_matrix[0, 0] * pe_matrix[0, 50]).numpy()
print(f"\nDot product similarity to position 0:")
print(f"  Position 1: {dot_0_1:.3f}")
print(f"  Position 10: {dot_0_10:.3f}")
print(f"  Position 50: {dot_0_50:.3f}")

In practice, learned positional embeddings (used by BERT, GPT) perform comparably to sinusoidal encodings for sequences within the training length. Sinusoidal encodings have the theoretical advantage of generalizing to longer sequences at inference (since the formula works for any position), but this rarely matters in practice where models are trained on fixed-length contexts.

Transformer Encoder Block

The Transformer encoder block is the fundamental building unit. It consists of two sub-layers: (1) Multi-Head Self-Attention and (2) Position-wise Feed-Forward Network. Each sub-layer is wrapped with a residual connection and layer normalization. The pattern is: output = LayerNorm(x + SubLayer(x)).

Transformer Encoder Block Architecture
flowchart TD
    IN["Input Embeddings + PE"] --> MHA["Multi-Head
Self-Attention"] IN --> ADD1["Add (Residual)"] MHA --> ADD1 ADD1 --> LN1["Layer Norm"] LN1 --> FFN["Feed-Forward Network
Dense(d_ff, relu) → Dense(d_model)"] LN1 --> ADD2["Add (Residual)"] FFN --> ADD2 ADD2 --> LN2["Layer Norm"] LN2 --> OUT["Encoder Output"]

Building a Reusable Encoder Block

The original Transformer ("Attention Is All You Need") uses post-norm: normalize after the residual addition. Modern implementations often prefer pre-norm: normalize before the sub-layer, which provides more stable training for deep stacks. We'll implement both:

import tensorflow as tf

class TransformerEncoderBlock(tf.keras.layers.Layer):
    """Single Transformer encoder block with MHA + FFN + residuals."""

    def __init__(self, d_model, num_heads, d_ff, dropout_rate=0.1, pre_norm=False):
        super().__init__()
        self.pre_norm = pre_norm

        # Multi-Head Self-Attention
        self.mha = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=d_model // num_heads,
        )

        # Feed-Forward Network (2 Dense layers with ReLU)
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(d_ff, activation='relu'),
            tf.keras.layers.Dense(d_model),
        ])

        # Layer Normalization
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        # Dropout
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x, training=False, mask=None):
        if self.pre_norm:
            # Pre-norm: LayerNorm → SubLayer → Residual
            normed = self.layernorm1(x)
            attn_output = self.mha(normed, normed, normed, attention_mask=mask)
            attn_output = self.dropout1(attn_output, training=training)
            x = x + attn_output

            normed = self.layernorm2(x)
            ffn_output = self.ffn(normed)
            ffn_output = self.dropout2(ffn_output, training=training)
            x = x + ffn_output
        else:
            # Post-norm (original): SubLayer → Residual → LayerNorm
            attn_output = self.mha(x, x, x, attention_mask=mask)
            attn_output = self.dropout1(attn_output, training=training)
            x = self.layernorm1(x + attn_output)

            ffn_output = self.ffn(x)
            ffn_output = self.dropout2(ffn_output, training=training)
            x = self.layernorm2(x + ffn_output)

        return x


# Test encoder block
d_model, num_heads, d_ff = 128, 8, 512
encoder_block = TransformerEncoderBlock(d_model, num_heads, d_ff)

# Input: batch=2, seq_len=20, d_model=128
x = tf.random.normal([2, 20, d_model])
output = encoder_block(x, training=False)

print(f"Encoder block input: {x.shape}")
print(f"Encoder block output: {output.shape}")
print(f"Parameters: {encoder_block.count_params()}")
print(f"\nArchitecture:")
print(f"  MHA: {num_heads} heads × {d_model // num_heads} depth = {d_model}")
print(f"  FFN: {d_model} → {d_ff} (ReLU) → {d_model}")
print(f"  Residual connections: 2")
print(f"  Layer norms: 2")

The feed-forward network expands the representation to a higher dimension ($d_{ff}$ is typically 4× $d_{model}$), applies a non-linearity, then projects back. This gives each position independent non-linear processing capacity. The residual connections ensure gradient flow through deep stacks (6+ layers), and layer normalization stabilizes training.

Transformer Decoder Block

The decoder block is more complex than the encoder — it has three sub-layers: (1) masked self-attention (causal), (2) cross-attention over encoder outputs, and (3) feed-forward network. The causal mask ensures that during training, position $i$ can only attend to positions $< i$ in the target sequence, preventing information leakage from future tokens.

Causal Masking & Cross-Attention

Here is the implementation for Causal Masking & Cross-Attention. Each code example below is self-contained and can be run independently:

import tensorflow as tf

class TransformerDecoderBlock(tf.keras.layers.Layer):
    """Single Transformer decoder block with masked self-attn + cross-attn + FFN."""

    def __init__(self, d_model, num_heads, d_ff, dropout_rate=0.1):
        super().__init__()

        # Masked Self-Attention (causal)
        self.masked_mha = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=d_model // num_heads,
        )

        # Cross-Attention (decoder queries encoder)
        self.cross_mha = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=d_model // num_heads,
        )

        # Feed-Forward Network
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(d_ff, activation='relu'),
            tf.keras.layers.Dense(d_model),
        ])

        # Layer Norms and Dropout
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout3 = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x, encoder_output, training=False, causal_mask=None, padding_mask=None):
        # Sub-layer 1: Masked self-attention
        attn1 = self.masked_mha(
            query=x, key=x, value=x,
            use_causal_mask=True  # Keras handles causal masking
        )
        attn1 = self.dropout1(attn1, training=training)
        x = self.layernorm1(x + attn1)

        # Sub-layer 2: Cross-attention (Q=decoder, K,V=encoder)
        attn2 = self.cross_mha(
            query=x,
            key=encoder_output,
            value=encoder_output,
            attention_mask=padding_mask
        )
        attn2 = self.dropout2(attn2, training=training)
        x = self.layernorm2(x + attn2)

        # Sub-layer 3: Feed-forward
        ffn_out = self.ffn(x)
        ffn_out = self.dropout3(ffn_out, training=training)
        x = self.layernorm3(x + ffn_out)

        return x


# Test decoder block
d_model, num_heads, d_ff = 128, 8, 512
decoder_block = TransformerDecoderBlock(d_model, num_heads, d_ff)

# Encoder output (from encoder stack)
encoder_output = tf.random.normal([2, 15, d_model])  # Source seq_len=15

# Decoder input (target tokens so far)
decoder_input = tf.random.normal([2, 10, d_model])   # Target seq_len=10

output = decoder_block(decoder_input, encoder_output, training=False)

print(f"Encoder output: {encoder_output.shape} (source sequence)")
print(f"Decoder input: {decoder_input.shape} (target sequence)")
print(f"Decoder block output: {output.shape}")
print(f"\nSub-layers:")
print(f"  1. Masked self-attention: target attends to past target tokens only")
print(f"  2. Cross-attention: target queries attend to all encoder outputs")
print(f"  3. FFN: position-wise non-linear transformation")
print(f"\nParameters: {decoder_block.count_params()}")

The decoder's causal self-attention ensures autoregressive generation: when predicting token $t$, the model only sees tokens $1, ..., t-1$. The cross-attention layer allows each decoder position to attend to the full encoder output, enabling the decoder to "look at" the entire source sequence when generating each target token.

Full Transformer Model

The complete Transformer stacks multiple encoder and decoder blocks. The full pipeline is: Token Embedding + Positional Encoding → Encoder Stack → Decoder Stack → Linear Head → Softmax. Let's assemble everything into a complete sequence-to-sequence model:

Stacking Blocks & Token Masks

Here is the implementation for Stacking Blocks & Token Masks. Each code example below is self-contained and can be run independently:

import tensorflow as tf
import numpy as np

class TransformerEncoderBlock(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, d_ff, dropout_rate=0.1):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
        self.ffn = tf.keras.Sequential([tf.keras.layers.Dense(d_ff, activation='relu'), tf.keras.layers.Dense(d_model)])
        self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.drop1 = tf.keras.layers.Dropout(dropout_rate)
        self.drop2 = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x, training=False, mask=None):
        attn = self.drop1(self.mha(x, x, x, attention_mask=mask), training=training)
        x = self.ln1(x + attn)
        ffn = self.drop2(self.ffn(x), training=training)
        return self.ln2(x + ffn)


class TransformerDecoderBlock(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, d_ff, dropout_rate=0.1):
        super().__init__()
        self.mha1 = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
        self.mha2 = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
        self.ffn = tf.keras.Sequential([tf.keras.layers.Dense(d_ff, activation='relu'), tf.keras.layers.Dense(d_model)])
        self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.ln3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.drop1 = tf.keras.layers.Dropout(dropout_rate)
        self.drop2 = tf.keras.layers.Dropout(dropout_rate)
        self.drop3 = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x, enc_out, training=False, padding_mask=None):
        attn1 = self.drop1(self.mha1(x, x, x, use_causal_mask=True), training=training)
        x = self.ln1(x + attn1)
        attn2 = self.drop2(self.mha2(x, enc_out, enc_out, attention_mask=padding_mask), training=training)
        x = self.ln2(x + attn2)
        ffn = self.drop3(self.ffn(x), training=training)
        return self.ln3(x + ffn)


class Transformer(tf.keras.Model):
    """Full Transformer: encoder-decoder for sequence-to-sequence tasks."""

    def __init__(self, src_vocab, tgt_vocab, d_model=256, num_heads=8,
                 d_ff=1024, num_layers=4, max_len=512, dropout_rate=0.1):
        super().__init__()
        self.d_model = d_model

        # Embeddings
        self.src_embedding = tf.keras.layers.Embedding(src_vocab, d_model)
        self.tgt_embedding = tf.keras.layers.Embedding(tgt_vocab, d_model)
        self.pos_encoding = self._sinusoidal_pe(max_len, d_model)

        # Encoder stack
        self.encoder_layers = [
            TransformerEncoderBlock(d_model, num_heads, d_ff, dropout_rate)
            for _ in range(num_layers)
        ]

        # Decoder stack
        self.decoder_layers = [
            TransformerDecoderBlock(d_model, num_heads, d_ff, dropout_rate)
            for _ in range(num_layers)
        ]

        # Output head
        self.final_layer = tf.keras.layers.Dense(tgt_vocab)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def _sinusoidal_pe(self, max_len, d_model):
        positions = np.arange(max_len)[:, np.newaxis]
        dims = np.arange(d_model)[np.newaxis, :]
        angles = positions / np.power(10000, (2 * (dims // 2)) / d_model)
        pe = np.zeros((max_len, d_model))
        pe[:, 0::2] = np.sin(angles[:, 0::2])
        pe[:, 1::2] = np.cos(angles[:, 1::2])
        return tf.cast(pe[np.newaxis, :, :], tf.float32)

    def call(self, inputs, training=False):
        src, tgt = inputs
        src_len = tf.shape(src)[1]
        tgt_len = tf.shape(tgt)[1]

        # Encode source
        x = self.src_embedding(src) * tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = x + self.pos_encoding[:, :src_len, :]
        x = self.dropout(x, training=training)
        for enc_layer in self.encoder_layers:
            x = enc_layer(x, training=training)
        enc_output = x

        # Decode target
        y = self.tgt_embedding(tgt) * tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        y = y + self.pos_encoding[:, :tgt_len, :]
        y = self.dropout(y, training=training)
        for dec_layer in self.decoder_layers:
            y = dec_layer(y, enc_output, training=training)

        # Project to vocabulary
        logits = self.final_layer(y)
        return logits


# Instantiate and test
src_vocab, tgt_vocab = 8000, 6000
model = Transformer(src_vocab, tgt_vocab, d_model=256, num_heads=8,
                    d_ff=1024, num_layers=4)

# Dummy input: source tokens and target tokens (teacher forcing)
src_tokens = tf.random.uniform([2, 20], maxval=src_vocab, dtype=tf.int32)
tgt_tokens = tf.random.uniform([2, 15], maxval=tgt_vocab, dtype=tf.int32)

logits = model((src_tokens, tgt_tokens), training=False)
print(f"Source tokens: {src_tokens.shape}")
print(f"Target tokens: {tgt_tokens.shape}")
print(f"Output logits: {logits.shape}  (batch, tgt_seq_len, tgt_vocab)")
print(f"\nModel config:")
print(f"  d_model={256}, num_heads=8, d_ff=1024, num_layers=4")
print(f"  Source vocab: {src_vocab}, Target vocab: {tgt_vocab}")
print(f"  Total parameters: {model.count_params():,}")

The model uses teacher forcing during training: the decoder receives the ground-truth target sequence shifted right by one position. At inference time, tokens are generated autoregressively — each predicted token is fed back as input for the next position. The embedding scaling factor $\sqrt{d_{model}}$ ensures the embeddings are on a similar scale as the positional encodings.

Text Classification with Transformers

For classification tasks, we only need the encoder — no decoder required. The strategy: pass the input sequence through the Transformer encoder stack, then pool the output (global average or take the [CLS] token) and feed it to a classification head. This "encoder-only" architecture is what BERT uses.

Encoder-Only Classifier on IMDB

Here is the implementation for Encoder-Only Classifier on IMDB. Each code example below is self-contained and can be run independently:

import tensorflow as tf
import numpy as np

class TransformerClassifier(tf.keras.Model):
    """Transformer encoder-only model for text classification."""

    def __init__(self, vocab_size, max_len, d_model=128, num_heads=4,
                 d_ff=256, num_layers=2, num_classes=2, dropout_rate=0.1):
        super().__init__()
        self.d_model = d_model

        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
        self.pos_embedding = tf.keras.layers.Embedding(max_len, d_model)

        self.encoder_layers = []
        for _ in range(num_layers):
            self.encoder_layers.append(
                tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
            )
            self.encoder_layers.append(tf.keras.layers.LayerNormalization(epsilon=1e-6))
            self.encoder_layers.append(tf.keras.layers.Dense(d_ff, activation='relu'))
            self.encoder_layers.append(tf.keras.layers.Dense(d_model))
            self.encoder_layers.append(tf.keras.layers.LayerNormalization(epsilon=1e-6))
            self.encoder_layers.append(tf.keras.layers.Dropout(dropout_rate))

        self.global_pool = tf.keras.layers.GlobalAveragePooling1D()
        self.classifier = tf.keras.Sequential([
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dropout(0.3),
            tf.keras.layers.Dense(num_classes, activation='softmax'),
        ])

    def call(self, x, training=False):
        seq_len = tf.shape(x)[1]
        positions = tf.range(seq_len)

        # Token + positional embedding
        x = self.embedding(x) * tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = x + self.pos_embedding(positions)

        # Pass through encoder blocks
        i = 0
        while i < len(self.encoder_layers):
            # MHA + residual + norm
            attn = self.encoder_layers[i](x, x, x)  # MHA
            x = self.encoder_layers[i+1](x + attn)   # Add & Norm
            # FFN + residual + norm
            ffn = self.encoder_layers[i+3](self.encoder_layers[i+2](x))  # Dense→Dense
            x = self.encoder_layers[i+4](x + ffn)    # Add & Norm
            x = self.encoder_layers[i+5](x, training=training)  # Dropout
            i += 6

        # Pool and classify
        pooled = self.global_pool(x)
        return self.classifier(pooled, training=training)


# Load IMDB dataset
vocab_size = 10000
max_len = 200

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(
    num_words=vocab_size
)
x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_len)
x_test = tf.keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_len)
print(f"Training: {x_train.shape}, Test: {x_test.shape}")

# Build Transformer classifier
transformer_clf = TransformerClassifier(
    vocab_size=vocab_size, max_len=max_len,
    d_model=64, num_heads=4, d_ff=128, num_layers=2, num_classes=2
)

transformer_clf.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train (small epochs for demonstration)
print("\nTraining Transformer classifier...")
transformer_clf.fit(x_train, y_train, epochs=3, batch_size=64,
                    validation_split=0.2, verbose=1)

Let's compare the Transformer classifier against an LSTM baseline on the same IMDB task to see how they stack up:

import tensorflow as tf

# LSTM baseline for comparison
vocab_size = 10000
max_len = 200

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(
    num_words=vocab_size
)
x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_len)
x_test = tf.keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_len)

# LSTM model
lstm_model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, 64),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)),
    tf.keras.layers.GlobalAveragePooling1D(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(2, activation='softmax'),
])

lstm_model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print("LSTM Baseline:")
print(f"  Parameters: {lstm_model.count_params():,}")
lstm_model.fit(x_train, y_train, epochs=3, batch_size=64,
               validation_split=0.2, verbose=1)

# Evaluate both
lstm_loss, lstm_acc = lstm_model.evaluate(x_test, y_test, verbose=0)
print(f"\n--- Results after 3 epochs ---")
print(f"LSTM Accuracy: {lstm_acc:.4f}")
print(f"\nNote: Transformers often need more epochs or larger models to")
print(f"outperform LSTMs on small datasets like IMDB (25k samples).")
print(f"Their advantage shines with more data and longer sequences.")
When Transformers Win: On small datasets like IMDB (25k training samples), LSTMs often match or beat small Transformers because they have a strong inductive bias for sequential data. Transformers shine when: (1) datasets are large (100k+ examples), (2) sequences are long (512+ tokens), (3) models are pretrained (BERT, GPT), or (4) multiple tasks benefit from transfer learning.

Using Pretrained Transformers

Training Transformers from scratch requires massive data and compute. In practice, we leverage pretrained models — large Transformers trained on billions of tokens — and fine-tune them for specific tasks. TensorFlow offers two main avenues: TensorFlow Hub (legacy) and KerasNLP (modern, recommended).

KerasNLP & Fine-Tuning

Here is the implementation for KerasNLP & Fine-Tuning. Each code example below is self-contained and can be run independently:

import tensorflow as tf

# Approach 1: KerasNLP (modern recommended approach)
# Install: pip install keras-nlp
# KerasNLP provides pretrained backbone + tokenizer as one package

# Example: Using a pretrained BERT-like model for classification
# (Pseudo-code — actual execution requires keras-nlp installed)
print("=== KerasNLP Approach (Recommended) ===")
print("""
import keras_nlp

# Load pretrained classifier with one line
classifier = keras_nlp.models.BertClassifier.from_preset(
    "bert_base_en_uncased",
    num_classes=2,
)

# Fine-tune on your data
classifier.compile(
    optimizer=tf.keras.optimizers.Adam(2e-5),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)
classifier.fit(train_ds, epochs=3)
""")

# Approach 2: TensorFlow Hub (legacy but widely available)
print("\n=== TF Hub Approach (Legacy) ===")
print("""
import tensorflow_hub as hub

# Load BERT encoder from TF Hub
encoder = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4",
    trainable=True
)
preprocessor = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
)

# Build model
inputs = tf.keras.layers.Input(shape=(), dtype=tf.string)
preprocessed = preprocessor(inputs)
outputs = encoder(preprocessed)
pooled = outputs['pooled_output']  # [CLS] token representation
predictions = tf.keras.layers.Dense(2, activation='softmax')(pooled)
model = tf.keras.Model(inputs, predictions)
""")

# Approach 3: Manual fine-tuning with frozen/unfrozen layers
print("\n=== Fine-Tuning Strategy ===")

# Demonstrate the common fine-tuning pattern
base_model = tf.keras.Sequential([
    tf.keras.layers.Embedding(10000, 128),
    tf.keras.layers.MultiHeadAttention(num_heads=4, key_dim=32),
])

# Phase 1: Freeze backbone, train head only
print("Phase 1: Train classification head (backbone frozen)")
print("  - Learning rate: 1e-3")
print("  - Epochs: 5")
print("  - Only top layers update")

# Phase 2: Unfreeze and fine-tune with lower LR
print("\nPhase 2: Fine-tune entire model")
print("  - Learning rate: 2e-5 (10-50x smaller)")
print("  - Epochs: 3-5 more")
print("  - All layers update with discriminative LRs")

# Key hyperparameters for Transformer fine-tuning
print("\n=== Fine-Tuning Hyperparameters ===")
print("  Learning rate: 1e-5 to 5e-5 (Adam)")
print("  Batch size: 16-32 (limited by memory)")
print("  Epochs: 2-5 (overtrain risk with small data)")
print("  Warmup: 10% of total steps")
print("  Weight decay: 0.01")
print("  Max sequence length: 128-512 (depends on task)")

The key insight with pretrained Transformers: the backbone already understands language. Fine-tuning adapts this knowledge to your specific task with minimal data. A BERT model pretrained on Wikipedia and BookCorpus (3.3B words) can be fine-tuned for sentiment analysis with just 1,000 labeled examples and still outperform an LSTM trained from scratch on 25,000 examples.

import tensorflow as tf
import numpy as np

# Practical example: Custom tokenizer + Transformer with warm-up schedule

class WarmupSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    """Linear warmup followed by inverse sqrt decay (original Transformer)."""

    def __init__(self, d_model, warmup_steps=4000):
        super().__init__()
        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)


# Visualize learning rate schedule
schedule = WarmupSchedule(d_model=256, warmup_steps=4000)
steps = np.arange(1, 20001)
lrs = [schedule(step).numpy() for step in steps]

print("Warmup Schedule (d_model=256, warmup=4000 steps):")
print(f"  Step 1000: LR = {schedule(1000).numpy():.6f}")
print(f"  Step 4000: LR = {schedule(4000).numpy():.6f} (peak)")
print(f"  Step 8000: LR = {schedule(8000).numpy():.6f}")
print(f"  Step 20000: LR = {schedule(20000).numpy():.6f}")

# Using the schedule with Adam optimizer
optimizer = tf.keras.optimizers.Adam(
    learning_rate=schedule,
    beta_1=0.9,
    beta_2=0.98,
    epsilon=1e-9
)
print(f"\nOptimizer: Adam with warmup schedule")
print(f"  Beta1=0.9, Beta2=0.98, Epsilon=1e-9")
print(f"  (Original Transformer hyperparameters)")

Vision Transformers (ViT) Preview

Transformers aren't limited to text — Vision Transformers (ViT) apply the same architecture to images by treating image patches as tokens. The idea: split an image into fixed-size patches (e.g., 16×16 pixels), flatten each patch into a vector, linearly project it (patch embedding), add positional embeddings, prepend a learnable [CLS] token, and feed everything through a standard Transformer encoder. The [CLS] token output is used for classification.

Building a Minimal ViT for CIFAR-10

Here is the implementation for Building a Minimal ViT for CIFAR-10. Each code example below is self-contained and can be run independently:

import tensorflow as tf
import numpy as np

class PatchEmbedding(tf.keras.layers.Layer):
    """Split image into patches and embed them."""

    def __init__(self, patch_size, d_model):
        super().__init__()
        self.patch_size = patch_size
        # Conv2D with kernel=stride=patch_size acts as patch extraction + projection
        self.projection = tf.keras.layers.Conv2D(
            d_model, kernel_size=patch_size, strides=patch_size
        )

    def call(self, images):
        # images: (batch, H, W, C)
        patches = self.projection(images)  # (batch, H/P, W/P, d_model)
        batch_size = tf.shape(patches)[0]
        # Flatten spatial dims: (batch, num_patches, d_model)
        patches = tf.reshape(patches, [batch_size, -1, patches.shape[-1]])
        return patches


class VisionTransformer(tf.keras.Model):
    """Minimal Vision Transformer for image classification."""

    def __init__(self, image_size, patch_size, num_classes, d_model=128,
                 num_heads=4, d_ff=256, num_layers=4, dropout_rate=0.1):
        super().__init__()
        num_patches = (image_size // patch_size) ** 2

        # Patch embedding
        self.patch_embed = PatchEmbedding(patch_size, d_model)

        # Learnable [CLS] token
        self.cls_token = self.add_weight(
            "cls_token", shape=(1, 1, d_model),
            initializer="random_normal"
        )

        # Learnable positional embedding (num_patches + 1 for CLS)
        self.pos_embed = self.add_weight(
            "pos_embed", shape=(1, num_patches + 1, d_model),
            initializer="random_normal"
        )

        # Transformer encoder blocks
        self.encoder_blocks = []
        for _ in range(num_layers):
            self.encoder_blocks.append(tf.keras.layers.LayerNormalization(epsilon=1e-6))
            self.encoder_blocks.append(
                tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
            )
            self.encoder_blocks.append(tf.keras.layers.LayerNormalization(epsilon=1e-6))
            self.encoder_blocks.append(tf.keras.Sequential([
                tf.keras.layers.Dense(d_ff, activation='gelu'),
                tf.keras.layers.Dense(d_model),
            ]))
            self.encoder_blocks.append(tf.keras.layers.Dropout(dropout_rate))

        self.final_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.classifier = tf.keras.layers.Dense(num_classes)

    def call(self, images, training=False):
        batch_size = tf.shape(images)[0]

        # Create patch embeddings
        x = self.patch_embed(images)  # (batch, num_patches, d_model)

        # Prepend [CLS] token
        cls_tokens = tf.broadcast_to(self.cls_token, [batch_size, 1, x.shape[-1]])
        x = tf.concat([cls_tokens, x], axis=1)  # (batch, num_patches+1, d_model)

        # Add positional embedding
        x = x + self.pos_embed

        # Transformer encoder (pre-norm variant)
        i = 0
        while i < len(self.encoder_blocks):
            # MHA block with residual
            normed = self.encoder_blocks[i](x)      # LayerNorm
            attn = self.encoder_blocks[i+1](normed, normed, normed)  # MHA
            x = x + attn
            # FFN block with residual
            normed = self.encoder_blocks[i+2](x)    # LayerNorm
            ffn = self.encoder_blocks[i+3](normed)  # FFN
            x = x + ffn
            x = self.encoder_blocks[i+4](x, training=training)  # Dropout
            i += 5

        # Classification from [CLS] token
        x = self.final_norm(x)
        cls_output = x[:, 0]  # First token is [CLS]
        return self.classifier(cls_output)


# Build ViT for CIFAR-10
vit = VisionTransformer(
    image_size=32,
    patch_size=4,       # 32/4 = 8 → 64 patches
    num_classes=10,
    d_model=128,
    num_heads=4,
    d_ff=256,
    num_layers=4,
)

# Test with dummy CIFAR-10 images
dummy_images = tf.random.normal([4, 32, 32, 3])
logits = vit(dummy_images, training=False)

print(f"Input images: {dummy_images.shape}")
print(f"Patch size: 4×4 → {(32//4)**2} patches per image")
print(f"Sequence: [CLS] + 64 patches = 65 tokens")
print(f"Output logits: {logits.shape} (batch, num_classes)")
print(f"\nViT Architecture:")
print(f"  d_model=128, num_heads=4, d_ff=256, num_layers=4")
print(f"  Total parameters: {vit.count_params():,}")

Let's train this minimal ViT on CIFAR-10 to see it in action:

import tensorflow as tf
import numpy as np

# Quick ViT training on CIFAR-10
# Load data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
y_train = y_train.squeeze()
y_test = y_test.squeeze()

print(f"CIFAR-10: {x_train.shape[0]} train, {x_test.shape[0]} test")
print(f"Image shape: {x_train.shape[1:]}")
print(f"Classes: {10}")

# Build a compact ViT
inputs = tf.keras.layers.Input(shape=(32, 32, 3))

# Patch embedding via Conv2D
patch_size = 4
d_model = 64
patches = tf.keras.layers.Conv2D(d_model, kernel_size=patch_size,
                                  strides=patch_size)(inputs)
# Reshape: (batch, 8, 8, 64) → (batch, 64, 64)
patches = tf.keras.layers.Reshape((-1, d_model))(patches)
num_patches = (32 // patch_size) ** 2  # 64

# Add learned positional embedding
positions = tf.keras.layers.Embedding(num_patches, d_model)(
    tf.range(num_patches)
)
x = patches + positions

# 2 Transformer encoder blocks
for _ in range(2):
    # MHA + residual + norm
    attn = tf.keras.layers.MultiHeadAttention(
        num_heads=4, key_dim=16
    )(x, x)
    x = tf.keras.layers.LayerNormalization()(x + attn)
    # FFN + residual + norm
    ffn = tf.keras.layers.Dense(128, activation='gelu')(x)
    ffn = tf.keras.layers.Dense(d_model)(ffn)
    x = tf.keras.layers.LayerNormalization()(x + ffn)

# Global average pooling → classifier
x = tf.keras.layers.GlobalAveragePooling1D()(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)

model = tf.keras.Model(inputs, outputs)
model.compile(
    optimizer=tf.keras.optimizers.Adam(3e-4),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print(f"\nCompact ViT parameters: {model.count_params():,}")
model.fit(x_train, y_train, epochs=5, batch_size=128,
          validation_split=0.1, verbose=1)
ViT Caveat: Vision Transformers require significantly more data than CNNs to train from scratch. The original ViT paper trained on ImageNet-21k (14M images) or JFT-300M. On small datasets like CIFAR-10, a well-tuned ResNet will outperform a ViT trained from scratch. The power of ViT comes from pretraining at scale, then fine-tuning — similar to BERT for NLP.
import tensorflow as tf
import numpy as np

# Summary: Key Transformer hyperparameters and their effects
print("=" * 60)
print("TRANSFORMER HYPERPARAMETER GUIDE")
print("=" * 60)

configs = {
    "BERT-Base": {"d": 768, "h": 12, "ff": 3072, "L": 12, "params": "110M"},
    "BERT-Large": {"d": 1024, "h": 16, "ff": 4096, "L": 24, "params": "340M"},
    "GPT-2 Small": {"d": 768, "h": 12, "ff": 3072, "L": 12, "params": "117M"},
    "GPT-2 Large": {"d": 1280, "h": 20, "ff": 5120, "L": 36, "params": "774M"},
    "ViT-Base/16": {"d": 768, "h": 12, "ff": 3072, "L": 12, "params": "86M"},
}

print(f"\n{'Model':<15} {'d_model':<8} {'heads':<6} {'d_ff':<6} {'layers':<7} {'params':<8}")
print("-" * 55)
for name, cfg in configs.items():
    print(f"{name:<15} {cfg['d']:<8} {cfg['h']:<6} {cfg['ff']:<6} {cfg['L']:<7} {cfg['params']:<8}")

print("\n\nKey Design Choices:")
print("  • d_ff = 4 × d_model (standard expansion ratio)")
print("  • key_dim = d_model / num_heads")
print("  • More layers > wider layers (for same param budget)")
print("  • Pre-norm more stable for deep models (>12 layers)")
print("  • Warmup + cosine/inverse-sqrt decay for learning rate")
print("  • Dropout: 0.1 in attention/FFN, 0.1-0.3 in embeddings")

print("\n\nWhen to Use What:")
print("  • Classification → Encoder-only (BERT-style)")
print("  • Generation → Decoder-only (GPT-style)")
print("  • Translation → Full Encoder-Decoder")
print("  • Images → ViT (pretrained) or CNN+Transformer hybrid")
print("  • Small data → Fine-tune pretrained > Train from scratch")

Next in the Series

In Part 9: Deployment, Performance & Best Practices, we'll take your models to production — TensorFlow Serving, TFLite for mobile/edge, SavedModel format, XLA compilation, mixed precision training, distributed strategies, and profiling tools.