Back to TensorFlow Mastery Series

Part 3: Training & Optimization

May 3, 2026 Wasil Zafar 30 min read

Master the complete TensorFlow training pipeline — from choosing the right loss function and optimizer, to fine-tuning learning rates, gradient clipping, and mixed precision training for maximum performance.

Table of Contents

  1. The Training Pipeline
  2. Loss Functions
  3. Optimizers
  4. Learning Rate Schedules
  5. Metrics
  6. compile() Deep Dive
  7. Batch Size & Gradient Accumulation
  8. Gradient Clipping
  9. Mixed Precision Training
  10. Complete Example

The Training Pipeline

Every TensorFlow model follows a four-step lifecycle: compile the model with an optimizer and loss function, fit it to training data, evaluate on held-out data, and predict on new inputs. Understanding each step — and the knobs you can turn — is the key to training models that converge quickly and generalize well.

Core Workflow: model.compile() configures the training strategy, model.fit() runs gradient descent, model.evaluate() measures test performance, and model.predict() produces inference outputs. Everything in this article revolves around making these four steps as effective as possible.

Pipeline Workflow

The diagram below shows how the compile → fit → evaluate → predict pipeline connects to the optimization components we'll cover in this article — loss functions, optimizers, learning rate schedules, metrics, and advanced techniques like gradient clipping and mixed precision.

TensorFlow Training Pipeline
flowchart TD
    A[Define Model] --> B[model.compile]
    B --> C[model.fit]
    C --> D[model.evaluate]
    D --> E[model.predict]

    B --> B1[Loss Function]
    B --> B2[Optimizer]
    B --> B3[Metrics]

    B2 --> B2a[Learning Rate Schedule]
    B2 --> B2b[Gradient Clipping]

    C --> C1[Training Loop]
    C1 --> C2[Forward Pass]
    C2 --> C3[Compute Loss]
    C3 --> C4[Backward Pass]
    C4 --> C5[Update Weights]
    C5 --> C1

    style A fill:#132440,stroke:#3B9797,color:#fff
    style B fill:#16476A,stroke:#3B9797,color:#fff
    style C fill:#3B9797,stroke:#132440,color:#fff
    style D fill:#16476A,stroke:#3B9797,color:#fff
    style E fill:#132440,stroke:#3B9797,color:#fff
                            

Let's start with the minimal compile-fit-evaluate-predict workflow and then progressively unlock more powerful training options.

import tensorflow as tf
import numpy as np

# Generate synthetic regression data
np.random.seed(42)
X_train = np.random.randn(1000, 10).astype(np.float32)
y_train = (X_train @ np.random.randn(10, 1) + 0.5).astype(np.float32)
X_test = np.random.randn(200, 10).astype(np.float32)
y_test = (X_test @ np.random.randn(10, 1) + 0.5).astype(np.float32)

# Build a simple model
model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(1)
])

# 1. Compile — configure optimizer, loss, metrics
model.compile(
    optimizer='adam',
    loss='mse',
    metrics=['mae']
)

# 2. Fit — train on data
history = model.fit(X_train, y_train, epochs=10, batch_size=32,
                    validation_split=0.2, verbose=1)

# 3. Evaluate — test performance
test_loss, test_mae = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Loss (MSE): {test_loss:.4f}")
print(f"Test MAE: {test_mae:.4f}")

# 4. Predict — inference
predictions = model.predict(X_test[:5])
print("Predictions shape:", predictions.shape)
print("First 5 predictions:", predictions.flatten())

The history object returned by fit() contains loss and metric values for every epoch, which you can plot to diagnose training behaviour — whether the model is converging, overfitting, or underfitting.

Loss Functions

The loss function measures how far the model's predictions are from the true labels. Choosing the right loss is critical — it defines what the model optimizes for. TensorFlow provides loss functions for regression, binary classification, and multi-class classification, plus the ability to write your own.

The general cross-entropy loss for multi-class classification is:

$$\mathcal{L} = -\sum_{i} y_i \log(\hat{y}_i)$$

where $y_i$ is the true label (one-hot encoded) and $\hat{y}_i$ is the predicted probability for class $i$.

Regression Losses

Mean Squared Error (MSE) penalizes large errors quadratically, making it sensitive to outliers. Mean Absolute Error (MAE) is more robust to outliers. Huber loss combines both — it's quadratic for small errors and linear for large ones.

import tensorflow as tf
import numpy as np

# Sample data
y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
y_pred = np.array([1.1, 2.3, 2.8, 4.5, 4.7])

# Mean Squared Error
mse = tf.keras.losses.MeanSquaredError()
print(f"MSE:   {mse(y_true, y_pred).numpy():.4f}")

# Mean Absolute Error
mae = tf.keras.losses.MeanAbsoluteError()
print(f"MAE:   {mae(y_true, y_pred).numpy():.4f}")

# Huber Loss (delta controls MSE-MAE transition)
huber = tf.keras.losses.Huber(delta=1.0)
print(f"Huber: {huber(y_true, y_pred).numpy():.4f}")

# Huber with smaller delta — more robust to outliers
huber_small = tf.keras.losses.Huber(delta=0.5)
print(f"Huber (δ=0.5): {huber_small(y_true, y_pred).numpy():.4f}")

Classification Losses

Binary Cross-Entropy (BCE) is for two-class problems. Categorical Cross-Entropy (CCE) is for multi-class with one-hot labels. Sparse Categorical Cross-Entropy is the same as CCE but takes integer labels directly, avoiding the memory cost of one-hot encoding.

import tensorflow as tf
import numpy as np

# Binary classification
y_true_binary = np.array([1, 0, 1, 1, 0], dtype=np.float32)
y_pred_binary = np.array([0.9, 0.1, 0.8, 0.7, 0.3], dtype=np.float32)

bce = tf.keras.losses.BinaryCrossentropy()
print(f"BCE: {bce(y_true_binary, y_pred_binary).numpy():.4f}")

# Multi-class with one-hot labels
y_true_onehot = np.array([[1,0,0], [0,1,0], [0,0,1]], dtype=np.float32)
y_pred_probs  = np.array([[0.9,0.05,0.05], [0.1,0.8,0.1], [0.2,0.3,0.5]], dtype=np.float32)

cce = tf.keras.losses.CategoricalCrossentropy()
print(f"CCE: {cce(y_true_onehot, y_pred_probs).numpy():.4f}")

# Multi-class with integer labels (more memory-efficient)
y_true_sparse = np.array([0, 1, 2])
scce = tf.keras.losses.SparseCategoricalCrossentropy()
print(f"Sparse CCE: {scce(y_true_sparse, y_pred_probs).numpy():.4f}")
Guide Loss Function Selection

Regression: Use MSE for clean data, MAE for outlier-prone data, Huber for a balanced approach.

Binary classification: Use BinaryCrossentropy with a sigmoid output.

Multi-class: Use SparseCategoricalCrossentropy with integer labels (saves memory), or CategoricalCrossentropy with one-hot labels. Use from_logits=True when the output layer has no softmax activation for better numerical stability.

MSE MAE Huber BCE SparseCCE

Custom Loss Functions

When built-in losses don't fit your problem, you can write a custom loss as a plain function or subclass tf.keras.losses.Loss for statefulness and serialization.

import tensorflow as tf
import numpy as np

# Custom loss as a simple function
def weighted_mse(y_true, y_pred):
    """MSE that penalizes under-predictions 2x more than over-predictions."""
    error = y_true - y_pred
    weights = tf.where(error > 0, 2.0, 1.0)  # under-prediction gets 2x weight
    return tf.reduce_mean(weights * tf.square(error))

# Custom loss as a class (supports serialization)
class FocalLoss(tf.keras.losses.Loss):
    """Focal loss for imbalanced classification."""
    def __init__(self, gamma=2.0, alpha=0.25, **kwargs):
        super().__init__(**kwargs)
        self.gamma = gamma
        self.alpha = alpha

    def call(self, y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
        cross_entropy = -y_true * tf.math.log(y_pred)
        weight = self.alpha * y_true * tf.pow(1.0 - y_pred, self.gamma)
        return tf.reduce_mean(weight * cross_entropy)

    def get_config(self):
        config = super().get_config()
        config.update({'gamma': self.gamma, 'alpha': self.alpha})
        return config

# Test custom losses
y_true = tf.constant([1.0, 2.0, 3.0])
y_pred = tf.constant([0.8, 2.5, 2.7])
print(f"Weighted MSE: {weighted_mse(y_true, y_pred).numpy():.4f}")

focal = FocalLoss(gamma=2.0, alpha=0.25)
y_true_cls = tf.constant([[1.0, 0.0], [0.0, 1.0]])
y_pred_cls = tf.constant([[0.9, 0.1], [0.3, 0.7]])
print(f"Focal Loss: {focal(y_true_cls, y_pred_cls).numpy():.4f}")

Optimizers

The optimizer determines how the model updates its weights in response to the gradient. All optimizers perform gradient descent, but they differ in how they scale, adapt, and accumulate gradient information.

SGD & Momentum

Stochastic Gradient Descent (SGD) updates weights proportionally to the gradient. Adding momentum accelerates convergence by accumulating a velocity term from past gradients:

$$v_t = \gamma v_{t-1} + \eta \nabla L$$

$$\theta_t = \theta_{t-1} - v_t$$

where $\gamma$ is the momentum coefficient (typically 0.9), $\eta$ is the learning rate, and $\nabla L$ is the gradient of the loss.

import tensorflow as tf
import numpy as np

# Generate synthetic data
np.random.seed(42)
X = np.random.randn(500, 5).astype(np.float32)
y = (X @ np.array([1, -2, 3, -1, 0.5]).reshape(-1, 1) + 0.1).astype(np.float32)

# SGD without momentum
model_sgd = tf.keras.Sequential([
    tf.keras.layers.Dense(32, activation='relu', input_shape=(5,)),
    tf.keras.layers.Dense(1)
])
model_sgd.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
                  loss='mse')
history_sgd = model_sgd.fit(X, y, epochs=20, batch_size=32, verbose=0)
print(f"SGD final loss: {history_sgd.history['loss'][-1]:.4f}")

# SGD with momentum — converges faster
model_mom = tf.keras.Sequential([
    tf.keras.layers.Dense(32, activation='relu', input_shape=(5,)),
    tf.keras.layers.Dense(1)
])
model_mom.compile(
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),
    loss='mse'
)
history_mom = model_mom.fit(X, y, epochs=20, batch_size=32, verbose=0)
print(f"SGD+Momentum final loss: {history_mom.history['loss'][-1]:.4f}")

# SGD with Nesterov momentum — looks ahead before computing gradient
model_nag = tf.keras.Sequential([
    tf.keras.layers.Dense(32, activation='relu', input_shape=(5,)),
    tf.keras.layers.Dense(1)
])
model_nag.compile(
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9, nesterov=True),
    loss='mse'
)
history_nag = model_nag.fit(X, y, epochs=20, batch_size=32, verbose=0)
print(f"SGD+Nesterov final loss: {history_nag.history['loss'][-1]:.4f}")

Adam & AdamW

Adam (Adaptive Moment Estimation) maintains per-parameter adaptive learning rates using exponentially decaying averages of past gradients ($m_t$, first moment) and squared gradients ($v_t$, second moment):

$$m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t$$

$$v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2$$

$$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$

$$\theta_t = \theta_{t-1} - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t$$

AdamW decouples weight decay from the gradient update, which often produces better generalization than Adam's L2 regularization.

import tensorflow as tf
import numpy as np

# Generate classification data
np.random.seed(42)
X = np.random.randn(1000, 20).astype(np.float32)
y = (np.sum(X[:, :5], axis=1) > 0).astype(np.float32)

def build_model():
    return tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(20,)),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])

# Adam — the most popular optimizer
model_adam = build_model()
model_adam.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss='binary_crossentropy',
    metrics=['accuracy']
)
h1 = model_adam.fit(X, y, epochs=15, batch_size=32, verbose=0)
print(f"Adam — Loss: {h1.history['loss'][-1]:.4f}, Acc: {h1.history['accuracy'][-1]:.4f}")

# AdamW — decoupled weight decay (better generalization)
model_adamw = build_model()
model_adamw.compile(
    optimizer=tf.keras.optimizers.AdamW(
        learning_rate=1e-3,
        weight_decay=1e-4   # decoupled weight decay
    ),
    loss='binary_crossentropy',
    metrics=['accuracy']
)
h2 = model_adamw.fit(X, y, epochs=15, batch_size=32, verbose=0)
print(f"AdamW — Loss: {h2.history['loss'][-1]:.4f}, Acc: {h2.history['accuracy'][-1]:.4f}")

RMSProp & Adagrad

RMSProp maintains a moving average of squared gradients to normalize updates — works well for recurrent neural networks. Adagrad adapts learning rates per-parameter based on accumulated historical gradients, excelling for sparse features but can decay the learning rate too aggressively.

import tensorflow as tf
import numpy as np

np.random.seed(42)
X = np.random.randn(500, 10).astype(np.float32)
y = (X @ np.random.randn(10, 1)).astype(np.float32)

def build_regression_model():
    return tf.keras.Sequential([
        tf.keras.layers.Dense(32, activation='relu', input_shape=(10,)),
        tf.keras.layers.Dense(1)
    ])

# RMSProp — good default for RNNs
model_rms = build_regression_model()
model_rms.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-3), loss='mse')
h_rms = model_rms.fit(X, y, epochs=20, batch_size=32, verbose=0)
print(f"RMSProp final loss: {h_rms.history['loss'][-1]:.4f}")

# Adagrad — per-parameter adaptive rates, good for sparse data
model_adagrad = build_regression_model()
model_adagrad.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.01), loss='mse')
h_ada = model_adagrad.fit(X, y, epochs=20, batch_size=32, verbose=0)
print(f"Adagrad final loss: {h_ada.history['loss'][-1]:.4f}")
Optimizer Cheat Sheet: Start with Adam (lr=1e-3) for most tasks. Use AdamW when you need weight decay regularization (especially for Transformers). Use SGD + momentum for vision models when you have the patience to tune the learning rate. Use RMSProp for RNN/LSTM architectures.

Learning Rate Schedules

The learning rate is the single most important hyperparameter. Starting high helps escape bad local minima; decaying it over time enables fine-grained convergence. TensorFlow provides several built-in schedules and the flexibility to write your own.

Built-in Schedules

Here is the implementation for Built-in Schedules. Each code example below is self-contained and can be run independently:

import tensorflow as tf
import numpy as np

# ExponentialDecay — multiply LR by decay_rate every decay_steps
exp_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-2,
    decay_steps=1000,
    decay_rate=0.9,
    staircase=False  # smooth decay; True for step-wise
)
print(f"ExponentialDecay at step 0:    {exp_schedule(0).numpy():.6f}")
print(f"ExponentialDecay at step 1000: {exp_schedule(1000).numpy():.6f}")
print(f"ExponentialDecay at step 5000: {exp_schedule(5000).numpy():.6f}")

# CosineDecay — smooth annealing to near-zero
cosine_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=1e-2,
    decay_steps=10000,
    alpha=1e-5  # minimum learning rate
)
print(f"\nCosineDecay at step 0:     {cosine_schedule(0).numpy():.6f}")
print(f"CosineDecay at step 5000:  {cosine_schedule(5000).numpy():.6f}")
print(f"CosineDecay at step 10000: {cosine_schedule(10000).numpy():.6f}")

# PiecewiseConstantDecay — manual step boundaries
piece_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    boundaries=[1000, 3000, 5000],
    values=[1e-2, 5e-3, 1e-3, 5e-4]
)
for step in [0, 500, 1500, 4000, 6000]:
    print(f"PiecewiseConstant at step {step}: {piece_schedule(step).numpy():.6f}")

# Use a schedule as the optimizer's learning rate
optimizer = tf.keras.optimizers.Adam(learning_rate=cosine_schedule)
print(f"\nOptimizer initial LR: {optimizer.learning_rate(0).numpy():.6f}")

Warmup Strategies

Warmup gradually increases the learning rate from near-zero over the first few thousand steps, preventing early training instability. This is essential for Transformer models and large batch training.

import tensorflow as tf
import numpy as np

class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    """Linear warmup followed by cosine decay."""
    def __init__(self, peak_lr, warmup_steps, total_steps):
        super().__init__()
        self.peak_lr = peak_lr
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup = self.peak_lr * (step / self.warmup_steps)
        decay_steps = self.total_steps - self.warmup_steps
        decay_progress = (step - self.warmup_steps) / decay_steps
        cosine = 0.5 * self.peak_lr * (1 + tf.cos(np.pi * decay_progress))
        return tf.where(step < self.warmup_steps, warmup, cosine)

    def get_config(self):
        return {
            'peak_lr': self.peak_lr,
            'warmup_steps': self.warmup_steps,
            'total_steps': self.total_steps
        }

# Create warmup+cosine schedule
schedule = WarmupCosineDecay(peak_lr=1e-3, warmup_steps=1000, total_steps=10000)
for step in [0, 250, 500, 1000, 3000, 5000, 10000]:
    print(f"Step {step:>5d}: LR = {schedule(step).numpy():.6f}")

# Plug it into an optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=schedule)
print(f"\nOptimizer LR at step 1000: {optimizer.learning_rate(1000).numpy():.6f}")

ReduceLROnPlateau

Instead of a fixed schedule, ReduceLROnPlateau monitors a metric and reduces the learning rate when it plateaus — a reactive approach that adapts to the actual training dynamics.

import tensorflow as tf
import numpy as np

# ReduceLROnPlateau callback
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',   # metric to watch
    factor=0.5,           # multiply LR by this when triggered
    patience=3,           # epochs without improvement before reducing
    min_lr=1e-6,          # floor for the learning rate
    verbose=1             # print when LR is reduced
)

# Quick demo
np.random.seed(42)
X = np.random.randn(500, 5).astype(np.float32)
y = (X @ np.random.randn(5, 1)).astype(np.float32)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(32, activation='relu', input_shape=(5,)),
    tf.keras.layers.Dense(1)
])
model.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss='mse')
model.fit(X, y, epochs=20, batch_size=32, validation_split=0.2,
          callbacks=[reduce_lr], verbose=0)
print(f"Final LR: {model.optimizer.learning_rate.numpy():.6f}")

Metrics

Metrics measure model quality during training and evaluation. Unlike the loss function (which must be differentiable), metrics can be any function — accuracy, AUC, precision, recall, or custom aggregations.

Built-in Metrics

Here is the implementation for Built-in Metrics. Each code example below is self-contained and can be run independently:

import tensorflow as tf
import numpy as np

# Classification model with multiple metrics
np.random.seed(42)
X = np.random.randn(1000, 10).astype(np.float32)
y = (np.sum(X[:, :3], axis=1) > 0).astype(np.float32)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(32, activation='relu', input_shape=(10,)),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=[
        tf.keras.metrics.BinaryAccuracy(name='accuracy'),
        tf.keras.metrics.AUC(name='auc'),
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
    ]
)

history = model.fit(X, y, epochs=10, batch_size=32, verbose=0)
results = model.evaluate(X, y, verbose=0)
for name, value in zip(model.metrics_names, results):
    print(f"{name:>12s}: {value:.4f}")

Custom Metrics

For metrics like F1 score that can't be computed per-batch (they need aggregate TP/FP/FN), subclass tf.keras.metrics.Metric to maintain state across batches.

import tensorflow as tf
import numpy as np

class F1Score(tf.keras.metrics.Metric):
    """Stateful F1 score that accumulates TP, FP, FN across batches."""
    def __init__(self, threshold=0.5, name='f1_score', **kwargs):
        super().__init__(name=name, **kwargs)
        self.threshold = threshold
        self.tp = self.add_weight(name='tp', initializer='zeros')
        self.fp = self.add_weight(name='fp', initializer='zeros')
        self.fn = self.add_weight(name='fn', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred_binary = tf.cast(y_pred >= self.threshold, tf.float32)
        y_true = tf.cast(y_true, tf.float32)
        self.tp.assign_add(tf.reduce_sum(y_pred_binary * y_true))
        self.fp.assign_add(tf.reduce_sum(y_pred_binary * (1 - y_true)))
        self.fn.assign_add(tf.reduce_sum((1 - y_pred_binary) * y_true))

    def result(self):
        precision = self.tp / (self.tp + self.fp + 1e-7)
        recall = self.tp / (self.tp + self.fn + 1e-7)
        return 2 * precision * recall / (precision + recall + 1e-7)

    def reset_state(self):
        self.tp.assign(0)
        self.fp.assign(0)
        self.fn.assign(0)

# Use the custom metric
np.random.seed(42)
X = np.random.randn(500, 8).astype(np.float32)
y = (np.sum(X[:, :3], axis=1) > 0).astype(np.float32)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(32, activation='relu', input_shape=(8,)),
    tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy',
              metrics=['accuracy', F1Score(threshold=0.5)])
model.fit(X, y, epochs=10, batch_size=32, verbose=0)
results = model.evaluate(X, y, verbose=0)
print(f"Accuracy: {results[1]:.4f}")
print(f"F1 Score: {results[2]:.4f}")

compile() Deep Dive

The compile() method has many arguments beyond the standard trio of loss, optimizer, and metrics. Understanding them lets you unlock multi-output training, eager debugging, and XLA compilation.

Arguments & Options

Reference model.compile() Arguments

loss — Loss function (string, function, or Loss instance). Can be a dict for multi-output models.

optimizer — Optimizer (string or instance). Strings like 'adam' use default hyperparameters.

metrics — List of metrics to track during training. Can be strings or Metric instances.

loss_weights — Dict or list of floats to weight losses in multi-output models.

run_eagerly — If True, disables tf.function tracing for easier debugging (slower).

jit_compile — If True, uses XLA compilation for faster training (requires compatible ops).

loss optimizer metrics jit_compile

Multi-Output Models

When a model has multiple output heads, you can assign different loss functions and weights to each.

import tensorflow as tf
import numpy as np

# Multi-output model: regression head + classification head
inputs = tf.keras.Input(shape=(10,))
shared = tf.keras.layers.Dense(64, activation='relu')(inputs)

# Output 1: regression (predict price)
price_output = tf.keras.layers.Dense(1, name='price')(shared)

# Output 2: classification (predict category)
category_output = tf.keras.layers.Dense(3, activation='softmax', name='category')(shared)

model = tf.keras.Model(inputs=inputs, outputs=[price_output, category_output])

# Compile with per-output loss and weights
model.compile(
    optimizer='adam',
    loss={
        'price': 'mse',
        'category': 'sparse_categorical_crossentropy'
    },
    loss_weights={
        'price': 1.0,       # equal weight
        'category': 0.5     # classification loss weighted 50%
    },
    metrics={
        'price': ['mae'],
        'category': ['accuracy']
    },
    jit_compile=True  # XLA compilation for speed
)

# Generate dummy data
np.random.seed(42)
X = np.random.randn(500, 10).astype(np.float32)
y_price = np.random.randn(500, 1).astype(np.float32)
y_category = np.random.randint(0, 3, size=(500,))

model.fit(X, {'price': y_price, 'category': y_category},
          epochs=5, batch_size=32, verbose=0)
results = model.evaluate(X, {'price': y_price, 'category': y_category}, verbose=0)
print(f"Total loss: {results[0]:.4f}")
model.summary()

Batch Size & Gradient Accumulation

Batch size controls how many samples the model sees before each weight update. It's a fundamental tradeoff between computational efficiency, memory usage, and optimization dynamics.

Memory vs Convergence

Batch Size Tradeoffs: Larger batches give more stable gradients and faster hardware utilization, but may converge to sharper (less generalizable) minima and require more memory. Smaller batches add noise that acts as regularization, often finding flatter minima that generalize better — but train slower due to less hardware parallelism. The linear scaling rule: when you double the batch size, double the learning rate.

Gradient Accumulation

Gradient accumulation simulates a large batch size without the memory cost — you accumulate gradients over multiple mini-batches before applying a single weight update. This is the standard technique when your desired batch size exceeds GPU memory.

import tensorflow as tf
import numpy as np

# Gradient accumulation with a custom training step
class GradientAccumulationModel(tf.keras.Model):
    """Model that accumulates gradients over multiple mini-batches."""
    def __init__(self, base_model, accumulation_steps=4):
        super().__init__()
        self.base_model = base_model
        self.accumulation_steps = accumulation_steps
        self.step_count = tf.Variable(0, trainable=False, dtype=tf.int32)
        self.gradient_accumulator = None

    def call(self, inputs, training=False):
        return self.base_model(inputs, training=training)

    def train_step(self, data):
        x, y = data

        # Initialize accumulator on first call
        if self.gradient_accumulator is None:
            self.gradient_accumulator = [
                tf.Variable(tf.zeros_like(v), trainable=False)
                for v in self.base_model.trainable_variables
            ]

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred)
            # Scale loss by accumulation steps
            scaled_loss = loss / self.accumulation_steps

        grads = tape.gradient(scaled_loss, self.base_model.trainable_variables)

        # Accumulate gradients
        for acc, grad in zip(self.gradient_accumulator, grads):
            if grad is not None:
                acc.assign_add(grad)

        self.step_count.assign_add(1)

        # Apply when we've accumulated enough
        if self.step_count % self.accumulation_steps == 0:
            self.optimizer.apply_gradients(
                zip(self.gradient_accumulator, self.base_model.trainable_variables)
            )
            for acc in self.gradient_accumulator:
                acc.assign(tf.zeros_like(acc))

        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

# Demo: effective batch size = 32 * 4 = 128, but only 32 in memory
np.random.seed(42)
X = np.random.randn(1000, 10).astype(np.float32)
y = (X @ np.random.randn(10, 1)).astype(np.float32)

base = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
    tf.keras.layers.Dense(1)
])
model = GradientAccumulationModel(base, accumulation_steps=4)
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
model.fit(X, y, epochs=5, batch_size=32, verbose=0)  # effective batch = 128
print(f"Final loss: {model.evaluate(X, y, verbose=0)[0]:.4f}")

Gradient Clipping

Gradient clipping prevents exploding gradients — a common problem in deep networks and RNNs where gradients grow exponentially through layers, causing weight updates that blow up the model.

Clipping Methods

TensorFlow optimizers support three clipping strategies directly via constructor arguments:

Gradient Clipping Methods
flowchart LR
    A[Computed Gradients] --> B{Clipping Method}
    B -->|clipvalue| C["Clip each element\nto [-val, val]"]
    B -->|clipnorm| D["Scale gradient if\n‖g‖ > norm"]
    B -->|global_clipnorm| E["Scale ALL gradients\nby global norm"]

    C --> F[Clipped Gradients]
    D --> F
    E --> F
    F --> G[Apply to Weights]

    style A fill:#132440,stroke:#3B9797,color:#fff
    style B fill:#3B9797,stroke:#132440,color:#fff
    style G fill:#132440,stroke:#3B9797,color:#fff
                            
import tensorflow as tf
import numpy as np

np.random.seed(42)
X = np.random.randn(500, 10).astype(np.float32)
y = (X @ np.random.randn(10, 1)).astype(np.float32)

# Method 1: clipvalue — clip each gradient element to [-0.5, 0.5]
model1 = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
    tf.keras.layers.Dense(1)
])
model1.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3, clipvalue=0.5),
    loss='mse'
)
model1.fit(X, y, epochs=5, batch_size=32, verbose=0)
print(f"clipvalue=0.5, loss: {model1.evaluate(X, y, verbose=0):.4f}")

# Method 2: clipnorm — scale gradient if its L2 norm exceeds threshold
model2 = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
    tf.keras.layers.Dense(1)
])
model2.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3, clipnorm=1.0),
    loss='mse'
)
model2.fit(X, y, epochs=5, batch_size=32, verbose=0)
print(f"clipnorm=1.0, loss: {model2.evaluate(X, y, verbose=0):.4f}")

# Method 3: global_clipnorm — scale ALL gradients by global norm
model3 = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
    tf.keras.layers.Dense(1)
])
model3.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3, global_clipnorm=1.0),
    loss='mse'
)
model3.fit(X, y, epochs=5, batch_size=32, verbose=0)
print(f"global_clipnorm=1.0, loss: {model3.evaluate(X, y, verbose=0):.4f}")

Clipping for RNNs

Recurrent networks (LSTM, GRU) are particularly prone to exploding gradients because the same weight matrices are applied at every timestep. global_clipnorm is the recommended strategy — it preserves the relative direction of gradients across all parameters.

import tensorflow as tf
import numpy as np

# Simulated sequence data
np.random.seed(42)
X_seq = np.random.randn(200, 50, 10).astype(np.float32)  # 200 samples, 50 timesteps, 10 features
y_seq = np.random.randint(0, 2, size=(200,)).astype(np.float32)

# LSTM with gradient clipping — prevents exploding gradients
model = tf.keras.Sequential([
    tf.keras.layers.LSTM(64, input_shape=(50, 10), return_sequences=True),
    tf.keras.layers.LSTM(32),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(
        learning_rate=1e-3,
        global_clipnorm=1.0  # recommended for RNNs
    ),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

history = model.fit(X_seq, y_seq, epochs=10, batch_size=32, verbose=0)
print(f"Final loss: {history.history['loss'][-1]:.4f}")
print(f"Final accuracy: {history.history['accuracy'][-1]:.4f}")
When to Clip: Always use gradient clipping for RNNs/LSTMs. For Transformers, global_clipnorm=1.0 is standard practice. For simple feedforward networks with ReLU activations, clipping is usually unnecessary. If your loss suddenly jumps to NaN or infinity during training, try adding global_clipnorm=1.0.

Mixed Precision Training

Mixed precision training uses 16-bit floating point (float16) for forward/backward passes while keeping 32-bit master copies of weights. This can double training throughput on modern GPUs with Tensor Cores (NVIDIA V100, A100, RTX 30xx/40xx) while maintaining model accuracy.

Policy & Setup

Here is the implementation for Policy & Setup. Each code example below is self-contained and can be run independently:

import tensorflow as tf

# Enable mixed precision globally
tf.keras.mixed_precision.set_global_policy('mixed_float16')

# Verify the current policy
policy = tf.keras.mixed_precision.global_policy()
print(f"Policy name: {policy.name}")
print(f"Compute dtype: {policy.compute_dtype}")   # float16
print(f"Variable dtype: {policy.variable_dtype}")  # float32 (master weights)

# Build a model — layers automatically use float16 compute
model = tf.keras.Sequential([
    tf.keras.layers.Dense(256, activation='relu', input_shape=(100,)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(64, activation='relu'),
    # IMPORTANT: output layer must be float32 for numerical stability
    tf.keras.layers.Dense(10, dtype='float32')
])

# Check dtypes
for layer in model.layers:
    print(f"{layer.name}: compute={layer.dtype_policy.compute_dtype}, "
          f"variable={layer.dtype_policy.variable_dtype}")

# Reset to float32 for subsequent examples
tf.keras.mixed_precision.set_global_policy('float32')

Loss Scaling

Float16 has a smaller dynamic range than float32, which can cause small gradient values to underflow to zero. Loss scaling multiplies the loss by a large factor before backpropagation, then divides the gradients by the same factor — this keeps small gradients in float16's representable range.

import tensorflow as tf
import numpy as np

# Enable mixed precision
tf.keras.mixed_precision.set_global_policy('mixed_float16')

# Build model with float32 output
model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(20,)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1, dtype='float32')  # keep output in float32
])

# The optimizer automatically handles loss scaling in TF2
# LossScaleOptimizer wraps the base optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

model.compile(optimizer=optimizer, loss='mse', metrics=['mae'])

# Train with mixed precision
np.random.seed(42)
X = np.random.randn(1000, 20).astype(np.float32)
y = np.random.randn(1000, 1).astype(np.float32)

history = model.fit(X, y, epochs=10, batch_size=64, verbose=0)
print(f"Mixed precision final loss: {history.history['loss'][-1]:.4f}")
print(f"Mixed precision final MAE:  {history.history['mae'][-1]:.4f}")

# Reset policy
tf.keras.mixed_precision.set_global_policy('float32')
Mixed Precision Tips: Always keep the final output layer in float32 with dtype='float32'. On GPUs with Tensor Cores (V100+), expect 1.5–3× speedup. On CPUs, mixed precision offers no benefit. Layer dimensions should be multiples of 8 for optimal Tensor Core utilization.

Complete Example

Let's combine everything — an end-to-end training pipeline using a cosine learning rate schedule, the Adam optimizer with gradient clipping, multiple metrics, and mixed precision.

End-to-End Training Pipeline

Here is the implementation for End-to-End Training Pipeline. Each code example below is self-contained and can be run independently:

import tensorflow as tf
import numpy as np

# ── 1. Enable mixed precision ──────────────────────────────
tf.keras.mixed_precision.set_global_policy('mixed_float16')

# ── 2. Generate synthetic multi-class data ──────────────────
np.random.seed(42)
num_classes = 10
X_train = np.random.randn(5000, 50).astype(np.float32)
y_train = np.random.randint(0, num_classes, size=(5000,))
X_val = np.random.randn(1000, 50).astype(np.float32)
y_val = np.random.randint(0, num_classes, size=(1000,))

# ── 3. Learning rate schedule: warmup + cosine decay ────────
total_steps = (5000 // 64) * 30  # ~2343 steps
warmup_steps = total_steps // 10  # 10% warmup

lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=1e-3,
    decay_steps=total_steps - warmup_steps,
    alpha=1e-5
)

# ── 4. Optimizer with gradient clipping ─────────────────────
optimizer = tf.keras.optimizers.AdamW(
    learning_rate=lr_schedule,
    weight_decay=1e-4,
    global_clipnorm=1.0   # prevent exploding gradients
)

# ── 5. Build model ─────────────────────────────────────────
model = tf.keras.Sequential([
    tf.keras.layers.Dense(256, activation='relu', input_shape=(50,)),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(64, activation='relu'),
    # Output layer in float32 for mixed precision stability
    tf.keras.layers.Dense(num_classes, dtype='float32')
])

# ── 6. Compile with multiple metrics ───────────────────────
model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[
        tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
        tf.keras.metrics.SparseTopKCategoricalAccuracy(k=3, name='top3_accuracy'),
    ],
    jit_compile=True  # XLA compilation for extra speed
)

# ── 7. Callbacks ────────────────────────────────────────────
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=5, restore_best_weights=True
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6
    ),
]

# ── 8. Train ────────────────────────────────────────────────
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=30,
    batch_size=64,
    callbacks=callbacks,
    verbose=1
)

# ── 9. Evaluate ─────────────────────────────────────────────
results = model.evaluate(X_val, y_val, verbose=0)
for name, value in zip(model.metrics_names, results):
    print(f"{name:>15s}: {value:.4f}")

# ── 10. Predict ─────────────────────────────────────────────
predictions = model.predict(X_val[:5], verbose=0)
predicted_classes = np.argmax(predictions, axis=1)
print(f"\nPredicted classes: {predicted_classes}")
print(f"True classes:      {y_val[:5]}")

# Reset precision policy
tf.keras.mixed_precision.set_global_policy('float32')

Summary & Next Steps

In this article we covered the complete TensorFlow training pipeline:

  • Loss functions — MSE, MAE, Huber for regression; BCE, CCE, SparseCCE for classification; custom losses for specialized problems
  • Optimizers — SGD (with momentum/Nesterov), Adam, AdamW, RMSProp, Adagrad — and when to choose each
  • Learning rate schedules — ExponentialDecay, CosineDecay, PiecewiseConstantDecay, warmup strategies, ReduceLROnPlateau
  • Metrics — built-in metrics (Accuracy, AUC, Precision, Recall) and custom stateful metrics (F1 score)
  • compile() options — loss_weights for multi-output, run_eagerly for debugging, jit_compile for XLA speed
  • Batch size — memory/convergence tradeoffs and gradient accumulation for large effective batch sizes
  • Gradient clipping — clipvalue, clipnorm, global_clipnorm for preventing exploding gradients
  • Mixed precision — float16 compute with float32 master weights for up to 3× speedup on Tensor Core GPUs

Next in the Series

In Part 4: Data Pipelines with tf.data, we'll build efficient input pipelines — from raw files to batched, shuffled, prefetched datasets that keep the GPU fully utilized during training.