Table of Contents

  1. An Image is Worth 16x16 Words
  2. Patch Embedding Layer
  3. Position Embeddings & [CLS] Token
  4. Transformer Encoder Block
  5. Building ViT from Scratch
  6. Training ViT on CIFAR-10
  7. Attention Visualization
  8. Transfer Learning with Pre-trained ViT
Back to TensorFlow Mastery Series

Deep Dive: Vision Transformer (ViT) in TensorFlow

May 3, 2026 Wasil Zafar 35 min read

Build the Vision Transformer from scratch in TensorFlow — patch embeddings, positional encoding, multi-head self-attention, and the [CLS] token classification pipeline that rivals CNNs on image recognition.

An Image is Worth 16×16 Words

In 2020, Google Research introduced a radical idea: what if we threw away convolutions entirely and treated images like text? The Vision Transformer (ViT) does exactly this — it splits an image into fixed-size patches, flattens each patch into a vector, and feeds the resulting sequence into a standard Transformer encoder. No convolutions, no pooling, no inductive bias about local connectivity.

Key Insight: A 224×224 image split into 16×16 patches produces 14×14 = 196 “tokens” — a sequence length comparable to a short paragraph of text. The Transformer processes these patch tokens with global self-attention, allowing every patch to attend to every other patch from layer one.
ViT Pipeline: Image to Classification
flowchart LR
    A[Input Image
224x224x3] --> B[Split into
16x16 Patches] B --> C[Flatten &
Linear Project] C --> D[Add Position
Embeddings] D --> E[Prepend
CLS Token] E --> F[Transformer
Encoder x N] F --> G[Extract CLS
Output] G --> H[MLP Head
Classification]

Splitting and Visualizing Patches

The first step is mechanical: reshape a 224×224×3 image into a sequence of 196 flattened patches, each of size 16×16×3 = 768 values.

import tensorflow as tf
import numpy as np

def extract_patches(image, patch_size=16):
    """Split an image into non-overlapping patches.

    Args:
        image: (H, W, 3) tensor, values in [0, 1]
        patch_size: size of each square patch (default 16)

    Returns:
        patches: (num_patches, patch_size*patch_size*3) flattened patches
    """
    # image shape: (224, 224, 3)
    h, w, c = image.shape
    num_patches_h = h // patch_size  # 224 // 16 = 14
    num_patches_w = w // patch_size  # 224 // 16 = 14
    num_patches = num_patches_h * num_patches_w  # 196

    # Reshape into grid of patches
    patches = tf.image.extract_patches(
        images=tf.expand_dims(image, 0),  # add batch dim
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, patch_size, patch_size, 1],
        rates=[1, 1, 1, 1],
        padding="VALID"
    )
    # patches shape: (1, 14, 14, 768)
    patches = tf.reshape(patches, (num_patches, patch_size * patch_size * c))
    return patches

# Create a dummy 224x224 RGB image
np.random.seed(42)
dummy_image = np.random.rand(224, 224, 3).astype(np.float32)

# Extract patches
patches = extract_patches(dummy_image, patch_size=16)
print(f"Image shape: (224, 224, 3)")
print(f"Patch size: 16x16")
print(f"Number of patches: {patches.shape[0]}")  # 196
print(f"Each patch flattened: {patches.shape[1]} values")  # 768
print(f"Grid: 14 rows x 14 columns = 196 patches")
print(f"\nFirst patch (top-left corner) stats:")
print(f"  Mean pixel: {patches[0].numpy().mean():.4f}")
print(f"  Std pixel:  {patches[0].numpy().std():.4f}")
Visualization Patch Grid
Visualizing the 14×14 Patch Grid

When you visualize the 196 patches rearranged on a grid, you can see that each patch captures a local region of the image. Unlike CNN receptive fields that grow gradually, the Transformer can attend to distant patches from the very first layer — enabling truly global reasoning from the start.

patch-size-16 196-tokens global-attention

Patch Embedding Layer

Each 16×16×3 patch (768 raw values) must be projected into a D-dimensional embedding space. The original ViT paper uses a simple linear projection — but this is mathematically equivalent to a Conv2D with kernel size 16 and stride 16. The convolution slides over the image with no overlap, producing one embedding vector per patch.

Implementation with Conv2D

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class PatchEmbedding(layers.Layer):
    """Convert image into patch embeddings using Conv2D.

    Equivalent to: flatten each patch, then multiply by learned weight matrix.
    Conv2D with kernel_size=stride=patch_size achieves this efficiently on GPU.

    Args:
        patch_size: size of each square patch (default 16)
        embed_dim: output embedding dimension (default 768)
    """
    def __init__(self, patch_size=16, embed_dim=768, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        # Conv2D with kernel=stride=patch_size = non-overlapping linear projection
        self.projection = layers.Conv2D(
            filters=embed_dim,
            kernel_size=patch_size,
            strides=patch_size,
            padding="valid",
            name="patch_projection"
        )

    def call(self, images):
        """
        Args:
            images: (batch, H, W, 3) input images

        Returns:
            embeddings: (batch, num_patches, embed_dim)
        """
        # (batch, H/P, W/P, embed_dim) e.g. (batch, 14, 14, 768)
        x = self.projection(images)
        batch_size = tf.shape(x)[0]
        h, w = x.shape[1], x.shape[2]
        # Flatten spatial dims: (batch, 196, 768)
        x = tf.reshape(x, (batch_size, h * w, self.embed_dim))
        return x

# Test PatchEmbedding
patch_embed = PatchEmbedding(patch_size=16, embed_dim=768)

# Dummy batch of 4 images, 224x224x3
dummy_batch = tf.random.normal((4, 224, 224, 3))
embeddings = patch_embed(dummy_batch)

print(f"Input shape:  {dummy_batch.shape}")   # (4, 224, 224, 3)
print(f"Output shape: {embeddings.shape}")    # (4, 196, 768)
print(f"Num patches:  {embeddings.shape[1]}") # 196
print(f"Embed dim:    {embeddings.shape[2]}") # 768
print(f"\nParameters: {patch_embed.count_params():,}")
# 16*16*3*768 + 768 (bias) = 590,592

Position Embeddings & [CLS] Token

Self-attention is permutation-equivariant — without positional information, the Transformer cannot distinguish patch ordering. ViT uses learnable 1D position embeddings (not the fixed sinusoidal encoding from the original Transformer). Each of the 197 positions (196 patches + 1 [CLS] token) gets a learned vector added to the patch embedding.

The [CLS] Token

Borrowed from BERT, the [CLS] token is a learnable embedding prepended to the patch sequence. After passing through all Transformer layers, the [CLS] output serves as the global image representation for classification — it aggregates information from all patches via self-attention.

import tensorflow as tf
from tensorflow.keras import layers

class CLSTokenAndPositionEmbedding(layers.Layer):
    """Add learnable [CLS] token and position embeddings.

    Sequence flow:
        patch_embeddings (batch, 196, D)
        -> prepend [CLS] token -> (batch, 197, D)
        -> add position embeddings -> (batch, 197, D)

    Args:
        num_patches: number of image patches (e.g., 196)
        embed_dim: embedding dimension (e.g., 768)
    """
    def __init__(self, num_patches, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.num_patches = num_patches
        self.embed_dim = embed_dim

        # Learnable [CLS] token: (1, 1, embed_dim)
        self.cls_token = self.add_weight(
            name="cls_token",
            shape=(1, 1, embed_dim),
            initializer="random_normal",
            trainable=True
        )

        # Learnable position embeddings for 197 positions
        # (1 CLS + 196 patches)
        self.position_embedding = self.add_weight(
            name="position_embedding",
            shape=(1, num_patches + 1, embed_dim),
            initializer="random_normal",
            trainable=True
        )

    def call(self, patch_embeddings):
        """
        Args:
            patch_embeddings: (batch, num_patches, embed_dim)

        Returns:
            encoded: (batch, num_patches + 1, embed_dim)
        """
        batch_size = tf.shape(patch_embeddings)[0]

        # Broadcast [CLS] token to batch: (batch, 1, embed_dim)
        cls_tokens = tf.broadcast_to(
            self.cls_token, (batch_size, 1, self.embed_dim)
        )

        # Prepend [CLS]: (batch, 197, embed_dim)
        x = tf.concat([cls_tokens, patch_embeddings], axis=1)

        # Add position embeddings
        x = x + self.position_embedding

        return x

# Test
num_patches = 196
embed_dim = 768

pos_embed_layer = CLSTokenAndPositionEmbedding(num_patches, embed_dim)
dummy_patches = tf.random.normal((4, 196, 768))
output = pos_embed_layer(dummy_patches)

print(f"Input:  {dummy_patches.shape}")  # (4, 196, 768)
print(f"Output: {output.shape}")         # (4, 197, 768)
print(f"  -> [CLS] token added at position 0")
print(f"  -> Position embeddings added to all 197 tokens")
print(f"\nParameters:")
print(f"  CLS token:        {1 * embed_dim:,} = {768}")
print(f"  Position embed:   {197 * embed_dim:,} = {197 * 768:,}")
print(f"  Total:            {pos_embed_layer.count_params():,}")

Transformer Encoder Block

The core building block of ViT is the standard Transformer encoder: LayerNorm, Multi-Head Self-Attention, residual connection, LayerNorm, MLP (with GELU activation), and another residual connection. This block is repeated N times (12 for ViT-Base, 24 for ViT-Large).

Scaled Dot-Product Attention

The attention mechanism computes compatibility scores between all pairs of tokens:

$$\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where $Q$, $K$, $V$ are the query, key, and value projections of the input, and $d_k$ is the dimension per head. For ViT-Base with 12 heads and $D=768$, we have $d_k = 768 / 12 = 64$.

Transformer Encoder Block Internals
flowchart TD
    A[Input Tokens
batch x 197 x D] --> B[LayerNorm] B --> C[Multi-Head
Self-Attention] C --> D[+ Residual] A --> D D --> E[LayerNorm] E --> F[MLP: Dense 4D
GELU
Dense D] F --> G[+ Residual] D --> G G --> H[Output Tokens
batch x 197 x D]
import tensorflow as tf
from tensorflow.keras import layers

class TransformerBlock(layers.Layer):
    """Single Transformer encoder block for ViT.

    Architecture:
        x -> LayerNorm -> MHSA -> + residual
          -> LayerNorm -> MLP  -> + residual

    Args:
        embed_dim: model dimension D (e.g., 768)
        num_heads: number of attention heads (e.g., 12)
        mlp_ratio: MLP hidden dim expansion factor (default 4)
        dropout_rate: dropout probability (default 0.1)
    """
    def __init__(self, embed_dim, num_heads, mlp_ratio=4,
                 dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.attn = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embed_dim // num_heads,
            dropout=dropout_rate
        )
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)

        mlp_hidden = embed_dim * mlp_ratio
        self.mlp = keras.Sequential([
            layers.Dense(mlp_hidden, activation="gelu"),
            layers.Dropout(dropout_rate),
            layers.Dense(embed_dim),
            layers.Dropout(dropout_rate),
        ])

    def call(self, x, training=False):
        """
        Args:
            x: (batch, seq_len, embed_dim) input tokens
            training: whether in training mode

        Returns:
            (batch, seq_len, embed_dim) transformed tokens
        """
        # MHSA with pre-norm and residual
        x_norm = self.norm1(x)
        attn_output = self.attn(
            query=x_norm, key=x_norm, value=x_norm,
            training=training
        )
        x = x + attn_output

        # MLP with pre-norm and residual
        x_norm = self.norm2(x)
        mlp_output = self.mlp(x_norm, training=training)
        x = x + mlp_output

        return x

# Test TransformerBlock
block = TransformerBlock(embed_dim=768, num_heads=12, mlp_ratio=4)

# Input: batch of 4, sequence of 197 tokens, 768 dims
dummy_input = tf.random.normal((4, 197, 768))
output = block(dummy_input, training=False)

print(f"Input shape:  {dummy_input.shape}")  # (4, 197, 768)
print(f"Output shape: {output.shape}")       # (4, 197, 768)
print(f"\nBlock parameters: {block.count_params():,}")
print(f"  MHSA:  Q,K,V projections + output = ~2.4M")
print(f"  MLP:   768->3072->768 = ~4.7M")
print(f"  Norms: ~3K")

Building ViT from Scratch

Now we assemble all components into the complete Vision Transformer: PatchEmbedding, CLSToken + PositionEmbedding, a stack of TransformerBlocks, final LayerNorm, and a classification head that reads the [CLS] token output.

Model Variants

Architecture ViT Model Sizes
ViT Configuration Variants
  • ViT-Tiny: 6 layers, 256 hidden, 4 heads, ~5.7M params
  • ViT-Small: 8 layers, 384 hidden, 6 heads, ~22M params
  • ViT-Base: 12 layers, 768 hidden, 12 heads, ~86M params
  • ViT-Large: 24 layers, 1024 hidden, 16 heads, ~307M params
  • ViT-Huge: 32 layers, 1280 hidden, 16 heads, ~632M params
ViT-Base 86M-params 12-layers
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class PatchEmbedding(layers.Layer):
    """Patch embedding via Conv2D projection."""
    def __init__(self, patch_size, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.proj = layers.Conv2D(
            embed_dim, kernel_size=patch_size,
            strides=patch_size, padding="valid"
        )

    def call(self, x):
        x = self.proj(x)
        b = tf.shape(x)[0]
        x = tf.reshape(x, (b, -1, x.shape[-1]))
        return x


class VisionTransformer(keras.Model):
    """Complete Vision Transformer (ViT) for image classification.

    Args:
        image_size: input image size (square)
        patch_size: size of each patch
        num_classes: number of output classes
        embed_dim: transformer hidden dimension D
        depth: number of transformer blocks
        num_heads: number of attention heads
        mlp_ratio: MLP expansion factor
        dropout_rate: dropout probability
    """
    def __init__(self, image_size=224, patch_size=16, num_classes=10,
                 embed_dim=768, depth=12, num_heads=12,
                 mlp_ratio=4, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.num_patches = (image_size // patch_size) ** 2  # 196
        self.embed_dim = embed_dim

        # 1. Patch embedding
        self.patch_embed = PatchEmbedding(patch_size, embed_dim)

        # 2. [CLS] token
        self.cls_token = self.add_weight(
            "cls_token", shape=(1, 1, embed_dim),
            initializer="zeros", trainable=True
        )

        # 3. Position embeddings (197 = 196 patches + 1 CLS)
        self.pos_embed = self.add_weight(
            "pos_embed", shape=(1, self.num_patches + 1, embed_dim),
            initializer="random_normal", trainable=True
        )

        self.pos_drop = layers.Dropout(dropout_rate)

        # 4. Transformer encoder blocks
        self.blocks = [
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout_rate)
            for _ in range(depth)
        ]

        # 5. Final LayerNorm
        self.norm = layers.LayerNormalization(epsilon=1e-6)

        # 6. Classification head
        self.head = layers.Dense(num_classes)

    def call(self, images, training=False):
        """
        Args:
            images: (batch, H, W, 3) input images

        Returns:
            logits: (batch, num_classes) classification logits
        """
        batch_size = tf.shape(images)[0]

        # Patch embedding: (batch, 196, D)
        x = self.patch_embed(images)

        # Prepend [CLS]: (batch, 197, D)
        cls = tf.broadcast_to(self.cls_token, (batch_size, 1, self.embed_dim))
        x = tf.concat([cls, x], axis=1)

        # Add position embeddings
        x = x + self.pos_embed
        x = self.pos_drop(x, training=training)

        # Transformer encoder
        for block in self.blocks:
            x = block(x, training=training)

        # Final norm
        x = self.norm(x)

        # Extract [CLS] token output (position 0)
        cls_output = x[:, 0]

        # Classification head
        logits = self.head(cls_output)
        return logits

# Build ViT-Base
vit_base = VisionTransformer(
    image_size=224, patch_size=16, num_classes=1000,
    embed_dim=768, depth=12, num_heads=12,
    mlp_ratio=4, dropout_rate=0.1
)

# Forward pass to build the model
dummy_images = tf.random.normal((2, 224, 224, 3))
logits = vit_base(dummy_images, training=False)

print(f"ViT-Base Configuration:")
print(f"  Image size:   224x224")
print(f"  Patch size:   16x16")
print(f"  Num patches:  196")
print(f"  Embed dim:    768")
print(f"  Depth:        12 blocks")
print(f"  Heads:        12")
print(f"  MLP ratio:    4x (768 -> 3072 -> 768)")
print(f"  Num classes:  1000")
print(f"\nOutput shape: {logits.shape}")  # (2, 1000)
print(f"Total parameters: {vit_base.count_params():,}")

Training ViT on CIFAR-10

The original ViT paper demonstrated that Transformers need massive datasets (JFT-300M, ImageNet-21k) to outperform CNNs. On smaller datasets like CIFAR-10 (50k training images), ViT underperforms unless you apply heavy regularization: smaller patch sizes, fewer layers, aggressive augmentation (RandAugment, Mixup, CutMix), and cosine learning rate with warmup.

Data Regime Warning: A ViT-Base trained from scratch on CIFAR-10 will severely overfit. Use ViT-Tiny (6 layers, 256 hidden) for small datasets, or use pre-trained weights. The lack of inductive bias (no locality, no translation equivariance) means ViT must learn these properties from data alone.

Training with Heavy Augmentation

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

def build_vit_tiny(image_size=32, patch_size=4, num_classes=10):
    """Build ViT-Tiny for CIFAR-10 (small images, small patches).

    CIFAR-10 images are 32x32, so we use patch_size=4:
    32/4 = 8 -> 8x8 = 64 patches per image.

    Args:
        image_size: input size (32 for CIFAR-10)
        patch_size: patch size (4 for 32x32 images)
        num_classes: number of classes (10 for CIFAR-10)

    Returns:
        Compiled Keras model ready for training
    """
    inputs = layers.Input(shape=(image_size, image_size, 3))

    # Patch embedding via Conv2D
    x = layers.Conv2D(256, kernel_size=patch_size,
                      strides=patch_size, padding="valid")(inputs)
    # x shape: (batch, 8, 8, 256)
    x = layers.Reshape((64, 256))(x)  # 64 patches, 256 dims

    # [CLS] token and position embeddings
    # Using Embedding layer trick for positions
    num_patches = 64
    embed_dim = 256

    # Add [CLS] token via concatenation
    cls_token = tf.Variable(
        tf.zeros((1, 1, embed_dim)), trainable=True, name="cls"
    )
    # We use a Lambda for simplicity in functional API
    def add_cls(patch_embeds):
        batch = tf.shape(patch_embeds)[0]
        cls_broadcast = tf.broadcast_to(cls_token, (batch, 1, embed_dim))
        return tf.concat([cls_broadcast, patch_embeds], axis=1)

    x = layers.Lambda(add_cls)(x)  # (batch, 65, 256)

    # Learnable position embeddings
    positions = tf.Variable(
        tf.random.normal((1, 65, embed_dim), stddev=0.02),
        trainable=True, name="pos_embed"
    )
    x = layers.Lambda(lambda t: t + positions)(x)
    x = layers.Dropout(0.1)(x)

    # 6 Transformer blocks (ViT-Tiny)
    for _ in range(6):
        # Pre-norm MHSA + residual
        x_norm = layers.LayerNormalization(epsilon=1e-6)(x)
        attn = layers.MultiHeadAttention(
            num_heads=4, key_dim=64, dropout=0.1
        )(x_norm, x_norm)
        x = x + attn

        # Pre-norm MLP + residual
        x_norm = layers.LayerNormalization(epsilon=1e-6)(x)
        mlp = layers.Dense(1024, activation="gelu")(x_norm)
        mlp = layers.Dropout(0.1)(mlp)
        mlp = layers.Dense(256)(mlp)
        mlp = layers.Dropout(0.1)(mlp)
        x = x + mlp

    # Final norm and classification from [CLS]
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    cls_output = x[:, 0]  # Extract [CLS] token
    outputs = layers.Dense(num_classes)(cls_output)

    model = keras.Model(inputs, outputs, name="vit_tiny_cifar10")
    return model

# Build and compile
model = build_vit_tiny(image_size=32, patch_size=4, num_classes=10)

print(f"ViT-Tiny for CIFAR-10:")
print(f"  Image: 32x32, Patch: 4x4, Patches: 64")
print(f"  Embed dim: 256, Heads: 4, Depth: 6")
print(f"  Parameters: {model.count_params():,}")
import tensorflow as tf
from tensorflow import keras
import numpy as np

def get_cifar10_augmented(batch_size=128):
    """Load CIFAR-10 with heavy augmentation for ViT training.

    Augmentation pipeline:
    - RandAugment (magnitude 9, 2 operations)
    - Random horizontal flip
    - Normalization to [0, 1]

    Returns:
        train_ds, val_ds: tf.data.Dataset pipelines
    """
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

    # Normalize to [0, 1]
    x_train = x_train.astype(np.float32) / 255.0
    x_test = x_test.astype(np.float32) / 255.0
    y_train = y_train.squeeze()
    y_test = y_test.squeeze()

    # Training augmentation
    data_augmentation = keras.Sequential([
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.1),
    ], name="augmentation")

    train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_ds = train_ds.shuffle(10000).batch(batch_size)
    train_ds = train_ds.map(
        lambda x, y: (data_augmentation(x, training=True), y),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

    val_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return train_ds, val_ds

def cosine_warmup_schedule(total_steps, warmup_steps, peak_lr=1e-3):
    """Cosine decay with linear warmup learning rate schedule.

    Args:
        total_steps: total training steps
        warmup_steps: steps for linear warmup
        peak_lr: maximum learning rate

    Returns:
        tf.keras.optimizers.schedules.LearningRateSchedule
    """
    warmup = tf.keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate=0.0,
        decay_steps=warmup_steps,
        end_learning_rate=peak_lr,
        power=1.0  # linear
    )
    cosine = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=peak_lr,
        decay_steps=total_steps - warmup_steps,
        alpha=1e-5  # minimum LR ratio
    )

    class WarmupCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, warmup_sched, cosine_sched, warmup_steps):
            self.warmup_sched = warmup_sched
            self.cosine_sched = cosine_sched
            self.warmup_steps = warmup_steps

        def __call__(self, step):
            return tf.cond(
                step < self.warmup_steps,
                lambda: self.warmup_sched(step),
                lambda: self.cosine_sched(step - self.warmup_steps)
            )

    return WarmupCosine(warmup, cosine, warmup_steps)

# Training configuration
epochs = 200
batch_size = 128
steps_per_epoch = 50000 // batch_size  # ~390
total_steps = steps_per_epoch * epochs
warmup_steps = steps_per_epoch * 10  # 10 epoch warmup

lr_schedule = cosine_warmup_schedule(total_steps, warmup_steps, peak_lr=1e-3)

print(f"Training Configuration:")
print(f"  Epochs: {epochs}")
print(f"  Batch size: {batch_size}")
print(f"  Steps/epoch: {steps_per_epoch}")
print(f"  Total steps: {total_steps:,}")
print(f"  Warmup steps: {warmup_steps:,} ({10} epochs)")
print(f"  Peak LR: 1e-3")
print(f"  Schedule: Linear warmup + Cosine decay")
print(f"  Optimizer: AdamW (weight_decay=0.05)")
print(f"  Label smoothing: 0.1")
print(f"\nExpected accuracy: ~92-94% (CIFAR-10)")
print(f"Training time: ~2 hours on single V100")

Attention Visualization

One of the most compelling properties of ViT is that attention maps are directly interpretable. We can extract the attention weights from each head and layer to see which patches the model attends to when making predictions. Early layers tend to show local attention patterns (like convolutions), while deeper layers capture long-range semantic relationships.

Extracting and Visualizing Attention Maps

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

class TransformerBlockWithAttnOutput(layers.Layer):
    """Transformer block that also returns attention weights.

    Same architecture as standard TransformerBlock but captures
    the attention weight matrix for visualization.
    """
    def __init__(self, embed_dim, num_heads, mlp_ratio=4, **kwargs):
        super().__init__(**kwargs)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.attn = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embed_dim // num_heads
        )
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.mlp = keras.Sequential([
            layers.Dense(embed_dim * mlp_ratio, activation="gelu"),
            layers.Dense(embed_dim),
        ])

    def call(self, x, return_attention=False):
        x_norm = self.norm1(x)
        if return_attention:
            attn_output, attn_weights = self.attn(
                x_norm, x_norm, x_norm,
                return_attention_scores=True
            )
        else:
            attn_output = self.attn(x_norm, x_norm, x_norm)
            attn_weights = None

        x = x + attn_output
        x = x + self.mlp(self.norm2(x))
        return x, attn_weights


def extract_attention_maps(model_blocks, input_tokens, num_patches_h=14):
    """Extract attention maps from all transformer blocks.

    Args:
        model_blocks: list of TransformerBlockWithAttnOutput layers
        input_tokens: (1, seq_len, embed_dim) prepared input
        num_patches_h: patches per row (14 for 224/16)

    Returns:
        attention_maps: list of (num_heads, seq_len, seq_len) arrays
    """
    attention_maps = []
    x = input_tokens

    for i, block in enumerate(model_blocks):
        x, attn_weights = block(x, return_attention=True)
        # attn_weights shape: (batch, num_heads, seq_len, seq_len)
        attention_maps.append(attn_weights[0].numpy())  # remove batch

    return attention_maps


def visualize_cls_attention(attn_weights, num_patches_h=14,
                            layer_idx=0, head_idx=0):
    """Extract [CLS] token attention as a spatial heatmap.

    The [CLS] token is at position 0. Its attention to patches
    [1:197] forms a 14x14 spatial map showing what the model
    'looks at' for classification.

    Args:
        attn_weights: (num_heads, 197, 197) attention matrix
        num_patches_h: patches per row
        layer_idx: which layer (for display)
        head_idx: which attention head

    Returns:
        heatmap: (14, 14) attention heatmap
    """
    # CLS attention to all patches (skip CLS-to-CLS at index 0)
    cls_attn = attn_weights[head_idx, 0, 1:]  # (196,)

    # Reshape to spatial grid
    heatmap = cls_attn.reshape(num_patches_h, num_patches_h)

    # Normalize to [0, 1]
    heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)

    print(f"Layer {layer_idx}, Head {head_idx}:")
    print(f"  CLS attention shape: {cls_attn.shape}")
    print(f"  Heatmap shape: {heatmap.shape}")
    print(f"  Max attention: {cls_attn.max():.4f}")
    print(f"  Min attention: {cls_attn.min():.4f}")
    print(f"  Entropy: {-np.sum(cls_attn * np.log(cls_attn + 1e-8)):.4f}")

    return heatmap

# Demonstrate attention extraction
num_heads = 12
seq_len = 197  # 196 patches + 1 CLS

# Simulate attention weights from different layers
np.random.seed(42)

# Early layer: more uniform/local attention
early_attn = np.random.dirichlet(np.ones(seq_len) * 5, size=(num_heads, seq_len))
# Later layer: more focused/sparse attention
late_attn = np.random.dirichlet(np.ones(seq_len) * 0.5, size=(num_heads, seq_len))

print("Attention Map Analysis:")
print("=" * 50)
print("\nEarly Layer (Layer 1) - Diffuse attention:")
heatmap_early = visualize_cls_attention(early_attn, layer_idx=1, head_idx=0)

print("\nLate Layer (Layer 11) - Focused attention:")
heatmap_late = visualize_cls_attention(late_attn, layer_idx=11, head_idx=0)

print(f"\nObservation:")
print(f"  Early layers: attention spread across many patches (local patterns)")
print(f"  Late layers: attention concentrated on semantically relevant regions")
Attention Patterns by Layer Depth:
  • Layers 1-3: Local attention resembling convolution kernels. Nearby patches attend to each other.
  • Layers 4-8: Mid-range patterns emerge. Some heads specialize in horizontal edges, others in vertical structures.
  • Layers 9-12: Global semantic attention. The [CLS] token focuses on object-relevant patches, ignoring background.

Transfer Learning with Pre-trained ViT

In practice, training ViT from scratch on small datasets yields modest results. The power of ViT comes from pre-training on massive datasets (ImageNet-21k or JFT-300M) and then fine-tuning on the target task. A ViT-Base pre-trained on ImageNet-21k and fine-tuned on a flowers dataset can reach 98%+ accuracy with just a few epochs of training.

Fine-Tuning Pipeline

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

def build_finetuning_model(base_model_name="vit-base-patch16-224",
                           num_classes=5, image_size=224):
    """Build fine-tuning model from pre-trained ViT weights.

    Strategy:
    1. Load pre-trained ViT backbone (frozen initially)
    2. Replace classification head for new task
    3. Unfreeze top layers for fine-tuning
    4. Train with low learning rate

    Args:
        base_model_name: pre-trained model identifier
        num_classes: number of target classes
        image_size: input image size

    Returns:
        Compiled model ready for fine-tuning
    """
    # In practice, load from TensorFlow Hub or HuggingFace:
    # import tensorflow_hub as hub
    # base_url = "https://tfhub.dev/sayakpaul/vit_b16_fe/1"
    # base_model = hub.KerasLayer(base_url, trainable=True)

    # Simulated architecture for demonstration
    inputs = layers.Input(shape=(image_size, image_size, 3))

    # Pre-processing: normalize to [-1, 1] (ViT convention)
    x = layers.Rescaling(scale=1.0/127.5, offset=-1.0)(inputs)

    # Simulated ViT backbone output (768-dim CLS embedding)
    # In real code: x = base_model(x)
    x = layers.Conv2D(768, 16, strides=16, padding="valid")(x)
    x = layers.Reshape((196, 768))(x)
    x = layers.GlobalAveragePooling1D()(x)  # Simplified

    # New classification head
    x = layers.LayerNormalization()(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(256, activation="gelu")(x)
    x = layers.Dropout(0.1)(x)
    outputs = layers.Dense(num_classes)(x)

    model = keras.Model(inputs, outputs, name="vit_finetuned")
    return model

# Build fine-tuning model for 5-class flowers dataset
model = build_finetuning_model(num_classes=5, image_size=224)

# Two-phase training strategy
print("Fine-Tuning Strategy:")
print("=" * 50)
print()
print("Phase 1: Head-only training (5 epochs)")
print("  - Freeze all backbone layers")
print("  - Train only new classification head")
print("  - LR: 1e-3, Optimizer: Adam")
print("  - Purpose: warm up the new head")
print()
print("Phase 2: Full fine-tuning (15 epochs)")
print("  - Unfreeze top 4 transformer blocks")
print("  - LR: 1e-5 (very low to preserve features)")
print("  - Weight decay: 0.01")
print("  - Purpose: adapt backbone to target domain")
print()
print(f"Model parameters: {model.count_params():,}")
import numpy as np

def compare_vit_vs_cnn(dataset_sizes, vit_from_scratch, vit_finetuned,
                       resnet50_finetuned, efficientnet_finetuned):
    """Compare ViT vs CNN accuracy across dataset sizes.

    Shows the data regime where each architecture excels:
    - Small data (<10k): CNNs win due to inductive bias
    - Medium data (10k-100k): Close competition
    - Large data (>100k): ViT matches or exceeds CNNs

    Args:
        dataset_sizes: list of training set sizes
        vit_from_scratch: ViT accuracy at each size
        vit_finetuned: Pre-trained ViT accuracy
        resnet50_finetuned: ResNet-50 fine-tuned accuracy
        efficientnet_finetuned: EfficientNet-B0 accuracy

    Returns:
        Comparison table printed to stdout
    """
    print(f"{'Dataset Size':<15} {'ViT Scratch':<14} {'ViT Pretrain':<14} "
          f"{'ResNet-50':<12} {'EffNet-B0':<12} {'Winner':<15}")
    print("-" * 82)

    for i, size in enumerate(dataset_sizes):
        scores = {
            "ViT Scratch": vit_from_scratch[i],
            "ViT Pretrain": vit_finetuned[i],
            "ResNet-50": resnet50_finetuned[i],
            "EffNet-B0": efficientnet_finetuned[i]
        }
        winner = max(scores, key=scores.get)
        size_str = f"{size:,}" if size < 1000000 else f"{size // 1000}k"

        print(f"{size_str:<15} {vit_from_scratch[i]:<14.1f} "
              f"{vit_finetuned[i]:<14.1f} {resnet50_finetuned[i]:<12.1f} "
              f"{efficientnet_finetuned[i]:<12.1f} {winner:<15}")

    return scores

# Accuracy comparison across data regimes
dataset_sizes = [1000, 5000, 10000, 50000, 100000, 1000000]

# Simulated accuracy values based on published benchmarks
vit_scratch =      [52.3, 68.7, 76.2, 88.4, 92.1, 97.8]
vit_pretrained =   [89.5, 93.2, 95.1, 97.3, 98.2, 99.1]
resnet50_ft =      [82.1, 88.9, 91.5, 94.8, 96.0, 97.5]
efficientnet_ft =  [84.7, 90.2, 92.8, 95.6, 96.8, 97.9]

print("ViT vs CNN: Accuracy by Dataset Size")
print("=" * 82)
compare_vit_vs_cnn(
    dataset_sizes, vit_scratch, vit_pretrained,
    resnet50_ft, efficientnet_ft
)

print("\nKey Takeaways:")
print("  1. ViT from scratch needs >50k images to compete with CNNs")
print("  2. Pre-trained ViT dominates at ALL dataset sizes")
print("  3. CNNs (ResNet, EfficientNet) have better inductive bias for small data")
print("  4. At scale (>100k), architectures converge in performance")
When ViT Beats CNN vs When CNN Wins:
  • ViT wins with pre-trained weights (always), with large datasets (>100k), when global context matters (scene understanding, multi-object reasoning), and for transfer learning across domains.
  • CNN wins on small datasets without pre-training, for edge deployment (smaller models), when locality is the dominant signal (texture classification), and with limited compute budget for training.
  • Hybrid approaches (ConvNeXt, CoAtNet) combine the best of both — CNN stems for local features plus Transformer blocks for global attention.