The Training Pipeline
Every TensorFlow model follows a four-step lifecycle: compile the model with an optimizer and loss function, fit it to training data, evaluate on held-out data, and predict on new inputs. Understanding each step — and the knobs you can turn — is the key to training models that converge quickly and generalize well.
model.compile() configures the training strategy, model.fit() runs gradient descent, model.evaluate() measures test performance, and model.predict() produces inference outputs. Everything in this article revolves around making these four steps as effective as possible.
Pipeline Workflow
The diagram below shows how the compile → fit → evaluate → predict pipeline connects to the optimization components we'll cover in this article — loss functions, optimizers, learning rate schedules, metrics, and advanced techniques like gradient clipping and mixed precision.
flowchart TD
A[Define Model] --> B[model.compile]
B --> C[model.fit]
C --> D[model.evaluate]
D --> E[model.predict]
B --> B1[Loss Function]
B --> B2[Optimizer]
B --> B3[Metrics]
B2 --> B2a[Learning Rate Schedule]
B2 --> B2b[Gradient Clipping]
C --> C1[Training Loop]
C1 --> C2[Forward Pass]
C2 --> C3[Compute Loss]
C3 --> C4[Backward Pass]
C4 --> C5[Update Weights]
C5 --> C1
style A fill:#132440,stroke:#3B9797,color:#fff
style B fill:#16476A,stroke:#3B9797,color:#fff
style C fill:#3B9797,stroke:#132440,color:#fff
style D fill:#16476A,stroke:#3B9797,color:#fff
style E fill:#132440,stroke:#3B9797,color:#fff
Let's start with the minimal compile-fit-evaluate-predict workflow and then progressively unlock more powerful training options.
import tensorflow as tf
import numpy as np
# Generate synthetic regression data
np.random.seed(42)
X_train = np.random.randn(1000, 10).astype(np.float32)
y_train = (X_train @ np.random.randn(10, 1) + 0.5).astype(np.float32)
X_test = np.random.randn(200, 10).astype(np.float32)
y_test = (X_test @ np.random.randn(10, 1) + 0.5).astype(np.float32)
# Build a simple model
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(1)
])
# 1. Compile — configure optimizer, loss, metrics
model.compile(
optimizer='adam',
loss='mse',
metrics=['mae']
)
# 2. Fit — train on data
history = model.fit(X_train, y_train, epochs=10, batch_size=32,
validation_split=0.2, verbose=1)
# 3. Evaluate — test performance
test_loss, test_mae = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Loss (MSE): {test_loss:.4f}")
print(f"Test MAE: {test_mae:.4f}")
# 4. Predict — inference
predictions = model.predict(X_test[:5])
print("Predictions shape:", predictions.shape)
print("First 5 predictions:", predictions.flatten())
The history object returned by fit() contains loss and metric values for every epoch, which you can plot to diagnose training behaviour — whether the model is converging, overfitting, or underfitting.
Loss Functions
The loss function measures how far the model's predictions are from the true labels. Choosing the right loss is critical — it defines what the model optimizes for. TensorFlow provides loss functions for regression, binary classification, and multi-class classification, plus the ability to write your own.
The general cross-entropy loss for multi-class classification is:
$$\mathcal{L} = -\sum_{i} y_i \log(\hat{y}_i)$$
where $y_i$ is the true label (one-hot encoded) and $\hat{y}_i$ is the predicted probability for class $i$.
Regression Losses
Mean Squared Error (MSE) penalizes large errors quadratically, making it sensitive to outliers. Mean Absolute Error (MAE) is more robust to outliers. Huber loss combines both — it's quadratic for small errors and linear for large ones.
import tensorflow as tf
import numpy as np
# Sample data
y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
y_pred = np.array([1.1, 2.3, 2.8, 4.5, 4.7])
# Mean Squared Error
mse = tf.keras.losses.MeanSquaredError()
print(f"MSE: {mse(y_true, y_pred).numpy():.4f}")
# Mean Absolute Error
mae = tf.keras.losses.MeanAbsoluteError()
print(f"MAE: {mae(y_true, y_pred).numpy():.4f}")
# Huber Loss (delta controls MSE-MAE transition)
huber = tf.keras.losses.Huber(delta=1.0)
print(f"Huber: {huber(y_true, y_pred).numpy():.4f}")
# Huber with smaller delta — more robust to outliers
huber_small = tf.keras.losses.Huber(delta=0.5)
print(f"Huber (δ=0.5): {huber_small(y_true, y_pred).numpy():.4f}")
Classification Losses
Binary Cross-Entropy (BCE) is for two-class problems. Categorical Cross-Entropy (CCE) is for multi-class with one-hot labels. Sparse Categorical Cross-Entropy is the same as CCE but takes integer labels directly, avoiding the memory cost of one-hot encoding.
import tensorflow as tf
import numpy as np
# Binary classification
y_true_binary = np.array([1, 0, 1, 1, 0], dtype=np.float32)
y_pred_binary = np.array([0.9, 0.1, 0.8, 0.7, 0.3], dtype=np.float32)
bce = tf.keras.losses.BinaryCrossentropy()
print(f"BCE: {bce(y_true_binary, y_pred_binary).numpy():.4f}")
# Multi-class with one-hot labels
y_true_onehot = np.array([[1,0,0], [0,1,0], [0,0,1]], dtype=np.float32)
y_pred_probs = np.array([[0.9,0.05,0.05], [0.1,0.8,0.1], [0.2,0.3,0.5]], dtype=np.float32)
cce = tf.keras.losses.CategoricalCrossentropy()
print(f"CCE: {cce(y_true_onehot, y_pred_probs).numpy():.4f}")
# Multi-class with integer labels (more memory-efficient)
y_true_sparse = np.array([0, 1, 2])
scce = tf.keras.losses.SparseCategoricalCrossentropy()
print(f"Sparse CCE: {scce(y_true_sparse, y_pred_probs).numpy():.4f}")
Regression: Use MSE for clean data, MAE for outlier-prone data, Huber for a balanced approach.
Binary classification: Use BinaryCrossentropy with a sigmoid output.
Multi-class: Use SparseCategoricalCrossentropy with integer labels (saves memory), or CategoricalCrossentropy with one-hot labels. Use from_logits=True when the output layer has no softmax activation for better numerical stability.
Custom Loss Functions
When built-in losses don't fit your problem, you can write a custom loss as a plain function or subclass tf.keras.losses.Loss for statefulness and serialization.
import tensorflow as tf
import numpy as np
# Custom loss as a simple function
def weighted_mse(y_true, y_pred):
"""MSE that penalizes under-predictions 2x more than over-predictions."""
error = y_true - y_pred
weights = tf.where(error > 0, 2.0, 1.0) # under-prediction gets 2x weight
return tf.reduce_mean(weights * tf.square(error))
# Custom loss as a class (supports serialization)
class FocalLoss(tf.keras.losses.Loss):
"""Focal loss for imbalanced classification."""
def __init__(self, gamma=2.0, alpha=0.25, **kwargs):
super().__init__(**kwargs)
self.gamma = gamma
self.alpha = alpha
def call(self, y_true, y_pred):
y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
cross_entropy = -y_true * tf.math.log(y_pred)
weight = self.alpha * y_true * tf.pow(1.0 - y_pred, self.gamma)
return tf.reduce_mean(weight * cross_entropy)
def get_config(self):
config = super().get_config()
config.update({'gamma': self.gamma, 'alpha': self.alpha})
return config
# Test custom losses
y_true = tf.constant([1.0, 2.0, 3.0])
y_pred = tf.constant([0.8, 2.5, 2.7])
print(f"Weighted MSE: {weighted_mse(y_true, y_pred).numpy():.4f}")
focal = FocalLoss(gamma=2.0, alpha=0.25)
y_true_cls = tf.constant([[1.0, 0.0], [0.0, 1.0]])
y_pred_cls = tf.constant([[0.9, 0.1], [0.3, 0.7]])
print(f"Focal Loss: {focal(y_true_cls, y_pred_cls).numpy():.4f}")
Optimizers
The optimizer determines how the model updates its weights in response to the gradient. All optimizers perform gradient descent, but they differ in how they scale, adapt, and accumulate gradient information.
SGD & Momentum
Stochastic Gradient Descent (SGD) updates weights proportionally to the gradient. Adding momentum accelerates convergence by accumulating a velocity term from past gradients:
$$v_t = \gamma v_{t-1} + \eta \nabla L$$
$$\theta_t = \theta_{t-1} - v_t$$
where $\gamma$ is the momentum coefficient (typically 0.9), $\eta$ is the learning rate, and $\nabla L$ is the gradient of the loss.
import tensorflow as tf
import numpy as np
# Generate synthetic data
np.random.seed(42)
X = np.random.randn(500, 5).astype(np.float32)
y = (X @ np.array([1, -2, 3, -1, 0.5]).reshape(-1, 1) + 0.1).astype(np.float32)
# SGD without momentum
model_sgd = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(5,)),
tf.keras.layers.Dense(1)
])
model_sgd.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
loss='mse')
history_sgd = model_sgd.fit(X, y, epochs=20, batch_size=32, verbose=0)
print(f"SGD final loss: {history_sgd.history['loss'][-1]:.4f}")
# SGD with momentum — converges faster
model_mom = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(5,)),
tf.keras.layers.Dense(1)
])
model_mom.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),
loss='mse'
)
history_mom = model_mom.fit(X, y, epochs=20, batch_size=32, verbose=0)
print(f"SGD+Momentum final loss: {history_mom.history['loss'][-1]:.4f}")
# SGD with Nesterov momentum — looks ahead before computing gradient
model_nag = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(5,)),
tf.keras.layers.Dense(1)
])
model_nag.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9, nesterov=True),
loss='mse'
)
history_nag = model_nag.fit(X, y, epochs=20, batch_size=32, verbose=0)
print(f"SGD+Nesterov final loss: {history_nag.history['loss'][-1]:.4f}")
Adam & AdamW
Adam (Adaptive Moment Estimation) maintains per-parameter adaptive learning rates using exponentially decaying averages of past gradients ($m_t$, first moment) and squared gradients ($v_t$, second moment):
$$m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t$$
$$v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2$$
$$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$
$$\theta_t = \theta_{t-1} - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t$$
AdamW decouples weight decay from the gradient update, which often produces better generalization than Adam's L2 regularization.
import tensorflow as tf
import numpy as np
# Generate classification data
np.random.seed(42)
X = np.random.randn(1000, 20).astype(np.float32)
y = (np.sum(X[:, :5], axis=1) > 0).astype(np.float32)
def build_model():
return tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(20,)),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
# Adam — the most popular optimizer
model_adam = build_model()
model_adam.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss='binary_crossentropy',
metrics=['accuracy']
)
h1 = model_adam.fit(X, y, epochs=15, batch_size=32, verbose=0)
print(f"Adam — Loss: {h1.history['loss'][-1]:.4f}, Acc: {h1.history['accuracy'][-1]:.4f}")
# AdamW — decoupled weight decay (better generalization)
model_adamw = build_model()
model_adamw.compile(
optimizer=tf.keras.optimizers.AdamW(
learning_rate=1e-3,
weight_decay=1e-4 # decoupled weight decay
),
loss='binary_crossentropy',
metrics=['accuracy']
)
h2 = model_adamw.fit(X, y, epochs=15, batch_size=32, verbose=0)
print(f"AdamW — Loss: {h2.history['loss'][-1]:.4f}, Acc: {h2.history['accuracy'][-1]:.4f}")
RMSProp & Adagrad
RMSProp maintains a moving average of squared gradients to normalize updates — works well for recurrent neural networks. Adagrad adapts learning rates per-parameter based on accumulated historical gradients, excelling for sparse features but can decay the learning rate too aggressively.
import tensorflow as tf
import numpy as np
np.random.seed(42)
X = np.random.randn(500, 10).astype(np.float32)
y = (X @ np.random.randn(10, 1)).astype(np.float32)
def build_regression_model():
return tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1)
])
# RMSProp — good default for RNNs
model_rms = build_regression_model()
model_rms.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-3), loss='mse')
h_rms = model_rms.fit(X, y, epochs=20, batch_size=32, verbose=0)
print(f"RMSProp final loss: {h_rms.history['loss'][-1]:.4f}")
# Adagrad — per-parameter adaptive rates, good for sparse data
model_adagrad = build_regression_model()
model_adagrad.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.01), loss='mse')
h_ada = model_adagrad.fit(X, y, epochs=20, batch_size=32, verbose=0)
print(f"Adagrad final loss: {h_ada.history['loss'][-1]:.4f}")
Learning Rate Schedules
The learning rate is the single most important hyperparameter. Starting high helps escape bad local minima; decaying it over time enables fine-grained convergence. TensorFlow provides several built-in schedules and the flexibility to write your own.
Built-in Schedules
Here is the implementation for Built-in Schedules. Each code example below is self-contained and can be run independently:
import tensorflow as tf
import numpy as np
# ExponentialDecay — multiply LR by decay_rate every decay_steps
exp_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=1e-2,
decay_steps=1000,
decay_rate=0.9,
staircase=False # smooth decay; True for step-wise
)
print(f"ExponentialDecay at step 0: {exp_schedule(0).numpy():.6f}")
print(f"ExponentialDecay at step 1000: {exp_schedule(1000).numpy():.6f}")
print(f"ExponentialDecay at step 5000: {exp_schedule(5000).numpy():.6f}")
# CosineDecay — smooth annealing to near-zero
cosine_schedule = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate=1e-2,
decay_steps=10000,
alpha=1e-5 # minimum learning rate
)
print(f"\nCosineDecay at step 0: {cosine_schedule(0).numpy():.6f}")
print(f"CosineDecay at step 5000: {cosine_schedule(5000).numpy():.6f}")
print(f"CosineDecay at step 10000: {cosine_schedule(10000).numpy():.6f}")
# PiecewiseConstantDecay — manual step boundaries
piece_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
boundaries=[1000, 3000, 5000],
values=[1e-2, 5e-3, 1e-3, 5e-4]
)
for step in [0, 500, 1500, 4000, 6000]:
print(f"PiecewiseConstant at step {step}: {piece_schedule(step).numpy():.6f}")
# Use a schedule as the optimizer's learning rate
optimizer = tf.keras.optimizers.Adam(learning_rate=cosine_schedule)
print(f"\nOptimizer initial LR: {optimizer.learning_rate(0).numpy():.6f}")
Warmup Strategies
Warmup gradually increases the learning rate from near-zero over the first few thousand steps, preventing early training instability. This is essential for Transformer models and large batch training.
import tensorflow as tf
import numpy as np
class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Linear warmup followed by cosine decay."""
def __init__(self, peak_lr, warmup_steps, total_steps):
super().__init__()
self.peak_lr = peak_lr
self.warmup_steps = warmup_steps
self.total_steps = total_steps
def __call__(self, step):
step = tf.cast(step, tf.float32)
warmup = self.peak_lr * (step / self.warmup_steps)
decay_steps = self.total_steps - self.warmup_steps
decay_progress = (step - self.warmup_steps) / decay_steps
cosine = 0.5 * self.peak_lr * (1 + tf.cos(np.pi * decay_progress))
return tf.where(step < self.warmup_steps, warmup, cosine)
def get_config(self):
return {
'peak_lr': self.peak_lr,
'warmup_steps': self.warmup_steps,
'total_steps': self.total_steps
}
# Create warmup+cosine schedule
schedule = WarmupCosineDecay(peak_lr=1e-3, warmup_steps=1000, total_steps=10000)
for step in [0, 250, 500, 1000, 3000, 5000, 10000]:
print(f"Step {step:>5d}: LR = {schedule(step).numpy():.6f}")
# Plug it into an optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=schedule)
print(f"\nOptimizer LR at step 1000: {optimizer.learning_rate(1000).numpy():.6f}")
ReduceLROnPlateau
Instead of a fixed schedule, ReduceLROnPlateau monitors a metric and reduces the learning rate when it plateaus — a reactive approach that adapts to the actual training dynamics.
import tensorflow as tf
import numpy as np
# ReduceLROnPlateau callback
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', # metric to watch
factor=0.5, # multiply LR by this when triggered
patience=3, # epochs without improvement before reducing
min_lr=1e-6, # floor for the learning rate
verbose=1 # print when LR is reduced
)
# Quick demo
np.random.seed(42)
X = np.random.randn(500, 5).astype(np.float32)
y = (X @ np.random.randn(5, 1)).astype(np.float32)
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(5,)),
tf.keras.layers.Dense(1)
])
model.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss='mse')
model.fit(X, y, epochs=20, batch_size=32, validation_split=0.2,
callbacks=[reduce_lr], verbose=0)
print(f"Final LR: {model.optimizer.learning_rate.numpy():.6f}")
Metrics
Metrics measure model quality during training and evaluation. Unlike the loss function (which must be differentiable), metrics can be any function — accuracy, AUC, precision, recall, or custom aggregations.
Built-in Metrics
Here is the implementation for Built-in Metrics. Each code example below is self-contained and can be run independently:
import tensorflow as tf
import numpy as np
# Classification model with multiple metrics
np.random.seed(42)
X = np.random.randn(1000, 10).astype(np.float32)
y = (np.sum(X[:, :3], axis=1) > 0).astype(np.float32)
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=[
tf.keras.metrics.BinaryAccuracy(name='accuracy'),
tf.keras.metrics.AUC(name='auc'),
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall'),
]
)
history = model.fit(X, y, epochs=10, batch_size=32, verbose=0)
results = model.evaluate(X, y, verbose=0)
for name, value in zip(model.metrics_names, results):
print(f"{name:>12s}: {value:.4f}")
Custom Metrics
For metrics like F1 score that can't be computed per-batch (they need aggregate TP/FP/FN), subclass tf.keras.metrics.Metric to maintain state across batches.
import tensorflow as tf
import numpy as np
class F1Score(tf.keras.metrics.Metric):
"""Stateful F1 score that accumulates TP, FP, FN across batches."""
def __init__(self, threshold=0.5, name='f1_score', **kwargs):
super().__init__(name=name, **kwargs)
self.threshold = threshold
self.tp = self.add_weight(name='tp', initializer='zeros')
self.fp = self.add_weight(name='fp', initializer='zeros')
self.fn = self.add_weight(name='fn', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred_binary = tf.cast(y_pred >= self.threshold, tf.float32)
y_true = tf.cast(y_true, tf.float32)
self.tp.assign_add(tf.reduce_sum(y_pred_binary * y_true))
self.fp.assign_add(tf.reduce_sum(y_pred_binary * (1 - y_true)))
self.fn.assign_add(tf.reduce_sum((1 - y_pred_binary) * y_true))
def result(self):
precision = self.tp / (self.tp + self.fp + 1e-7)
recall = self.tp / (self.tp + self.fn + 1e-7)
return 2 * precision * recall / (precision + recall + 1e-7)
def reset_state(self):
self.tp.assign(0)
self.fp.assign(0)
self.fn.assign(0)
# Use the custom metric
np.random.seed(42)
X = np.random.randn(500, 8).astype(np.float32)
y = (np.sum(X[:, :3], axis=1) > 0).astype(np.float32)
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(8,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy',
metrics=['accuracy', F1Score(threshold=0.5)])
model.fit(X, y, epochs=10, batch_size=32, verbose=0)
results = model.evaluate(X, y, verbose=0)
print(f"Accuracy: {results[1]:.4f}")
print(f"F1 Score: {results[2]:.4f}")
compile() Deep Dive
The compile() method has many arguments beyond the standard trio of loss, optimizer, and metrics. Understanding them lets you unlock multi-output training, eager debugging, and XLA compilation.
Arguments & Options
loss — Loss function (string, function, or Loss instance). Can be a dict for multi-output models.
optimizer — Optimizer (string or instance). Strings like 'adam' use default hyperparameters.
metrics — List of metrics to track during training. Can be strings or Metric instances.
loss_weights — Dict or list of floats to weight losses in multi-output models.
run_eagerly — If True, disables tf.function tracing for easier debugging (slower).
jit_compile — If True, uses XLA compilation for faster training (requires compatible ops).
Multi-Output Models
When a model has multiple output heads, you can assign different loss functions and weights to each.
import tensorflow as tf
import numpy as np
# Multi-output model: regression head + classification head
inputs = tf.keras.Input(shape=(10,))
shared = tf.keras.layers.Dense(64, activation='relu')(inputs)
# Output 1: regression (predict price)
price_output = tf.keras.layers.Dense(1, name='price')(shared)
# Output 2: classification (predict category)
category_output = tf.keras.layers.Dense(3, activation='softmax', name='category')(shared)
model = tf.keras.Model(inputs=inputs, outputs=[price_output, category_output])
# Compile with per-output loss and weights
model.compile(
optimizer='adam',
loss={
'price': 'mse',
'category': 'sparse_categorical_crossentropy'
},
loss_weights={
'price': 1.0, # equal weight
'category': 0.5 # classification loss weighted 50%
},
metrics={
'price': ['mae'],
'category': ['accuracy']
},
jit_compile=True # XLA compilation for speed
)
# Generate dummy data
np.random.seed(42)
X = np.random.randn(500, 10).astype(np.float32)
y_price = np.random.randn(500, 1).astype(np.float32)
y_category = np.random.randint(0, 3, size=(500,))
model.fit(X, {'price': y_price, 'category': y_category},
epochs=5, batch_size=32, verbose=0)
results = model.evaluate(X, {'price': y_price, 'category': y_category}, verbose=0)
print(f"Total loss: {results[0]:.4f}")
model.summary()
Batch Size & Gradient Accumulation
Batch size controls how many samples the model sees before each weight update. It's a fundamental tradeoff between computational efficiency, memory usage, and optimization dynamics.
Memory vs Convergence
Gradient Accumulation
Gradient accumulation simulates a large batch size without the memory cost — you accumulate gradients over multiple mini-batches before applying a single weight update. This is the standard technique when your desired batch size exceeds GPU memory.
import tensorflow as tf
import numpy as np
# Gradient accumulation with a custom training step
class GradientAccumulationModel(tf.keras.Model):
"""Model that accumulates gradients over multiple mini-batches."""
def __init__(self, base_model, accumulation_steps=4):
super().__init__()
self.base_model = base_model
self.accumulation_steps = accumulation_steps
self.step_count = tf.Variable(0, trainable=False, dtype=tf.int32)
self.gradient_accumulator = None
def call(self, inputs, training=False):
return self.base_model(inputs, training=training)
def train_step(self, data):
x, y = data
# Initialize accumulator on first call
if self.gradient_accumulator is None:
self.gradient_accumulator = [
tf.Variable(tf.zeros_like(v), trainable=False)
for v in self.base_model.trainable_variables
]
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred)
# Scale loss by accumulation steps
scaled_loss = loss / self.accumulation_steps
grads = tape.gradient(scaled_loss, self.base_model.trainable_variables)
# Accumulate gradients
for acc, grad in zip(self.gradient_accumulator, grads):
if grad is not None:
acc.assign_add(grad)
self.step_count.assign_add(1)
# Apply when we've accumulated enough
if self.step_count % self.accumulation_steps == 0:
self.optimizer.apply_gradients(
zip(self.gradient_accumulator, self.base_model.trainable_variables)
)
for acc in self.gradient_accumulator:
acc.assign(tf.zeros_like(acc))
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
# Demo: effective batch size = 32 * 4 = 128, but only 32 in memory
np.random.seed(42)
X = np.random.randn(1000, 10).astype(np.float32)
y = (X @ np.random.randn(10, 1)).astype(np.float32)
base = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1)
])
model = GradientAccumulationModel(base, accumulation_steps=4)
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
model.fit(X, y, epochs=5, batch_size=32, verbose=0) # effective batch = 128
print(f"Final loss: {model.evaluate(X, y, verbose=0)[0]:.4f}")
Gradient Clipping
Gradient clipping prevents exploding gradients — a common problem in deep networks and RNNs where gradients grow exponentially through layers, causing weight updates that blow up the model.
Clipping Methods
TensorFlow optimizers support three clipping strategies directly via constructor arguments:
flowchart LR
A[Computed Gradients] --> B{Clipping Method}
B -->|clipvalue| C["Clip each element\nto [-val, val]"]
B -->|clipnorm| D["Scale gradient if\n‖g‖ > norm"]
B -->|global_clipnorm| E["Scale ALL gradients\nby global norm"]
C --> F[Clipped Gradients]
D --> F
E --> F
F --> G[Apply to Weights]
style A fill:#132440,stroke:#3B9797,color:#fff
style B fill:#3B9797,stroke:#132440,color:#fff
style G fill:#132440,stroke:#3B9797,color:#fff
import tensorflow as tf
import numpy as np
np.random.seed(42)
X = np.random.randn(500, 10).astype(np.float32)
y = (X @ np.random.randn(10, 1)).astype(np.float32)
# Method 1: clipvalue — clip each gradient element to [-0.5, 0.5]
model1 = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1)
])
model1.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3, clipvalue=0.5),
loss='mse'
)
model1.fit(X, y, epochs=5, batch_size=32, verbose=0)
print(f"clipvalue=0.5, loss: {model1.evaluate(X, y, verbose=0):.4f}")
# Method 2: clipnorm — scale gradient if its L2 norm exceeds threshold
model2 = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1)
])
model2.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3, clipnorm=1.0),
loss='mse'
)
model2.fit(X, y, epochs=5, batch_size=32, verbose=0)
print(f"clipnorm=1.0, loss: {model2.evaluate(X, y, verbose=0):.4f}")
# Method 3: global_clipnorm — scale ALL gradients by global norm
model3 = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1)
])
model3.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3, global_clipnorm=1.0),
loss='mse'
)
model3.fit(X, y, epochs=5, batch_size=32, verbose=0)
print(f"global_clipnorm=1.0, loss: {model3.evaluate(X, y, verbose=0):.4f}")
Clipping for RNNs
Recurrent networks (LSTM, GRU) are particularly prone to exploding gradients because the same weight matrices are applied at every timestep. global_clipnorm is the recommended strategy — it preserves the relative direction of gradients across all parameters.
import tensorflow as tf
import numpy as np
# Simulated sequence data
np.random.seed(42)
X_seq = np.random.randn(200, 50, 10).astype(np.float32) # 200 samples, 50 timesteps, 10 features
y_seq = np.random.randint(0, 2, size=(200,)).astype(np.float32)
# LSTM with gradient clipping — prevents exploding gradients
model = tf.keras.Sequential([
tf.keras.layers.LSTM(64, input_shape=(50, 10), return_sequences=True),
tf.keras.layers.LSTM(32),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(
learning_rate=1e-3,
global_clipnorm=1.0 # recommended for RNNs
),
loss='binary_crossentropy',
metrics=['accuracy']
)
history = model.fit(X_seq, y_seq, epochs=10, batch_size=32, verbose=0)
print(f"Final loss: {history.history['loss'][-1]:.4f}")
print(f"Final accuracy: {history.history['accuracy'][-1]:.4f}")
global_clipnorm=1.0 is standard practice. For simple feedforward networks with ReLU activations, clipping is usually unnecessary. If your loss suddenly jumps to NaN or infinity during training, try adding global_clipnorm=1.0.
Mixed Precision Training
Mixed precision training uses 16-bit floating point (float16) for forward/backward passes while keeping 32-bit master copies of weights. This can double training throughput on modern GPUs with Tensor Cores (NVIDIA V100, A100, RTX 30xx/40xx) while maintaining model accuracy.
Policy & Setup
Here is the implementation for Policy & Setup. Each code example below is self-contained and can be run independently:
import tensorflow as tf
# Enable mixed precision globally
tf.keras.mixed_precision.set_global_policy('mixed_float16')
# Verify the current policy
policy = tf.keras.mixed_precision.global_policy()
print(f"Policy name: {policy.name}")
print(f"Compute dtype: {policy.compute_dtype}") # float16
print(f"Variable dtype: {policy.variable_dtype}") # float32 (master weights)
# Build a model — layers automatically use float16 compute
model = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation='relu', input_shape=(100,)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
# IMPORTANT: output layer must be float32 for numerical stability
tf.keras.layers.Dense(10, dtype='float32')
])
# Check dtypes
for layer in model.layers:
print(f"{layer.name}: compute={layer.dtype_policy.compute_dtype}, "
f"variable={layer.dtype_policy.variable_dtype}")
# Reset to float32 for subsequent examples
tf.keras.mixed_precision.set_global_policy('float32')
Loss Scaling
Float16 has a smaller dynamic range than float32, which can cause small gradient values to underflow to zero. Loss scaling multiplies the loss by a large factor before backpropagation, then divides the gradients by the same factor — this keeps small gradients in float16's representable range.
import tensorflow as tf
import numpy as np
# Enable mixed precision
tf.keras.mixed_precision.set_global_policy('mixed_float16')
# Build model with float32 output
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(20,)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1, dtype='float32') # keep output in float32
])
# The optimizer automatically handles loss scaling in TF2
# LossScaleOptimizer wraps the base optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
model.compile(optimizer=optimizer, loss='mse', metrics=['mae'])
# Train with mixed precision
np.random.seed(42)
X = np.random.randn(1000, 20).astype(np.float32)
y = np.random.randn(1000, 1).astype(np.float32)
history = model.fit(X, y, epochs=10, batch_size=64, verbose=0)
print(f"Mixed precision final loss: {history.history['loss'][-1]:.4f}")
print(f"Mixed precision final MAE: {history.history['mae'][-1]:.4f}")
# Reset policy
tf.keras.mixed_precision.set_global_policy('float32')
dtype='float32'. On GPUs with Tensor Cores (V100+), expect 1.5–3× speedup. On CPUs, mixed precision offers no benefit. Layer dimensions should be multiples of 8 for optimal Tensor Core utilization.
Complete Example
Let's combine everything — an end-to-end training pipeline using a cosine learning rate schedule, the Adam optimizer with gradient clipping, multiple metrics, and mixed precision.
End-to-End Training Pipeline
Here is the implementation for End-to-End Training Pipeline. Each code example below is self-contained and can be run independently:
import tensorflow as tf
import numpy as np
# ── 1. Enable mixed precision ──────────────────────────────
tf.keras.mixed_precision.set_global_policy('mixed_float16')
# ── 2. Generate synthetic multi-class data ──────────────────
np.random.seed(42)
num_classes = 10
X_train = np.random.randn(5000, 50).astype(np.float32)
y_train = np.random.randint(0, num_classes, size=(5000,))
X_val = np.random.randn(1000, 50).astype(np.float32)
y_val = np.random.randint(0, num_classes, size=(1000,))
# ── 3. Learning rate schedule: warmup + cosine decay ────────
total_steps = (5000 // 64) * 30 # ~2343 steps
warmup_steps = total_steps // 10 # 10% warmup
lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate=1e-3,
decay_steps=total_steps - warmup_steps,
alpha=1e-5
)
# ── 4. Optimizer with gradient clipping ─────────────────────
optimizer = tf.keras.optimizers.AdamW(
learning_rate=lr_schedule,
weight_decay=1e-4,
global_clipnorm=1.0 # prevent exploding gradients
)
# ── 5. Build model ─────────────────────────────────────────
model = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation='relu', input_shape=(50,)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(64, activation='relu'),
# Output layer in float32 for mixed precision stability
tf.keras.layers.Dense(num_classes, dtype='float32')
])
# ── 6. Compile with multiple metrics ───────────────────────
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=3, name='top3_accuracy'),
],
jit_compile=True # XLA compilation for extra speed
)
# ── 7. Callbacks ────────────────────────────────────────────
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor='val_loss', patience=5, restore_best_weights=True
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6
),
]
# ── 8. Train ────────────────────────────────────────────────
history = model.fit(
X_train, y_train,
validation_data=(X_val, y_val),
epochs=30,
batch_size=64,
callbacks=callbacks,
verbose=1
)
# ── 9. Evaluate ─────────────────────────────────────────────
results = model.evaluate(X_val, y_val, verbose=0)
for name, value in zip(model.metrics_names, results):
print(f"{name:>15s}: {value:.4f}")
# ── 10. Predict ─────────────────────────────────────────────
predictions = model.predict(X_val[:5], verbose=0)
predicted_classes = np.argmax(predictions, axis=1)
print(f"\nPredicted classes: {predicted_classes}")
print(f"True classes: {y_val[:5]}")
# Reset precision policy
tf.keras.mixed_precision.set_global_policy('float32')
Summary & Next Steps
In this article we covered the complete TensorFlow training pipeline:
- Loss functions — MSE, MAE, Huber for regression; BCE, CCE, SparseCCE for classification; custom losses for specialized problems
- Optimizers — SGD (with momentum/Nesterov), Adam, AdamW, RMSProp, Adagrad — and when to choose each
- Learning rate schedules — ExponentialDecay, CosineDecay, PiecewiseConstantDecay, warmup strategies, ReduceLROnPlateau
- Metrics — built-in metrics (Accuracy, AUC, Precision, Recall) and custom stateful metrics (F1 score)
- compile() options — loss_weights for multi-output, run_eagerly for debugging, jit_compile for XLA speed
- Batch size — memory/convergence tradeoffs and gradient accumulation for large effective batch sizes
- Gradient clipping — clipvalue, clipnorm, global_clipnorm for preventing exploding gradients
- Mixed precision — float16 compute with float32 master weights for up to 3× speedup on Tensor Core GPUs
Next in the Series
In Part 4: Data Pipelines with tf.data, we'll build efficient input pipelines — from raw files to batched, shuffled, prefetched datasets that keep the GPU fully utilized during training.