Back to PyTorch Mastery Series

Part 3: Training, Evaluation & Checkpointing

May 3, 2026 Wasil Zafar 30 min read

From the 5-step training loop to GPU acceleration, mixed precision, and distributed training — learn how to train, evaluate, save, and scale PyTorch models like a professional.

Table of Contents

  1. Training Loop Anatomy
  2. Complete Training Loop
  3. Model Evaluation & Metrics
  4. Validation During Training
  5. Saving & Loading Models
  6. GPU Acceleration
  7. Mixed Precision Training
  8. Distributed Training
  9. Early Stopping
  10. Gradient Clipping
  11. Text Generation Strategies
  12. Conclusion & Next Steps

The Training Loop Anatomy

Training a neural network in PyTorch revolves around a deceptively simple loop. Unlike frameworks such as Keras where model.fit() hides everything, PyTorch gives you explicit control over every step. This is both its greatest strength (maximum flexibility) and its biggest source of bugs (easy to get the order wrong).

The core training loop has exactly five steps, and the order matters:

  1. Forward pass — Feed input through the model to get predictions
  2. Compute loss — Compare predictions against ground truth
  3. Backward pass — Compute gradients of the loss with respect to all parameters
  4. Optimizer step — Update parameters using computed gradients
  5. Zero gradients — Reset gradients to prevent accumulation
The PyTorch Training Loop
flowchart TD
    A["🔄 Start Epoch"] --> B["1️⃣ Forward Pass
predictions = model(inputs)"] B --> C["2️⃣ Compute Loss
loss = criterion(predictions, targets)"] C --> D["3️⃣ Backward Pass
loss.backward()"] D --> E["4️⃣ Optimizer Step
optimizer.step()"] E --> F["5️⃣ Zero Gradients
optimizer.zero_grad()"] F --> G{"More batches?"} G -->|Yes| B G -->|No| H{"More epochs?"} H -->|Yes| A H -->|No| I["✅ Training Complete"]

Let's walk through each step to understand why it exists and what happens internally.

The 5-Step Loop in Code

The following code demonstrates the minimal training loop with all five steps clearly labeled. We create a tiny model, generate synthetic data, and run three epochs to watch the loss decrease. Every import and data creation step is included so you can run this independently:

import torch
import torch.nn as nn
import torch.optim as optim

# Create synthetic dataset: y = 2x + 1 with noise
torch.manual_seed(42)
X = torch.randn(100, 1)
y = 2 * X + 1 + 0.1 * torch.randn(100, 1)

# Simple linear model
model = nn.Linear(1, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Training loop — 5 steps
for epoch in range(3):
    # Step 1: Forward pass
    predictions = model(X)

    # Step 2: Compute loss
    loss = criterion(predictions, y)

    # Step 3: Backward pass (compute gradients)
    loss.backward()

    # Step 4: Update parameters
    optimizer.step()

    # Step 5: Zero gradients for next iteration
    optimizer.zero_grad()

    print(f"Epoch {epoch+1}/3, Loss: {loss.item():.4f}")

# Check learned parameters (should be close to w=2, b=1)
w = model.weight.item()
b = model.bias.item()
print(f"\nLearned: y = {w:.3f}x + {b:.3f}")

Notice that we call optimizer.zero_grad() after the step. This is perfectly valid — many tutorials place it at the beginning of the loop body instead. Both orderings work because gradients are zeroed before the next backward() call either way.

Key Insight: PyTorch accumulates gradients by default. If you forget zero_grad(), each backward() call adds to the existing gradients instead of replacing them. This is occasionally useful (e.g., for gradient accumulation with large effective batch sizes), but in most cases it's a bug.

Common Ordering Mistakes

Getting the five steps in the wrong order is one of the most frequent PyTorch bugs. Here are the mistakes that cause silent failures — the code runs without error but the model never learns:

Common Pitfalls
Training Loop Anti-Patterns

Bug 1: Calling step() before backward() — The optimizer updates parameters using whatever stale gradients exist (zeros on the first iteration). The model never learns from the current loss.

Bug 2: Calling zero_grad() after backward() but before step() — You compute fresh gradients, immediately erase them, then the optimizer has nothing to work with.

Bug 3: Forgetting zero_grad() entirely — Gradients accumulate across iterations, causing increasingly unstable updates and divergence.

debugging gradient flow training stability

Implementing a Complete Training Loop

Real training loops do more than the minimal five steps. You need to track loss history, count epochs, print progress, and often work with batches of data rather than the full dataset at once. The code below shows a production-style training loop for a classification task with all these features:

import torch
import torch.nn as nn
import torch.optim as optim

# Generate synthetic classification data (3 classes, 4 features)
torch.manual_seed(42)
num_samples = 300
X = torch.randn(num_samples, 4)
y = torch.randint(0, 3, (num_samples,))

# Define a simple classifier
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(4, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 3)
        )

    def forward(self, x):
        return self.net(x)

model = Classifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Track loss history
loss_history = []
num_epochs = 20

for epoch in range(num_epochs):
    model.train()  # Set to training mode

    # Forward pass
    outputs = model(X)
    loss = criterion(outputs, y)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Track and report
    loss_history.append(loss.item())
    if (epoch + 1) % 5 == 0:
        # Compute training accuracy
        _, predicted = torch.max(outputs, 1)
        accuracy = (predicted == y).float().mean().item()
        print(f"Epoch [{epoch+1:3d}/{num_epochs}]  "
              f"Loss: {loss.item():.4f}  "
              f"Accuracy: {accuracy:.2%}")

print(f"\nFinal loss: {loss_history[-1]:.4f}")
print(f"Loss reduction: {loss_history[0]:.4f} → {loss_history[-1]:.4f}")

There are a few important details here. First, we call model.train() at the start of each epoch — this ensures layers like Dropout and BatchNorm behave correctly during training. Second, we compute accuracy alongside loss to get a more interpretable measure of progress. Third, we print every 5 epochs to avoid flooding the console during long training runs.

Training with Mini-Batches

In practice, you rarely train on the entire dataset in one forward pass. Mini-batch training breaks the data into smaller chunks, which provides noisy but useful gradient estimates, uses less memory, and often converges faster. Here we manually implement batching without a DataLoader (Part 4 covers DataLoader in depth):

import torch
import torch.nn as nn
import torch.optim as optim

# Synthetic regression data
torch.manual_seed(42)
X = torch.randn(200, 3)
y = X @ torch.tensor([1.5, -2.0, 0.5]) + 0.3 + 0.1 * torch.randn(200)

# Model, loss, optimizer
model = nn.Linear(3, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Mini-batch training
batch_size = 32
num_epochs = 10
n_samples = X.size(0)

for epoch in range(num_epochs):
    epoch_loss = 0.0
    num_batches = 0

    # Shuffle indices each epoch
    indices = torch.randperm(n_samples)

    for start in range(0, n_samples, batch_size):
        end = min(start + batch_size, n_samples)
        batch_idx = indices[start:end]

        X_batch = X[batch_idx]
        y_batch = y[batch_idx]

        # Forward + backward + step
        predictions = model(X_batch).squeeze()
        loss = criterion(predictions, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        num_batches += 1

    avg_loss = epoch_loss / num_batches
    if (epoch + 1) % 2 == 0:
        print(f"Epoch {epoch+1:2d}/{num_epochs}  Avg Loss: {avg_loss:.4f}")

Note how we shuffle indices each epoch with torch.randperm(). Without shuffling, the model sees data in the same order every epoch, which can create systematic biases in the gradient updates. Shuffling ensures each mini-batch is a random sample of the full dataset.

Model Evaluation & Metrics

model.eval() vs model.train()

PyTorch models have two modes: training mode (model.train()) and evaluation mode (model.eval()). This distinction matters because certain layers behave differently depending on the mode:

  • Dropout — Active in training (randomly zeros neurons), disabled in eval (uses all neurons)
  • BatchNorm — Uses batch statistics in training, uses running mean/variance in eval

Additionally, you should wrap evaluation code in torch.no_grad() to disable gradient computation. This saves memory and speeds up inference since PyTorch doesn't need to track operations for backpropagation:

import torch
import torch.nn as nn

# Model with Dropout (behaves differently in train vs eval)
class DropoutNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(4, 4)
        self.dropout = nn.Dropout(p=0.5)
        self.output = nn.Linear(4, 2)

    def forward(self, x):
        x = torch.relu(self.layer(x))
        x = self.dropout(x)
        return self.output(x)

model = DropoutNet()
sample = torch.randn(1, 4)

# Training mode — Dropout is active (outputs vary)
model.train()
train_outputs = [model(sample).detach() for _ in range(3)]
print("Training mode outputs (vary due to Dropout):")
for i, out in enumerate(train_outputs):
    print(f"  Run {i+1}: {out.numpy().flatten()}")

# Evaluation mode — Dropout disabled (outputs are deterministic)
model.eval()
with torch.no_grad():
    eval_outputs = [model(sample) for _ in range(3)]
    print("\nEval mode outputs (deterministic):")
    for i, out in enumerate(eval_outputs):
        print(f"  Run {i+1}: {out.numpy().flatten()}")

The outputs in training mode differ each time because Dropout randomly zeros different neurons. In eval mode, every run produces the identical result. Always remember to switch back to model.train() before resuming training.

Classification Metrics

Loss alone doesn't tell the full story. For classification tasks, you need metrics like accuracy, precision, recall, and F1 score to understand how your model is failing. The following code computes all these metrics from scratch using pure PyTorch — no scikit-learn required:

import torch
import torch.nn as nn

# Simulate a trained model's predictions
torch.manual_seed(42)
num_classes = 3
num_samples = 100

# Simulated logits (raw model output) and true labels
logits = torch.randn(num_samples, num_classes)
true_labels = torch.randint(0, num_classes, (num_samples,))

# Get predicted classes
predicted = torch.argmax(logits, dim=1)

# Overall accuracy
correct = (predicted == true_labels).sum().item()
accuracy = correct / num_samples
print(f"Accuracy: {accuracy:.2%} ({correct}/{num_samples})")

# Per-class precision, recall, F1
print(f"\n{'Class':<8} {'Precision':<12} {'Recall':<10} {'F1':<10} {'Support':<10}")
print("-" * 50)

for cls in range(num_classes):
    tp = ((predicted == cls) & (true_labels == cls)).sum().item()
    fp = ((predicted == cls) & (true_labels != cls)).sum().item()
    fn = ((predicted != cls) & (true_labels == cls)).sum().item()
    support = (true_labels == cls).sum().item()

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

    print(f"  {cls:<6} {precision:<12.4f} {recall:<10.4f} {f1:<10.4f} {support:<10}")

# Confusion matrix
confusion = torch.zeros(num_classes, num_classes, dtype=torch.int64)
for t, p in zip(true_labels, predicted):
    confusion[t, p] += 1

print(f"\nConfusion Matrix:")
print(f"{'':>8}", end="")
for c in range(num_classes):
    print(f"Pred {c:>3}", end="  ")
print()
for i in range(num_classes):
    print(f"True {i:>2} ", end=" ")
    for j in range(num_classes):
        print(f"{confusion[i,j].item():>6}", end="  ")
    print()

The confusion matrix is especially useful — it shows exactly where your model confuses one class for another. Diagonal entries are correct predictions; off-diagonal entries are errors. High off-diagonal values in a specific cell tell you which class pair the model struggles with most.

Precision vs Recall: Precision asks "of everything the model predicted as class X, how many were actually X?" Recall asks "of everything that was actually class X, how many did the model find?" F1 is the harmonic mean of both — it penalises models that sacrifice one for the other.

Validation During Training

Train/Val Split Strategy

Monitoring only training loss is like grading your own homework — you need an independent measure to know if your model actually generalises. A validation set is a portion of data held out from training, used solely to evaluate the model at the end of each epoch. The key pattern is: train on training data, evaluate on validation data, compare both losses.

import torch
import torch.nn as nn
import torch.optim as optim

# Synthetic dataset
torch.manual_seed(42)
X = torch.randn(500, 4)
y = (X[:, 0] + X[:, 1] > 0).long()  # Binary classification

# 80/20 train/val split
n_train = int(0.8 * len(X))
indices = torch.randperm(len(X))
train_idx, val_idx = indices[:n_train], indices[n_train:]

X_train, y_train = X[train_idx], y[train_idx]
X_val, y_val = X[val_idx], y[val_idx]

print(f"Training samples: {len(X_train)}")
print(f"Validation samples: {len(X_val)}")

# Model
model = nn.Sequential(
    nn.Linear(4, 16), nn.ReLU(),
    nn.Linear(16, 8), nn.ReLU(),
    nn.Linear(8, 2)
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training with validation tracking
train_losses, val_losses = [], []

for epoch in range(30):
    # --- Training phase ---
    model.train()
    train_out = model(X_train)
    train_loss = criterion(train_out, y_train)

    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()
    train_losses.append(train_loss.item())

    # --- Validation phase ---
    model.eval()
    with torch.no_grad():
        val_out = model(X_val)
        val_loss = criterion(val_out, y_val)
        val_losses.append(val_loss.item())

    if (epoch + 1) % 10 == 0:
        train_acc = (train_out.argmax(1) == y_train).float().mean()
        val_acc = (val_out.argmax(1) == y_val).float().mean()
        print(f"Epoch {epoch+1:2d}  "
              f"Train Loss: {train_loss.item():.4f}  Val Loss: {val_loss.item():.4f}  "
              f"Train Acc: {train_acc:.2%}  Val Acc: {val_acc:.2%}")

The critical detail is the model.eval() and torch.no_grad() wrapping the validation phase. You must switch back to model.train() at the start of the next epoch (which we do at the top of the loop body).

Detecting Overfitting

Overfitting occurs when training loss keeps decreasing but validation loss starts increasing. This means the model is memorising the training data rather than learning generalisable patterns. By tracking both losses, you can catch this early:

import torch
import torch.nn as nn
import torch.optim as optim

# Intentionally small dataset to trigger overfitting
torch.manual_seed(42)
X_train = torch.randn(30, 4)
y_train = torch.randint(0, 2, (30,))
X_val = torch.randn(100, 4)
y_val = torch.randint(0, 2, (100,))

# Overly complex model for the data size
model = nn.Sequential(
    nn.Linear(4, 64), nn.ReLU(),
    nn.Linear(64, 64), nn.ReLU(),
    nn.Linear(64, 32), nn.ReLU(),
    nn.Linear(32, 2)
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)

print(f"{'Epoch':<8} {'Train Loss':<14} {'Val Loss':<14} {'Gap':<10} {'Status'}")
print("-" * 60)

for epoch in range(50):
    model.train()
    train_loss = criterion(model(X_train), y_train)
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        val_loss = criterion(model(X_val), y_val)

    gap = val_loss.item() - train_loss.item()

    if (epoch + 1) % 10 == 0:
        status = "⚠️  Overfitting" if gap > 0.3 else "✅ OK"
        print(f"{epoch+1:<8} {train_loss.item():<14.4f} {val_loss.item():<14.4f} {gap:<10.4f} {status}")

When you see the gap between validation loss and training loss growing, that's the classic overfitting signal. Remedies include adding dropout, reducing model complexity, collecting more training data, or using early stopping (covered later in this article).

Saving & Loading Models

The state_dict Pattern

PyTorch offers two approaches for saving models. The recommended approach saves only the state_dict() — a dictionary mapping each layer name to its parameter tensor. This is portable, version-resistant, and compact:

import torch
import torch.nn as nn
import os

# Define and train a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 32)
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x):
        return self.fc2(torch.relu(self.fc1(x)))

model = SimpleModel()

# Inspect what state_dict contains
print("state_dict keys:")
for key, tensor in model.state_dict().items():
    print(f"  {key}: shape={tensor.shape}")

# Save state_dict (recommended)
save_path = "model_weights.pth"
torch.save(model.state_dict(), save_path)
print(f"\nSaved to {save_path} ({os.path.getsize(save_path)} bytes)")

# Load into a new model instance
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load(save_path, weights_only=True))
loaded_model.eval()

# Verify weights match
for key in model.state_dict():
    match = torch.equal(model.state_dict()[key], loaded_model.state_dict()[key])
    print(f"  {key} match: {match}")

# Clean up
os.remove(save_path)

The weights_only=True argument in torch.load() is a security best practice — it prevents arbitrary code execution from malicious checkpoint files. Always use it when loading weights from untrusted sources.

Full Checkpoints for Resumed Training

If you need to resume training after an interruption, saving just the model weights isn't enough. You also need the optimizer state (momentum buffers, adaptive learning rates), the current epoch, the loss history, and any scheduler state. Here's the full checkpoint pattern:

import torch
import torch.nn as nn
import torch.optim as optim
import os

# Model and optimizer
model = nn.Sequential(nn.Linear(5, 16), nn.ReLU(), nn.Linear(16, 1))
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Simulate some training
for i in range(5):
    x = torch.randn(8, 5)
    loss = model(x).sum()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Save full checkpoint
checkpoint_path = "checkpoint.pth"
checkpoint = {
    'epoch': 5,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss.item(),
    'train_losses': [1.2, 0.9, 0.7, 0.5, 0.4],
}
torch.save(checkpoint, checkpoint_path)
print(f"Checkpoint saved at epoch {checkpoint['epoch']}")

# --- Restore from checkpoint ---
loaded_checkpoint = torch.load(checkpoint_path, weights_only=False)

# Recreate model and optimizer, then load state
restored_model = nn.Sequential(nn.Linear(5, 16), nn.ReLU(), nn.Linear(16, 1))
restored_optimizer = optim.Adam(restored_model.parameters(), lr=0.001)

restored_model.load_state_dict(loaded_checkpoint['model_state_dict'])
restored_optimizer.load_state_dict(loaded_checkpoint['optimizer_state_dict'])
start_epoch = loaded_checkpoint['epoch']
prev_losses = loaded_checkpoint['train_losses']

print(f"Resumed from epoch {start_epoch}")
print(f"Previous losses: {prev_losses}")

# Continue training from where we left off
for epoch in range(start_epoch, start_epoch + 3):
    x = torch.randn(8, 5)
    loss = restored_model(x).sum()
    restored_optimizer.zero_grad()
    loss.backward()
    restored_optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

os.remove(checkpoint_path)

Without restoring the optimizer state, Adam's internal momentum and variance estimates would reset to zero, causing a spike in the loss that can take many epochs to recover from. Always save the full checkpoint for long training runs.

Save Best Model Pattern

Instead of saving after every epoch (which wastes disk space), a common pattern is to save only when the validation loss improves. This ensures you always have the best-performing model on disk:

import torch
import torch.nn as nn
import torch.optim as optim
import os

# Setup
torch.manual_seed(42)
model = nn.Sequential(nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 2))
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

X_train = torch.randn(200, 4)
y_train = torch.randint(0, 2, (200,))
X_val = torch.randn(50, 4)
y_val = torch.randint(0, 2, (50,))

best_val_loss = float('inf')
best_model_path = "best_model.pth"

for epoch in range(20):
    model.train()
    loss = criterion(model(X_train), y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        val_loss = criterion(model(X_val), y_val).item()

    # Save if this is the best so far
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        saved_msg = " ← saved best"
    else:
        saved_msg = ""

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1:2d}  Val Loss: {val_loss:.4f}{saved_msg}")

# Load the best model for inference
model.load_state_dict(torch.load(best_model_path, weights_only=True))
model.eval()
print(f"\nLoaded best model (val loss = {best_val_loss:.4f})")

os.remove(best_model_path)

This is especially important for models that overfit later in training — the best checkpoint might be from epoch 15 even if you trained for 50 epochs.

GPU Acceleration & Device Management

Device-Agnostic Training Code

Writing code that runs seamlessly on both CPU and GPU is essential. The standard pattern is to define a device variable once and use .to(device) for all tensors and the model. This way, your code works everywhere without modification:

import torch
import torch.nn as nn
import torch.optim as optim

# Device-agnostic: use GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Extra info if GPU is available
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
    print(f"Using Apple Silicon MPS backend")

# Model and data — both must be on the same device
model = nn.Sequential(
    nn.Linear(10, 32), nn.ReLU(),
    nn.Linear(32, 3)
).to(device)

# Move data to device
X = torch.randn(64, 10).to(device)
y = torch.randint(0, 3, (64,)).to(device)

# Training works identically regardless of device
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(5):
    outputs = model(X)
    loss = criterion(outputs, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # .item() moves scalar back to CPU automatically
    print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")

The critical rule is: all tensors involved in a computation must be on the same device. If your model is on GPU but your data is on CPU, you get a runtime error. The .to(device) call is a no-op if the tensor is already on the target device, so it's always safe to call.

Common GPU Pitfall: Don't forget to move your labels to GPU too. A frequent bug is model(X.to(device)) with criterion(output, y) where y is still on CPU. The error message ("expected all tensors on same device") can be confusing the first time.

Full GPU Training Pipeline

Here's a complete training pipeline that handles device placement for model, data, checkpointing, and inference. This is the pattern you'll use in every real project:

import torch
import torch.nn as nn
import torch.optim as optim
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(8, 64), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(64, 32), nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.layers(x)

# Create model ON device from the start
model = MLP().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Data on device
torch.manual_seed(42)
X_train = torch.randn(256, 8).to(device)
y_train = torch.randn(256, 1).to(device)

# Train
model.train()
for epoch in range(10):
    pred = model(X_train)
    loss = criterion(pred, y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")

# Save checkpoint (state_dicts are always CPU-friendly)
checkpoint_path = "gpu_checkpoint.pth"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, checkpoint_path)

# Load on potentially different device (e.g., trained on GPU, deploy on CPU)
deploy_device = torch.device('cpu')
loaded_model = MLP().to(deploy_device)
checkpoint = torch.load(checkpoint_path, map_location=deploy_device, weights_only=False)
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model.eval()

# Inference on CPU
test_input = torch.randn(4, 8)  # CPU tensor
with torch.no_grad():
    predictions = loaded_model(test_input)
    print(f"\nPredictions (on {deploy_device}): {predictions.squeeze().tolist()}")

os.remove(checkpoint_path)

The map_location argument in torch.load() is crucial for cross-device loading. If you trained on GPU and want to deploy on a CPU-only machine, map_location='cpu' remaps all tensors to CPU during loading.

Mixed Precision Training

Mixed precision training uses a combination of 32-bit (FP32) and 16-bit (FP16) floating-point arithmetic to train models faster with less GPU memory. Modern GPUs have specialised hardware (Tensor Cores) that perform FP16 operations 2-8x faster than FP32. PyTorch's Automatic Mixed Precision (AMP) makes this nearly transparent:

import torch
import torch.nn as nn
import torch.optim as optim

# Mixed precision requires CUDA — fall back gracefully on CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_amp = torch.cuda.is_available()

# Large model to see AMP benefits
model = nn.Sequential(
    nn.Linear(512, 1024), nn.ReLU(),
    nn.Linear(1024, 1024), nn.ReLU(),
    nn.Linear(1024, 10)
).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# AMP components
scaler = torch.amp.GradScaler(enabled=use_amp)

# Training with mixed precision
torch.manual_seed(42)
for epoch in range(5):
    X = torch.randn(128, 512).to(device)
    y = torch.randint(0, 10, (128,)).to(device)

    optimizer.zero_grad()

    # autocast: forward pass uses FP16 where safe
    with torch.amp.autocast(device_type=device.type, enabled=use_amp):
        outputs = model(X)
        loss = criterion(outputs, y)

    # scaler: scales loss to prevent FP16 underflow, then unscales before step
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}  (AMP: {use_amp})")

if use_amp:
    print(f"\nScale factor: {scaler.get_scale()}")
else:
    print("\nAMP disabled (no CUDA). Code still works on CPU with full precision.")

The GradScaler is the key to making FP16 work reliably. FP16 has a narrower range than FP32, so very small gradients can become zero ("underflow"). The scaler multiplies the loss by a large factor before backward() so that gradients stay in the representable range, then divides them back before the optimizer step. It dynamically adjusts the scale factor if it detects infinities or NaNs.

When to Use Mixed Precision

Performance
Mixed Precision: When It Helps

Use AMP when: You have an NVIDIA GPU with Tensor Cores (Volta, Turing, Ampere, Hopper — basically any GPU from 2017+). Large batch sizes and large models see the biggest speedups (1.5-3x). Memory savings let you double your batch size.

Skip AMP when: You're training on CPU (no benefit). You're training a tiny model where overhead exceeds savings. You need bit-exact reproducibility (FP16 introduces small numerical differences). Your loss function is numerically sensitive and produces NaN with scaling.

Typical results: 1.5-2x speedup on V100, 2-3x on A100, 40-50% memory reduction on most architectures.

GPU performance Tensor Cores

Distributed Training Fundamentals

When a single GPU isn't enough — either because the model doesn't fit in memory or because training takes too long — you need distributed training across multiple GPUs. PyTorch offers two main approaches:

Distributed Training Architecture
flowchart TD
    subgraph DP["DataParallel (Simple)"]
        A["Full Model on GPU 0"] --> B["Replicate to GPU 1"]
        A --> C["Replicate to GPU 2"]
        D["Data Batch"] --> E["Split into 3"]
        E --> A
        E --> B
        E --> C
        A --> F["Gather Outputs on GPU 0"]
        B --> F
        C --> F
    end

    subgraph DDP["DistributedDataParallel (Recommended)"]
        G["Process 0 — GPU 0
Full Model Copy"] H["Process 1 — GPU 1
Full Model Copy"] I["Process 2 — GPU 2
Full Model Copy"] G <--> J["All-Reduce
Gradient Sync"] H <--> J I <--> J end

DataParallel (Quick Start)

nn.DataParallel is the simplest way to use multiple GPUs — wrap your model in one line and PyTorch handles splitting data across GPUs. However, it has a significant limitation: all gradient computation funnels through GPU 0, creating a bottleneck. The following code demonstrates the pattern (it works on a single GPU or CPU too, just without parallelism):

import torch
import torch.nn as nn

# Check available GPUs
num_gpus = torch.cuda.device_count()
print(f"Available GPUs: {num_gpus}")

# Model definition
model = nn.Sequential(
    nn.Linear(100, 256), nn.ReLU(),
    nn.Linear(256, 10)
)

# Wrap with DataParallel if multiple GPUs available
if num_gpus > 1:
    model = nn.DataParallel(model)
    print(f"Using DataParallel across {num_gpus} GPUs")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Training works exactly the same — DataParallel is transparent
X = torch.randn(64, 100).to(device)
output = model(X)
print(f"Input: {X.shape} → Output: {output.shape}")

# Access the underlying model (important for saving)
actual_model = model.module if hasattr(model, 'module') else model
print(f"Model type: {type(actual_model)}")
print(f"Parameters: {sum(p.numel() for p in actual_model.parameters()):,}")

Notice the model.module pattern at the end. When you wrap a model in DataParallel, the original model is stored as .module. You need to use model.module.state_dict() when saving, not model.state_dict(), to avoid saving the wrapper.

DistributedDataParallel (Production)

DistributedDataParallel (DDP) is the recommended approach for multi-GPU training. Unlike DataParallel, each GPU runs in its own process with its own copy of the model, and gradients are synchronised via an efficient all-reduce operation. The code below shows the DDP setup pattern — it must be launched as a separate script with torchrun:

# ddp_training.py — launch with: torchrun --nproc_per_node=2 ddp_training.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os

def setup(rank, world_size):
    """Initialize the distributed process group."""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)

    # Each process creates the model on its own GPU
    model = nn.Sequential(
        nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 2)
    ).to(rank)

    # Wrap with DDP
    ddp_model = DDP(model, device_ids=[rank])

    optimizer = optim.Adam(ddp_model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(5):
        # Each process gets different data (in practice, use DistributedSampler)
        X = torch.randn(32, 10).to(rank)
        y = torch.randint(0, 2, (32,)).to(rank)

        output = ddp_model(X)
        loss = criterion(output, y)

        optimizer.zero_grad()
        loss.backward()  # Gradients are all-reduced automatically
        optimizer.step()

        if rank == 0:  # Only print from process 0
            print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")

    # Save only from rank 0
    if rank == 0:
        torch.save(ddp_model.module.state_dict(), "ddp_model.pth")
        print("Model saved from rank 0")

    cleanup()

# Entry point — in practice, torchrun handles spawning
if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    if world_size < 2:
        print(f"DDP requires 2+ GPUs. Found: {world_size}")
        print("Launch with: torchrun --nproc_per_node=2 ddp_training.py")
    else:
        torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)

DDP is faster than DataParallel because each process computes gradients independently and synchronises them in the background while the next forward pass runs. The all-reduce operation is highly optimised and scales near-linearly with the number of GPUs.

When to Scale: Start with a single GPU. Move to DDP when (1) training takes longer than you can afford, (2) your model or batch size doesn't fit in one GPU's memory, or (3) you're training foundation models. DataParallel is fine for quick multi-GPU experiments but use DDP for anything serious.

Early Stopping

Early stopping is the most practical regularisation technique: stop training when the validation loss hasn't improved for a set number of epochs (called "patience"). This prevents overfitting and saves compute time. Rather than a callback-based approach, PyTorch developers implement it explicitly — giving you full control:

import torch
import torch.nn as nn
import torch.optim as optim

class EarlyStopping:
    """Stop training when validation loss stops improving."""

    def __init__(self, patience=5, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.should_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        return self.should_stop

# Setup
torch.manual_seed(42)
model = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

X_train = torch.randn(200, 4)
y_train = torch.randint(0, 2, (200,))
X_val = torch.randn(50, 4)
y_val = torch.randint(0, 2, (50,))

early_stopping = EarlyStopping(patience=5, min_delta=0.001)
max_epochs = 100

for epoch in range(max_epochs):
    model.train()
    loss = criterion(model(X_train), y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        val_loss = criterion(model(X_val), y_val).item()

    stopped = early_stopping(val_loss)

    if (epoch + 1) % 5 == 0 or stopped:
        print(f"Epoch {epoch+1:3d}  Train: {loss.item():.4f}  "
              f"Val: {val_loss:.4f}  Patience: {early_stopping.counter}/{early_stopping.patience}")

    if stopped:
        print(f"\n🛑 Early stopping at epoch {epoch+1}! "
              f"Best val loss: {early_stopping.best_loss:.4f}")
        break
else:
    print(f"\nCompleted all {max_epochs} epochs.")

The min_delta parameter prevents the model from triggering false "improvements" — tiny random fluctuations in validation loss don't reset the patience counter. In practice, combine early stopping with the "save best model" pattern from the previous section: save a checkpoint whenever validation loss improves, and load that checkpoint after training stops.

Gradient Clipping

Gradient clipping prevents the exploding gradient problem, where gradients grow so large that parameter updates overshoot dramatically, causing the loss to spike or become NaN. This is especially common in recurrent networks (RNNs, LSTMs) and deep transformers. PyTorch provides two clipping strategies:

  • Norm clipping (clip_grad_norm_) — Scales all gradients so the total norm stays below a threshold. Preserves gradient direction.
  • Value clipping (clip_grad_value_) — Clamps each individual gradient element to [-value, value]. Can change gradient direction.
import torch
import torch.nn as nn
import torch.optim as optim

# Model and synthetic data
torch.manual_seed(42)
model = nn.Sequential(
    nn.Linear(4, 32), nn.ReLU(),
    nn.Linear(32, 16), nn.ReLU(),
    nn.Linear(16, 1)
)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

X = torch.randn(32, 4)
y = torch.randn(32, 1) * 100  # Large targets → large gradients

# Without clipping — observe the gradient norm
loss = criterion(model(X), y)
optimizer.zero_grad()
loss.backward()

total_norm_before = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float('inf'))
print(f"Gradient norm (no clipping): {total_norm_before:.2f}")

# With norm clipping (max_norm=1.0)
optimizer.zero_grad()
loss = criterion(model(X), y)
loss.backward()

total_norm_clipped = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
print(f"Gradient norm (clipped to 1.0): {total_norm_clipped:.2f}")

# Check that gradients were actually scaled
for name, param in model.named_parameters():
    if param.grad is not None:
        grad_norm = param.grad.norm().item()
        print(f"  {name}: grad norm = {grad_norm:.4f}")
    break  # Just show first layer

# Value clipping alternative
optimizer.zero_grad()
loss = criterion(model(X), y)
loss.backward()

torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
print(f"\nAfter value clipping (±0.5):")
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"  {name}: grad min={param.grad.min().item():.4f}, "
              f"max={param.grad.max().item():.4f}")
    break

Norm clipping is almost always the better choice. It scales the entire gradient vector proportionally, so the direction of the update is preserved — only the magnitude is capped. Value clipping can distort the gradient direction because it clamps each element independently.

Practical Tip: A max_norm of 1.0 is a common starting point. Monitor your gradient norms during training — if they regularly exceed 10-100x the clip threshold, your learning rate may be too high or your model has a deeper architectural issue. Gradient clipping is a safety net, not a permanent fix.

Here's where gradient clipping fits in the training loop — it goes between backward() and step():

import torch
import torch.nn as nn
import torch.optim as optim

# Complete training loop with gradient clipping
torch.manual_seed(42)
model = nn.Sequential(nn.Linear(5, 20), nn.Tanh(), nn.Linear(20, 1))
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
max_grad_norm = 1.0

X = torch.randn(64, 5)
y = torch.randn(64, 1)

for epoch in range(5):
    predictions = model(X)
    loss = criterion(predictions, y)

    optimizer.zero_grad()
    loss.backward()

    # Clip AFTER backward, BEFORE step
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)

    optimizer.step()

    print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}, Grad Norm = {grad_norm:.4f}")

The placement is crucial: you must clip after loss.backward() populates the gradients but before optimizer.step() uses them. clip_grad_norm_ returns the original norm before clipping, which is useful for monitoring whether clipping is actually being triggered.

Text Generation Strategies

When using models for text generation (language models, chatbots, code completion), the model outputs a probability distribution over the vocabulary at each step. The decoding strategy determines how we select the next token from this distribution. The choice dramatically affects output quality — greedy decoding produces repetitive text, while sophisticated sampling produces creative, diverse outputs.

Greedy Decoding vs Sampling

Greedy decoding always picks the highest-probability token. It's deterministic and fast but often produces boring, repetitive text ("I think that the the the..."). Sampling randomly draws from the full distribution, producing diverse outputs but potentially incoherent ones. The key is finding the middle ground.

import torch
import torch.nn.functional as F

# Simulate model output: logits over a vocabulary of 10 tokens
torch.manual_seed(42)
vocab_size = 10
logits = torch.randn(vocab_size)  # raw model output (unnormalized)

# Convert to probabilities
probs = F.softmax(logits, dim=-1)
print("Token probabilities:", probs.round(decimals=3).tolist())
print(f"Sum: {probs.sum().item():.4f}")

# Greedy decoding: always pick the most likely token
greedy_token = torch.argmax(logits)
print(f"\nGreedy choice: token {greedy_token.item()} (prob={probs[greedy_token]:.3f})")

# Sampling: draw randomly from the distribution
sampled_token = torch.multinomial(probs, num_samples=1)
print(f"Sampled choice: token {sampled_token.item()} (prob={probs[sampled_token]:.3f})")

# Sample 20 times to see the diversity
samples = torch.multinomial(probs, num_samples=20, replacement=True)
print(f"\n20 samples: {samples.tolist()}")
print(f"Most frequent: token {samples.mode().values.item()}")

Temperature Scaling

Temperature controls the "sharpness" of the probability distribution. Before applying softmax, we divide the logits by a temperature parameter $T$:

$$P(\text{token}_i) = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}$$

When $T = 1.0$ (default), the distribution is unchanged. When $T < 1.0$ (e.g., 0.1), the distribution becomes sharper — high-probability tokens get even higher probability, approaching greedy behavior. When $T > 1.0$ (e.g., 2.0), the distribution becomes flatter — all tokens become more equally likely, increasing randomness and creativity.

import torch
import torch.nn.functional as F

# Same logits, different temperatures
logits = torch.tensor([2.0, 1.0, 0.5, 0.1, -0.5, -1.0])
tokens = ["great", "good", "fine", "okay", "bad", "terrible"]

temperatures = [0.1, 0.5, 1.0, 1.5, 3.0]
print(f"{'Token':<10} {'Logit':<7} | " + " | ".join(f"T={t}" for t in temperatures))
print("-" * 75)

for i, token in enumerate(tokens):
    row = f"{token:<10} {logits[i]:<7.1f} | "
    for T in temperatures:
        prob = F.softmax(logits / T, dim=-1)[i].item()
        row += f"{prob:.3f}   | "
    print(row)

print("\nNotice:")
print("  T=0.1 -> almost greedy (99%+ on 'great')")
print("  T=1.0 -> standard softmax")
print("  T=3.0 -> nearly uniform (even 'terrible' has ~10% chance)")

Top-k and Top-p (Nucleus) Sampling

Top-k sampling restricts the candidate pool to the $k$ highest-probability tokens and redistributes probability among them. This prevents sampling from the "long tail" of unlikely tokens that could produce nonsensical outputs. Top-p (nucleus) sampling is more adaptive — it includes the smallest set of tokens whose cumulative probability exceeds $p$ (e.g., 0.9). When the model is confident, this might be just 2-3 tokens; when uncertain, it could be dozens.

import torch
import torch.nn.functional as F

def top_k_sampling(logits, k=5, temperature=1.0):
    """Sample from the top-k most likely tokens."""
    scaled_logits = logits / temperature
    
    # Keep only top-k logits, set rest to -inf
    top_k_values, top_k_indices = torch.topk(scaled_logits, k)
    filtered_logits = torch.full_like(scaled_logits, float('-inf'))
    filtered_logits.scatter_(0, top_k_indices, top_k_values)
    
    probs = F.softmax(filtered_logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

def top_p_sampling(logits, p=0.9, temperature=1.0):
    """Sample from the smallest set of tokens with cumulative prob >= p."""
    scaled_logits = logits / temperature
    probs = F.softmax(scaled_logits, dim=-1)
    
    # Sort by probability (descending)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    
    # Remove tokens with cumulative prob above threshold
    # Keep at least one token
    sorted_mask = cumulative_probs - sorted_probs >= p
    sorted_probs[sorted_mask] = 0.0
    sorted_probs /= sorted_probs.sum()  # renormalize
    
    # Sample from filtered distribution
    sampled_idx = torch.multinomial(sorted_probs, num_samples=1)
    return sorted_indices[sampled_idx]

# Demo with a vocabulary
torch.manual_seed(0)
logits = torch.tensor([5.0, 3.5, 2.0, 1.5, 1.0, 0.5, 0.1, -1.0, -2.0, -5.0])
vocab = ["the", "a", "cat", "dog", "sat", "ran", "big", "on", "was", "xyz"]

probs = F.softmax(logits, dim=-1)
print("Full distribution:")
for i, (t, p) in enumerate(zip(vocab, probs)):
    print(f"  {t:<6} {p:.4f} {'***' if p > 0.05 else ''}")

print(f"\nTop-k (k=3): only samples from [{', '.join(vocab[:3])}]")
print(f"Top-p (p=0.9): includes tokens until cumulative prob >= 0.9")

# Generate 10 samples with each method
top_k_results = [vocab[top_k_sampling(logits, k=3).item()] for _ in range(10)]
top_p_results = [vocab[top_p_sampling(logits, p=0.9).item()] for _ in range(10)]
print(f"\nTop-k samples: {top_k_results}")
print(f"Top-p samples: {top_p_results}")
Practical Guide
Choosing Generation Parameters

Factual tasks (summarization, translation): Low temperature (0.1–0.5) + low top-k (5–10). You want deterministic, accurate output.

Creative tasks (stories, poetry): Higher temperature (0.7–1.2) + top-p (0.9–0.95). You want diversity without complete randomness.

Code generation: Temperature 0.2–0.4 + top-p 0.95. Code needs to be correct but you still want exploration of different valid solutions.

Chatbots: Temperature 0.7 + top-p 0.9 is a common starting point for natural-sounding responses.

temperature top-k nucleus sampling

Conclusion & Next Steps

In this article, we covered the complete lifecycle of training a PyTorch model — from the fundamental 5-step training loop to advanced techniques like mixed precision and distributed training. Here's a summary of the key patterns:

  • Training loop: Forward → Loss → Backward → Step → Zero Grad (order matters!)
  • Evaluation: Always use model.eval() + torch.no_grad() for validation and inference
  • Checkpointing: Save state_dict() for weights only; save full checkpoints (model + optimizer + epoch) for resumable training
  • GPU: Define device once, use .to(device) everywhere, use map_location for cross-device loading
  • Mixed precision: autocast + GradScaler for 1.5-3x speedup on modern GPUs
  • Early stopping: Monitor validation loss with patience to prevent overfitting
  • Gradient clipping: Use clip_grad_norm_ between backward() and step()

Next in the Series

In Part 4: Datasets & Data Pipelines, we'll build custom Dataset classes, use DataLoader for efficient batching, apply transforms, handle image and text data, and create production-quality data pipelines.