An Image is Worth 16×16 Words
In 2020, Google Research introduced a radical idea: what if we threw away convolutions entirely and treated images like text? The Vision Transformer (ViT) does exactly this — it splits an image into fixed-size patches, flattens each patch into a vector, and feeds the resulting sequence into a standard Transformer encoder. No convolutions, no pooling, no inductive bias about local connectivity.
flowchart LR
A[Input Image
224x224x3] --> B[Split into
16x16 Patches]
B --> C[Flatten &
Linear Project]
C --> D[Add Position
Embeddings]
D --> E[Prepend
CLS Token]
E --> F[Transformer
Encoder x N]
F --> G[Extract CLS
Output]
G --> H[MLP Head
Classification]
Splitting and Visualizing Patches
The first step is mechanical: reshape a 224×224×3 image into a sequence of 196 flattened patches, each of size 16×16×3 = 768 values.
import tensorflow as tf
import numpy as np
def extract_patches(image, patch_size=16):
"""Split an image into non-overlapping patches.
Args:
image: (H, W, 3) tensor, values in [0, 1]
patch_size: size of each square patch (default 16)
Returns:
patches: (num_patches, patch_size*patch_size*3) flattened patches
"""
# image shape: (224, 224, 3)
h, w, c = image.shape
num_patches_h = h // patch_size # 224 // 16 = 14
num_patches_w = w // patch_size # 224 // 16 = 14
num_patches = num_patches_h * num_patches_w # 196
# Reshape into grid of patches
patches = tf.image.extract_patches(
images=tf.expand_dims(image, 0), # add batch dim
sizes=[1, patch_size, patch_size, 1],
strides=[1, patch_size, patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID"
)
# patches shape: (1, 14, 14, 768)
patches = tf.reshape(patches, (num_patches, patch_size * patch_size * c))
return patches
# Create a dummy 224x224 RGB image
np.random.seed(42)
dummy_image = np.random.rand(224, 224, 3).astype(np.float32)
# Extract patches
patches = extract_patches(dummy_image, patch_size=16)
print(f"Image shape: (224, 224, 3)")
print(f"Patch size: 16x16")
print(f"Number of patches: {patches.shape[0]}") # 196
print(f"Each patch flattened: {patches.shape[1]} values") # 768
print(f"Grid: 14 rows x 14 columns = 196 patches")
print(f"\nFirst patch (top-left corner) stats:")
print(f" Mean pixel: {patches[0].numpy().mean():.4f}")
print(f" Std pixel: {patches[0].numpy().std():.4f}")
Visualizing the 14×14 Patch Grid
When you visualize the 196 patches rearranged on a grid, you can see that each patch captures a local region of the image. Unlike CNN receptive fields that grow gradually, the Transformer can attend to distant patches from the very first layer — enabling truly global reasoning from the start.
Patch Embedding Layer
Each 16×16×3 patch (768 raw values) must be projected into a D-dimensional embedding space. The original ViT paper uses a simple linear projection — but this is mathematically equivalent to a Conv2D with kernel size 16 and stride 16. The convolution slides over the image with no overlap, producing one embedding vector per patch.
Implementation with Conv2D
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class PatchEmbedding(layers.Layer):
"""Convert image into patch embeddings using Conv2D.
Equivalent to: flatten each patch, then multiply by learned weight matrix.
Conv2D with kernel_size=stride=patch_size achieves this efficiently on GPU.
Args:
patch_size: size of each square patch (default 16)
embed_dim: output embedding dimension (default 768)
"""
def __init__(self, patch_size=16, embed_dim=768, **kwargs):
super().__init__(**kwargs)
self.patch_size = patch_size
self.embed_dim = embed_dim
# Conv2D with kernel=stride=patch_size = non-overlapping linear projection
self.projection = layers.Conv2D(
filters=embed_dim,
kernel_size=patch_size,
strides=patch_size,
padding="valid",
name="patch_projection"
)
def call(self, images):
"""
Args:
images: (batch, H, W, 3) input images
Returns:
embeddings: (batch, num_patches, embed_dim)
"""
# (batch, H/P, W/P, embed_dim) e.g. (batch, 14, 14, 768)
x = self.projection(images)
batch_size = tf.shape(x)[0]
h, w = x.shape[1], x.shape[2]
# Flatten spatial dims: (batch, 196, 768)
x = tf.reshape(x, (batch_size, h * w, self.embed_dim))
return x
# Test PatchEmbedding
patch_embed = PatchEmbedding(patch_size=16, embed_dim=768)
# Dummy batch of 4 images, 224x224x3
dummy_batch = tf.random.normal((4, 224, 224, 3))
embeddings = patch_embed(dummy_batch)
print(f"Input shape: {dummy_batch.shape}") # (4, 224, 224, 3)
print(f"Output shape: {embeddings.shape}") # (4, 196, 768)
print(f"Num patches: {embeddings.shape[1]}") # 196
print(f"Embed dim: {embeddings.shape[2]}") # 768
print(f"\nParameters: {patch_embed.count_params():,}")
# 16*16*3*768 + 768 (bias) = 590,592
Position Embeddings & [CLS] Token
Self-attention is permutation-equivariant — without positional information, the Transformer cannot distinguish patch ordering. ViT uses learnable 1D position embeddings (not the fixed sinusoidal encoding from the original Transformer). Each of the 197 positions (196 patches + 1 [CLS] token) gets a learned vector added to the patch embedding.
The [CLS] Token
Borrowed from BERT, the [CLS] token is a learnable embedding prepended to the patch sequence. After passing through all Transformer layers, the [CLS] output serves as the global image representation for classification — it aggregates information from all patches via self-attention.
import tensorflow as tf
from tensorflow.keras import layers
class CLSTokenAndPositionEmbedding(layers.Layer):
"""Add learnable [CLS] token and position embeddings.
Sequence flow:
patch_embeddings (batch, 196, D)
-> prepend [CLS] token -> (batch, 197, D)
-> add position embeddings -> (batch, 197, D)
Args:
num_patches: number of image patches (e.g., 196)
embed_dim: embedding dimension (e.g., 768)
"""
def __init__(self, num_patches, embed_dim, **kwargs):
super().__init__(**kwargs)
self.num_patches = num_patches
self.embed_dim = embed_dim
# Learnable [CLS] token: (1, 1, embed_dim)
self.cls_token = self.add_weight(
name="cls_token",
shape=(1, 1, embed_dim),
initializer="random_normal",
trainable=True
)
# Learnable position embeddings for 197 positions
# (1 CLS + 196 patches)
self.position_embedding = self.add_weight(
name="position_embedding",
shape=(1, num_patches + 1, embed_dim),
initializer="random_normal",
trainable=True
)
def call(self, patch_embeddings):
"""
Args:
patch_embeddings: (batch, num_patches, embed_dim)
Returns:
encoded: (batch, num_patches + 1, embed_dim)
"""
batch_size = tf.shape(patch_embeddings)[0]
# Broadcast [CLS] token to batch: (batch, 1, embed_dim)
cls_tokens = tf.broadcast_to(
self.cls_token, (batch_size, 1, self.embed_dim)
)
# Prepend [CLS]: (batch, 197, embed_dim)
x = tf.concat([cls_tokens, patch_embeddings], axis=1)
# Add position embeddings
x = x + self.position_embedding
return x
# Test
num_patches = 196
embed_dim = 768
pos_embed_layer = CLSTokenAndPositionEmbedding(num_patches, embed_dim)
dummy_patches = tf.random.normal((4, 196, 768))
output = pos_embed_layer(dummy_patches)
print(f"Input: {dummy_patches.shape}") # (4, 196, 768)
print(f"Output: {output.shape}") # (4, 197, 768)
print(f" -> [CLS] token added at position 0")
print(f" -> Position embeddings added to all 197 tokens")
print(f"\nParameters:")
print(f" CLS token: {1 * embed_dim:,} = {768}")
print(f" Position embed: {197 * embed_dim:,} = {197 * 768:,}")
print(f" Total: {pos_embed_layer.count_params():,}")
Transformer Encoder Block
The core building block of ViT is the standard Transformer encoder: LayerNorm, Multi-Head Self-Attention, residual connection, LayerNorm, MLP (with GELU activation), and another residual connection. This block is repeated N times (12 for ViT-Base, 24 for ViT-Large).
Scaled Dot-Product Attention
The attention mechanism computes compatibility scores between all pairs of tokens:
$$\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
Where $Q$, $K$, $V$ are the query, key, and value projections of the input, and $d_k$ is the dimension per head. For ViT-Base with 12 heads and $D=768$, we have $d_k = 768 / 12 = 64$.
flowchart TD
A[Input Tokens
batch x 197 x D] --> B[LayerNorm]
B --> C[Multi-Head
Self-Attention]
C --> D[+ Residual]
A --> D
D --> E[LayerNorm]
E --> F[MLP: Dense 4D
GELU
Dense D]
F --> G[+ Residual]
D --> G
G --> H[Output Tokens
batch x 197 x D]
import tensorflow as tf
from tensorflow.keras import layers
class TransformerBlock(layers.Layer):
"""Single Transformer encoder block for ViT.
Architecture:
x -> LayerNorm -> MHSA -> + residual
-> LayerNorm -> MLP -> + residual
Args:
embed_dim: model dimension D (e.g., 768)
num_heads: number of attention heads (e.g., 12)
mlp_ratio: MLP hidden dim expansion factor (default 4)
dropout_rate: dropout probability (default 0.1)
"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4,
dropout_rate=0.1, **kwargs):
super().__init__(**kwargs)
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
self.attn = layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=embed_dim // num_heads,
dropout=dropout_rate
)
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
mlp_hidden = embed_dim * mlp_ratio
self.mlp = keras.Sequential([
layers.Dense(mlp_hidden, activation="gelu"),
layers.Dropout(dropout_rate),
layers.Dense(embed_dim),
layers.Dropout(dropout_rate),
])
def call(self, x, training=False):
"""
Args:
x: (batch, seq_len, embed_dim) input tokens
training: whether in training mode
Returns:
(batch, seq_len, embed_dim) transformed tokens
"""
# MHSA with pre-norm and residual
x_norm = self.norm1(x)
attn_output = self.attn(
query=x_norm, key=x_norm, value=x_norm,
training=training
)
x = x + attn_output
# MLP with pre-norm and residual
x_norm = self.norm2(x)
mlp_output = self.mlp(x_norm, training=training)
x = x + mlp_output
return x
# Test TransformerBlock
block = TransformerBlock(embed_dim=768, num_heads=12, mlp_ratio=4)
# Input: batch of 4, sequence of 197 tokens, 768 dims
dummy_input = tf.random.normal((4, 197, 768))
output = block(dummy_input, training=False)
print(f"Input shape: {dummy_input.shape}") # (4, 197, 768)
print(f"Output shape: {output.shape}") # (4, 197, 768)
print(f"\nBlock parameters: {block.count_params():,}")
print(f" MHSA: Q,K,V projections + output = ~2.4M")
print(f" MLP: 768->3072->768 = ~4.7M")
print(f" Norms: ~3K")
Building ViT from Scratch
Now we assemble all components into the complete Vision Transformer: PatchEmbedding, CLSToken + PositionEmbedding, a stack of TransformerBlocks, final LayerNorm, and a classification head that reads the [CLS] token output.
Model Variants
ViT Configuration Variants
- ViT-Tiny: 6 layers, 256 hidden, 4 heads, ~5.7M params
- ViT-Small: 8 layers, 384 hidden, 6 heads, ~22M params
- ViT-Base: 12 layers, 768 hidden, 12 heads, ~86M params
- ViT-Large: 24 layers, 1024 hidden, 16 heads, ~307M params
- ViT-Huge: 32 layers, 1280 hidden, 16 heads, ~632M params
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class PatchEmbedding(layers.Layer):
"""Patch embedding via Conv2D projection."""
def __init__(self, patch_size, embed_dim, **kwargs):
super().__init__(**kwargs)
self.proj = layers.Conv2D(
embed_dim, kernel_size=patch_size,
strides=patch_size, padding="valid"
)
def call(self, x):
x = self.proj(x)
b = tf.shape(x)[0]
x = tf.reshape(x, (b, -1, x.shape[-1]))
return x
class VisionTransformer(keras.Model):
"""Complete Vision Transformer (ViT) for image classification.
Args:
image_size: input image size (square)
patch_size: size of each patch
num_classes: number of output classes
embed_dim: transformer hidden dimension D
depth: number of transformer blocks
num_heads: number of attention heads
mlp_ratio: MLP expansion factor
dropout_rate: dropout probability
"""
def __init__(self, image_size=224, patch_size=16, num_classes=10,
embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, dropout_rate=0.1, **kwargs):
super().__init__(**kwargs)
self.num_patches = (image_size // patch_size) ** 2 # 196
self.embed_dim = embed_dim
# 1. Patch embedding
self.patch_embed = PatchEmbedding(patch_size, embed_dim)
# 2. [CLS] token
self.cls_token = self.add_weight(
"cls_token", shape=(1, 1, embed_dim),
initializer="zeros", trainable=True
)
# 3. Position embeddings (197 = 196 patches + 1 CLS)
self.pos_embed = self.add_weight(
"pos_embed", shape=(1, self.num_patches + 1, embed_dim),
initializer="random_normal", trainable=True
)
self.pos_drop = layers.Dropout(dropout_rate)
# 4. Transformer encoder blocks
self.blocks = [
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout_rate)
for _ in range(depth)
]
# 5. Final LayerNorm
self.norm = layers.LayerNormalization(epsilon=1e-6)
# 6. Classification head
self.head = layers.Dense(num_classes)
def call(self, images, training=False):
"""
Args:
images: (batch, H, W, 3) input images
Returns:
logits: (batch, num_classes) classification logits
"""
batch_size = tf.shape(images)[0]
# Patch embedding: (batch, 196, D)
x = self.patch_embed(images)
# Prepend [CLS]: (batch, 197, D)
cls = tf.broadcast_to(self.cls_token, (batch_size, 1, self.embed_dim))
x = tf.concat([cls, x], axis=1)
# Add position embeddings
x = x + self.pos_embed
x = self.pos_drop(x, training=training)
# Transformer encoder
for block in self.blocks:
x = block(x, training=training)
# Final norm
x = self.norm(x)
# Extract [CLS] token output (position 0)
cls_output = x[:, 0]
# Classification head
logits = self.head(cls_output)
return logits
# Build ViT-Base
vit_base = VisionTransformer(
image_size=224, patch_size=16, num_classes=1000,
embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, dropout_rate=0.1
)
# Forward pass to build the model
dummy_images = tf.random.normal((2, 224, 224, 3))
logits = vit_base(dummy_images, training=False)
print(f"ViT-Base Configuration:")
print(f" Image size: 224x224")
print(f" Patch size: 16x16")
print(f" Num patches: 196")
print(f" Embed dim: 768")
print(f" Depth: 12 blocks")
print(f" Heads: 12")
print(f" MLP ratio: 4x (768 -> 3072 -> 768)")
print(f" Num classes: 1000")
print(f"\nOutput shape: {logits.shape}") # (2, 1000)
print(f"Total parameters: {vit_base.count_params():,}")
Training ViT on CIFAR-10
The original ViT paper demonstrated that Transformers need massive datasets (JFT-300M, ImageNet-21k) to outperform CNNs. On smaller datasets like CIFAR-10 (50k training images), ViT underperforms unless you apply heavy regularization: smaller patch sizes, fewer layers, aggressive augmentation (RandAugment, Mixup, CutMix), and cosine learning rate with warmup.
Training with Heavy Augmentation
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
def build_vit_tiny(image_size=32, patch_size=4, num_classes=10):
"""Build ViT-Tiny for CIFAR-10 (small images, small patches).
CIFAR-10 images are 32x32, so we use patch_size=4:
32/4 = 8 -> 8x8 = 64 patches per image.
Args:
image_size: input size (32 for CIFAR-10)
patch_size: patch size (4 for 32x32 images)
num_classes: number of classes (10 for CIFAR-10)
Returns:
Compiled Keras model ready for training
"""
inputs = layers.Input(shape=(image_size, image_size, 3))
# Patch embedding via Conv2D
x = layers.Conv2D(256, kernel_size=patch_size,
strides=patch_size, padding="valid")(inputs)
# x shape: (batch, 8, 8, 256)
x = layers.Reshape((64, 256))(x) # 64 patches, 256 dims
# [CLS] token and position embeddings
# Using Embedding layer trick for positions
num_patches = 64
embed_dim = 256
# Add [CLS] token via concatenation
cls_token = tf.Variable(
tf.zeros((1, 1, embed_dim)), trainable=True, name="cls"
)
# We use a Lambda for simplicity in functional API
def add_cls(patch_embeds):
batch = tf.shape(patch_embeds)[0]
cls_broadcast = tf.broadcast_to(cls_token, (batch, 1, embed_dim))
return tf.concat([cls_broadcast, patch_embeds], axis=1)
x = layers.Lambda(add_cls)(x) # (batch, 65, 256)
# Learnable position embeddings
positions = tf.Variable(
tf.random.normal((1, 65, embed_dim), stddev=0.02),
trainable=True, name="pos_embed"
)
x = layers.Lambda(lambda t: t + positions)(x)
x = layers.Dropout(0.1)(x)
# 6 Transformer blocks (ViT-Tiny)
for _ in range(6):
# Pre-norm MHSA + residual
x_norm = layers.LayerNormalization(epsilon=1e-6)(x)
attn = layers.MultiHeadAttention(
num_heads=4, key_dim=64, dropout=0.1
)(x_norm, x_norm)
x = x + attn
# Pre-norm MLP + residual
x_norm = layers.LayerNormalization(epsilon=1e-6)(x)
mlp = layers.Dense(1024, activation="gelu")(x_norm)
mlp = layers.Dropout(0.1)(mlp)
mlp = layers.Dense(256)(mlp)
mlp = layers.Dropout(0.1)(mlp)
x = x + mlp
# Final norm and classification from [CLS]
x = layers.LayerNormalization(epsilon=1e-6)(x)
cls_output = x[:, 0] # Extract [CLS] token
outputs = layers.Dense(num_classes)(cls_output)
model = keras.Model(inputs, outputs, name="vit_tiny_cifar10")
return model
# Build and compile
model = build_vit_tiny(image_size=32, patch_size=4, num_classes=10)
print(f"ViT-Tiny for CIFAR-10:")
print(f" Image: 32x32, Patch: 4x4, Patches: 64")
print(f" Embed dim: 256, Heads: 4, Depth: 6")
print(f" Parameters: {model.count_params():,}")
import tensorflow as tf
from tensorflow import keras
import numpy as np
def get_cifar10_augmented(batch_size=128):
"""Load CIFAR-10 with heavy augmentation for ViT training.
Augmentation pipeline:
- RandAugment (magnitude 9, 2 operations)
- Random horizontal flip
- Normalization to [0, 1]
Returns:
train_ds, val_ds: tf.data.Dataset pipelines
"""
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
# Normalize to [0, 1]
x_train = x_train.astype(np.float32) / 255.0
x_test = x_test.astype(np.float32) / 255.0
y_train = y_train.squeeze()
y_test = y_test.squeeze()
# Training augmentation
data_augmentation = keras.Sequential([
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.1),
], name="augmentation")
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(10000).batch(batch_size)
train_ds = train_ds.map(
lambda x, y: (data_augmentation(x, training=True), y),
num_parallel_calls=tf.data.AUTOTUNE
)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
return train_ds, val_ds
def cosine_warmup_schedule(total_steps, warmup_steps, peak_lr=1e-3):
"""Cosine decay with linear warmup learning rate schedule.
Args:
total_steps: total training steps
warmup_steps: steps for linear warmup
peak_lr: maximum learning rate
Returns:
tf.keras.optimizers.schedules.LearningRateSchedule
"""
warmup = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=0.0,
decay_steps=warmup_steps,
end_learning_rate=peak_lr,
power=1.0 # linear
)
cosine = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate=peak_lr,
decay_steps=total_steps - warmup_steps,
alpha=1e-5 # minimum LR ratio
)
class WarmupCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, warmup_sched, cosine_sched, warmup_steps):
self.warmup_sched = warmup_sched
self.cosine_sched = cosine_sched
self.warmup_steps = warmup_steps
def __call__(self, step):
return tf.cond(
step < self.warmup_steps,
lambda: self.warmup_sched(step),
lambda: self.cosine_sched(step - self.warmup_steps)
)
return WarmupCosine(warmup, cosine, warmup_steps)
# Training configuration
epochs = 200
batch_size = 128
steps_per_epoch = 50000 // batch_size # ~390
total_steps = steps_per_epoch * epochs
warmup_steps = steps_per_epoch * 10 # 10 epoch warmup
lr_schedule = cosine_warmup_schedule(total_steps, warmup_steps, peak_lr=1e-3)
print(f"Training Configuration:")
print(f" Epochs: {epochs}")
print(f" Batch size: {batch_size}")
print(f" Steps/epoch: {steps_per_epoch}")
print(f" Total steps: {total_steps:,}")
print(f" Warmup steps: {warmup_steps:,} ({10} epochs)")
print(f" Peak LR: 1e-3")
print(f" Schedule: Linear warmup + Cosine decay")
print(f" Optimizer: AdamW (weight_decay=0.05)")
print(f" Label smoothing: 0.1")
print(f"\nExpected accuracy: ~92-94% (CIFAR-10)")
print(f"Training time: ~2 hours on single V100")
Attention Visualization
One of the most compelling properties of ViT is that attention maps are directly interpretable. We can extract the attention weights from each head and layer to see which patches the model attends to when making predictions. Early layers tend to show local attention patterns (like convolutions), while deeper layers capture long-range semantic relationships.
Extracting and Visualizing Attention Maps
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
class TransformerBlockWithAttnOutput(layers.Layer):
"""Transformer block that also returns attention weights.
Same architecture as standard TransformerBlock but captures
the attention weight matrix for visualization.
"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4, **kwargs):
super().__init__(**kwargs)
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
self.attn = layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=embed_dim // num_heads
)
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
self.mlp = keras.Sequential([
layers.Dense(embed_dim * mlp_ratio, activation="gelu"),
layers.Dense(embed_dim),
])
def call(self, x, return_attention=False):
x_norm = self.norm1(x)
if return_attention:
attn_output, attn_weights = self.attn(
x_norm, x_norm, x_norm,
return_attention_scores=True
)
else:
attn_output = self.attn(x_norm, x_norm, x_norm)
attn_weights = None
x = x + attn_output
x = x + self.mlp(self.norm2(x))
return x, attn_weights
def extract_attention_maps(model_blocks, input_tokens, num_patches_h=14):
"""Extract attention maps from all transformer blocks.
Args:
model_blocks: list of TransformerBlockWithAttnOutput layers
input_tokens: (1, seq_len, embed_dim) prepared input
num_patches_h: patches per row (14 for 224/16)
Returns:
attention_maps: list of (num_heads, seq_len, seq_len) arrays
"""
attention_maps = []
x = input_tokens
for i, block in enumerate(model_blocks):
x, attn_weights = block(x, return_attention=True)
# attn_weights shape: (batch, num_heads, seq_len, seq_len)
attention_maps.append(attn_weights[0].numpy()) # remove batch
return attention_maps
def visualize_cls_attention(attn_weights, num_patches_h=14,
layer_idx=0, head_idx=0):
"""Extract [CLS] token attention as a spatial heatmap.
The [CLS] token is at position 0. Its attention to patches
[1:197] forms a 14x14 spatial map showing what the model
'looks at' for classification.
Args:
attn_weights: (num_heads, 197, 197) attention matrix
num_patches_h: patches per row
layer_idx: which layer (for display)
head_idx: which attention head
Returns:
heatmap: (14, 14) attention heatmap
"""
# CLS attention to all patches (skip CLS-to-CLS at index 0)
cls_attn = attn_weights[head_idx, 0, 1:] # (196,)
# Reshape to spatial grid
heatmap = cls_attn.reshape(num_patches_h, num_patches_h)
# Normalize to [0, 1]
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
print(f"Layer {layer_idx}, Head {head_idx}:")
print(f" CLS attention shape: {cls_attn.shape}")
print(f" Heatmap shape: {heatmap.shape}")
print(f" Max attention: {cls_attn.max():.4f}")
print(f" Min attention: {cls_attn.min():.4f}")
print(f" Entropy: {-np.sum(cls_attn * np.log(cls_attn + 1e-8)):.4f}")
return heatmap
# Demonstrate attention extraction
num_heads = 12
seq_len = 197 # 196 patches + 1 CLS
# Simulate attention weights from different layers
np.random.seed(42)
# Early layer: more uniform/local attention
early_attn = np.random.dirichlet(np.ones(seq_len) * 5, size=(num_heads, seq_len))
# Later layer: more focused/sparse attention
late_attn = np.random.dirichlet(np.ones(seq_len) * 0.5, size=(num_heads, seq_len))
print("Attention Map Analysis:")
print("=" * 50)
print("\nEarly Layer (Layer 1) - Diffuse attention:")
heatmap_early = visualize_cls_attention(early_attn, layer_idx=1, head_idx=0)
print("\nLate Layer (Layer 11) - Focused attention:")
heatmap_late = visualize_cls_attention(late_attn, layer_idx=11, head_idx=0)
print(f"\nObservation:")
print(f" Early layers: attention spread across many patches (local patterns)")
print(f" Late layers: attention concentrated on semantically relevant regions")
- Layers 1-3: Local attention resembling convolution kernels. Nearby patches attend to each other.
- Layers 4-8: Mid-range patterns emerge. Some heads specialize in horizontal edges, others in vertical structures.
- Layers 9-12: Global semantic attention. The [CLS] token focuses on object-relevant patches, ignoring background.
Transfer Learning with Pre-trained ViT
In practice, training ViT from scratch on small datasets yields modest results. The power of ViT comes from pre-training on massive datasets (ImageNet-21k or JFT-300M) and then fine-tuning on the target task. A ViT-Base pre-trained on ImageNet-21k and fine-tuned on a flowers dataset can reach 98%+ accuracy with just a few epochs of training.
Fine-Tuning Pipeline
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
def build_finetuning_model(base_model_name="vit-base-patch16-224",
num_classes=5, image_size=224):
"""Build fine-tuning model from pre-trained ViT weights.
Strategy:
1. Load pre-trained ViT backbone (frozen initially)
2. Replace classification head for new task
3. Unfreeze top layers for fine-tuning
4. Train with low learning rate
Args:
base_model_name: pre-trained model identifier
num_classes: number of target classes
image_size: input image size
Returns:
Compiled model ready for fine-tuning
"""
# In practice, load from TensorFlow Hub or HuggingFace:
# import tensorflow_hub as hub
# base_url = "https://tfhub.dev/sayakpaul/vit_b16_fe/1"
# base_model = hub.KerasLayer(base_url, trainable=True)
# Simulated architecture for demonstration
inputs = layers.Input(shape=(image_size, image_size, 3))
# Pre-processing: normalize to [-1, 1] (ViT convention)
x = layers.Rescaling(scale=1.0/127.5, offset=-1.0)(inputs)
# Simulated ViT backbone output (768-dim CLS embedding)
# In real code: x = base_model(x)
x = layers.Conv2D(768, 16, strides=16, padding="valid")(x)
x = layers.Reshape((196, 768))(x)
x = layers.GlobalAveragePooling1D()(x) # Simplified
# New classification head
x = layers.LayerNormalization()(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(256, activation="gelu")(x)
x = layers.Dropout(0.1)(x)
outputs = layers.Dense(num_classes)(x)
model = keras.Model(inputs, outputs, name="vit_finetuned")
return model
# Build fine-tuning model for 5-class flowers dataset
model = build_finetuning_model(num_classes=5, image_size=224)
# Two-phase training strategy
print("Fine-Tuning Strategy:")
print("=" * 50)
print()
print("Phase 1: Head-only training (5 epochs)")
print(" - Freeze all backbone layers")
print(" - Train only new classification head")
print(" - LR: 1e-3, Optimizer: Adam")
print(" - Purpose: warm up the new head")
print()
print("Phase 2: Full fine-tuning (15 epochs)")
print(" - Unfreeze top 4 transformer blocks")
print(" - LR: 1e-5 (very low to preserve features)")
print(" - Weight decay: 0.01")
print(" - Purpose: adapt backbone to target domain")
print()
print(f"Model parameters: {model.count_params():,}")
import numpy as np
def compare_vit_vs_cnn(dataset_sizes, vit_from_scratch, vit_finetuned,
resnet50_finetuned, efficientnet_finetuned):
"""Compare ViT vs CNN accuracy across dataset sizes.
Shows the data regime where each architecture excels:
- Small data (<10k): CNNs win due to inductive bias
- Medium data (10k-100k): Close competition
- Large data (>100k): ViT matches or exceeds CNNs
Args:
dataset_sizes: list of training set sizes
vit_from_scratch: ViT accuracy at each size
vit_finetuned: Pre-trained ViT accuracy
resnet50_finetuned: ResNet-50 fine-tuned accuracy
efficientnet_finetuned: EfficientNet-B0 accuracy
Returns:
Comparison table printed to stdout
"""
print(f"{'Dataset Size':<15} {'ViT Scratch':<14} {'ViT Pretrain':<14} "
f"{'ResNet-50':<12} {'EffNet-B0':<12} {'Winner':<15}")
print("-" * 82)
for i, size in enumerate(dataset_sizes):
scores = {
"ViT Scratch": vit_from_scratch[i],
"ViT Pretrain": vit_finetuned[i],
"ResNet-50": resnet50_finetuned[i],
"EffNet-B0": efficientnet_finetuned[i]
}
winner = max(scores, key=scores.get)
size_str = f"{size:,}" if size < 1000000 else f"{size // 1000}k"
print(f"{size_str:<15} {vit_from_scratch[i]:<14.1f} "
f"{vit_finetuned[i]:<14.1f} {resnet50_finetuned[i]:<12.1f} "
f"{efficientnet_finetuned[i]:<12.1f} {winner:<15}")
return scores
# Accuracy comparison across data regimes
dataset_sizes = [1000, 5000, 10000, 50000, 100000, 1000000]
# Simulated accuracy values based on published benchmarks
vit_scratch = [52.3, 68.7, 76.2, 88.4, 92.1, 97.8]
vit_pretrained = [89.5, 93.2, 95.1, 97.3, 98.2, 99.1]
resnet50_ft = [82.1, 88.9, 91.5, 94.8, 96.0, 97.5]
efficientnet_ft = [84.7, 90.2, 92.8, 95.6, 96.8, 97.9]
print("ViT vs CNN: Accuracy by Dataset Size")
print("=" * 82)
compare_vit_vs_cnn(
dataset_sizes, vit_scratch, vit_pretrained,
resnet50_ft, efficientnet_ft
)
print("\nKey Takeaways:")
print(" 1. ViT from scratch needs >50k images to compete with CNNs")
print(" 2. Pre-trained ViT dominates at ALL dataset sizes")
print(" 3. CNNs (ResNet, EfficientNet) have better inductive bias for small data")
print(" 4. At scale (>100k), architectures converge in performance")
- ViT wins with pre-trained weights (always), with large datasets (>100k), when global context matters (scene understanding, multi-object reasoning), and for transfer learning across domains.
- CNN wins on small datasets without pre-training, for edge deployment (smaller models), when locality is the dominant signal (texture classification), and with limited compute budget for training.
- Hybrid approaches (ConvNeXt, CoAtNet) combine the best of both — CNN stems for local features plus Transformer blocks for global attention.