Table of Contents

  1. Foundations
  2. Pre-training
  3. Implementation
  4. Fine-tuning
Back to TensorFlow Mastery Series

Deep Dive: BERT — Bidirectional Transformers in TensorFlow

May 3, 2026 Wasil Zafar 40 min read

Build BERT from scratch in TensorFlow — implement masked language modeling, next sentence prediction, tokenization, and fine-tune for text classification, NER, and question answering.

1. Why BERT Changed NLP

Before BERT, language models processed text in a single direction. GPT-1 read left-to-right, capturing only preceding context. ELMo introduced bidirectionality but operated with shallow concatenation of separately trained forward and backward LSTMs. Neither approach could deeply integrate context from both directions simultaneously.

BERT’s Key Insight: By masking random tokens and training the model to predict them using both left and right context simultaneously, BERT achieves deep bidirectional representations that capture full sentence semantics.

When BERT was released in 2018, it achieved state-of-the-art results on 11 NLP benchmarks simultaneously, including SQuAD (question answering), GLUE (general language understanding), and NER tasks. This was unprecedented — a single pre-trained model dominating across diverse tasks.

Unidirectional vs Bidirectional Context

Consider the word “bank” in two sentences: “I deposited money at the bank” vs “I sat by the river bank.” A left-to-right model sees “I sat by the river” before encountering “bank” — but in the first sentence, it only has “I deposited money at the” which is less disambiguating without future context. BERT sees the entire sentence at once.

import tensorflow as tf
import numpy as np

# Demonstrate how context changes word meaning
sentences = [
    "I deposited money at the bank",
    "I sat by the river bank"
]

# In a unidirectional model (left-to-right):
# "bank" in sentence 1 sees: "I deposited money at the"
# "bank" in sentence 2 sees: "I sat by the river"

# In BERT (bidirectional):
# "bank" in sentence 1 sees: "I deposited money at the [MASK]"
#   + full surrounding context from both sides
# "bank" in sentence 2 sees: "I sat by the river [MASK]"
#   + full surrounding context from both sides

# Simple demonstration of masking
def mask_token(sentence, target_word):
    tokens = sentence.lower().split()
    masked = ["[MASK]" if t == target_word else t for t in tokens]
    return " ".join(masked)

for sent in sentences:
    masked = mask_token(sent, "bank")
    print(f"Original: {sent}")
    print(f"Masked:   {masked}")
    print(f"BERT predicts 'bank' using ALL surrounding tokens")
    print()

2. BERT Architecture

BERT uses only the encoder portion of the Transformer architecture — no decoder is needed since BERT is not generating text autoregressively. Two model sizes were released:

Architecture Model Configurations
Parameter BERT-Base BERT-Large
Layers (L)1224
Hidden Size (H)7681024
Attention Heads (A)1216
Parameters110M340M
Feed-Forward Size30724096
Transformer Encoder-Only Self-Attention

Input Representation

BERT’s input is the sum of three embeddings: token embeddings (WordPiece vocabulary), segment embeddings (distinguishing sentence A from sentence B), and position embeddings (encoding sequential position up to 512 tokens).

BERT Architecture Overview
flowchart TD
    A[Input Tokens] --> B[Token Embeddings]
    A --> C[Segment Embeddings]
    A --> D[Position Embeddings]
    B --> E[Sum of Embeddings]
    C --> E
    D --> E
    E --> F[Transformer Encoder Block 1]
    F --> G[Transformer Encoder Block 2]
    G --> H[...]
    H --> I[Transformer Encoder Block L]
    I --> J["[CLS] Output"]
    I --> K[Token Outputs]
    J --> L[NSP Head]
    K --> M[MLM Head]
                            
import tensorflow as tf
import numpy as np

class BERTEmbedding(tf.keras.layers.Layer):
    """BERT Input Embedding: Token + Segment + Position"""

    def __init__(self, vocab_size, max_len, hidden_size, **kwargs):
        super(BERTEmbedding, self).__init__(**kwargs)
        self.token_embedding = tf.keras.layers.Embedding(
            vocab_size, hidden_size, name="token_embedding"
        )
        self.segment_embedding = tf.keras.layers.Embedding(
            2, hidden_size, name="segment_embedding"
        )
        self.position_embedding = tf.keras.layers.Embedding(
            max_len, hidden_size, name="position_embedding"
        )
        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout = tf.keras.layers.Dropout(0.1)

    def call(self, token_ids, segment_ids, training=False):
        seq_len = tf.shape(token_ids)[1]
        positions = tf.range(seq_len)

        # Sum three embeddings
        x = self.token_embedding(token_ids)
        x = x + self.segment_embedding(segment_ids)
        x = x + self.position_embedding(positions)

        x = self.layer_norm(x)
        x = self.dropout(x, training=training)
        return x

# Test the embedding layer
vocab_size = 30522
max_len = 512
hidden_size = 768

embedding = BERTEmbedding(vocab_size, max_len, hidden_size)

# Simulated input: batch_size=2, seq_len=10
token_ids = tf.random.uniform((2, 10), 0, vocab_size, dtype=tf.int32)
segment_ids = tf.concat([
    tf.zeros((2, 5), dtype=tf.int32),
    tf.ones((2, 5), dtype=tf.int32)
], axis=1)

output = embedding(token_ids, segment_ids)
print(f"Embedding output shape: {output.shape}")
# Expected: (2, 10, 768)

3. Pre-training Task 1: Masked Language Modeling (MLM)

MLM is BERT’s primary pre-training objective. During training, 15% of input tokens are selected for prediction. Of those selected tokens:

  • 80% are replaced with [MASK]
  • 10% are replaced with a random token from the vocabulary
  • 10% are left unchanged

This strategy prevents the model from simply learning that [MASK] is the signal for prediction, and encourages robust representations for all tokens.

import tensorflow as tf
import numpy as np

def create_mlm_data(token_ids, vocab_size, mask_token_id=103,
                    mask_prob=0.15):
    """
    Create masked language modeling training data.

    Args:
        token_ids: numpy array of shape (seq_len,) with token IDs
        vocab_size: size of the vocabulary
        mask_token_id: ID of the [MASK] token (103 in BERT)
        mask_prob: probability of selecting a token for masking

    Returns:
        masked_ids: token IDs with masking applied
        labels: original token IDs at masked positions, -100 elsewhere
    """
    labels = np.full_like(token_ids, -100)  # -100 = ignore in loss
    masked_ids = token_ids.copy()

    # Select 15% of positions for masking
    num_to_mask = max(1, int(len(token_ids) * mask_prob))
    mask_indices = np.random.choice(
        len(token_ids), size=num_to_mask, replace=False
    )

    for idx in mask_indices:
        labels[idx] = token_ids[idx]  # Store original for prediction
        rand = np.random.random()

        if rand < 0.8:
            # 80%: replace with [MASK]
            masked_ids[idx] = mask_token_id
        elif rand < 0.9:
            # 10%: replace with random token
            masked_ids[idx] = np.random.randint(0, vocab_size)
        # else 10%: keep original (do nothing)

    return masked_ids, labels

# Example usage
np.random.seed(42)
original_tokens = np.array([101, 2023, 2003, 1037, 3231, 6251, 102])
# [CLS] this  is   a    test  sentence [SEP]

masked, labels = create_mlm_data(original_tokens, vocab_size=30522)

print("Original tokens:", original_tokens)
print("Masked tokens:  ", masked)
print("Labels:         ", labels)
print("(Labels -100 = not masked, other values = predict this token)")

MLM Prediction Head

import tensorflow as tf

class MLMHead(tf.keras.layers.Layer):
    """Masked Language Modeling prediction head."""

    def __init__(self, hidden_size, vocab_size, **kwargs):
        super(MLMHead, self).__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(
            hidden_size, activation="gelu", name="transform_dense"
        )
        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.output_projection = tf.keras.layers.Dense(
            vocab_size, name="output_projection"
        )

    def call(self, encoder_output):
        # encoder_output: (batch, seq_len, hidden_size)
        x = self.dense(encoder_output)
        x = self.layer_norm(x)
        logits = self.output_projection(x)
        # logits: (batch, seq_len, vocab_size)
        return logits

# Test MLM head
hidden_size = 768
vocab_size = 30522

mlm_head = MLMHead(hidden_size, vocab_size)
dummy_encoder_out = tf.random.normal((2, 10, hidden_size))
logits = mlm_head(dummy_encoder_out)
print(f"MLM logits shape: {logits.shape}")
# Expected: (2, 10, 30522)

# Compute loss only at masked positions
labels = tf.constant([[  -100, -100, 2003, -100, -100, -100, -100, -100, -100, -100],
                      [  -100, -100, -100, -100, 3231, -100, -100, -100, -100, -100]])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction="none"
)
loss = loss_fn(tf.maximum(labels, 0), logits)
mask = tf.cast(labels != -100, tf.float32)
masked_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
print(f"MLM loss (random init): {masked_loss.numpy():.4f}")

4. Pre-training Task 2: Next Sentence Prediction (NSP)

NSP trains BERT to understand relationships between sentences. Given two segments A and B, the model predicts whether B actually follows A in the original text (label: IsNext) or is a randomly sampled sentence (label: NotNext). The [CLS] token’s final hidden state is used for this binary classification.

BERT Pre-training Flow
flowchart LR
    A[Raw Text Corpus] --> B[Sentence Pairs]
    B --> C{50% Real Pairs\n50% Random}
    C -->|IsNext| D[Positive Examples]
    C -->|NotNext| E[Negative Examples]
    D --> F[Token + Segment IDs]
    E --> F
    F --> G[Apply 15% Masking]
    G --> H[BERT Encoder]
    H --> I["[CLS] -> NSP Loss"]
    H --> J["Masked Tokens -> MLM Loss"]
    I --> K[Total Loss = MLM + NSP]
    J --> K
                            
import tensorflow as tf
import numpy as np

def create_nsp_data(sentences, num_examples=100):
    """
    Generate Next Sentence Prediction training pairs.

    Args:
        sentences: list of sentence strings
        num_examples: number of training pairs to generate

    Returns:
        pairs: list of (sentence_a, sentence_b, is_next) tuples
    """
    pairs = []
    num_sentences = len(sentences)

    for _ in range(num_examples):
        idx_a = np.random.randint(0, num_sentences - 1)

        if np.random.random() < 0.5:
            # 50% chance: B actually follows A
            idx_b = idx_a + 1
            is_next = 1
        else:
            # 50% chance: B is a random sentence
            idx_b = np.random.randint(0, num_sentences)
            while idx_b == idx_a + 1:
                idx_b = np.random.randint(0, num_sentences)
            is_next = 0

        pairs.append((sentences[idx_a], sentences[idx_b], is_next))

    return pairs

# Example corpus
corpus = [
    "The cat sat on the mat.",
    "It was a warm sunny day.",
    "Birds were singing in the trees.",
    "The dog chased the ball.",
    "Rain started to fall heavily.",
    "She opened her umbrella quickly.",
]

np.random.seed(42)
nsp_pairs = create_nsp_data(corpus, num_examples=6)

print("NSP Training Pairs:")
print("-" * 60)
for sent_a, sent_b, label in nsp_pairs:
    label_str = "IsNext" if label == 1 else "NotNext"
    print(f"A: {sent_a}")
    print(f"B: {sent_b}")
    print(f"Label: {label_str}")
    print()

NSP Classification Head

import tensorflow as tf

class NSPHead(tf.keras.layers.Layer):
    """Next Sentence Prediction classification head."""

    def __init__(self, hidden_size, **kwargs):
        super(NSPHead, self).__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(
            hidden_size, activation="tanh", name="pooler_dense"
        )
        self.classifier = tf.keras.layers.Dense(
            2, name="nsp_classifier"
        )

    def call(self, encoder_output):
        # Use [CLS] token output (first token)
        cls_output = encoder_output[:, 0, :]  # (batch, hidden_size)
        pooled = self.dense(cls_output)
        logits = self.classifier(pooled)  # (batch, 2)
        return logits

# Test NSP head
hidden_size = 768
nsp_head = NSPHead(hidden_size)
dummy_encoder_out = tf.random.normal((4, 10, hidden_size))
nsp_logits = nsp_head(dummy_encoder_out)
print(f"NSP logits shape: {nsp_logits.shape}")
# Expected: (4, 2) -- binary classification per sample

# Compute NSP loss
nsp_labels = tf.constant([1, 0, 1, 0])  # IsNext, NotNext, ...
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
nsp_loss = loss_fn(nsp_labels, nsp_logits)
print(f"NSP loss (random init): {nsp_loss.numpy():.4f}")

5. Building BERT from Scratch in TensorFlow

Now we combine all components into a complete BERT implementation. The core building blocks are: Multi-Head Attention, Feed-Forward Network, and the Transformer Encoder Block (attention + FFN + LayerNorm + residual connections).

The scaled dot-product attention formula:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where $Q$, $K$, $V$ are query, key, and value matrices, and $d_k$ is the dimension of the keys.

import tensorflow as tf
import numpy as np

class MultiHeadAttention(tf.keras.layers.Layer):
    """Multi-Head Self-Attention mechanism."""

    def __init__(self, hidden_size, num_heads, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        assert hidden_size % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.hidden_size = hidden_size

        self.wq = tf.keras.layers.Dense(hidden_size)
        self.wk = tf.keras.layers.Dense(hidden_size)
        self.wv = tf.keras.layers.Dense(hidden_size)
        self.wo = tf.keras.layers.Dense(hidden_size)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.head_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, x, mask=None):
        batch_size = tf.shape(x)[0]

        q = self.split_heads(self.wq(x), batch_size)
        k = self.split_heads(self.wk(x), batch_size)
        v = self.split_heads(self.wv(x), batch_size)

        # Scaled dot-product attention
        scale = tf.math.sqrt(tf.cast(self.head_dim, tf.float32))
        scores = tf.matmul(q, k, transpose_b=True) / scale

        if mask is not None:
            scores += (mask * -1e9)

        weights = tf.nn.softmax(scores, axis=-1)
        attention = tf.matmul(weights, v)

        # Concatenate heads
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        attention = tf.reshape(attention, (batch_size, -1, self.hidden_size))

        return self.wo(attention)


class TransformerEncoderBlock(tf.keras.layers.Layer):
    """Single Transformer Encoder Block."""

    def __init__(self, hidden_size, num_heads, ff_size, dropout=0.1, **kwargs):
        super(TransformerEncoderBlock, self).__init__(**kwargs)
        self.attention = MultiHeadAttention(hidden_size, num_heads)
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(ff_size, activation="gelu"),
            tf.keras.layers.Dense(hidden_size)
        ])
        self.norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(dropout)
        self.dropout2 = tf.keras.layers.Dropout(dropout)

    def call(self, x, mask=None, training=False):
        # Multi-head attention + residual + norm
        attn_out = self.attention(x, mask=mask)
        attn_out = self.dropout1(attn_out, training=training)
        x = self.norm1(x + attn_out)

        # Feed-forward + residual + norm
        ffn_out = self.ffn(x)
        ffn_out = self.dropout2(ffn_out, training=training)
        x = self.norm2(x + ffn_out)
        return x


class BERTModel(tf.keras.Model):
    """Complete BERT Model with pre-training heads."""

    def __init__(self, vocab_size=30522, max_len=512, hidden_size=768,
                 num_heads=12, num_layers=12, ff_size=3072, **kwargs):
        super(BERTModel, self).__init__(**kwargs)
        self.embedding = BERTEmbedding(vocab_size, max_len, hidden_size)
        self.encoder_blocks = [
            TransformerEncoderBlock(hidden_size, num_heads, ff_size)
            for _ in range(num_layers)
        ]
        self.mlm_head = MLMHead(hidden_size, vocab_size)
        self.nsp_head = NSPHead(hidden_size)

    def call(self, token_ids, segment_ids, mask=None, training=False):
        x = self.embedding(token_ids, segment_ids, training=training)

        for block in self.encoder_blocks:
            x = block(x, mask=mask, training=training)

        mlm_logits = self.mlm_head(x)
        nsp_logits = self.nsp_head(x)
        return mlm_logits, nsp_logits

# Instantiate BERT-Base (use fewer layers for demo)
bert = BERTModel(
    vocab_size=30522,
    max_len=512,
    hidden_size=768,
    num_heads=12,
    num_layers=2,  # Use 2 layers for quick demo (12 for full)
    ff_size=3072
)

# Forward pass
token_ids = tf.random.uniform((2, 32), 0, 30522, dtype=tf.int32)
segment_ids = tf.zeros((2, 32), dtype=tf.int32)

mlm_out, nsp_out = bert(token_ids, segment_ids)
print(f"MLM output shape: {mlm_out.shape}")  # (2, 32, 30522)
print(f"NSP output shape: {nsp_out.shape}")  # (2, 2)

# Count parameters
total_params = sum(
    tf.reduce_prod(v.shape).numpy() for v in bert.trainable_variables
)
print(f"Total trainable parameters: {total_params:,}")
Parameter Count Note: The 2-layer demo model has significantly fewer parameters than BERT-Base (110M). Set num_layers=12 for the full BERT-Base configuration. Training the full model requires large-scale compute (16 TPU chips, 4 days in the original paper).

6. Tokenization: WordPiece

BERT uses WordPiece tokenization with a vocabulary of approximately 30,000 tokens. WordPiece handles out-of-vocabulary (OOV) words by splitting them into known subword units. Subword continuations are prefixed with ##.

For example: “playing” → “play” + “##ing”, “unhappiness” → “un” + “##hap” + “##pi” + “##ness”

import tensorflow as tf
import numpy as np

class SimpleWordPieceTokenizer:
    """
    Simplified WordPiece tokenizer for demonstration.
    In production, use tensorflow_text or HuggingFace tokenizers.
    """

    def __init__(self, vocab):
        self.vocab = set(vocab)
        self.special_tokens = {
            "[PAD]": 0, "[UNK]": 1, "[CLS]": 2,
            "[SEP]": 3, "[MASK]": 4
        }

    def tokenize(self, text):
        """Tokenize text into WordPiece tokens."""
        tokens = []
        for word in text.lower().split():
            if word in self.vocab:
                tokens.append(word)
            else:
                # Try to split into subwords
                subwords = self._wordpiece_split(word)
                tokens.extend(subwords)
        return tokens

    def _wordpiece_split(self, word):
        """Split a word into WordPiece subwords."""
        subwords = []
        start = 0
        while start < len(word):
            end = len(word)
            found = False
            while start < end:
                substr = word[start:end]
                if start > 0:
                    substr = "##" + substr
                if substr in self.vocab:
                    subwords.append(substr)
                    found = True
                    break
                end -= 1
            if not found:
                subwords.append("[UNK]")
                start += 1
            else:
                start = end
        return subwords

# Example vocabulary (subset)
vocab = {
    "the", "a", "is", "was", "in", "on", "at",
    "play", "##ing", "##ed", "##er", "##s",
    "un", "##hap", "##pi", "##ness",
    "transform", "##er", "deep", "learn",
    "model", "train", "##able", "pre",
    "##train", "bi", "##direction", "##al",
    "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"
}

tokenizer = SimpleWordPieceTokenizer(vocab)

# Tokenization examples
test_words = ["playing", "transformer", "pretrain", "bidirectional"]
print("WordPiece Tokenization Examples:")
print("-" * 40)
for word in test_words:
    tokens = tokenizer.tokenize(word)
    print(f"  {word:20s} -> {tokens}")

Using TensorFlow Official Tokenizer

import tensorflow as tf

# Using tf-models-official BERT tokenizer (production approach)
# pip install tf-models-official tensorflow-text

# For demonstration, here is how you would use it:
"""
import tensorflow_text as text
from official.nlp.tools import tokenization

# Load pre-trained vocabulary
vocab_file = "path/to/bert/vocab.txt"
tokenizer = tokenization.FullTokenizer(
    vocab_file=vocab_file, do_lower_case=True
)

# Tokenize
sentence = "BERT uses WordPiece tokenization"
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)

print(f"Tokens: {tokens}")
print(f"IDs:    {token_ids}")
"""

# Alternative: using TF Hub preprocessing
# This is the recommended approach for fine-tuning
print("Production tokenization with TF Hub:")
print("  import tensorflow_hub as hub")
print("  preprocess = hub.load(")
print("    'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'")
print("  )")
print("  text_input = tf.constant(['Hello TensorFlow!'])")
print("  encoder_inputs = preprocess(text_input)")
print()
print("Keys in encoder_inputs:")
print("  - input_word_ids: token IDs (int32)")
print("  - input_mask: attention mask (int32)")
print("  - input_type_ids: segment IDs (int32)")

7. Fine-tuning BERT for Classification

Fine-tuning BERT for text classification involves adding a simple Dense layer on top of the [CLS] token output. The standard approach: load pre-trained weights, add a task-specific head, and fine-tune all layers with a small learning rate.

import tensorflow as tf
import tensorflow_hub as hub
import numpy as np

# Load pre-trained BERT from TF Hub
def build_bert_classifier(num_classes=2, trainable_bert=True):
    """Build BERT classifier using TF Hub."""

    # Input layers
    text_input = tf.keras.layers.Input(
        shape=(), dtype=tf.string, name="text_input"
    )

    # BERT preprocessing
    preprocess_url = (
        "https://tfhub.dev/tensorflow/"
        "bert_en_uncased_preprocess/3"
    )
    preprocessor = hub.KerasLayer(preprocess_url, name="preprocess")

    # BERT encoder
    encoder_url = (
        "https://tfhub.dev/tensorflow/"
        "bert_en_uncased_L-12_H-768_A-12/4"
    )
    encoder = hub.KerasLayer(
        encoder_url, trainable=trainable_bert, name="bert_encoder"
    )

    # Forward pass
    encoder_inputs = preprocessor(text_input)
    outputs = encoder(encoder_inputs)
    pooled_output = outputs["pooled_output"]  # [CLS] representation

    # Classification head
    x = tf.keras.layers.Dropout(0.1)(pooled_output)
    x = tf.keras.layers.Dense(
        num_classes, activation="softmax", name="classifier"
    )(x)

    model = tf.keras.Model(inputs=text_input, outputs=x)
    return model

# Build model
model = build_bert_classifier(num_classes=2)
print("Model built successfully!")
print(f"Total layers: {len(model.layers)}")

# Learning rate schedule with warmup
class WarmupSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_lr, warmup_steps, total_steps):
        self.initial_lr = initial_lr
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup = self.initial_lr * (step / self.warmup_steps)
        decay = self.initial_lr * (
            1.0 - (step - self.warmup_steps) /
            (self.total_steps - self.warmup_steps)
        )
        return tf.where(step < self.warmup_steps, warmup, decay)

# Compile with warmup schedule
total_steps = 3000  # ~3 epochs on IMDB
warmup_steps = 300  # 10% warmup

lr_schedule = WarmupSchedule(
    initial_lr=2e-5,
    warmup_steps=warmup_steps,
    total_steps=total_steps
)

optimizer = tf.keras.optimizers.Adam(
    learning_rate=lr_schedule, epsilon=1e-8
)
model.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

print("\nModel compiled with warmup schedule")
print(f"  Initial LR: 2e-5")
print(f"  Warmup steps: {warmup_steps}")
print(f"  Total steps: {total_steps}")
Expected Results: Fine-tuning BERT-Base on IMDB sentiment classification typically achieves 93-94% accuracy within 3 epochs, compared to ~88% for simpler models like bidirectional LSTMs without pre-training.

8. Fine-tuning for NER and QA

Beyond sequence classification, BERT excels at token-level tasks. For Named Entity Recognition (NER), each token’s output representation is classified independently. For extractive Question Answering (QA), the model predicts the start and end positions of the answer span within the context passage.

Token Classification (NER)

import tensorflow as tf
import numpy as np

def build_bert_ner_model(num_labels=9, max_len=128):
    """
    Build BERT model for Named Entity Recognition.
    Uses all token outputs (not just [CLS]).

    Labels follow BIO scheme:
      O, B-PER, I-PER, B-ORG, I-ORG,
      B-LOC, I-LOC, B-MISC, I-MISC
    """
    # Input layers
    input_ids = tf.keras.layers.Input(
        shape=(max_len,), dtype=tf.int32, name="input_ids"
    )
    attention_mask = tf.keras.layers.Input(
        shape=(max_len,), dtype=tf.int32, name="attention_mask"
    )
    token_type_ids = tf.keras.layers.Input(
        shape=(max_len,), dtype=tf.int32, name="token_type_ids"
    )

    # Simulated BERT encoder output (in practice, use TF Hub)
    # Shape: (batch_size, max_len, hidden_size=768)
    encoder = tf.keras.layers.Dense(768, name="simulated_bert")
    embedding = tf.keras.layers.Embedding(30522, 768)
    sequence_output = embedding(input_ids)

    # NER classification head - applied to ALL tokens
    x = tf.keras.layers.Dropout(0.1)(sequence_output)
    logits = tf.keras.layers.Dense(
        num_labels, name="ner_classifier"
    )(x)
    # logits shape: (batch_size, max_len, num_labels)

    model = tf.keras.Model(
        inputs=[input_ids, attention_mask, token_type_ids],
        outputs=logits
    )
    return model

# Build NER model
ner_model = build_bert_ner_model(num_labels=9)

# Example: tag a sentence
# "John lives in New York"
# Tokens: [CLS] john lives in new york [SEP] [PAD]...
sample_ids = np.zeros((1, 128), dtype=np.int32)
sample_ids[0, :7] = [101, 2198, 3268, 1999, 2047, 2259, 102]
sample_mask = np.zeros((1, 128), dtype=np.int32)
sample_mask[0, :7] = 1
sample_segments = np.zeros((1, 128), dtype=np.int32)

logits = ner_model.predict(
    [sample_ids, sample_mask, sample_segments], verbose=0
)
predictions = np.argmax(logits[0, :7], axis=-1)

labels_map = {
    0: "O", 1: "B-PER", 2: "I-PER", 3: "B-ORG",
    4: "I-ORG", 5: "B-LOC", 6: "I-LOC", 7: "B-MISC", 8: "I-MISC"
}
tokens = ["[CLS]", "john", "lives", "in", "new", "york", "[SEP]"]

print("NER Predictions (random init - will improve after training):")
print("-" * 40)
for token, pred in zip(tokens, predictions):
    print(f"  {token:10s} -> {labels_map[pred]}")

Extractive Question Answering

import tensorflow as tf
import numpy as np

def build_bert_qa_model(max_len=384):
    """
    Build BERT model for extractive Question Answering.
    Predicts start and end positions of the answer span.
    """
    input_ids = tf.keras.layers.Input(
        shape=(max_len,), dtype=tf.int32, name="input_ids"
    )
    attention_mask = tf.keras.layers.Input(
        shape=(max_len,), dtype=tf.int32, name="attention_mask"
    )
    token_type_ids = tf.keras.layers.Input(
        shape=(max_len,), dtype=tf.int32, name="token_type_ids"
    )

    # Simulated BERT sequence output
    embedding = tf.keras.layers.Embedding(30522, 768)
    sequence_output = embedding(input_ids)

    # QA head: predict start and end positions
    # Project to 2 outputs per token (start score, end score)
    qa_outputs = tf.keras.layers.Dense(
        2, name="qa_outputs"
    )(sequence_output)

    start_logits = qa_outputs[:, :, 0]  # (batch, seq_len)
    end_logits = qa_outputs[:, :, 1]    # (batch, seq_len)

    model = tf.keras.Model(
        inputs=[input_ids, attention_mask, token_type_ids],
        outputs=[start_logits, end_logits]
    )
    return model

# Build QA model
qa_model = build_bert_qa_model(max_len=384)

# Example: Question + Context
# Q: "Where does John live?"
# C: "John lives in New York City."
# Expected answer span: "New York City" (positions 4-6 in context)

sample_ids = np.zeros((1, 384), dtype=np.int32)
# [CLS] where does john live [SEP] john lives in new york city [SEP]
sample_ids[0, :13] = [101, 2073, 2515, 2198, 2444, 102,
                       2198, 3268, 1999, 2047, 2259, 2103, 102]
sample_mask = np.zeros((1, 384), dtype=np.int32)
sample_mask[0, :13] = 1
sample_segments = np.zeros((1, 384), dtype=np.int32)
sample_segments[0, 6:13] = 1  # Context is segment B

start_logits, end_logits = qa_model.predict(
    [sample_ids, sample_mask, sample_segments], verbose=0
)

# Get predicted span
start_pos = np.argmax(start_logits[0, :13])
end_pos = np.argmax(end_logits[0, :13])

context_tokens = ["[CLS]", "where", "does", "john", "live", "[SEP]",
                  "john", "lives", "in", "new", "york", "city", "[SEP]"]
print(f"Predicted start position: {start_pos}")
print(f"Predicted end position: {end_pos}")
print(f"Answer span: {' '.join(context_tokens[start_pos:end_pos+1])}")
print("(Random init - answers will be correct after fine-tuning on SQuAD)")
Fine-tuning Best Practices Checklist:
  • Use learning rate 2e-5 to 5e-5 (lower than pre-training)
  • Apply linear warmup for 10% of total steps
  • Train for 2-4 epochs (more risks overfitting on small datasets)
  • Use batch size 16 or 32 (gradient accumulation if memory-limited)
  • Set max sequence length to 128 for classification, 384 for QA
  • Freeze BERT layers initially on very small datasets (<1000 examples)
  • Use weight decay 0.01 (AdamW optimizer)
  • Monitor validation loss for early stopping