From Training to Production
You've spent days (or weeks) training a model that achieves excellent metrics on your validation set. You open a Jupyter notebook, call model(x), and the predictions look great. Then someone asks: "Can we put this in our app?" — and suddenly you realize that research code and production code are fundamentally different beasts.
The deployment gap is the distance between a model that works in a notebook and a model that serves reliable predictions at scale. Research code can afford to be messy — it only needs to run once, on your machine, with your exact setup. Production code must handle arbitrary inputs, run fast, use minimal memory, recover from errors, and operate 24/7 without crashing. Bridging this gap is what this article is all about.
Deployment Pipeline Overview
Before we dive into individual techniques, let's visualize the full journey a model takes from a training notebook to a production endpoint. Each box represents a stage we'll cover in detail throughout this article.
flowchart LR
A[Train Model] --> B[Optimize]
B --> C{Export Format}
C -->|TorchScript| D[JIT Module]
C -->|ONNX| E[ONNX Model]
D --> F[Quantize / Prune]
E --> F
F --> G{Deploy Target}
G -->|Server| H[Flask / FastAPI]
G -->|Mobile| I[PyTorch Mobile]
G -->|Cloud| J[TorchServe]
H --> K[Monitor & Iterate]
I --> K
J --> K
The pipeline starts with a trained model and proceeds through optimization (quantization, pruning), export to a portable format (TorchScript or ONNX), and finally deployment to a target environment. Let's start with the first export option: TorchScript.
TorchScript: JIT Compilation
TorchScript is PyTorch's way of converting your Python model into a serialized, optimizable format that can run without a Python interpreter. This is critical for production because it removes the Python GIL bottleneck, enables C++ inference, and allows the JIT compiler to apply graph-level optimizations like operator fusion.
There are two ways to create a TorchScript model: tracing and scripting. Tracing records the operations that happen when you pass a sample input through the model. Scripting analyzes your Python source code and compiles it directly. Each approach has trade-offs.
Tracing a Model
Tracing is the simplest approach — you provide a sample input and PyTorch records every operation. The result is a static computation graph. This works perfectly for models with no data-dependent control flow (no if statements that depend on input values).
import torch
import torch.nn as nn
# Define a simple model
class SimpleClassifier(nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.features = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.ReLU(),
)
self.classifier = nn.Linear(64, num_classes)
def forward(self, x):
features = self.features(x)
return self.classifier(features)
# Create and set to eval mode (critical for tracing!)
model = SimpleClassifier(input_dim=784, num_classes=10)
model.eval()
# Trace with a sample input — shape must match real inputs
sample_input = torch.randn(1, 784)
traced_model = torch.jit.trace(model, sample_input)
# Test that traced model produces identical output
with torch.no_grad():
original_out = model(sample_input)
traced_out = traced_model(sample_input)
print("Outputs match:", torch.allclose(original_out, traced_out))
# Outputs match: True
# Save the traced model
traced_model.save("classifier_traced.pt")
print("Traced model saved successfully")
Notice we call model.eval() before tracing — this is essential because it disables dropout and batch normalization's running statistics, ensuring the traced graph reflects inference behavior, not training behavior.
Scripting: Handling Control Flow
Tracing has a fundamental limitation: it can only record the operations that actually execute for the sample input you provide. If your model contains if/else branches that depend on the input data, tracing will only capture one branch — the one that ran for your sample. torch.jit.script solves this by analyzing the Python source code directly.
import torch
import torch.nn as nn
# Model with data-dependent control flow
class DynamicModel(nn.Module):
def __init__(self):
super().__init__()
self.linear_small = nn.Linear(10, 5)
self.linear_large = nn.Linear(10, 5)
def forward(self, x):
# This branch depends on the INPUT — tracing would miss one path!
if x.sum() > 0:
return self.linear_small(x)
else:
return self.linear_large(x)
model = DynamicModel()
model.eval()
# Script (not trace) — handles both branches correctly
scripted_model = torch.jit.script(model)
# Test both branches
positive_input = torch.ones(1, 10) # sum > 0 → linear_small
negative_input = -torch.ones(1, 10) # sum < 0 → linear_large
with torch.no_grad():
out1 = scripted_model(positive_input)
out2 = scripted_model(negative_input)
print("Positive branch output shape:", out1.shape) # torch.Size([1, 5])
print("Negative branch output shape:", out2.shape) # torch.Size([1, 5])
# Save scripted model
scripted_model.save("dynamic_scripted.pt")
print("Scripted model saved successfully")
Use tracing when your model is a straightforward sequence of operations with no data-dependent branches. Use scripting when your model contains if statements, loops over variable-length inputs, or any logic that changes based on the actual data values. When in doubt, scripting is the safer choice.
Loading TorchScript Models (No Python Required)
One of TorchScript's biggest advantages is that saved models can be loaded in any environment — including C++ — without the original Python class definition. This means your deployment server doesn't need your training codebase at all.
import torch
# Load a TorchScript model — no class definition needed!
loaded_model = torch.jit.load("classifier_traced.pt")
loaded_model.eval()
# Run inference
test_input = torch.randn(5, 784) # batch of 5
with torch.no_grad():
predictions = loaded_model(test_input)
predicted_classes = predictions.argmax(dim=1)
print("Predicted classes:", predicted_classes)
# Predicted classes: tensor([3, 7, 1, 0, 5]) (example output)
# Inspect the graph (useful for debugging)
print("\nTorchScript graph (first 500 chars):")
print(str(loaded_model.graph)[:500])
This decoupling of model definition from model execution is what makes TorchScript ideal for production: your ML team trains in Python, exports a .pt file, and your backend team loads it in their C++ or Python serving infrastructure without needing any ML dependencies beyond libtorch.
ONNX Export
ONNX (Open Neural Network Exchange) is an open standard for representing machine learning models. Think of it as a universal file format — like PDF for documents, but for neural networks. You can export a PyTorch model to ONNX and then run it with any ONNX-compatible runtime: ONNX Runtime (Microsoft), TensorRT (NVIDIA), OpenVINO (Intel), Core ML (Apple), and more.
This interoperability is ONNX's superpower. You train in PyTorch but deploy on whatever hardware or framework your production environment requires — without rewriting the model.
Exporting with Dynamic Axes
By default, ONNX export fixes the input shape to whatever sample you provide. In production, you almost always want dynamic axes — the ability to handle variable batch sizes (and sometimes variable sequence lengths). The dynamic_axes parameter tells ONNX which dimensions can change at runtime.
import torch
import torch.nn as nn
# Define a model
class TextClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(embed_dim, num_classes)
def forward(self, x):
# x: (batch, seq_len) of token IDs
embedded = self.embedding(x) # (batch, seq_len, embed_dim)
pooled = self.pool(embedded.transpose(1, 2)).squeeze(-1) # (batch, embed_dim)
return self.fc(pooled) # (batch, num_classes)
model = TextClassifier(vocab_size=10000, embed_dim=128, num_classes=5)
model.eval()
# Sample input — batch=1, seq_len=20
dummy_input = torch.randint(0, 10000, (1, 20))
# Export with dynamic batch size AND sequence length
torch.onnx.export(
model,
dummy_input,
"text_classifier.onnx",
input_names=["input_ids"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_length"},
"logits": {0: "batch_size"}
},
opset_version=17
)
print("ONNX model exported with dynamic axes")
print(f" - input_ids: dynamic batch_size (dim 0) + seq_length (dim 1)")
print(f" - logits: dynamic batch_size (dim 0)")
The dynamic_axes dictionary maps axis indices to human-readable names. Axis 0 is typically the batch dimension. Setting opset_version=17 ensures compatibility with the latest ONNX operators.
Running Inference with ONNX Runtime
Once exported, you can load the ONNX model with onnxruntime — Microsoft's high-performance inference engine. ONNX Runtime applies graph optimizations (constant folding, operator fusion) and can target CPU, GPU, or specialized hardware accelerators. Here's how to load and run the exported model.
import numpy as np
# pip install onnxruntime
import onnxruntime as ort
# Create an inference session
session = ort.InferenceSession("text_classifier.onnx")
# Print input/output metadata
for inp in session.get_inputs():
print(f"Input: {inp.name}, shape: {inp.shape}, type: {inp.type}")
for out in session.get_outputs():
print(f"Output: {out.name}, shape: {out.shape}, type: {out.type}")
# Run inference with numpy arrays (not PyTorch tensors)
# Dynamic axes: test with batch=4, seq_len=15
test_input = np.random.randint(0, 10000, size=(4, 15)).astype(np.int64)
outputs = session.run(None, {"input_ids": test_input})
logits = outputs[0]
predictions = np.argmax(logits, axis=1)
print(f"\nBatch predictions: {predictions}")
print(f"Logits shape: {logits.shape}") # (4, 5)
ONNX Runtime accepts NumPy arrays, not PyTorch tensors. This is by design — ONNX Runtime is framework-agnostic. In many benchmarks, ONNX Runtime inference is 2–5× faster than native PyTorch eager mode because of its aggressive graph optimizations.
Model Quantization
Quantization reduces the numerical precision of model weights and activations — typically from 32-bit floating point (FP32) to 8-bit integer (INT8). This sounds like it would destroy accuracy, but modern quantization techniques are remarkably good at preserving model quality while delivering dramatic improvements in speed and memory usage.
flowchart TD
A[Trained FP32 Model] --> B{Quantization Strategy}
B -->|Simplest| C[Dynamic Quantization]
B -->|Best accuracy| D[Static Quantization]
B -->|Highest quality| E[Quantization-Aware Training]
C --> F[Weights: INT8 at load time
Activations: INT8 at runtime]
D --> G[Calibration dataset needed
Both weights + activations pre-quantized]
E --> H[Fake quantization during training
Model learns to be robust to INT8]
F --> I[Deploy INT8 Model]
G --> I
H --> I
Dynamic Quantization
Dynamic quantization is the easiest quantization method — it requires no calibration data and can be applied in a single line of code. Weights are quantized to INT8 ahead of time, while activations are quantized dynamically at runtime. This works especially well for models dominated by Linear layers (transformers, LSTMs, fully connected networks).
import torch
import torch.nn as nn
import os
# Define a model (simulating a small transformer-like architecture)
class SmallModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10),
)
def forward(self, x):
return self.layers(x)
model = SmallModel()
model.eval()
# One-line dynamic quantization
quantized_model = torch.ao.quantization.quantize_dynamic(
model,
{nn.Linear}, # which layer types to quantize
dtype=torch.qint8 # target dtype
)
# Compare model sizes
torch.save(model.state_dict(), "original.pt")
torch.save(quantized_model.state_dict(), "quantized.pt")
orig_size = os.path.getsize("original.pt") / 1024
quant_size = os.path.getsize("quantized.pt") / 1024
print(f"Original model size: {orig_size:.1f} KB")
print(f"Quantized model size: {quant_size:.1f} KB")
print(f"Compression ratio: {orig_size / quant_size:.2f}x")
# Verify outputs are close
test_input = torch.randn(1, 512)
with torch.no_grad():
orig_output = model(test_input)
quant_output = quantized_model(test_input)
diff = (orig_output - quant_output).abs().mean().item()
print(f"\nMean absolute difference: {diff:.6f}")
# Clean up temp files
os.remove("original.pt")
os.remove("quantized.pt")
The key line is quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8). You tell PyTorch which layer types to quantize (typically nn.Linear and nn.LSTM) and the target integer type. The resulting model runs entirely on the CPU with INT8 arithmetic for the specified layers.
Post-Training Static Quantization
Static quantization goes one step further: it pre-computes optimal quantization parameters (scale and zero-point) for both weights and activations using a representative calibration dataset. This typically yields better accuracy than dynamic quantization because the activation ranges are determined in advance rather than estimated at runtime.
import torch
import torch.nn as nn
# Model with QuantStub/DeQuantStub markers
class QuantizableModel(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.layers = nn.Sequential(
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10),
)
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x) # float → quantized
x = self.layers(x)
x = self.dequant(x) # quantized → float
return x
model = QuantizableModel()
model.eval()
# Step 1: Specify quantization config (fbgemm for x86, qnnpack for ARM)
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
# Step 2: Prepare — inserts observer modules that track value ranges
prepared_model = torch.ao.quantization.prepare(model)
# Step 3: Calibrate — run representative data through the model
# The observers collect activation statistics
print("Calibrating with representative data...")
for _ in range(100):
calibration_data = torch.randn(32, 256)
prepared_model(calibration_data)
# Step 4: Convert — replace observers with actual quantized operations
quantized_model = torch.ao.quantization.convert(prepared_model)
# Test inference
test_input = torch.randn(1, 256)
with torch.no_grad():
output = quantized_model(test_input)
print(f"Output shape: {output.shape}")
print(f"Prediction: class {output.argmax(dim=1).item()}")
print("Static quantization complete!")
The four steps are always the same: (1) set qconfig, (2) prepare to insert observers, (3) calibrate by running representative data, (4) convert to produce the final INT8 model. Use "fbgemm" for Intel/AMD CPUs and "qnnpack" for ARM processors (mobile).
Pruning
Pruning removes unnecessary weights from a neural network — setting them to zero so they don't contribute to the output. The insight behind pruning is that most trained networks are over-parameterized: a large fraction of their weights are very close to zero and contribute almost nothing to the model's predictions. By removing these weights, you get a smaller, faster model with minimal accuracy loss.
There are two main types of pruning. Unstructured pruning zeroes out individual weights wherever they're smallest, creating a sparse weight matrix. Structured pruning removes entire neurons, filters, or channels, creating a genuinely smaller model that runs faster on standard hardware without specialized sparse computation libraries.
Unstructured Magnitude Pruning
The simplest pruning strategy is magnitude pruning: sort all weights by absolute value and set the smallest ones to zero. PyTorch provides torch.nn.utils.prune with several built-in pruning methods. Here we prune 40% of weights in each Linear layer based on their L1 magnitude.
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
# Create a model
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
model = MLP()
# Apply L1 unstructured pruning — remove 40% of weights
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
prune.l1_unstructured(module, name="weight", amount=0.4)
# Check sparsity
def compute_sparsity(model):
total = 0
zeros = 0
for name, param in model.named_parameters():
if "weight" in name:
total += param.numel()
zeros += (param == 0).sum().item()
return zeros / total * 100
print(f"Model sparsity: {compute_sparsity(model):.1f}%")
# Inspect pruning mask for fc1
print(f"\nfc1 weight mask shape: {model.fc1.weight_mask.shape}")
print(f"fc1 mask zeros: {(model.fc1.weight_mask == 0).sum().item()}")
print(f"fc1 mask ones: {(model.fc1.weight_mask == 1).sum().item()}")
# Make pruning permanent (remove the mask, apply it to weight)
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
prune.remove(module, "weight")
print(f"\nFinal sparsity after making permanent: {compute_sparsity(model):.1f}%")
The prune.l1_unstructured call doesn't actually delete weights — it creates a mask (weight_mask) that multiplies with the weight tensor, effectively zeroing out pruned weights while keeping the tensor shape the same. Calling prune.remove makes the pruning permanent by applying the mask directly to the weight and removing the mask parameter.
Structured Pruning
Structured pruning removes entire rows or columns from weight matrices — equivalent to removing whole neurons or filters. Unlike unstructured pruning, this actually reduces the model's computation graph and speeds up inference on standard (non-sparse) hardware.
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
# Create a model
model = nn.Sequential(
nn.Linear(100, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 10),
)
# Structured pruning: remove 30% of output neurons (rows) from first layer
# based on L2 norm of each row
prune.ln_structured(model[0], name="weight", amount=0.3, n=2, dim=0)
# Check which neurons were pruned
mask = model[0].weight_mask
neurons_kept = mask.any(dim=1).sum().item()
neurons_pruned = (~mask.any(dim=1)).sum().item()
print(f"Neurons kept: {neurons_kept}, pruned: {neurons_pruned}")
print(f"Original shape: {model[0].weight.shape}") # Still (64, 100)
# Global pruning — prune 50% of ALL weights across the model
parameters_to_prune = [
(model[0], "weight"),
(model[2], "weight"),
(model[4], "weight"),
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.5, # 50% of ALL weights globally
)
# Check global sparsity
total_zeros = sum(
(getattr(m, "weight_mask", torch.ones(1)) == 0).sum().item()
for m in [model[0], model[2], model[4]]
)
total_params = sum(m.weight.numel() for m in [model[0], model[2], model[4]])
print(f"\nGlobal sparsity: {total_zeros / total_params * 100:.1f}%")
Global pruning (prune.global_unstructured) is often more effective than per-layer pruning because it prunes the globally least important weights regardless of which layer they're in. This means layers with more redundancy contribute more pruned weights, while critical layers keep more of their capacity.
Knowledge Distillation
Knowledge distillation is a technique where a small student model learns to mimic a large teacher model. Instead of training the student on hard labels (one-hot vectors like [0, 0, 1, 0]), you train it on the teacher's soft predictions — the full probability distribution over all classes. These soft targets carry much more information than hard labels because they encode the teacher's knowledge about inter-class relationships.
Why Soft Targets Work
Consider a digit classifier. A hard label for "7" is simply [0,0,0,0,0,0,0,1,0,0]. But a teacher's soft prediction might be [0.001, 0.01, 0.05, 0.001, 0.001, 0.001, 0.001, 0.88, 0.001, 0.05]. This tells the student that "7" looks a little like "2" and "9" — valuable structural knowledge that hard labels completely miss. The temperature parameter controls how "soft" these probabilities become.
Distillation Training Loop
The distillation loss combines two objectives: (1) the standard cross-entropy between the student's predictions and the true labels, and (2) the KL divergence between the student's and teacher's softened probability distributions. The temperature parameter controls how much to soften the distributions — higher temperatures produce softer probabilities that reveal more inter-class structure.
import torch
import torch.nn as nn
import torch.nn.functional as F
# Teacher: large model (pretrained, frozen)
teacher = nn.Sequential(
nn.Linear(784, 512), nn.ReLU(),
nn.Linear(512, 256), nn.ReLU(),
nn.Linear(256, 10),
)
# Student: small model (to be trained)
student = nn.Sequential(
nn.Linear(784, 64), nn.ReLU(),
nn.Linear(64, 10),
)
def distillation_loss(student_logits, teacher_logits, true_labels,
temperature=4.0, alpha=0.7):
"""
Combined distillation loss:
- alpha * KL divergence between soft predictions (teacher knowledge)
- (1-alpha) * cross-entropy with true labels (ground truth)
"""
# Soft targets from teacher and student
soft_student = F.log_softmax(student_logits / temperature, dim=1)
soft_teacher = F.softmax(teacher_logits / temperature, dim=1)
# KL divergence (multiply by T^2 to match gradient magnitudes)
distill_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean")
distill_loss *= temperature ** 2
# Standard cross-entropy with true labels
hard_loss = F.cross_entropy(student_logits, true_labels)
return alpha * distill_loss + (1 - alpha) * hard_loss
# Training example
teacher.eval() # Teacher is frozen
student.train()
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
# Simulate training
for epoch in range(3):
# Fake batch: 64 samples of 784 features, 10 classes
inputs = torch.randn(64, 784)
labels = torch.randint(0, 10, (64,))
# Get teacher's predictions (no gradients needed)
with torch.no_grad():
teacher_logits = teacher(inputs)
# Get student's predictions
student_logits = student(inputs)
# Compute combined loss
loss = distillation_loss(student_logits, teacher_logits, labels,
temperature=4.0, alpha=0.7)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Distillation Loss: {loss.item():.4f}")
# Compare model sizes
teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())
print(f"\nTeacher parameters: {teacher_params:,}")
print(f"Student parameters: {student_params:,}")
print(f"Compression ratio: {teacher_params / student_params:.1f}x")
The alpha parameter balances the two loss components. Setting alpha=0.7 means 70% of the learning signal comes from mimicking the teacher and 30% from the true labels. Temperature values between 2 and 10 are typical — higher temperatures reveal more subtle inter-class relationships but can make the signal too diffuse if set too high.
Model Profiling
Before optimizing, you need to know where your model is slow. PyTorch's built-in profiler (torch.profiler) measures the execution time and memory consumption of every operation in your model, broken down by CPU and GPU. This is essential for making informed optimization decisions — don't guess, measure.
Using torch.profiler
The profiler wraps your inference code and records detailed timing information. You can then print a summary table sorted by CPU or GPU time to identify the most expensive operations. Here's how to profile a model and read the results.
import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity
# Define a model
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
)
self.classifier = nn.Linear(64, 10)
def forward(self, x):
with record_function("FEATURE_EXTRACTION"):
x = self.features(x)
x = x.view(x.size(0), -1)
with record_function("CLASSIFICATION"):
x = self.classifier(x)
return x
model = ConvNet()
model.eval()
# Profile CPU operations
inputs = torch.randn(8, 3, 32, 32)
with profile(
activities=[ProfilerActivity.CPU],
record_shapes=True,
profile_memory=True,
) as prof:
with torch.no_grad():
for _ in range(10):
output = model(inputs)
# Print top operations sorted by CPU time
print(prof.key_averages().table(
sort_by="cpu_time_total",
row_limit=10
))
# Print custom-labeled regions
print("\nCustom regions:")
print(prof.key_averages().table(
sort_by="cpu_time_total",
row_limit=5
))
The record_function context manager lets you label arbitrary code regions with custom names (like "FEATURE_EXTRACTION" and "CLASSIFICATION"), so you can see their aggregate cost in the profiler output. The profile_memory=True flag also tracks memory allocations, helping you find memory-hungry operations.
Memory Snapshot
Understanding GPU memory usage is critical when deploying large models. PyTorch tracks every allocation and deallocation, and you can snapshot the current state to see exactly how memory is distributed across your model's tensors.
import torch
import torch.nn as nn
# Check if CUDA is available (skip GPU parts if not)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
# Reset memory stats
torch.cuda.reset_peak_memory_stats()
model = nn.Sequential(
nn.Linear(1024, 2048),
nn.ReLU(),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Linear(1024, 10),
).to(device)
# Run a forward pass
x = torch.randn(32, 1024, device=device)
with torch.no_grad():
output = model(x)
# Memory statistics
allocated = torch.cuda.memory_allocated(device) / 1024**2
reserved = torch.cuda.memory_reserved(device) / 1024**2
peak = torch.cuda.max_memory_allocated(device) / 1024**2
print(f"Currently allocated: {allocated:.2f} MB")
print(f"Currently reserved: {reserved:.2f} MB")
print(f"Peak allocated: {peak:.2f} MB")
else:
# CPU memory tracking alternative
model = nn.Sequential(
nn.Linear(1024, 2048),
nn.ReLU(),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Linear(1024, 10),
)
# Count parameters by layer
for name, param in model.named_parameters():
mem_mb = param.numel() * param.element_size() / 1024**2
print(f"{name}: {param.shape} = {param.numel():,} params ({mem_mb:.3f} MB)")
total_params = sum(p.numel() for p in model.parameters())
total_mem = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
print(f"\nTotal: {total_params:,} params ({total_mem:.2f} MB)")
The difference between memory_allocated and memory_reserved is important: PyTorch's CUDA memory allocator reserves large blocks from the GPU and sub-allocates from them. reserved is the total GPU memory held by PyTorch's allocator, while allocated is the memory actually used by tensors. The gap is the allocator's free pool.
Serving with Flask/FastAPI
The most common way to deploy a PyTorch model is to wrap it in a REST API — an HTTP endpoint that accepts input data (typically JSON or image bytes), runs the model, and returns predictions. This lets any application (web, mobile, IoT) consume your model over the network without needing PyTorch installed on the client.
FastAPI Model Server
Here's a production-ready pattern for serving a PyTorch model with FastAPI. The model is loaded once at startup and reused for every request. Input validation ensures malformed requests are rejected before they reach the model. The torch.no_grad() context prevents gradient tracking, which would waste memory during inference.
# save as: app.py
# run with: uvicorn app:app --host 0.0.0.0 --port 8000
import torch
import torch.nn as nn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, validator
from typing import List
# --- Model Definition ---
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(4, 32),
nn.ReLU(),
nn.Linear(32, 3),
)
def forward(self, x):
return self.layers(x)
# --- Load model at startup ---
app = FastAPI(title="PyTorch Model API")
model = SimpleNet()
model.eval()
# In production: model = torch.jit.load("model.pt")
# --- Request/Response schemas ---
class PredictionRequest(BaseModel):
features: List[float]
@validator("features")
def check_length(cls, v):
if len(v) != 4:
raise ValueError("features must have exactly 4 values")
return v
class PredictionResponse(BaseModel):
predicted_class: int
confidence: float
probabilities: List[float]
# --- Endpoints ---
@app.get("/health")
def health_check():
return {"status": "healthy", "model_loaded": True}
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
try:
# Convert input to tensor
input_tensor = torch.tensor([request.features], dtype=torch.float32)
# Run inference
with torch.no_grad():
logits = model(input_tensor)
probs = torch.softmax(logits, dim=1)
predicted_class = probs.argmax(dim=1).item()
confidence = probs[0, predicted_class].item()
return PredictionResponse(
predicted_class=predicted_class,
confidence=round(confidence, 4),
probabilities=[round(p, 4) for p in probs[0].tolist()],
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Test locally (without running the server)
if __name__ == "__main__":
req = PredictionRequest(features=[5.1, 3.5, 1.4, 0.2])
result = predict(req)
print(f"Predicted class: {result.predicted_class}")
print(f"Confidence: {result.confidence}")
print(f"Probabilities: {result.probabilities}")
Key production patterns in this code: Pydantic validation rejects invalid inputs before they reach the model; the health endpoint lets load balancers check if the server is alive; the response schema guarantees a consistent API contract; and try/except catches unexpected errors without crashing the server.
Batch Inference Pattern
In production, you often want to process multiple inputs in a single request for efficiency. Batching amortizes the overhead of tensor creation and GPU kernel launches across many samples. Here's how to add a batch endpoint.
import torch
import torch.nn as nn
from typing import List
# Reuse model from above
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 3),
)
def forward(self, x):
return self.layers(x)
model = SimpleNet()
model.eval()
def batch_predict(batch_features: List[List[float]]) -> List[dict]:
"""Process multiple inputs in a single forward pass."""
# Validate all inputs
for i, features in enumerate(batch_features):
if len(features) != 4:
raise ValueError(f"Sample {i}: expected 4 features, got {len(features)}")
# Stack into a single tensor (batch inference)
input_tensor = torch.tensor(batch_features, dtype=torch.float32)
with torch.no_grad():
logits = model(input_tensor)
probs = torch.softmax(logits, dim=1)
classes = probs.argmax(dim=1)
# Format results
results = []
for i in range(len(batch_features)):
results.append({
"sample_index": i,
"predicted_class": classes[i].item(),
"confidence": round(probs[i, classes[i]].item(), 4),
})
return results
# Test batch inference
batch = [
[5.1, 3.5, 1.4, 0.2],
[6.2, 2.9, 4.3, 1.3],
[7.7, 3.0, 6.1, 2.3],
]
results = batch_predict(batch)
for r in results:
print(f"Sample {r['sample_index']}: class={r['predicted_class']}, "
f"confidence={r['confidence']}")
Batch inference is significantly faster than processing one sample at a time because modern hardware (especially GPUs) is designed for parallel computation. A single forward pass with a batch of 32 is typically much faster than 32 individual forward passes.
Mobile & Edge Deployment
Running models on mobile phones and edge devices (Raspberry Pi, NVIDIA Jetson, microcontrollers) has unique constraints: limited RAM, no GPU (or a weak one), limited battery, and strict latency requirements. PyTorch Mobile provides tools to optimize and package models specifically for these environments.
The key challenge is model size. A typical ResNet-50 is ~100 MB in FP32 — too large for a mobile app download. Combining TorchScript export with quantization and the mobile optimizer can shrink this to under 10 MB while maintaining acceptable accuracy.
Mobile Optimization Pipeline
PyTorch provides torch.utils.mobile_optimizer which applies a series of mobile-specific optimizations to a TorchScript model: fusing Conv+BN+ReLU operations, removing dropout layers, and optimizing memory layout for ARM processors.
import torch
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile
# Create a CNN suitable for mobile
class MobileClassifier(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1), nn.BatchNorm2d(16), nn.ReLU(),
nn.Conv2d(16, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
)
self.classifier = nn.Linear(32, 10)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
model = MobileClassifier()
model.eval()
# Step 1: Trace the model
sample = torch.randn(1, 3, 64, 64)
traced = torch.jit.trace(model, sample)
# Step 2: Apply mobile optimizations
optimized = optimize_for_mobile(traced)
# Step 3: Save for mobile deployment
optimized._save_for_lite_interpreter("model_mobile.ptl")
# Compare file sizes
import os
traced.save("model_full.pt")
full_size = os.path.getsize("model_full.pt") / 1024
mobile_size = os.path.getsize("model_mobile.ptl") / 1024
print(f"Full TorchScript: {full_size:.1f} KB")
print(f"Mobile optimized: {mobile_size:.1f} KB")
print(f"Size reduction: {(1 - mobile_size/full_size)*100:.1f}%")
# Verify output matches
with torch.no_grad():
orig_out = traced(sample)
opt_out = optimized(sample)
print(f"\nOutputs match: {torch.allclose(orig_out, opt_out, atol=1e-5)}")
# Clean up
os.remove("model_full.pt")
os.remove("model_mobile.ptl")
The _save_for_lite_interpreter method produces a .ptl file optimized for PyTorch's lightweight mobile interpreter. This file excludes unnecessary metadata and is designed for the LiteInterpreter class in the PyTorch Mobile SDK (available for iOS and Android). Combined with INT8 quantization, this pipeline can reduce a 100 MB model to under 5 MB.
Production Best Practices
Deploying a model isn't a one-time event — it's the beginning of a lifecycle. Models degrade over time as real-world data drifts away from the training distribution. Infrastructure fails. Dependencies update. This section covers the practices that keep production ML systems reliable.
Reproducibility
Reproducibility means getting the exact same results every time you run your code. This is critical for debugging, auditing, and regulatory compliance. PyTorch's randomness comes from multiple sources (Python, NumPy, CUDA), and you need to seed all of them.
import torch
import numpy as np
import random
import os
def set_seed(seed: int = 42):
"""Set all random seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Make CuDNN deterministic (may reduce performance)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Enable deterministic algorithms (PyTorch 1.8+)
# Raises error if a non-deterministic algorithm is used
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True, warn_only=True)
# Set seed and verify reproducibility
set_seed(42)
a = torch.randn(3, 3)
set_seed(42)
b = torch.randn(3, 3)
print(f"Tensors identical: {torch.equal(a, b)}") # True
print(f"Tensor values:\n{a}")
# Model weight initialization is also reproducible
set_seed(42)
model1 = torch.nn.Linear(10, 5)
set_seed(42)
model2 = torch.nn.Linear(10, 5)
print(f"\nModel weights identical: {torch.equal(model1.weight, model2.weight)}")
Setting torch.backends.cudnn.deterministic = True forces cuDNN to use deterministic convolution algorithms. This guarantees bit-exact results but may be 10–20% slower. For training, you might want to leave this off for speed; for evaluation and production, determinism is usually worth the cost.
Monitoring & Model Drift Detection
Once deployed, your model will encounter data it has never seen during training. Data drift occurs when the distribution of incoming data changes over time — and when it does, your model's predictions become unreliable. Monitoring is your early warning system. Track prediction distributions, input statistics, and confidence scores to detect when the model needs retraining.
import torch
import numpy as np
from collections import deque
from datetime import datetime
class ModelMonitor:
"""Track prediction statistics for drift detection."""
def __init__(self, window_size=1000):
self.predictions = deque(maxlen=window_size)
self.confidences = deque(maxlen=window_size)
self.input_means = deque(maxlen=window_size)
self.baseline_conf = None
self.baseline_dist = None
def log_prediction(self, input_tensor, logits):
"""Record a prediction for monitoring."""
probs = torch.softmax(logits, dim=1)
pred_class = probs.argmax(dim=1).item()
confidence = probs.max(dim=1).values.item()
self.predictions.append(pred_class)
self.confidences.append(confidence)
self.input_means.append(input_tensor.mean().item())
def set_baseline(self):
"""Snapshot current statistics as the baseline."""
self.baseline_conf = np.mean(list(self.confidences))
counts = np.bincount(list(self.predictions))
self.baseline_dist = counts / counts.sum()
print(f"Baseline set: avg_confidence={self.baseline_conf:.4f}")
print(f"Baseline class distribution: {self.baseline_dist}")
def check_drift(self, threshold=0.1):
"""Compare current stats to baseline."""
if self.baseline_conf is None:
return {"status": "no_baseline"}
current_conf = np.mean(list(self.confidences))
conf_drift = abs(current_conf - self.baseline_conf)
alert = conf_drift > threshold
return {
"timestamp": datetime.now().isoformat(),
"baseline_confidence": round(self.baseline_conf, 4),
"current_confidence": round(current_conf, 4),
"confidence_drift": round(conf_drift, 4),
"drift_detected": alert,
"samples_tracked": len(self.confidences),
}
# Demo usage
monitor = ModelMonitor(window_size=500)
# Simulate a simple model
model = torch.nn.Linear(10, 3)
model.eval()
# Phase 1: Normal data (establish baseline)
print("=== Phase 1: Baseline Data ===")
for _ in range(200):
x = torch.randn(1, 10)
with torch.no_grad():
logits = model(x)
monitor.log_prediction(x, logits)
monitor.set_baseline()
# Phase 2: Drifted data (shifted distribution)
print("\n=== Phase 2: Drifted Data ===")
for _ in range(200):
x = torch.randn(1, 10) + 5.0 # shift input distribution
with torch.no_grad():
logits = model(x)
monitor.log_prediction(x, logits)
report = monitor.check_drift(threshold=0.05)
for key, value in report.items():
print(f" {key}: {value}")
This ModelMonitor tracks three key signals: prediction distribution (are certain classes being predicted much more or less than during baseline?), confidence levels (dropping confidence often signals out-of-distribution data), and input statistics (shifted feature means/variances indicate data drift). In production, these metrics would be reported to a dashboard like Grafana or Datadog.
Model Versioning & A/B Testing
Production systems need to track which model version produced each prediction (for debugging and auditing) and safely roll out new models without disrupting service. Here's a simple model registry pattern that supports loading specific versions and comparing performance.
import torch
import os
import json
from datetime import datetime
class ModelRegistry:
"""Simple file-based model version registry."""
def __init__(self, registry_dir="model_registry"):
self.registry_dir = registry_dir
os.makedirs(registry_dir, exist_ok=True)
def register(self, model, version, metrics=None):
"""Save a model version with metadata."""
version_dir = os.path.join(self.registry_dir, f"v{version}")
os.makedirs(version_dir, exist_ok=True)
# Save model
model_path = os.path.join(version_dir, "model.pt")
torch.save(model.state_dict(), model_path)
# Save metadata
metadata = {
"version": version,
"registered_at": datetime.now().isoformat(),
"parameters": sum(p.numel() for p in model.parameters()),
"metrics": metrics or {},
}
meta_path = os.path.join(version_dir, "metadata.json")
with open(meta_path, "w") as f:
json.dump(metadata, f, indent=2)
print(f"Registered model v{version}: {metadata['parameters']:,} params")
return metadata
def list_versions(self):
"""List all registered versions."""
versions = sorted(os.listdir(self.registry_dir))
for v in versions:
meta_path = os.path.join(self.registry_dir, v, "metadata.json")
if os.path.exists(meta_path):
with open(meta_path) as f:
meta = json.load(f)
print(f" {v}: {meta['parameters']:,} params, "
f"registered {meta['registered_at'][:10]}")
# Demo: register two model versions
import torch.nn as nn
registry = ModelRegistry("./demo_registry")
# Version 1: small model
model_v1 = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 3))
registry.register(model_v1, version="1.0",
metrics={"accuracy": 0.85, "latency_ms": 2.1})
# Version 2: larger model
model_v2 = nn.Sequential(nn.Linear(10, 128), nn.ReLU(), nn.Linear(128, 3))
registry.register(model_v2, version="2.0",
metrics={"accuracy": 0.91, "latency_ms": 3.8})
# List all versions
print("\nRegistered models:")
registry.list_versions()
# Clean up demo files
import shutil
shutil.rmtree("./demo_registry")
In practice, you'd use a dedicated model registry like MLflow, DVC, or Weights & Biases. The pattern above demonstrates the core concept: every model deployment is versioned, tagged with metrics, and traceable back to the training run that produced it.
Conclusion & Series Recap
Congratulations — you've completed the PyTorch Mastery series! In this final part, you learned how to bridge the gap between a trained model and a production system. Let's recap the key tools in your deployment toolkit:
Your Production Arsenal
- TorchScript — Export models for Python-free inference; choose tracing for simple models, scripting for models with control flow
- ONNX — Universal model format for cross-framework deployment; pair with ONNX Runtime for optimized inference
- Quantization — Shrink models 4× with minimal accuracy loss; dynamic quantization is the easiest starting point
- Pruning — Remove unimportant weights; global pruning is more effective than per-layer
- Knowledge Distillation — Train a small student model to mimic a large teacher using soft targets
- Profiling — Measure before optimizing;
torch.profilerreveals exactly where time and memory are spent - REST APIs — Serve models with FastAPI for production-grade endpoints with validation and documentation
- Mobile — Combine TorchScript + quantization + mobile optimizer for on-device inference
- Best Practices — Reproducibility, monitoring, versioning, and drift detection keep systems reliable