The nn.Module Foundation
Every neural network in PyTorch is built by subclassing torch.nn.Module. This base class gives you parameter management, GPU transfer, serialization, and a clean API for composing layers. You override two methods: __init__ (define layers) and forward (define computation).
nn.Module or nn.Parameter inside __init__ is automatically registered. PyTorch tracks these so .parameters(), .to(device), and .state_dict() all work seamlessly.
Subclassing nn.Module
Here is the minimal pattern — a two-layer network with a ReLU activation in between:
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
"""A minimal two-layer network."""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.layer1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.layer1(x)
x = self.relu(x)
x = self.layer2(x)
return x
# Create the model
model = SimpleNet(input_dim=10, hidden_dim=32, output_dim=3)
print(model)
# Forward pass with random input
sample = torch.randn(4, 10) # batch of 4, 10 features
output = model(sample)
print("Output shape:", output.shape) # torch.Size([4, 3])
Model Inspection
PyTorch provides several methods to introspect your model — list parameters, count them, and examine sub-modules:
import torch
import torch.nn as nn
class InspectableNet(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Linear(20, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
)
self.classifier = nn.Linear(32, 5)
def forward(self, x):
return self.classifier(self.features(x))
model = InspectableNet()
# named_parameters: yields (name, parameter) tuples
print("=== Named Parameters ===")
for name, param in model.named_parameters():
print(f" {name:30s} shape={str(param.shape):15s} requires_grad={param.requires_grad}")
# Total parameter count
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal params: {total:,}")
print(f"Trainable params: {trainable:,}")
# named_modules: full module hierarchy
print("\n=== Named Modules ===")
for name, module in model.named_modules():
print(f" {name or '(root)':30s} -> {module.__class__.__name__}")
nn.Parameter — Custom Learnable Tensors
nn.Parameter is a special tensor wrapper that tells PyTorch: "this tensor is a learnable weight — include it in .parameters(), track its gradients, and save/load it with the model." Every nn.Linear, nn.Conv2d, etc. uses nn.Parameter internally for their weights and biases. You use it directly when building custom layers that need learnable values not covered by standard modules.
torch.Tensor assigned as an attribute is not registered — it won't appear in .parameters(), won't be moved by .to(device), and won't be saved in .state_dict(). Wrapping it in nn.Parameter solves all three problems. For non-learnable constants (e.g., a fixed positional encoding), use self.register_buffer() instead.
The following example builds a CustomScaleShift layer that learns a per-feature scale ($\gamma$) and shift ($\beta$) — essentially the same learnable affine transform used inside BatchNorm and LayerNorm. We use nn.Parameter for the learnable tensors and register_buffer for a non-learnable running mean. After construction, we inspect the model to see how PyTorch distinguishes the two:
import torch
import torch.nn as nn
class CustomScaleShift(nn.Module):
"""A layer that learns per-feature scale and shift (like simplified LayerNorm)."""
def __init__(self, num_features):
super().__init__()
# nn.Parameter: automatically tracked, requires_grad=True by default
self.scale = nn.Parameter(torch.ones(num_features)) # learnable
self.shift = nn.Parameter(torch.zeros(num_features)) # learnable
# register_buffer: tracked for .to(device) and state_dict, but NOT learnable
self.register_buffer('running_mean', torch.zeros(num_features))
def forward(self, x):
return x * self.scale + self.shift
# Create and inspect
layer = CustomScaleShift(num_features=64)
# Parameters are registered automatically
print("=== Parameters (learnable) ===")
for name, p in layer.named_parameters():
print(f" {name}: shape={p.shape}, requires_grad={p.requires_grad}")
# Buffers are tracked but not learnable
print("\n=== Buffers (non-learnable) ===")
for name, b in layer.named_buffers():
print(f" {name}: shape={b.shape}, requires_grad={b.requires_grad}")
# Both appear in state_dict (for saving/loading)
print("\n=== State Dict Keys ===")
for key in layer.state_dict():
print(f" {key}")
# .to(device) moves both parameters AND buffers
print(f"\nTotal learnable params: {sum(p.numel() for p in layer.parameters())}")
Notice how scale and shift appear under Parameters with requires_grad=True, while running_mean appears under Buffers with requires_grad=False. All three are saved in state_dict() and move together when you call .to('cuda'), but only the parameters receive gradient updates during training.
A common real-world use case is learnable positional encodings — models like BERT, GPT, and Vision Transformers (ViT) add a learnable position vector to each token/patch embedding. Since these positions must be updated during training, they are stored as nn.Parameter:
import torch
import torch.nn as nn
# Common pattern: using nn.Parameter for attention/positional embeddings
class LearnablePositionalEncoding(nn.Module):
"""Learnable position embeddings (used in ViT, BERT, etc.)."""
def __init__(self, max_seq_len, embed_dim):
super().__init__()
# Each position gets a learnable vector
self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, embed_dim) * 0.02)
def forward(self, x):
# x shape: (batch, seq_len, embed_dim)
seq_len = x.size(1)
return x + self.pos_embed[:, :seq_len, :]
# Example
pe = LearnablePositionalEncoding(max_seq_len=512, embed_dim=256)
x = torch.randn(4, 100, 256) # batch=4, seq_len=100
out = pe(x)
print(f"Input: {x.shape}")
print(f"Output: {out.shape}")
print(f"pos_embed shape: {pe.pos_embed.shape}")
print(f"Learnable params: {pe.pos_embed.numel():,}") # 512 * 256 = 131,072
The pos_embed parameter has shape (1, 512, 256) \u2014 one learnable vector per position, broadcast across the batch. During the forward pass we slice [:, :seq_len, :] to handle variable-length inputs. The leading dimension of 1 allows broadcasting across any batch size. During training, backpropagation updates each position vector to encode useful position-dependent patterns.
nn.Parameter, the optimizer will never update it:
self.weights = torch.randn(10, 10) ✘ (invisible to optimizer)
self.weights = nn.Parameter(torch.randn(10, 10)) ✔ (learnable)
Similarly,
nn.ParameterList and nn.ParameterDict exist for dynamic collections of parameters.
Linear Layers & Sequential Models
nn.Linear(in_features, out_features) applies the transformation $y = xW^T + b$. The weight matrix has shape (out_features, in_features) and the bias vector has shape (out_features,).
import torch
import torch.nn as nn
# Create a single linear layer: 5 inputs -> 3 outputs
layer = nn.Linear(in_features=5, out_features=3)
print("Weight shape:", layer.weight.shape) # torch.Size([3, 5])
print("Bias shape: ", layer.bias.shape) # torch.Size([3])
# Forward pass
x = torch.randn(2, 5) # batch of 2
y = layer(x)
print("Input shape: ", x.shape) # torch.Size([2, 5])
print("Output shape:", y.shape) # torch.Size([2, 3])
# Disable bias
layer_no_bias = nn.Linear(5, 3, bias=False)
print("Has bias?", layer_no_bias.bias is None) # True
nn.Sequential
nn.Sequential chains layers in order — you don't need to write a forward method. It's perfect for simple feed-forward architectures:
import torch
import torch.nn as nn
# Build a classifier using nn.Sequential
classifier = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 10),
)
print(classifier)
# Forward pass
x = torch.randn(8, 784) # batch of 8 flattened 28x28 images
logits = classifier(x)
print("Logits shape:", logits.shape) # torch.Size([8, 10])
# Access individual layers by index
print("First layer:", classifier[0])
print("Last layer: ", classifier[6])
Understanding Parameter Counts
Knowing how many parameters your model has is essential for estimating memory requirements and detecting configuration mistakes. A nn.Linear(in, out) layer has in × out + out parameters (weights plus biases). Here’s a utility function that breaks this down layer by layer:
import torch.nn as nn
def count_parameters(model):
"""Print parameter count per layer and total."""
total = 0
for name, param in model.named_parameters():
count = param.numel()
total += count
print(f" {name:30s} {str(param.shape):20s} {count:>8,} params")
print(f" {'TOTAL':30s} {'':20s} {total:>8,} params")
return total
# Example: 784 -> 256 -> 128 -> 10
model = nn.Sequential(
nn.Linear(784, 256), # 784*256 + 256 = 200,960
nn.ReLU(),
nn.Linear(256, 128), # 256*128 + 128 = 32,896
nn.ReLU(),
nn.Linear(128, 10), # 128*10 + 10 = 1,290
)
print("Layer-by-layer parameter breakdown:")
count_parameters(model)
# Total: 235,146
Activation Functions Deep Dive
Activation functions introduce non-linearity, enabling networks to learn complex patterns. Without them, stacking linear layers collapses to a single linear transformation.
flowchart LR
A["Input\nx ∈ ℝⁿ"] --> B["Linear\nW₁x + b₁"]
B --> C["Activation\nσ(·)"]
C --> D["Linear\nW₂h + b₂"]
D --> E["Activation\nσ(·)"]
E --> F["Output\nŷ ∈ ℝᵐ"]
style A fill:#3B9797,stroke:#132440,color:#fff
style B fill:#16476A,stroke:#132440,color:#fff
style C fill:#BF092F,stroke:#132440,color:#fff
style D fill:#16476A,stroke:#132440,color:#fff
style E fill:#BF092F,stroke:#132440,color:#fff
style F fill:#3B9797,stroke:#132440,color:#fff
import torch
import torch.nn as nn
x = torch.linspace(-3, 3, 7)
print("Input:", x.tolist())
# ReLU: max(0, x) — most popular for hidden layers
relu = nn.ReLU()
print("ReLU: ", relu(x).tolist())
# LeakyReLU: allows small negative slope to prevent dead neurons
leaky = nn.LeakyReLU(negative_slope=0.01)
print("LeakyReLU: ", leaky(x).tolist())
# ELU: smooth for negative values, faster convergence
elu = nn.ELU(alpha=1.0)
print("ELU: ", [round(v, 4) for v in elu(x).tolist()])
# GELU: used in Transformers (BERT, GPT)
gelu = nn.GELU()
print("GELU: ", [round(v, 4) for v in gelu(x).tolist()])
# SiLU (Swish): x * sigmoid(x), used in EfficientNet
silu = nn.SiLU()
print("SiLU: ", [round(v, 4) for v in silu(x).tolist()])
# Sigmoid: squashes to (0, 1) — used for binary output
sigmoid = nn.Sigmoid()
print("Sigmoid: ", [round(v, 4) for v in sigmoid(x).tolist()])
# Tanh: squashes to (-1, 1) — used in RNNs
tanh = nn.Tanh()
print("Tanh: ", [round(v, 4) for v in tanh(x).tolist()])
# Softmax: probability distribution over classes (dim matters!)
softmax = nn.Softmax(dim=0)
print("Softmax: ", [round(v, 4) for v in softmax(x).tolist()])
Which Activation to Use?
Hidden layers: Start with ReLU. If you see dead neurons, try LeakyReLU or ELU. For Transformer architectures, use GELU.
Output layer — binary classification: Sigmoid (or skip it and use BCEWithLogitsLoss).
Output layer — multi-class: No activation (use CrossEntropyLoss, which includes LogSoftmax internally).
Output layer — regression: No activation (raw linear output).
GELU: The Transformer Activation
The Gaussian Error Linear Unit (GELU) has become the default activation in Transformer architectures (BERT, GPT, LLaMA, ViT). Unlike ReLU which makes a hard binary decision (pass or block), GELU makes a soft, probabilistic decision — it multiplies each input by the probability of that input being greater than other inputs (modeled by the Gaussian CDF):
$$\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]$$The practical approximation used in most implementations is:
$$\text{GELU}(x) \approx 0.5 \cdot x \cdot \left(1 + \tanh\left[\sqrt{\frac{2}{\pi}}(x + 0.044715 x^3)\right]\right)$$Why is this better than ReLU for Transformers? ReLU has a sharp corner at zero and completely kills negative activations (gradient = 0). GELU instead provides a smooth transition: small negative values are dampened (not killed), allowing a small gradient to flow. This smoothness aligns with the stochastic regularization used in Transformers (dropout, attention masking) and empirically produces better optimization landscapes for deep attention networks.
import torch
import torch.nn as nn
import math
# Compare ReLU, GELU, and the GELU approximation
x = torch.linspace(-4, 4, 200)
# Exact GELU: x * Phi(x) where Phi is the standard Gaussian CDF
def gelu_exact(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
# Tanh approximation (used in GPT-2)
def gelu_approx(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x**3)))
relu_out = torch.relu(x)
gelu_out = gelu_exact(x)
gelu_approx_out = gelu_approx(x)
pytorch_gelu = nn.GELU()(x)
print("Comparison at x = -1.0:")
print(f" ReLU(-1.0) = {torch.relu(torch.tensor(-1.0)).item():.4f}") # 0.0 (killed!)
print(f" GELU(-1.0) = {gelu_exact(torch.tensor(-1.0)).item():.4f}") # -0.1587 (dampened, not killed)
print(f" Approx(-1.0)= {gelu_approx(torch.tensor(-1.0)).item():.4f}") # close to exact
print("\nComparison at x = 2.0:")
print(f" ReLU(2.0) = {torch.relu(torch.tensor(2.0)).item():.4f}") # 2.0
print(f" GELU(2.0) = {gelu_exact(torch.tensor(2.0)).item():.4f}") # ~1.9545 (slightly less)
print(f"\nPyTorch nn.GELU matches exact: {torch.allclose(pytorch_gelu, gelu_out, atol=1e-4)}")
The Dead Neuron Problem
With standard ReLU, neurons that receive consistently negative inputs output zero and get zero gradients — they "die" and stop learning. LeakyReLU and ELU mitigate this by allowing a small gradient for negative inputs.
import torch
import torch.nn as nn
# Simulate dead neuron problem
torch.manual_seed(42)
x = torch.randn(1000)
relu = nn.ReLU()
leaky = nn.LeakyReLU(0.01)
relu_out = relu(x)
leaky_out = leaky(x)
dead_relu = (relu_out == 0).sum().item()
dead_leaky = (leaky_out == 0).sum().item()
print(f"ReLU: {dead_relu} / {len(x)} outputs are exactly zero ({dead_relu/len(x)*100:.1f}%)")
print(f"LeakyReLU: {dead_leaky} / {len(x)} outputs are exactly zero ({dead_leaky/len(x)*100:.1f}%)")
# ReLU: ~500 zeros, LeakyReLU: 0 zeros
Visual Comparison of Activation Functions
Seeing activation functions plotted side-by-side makes their behavior intuitive. Notice how ReLU has a sharp corner at zero (causing dead neurons), while GELU and SiLU/Swish have smooth transitions that preserve small negative gradients:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
x = torch.linspace(-4, 4, 200)
activations = {
'ReLU': nn.ReLU(),
'LeakyReLU': nn.LeakyReLU(0.1),
'ELU': nn.ELU(),
'GELU': nn.GELU(),
'SiLU (Swish)': nn.SiLU(),
'Tanh': nn.Tanh(),
}
fig, axes = plt.subplots(2, 3, figsize=(14, 8))
axes = axes.flatten()
for ax, (name, fn) in zip(axes, activations.items()):
y = fn(x)
ax.plot(x.numpy(), y.detach().numpy(), linewidth=2, color='#3B9797')
ax.axhline(0, color='gray', linewidth=0.5, linestyle='--')
ax.axvline(0, color='gray', linewidth=0.5, linestyle='--')
ax.set_title(name, fontsize=13, fontweight='bold', color='#132440')
ax.set_xlim(-4, 4)
ax.grid(True, alpha=0.3)
plt.suptitle('Activation Functions Comparison', fontsize=16, fontweight='bold', color='#132440')
plt.tight_layout()
plt.show()
Loss Functions by Task
The loss function measures how far the model's predictions are from the targets. Choosing the right loss is critical — it defines the optimization landscape your model navigates.
Classification Losses
nn.CrossEntropyLoss is the standard choice for multi-class classification. It combines LogSoftmax and NLLLoss internally, so your model should output raw logits (no softmax). For binary classification, use BCEWithLogitsLoss which similarly includes the sigmoid:
import torch
import torch.nn as nn
torch.manual_seed(0)
# === CrossEntropyLoss (multi-class classification) ===
# Combines LogSoftmax + NLLLoss — do NOT add Softmax to your model output
criterion_ce = nn.CrossEntropyLoss()
logits = torch.randn(4, 5) # batch of 4, 5 classes (raw scores)
targets = torch.tensor([1, 0, 4, 2]) # class indices (not one-hot)
loss_ce = criterion_ce(logits, targets)
print(f"CrossEntropyLoss: {loss_ce.item():.4f}")
# === BCEWithLogitsLoss (binary classification — preferred) ===
# Combines Sigmoid + BCELoss — numerically stable
criterion_bce_logits = nn.BCEWithLogitsLoss()
logits_binary = torch.randn(4, 1) # raw scores
targets_binary = torch.tensor([[1.0], [0.0], [1.0], [0.0]])
loss_bce_logits = criterion_bce_logits(logits_binary, targets_binary)
print(f"BCEWithLogitsLoss: {loss_bce_logits.item():.4f}")
# === BCELoss (binary — requires sigmoid output) ===
criterion_bce = nn.BCELoss()
probs = torch.sigmoid(logits_binary) # must pass through sigmoid first
loss_bce = criterion_bce(probs, targets_binary)
print(f"BCELoss: {loss_bce.item():.4f}")
# === NLLLoss (when you apply LogSoftmax yourself) ===
criterion_nll = nn.NLLLoss()
log_probs = nn.functional.log_softmax(logits, dim=1)
loss_nll = criterion_nll(log_probs, targets)
print(f"NLLLoss: {loss_nll.item():.4f}")
print(f"CE == NLL? {torch.allclose(loss_ce, loss_nll)}") # True
Regression Losses
For regression tasks (predicting continuous values), the most common losses are MSE (penalizes large errors quadratically), L1/MAE (more robust to outliers), and Smooth L1/Huber (combines the best of both — MSE for small errors, L1 for large ones):
import torch
import torch.nn as nn
torch.manual_seed(0)
predictions = torch.tensor([2.5, 0.0, 2.1, 7.8])
targets = torch.tensor([3.0, -0.5, 2.0, 7.5])
# MSELoss: Mean Squared Error (L2) — penalizes large errors heavily
mse = nn.MSELoss()
print(f"MSELoss: {mse(predictions, targets).item():.4f}")
# L1Loss: Mean Absolute Error — more robust to outliers
l1 = nn.L1Loss()
print(f"L1Loss: {l1(predictions, targets).item():.4f}")
# SmoothL1Loss (Huber Loss): MSE for small errors, L1 for large errors
smooth_l1 = nn.SmoothL1Loss()
print(f"SmoothL1: {smooth_l1(predictions, targets).item():.4f}")
# KLDivLoss: KL divergence (for comparing probability distributions)
kl_div = nn.KLDivLoss(reduction='batchmean')
log_pred = torch.log_softmax(torch.randn(4, 5), dim=1)
target_dist = torch.softmax(torch.randn(4, 5), dim=1)
print(f"KLDivLoss: {kl_div(log_pred, target_dist).item():.4f}")
Handling Class Imbalance with Weighted Loss
Real-world datasets are rarely balanced. In medical diagnosis, fraud detection, and defect inspection, the minority class may be 1-5% of the data. Without correction, the model will learn to always predict the majority class. Inverse-frequency weighting gives rare classes a louder voice in the loss:
import torch
import torch.nn as nn
torch.manual_seed(42)
# Simulated dataset: 90% class 0, 8% class 1, 2% class 2
class_counts = torch.tensor([900, 80, 20], dtype=torch.float32)
total = class_counts.sum()
# Inverse frequency weighting
weights = total / (len(class_counts) * class_counts)
print("Class weights:", weights.tolist()) # [0.37, 4.17, 16.67]
# Pass weights to CrossEntropyLoss
criterion = nn.CrossEntropyLoss(weight=weights)
logits = torch.randn(8, 3)
targets = torch.tensor([0, 0, 0, 0, 0, 0, 1, 2])
loss = criterion(logits, targets)
print(f"Weighted CrossEntropyLoss: {loss.item():.4f}")
# Unweighted comparison
criterion_unweighted = nn.CrossEntropyLoss()
loss_unweighted = criterion_unweighted(logits, targets)
print(f"Unweighted CrossEntropyLoss: {loss_unweighted.item():.4f}")
nn.Softmax to your model's output layer when using nn.CrossEntropyLoss. It already applies LogSoftmax internally. Adding softmax will double-apply it and produce incorrect gradients.
Weight Initialization
The initial values of a network's weights significantly affect training dynamics. Poor initialization can lead to vanishing gradients (activations shrink to zero) or exploding gradients (activations blow up), both of which prevent learning.
Initialization Strategies
PyTorch provides several initialization methods in torch.nn.init. Each is designed for a specific activation function. Using the wrong init can cause gradients to either vanish (shrink to zero) or explode (grow unboundedly) during training:
import torch
import torch.nn as nn
# Create a linear layer to experiment with
layer = nn.Linear(256, 128)
# Xavier/Glorot Uniform — good for Sigmoid/Tanh
nn.init.xavier_uniform_(layer.weight)
print(f"Xavier Uniform | mean={layer.weight.mean():.4f}, std={layer.weight.std():.4f}")
# Xavier/Glorot Normal
nn.init.xavier_normal_(layer.weight)
print(f"Xavier Normal | mean={layer.weight.mean():.4f}, std={layer.weight.std():.4f}")
# Kaiming/He Uniform — designed for ReLU
nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
print(f"Kaiming Uniform | mean={layer.weight.mean():.4f}, std={layer.weight.std():.4f}")
# Kaiming/He Normal — recommended default for ReLU networks
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
print(f"Kaiming Normal | mean={layer.weight.mean():.4f}, std={layer.weight.std():.4f}")
# Orthogonal — preserves gradient norms (great for RNNs)
nn.init.orthogonal_(layer.weight)
print(f"Orthogonal | mean={layer.weight.mean():.4f}, std={layer.weight.std():.4f}")
# Zeros (bad for weights — all neurons identical — but fine for biases)
nn.init.zeros_(layer.bias)
print(f"Zeros (bias) | bias mean={layer.bias.mean():.4f}")
Custom Initialization with apply()
Use model.apply(fn) to walk every sub-module and apply a custom init function:
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(100, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 10),
)
def forward(self, x):
return self.net(x)
def init_weights(m):
"""Apply Kaiming init to Linear layers, zero biases."""
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
model = MLP()
# Check weights BEFORE init
print("Before custom init:")
print(f" layer0 weight std: {model.net[0].weight.std():.4f}")
# Apply custom init to ALL sub-modules
model.apply(init_weights)
print("After custom init:")
print(f" layer0 weight std: {model.net[0].weight.std():.4f}")
print(f" layer0 bias mean: {model.net[0].bias.mean():.4f}")
# Verify with a forward pass
x = torch.randn(4, 100)
out = model(x)
print(f"Output shape: {out.shape}")
print(f"Output std: {out.std():.4f}")
When to Use Which Init?
Kaiming (He): Default choice for networks using ReLU or LeakyReLU. Accounts for the fact that ReLU zeros out half the inputs.
Xavier (Glorot): Designed for Sigmoid and Tanh activations where inputs are symmetrically distributed around zero.
Orthogonal: Excellent for RNNs and LSTMs — preserves gradient magnitude across time steps.
PyTorch defaults: nn.Linear uses Kaiming Uniform by default, which is already a good choice for ReLU networks.
Normalization Layers
Normalization layers stabilize training by reducing internal covariate shift — the phenomenon where each layer’s input distribution changes as preceding layers update their weights. By normalizing activations, these layers allow higher learning rates, faster convergence, and act as mild regularizers.
Batch Normalization (BatchNorm)
BatchNorm normalizes across the batch dimension for each feature/channel independently. For a mini-batch of activations, it computes:
$$\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}$$ $$y_i = \gamma \, \hat{x}_i + \beta$$Where $\mu_B$ and $\sigma_B^2$ are the mean and variance computed over the batch, $\gamma$ (scale) and $\beta$ (shift) are learnable parameters that allow the network to undo the normalization if needed, and $\epsilon$ is a small constant for numerical stability.
model.train() and model.eval() to switch modes.
import torch
import torch.nn as nn
# BatchNorm1d: for fully-connected layers (normalizes over batch for each feature)
bn1d = nn.BatchNorm1d(num_features=64)
# BatchNorm2d: for convolutional layers (normalizes over batch for each channel)
bn2d = nn.BatchNorm2d(num_features=128)
# Example: batch of 8 samples, 64 features
x_fc = torch.randn(8, 64)
out_fc = bn1d(x_fc)
print(f"BatchNorm1d input: {x_fc.shape} -> output: {out_fc.shape}")
print(f"Output mean per feature: {out_fc.mean(dim=0)[:5].tolist()}")
print(f"Output std per feature: {out_fc.std(dim=0)[:5].tolist()}")
# Example: batch of 8, 128 channels, 16x16 spatial
x_conv = torch.randn(8, 128, 16, 16)
out_conv = bn2d(x_conv)
print(f"\nBatchNorm2d input: {x_conv.shape} -> output: {out_conv.shape}")
print(f"Learnable params: gamma={bn2d.weight.shape}, beta={bn2d.bias.shape}")
print(f"Running mean shape: {bn2d.running_mean.shape}")
Layer Normalization (LayerNorm)
LayerNorm normalizes across the feature dimension for each sample independently (no dependence on batch). This makes it ideal for sequence models (Transformers, RNNs) where batch statistics are unreliable due to variable sequence lengths:
$$\hat{x}_i = \frac{x_i - \mu_L}{\sqrt{\sigma_L^2 + \epsilon}}, \quad \mu_L = \frac{1}{H}\sum_{j=1}^H x_j, \quad \sigma_L^2 = \frac{1}{H}\sum_{j=1}^H (x_j - \mu_L)^2$$Where the statistics are computed over the last $H$ dimensions (the normalized_shape). Unlike BatchNorm, LayerNorm behaves identically during training and inference.
import torch
import torch.nn as nn
# LayerNorm: normalizes over the last N dimensions
# For a Transformer with hidden_dim=512
ln = nn.LayerNorm(normalized_shape=512)
# Single sample: each token is normalized independently
x = torch.randn(1, 20, 512) # (batch, seq_len, hidden_dim)
out = ln(x)
print(f"LayerNorm input: {x.shape} -> output: {out.shape}")
print(f"Token 0 mean: {out[0, 0].mean().item():.6f}")
print(f"Token 0 std: {out[0, 0].std().item():.4f}")
# LayerNorm over multiple dimensions (e.g., for CNN feature maps)
ln_2d = nn.LayerNorm(normalized_shape=[64, 8, 8]) # channels, H, W
x_2d = torch.randn(4, 64, 8, 8)
out_2d = ln_2d(x_2d)
print(f"\nLayerNorm 2D input: {x_2d.shape} -> output: {out_2d.shape}")
print(f"Sample 0 mean: {out_2d[0].mean().item():.6f}")
print(f"Params: weight={ln.weight.shape}, bias={ln.bias.shape}")
Group Normalization & Instance Normalization
GroupNorm divides channels into groups and normalizes within each group. It bridges the gap between LayerNorm (1 group = all channels) and InstanceNorm (groups = channels). It’s batch-size independent, making it ideal for small-batch training (detection, segmentation):
$$\hat{x}_i = \frac{x_i - \mu_g}{\sqrt{\sigma_g^2 + \epsilon}}$$Where $\mu_g$ and $\sigma_g^2$ are computed over each group of channels per sample.
InstanceNorm normalizes each channel per sample independently (equivalent to GroupNorm with num_groups = num_channels). It’s the default for style transfer because it removes instance-specific contrast:
import torch
import torch.nn as nn
# GroupNorm: divide 64 channels into 8 groups of 8 channels each
gn = nn.GroupNorm(num_groups=8, num_channels=64)
# InstanceNorm: each channel normalized independently (groups = channels)
ins_norm = nn.InstanceNorm2d(num_features=64, affine=True)
# Input: batch of 2, 64 channels, 32x32 spatial
x = torch.randn(2, 64, 32, 32)
out_gn = gn(x)
out_in = ins_norm(x)
print(f"GroupNorm (8 groups of 8ch): {out_gn.shape}")
print(f" Sample 0, Group 0 (ch 0-7) mean: {out_gn[0, :8].mean().item():.6f}")
print(f"\nInstanceNorm (64 groups of 1ch): {out_in.shape}")
print(f" Sample 0, Channel 0 mean: {out_in[0, 0].mean().item():.6f}")
print(f" Sample 0, Channel 0 std: {out_in[0, 0].std().item():.4f}")
# GroupNorm special cases:
# num_groups=1 -> equivalent to LayerNorm
# num_groups=num_channels -> equivalent to InstanceNorm
ln_equiv = nn.GroupNorm(num_groups=1, num_channels=64) # LayerNorm
in_equiv = nn.GroupNorm(num_groups=64, num_channels=64) # InstanceNorm
print(f"\nGroupNorm(1 group) ~ LayerNorm: {ln_equiv}")
print(f"GroupNorm(64 groups) ~ InstanceNorm: {in_equiv}")
When to Use Which Normalization
BatchNorm: Default for CNNs and MLPs with batch size ≥ 16. Best convergence speed. Avoid with small batches or sequence models.
LayerNorm: Default for Transformers, RNNs, and any model with variable-length inputs. Batch-size independent.
GroupNorm: Use when batch size is small (1–8), e.g., object detection, segmentation. Set groups=32 as a solid default.
InstanceNorm: Style transfer and generative models. Removes per-instance contrast information.
import torch
import torch.nn as nn
# Side-by-side comparison: how each norm computes statistics differently
x = torch.randn(4, 8, 6, 6) # batch=4, channels=8, spatial=6x6
norms = {
"BatchNorm2d": nn.BatchNorm2d(8),
"LayerNorm": nn.LayerNorm([8, 6, 6]),
"GroupNorm(2gr)": nn.GroupNorm(num_groups=2, num_channels=8),
"InstanceNorm2d": nn.InstanceNorm2d(8, affine=True),
}
print(f"Input shape: {x.shape} (batch=4, channels=8, H=6, W=6)")
print(f"{'Norm Type':<18} {'Output mean':>12} {'Output std':>11} {'Params':>8}")
print("-" * 55)
for name, norm in norms.items():
out = norm(x)
n_params = sum(p.numel() for p in norm.parameters())
print(f"{name:<18} {out.mean().item():>12.6f} {out.std().item():>11.4f} {n_params:>8}")
bias=False in Conv/Linear layers preceding BatchNorm since BN already has a learnable bias ($\beta$). (4) Don’t use BatchNorm in the Discriminator’s first layer (GANs).
Optimizers
Optimizers update model parameters based on gradients computed during the backward pass. The choice of optimizer and its hyperparameters (learning rate, momentum, weight decay) can make or break training.
SGD with Momentum
Stochastic Gradient Descent is the simplest optimizer — it subtracts the gradient scaled by the learning rate. Momentum adds a velocity term that accumulates past gradients, helping the optimizer roll through flat regions and dampen oscillations. Nesterov momentum is a slight improvement that computes the gradient at the “look-ahead” position.
Vanilla SGD update rule:
$$\theta_{t+1} = \theta_t - \eta \, \nabla L(\theta_t)$$SGD with Momentum:
$$v_t = \mu \cdot v_{t-1} + \nabla L(\theta_t)$$ $$\theta_{t+1} = \theta_t - \eta \cdot v_t$$Nesterov Momentum (look-ahead gradient):
$$v_t = \mu \cdot v_{t-1} + \nabla L(\theta_t - \eta \, \mu \cdot v_{t-1})$$ $$\theta_{t+1} = \theta_t - \eta \cdot v_t$$Where $\theta$ are the parameters, $\eta$ is the learning rate, $\mu$ is the momentum coefficient (typically 0.9), $v_t$ is the velocity (accumulated gradient) at step $t$, and $\nabla L$ is the gradient of the loss.
import torch
import torch.nn as nn
# Simple model
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 1))
criterion = nn.MSELoss()
# SGD — vanilla
optimizer_sgd = torch.optim.SGD(model.parameters(), lr=0.01)
# SGD with momentum — accelerates convergence
optimizer_momentum = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# SGD with Nesterov momentum — look-ahead gradient
optimizer_nesterov = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True)
# Simulate one optimization step
x = torch.randn(8, 10)
y = torch.randn(8, 1)
optimizer_momentum.zero_grad()
loss = criterion(model(x), y)
loss.backward()
optimizer_momentum.step()
print(f"Loss after 1 step: {loss.item():.4f}")
print(f"Optimizer: {optimizer_momentum}")
Adam & AdamW
Adam (Adaptive Moment Estimation) maintains per-parameter learning rates based on both the mean and variance of past gradients. It’s the default choice for most projects because it converges quickly with minimal tuning. AdamW decouples weight decay from the gradient update, which produces better generalization and is the standard optimizer for Transformer-based models.
Adam update rule:
$$m_t = \beta_1 \, m_{t-1} + (1 - \beta_1) \, g_t \quad \text{(first moment estimate)}$$ $$v_t = \beta_2 \, v_{t-1} + (1 - \beta_2) \, g_t^2 \quad \text{(second moment estimate)}$$ $$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \quad \text{(bias correction)}$$ $$\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \, \hat{m}_t$$AdamW (decoupled weight decay):
$$\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \, \hat{m}_t - \eta \lambda \, \theta_t$$Where $g_t = \nabla L(\theta_t)$, $\beta_1 = 0.9$, $\beta_2 = 0.999$, $\epsilon = 10^{-8}$ (prevents division by zero), and $\lambda$ is the weight decay coefficient. The bias correction compensates for the zero-initialized moments during early training steps.
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 1))
criterion = nn.MSELoss()
# Adam: adaptive learning rates per parameter
optimizer_adam = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))
# AdamW: decoupled weight decay (recommended for fine-tuning and Transformers)
optimizer_adamw = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
# RMSprop: adaptive, good for non-stationary problems
optimizer_rmsprop = torch.optim.RMSprop(model.parameters(), lr=1e-3, alpha=0.99)
# Compare all optimizers with a quick training loop
for name, opt in [("Adam", optimizer_adam), ("AdamW", optimizer_adamw), ("RMSprop", optimizer_rmsprop)]:
torch.manual_seed(0)
m = nn.Sequential(nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 1))
opt_instance = type(opt)(m.parameters(), **{k: v for k, v in opt.defaults.items()})
x = torch.randn(32, 10)
y = torch.randn(32, 1)
for step in range(50):
opt_instance.zero_grad()
loss = criterion(m(x), y)
loss.backward()
opt_instance.step()
print(f"{name:8s} final loss: {loss.item():.4f}")
Parameter Groups with Different Learning Rates
Different parts of a model often benefit from different learning rates — e.g., a pretrained backbone with a small LR and a new head with a larger LR:
import torch
import torch.nn as nn
class TwoPartModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = nn.Sequential(
nn.Linear(100, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
)
self.head = nn.Linear(32, 10)
def forward(self, x):
features = self.backbone(x)
return self.head(features)
model = TwoPartModel()
# Different LR per parameter group
optimizer = torch.optim.Adam([
{'params': model.backbone.parameters(), 'lr': 1e-4}, # pretrained: slow LR
{'params': model.head.parameters(), 'lr': 1e-2}, # new head: fast LR
], weight_decay=1e-5)
print("Parameter groups:")
for i, group in enumerate(optimizer.param_groups):
n_params = sum(p.numel() for p in group['params'])
print(f" Group {i}: lr={group['lr']}, params={n_params:,}")
Learning Rate Scheduling
Training with a fixed learning rate is rarely optimal. Schedulers adjust the LR during training — typically decaying it so the model takes smaller steps as it approaches convergence.
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import (
StepLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, OneCycleLR
)
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# StepLR: multiply LR by gamma every step_size epochs
scheduler_step = StepLR(optimizer, step_size=10, gamma=0.5)
print(f"StepLR — LR at epoch 0: {scheduler_step.get_last_lr()[0]:.4f}")
# ExponentialLR: multiply LR by gamma every epoch
optimizer2 = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler_exp = ExponentialLR(optimizer2, gamma=0.95)
# CosineAnnealingLR: cosine decay from initial LR to eta_min
optimizer3 = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler_cosine = CosineAnnealingLR(optimizer3, T_max=50, eta_min=1e-5)
# ReduceLROnPlateau: reduce when a metric stops improving
optimizer4 = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler_plateau = ReduceLROnPlateau(optimizer4, mode='min', factor=0.5, patience=5)
# OneCycleLR: warmup then cosine decay (super-convergence)
optimizer5 = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler_onecycle = OneCycleLR(optimizer5, max_lr=0.1, total_steps=100)
print("Schedulers created successfully!")
print(" StepLR: decay by 0.5 every 10 epochs")
print(" ExponentialLR: decay by 0.95 every epoch")
print(" CosineAnnealingLR: cosine to 1e-5 over 50 epochs")
print(" ReduceLROnPlateau: halve after 5 stagnant epochs")
print(" OneCycleLR: warmup to 0.1, cosine to 0")
Visualizing Learning Rate Schedules
Seeing how the learning rate changes across epochs makes it much easier to choose and tune a schedule. The plot below compares all four major strategies — notice how OneCycleLR ramps up first (warmup) before decaying, while CosineAnnealing follows a smooth cosine curve:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR, ExponentialLR, CosineAnnealingLR, OneCycleLR
import matplotlib.pyplot as plt
epochs = 100
def track_lr(scheduler_class, **kwargs):
"""Track LR across epochs for a given scheduler."""
model = nn.Linear(10, 1)
opt = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler = scheduler_class(opt, **kwargs)
lrs = []
for epoch in range(epochs):
lrs.append(opt.param_groups[0]['lr'])
# Simulate a training step
opt.step()
scheduler.step()
return lrs
# Track each scheduler
schedules = {
'StepLR (step=20, γ=0.5)': track_lr(StepLR, step_size=20, gamma=0.5),
'ExponentialLR (γ=0.97)': track_lr(ExponentialLR, gamma=0.97),
'CosineAnnealingLR (T=100)': track_lr(CosineAnnealingLR, T_max=100, eta_min=1e-4),
}
# OneCycleLR needs total_steps
model = nn.Linear(10, 1)
opt = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler_oc = OneCycleLR(opt, max_lr=0.1, total_steps=epochs)
lrs_oc = []
for _ in range(epochs):
lrs_oc.append(opt.param_groups[0]['lr'])
opt.step()
scheduler_oc.step()
schedules['OneCycleLR (max=0.1)'] = lrs_oc
# Plot
fig, ax = plt.subplots(figsize=(12, 5))
colors = ['#3B9797', '#BF092F', '#16476A', '#D4A843']
for (name, lrs), color in zip(schedules.items(), colors):
ax.plot(range(epochs), lrs, label=name, linewidth=2, color=color)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Learning Rate', fontsize=12)
ax.set_title('Learning Rate Schedule Comparison', fontsize=14, fontweight='bold', color='#132440')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
optimizer.step() before scheduler.step(). In PyTorch 1.1+, calling the scheduler first triggers a warning and can produce incorrect learning rates on the first epoch.
Building a Complete MLP
Let's bring everything together — define a Multi-Layer Perceptron, initialize its weights, choose a loss and optimizer, schedule the learning rate, train on synthetic data, and evaluate the results.
Model Definition
We’ll build a configurable MLP class that accepts a list of hidden layer sizes and automatically constructs the network with BatchNorm, ReLU, and Dropout between each layer. This pattern is reusable — change the dimensions and you have a model for any tabular dataset:
import torch
import torch.nn as nn
class MLP(nn.Module):
"""Multi-Layer Perceptron with configurable architecture."""
def __init__(self, input_dim, hidden_dims, output_dim, dropout=0.2):
super().__init__()
layers = []
prev_dim = input_dim
for h_dim in hidden_dims:
layers.extend([
nn.Linear(prev_dim, h_dim),
nn.BatchNorm1d(h_dim),
nn.ReLU(),
nn.Dropout(dropout),
])
prev_dim = h_dim
layers.append(nn.Linear(prev_dim, output_dim))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
# Create a model: 20 features -> [128, 64, 32] hidden -> 4 classes
model = MLP(input_dim=20, hidden_dims=[128, 64, 32], output_dim=4, dropout=0.3)
print(model)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")
# Test forward pass
x = torch.randn(16, 20)
out = model(x)
print(f"Output shape: {out.shape}") # [16, 4]
Training & Evaluation
Now we’ll wire everything together into a complete training pipeline: create synthetic data with sklearn, split into train/validation sets, run a training loop with loss tracking, schedule the learning rate, and plot the results. This is the canonical PyTorch training pattern you’ll reuse in every project:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
# --- 1. Generate synthetic dataset ---
X, y = make_classification(
n_samples=2000, n_features=20, n_classes=4,
n_informative=12, n_redundant=4, random_state=42
)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)
# --- 2. Define model ---
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Linear(20, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(64, 32), nn.BatchNorm1d(32), nn.ReLU(),
nn.Linear(32, 4),
)
def forward(self, x):
return self.network(x)
torch.manual_seed(42)
model = MLP()
# --- 3. Custom weight init ---
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
model.apply(init_weights)
# --- 4. Loss, optimizer, scheduler ---
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=80, eta_min=1e-5)
# --- 5. Training loop ---
epochs = 80
train_losses, test_losses, test_accs = [], [], []
for epoch in range(epochs):
# Train
model.train()
optimizer.zero_grad()
logits = model(X_train)
loss = criterion(logits, y_train)
loss.backward()
optimizer.step()
scheduler.step()
train_losses.append(loss.item())
# Evaluate
model.eval()
with torch.no_grad():
test_logits = model(X_test)
test_loss = criterion(test_logits, y_test)
preds = test_logits.argmax(dim=1)
acc = (preds == y_test).float().mean()
test_losses.append(test_loss.item())
test_accs.append(acc.item())
if (epoch + 1) % 20 == 0:
lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch+1:3d} | Train Loss: {loss.item():.4f} | "
f"Test Loss: {test_loss.item():.4f} | Acc: {acc.item():.4f} | LR: {lr:.6f}")
# --- 6. Plot results ---
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ax1.plot(train_losses, label='Train Loss', color='#3B9797', linewidth=2)
ax1.plot(test_losses, label='Test Loss', color='#BF092F', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training & Test Loss', fontweight='bold', color='#132440')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax2.plot(test_accs, label='Test Accuracy', color='#16476A', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Test Accuracy', fontweight='bold', color='#132440')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
nn.BatchNorm1d behaves differently during training and evaluation. Always call model.train() before training and model.eval() before inference. In eval mode, BatchNorm uses running statistics instead of batch statistics.
Conclusion & Next Steps
You now have a complete toolkit for building neural networks in PyTorch:
- nn.Module — the base class for all models, with automatic parameter tracking
- Linear layers & Sequential — the building blocks and a quick way to compose them
- Activation functions — ReLU for most cases, GELU for Transformers, and how to avoid dead neurons
- Loss functions — CrossEntropyLoss for multi-class, BCEWithLogitsLoss for binary, MSELoss for regression
- Weight initialization — Kaiming for ReLU, Xavier for Sigmoid/Tanh, and how to apply custom init
- Optimizers — Adam/AdamW as the default, SGD+momentum when you need more control, parameter groups for differential LRs
- LR Schedulers — CosineAnnealing and OneCycleLR for strong baselines, ReduceLROnPlateau for adaptive decay
Next in the Series
In Part 3: Training, Evaluation & Checkpointing, we'll build a robust training loop with mini-batch DataLoaders, gradient clipping, early stopping, model checkpointing, and evaluation metrics like precision, recall, and F1-score.