What are GANs?
Imagine a counterfeiter who tries to produce fake currency that looks identical to real bills, and a detective whose job is to catch the fakes. As the counterfeiter gets better at forging, the detective must sharpen their skills to tell real from fake. Over time, this adversarial competition pushes both to become experts — the counterfeiter produces nearly perfect fakes, and the detective becomes incredibly discerning.
This is exactly how Generative Adversarial Networks (GANs) work. Introduced by Ian Goodfellow and colleagues in 2014, GANs consist of two neural networks — a Generator (the counterfeiter) and a Discriminator (the detective) — locked in a competitive game. The Generator creates synthetic data samples, while the Discriminator tries to distinguish real data from generated fakes. Through this adversarial training, the Generator learns to produce remarkably realistic outputs.
Why Generative Models Matter
Generative models like GANs have transformed numerous fields:
- Data Augmentation: Generate synthetic training data when real data is scarce (medical imaging, rare defects)
- Creative AI: Art generation, style transfer, deepfakes, virtual try-on
- Drug Discovery: Generate novel molecular structures with desired properties
- Super-Resolution: Enhance low-resolution images to high-resolution (ESRGAN)
- Image-to-Image Translation: Convert sketches to photos, day to night, summer to winter
Let's start by verifying our PyTorch setup and understanding the basic building blocks we'll need throughout this deep dive.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# Check device availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")
# Set random seed for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
# Basic GAN concept: Generator maps noise to data space
latent_dim = 100 # Dimension of the noise vector z
noise = torch.randn(4, latent_dim, device=device)
print(f"Noise vector shape: {noise.shape}")
print(f"Noise statistics - Mean: {noise.mean():.4f}, Std: {noise.std():.4f}")
The code above establishes our working environment and introduces the fundamental concept: GANs map random noise vectors (sampled from a simple distribution like a standard normal) into the data space (e.g., realistic images). The latent dimension of 100 is a common choice — it's large enough to encode diverse outputs but small enough to train efficiently.
The Adversarial Game
The GAN framework is formalized as a two-player minimax game. The Generator $G$ tries to minimize its objective while the Discriminator $D$ tries to maximize it. They play against each other with opposing goals:
- Generator ($G$): Takes random noise $z$ and produces a fake sample $G(z)$. Its goal is to fool the Discriminator into classifying fakes as real.
- Discriminator ($D$): Takes either a real sample $x$ or a fake $G(z)$ and outputs a probability that the input is real. Its goal is to correctly classify real vs fake.
The Minimax Objective Function
The full GAN objective is expressed as:
$$\min_G \max_D \mathbb{E}_{x \sim p_\text{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$$
Let's break this down piece by piece:
- $\mathbb{E}_{x \sim p_\text{data}}[\log D(x)]$ — The Discriminator wants to output values close to 1 for real data, maximizing $\log D(x)$
- $\mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$ — For fake data $G(z)$, the Discriminator wants $D(G(z)) \approx 0$, maximizing $\log(1 - 0) = 0$
- The Generator wants $D(G(z)) \approx 1$, which minimizes $\log(1 - D(G(z)))$ toward $-\infty$
In practice, we use a non-saturating loss for the Generator to avoid vanishing gradients early in training:
$$\mathcal{L}_G = -\mathbb{E}_{z \sim p_z}[\log D(G(z))]$$
This reformulation provides stronger gradients when the Generator is poor (early training), since $-\log(D(G(z)))$ has a steep gradient when $D(G(z))$ is near 0.
Training Dynamics
The following diagram shows how the GAN training loop works. Notice the alternating optimization — we update the Discriminator and Generator in separate steps:
flowchart TD
A[Sample real data x from dataset] --> B[Sample noise z from N(0,1)]
B --> C[Generate fake data: G(z)]
C --> D{Train Discriminator}
D --> E[D classifies real x → target: 1]
D --> F[D classifies fake G(z) → target: 0]
E --> G[Compute D loss and update D weights]
F --> G
G --> H{Train Generator}
H --> I[Generate new fakes G(z)]
I --> J[D classifies G(z) → G wants output: 1]
J --> K[Compute G loss and update G weights]
K --> L{Convergence?}
L -->|No| A
L -->|Yes| M[Generator produces realistic samples]
The ideal endpoint is a Nash equilibrium where the Generator produces samples indistinguishable from real data, and the Discriminator outputs 0.5 for everything (it can't tell real from fake). In practice, reaching this equilibrium is challenging, which is why GAN training requires careful attention to stability.
import torch
import torch.nn as nn
# Demonstrate the GAN loss functions
# Binary Cross-Entropy loss is the standard GAN loss
bce_loss = nn.BCELoss()
# Discriminator perspective:
# Real samples should be classified as 1
d_output_real = torch.tensor([0.9, 0.8, 0.95, 0.7]) # D's outputs for real data
real_labels = torch.ones(4) # Target: all 1s
loss_real = bce_loss(d_output_real, real_labels)
print(f"D loss on real data: {loss_real.item():.4f}")
# Fake samples should be classified as 0
d_output_fake = torch.tensor([0.3, 0.1, 0.2, 0.4]) # D's outputs for fake data
fake_labels = torch.zeros(4) # Target: all 0s
loss_fake = bce_loss(d_output_fake, fake_labels)
print(f"D loss on fake data: {loss_fake.item():.4f}")
# Total discriminator loss
d_loss = loss_real + loss_fake
print(f"Total D loss: {d_loss.item():.4f}")
# Generator perspective (non-saturating loss):
# Generator wants D to output 1 for its fakes
g_output = torch.tensor([0.3, 0.1, 0.2, 0.4]) # D's output for G's fakes
g_loss = bce_loss(g_output, real_labels) # Target: 1 (fool D)
print(f"G loss: {g_loss.item():.4f}")
print(f"\nLower G loss = Generator is fooling Discriminator better")
This example demonstrates the loss computation from both perspectives. The Discriminator's loss penalizes misclassifications in both directions, while the Generator's loss (non-saturating form) rewards producing outputs that the Discriminator thinks are real. Notice we use real_labels (ones) as the target for the Generator's loss — this is because the Generator wants the Discriminator to believe its fakes are real.
Generator Architecture
The Generator's job is to transform a random noise vector $z \in \mathbb{R}^{100}$ into a realistic image (e.g., 28×28 pixels for MNIST or 64×64 for more complex datasets). This requires upsampling — gradually increasing spatial dimensions while decreasing channel depth.
The key building block is the Transposed Convolution (nn.ConvTranspose2d), sometimes misleadingly called "deconvolution." Unlike regular convolutions that reduce spatial dimensions, transposed convolutions increase them by inserting zeros between input pixels and then applying a standard convolution.
The standard Generator architecture follows this pattern:
- Reshape noise vector into a small spatial tensor (e.g., 4×4 with many channels)
- Apply transposed convolutions to progressively upsample (4→8→16→32→64)
- Use Batch Normalization after each layer (except the output) for training stability
- Use ReLU activation for all layers except the output (which uses Tanh)
Building the Generator from Scratch
The Generator takes a random noise vector z (typically 100 dimensions sampled from a normal distribution) and transforms it into a full image. We use ConvTranspose2d layers to progressively upsample from a tiny spatial size (4×4) to the target resolution, with batch normalization and ReLU activations between layers:
import torch
import torch.nn as nn
class Generator(nn.Module):
"""
Generator network for 28x28 grayscale images (e.g., MNIST).
Maps a latent vector z (100-dim) to a 1x28x28 image.
"""
def __init__(self, latent_dim=100, img_channels=1):
super(Generator, self).__init__()
self.latent_dim = latent_dim
# Project and reshape: 100 -> 256*7*7
self.project = nn.Sequential(
nn.Linear(latent_dim, 256 * 7 * 7),
nn.BatchNorm1d(256 * 7 * 7),
nn.ReLU(inplace=True)
)
# Upsample: 7x7 -> 14x14 -> 28x28
self.conv_blocks = nn.Sequential(
# Block 1: 256x7x7 -> 128x14x14
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
# Block 2: 128x14x14 -> 1x28x28
nn.ConvTranspose2d(128, img_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh() # Output in [-1, 1] range
)
def forward(self, z):
# z shape: (batch_size, latent_dim)
x = self.project(z) # (batch, 256*7*7)
x = x.view(-1, 256, 7, 7) # (batch, 256, 7, 7)
x = self.conv_blocks(x) # (batch, 1, 28, 28)
return x
# Test the generator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gen = Generator(latent_dim=100).to(device)
# Generate a batch of fake images from random noise
noise = torch.randn(16, 100, device=device)
fake_images = gen(noise)
print(f"Generator output shape: {fake_images.shape}")
print(f"Output range: [{fake_images.min().item():.3f}, {fake_images.max().item():.3f}]")
print(f"Generator parameters: {sum(p.numel() for p in gen.parameters()):,}")
Our Generator takes a 100-dimensional noise vector and produces a 28×28 grayscale image. The architecture uses a project-then-reshape strategy: first a linear layer maps the noise to a high-dimensional vector, which is then reshaped into a small spatial feature map (7×7 with 256 channels). Two transposed convolution blocks double the spatial dimensions at each step (7→14→28), while halving the channels. The final Tanh activation ensures output pixel values are in [-1, 1], matching normalized image data.
Discriminator Architecture
The Discriminator is a binary classifier that takes an image (real or fake) and outputs a single probability: how likely the image is real. Its architecture is essentially the mirror of the Generator — it uses standard (strided) convolutions to downsample the spatial dimensions while increasing channel depth.
Key design principles for the Discriminator:
- Strided Convolutions: Use stride=2 instead of pooling layers (learned downsampling)
- LeakyReLU: Allows small gradients for negative values (slope=0.2), preventing dead neurons
- No Batch Norm in First Layer: Applying BN to the input layer can destabilize training
- Sigmoid Output: Produces probability in [0, 1] for real/fake classification
Building the Discriminator from Scratch
The Discriminator is a binary classifier that takes an image (real or generated) and outputs a single scalar probability — how likely the image is real. It uses strided convolutions to downsample progressively, with LeakyReLU (slope 0.2) to avoid dead neurons. Note that we deliberately omit batch normalization in the first layer:
import torch
import torch.nn as nn
class Discriminator(nn.Module):
"""
Discriminator network for 28x28 grayscale images.
Classifies images as real (1) or fake (0).
"""
def __init__(self, img_channels=1):
super(Discriminator, self).__init__()
self.conv_blocks = nn.Sequential(
# Block 1: 1x28x28 -> 64x14x14 (NO batch norm in first layer)
nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
# Block 2: 64x14x14 -> 128x7x7
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
)
# Classifier: flatten and output probability
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 7 * 7, 1),
nn.Sigmoid()
)
def forward(self, img):
# img shape: (batch_size, 1, 28, 28)
features = self.conv_blocks(img) # (batch, 128, 7, 7)
validity = self.classifier(features) # (batch, 1)
return validity
# Test the discriminator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
disc = Discriminator().to(device)
# Test with random image-shaped input
fake_input = torch.randn(16, 1, 28, 28, device=device)
prediction = disc(fake_input)
print(f"Discriminator output shape: {prediction.shape}")
print(f"Predictions (should be ~0.5 for random input):")
print(f" Mean: {prediction.mean().item():.4f}")
print(f" Range: [{prediction.min().item():.4f}, {prediction.max().item():.4f}]")
print(f"Discriminator parameters: {sum(p.numel() for p in disc.parameters()):,}")
The Discriminator mirrors the Generator: strided convolutions halve spatial dimensions (28→14→7) while increasing channels (1→64→128). We intentionally skip Batch Normalization in the first layer — applying it directly to the input images can cause the Discriminator to learn batch statistics rather than individual image features, destabilizing training. The output is a single sigmoid-activated value representing "probability of being real."
Training Loop
GAN training uses alternating optimization — we train the Discriminator and Generator in separate steps within each iteration. The critical insight is that we always train the Discriminator first, because the Generator needs meaningful gradients from a competent Discriminator to learn effectively.
The training loop follows this structure for each batch:
- Train Discriminator:
- Pass real images through D → compute loss against label=1
- Generate fake images, pass through D → compute loss against label=0
- Backpropagate total D loss, update D weights
- Train Generator:
- Generate fake images, pass through D → compute loss against label=1 (fool D)
- Backpropagate G loss, update G weights only
optimizer_G.step().
Label Smoothing is a common trick where instead of using hard labels (1.0 for real, 0.0 for fake), we use softened values (e.g., 0.9 for real). This prevents the Discriminator from becoming too confident too early, which would give the Generator poor gradient signals.
Complete Training Loop Implementation
Now we wire the Generator and Discriminator together in an alternating training loop. In each iteration, we first train D on a batch of real images (label ≈ 1) and a batch of fake images from G (label = 0), then train G by generating fakes and trying to fool D into outputting 1. The key insight is that G never sees real images directly — it only learns through the gradient signal from D:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
# Define Generator and Discriminator (compact versions for demonstration)
class Generator(nn.Module):
def __init__(self, latent_dim=100):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim, 256 * 7 * 7),
nn.BatchNorm1d(256 * 7 * 7),
nn.ReLU(True),
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 1, 4, 2, 1),
nn.Tanh(),
)
def forward(self, z):
x = self.net(z).view(-1, 256, 7, 7)
return self.conv(x)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 64, 4, 2, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, True),
nn.Flatten(),
nn.Linear(128 * 7 * 7, 1),
nn.Sigmoid(),
)
def forward(self, img):
return self.net(img)
# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_dim = 100
# Initialize models
gen = Generator(latent_dim).to(device)
disc = Discriminator().to(device)
# Optimizers (Adam with beta1=0.5 is standard for GANs)
optimizer_G = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(disc.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Loss function
criterion = nn.BCELoss()
# Load MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
])
dataset = torchvision.datasets.MNIST(root='./data', train=True,
transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
# Training loop (abbreviated: 2 epochs for demonstration)
num_epochs = 2
g_losses, d_losses = [], []
for epoch in range(num_epochs):
for i, (real_imgs, _) in enumerate(dataloader):
batch_size = real_imgs.size(0)
real_imgs = real_imgs.to(device)
# Labels with smoothing
real_labels = torch.full((batch_size, 1), 0.9, device=device)
fake_labels = torch.zeros(batch_size, 1, device=device)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Loss on real images
output_real = disc(real_imgs)
loss_real = criterion(output_real, real_labels)
# Loss on fake images
noise = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = gen(noise)
output_fake = disc(fake_imgs.detach()) # detach: don't backprop through G
loss_fake = criterion(output_fake, fake_labels)
# Total D loss
d_loss = loss_real + loss_fake
d_loss.backward()
optimizer_D.step()
# ---------------------
# Train Generator
# ---------------------
optimizer_G.zero_grad()
# Generate fakes and get D's opinion
noise = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = gen(noise)
output = disc(fake_imgs)
# G wants D to output 1 (real) for its fakes
g_loss = criterion(output, torch.ones(batch_size, 1, device=device))
g_loss.backward()
optimizer_G.step()
# Track losses
g_losses.append(g_loss.item())
d_losses.append(d_loss.item())
if i % 200 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i}/{len(dataloader)}] "
f"D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}")
print(f"\nTraining complete! Final D_loss: {d_losses[-1]:.4f}, G_loss: {g_losses[-1]:.4f}")
This complete training loop demonstrates several crucial GAN training practices. Notice the .detach() call when training the Discriminator on fake images — this prevents gradients from flowing back into the Generator during the D update step. We use Adam with $\beta_1 = 0.5$ (instead of the default 0.9) as recommended in the DCGAN paper, which helps stabilize adversarial training. Label smoothing (0.9 instead of 1.0 for real labels) acts as a regularizer for the Discriminator.
DCGAN: Deep Convolutional GAN
DCGAN (Deep Convolutional GAN, Radford et al. 2016) established a set of architectural guidelines that made convolutional GANs stable to train. Before DCGAN, training GANs with convolutional layers was notoriously unreliable. The key guidelines are:
DCGAN Design Rules
- Replace pooling with strided convolutions (D) and transposed convolutions (G)
- Use Batch Normalization in both G and D (except D's input layer and G's output layer)
- Remove fully connected layers for deeper architectures (use global average pooling in D)
- Use ReLU in Generator (all layers except output, which uses Tanh)
- Use LeakyReLU in Discriminator (all layers, slope=0.2)
flowchart LR
A["z ∈ R^100
Random Noise"] --> B["Project & Reshape
4×4×1024"]
B --> C["ConvT 4×4, s2
BN + ReLU
8×8×512"]
C --> D["ConvT 4×4, s2
BN + ReLU
16×16×256"]
D --> E["ConvT 4×4, s2
BN + ReLU
32×32×128"]
E --> F["ConvT 4×4, s2
Tanh
64×64×3"]
The DCGAN architecture progressively doubles spatial dimensions at each layer while halving the number of feature maps. This creates a smooth upsampling path from the compact latent space to the full image resolution.
Full DCGAN Implementation
Here is the complete DCGAN implementation following the architectural guidelines from Radford et al. (2015). The Generator uses ConvTranspose2d with stride 2 for upsampling, while the Discriminator uses strided Conv2d for downsampling — no pooling layers anywhere. We initialize all weights from a normal distribution with mean=0 and std=0.02:
import torch
import torch.nn as nn
class DCGANGenerator(nn.Module):
"""
DCGAN Generator for 64x64 RGB images.
Follows all DCGAN architectural guidelines.
"""
def __init__(self, latent_dim=100, feature_maps=64, img_channels=3):
super().__init__()
self.main = nn.Sequential(
# Input: latent_dim x 1 x 1
# Output: (feature_maps*8) x 4 x 4
nn.ConvTranspose2d(latent_dim, feature_maps * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(feature_maps * 8),
nn.ReLU(True),
# State: (feature_maps*8) x 4 x 4
# Output: (feature_maps*4) x 8 x 8
nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 4),
nn.ReLU(True),
# State: (feature_maps*4) x 8 x 8
# Output: (feature_maps*2) x 16 x 16
nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 2),
nn.ReLU(True),
# State: (feature_maps*2) x 16 x 16
# Output: feature_maps x 32 x 32
nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps),
nn.ReLU(True),
# State: feature_maps x 32 x 32
# Output: img_channels x 64 x 64
nn.ConvTranspose2d(feature_maps, img_channels, 4, 2, 1, bias=False),
nn.Tanh() # No BN in output layer
)
def forward(self, z):
# Reshape z to (batch, latent_dim, 1, 1) for ConvTranspose2d
return self.main(z.view(-1, z.size(1), 1, 1))
class DCGANDiscriminator(nn.Module):
"""
DCGAN Discriminator for 64x64 RGB images.
Mirror of the Generator with strided convolutions.
"""
def __init__(self, feature_maps=64, img_channels=3):
super().__init__()
self.main = nn.Sequential(
# Input: img_channels x 64 x 64
# No BatchNorm in first layer!
nn.Conv2d(img_channels, feature_maps, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# State: feature_maps x 32 x 32
nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 2),
nn.LeakyReLU(0.2, inplace=True),
# State: (feature_maps*2) x 16 x 16
nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 4),
nn.LeakyReLU(0.2, inplace=True),
# State: (feature_maps*4) x 8 x 8
nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 8),
nn.LeakyReLU(0.2, inplace=True),
# State: (feature_maps*8) x 4 x 4
nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, img):
return self.main(img).view(-1, 1)
# Weight initialization (DCGAN recommendation)
def weights_init(m):
"""Initialize weights from N(0, 0.02) for Conv and BN layers."""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# Instantiate and initialize
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gen = DCGANGenerator(latent_dim=100, feature_maps=64, img_channels=1).to(device)
disc = DCGANDiscriminator(feature_maps=64, img_channels=1).to(device)
gen.apply(weights_init)
disc.apply(weights_init)
# Verify architectures
noise = torch.randn(8, 100, device=device)
fake = gen(noise)
pred = disc(fake)
print(f"Generator: noise {noise.shape} -> image {fake.shape}")
print(f"Discriminator: image {fake.shape} -> prediction {pred.shape}")
print(f"G parameters: {sum(p.numel() for p in gen.parameters()):,}")
print(f"D parameters: {sum(p.numel() for p in disc.parameters()):,}")
This DCGAN implementation follows all the architectural guidelines precisely. Note the weights_init function — the DCGAN paper recommends initializing all weights from a Normal distribution with mean=0 and std=0.02. This seemingly small detail has a significant impact on training stability. Also notice bias=False in all convolutional layers that are followed by Batch Normalization, since BN already includes a learnable bias term.
Let's see how to train this DCGAN with proper weight initialization and the recommended hyperparameters:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# DCGAN Training Configuration
# These hyperparameters are from the original DCGAN paper
config = {
'latent_dim': 100,
'batch_size': 128,
'lr': 0.0002,
'beta1': 0.5, # Adam beta1 (0.5 not 0.9!)
'beta2': 0.999,
'num_epochs': 5,
'image_size': 64,
'feature_maps_g': 64,
'feature_maps_d': 64,
}
# Dataset: Fashion-MNIST resized to 64x64
transform = transforms.Compose([
transforms.Resize(64),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
dataset = torchvision.datasets.FashionMNIST(
root='./data', train=True, transform=transform, download=True
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=config['batch_size'], shuffle=True, num_workers=0
)
print(f"Dataset size: {len(dataset)} images")
print(f"Batches per epoch: {len(dataloader)}")
print(f"Training config: {config}")
# Fixed noise for tracking generation quality over time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fixed_noise = torch.randn(64, config['latent_dim'], device=device)
print(f"\nFixed noise for visualization: {fixed_noise.shape}")
print(f"This same noise will be used each epoch to see how generation improves")
Using a fixed noise vector throughout training is a crucial visualization technique. By feeding the same noise to the Generator at different training stages, you can visually track how the generated images improve over time. The Fashion-MNIST dataset provides more visual diversity than regular MNIST (clothing items vs digits), making it easier to evaluate generation quality.
Training Instabilities & Solutions
GAN training is notoriously difficult. The adversarial game creates a delicate balance that can easily break down. Here are the most common failure modes:
GAN Training Failure Modes
- Mode Collapse: Generator produces only a few distinct outputs, ignoring most of the data distribution. For MNIST, it might only generate "1"s and "7"s.
- Vanishing Gradients: If the Discriminator becomes too good too fast, the Generator receives near-zero gradients and stops learning.
- Training Divergence: Losses oscillate wildly without converging. G and D losses don't stabilize.
- Oscillation: G and D take turns "winning" without reaching equilibrium.
The root cause of many instabilities lies in the original GAN loss function. The JS (Jensen-Shannon) divergence used in standard GANs can have zero gradients when the real and generated distributions don't overlap — which is common in high-dimensional spaces. This insight led to the development of Wasserstein GANs (WGAN).
Wasserstein Distance
The Wasserstein-1 distance (Earth Mover's distance) measures the minimum "cost" of transporting mass from one distribution to another:
$$W(p_r, p_g) = \sup_{\|f\|_L \leq 1} \mathbb{E}_{x \sim p_r}[f(x)] - \mathbb{E}_{x \sim p_g}[f(x)]$$
Unlike JS divergence, Wasserstein distance provides meaningful gradients even when distributions don't overlap. The Discriminator (called "Critic" in WGAN terminology) no longer outputs probabilities — it outputs an unconstrained real number representing "realness."
WGAN-GP: Wasserstein GAN with Gradient Penalty
The original WGAN enforced the Lipschitz constraint by clipping weights, which was crude and led to capacity underuse. WGAN-GP (Gulrajani et al. 2017) replaced weight clipping with a gradient penalty — penalizing the critic when its gradient norm deviates from 1:
import torch
import torch.nn as nn
import torch.autograd as autograd
def compute_gradient_penalty(critic, real_samples, fake_samples, device):
"""
Compute gradient penalty for WGAN-GP.
Enforces 1-Lipschitz constraint on the critic.
"""
batch_size = real_samples.size(0)
# Random interpolation between real and fake samples
alpha = torch.rand(batch_size, 1, 1, 1, device=device)
interpolated = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
# Critic's output for interpolated samples
critic_interpolated = critic(interpolated)
# Compute gradients of critic output w.r.t. interpolated input
gradients = autograd.grad(
outputs=critic_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(critic_interpolated),
create_graph=True, # Need second-order gradients for backprop
retain_graph=True,
)[0]
# Flatten gradients and compute L2 norm per sample
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
# Penalty: (||grad|| - 1)^2
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
return gradient_penalty
# Demonstrate gradient penalty computation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Simple critic for demonstration
critic = nn.Sequential(
nn.Conv2d(1, 32, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Flatten(),
nn.Linear(64 * 7 * 7, 1) # No sigmoid! Critic outputs unbounded values
).to(device)
# Test with dummy data
real = torch.randn(8, 1, 28, 28, device=device)
fake = torch.randn(8, 1, 28, 28, device=device)
gp = compute_gradient_penalty(critic, real, fake, device)
print(f"Gradient penalty: {gp.item():.4f}")
print(f"Target: gradient norm should be close to 1.0")
print(f"Penalty pushes critic to be 1-Lipschitz continuous")
The gradient penalty works by sampling random points between real and fake data (via linear interpolation with random $\alpha$), computing the critic's gradient at those points, and penalizing any deviation from a gradient norm of 1. This is more stable than weight clipping because it doesn't limit the critic's capacity — it just constrains how fast the output can change.
Now let's implement the complete WGAN-GP training loop with the critic (note: no sigmoid, different loss formulation):
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
class WGANCritic(nn.Module):
"""WGAN Critic (no sigmoid output)."""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 64, 4, 2, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 128, 4, 2, 1),
# No BN in WGAN-GP critic (interferes with gradient penalty)
nn.LeakyReLU(0.2, True),
nn.Flatten(),
nn.Linear(128 * 7 * 7, 1),
# No sigmoid! Output is unbounded
)
def forward(self, img):
return self.net(img)
class WGANGenerator(nn.Module):
"""WGAN Generator (same as standard)."""
def __init__(self, latent_dim=100):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim, 256 * 7 * 7),
nn.BatchNorm1d(256 * 7 * 7),
nn.ReLU(True),
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 1, 4, 2, 1),
nn.Tanh(),
)
def forward(self, z):
return self.conv(self.net(z).view(-1, 256, 7, 7))
def gradient_penalty(critic, real, fake, device):
"""Compute GP for WGAN-GP."""
batch_size = real.size(0)
alpha = torch.rand(batch_size, 1, 1, 1, device=device)
interpolated = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
critic_out = critic(interpolated)
grads = autograd.grad(critic_out, interpolated,
grad_outputs=torch.ones_like(critic_out),
create_graph=True)[0]
grads = grads.view(batch_size, -1)
return ((grads.norm(2, dim=1) - 1) ** 2).mean()
# WGAN-GP Training Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_dim = 100
n_critic = 5 # Train critic 5x per generator step
lambda_gp = 10 # Gradient penalty coefficient
gen = WGANGenerator(latent_dim).to(device)
critic = WGANCritic().to(device)
# WGAN-GP uses Adam with different betas
opt_G = optim.Adam(gen.parameters(), lr=1e-4, betas=(0.0, 0.9))
opt_C = optim.Adam(critic.parameters(), lr=1e-4, betas=(0.0, 0.9))
# Simulated training step (single batch demonstration)
real_imgs = torch.randn(32, 1, 28, 28, device=device) # Simulated real batch
# Train Critic (n_critic steps)
for _ in range(n_critic):
noise = torch.randn(32, latent_dim, device=device)
fake_imgs = gen(noise).detach()
critic_real = critic(real_imgs).mean()
critic_fake = critic(fake_imgs).mean()
gp = gradient_penalty(critic, real_imgs, fake_imgs, device)
# WGAN-GP Critic loss: maximize (real - fake) with GP
critic_loss = critic_fake - critic_real + lambda_gp * gp
opt_C.zero_grad()
critic_loss.backward()
opt_C.step()
# Train Generator (1 step)
noise = torch.randn(32, latent_dim, device=device)
fake_imgs = gen(noise)
gen_loss = -critic(fake_imgs).mean() # Maximize critic output for fakes
opt_G.zero_grad()
gen_loss.backward()
opt_G.step()
print(f"Critic loss: {critic_loss.item():.4f}")
print(f" - Critic on real: {critic_real.item():.4f}")
print(f" - Critic on fake: {critic_fake.item():.4f}")
print(f" - Gradient penalty: {gp.item():.4f}")
print(f"Generator loss: {gen_loss.item():.4f}")
print(f"\nKey WGAN-GP differences:")
print(f" - No sigmoid in critic (unbounded output)")
print(f" - No BN in critic (interferes with GP)")
print(f" - Train critic {n_critic}x per G step")
print(f" - Lambda_GP = {lambda_gp}")
Key differences from standard GAN training:
- No sigmoid output in the critic — it outputs unbounded values (Wasserstein estimate)
- No Batch Normalization in critic — BN creates dependencies between samples in a batch, which violates the per-sample gradient penalty assumption
- Train critic more often (5 updates per 1 Generator update) — a well-trained critic provides better gradients
- Different Adam betas ($\beta_1 = 0.0, \beta_2 = 0.9$) — removes momentum from the first moment
- Loss is simple subtraction — no log, no BCE, just maximize (real_score - fake_score)
Conditional GANs (cGAN)
Standard GANs generate samples randomly — you can't control what they produce. Conditional GANs (cGAN) (Mirza & Osindero, 2014) solve this by conditioning both the Generator and Discriminator on additional information, typically a class label. This allows directed generation: "generate a shirt" or "generate the digit 7."
The conditioning works by:
- Embedding the label into a dense vector using
nn.Embedding - Concatenating the label embedding with the noise vector (for G) or with the image (for D)
- The networks learn to use this extra information to specialize their outputs
Conditional GAN Implementation
In a Conditional GAN, both the Generator and Discriminator receive additional information — typically a class label. The Generator can then produce images of a specific class on demand (e.g., "generate a 7"), and the Discriminator judges whether the image is a realistic example of that class. We embed the class label into a dense vector and concatenate it with the noise (for G) or the image features (for D):
import torch
import torch.nn as nn
class ConditionalGenerator(nn.Module):
"""
Conditional Generator for MNIST (10 classes).
Concatenates label embedding with noise vector.
"""
def __init__(self, latent_dim=100, num_classes=10, embed_dim=10):
super().__init__()
# Embed class label into dense vector
self.label_embedding = nn.Embedding(num_classes, embed_dim)
# Generator takes concatenated [noise, label_embed]
input_dim = latent_dim + embed_dim
self.net = nn.Sequential(
nn.Linear(input_dim, 256 * 7 * 7),
nn.BatchNorm1d(256 * 7 * 7),
nn.ReLU(True),
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 1, 4, 2, 1),
nn.Tanh(),
)
def forward(self, z, labels):
# z: (batch, latent_dim), labels: (batch,) integers
label_embed = self.label_embedding(labels) # (batch, embed_dim)
gen_input = torch.cat([z, label_embed], dim=1) # (batch, latent_dim + embed_dim)
x = self.net(gen_input).view(-1, 256, 7, 7)
return self.conv(x)
class ConditionalDiscriminator(nn.Module):
"""
Conditional Discriminator for MNIST.
Concatenates label as an extra channel to the image.
"""
def __init__(self, num_classes=10, img_size=28):
super().__init__()
# Embed label into a full image-sized channel
self.label_embedding = nn.Embedding(num_classes, img_size * img_size)
self.img_size = img_size
# Discriminator takes image (1 channel) + label map (1 channel) = 2 channels
self.net = nn.Sequential(
nn.Conv2d(2, 64, 4, 2, 1), # 2 input channels!
nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, True),
nn.Flatten(),
nn.Linear(128 * 7 * 7, 1),
nn.Sigmoid(),
)
def forward(self, img, labels):
# Embed label and reshape to image dimensions
label_map = self.label_embedding(labels) # (batch, img_size*img_size)
label_map = label_map.view(-1, 1, self.img_size, self.img_size) # (batch, 1, 28, 28)
# Concatenate image and label map along channel dimension
d_input = torch.cat([img, label_map], dim=1) # (batch, 2, 28, 28)
return self.net(d_input)
# Test Conditional GAN
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_dim = 100
cgen = ConditionalGenerator(latent_dim=latent_dim).to(device)
cdisc = ConditionalDiscriminator().to(device)
# Generate specific digits
noise = torch.randn(10, latent_dim, device=device)
labels = torch.arange(10, device=device) # Generate one of each digit (0-9)
fake_images = cgen(noise, labels)
predictions = cdisc(fake_images, labels)
print(f"Conditional Generator: noise {noise.shape} + labels {labels.shape} -> images {fake_images.shape}")
print(f"Conditional Discriminator: images + labels -> predictions {predictions.shape}")
print(f"\nGenerated images for labels: {labels.tolist()}")
print(f"We can now control WHAT the GAN generates!")
The Conditional GAN introduces label information at both ends of the adversarial game. For the Generator, we simply concatenate the label embedding with the noise vector before the first layer. For the Discriminator, we take a slightly different approach: we embed the label into a full spatial map (28×28) and concatenate it as an additional channel to the image. This gives the Discriminator pixel-level access to the label information, allowing it to verify that the generated content matches the specified class.
Let's see a complete conditional generation example where we control the output class:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
class SimpleConditionalGen(nn.Module):
"""Minimal cGAN generator for demonstration."""
def __init__(self, latent_dim=100, num_classes=10, embed_dim=10):
super().__init__()
self.embed = nn.Embedding(num_classes, embed_dim)
self.net = nn.Sequential(
nn.Linear(latent_dim + embed_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(True),
nn.Linear(512, 784),
nn.Tanh(),
)
def forward(self, z, labels):
gen_input = torch.cat([z, self.embed(labels)], dim=1)
return self.net(gen_input).view(-1, 1, 28, 28)
# Demonstrate controlled generation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gen = SimpleConditionalGen().to(device)
# Generate 5 images of each class
latent_dim = 100
num_classes = 10
images_per_class = 5
all_images = []
all_labels = []
for class_idx in range(num_classes):
noise = torch.randn(images_per_class, latent_dim, device=device)
labels = torch.full((images_per_class,), class_idx, dtype=torch.long, device=device)
with torch.no_grad():
generated = gen(noise, labels)
all_images.append(generated)
all_labels.extend([class_idx] * images_per_class)
all_images = torch.cat(all_images, dim=0)
print(f"Generated {all_images.shape[0]} images total")
print(f"5 images per class × 10 classes = {num_classes * images_per_class}")
print(f"Image shape: {all_images.shape[1:]}")
print(f"\nWith a trained cGAN, each row would show variations of the SAME digit/class")
print(f"Labels generated: {all_labels[:10]}... (first 10)")
Once trained, a Conditional GAN gives you explicit control over generation. You can generate specific digits, specific clothing items, or any class in your dataset by simply specifying the desired label. This is the foundation for many practical GAN applications like text-to-image generation and attribute manipulation.
Evaluating GANs
Evaluating generative models is fundamentally harder than evaluating classifiers. There's no single ground truth — we want images that are both high quality (realistic) and diverse (covering the full data distribution). The two most widely used metrics are:
Fréchet Inception Distance (FID)
FID measures the distance between the distribution of generated images and real images in the feature space of a pretrained Inception network. Lower FID = better quality and diversity. It computes the Fréchet distance between two multivariate Gaussians fitted to the Inception features of real and generated images.
Inception Score (IS)
IS measures two things simultaneously:
- Quality: Each generated image should be confidently classified (low entropy $p(y|x)$)
- Diversity: The marginal distribution $p(y)$ should be uniform across classes (high entropy)
Higher IS = better. However, IS has known limitations — it doesn't compare against real data and can be fooled by memorization.
import torch
import torch.nn as nn
import numpy as np
def calculate_fid_components(real_features, fake_features):
"""
Calculate FID components (mean and covariance) from feature vectors.
FID = ||mu_r - mu_f||^2 + Tr(Sigma_r + Sigma_f - 2*sqrt(Sigma_r * Sigma_f))
In practice, use the 'pytorch-fid' package for the full calculation.
This demonstrates the concept.
"""
# Compute statistics for real features
mu_real = real_features.mean(dim=0)
sigma_real = torch.cov(real_features.T)
# Compute statistics for fake features
mu_fake = fake_features.mean(dim=0)
sigma_fake = torch.cov(fake_features.T)
# Squared difference of means
mean_diff_sq = ((mu_real - mu_fake) ** 2).sum()
# Trace of covariance sum (simplified - full FID needs matrix sqrt)
trace_sum = sigma_real.trace() + sigma_fake.trace()
print(f"Real features - Mean norm: {mu_real.norm():.4f}")
print(f"Fake features - Mean norm: {mu_fake.norm():.4f}")
print(f"Mean difference (squared): {mean_diff_sq.item():.4f}")
print(f"Covariance trace sum: {trace_sum.item():.4f}")
return mean_diff_sq.item(), trace_sum.item()
# Simulate Inception features (in practice, extract from Inception-v3 pool3 layer)
# Feature dimension is 2048 for Inception-v3
feature_dim = 64 # Simplified for demonstration
num_samples = 500
# Simulated features: real distribution
real_features = torch.randn(num_samples, feature_dim) * 1.0 + 0.5
# Simulated features: good generator (close to real)
good_fake_features = torch.randn(num_samples, feature_dim) * 1.1 + 0.6
# Simulated features: bad generator (far from real)
bad_fake_features = torch.randn(num_samples, feature_dim) * 2.0 + 3.0
print("=== Good Generator (features similar to real) ===")
good_fid = calculate_fid_components(real_features, good_fake_features)
print("\n=== Bad Generator (features distant from real) ===")
bad_fid = calculate_fid_components(real_features, bad_fake_features)
print(f"\n--- Summary ---")
print(f"Lower mean difference = more realistic images (closer distributions)")
print(f"In practice, use: pip install pytorch-fid")
print(f"Command: python -m pytorch_fid path/real path/generated")
In practice, you'll use the pytorch-fid package to compute FID scores. The key takeaway is that FID captures both quality (mean difference) and diversity (covariance matching) in a single number. A well-trained GAN typically achieves FID scores under 50 on standard benchmarks, with state-of-the-art models reaching single digits.
- Always compute FID with at least 10,000-50,000 samples for statistical reliability
- Use the same preprocessing for real and generated images
- Track FID during training to detect mode collapse (sudden FID increase)
- Visual inspection remains essential — metrics can miss artifacts humans notice
Advanced GAN Variants
Since the original GAN paper, hundreds of variants have been proposed. Here are the most impactful ones that pushed the boundaries of image generation:
StyleGAN (Karras et al. 2019)
StyleGAN introduced a radically different Generator architecture inspired by style transfer. Instead of feeding noise directly into the first layer, it uses a mapping network to transform $z$ into an intermediate latent space $W$, which then controls generation at each resolution through Adaptive Instance Normalization (AdaIN). This gives unprecedented control over generated attributes — you can mix "coarse" features (pose, face shape) from one latent code with "fine" features (hair color, texture) from another.
CycleGAN (Zhu et al. 2017)
CycleGAN enables unpaired image-to-image translation — converting between domains without requiring matched pairs. Want to turn horses into zebras? Summer landscapes into winter? CycleGAN uses a cycle consistency loss: translating an image from domain A to B and back to A should reconstruct the original.
Pix2Pix (Isola et al. 2017)
Pix2Pix handles paired image-to-image translation where you have aligned input-output pairs. It uses a U-Net Generator (with skip connections) and a PatchGAN Discriminator that classifies whether each $N \times N$ patch is real or fake, rather than the whole image.
import torch
import torch.nn as nn
class StyleMappingNetwork(nn.Module):
"""
Simplified StyleGAN Mapping Network.
Maps z from Z-space to w in W-space (intermediate latent space).
The original uses 8 FC layers.
"""
def __init__(self, latent_dim=512, w_dim=512, num_layers=4):
super().__init__()
layers = []
for i in range(num_layers):
layers.extend([
nn.Linear(latent_dim if i == 0 else w_dim, w_dim),
nn.LeakyReLU(0.2, True),
])
self.mapping = nn.Sequential(*layers)
def forward(self, z):
return self.mapping(z)
class AdaIN(nn.Module):
"""Adaptive Instance Normalization - core of StyleGAN."""
def __init__(self, w_dim, num_features):
super().__init__()
# Learned affine transform from w to scale/shift
self.to_style = nn.Linear(w_dim, num_features * 2)
self.norm = nn.InstanceNorm2d(num_features)
def forward(self, x, w):
# x: feature map (B, C, H, W), w: style vector (B, w_dim)
style = self.to_style(w) # (B, 2*C)
gamma, beta = style.chunk(2, dim=1) # Each (B, C)
gamma = gamma.unsqueeze(2).unsqueeze(3) # (B, C, 1, 1)
beta = beta.unsqueeze(2).unsqueeze(3)
# Normalize then apply learned style
out = self.norm(x)
return gamma * out + beta
# Demonstrate StyleGAN concepts
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_dim = 512
w_dim = 512
# Mapping network transforms Z -> W
mapping = StyleMappingNetwork(latent_dim, w_dim).to(device)
# Sample two different latent codes
z1 = torch.randn(1, latent_dim, device=device)
z2 = torch.randn(1, latent_dim, device=device)
w1 = mapping(z1)
w2 = mapping(z2)
print(f"Z-space sample: {z1.shape}")
print(f"W-space (intermediate latent): {w1.shape}")
print(f"\nStyle mixing: use w1 for coarse layers, w2 for fine layers")
print(f"This separates high-level attributes from fine details!")
# AdaIN demonstration
adain = AdaIN(w_dim=w_dim, num_features=64).to(device)
feature_map = torch.randn(1, 64, 16, 16, device=device)
styled_features = adain(feature_map, w1)
print(f"\nAdaIN: features {feature_map.shape} + style {w1.shape} -> styled {styled_features.shape}")
StyleGAN's mapping network is a brilliant insight: the raw latent space $Z$ is entangled (changing one dimension affects multiple visual attributes), but the intermediate space $W$ learned by the mapping network is disentangled — individual dimensions correspond to meaningful attributes. This enables smooth interpolation and precise attribute control that earlier GANs couldn't achieve.
Let's also look at the cycle consistency concept used in CycleGAN:
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
"""Residual block used in CycleGAN generators."""
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, 3),
nn.InstanceNorm2d(channels),
nn.ReLU(True),
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, 3),
nn.InstanceNorm2d(channels),
)
def forward(self, x):
return x + self.block(x)
class CycleGANGenerator(nn.Module):
"""
Simplified CycleGAN Generator (encoder-transformer-decoder).
Uses Instance Normalization (not Batch Norm) for style transfer.
"""
def __init__(self, in_channels=3, num_residuals=6):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(in_channels, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.ReLU(True),
)
# Transformer (residual blocks)
self.transformer = nn.Sequential(
*[ResidualBlock(128) for _ in range(num_residuals)]
)
# Decoder
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(64),
nn.ReLU(True),
nn.ReflectionPad2d(3),
nn.Conv2d(64, in_channels, 7),
nn.Tanh(),
)
def forward(self, x):
return self.decoder(self.transformer(self.encoder(x)))
# Demonstrate cycle consistency concept
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Two generators: A->B and B->A
G_AB = CycleGANGenerator(in_channels=3).to(device)
G_BA = CycleGANGenerator(in_channels=3).to(device)
# Simulate an image from domain A (e.g., horse photo)
img_A = torch.randn(1, 3, 128, 128, device=device)
# Translate A -> B (horse -> zebra)
fake_B = G_AB(img_A)
# Translate back B -> A (zebra -> horse)
reconstructed_A = G_BA(fake_B)
# Cycle consistency loss: ||G_BA(G_AB(A)) - A||
cycle_loss = nn.L1Loss()(reconstructed_A, img_A)
print(f"Original image A: {img_A.shape}")
print(f"Translated to B: {fake_B.shape}")
print(f"Reconstructed A: {reconstructed_A.shape}")
print(f"Cycle consistency loss: {cycle_loss.item():.4f}")
print(f"\nGoal: minimize this loss so translation is reversible")
print(f"This ensures the GAN preserves content while changing style!")
CycleGAN's genius is the cycle consistency loss: if you translate a horse image to a zebra and then back to a horse, you should get the original image back. This constraint ensures the generator preserves the underlying content (pose, background, shape) while only changing the domain-specific attributes (stripes vs solid color). Notice the use of Instance Normalization instead of Batch Normalization — this is standard for style transfer networks because it normalizes each image independently, preserving per-image style information.
Conclusion & Next Steps
In this deep dive, we've built a complete understanding of Generative Adversarial Networks from the ground up:
- Foundations: The minimax game, Generator vs Discriminator roles, and Nash equilibrium
- Architecture: Transposed convolutions for upsampling (G), strided convolutions for downsampling (D)
- Training: Alternating optimization, label smoothing, proper hyperparameters
- DCGAN: Architectural guidelines that made convolutional GANs trainable
- Stability: WGAN-GP gradient penalty for robust training
- Control: Conditional GANs for directed generation
- Evaluation: FID and Inception Score metrics
- Advanced: StyleGAN (disentangled control), CycleGAN (unpaired translation)
- Train a DCGAN on CIFAR-10 (32×32 color images) and track FID over training
- Implement WGAN-GP and compare training stability against standard GAN on MNIST
- Build a Conditional GAN that can generate specific Fashion-MNIST classes on demand
- Experiment with interpolation in latent space — generate images between two noise vectors
GANs remain one of the most active areas of deep learning research. While diffusion models have recently surpassed GANs in image quality for some tasks, GANs still excel at real-time generation (much faster inference), video synthesis, and domain adaptation. Understanding GAN architecture deeply prepares you for all modern generative AI work.