model.fit() Deep Dive
The model.fit() method is the standard Keras training API. While most tutorials only show epochs and batch_size, the method offers dozens of parameters for controlling validation, class imbalance, verbosity, and more. Understanding every option lets you get the most from high-level training before resorting to custom loops.
model.fit() returns a History object containing all metrics recorded during training. This object is your primary tool for diagnosing overfitting, underfitting, and learning rate issues — always capture it: history = model.fit(...).
import tensorflow as tf
import numpy as np
# Generate synthetic classification data
np.random.seed(42)
X_train = np.random.randn(2000, 20).astype(np.float32)
y_train = (X_train[:, 0] + X_train[:, 1] > 0).astype(np.float32)
X_val = np.random.randn(500, 20).astype(np.float32)
y_val = (X_val[:, 0] + X_val[:, 1] > 0).astype(np.float32)
# Build a simple model
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(20,)),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# model.fit() with all key parameters
history = model.fit(
X_train, y_train,
epochs=20, # Number of full passes over training data
batch_size=64, # Samples per gradient update
validation_data=(X_val, y_val), # Explicit validation set (preferred)
# validation_split=0.2, # Alternative: split from training data
class_weight={0: 1.0, 1: 2.0}, # Upweight minority class
# sample_weight=np.ones(2000), # Per-sample weights (alternative)
verbose=1, # 0=silent, 1=progress bar, 2=one line/epoch
shuffle=True, # Shuffle training data each epoch
initial_epoch=0, # Resume from a specific epoch number
)
# History object contains all recorded metrics
print(f"Keys: {list(history.history.keys())}")
print(f"Final train accuracy: {history.history['accuracy'][-1]:.4f}")
print(f"Final val accuracy: {history.history['val_accuracy'][-1]:.4f}")
print(f"Total epochs trained: {len(history.history['loss'])}")
Working with the History Object
The History object is invaluable for plotting training curves and detecting overfitting. When validation loss diverges from training loss, your model is overfitting — a signal to apply regularization, data augmentation, or early stopping.
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Quick training for demo
np.random.seed(42)
X = np.random.randn(1000, 10).astype(np.float32)
y = (X[:, 0] > 0).astype(np.float32)
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
history = model.fit(X, y, epochs=30, validation_split=0.2, verbose=0)
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(history.history['loss'], label='Train Loss')
axes[0].plot(history.history['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss Curves')
axes[0].legend()
axes[1].plot(history.history['accuracy'], label='Train Accuracy')
axes[1].plot(history.history['val_accuracy'], label='Val Accuracy')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Accuracy Curves')
axes[1].legend()
plt.tight_layout()
plt.show()
The gap between training and validation curves tells you everything: if both curves converge to a good value, your model generalizes well. If training improves but validation plateaus or worsens, you're overfitting.
Custom Training Loops
While model.fit() covers most scenarios, custom training loops with tf.GradientTape give you complete control over the training process. You need them for GANs (alternating generator/discriminator updates), reinforcement learning (policy gradient steps), multi-loss objectives, or any scenario where the standard fit workflow doesn't apply.
flowchart LR
A[Load Batch] --> B[Forward Pass]
B --> C[Compute Loss]
C --> D[Compute Gradients]
D --> E[Update Weights]
E --> F{More Batches?}
F -->|Yes| A
F -->|No| G[Validate]
G --> H{More Epochs?}
H -->|Yes| A
H -->|No| I[Done]
import tensorflow as tf
import numpy as np
# Generate data
np.random.seed(42)
X_train = np.random.randn(1000, 16).astype(np.float32)
y_train = (X_train[:, 0] * 2 + X_train[:, 1] + np.random.randn(1000) * 0.1).astype(np.float32)
# Create tf.data pipeline
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_ds = train_ds.shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)
# Build model
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(16,)),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(1)
])
# Custom training loop with GradientTape
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.MeanSquaredError()
# Training loop
for epoch in range(5):
epoch_loss = 0.0
num_batches = 0
for x_batch, y_batch in train_ds:
with tf.GradientTape() as tape:
# Forward pass inside tape context
predictions = model(x_batch, training=True)
loss = loss_fn(y_batch, predictions)
# Compute gradients
gradients = tape.gradient(loss, model.trainable_variables)
# Apply gradients
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
epoch_loss += loss.numpy()
num_batches += 1
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch + 1}/5, Loss: {avg_loss:.4f}")
Structuring train_step and test_step
For cleaner code, encapsulate the forward pass, loss computation, and gradient update into dedicated functions. Using @tf.function compiles these into a graph for faster execution:
import tensorflow as tf
import numpy as np
# Generate data
np.random.seed(42)
X_train = np.random.randn(800, 10).astype(np.float32)
y_train = (X_train[:, :5].sum(axis=1) > 0).astype(np.float32)
X_test = np.random.randn(200, 10).astype(np.float32)
y_test = (X_test[:, :5].sum(axis=1) > 0).astype(np.float32)
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(32)
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
optimizer = tf.keras.optimizers.Adam(1e-3)
loss_fn = tf.keras.losses.BinaryCrossentropy()
train_acc = tf.keras.metrics.BinaryAccuracy()
test_acc = tf.keras.metrics.BinaryAccuracy()
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
preds = model(x, training=True)
loss = loss_fn(y, preds)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_acc.update_state(y, preds)
return loss
@tf.function
def test_step(x, y):
preds = model(x, training=False)
loss = loss_fn(y, preds)
test_acc.update_state(y, preds)
return loss
# Full training loop with train/test separation
for epoch in range(10):
train_acc.reset_state()
test_acc.reset_state()
for x_batch, y_batch in train_ds:
train_step(x_batch, y_batch)
for x_batch, y_batch in test_ds:
test_step(x_batch, y_batch)
print(f"Epoch {epoch+1}: Train Acc={train_acc.result():.4f}, "
f"Test Acc={test_acc.result():.4f}")
Overriding train_step()
The middle ground between model.fit() and a fully custom loop is overriding train_step() in a Model subclass. This gives you custom training logic while retaining all the benefits of fit — callbacks, progress bars, and validation handling.
import tensorflow as tf
import numpy as np
class CustomModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(64, activation='relu')
self.dense2 = tf.keras.layers.Dense(32, activation='relu')
self.output_layer = tf.keras.layers.Dense(1, activation='sigmoid')
def call(self, inputs, training=False):
x = self.dense1(inputs)
x = self.dense2(x)
return self.output_layer(x)
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
# Custom loss: binary crossentropy + L2 regularization
bce_loss = self.compute_loss(y=y, y_pred=y_pred)
l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in self.trainable_variables])
total_loss = bce_loss + 1e-4 * l2_loss
gradients = tape.gradient(total_loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
# Update compiled metrics
for metric in self.metrics:
if metric.name == 'loss':
metric.update_state(total_loss)
else:
metric.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
# Usage — still works with model.fit()!
np.random.seed(42)
X = np.random.randn(1000, 20).astype(np.float32)
y = (X[:, 0] > 0).astype(np.float32)
custom_model = CustomModel()
custom_model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy']
)
history = custom_model.fit(X, y, epochs=10, batch_size=32,
validation_split=0.2, verbose=1)
print(f"Final val accuracy: {history.history['val_accuracy'][-1]:.4f}")
Understanding Compiled Metrics
When you override train_step(), you're responsible for updating metrics. The self.metrics property includes all metrics passed to compile(), plus a built-in loss tracker. Always update them inside your custom step for accurate reporting.
Built-in Callbacks
Callbacks are objects that execute at various stages of training — at the start/end of epochs, batches, or the entire training run. They let you monitor progress, save checkpoints, adjust the learning rate, and stop training early without modifying the training loop. TensorFlow provides a rich set of built-in callbacks.
The EarlyStopping condition can be expressed mathematically. Training stops when:
$$\text{val\_loss}_{t} > \text{val\_loss}_{\text{best}} - \delta \quad \text{for } p \text{ consecutive epochs}$$
where $\delta$ is the min_delta threshold and $p$ is the patience parameter.
EarlyStopping
Here is the implementation for EarlyStopping. Each code example below is self-contained and can be run independently:
import tensorflow as tf
import numpy as np
# Generate data with some noise (will overfit)
np.random.seed(42)
X_train = np.random.randn(2000, 15).astype(np.float32)
y_train = (X_train[:, 0] + 0.5 * X_train[:, 1] > 0).astype(np.float32)
X_val = np.random.randn(500, 15).astype(np.float32)
y_val = (X_val[:, 0] + 0.5 * X_val[:, 1] > 0).astype(np.float32)
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(15,)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# EarlyStopping: Stop when val_loss doesn't improve for 5 epochs
early_stop = tf.keras.callbacks.EarlyStopping(
monitor='val_loss', # Metric to watch
patience=5, # Epochs with no improvement before stopping
min_delta=1e-4, # Minimum change to qualify as improvement
restore_best_weights=True, # Roll back to best epoch's weights
verbose=1, # Print when stopping
mode='min' # 'min' for loss, 'max' for accuracy
)
history = model.fit(
X_train, y_train,
epochs=100, # Set high — EarlyStopping will terminate
validation_data=(X_val, y_val),
callbacks=[early_stop],
verbose=0
)
print(f"Stopped at epoch: {len(history.history['loss'])}")
print(f"Best val_loss: {min(history.history['val_loss']):.4f}")
ModelCheckpoint & ReduceLROnPlateau
Here is the implementation for ModelCheckpoint & ReduceLROnPlateau. Each code example below is self-contained and can be run independently:
import tensorflow as tf
import numpy as np
import os
np.random.seed(42)
X = np.random.randn(1000, 10).astype(np.float32)
y = (X[:, 0] + X[:, 1] > 0).astype(np.float32)
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
loss='binary_crossentropy', metrics=['accuracy'])
# ModelCheckpoint: Save best model based on val_loss
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
filepath='best_model.keras', # Save path (.keras recommended)
monitor='val_loss',
save_best_only=True, # Only save when val_loss improves
save_weights_only=False, # Save full model (architecture + weights)
verbose=1
)
# ReduceLROnPlateau: Reduce learning rate when metric plateaus
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5, # Multiply LR by this factor
patience=3, # Epochs before reduction
min_lr=1e-6, # Lower bound on learning rate
verbose=1
)
# CSVLogger: Save training metrics to CSV
csv_logger = tf.keras.callbacks.CSVLogger('training_log.csv', append=False)
# LearningRateScheduler: Custom LR schedule
def lr_schedule(epoch, lr):
"""Decay LR by 10% every 10 epochs."""
if epoch > 0 and epoch % 10 == 0:
return lr * 0.9
return lr
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule, verbose=0)
# Combine all callbacks
history = model.fit(
X, y, epochs=30, validation_split=0.2,
callbacks=[checkpoint_cb, reduce_lr, csv_logger],
verbose=0
)
print(f"Final LR: {model.optimizer.learning_rate.numpy():.6f}")
print(f"Log saved to: training_log.csv")
# Clean up demo files
for f in ['best_model.keras', 'training_log.csv']:
if os.path.exists(f):
os.remove(f)
Callback Combinations
In practice, you'll use multiple callbacks together. A common production combination is EarlyStopping + ModelCheckpoint + ReduceLROnPlateau + TensorBoard. They work in order — callbacks execute in the order they appear in the list.
Custom Callbacks
When built-in callbacks don't meet your needs, subclass tf.keras.callbacks.Callback and override the hooks you need. The API provides hooks at every stage of training: on_train_begin/end, on_epoch_begin/end, on_batch_begin/end, and their test/predict counterparts.
flowchart TD
A[on_train_begin] --> B[on_epoch_begin]
B --> C[on_batch_begin]
C --> D[Train Step]
D --> E[on_batch_end]
E --> F{More Batches?}
F -->|Yes| C
F -->|No| G[on_epoch_end]
G --> H{More Epochs?}
H -->|Yes| B
H -->|No| I[on_train_end]
Practical Custom Callback Examples
Here is the implementation for Practical Custom Callback Examples. Each code example below is self-contained and can be run independently:
import tensorflow as tf
import numpy as np
import time
class TimingCallback(tf.keras.callbacks.Callback):
"""Tracks time per epoch and total training time."""
def on_train_begin(self, logs=None):
self.train_start = time.time()
self.epoch_times = []
def on_epoch_begin(self, epoch, logs=None):
self.epoch_start = time.time()
def on_epoch_end(self, epoch, logs=None):
elapsed = time.time() - self.epoch_start
self.epoch_times.append(elapsed)
print(f" â± Epoch {epoch + 1}: {elapsed:.2f}s")
def on_train_end(self, logs=None):
total = time.time() - self.train_start
avg = np.mean(self.epoch_times)
print(f"\n✓ Training complete: {total:.2f}s total, {avg:.2f}s/epoch avg")
class GradientNormCallback(tf.keras.callbacks.Callback):
"""Monitors gradient norms to detect vanishing/exploding gradients."""
def on_batch_end(self, batch, logs=None):
if batch % 10 == 0: # Log every 10 batches
# Access model gradients (requires custom training or hooks)
total_norm = 0.0
for var in self.model.trainable_variables:
if var.name.endswith('kernel:0'):
total_norm += tf.reduce_sum(var ** 2).numpy()
# This is a simplified version — real gradient monitoring
# requires GradientTape or tf.debugging tools
# Usage
np.random.seed(42)
X = np.random.randn(500, 8).astype(np.float32)
y = (X[:, 0] > 0).astype(np.float32)
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(8,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
history = model.fit(X, y, epochs=5, batch_size=32,
callbacks=[TimingCallback()], verbose=0)
Custom callbacks are particularly useful for integration with external logging systems (Weights & Biases, MLflow), sending notifications on training milestones, implementing custom early stopping logic, or dynamically modifying training behavior.
TensorBoard Integration
TensorBoard is TensorFlow's built-in visualization toolkit for monitoring training runs. It provides dashboards for scalars (loss, accuracy), histograms (weight distributions), images, graphs (model architecture), and profiling data — all updated in real time during training.
import tensorflow as tf
import numpy as np
import datetime
import os
np.random.seed(42)
X_train = np.random.randn(1000, 20).astype(np.float32)
y_train = (X_train[:, 0] + X_train[:, 1] > 0).astype(np.float32)
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(20,)),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# TensorBoard callback — logs to timestamped directory
log_dir = os.path.join("logs", "fit", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_cb = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1, # Log weight histograms every epoch
write_graph=True, # Log the computation graph
write_images=False, # Log model weights as images
update_freq='epoch', # 'batch' or 'epoch' or integer
profile_batch='2,5' # Profile batches 2-5 for performance analysis
)
history = model.fit(
X_train, y_train,
epochs=10, validation_split=0.2,
callbacks=[tensorboard_cb],
verbose=0
)
print(f"TensorBoard logs saved to: {log_dir}")
print("Launch with: tensorboard --logdir logs/fit")
print(f"Final val_accuracy: {history.history['val_accuracy'][-1]:.4f}")
Custom Summaries with tf.summary
For custom training loops, use tf.summary to log arbitrary scalars, images, histograms, and text directly to TensorBoard:
import tensorflow as tf
import numpy as np
import datetime
import os
# Custom logging with tf.summary in a training loop
np.random.seed(42)
X = np.random.randn(500, 10).astype(np.float32)
y = (X[:, 0] > 0).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices((X, y)).batch(32)
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
optimizer = tf.keras.optimizers.Adam(1e-3)
loss_fn = tf.keras.losses.BinaryCrossentropy()
# Create summary writer
log_dir = os.path.join("logs", "custom", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
summary_writer = tf.summary.create_file_writer(log_dir)
global_step = 0
for epoch in range(5):
epoch_loss = []
for x_batch, y_batch in dataset:
with tf.GradientTape() as tape:
preds = model(x_batch, training=True)
loss = loss_fn(y_batch, preds)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
# Log per-step metrics
with summary_writer.as_default(step=global_step):
tf.summary.scalar('batch_loss', loss)
# Log gradient norms
grad_norm = tf.sqrt(sum(tf.reduce_sum(g**2) for g in grads if g is not None))
tf.summary.scalar('gradient_norm', grad_norm)
epoch_loss.append(loss.numpy())
global_step += 1
# Log per-epoch metrics
with summary_writer.as_default(step=epoch):
tf.summary.scalar('epoch_loss', np.mean(epoch_loss))
# Log weight histograms
for var in model.trainable_variables:
tf.summary.histogram(var.name, var)
print(f"Epoch {epoch+1}: loss={np.mean(epoch_loss):.4f}")
print(f"Custom logs: {log_dir}")
tensorboard --logdir logs/ in your terminal, then open http://localhost:6006 in your browser. For Colab, use %load_ext tensorboard followed by %tensorboard --logdir logs/.
Model Saving & Loading
TensorFlow supports two primary model formats: SavedModel (the default, recommended format) and HDF5 (legacy Keras format). SavedModel stores the full computation graph, variables, and signatures — ideal for deployment via TensorFlow Serving. The newer .keras format is the recommended Keras-native option.
import tensorflow as tf
import numpy as np
import os
np.random.seed(42)
X = np.random.randn(200, 10).astype(np.float32)
y = (X[:, 0] > 0).astype(np.float32)
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(10,)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.fit(X, y, epochs=5, verbose=0)
# Option 1: Save as .keras (recommended for Keras models)
model.save('my_model.keras')
loaded_model = tf.keras.models.load_model('my_model.keras')
# Option 2: SavedModel format (recommended for deployment/serving)
tf.saved_model.save(model, 'saved_model_dir')
loaded_sm = tf.keras.models.load_model('saved_model_dir')
# Option 3: Save weights only (smaller file, need architecture code)
model.save_weights('model_weights.weights.h5')
# To reload: model.load_weights('model_weights.weights.h5')
# Verify loaded model produces same output
test_input = np.random.randn(5, 10).astype(np.float32)
original_output = model.predict(test_input, verbose=0)
loaded_output = loaded_model.predict(test_input, verbose=0)
print(f"Outputs match: {np.allclose(original_output, loaded_output, atol=1e-6)}")
print(f".keras file size: {os.path.getsize('my_model.keras') / 1024:.1f} KB")
# Clean up
import shutil
os.remove('my_model.keras')
os.remove('model_weights.weights.h5')
if os.path.exists('saved_model_dir'):
shutil.rmtree('saved_model_dir')
Handling Custom Objects
When your model uses custom layers, losses, or metrics, you must register them for serialization. Use the @tf.keras.utils.register_keras_serializable decorator:
import tensorflow as tf
import numpy as np
import os
@tf.keras.utils.register_keras_serializable(package='MyPackage')
class SwishActivation(tf.keras.layers.Layer):
"""Custom Swish activation layer."""
def call(self, inputs):
return inputs * tf.nn.sigmoid(inputs)
def get_config(self):
return super().get_config()
@tf.keras.utils.register_keras_serializable(package='MyPackage')
def custom_loss(y_true, y_pred):
"""Custom focal-like loss."""
bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
pt = tf.where(tf.equal(y_true, 1), y_pred, 1 - y_pred)
return bce * (1 - pt) ** 2
# Build model with custom components
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, input_shape=(10,)),
SwishActivation(),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
# Train briefly
np.random.seed(42)
X = np.random.randn(100, 10).astype(np.float32)
y = (X[:, 0] > 0).astype(np.float32)
model.fit(X, y, epochs=3, verbose=0)
# Save and reload — custom objects automatically resolved
model.save('custom_model.keras')
loaded = tf.keras.models.load_model('custom_model.keras')
preds = loaded.predict(X[:3], verbose=0)
print(f"Predictions from loaded model: {preds.flatten()}")
os.remove('custom_model.keras')
Checkpointing Strategies
For custom training loops (where ModelCheckpoint callback isn't available), TensorFlow provides tf.train.Checkpoint and tf.train.CheckpointManager. These give you fine-grained control over what gets saved, when, and how many checkpoints to retain.
import tensorflow as tf
import numpy as np
import os
np.random.seed(42)
X = np.random.randn(500, 10).astype(np.float32)
y = (X[:, 0] > 0).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices((X, y)).batch(32)
# Build model and optimizer
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
optimizer = tf.keras.optimizers.Adam(1e-3)
loss_fn = tf.keras.losses.BinaryCrossentropy()
# Create Checkpoint — tracks model + optimizer state
checkpoint = tf.train.Checkpoint(
model=model,
optimizer=optimizer,
epoch=tf.Variable(0)
)
# CheckpointManager — manages retention policy
checkpoint_dir = './checkpoints'
manager = tf.train.CheckpointManager(
checkpoint,
directory=checkpoint_dir,
max_to_keep=3, # Only keep last 3 checkpoints
checkpoint_name='ckpt'
)
# Restore from latest checkpoint (if exists)
if manager.latest_checkpoint:
checkpoint.restore(manager.latest_checkpoint)
print(f"Restored from {manager.latest_checkpoint}")
else:
print("Starting fresh training")
# Training with periodic checkpointing
for epoch in range(10):
checkpoint.epoch.assign_add(1)
epoch_loss = []
for x_batch, y_batch in dataset:
with tf.GradientTape() as tape:
preds = model(x_batch, training=True)
loss = loss_fn(y_batch, preds)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
epoch_loss.append(loss.numpy())
avg_loss = np.mean(epoch_loss)
# Save checkpoint every 3 epochs
if (epoch + 1) % 3 == 0:
save_path = manager.save()
print(f"Epoch {epoch+1}: loss={avg_loss:.4f} | Saved: {save_path}")
else:
print(f"Epoch {epoch+1}: loss={avg_loss:.4f}")
print(f"\nAll checkpoints: {manager.checkpoints}")
# Clean up
import shutil
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
Restoring from Checkpoints
The power of tf.train.Checkpoint is that it saves both model weights and optimizer state (momentum, adaptive learning rates). This means training resumes exactly where it left off — no warm-up period needed after a restart.
Custom Metrics
TensorFlow's built-in metrics cover common cases, but many tasks require custom metrics. Subclass tf.keras.metrics.Metric and implement three methods: update_state() (accumulate batch statistics), result() (compute the final metric), and reset_state() (clear between epochs).
F1 Score Implementation
Here is the implementation for F1 Score Implementation. Each code example below is self-contained and can be run independently:
import tensorflow as tf
import numpy as np
class F1Score(tf.keras.metrics.Metric):
"""Streaming F1 Score metric for binary classification."""
def __init__(self, threshold=0.5, name='f1_score', **kwargs):
super().__init__(name=name, **kwargs)
self.threshold = threshold
self.true_positives = self.add_weight(name='tp', initializer='zeros')
self.false_positives = self.add_weight(name='fp', initializer='zeros')
self.false_negatives = self.add_weight(name='fn', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred_binary = tf.cast(y_pred >= self.threshold, tf.float32)
y_true_f = tf.cast(y_true, tf.float32)
tp = tf.reduce_sum(y_true_f * y_pred_binary)
fp = tf.reduce_sum((1 - y_true_f) * y_pred_binary)
fn = tf.reduce_sum(y_true_f * (1 - y_pred_binary))
self.true_positives.assign_add(tp)
self.false_positives.assign_add(fp)
self.false_negatives.assign_add(fn)
def result(self):
precision = self.true_positives / (self.true_positives + self.false_positives + 1e-7)
recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-7)
f1 = 2 * (precision * recall) / (precision + recall + 1e-7)
return f1
def reset_state(self):
self.true_positives.assign(0)
self.false_positives.assign(0)
self.false_negatives.assign(0)
# Usage with model.compile()
np.random.seed(42)
X = np.random.randn(1000, 10).astype(np.float32)
y = (X[:, 0] + X[:, 1] > 0.5).astype(np.float32) # Imbalanced (~35% positive)
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', F1Score(threshold=0.5)]
)
history = model.fit(X, y, epochs=10, validation_split=0.2, verbose=0)
print(f"Final F1 Score (train): {history.history['f1_score'][-1]:.4f}")
print(f"Final F1 Score (val): {history.history['val_f1_score'][-1]:.4f}")
print(f"Final Accuracy (val): {history.history['val_accuracy'][-1]:.4f}")
Custom metrics are essential for tasks where accuracy is misleading. For imbalanced datasets, F1 score, AUC-PR, or Matthews Correlation Coefficient provide more meaningful signals than raw accuracy.
import tensorflow as tf
import numpy as np
class MeanAbsolutePercentageError(tf.keras.metrics.Metric):
"""Streaming MAPE metric for regression."""
def __init__(self, name='mape', **kwargs):
super().__init__(name=name, **kwargs)
self.total_error = self.add_weight(name='total_error', initializer='zeros')
self.count = self.add_weight(name='count', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
# Avoid division by zero
mask = tf.not_equal(y_true, 0)
safe_true = tf.boolean_mask(y_true, mask)
safe_pred = tf.boolean_mask(y_pred, mask)
ape = tf.abs((safe_true - safe_pred) / safe_true) * 100
self.total_error.assign_add(tf.reduce_sum(ape))
self.count.assign_add(tf.cast(tf.size(ape), tf.float32))
def result(self):
return self.total_error / (self.count + 1e-7)
def reset_state(self):
self.total_error.assign(0)
self.count.assign(0)
# Regression example
np.random.seed(42)
X = np.random.randn(500, 5).astype(np.float32)
y = (3 * X[:, 0] + 2 * X[:, 1] + 10).astype(np.float32) # Linear target
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(5,)),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse',
metrics=[MeanAbsolutePercentageError()])
history = model.fit(X, y, epochs=20, validation_split=0.2, verbose=0)
print(f"Final MAPE (val): {history.history['val_mape'][-1]:.2f}%")
Production Training Template
Here's a complete production-ready training script that combines all best practices covered in this article: data pipeline, model definition, callbacks, logging, checkpointing, and model export. This template is suitable for real-world training jobs and serves as a starting point for any TensorFlow project.
import tensorflow as tf
import numpy as np
import os
import datetime
import json
# ============================================================
# PRODUCTION TRAINING TEMPLATE
# Combines: tf.data pipeline + callbacks + TensorBoard +
# checkpointing + model export + custom metrics
# ============================================================
# --- Configuration ---
CONFIG = {
'batch_size': 64,
'epochs': 50,
'learning_rate': 1e-3,
'patience': 7,
'num_features': 20,
'hidden_units': [128, 64, 32],
'dropout_rate': 0.3,
'checkpoint_dir': './training_output/checkpoints',
'log_dir': './training_output/logs',
'export_dir': './training_output/saved_model',
}
# --- Data Pipeline ---
np.random.seed(42)
X_train = np.random.randn(5000, CONFIG['num_features']).astype(np.float32)
y_train = (X_train[:, 0] + 0.5 * X_train[:, 1] - X_train[:, 2] > 0).astype(np.float32)
X_val = np.random.randn(1000, CONFIG['num_features']).astype(np.float32)
y_val = (X_val[:, 0] + 0.5 * X_val[:, 1] - X_val[:, 2] > 0).astype(np.float32)
train_ds = (tf.data.Dataset.from_tensor_slices((X_train, y_train))
.shuffle(5000)
.batch(CONFIG['batch_size'])
.prefetch(tf.data.AUTOTUNE))
val_ds = (tf.data.Dataset.from_tensor_slices((X_val, y_val))
.batch(CONFIG['batch_size'])
.prefetch(tf.data.AUTOTUNE))
# --- Model Definition ---
def build_model(config):
inputs = tf.keras.Input(shape=(config['num_features'],))
x = inputs
for units in config['hidden_units']:
x = tf.keras.layers.Dense(units, activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dropout(config['dropout_rate'])(x)
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
return tf.keras.Model(inputs, outputs, name='production_model')
model = build_model(CONFIG)
model.compile(
optimizer=tf.keras.optimizers.Adam(CONFIG['learning_rate']),
loss='binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)
# --- Callbacks Stack ---
os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)
os.makedirs(CONFIG['log_dir'], exist_ok=True)
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor='val_auc', patience=CONFIG['patience'],
mode='max', restore_best_weights=True, verbose=1
),
tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(CONFIG['checkpoint_dir'], 'best.keras'),
monitor='val_auc', mode='max',
save_best_only=True, verbose=1
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', factor=0.5,
patience=3, min_lr=1e-6, verbose=1
),
tf.keras.callbacks.TensorBoard(
log_dir=os.path.join(CONFIG['log_dir'],
datetime.datetime.now().strftime("%Y%m%d-%H%M%S")),
histogram_freq=1
),
tf.keras.callbacks.CSVLogger(
os.path.join(CONFIG['log_dir'], 'training_history.csv')
),
]
# --- Training ---
print("=" * 60)
print("STARTING TRAINING")
print(f"Config: {json.dumps(CONFIG, indent=2)}")
print("=" * 60)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=CONFIG['epochs'],
callbacks=callbacks,
verbose=1
)
# --- Export Model ---
os.makedirs(CONFIG['export_dir'], exist_ok=True)
model.save(os.path.join(CONFIG['export_dir'], 'final_model.keras'))
# Save config alongside model
with open(os.path.join(CONFIG['export_dir'], 'config.json'), 'w') as f:
json.dump(CONFIG, f, indent=2)
# --- Summary ---
best_epoch = np.argmax(history.history['val_auc'])
print("\n" + "=" * 60)
print("TRAINING COMPLETE")
print(f"Best epoch: {best_epoch + 1}")
print(f"Best val_auc: {history.history['val_auc'][best_epoch]:.4f}")
print(f"Best val_accuracy: {history.history['val_accuracy'][best_epoch]:.4f}")
print(f"Model saved to: {CONFIG['export_dir']}")
print("=" * 60)
# Clean up demo files
import shutil
if os.path.exists('./training_output'):
shutil.rmtree('./training_output')
Summary & Next Steps
In this article, you mastered the full spectrum of TensorFlow training workflows:
model.fit()— the high-level API with History tracking, class weights, and validation- Custom loops — tf.GradientTape for GANs, RL, and multi-loss scenarios
- train_step() override — the best of both worlds (custom logic + fit benefits)
- Callbacks — EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
- Custom callbacks — timing, logging, gradient monitoring
- TensorBoard — scalars, histograms, profiling with tf.summary
- Model saving — .keras, SavedModel, and weight-only formats
- Checkpointing — tf.train.Checkpoint + CheckpointManager for resumable training
- Custom metrics — F1 Score, MAPE, and streaming metric patterns
Next in the Series
In Part 6: CNNs & Computer Vision, we'll build convolutional neural networks from scratch — Conv2D, pooling, batch normalization, transfer learning with pretrained models, and complete image classification pipelines.