What is Image Segmentation?
Image segmentation is the task of assigning a class label to every single pixel in an image. Unlike classification (which outputs one label for the whole image) or object detection (which draws bounding boxes), segmentation produces a pixel-perfect map that outlines every object's exact boundary. Think of it as "coloring in" each pixel according to what object it belongs to.
Real-world applications of image segmentation are everywhere: radiologists use it to outline tumors in MRI scans, self-driving cars segment roads from pedestrians, satellite systems map forest coverage, and AR filters separate your face from the background in real-time.
Types of Segmentation
There are three major types of segmentation you should know about, each building in complexity:
- Semantic Segmentation: Every pixel gets a class label, but different instances of the same class are not distinguished. Two cars touching each other form one "car blob."
- Instance Segmentation: Not only labels each pixel, but also separates individual object instances. Those two touching cars get unique IDs (car-1, car-2).
- Panoptic Segmentation: Combines both — it segments "stuff" (sky, road, grass) semantically and "things" (cars, people) by instance.
U-Net was originally designed for semantic segmentation (specifically, biomedical cell segmentation), and that's what we'll focus on building from scratch. Let's explore how it works by visualizing a simple segmentation task in code:
import torch
import numpy as np
import matplotlib.pyplot as plt
# Simulate a simple segmentation scenario
# Input: 256x256 grayscale image (like a medical scan)
# Output: 256x256 binary mask (foreground vs background)
# Create a synthetic "medical image" with a circular object
image = np.zeros((256, 256), dtype=np.float32)
y, x = np.ogrid[-128:128, -128:128]
circle_mask = (x**2 + y**2) < 60**2
image[circle_mask] = 0.8
image += np.random.normal(0, 0.1, image.shape) # Add noise
# The ground truth segmentation mask
ground_truth = np.zeros((256, 256), dtype=np.float32)
ground_truth[circle_mask] = 1.0
# Visualize the task
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(image, cmap='gray')
axes[0].set_title('Input Image')
axes[1].imshow(ground_truth, cmap='RdYlBu')
axes[1].set_title('Ground Truth Mask')
axes[2].imshow(image, cmap='gray', alpha=0.7)
axes[2].imshow(ground_truth, cmap='Reds', alpha=0.3)
axes[2].set_title('Overlay')
for ax in axes:
ax.axis('off')
plt.tight_layout()
plt.show()
print(f"Image shape: {image.shape}")
print(f"Mask shape: {ground_truth.shape}")
print(f"Unique mask values: {np.unique(ground_truth)}")
This simple example illustrates the core challenge: given a noisy input image, our model must predict a clean per-pixel mask. In real scenarios, the shapes are far more complex — irregular tumors, overlapping cells, or winding roads — which is exactly why we need a powerful architecture like U-Net.
The U-Net Architecture
U-Net was introduced in 2015 by Olaf Ronneberger et al. at the University of Freiburg for biomedical image segmentation. The name comes from its distinctive U-shaped architecture: the left side contracts (encodes), the bottom bridges, and the right side expands (decodes). What made it revolutionary was not just the encoder-decoder design (which existed before), but the skip connections that concatenate encoder features directly to the decoder at each level.
The Key Innovation: Concatenation Skip Connections
Previous encoder-decoder networks (like SegNet) lost fine spatial details during downsampling. U-Net's genius was to copy encoder feature maps and concatenate them with decoder feature maps at matching resolutions. This gives the decoder access to both high-resolution spatial details (from the encoder) AND high-level semantic context (from the upsampling path). The concatenation approach (rather than element-wise addition like ResNet) preserves more information because it doubles the channel count and lets the network learn how to best combine both sources.
flowchart TD
subgraph Encoder["Encoder (Contracting Path)"]
E1["Conv Block 1
64 channels"] --> P1["MaxPool 2×2"]
P1 --> E2["Conv Block 2
128 channels"]
E2 --> P2["MaxPool 2×2"]
P2 --> E3["Conv Block 3
256 channels"]
E3 --> P3["MaxPool 2×2"]
P3 --> E4["Conv Block 4
512 channels"]
E4 --> P4["MaxPool 2×2"]
end
subgraph Bridge["Bottleneck"]
P4 --> B["Conv Block 5
1024 channels"]
end
subgraph Decoder["Decoder (Expanding Path)"]
B --> U4["UpConv 512"]
U4 --> C4["Concat + Conv Block
512 channels"]
C4 --> U3["UpConv 256"]
U3 --> C3["Concat + Conv Block
256 channels"]
C3 --> U2["UpConv 128"]
U2 --> C2["Concat + Conv Block
128 channels"]
C2 --> U1["UpConv 64"]
U1 --> C1["Concat + Conv Block
64 channels"]
end
C1 --> OUT["1×1 Conv → num_classes"]
E4 -.->|"Skip Connection"| C4
E3 -.->|"Skip Connection"| C3
E2 -.->|"Skip Connection"| C2
E1 -.->|"Skip Connection"| C1
Notice the symmetric structure: for every downsampling step in the encoder, there's a corresponding upsampling step in the decoder, connected by a skip connection. The dashed arrows represent the feature maps being copied from the encoder and concatenated with the upsampled decoder output. This is the essence of U-Net's power.
Encoder (Contracting Path)
The encoder is the left side of the "U". Its job is to progressively capture higher-level features while reducing spatial resolution. Each encoder level consists of a double convolution block (two 3×3 convolutions, each followed by Batch Normalization and ReLU) followed by 2×2 max pooling for downsampling. The channel count doubles at each level: 64 → 128 → 256 → 512 → 1024.
Let's build the fundamental building block — the double convolution — which is used in both the encoder and decoder:
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
"""Double convolution block: (Conv2d -> BN -> ReLU) × 2
This is the fundamental building block of U-Net, used in both
the encoder and decoder paths.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
# Test the double conv block
block = DoubleConv(in_channels=3, out_channels=64)
sample_input = torch.randn(1, 3, 256, 256) # batch=1, RGB, 256x256
output = block(sample_input)
print(f"Input shape: {sample_input.shape}") # [1, 3, 256, 256]
print(f"Output shape: {output.shape}") # [1, 64, 256, 256]
print(f"Parameters: {sum(p.numel() for p in block.parameters()):,}")
Notice that we use padding=1 with a 3×3 kernel, which preserves spatial dimensions. We also set bias=False in the Conv2d layers because BatchNorm already has a learnable bias (the beta parameter), making the conv bias redundant. This is a common best practice that saves parameters.
Implementing the Encoder
The complete encoder applies the DoubleConv block at each level, then halves the spatial dimensions with max pooling. After 4 pooling operations, a 256×256 input becomes 16×16, but with 512 channels of rich feature information. The encoder is essentially the feature extraction backbone — you can think of it like the first half of a VGG or ResNet without the classification head.
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Encoder(nn.Module):
"""U-Net encoder: 4 levels of DoubleConv + MaxPool."""
def __init__(self, in_channels=3):
super().__init__()
self.enc1 = DoubleConv(in_channels, 64)
self.enc2 = DoubleConv(64, 128)
self.enc3 = DoubleConv(128, 256)
self.enc4 = DoubleConv(256, 512)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
# Each level: apply convs, save skip, then pool
skip1 = self.enc1(x) # 64 channels, full resolution
skip2 = self.enc2(self.pool(skip1)) # 128 ch, 1/2 resolution
skip3 = self.enc3(self.pool(skip2)) # 256 ch, 1/4 resolution
skip4 = self.enc4(self.pool(skip3)) # 512 ch, 1/8 resolution
# Return pooled output (goes to bottleneck) + skip connections
pooled = self.pool(skip4) # 512 ch, 1/16 resolution
return pooled, [skip1, skip2, skip3, skip4]
# Test the encoder
encoder = Encoder(in_channels=3)
x = torch.randn(1, 3, 256, 256)
bottleneck_input, skips = encoder(x)
print("Encoder feature map progression:")
print(f" Input: {x.shape}")
for i, skip in enumerate(skips, 1):
print(f" Skip {i}: {skip.shape}")
print(f" To bottleneck: {bottleneck_input.shape}")
The key detail here is that we save the output of each DoubleConv before pooling. These saved feature maps (skip1 through skip4) will be passed to the decoder later. They contain fine-grained spatial information that would otherwise be lost after downsampling.
Bottleneck
The bottleneck sits at the bottom of the "U" — it's the deepest level with the smallest spatial dimensions but the most channels (1024). Its purpose is to capture global context: at this point, the receptive field is large enough that each neuron "sees" a significant portion of the original image. The bottleneck processes the most compressed representation before the decoder starts expanding it back to full resolution.
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
# The bottleneck is simply another DoubleConv — but at the deepest level
bottleneck = DoubleConv(512, 1024)
# Simulating what arrives at the bottleneck after 4 pooling steps
# from a 256x256 input: 256 -> 128 -> 64 -> 32 -> 16
x = torch.randn(1, 512, 16, 16)
output = bottleneck(x)
print(f"Bottleneck input: {x.shape}") # [1, 512, 16, 16]
print(f"Bottleneck output: {output.shape}") # [1, 1024, 16, 16]
print(f"Bottleneck params: {sum(p.numel() for p in bottleneck.parameters()):,}")
print(f"\nAt 16x16 with 1024 channels, each 'pixel' has a receptive field")
print(f"spanning a large portion of the original 256x256 image.")
The bottleneck is architecturally identical to any other DoubleConv block — what makes it special is its position in the network. It bridges the encoder and decoder, holding the most abstract, semantically rich representation. From here, the decoder will progressively upsample while recovering spatial detail via skip connections.
Decoder (Expanding Path)
The decoder is the right side of the "U". Its job is to upsample the feature maps back to the original resolution while combining them with the skip connections from the encoder. Each decoder level uses ConvTranspose2d (also called "transposed convolution" or "deconvolution") to double the spatial dimensions, then concatenates the result with the corresponding encoder skip connection, and finally applies a DoubleConv block to refine the features.
Transposed Convolutions & Crop-and-Concatenate
ConvTranspose2d is essentially the "reverse" of a regular convolution: instead of shrinking the spatial dimensions, it expands them. Think of it as inserting zeros between input pixels and then convolving — the result is a larger feature map. After upsampling, we concatenate the encoder features (which have the same spatial size) along the channel dimension. If sizes don't match perfectly due to odd dimensions, we center-crop the encoder features to match.
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class DecoderBlock(nn.Module):
"""Single decoder level: UpConv -> Concat with skip -> DoubleConv."""
def __init__(self, in_channels, out_channels):
super().__init__()
# Transposed conv halves channels and doubles spatial dims
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
# After concat: out_channels (from up) + out_channels (from skip) = 2*out_channels
self.conv = DoubleConv(out_channels * 2, out_channels)
def forward(self, x, skip):
# Step 1: Upsample
x = self.up(x)
# Step 2: Handle size mismatch (crop skip to match x)
diff_h = skip.size(2) - x.size(2)
diff_w = skip.size(3) - x.size(3)
x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2,
diff_h // 2, diff_h - diff_h // 2])
# Step 3: Concatenate along channel dimension
x = torch.cat([skip, x], dim=1)
# Step 4: Refine with double convolution
return self.conv(x)
# Test one decoder block
decoder_block = DecoderBlock(in_channels=1024, out_channels=512)
# Simulating: bottleneck output and encoder skip connection
bottleneck_out = torch.randn(1, 1024, 16, 16)
encoder_skip = torch.randn(1, 512, 32, 32) # From encoder level 4
output = decoder_block(bottleneck_out, encoder_skip)
print(f"Bottleneck output: {bottleneck_out.shape}")
print(f"Encoder skip: {encoder_skip.shape}")
print(f"Decoder output: {output.shape}")
print(f"\n1024ch@16x16 -> UpConv -> 512ch@32x32 -> Concat(skip) -> 1024ch@32x32 -> DoubleConv -> 512ch@32x32")
The channel flow is important to understand: ConvTranspose2d takes 1024 channels down to 512 (and doubles spatial size). Then we concatenate with the 512-channel skip connection, giving 1024 channels. The DoubleConv then reduces back to 512. This pattern repeats at each level: 1024→512, 512→256, 256→128, 128→64.
Skip Connections: The Secret Sauce
Skip connections are the architectural innovation that made U-Net so effective. They solve a fundamental tension in segmentation: to classify what an object is, you need high-level semantic features (obtained by deep layers with large receptive fields). But to precisely locate object boundaries, you need fine-grained spatial details (which are lost during pooling). Skip connections give the decoder both — simultaneously.
Comparing: With vs Without Skip Connections
To really appreciate what skip connections do, let's build a version without them and compare the information flow. Without skip connections, the decoder must hallucinate fine spatial details from just the bottleneck — a 16×16 feature map trying to reconstruct 256×256 precision. With skip connections, the decoder gets direct access to the full-resolution features it needs.
flowchart LR
subgraph Without["Without Skip Connections"]
A1["256×256
Input"] --> B1["16×16
Bottleneck"] --> C1["256×256
Output"]
end
subgraph With["With Skip Connections"]
A2["256×256
Input"] --> B2["16×16
Bottleneck"] --> C2["256×256
Output"]
A2 -->|"Fine details
at each scale"| C2
end
import torch
import torch.nn as nn
# Demonstrate the information difference with and without skips
class SimpleEncoderDecoder(nn.Module):
"""Encoder-decoder WITHOUT skip connections."""
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, stride=2), nn.ReLU(),
nn.ConvTranspose2d(64, 1, 2, stride=2), nn.Sigmoid(),
)
def forward(self, x):
encoded = self.encoder(x)
return self.decoder(encoded)
class SkipEncoderDecoder(nn.Module):
"""Encoder-decoder WITH skip connections (U-Net style)."""
def __init__(self):
super().__init__()
self.enc1 = nn.Sequential(nn.Conv2d(1, 64, 3, padding=1), nn.ReLU())
self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.ReLU())
self.pool = nn.MaxPool2d(2)
self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec1 = nn.Sequential(nn.Conv2d(128, 64, 3, padding=1), nn.ReLU()) # 128 = 64 + 64
self.up2 = nn.ConvTranspose2d(64, 1, 2, stride=2)
self.final = nn.Sigmoid()
def forward(self, x):
skip1 = self.enc1(x) # Full resolution features
skip2 = self.enc2(self.pool(skip1)) # Half resolution
up = self.up1(self.pool(skip2)) # Upsample
up = torch.cat([up, skip1], dim=1) # Concatenate skip!
up = self.dec1(up)
up = self.up2(up)
return self.final(up)
# Compare parameter counts and information flow
no_skip = SimpleEncoderDecoder()
with_skip = SkipEncoderDecoder()
x = torch.randn(1, 1, 64, 64)
print(f"Without skips - Output: {no_skip(x).shape}")
print(f"With skips - Output: {with_skip(x).shape}")
print(f"\nWithout skips params: {sum(p.numel() for p in no_skip.parameters()):,}")
print(f"With skips params: {sum(p.numel() for p in with_skip.parameters()):,}")
print(f"\nThe skip-connected version has access to full-resolution encoder")
print(f"features, enabling precise boundary reconstruction.")
In practice, models without skip connections produce "blobby" segmentation masks that capture approximate shape but miss fine edges. With skip connections, the masks closely follow object boundaries because the decoder can reference the encoder's high-resolution edge and texture information directly.
Building the Complete U-Net
Now we have all the pieces — let's assemble the full U-Net. The architecture combines our encoder, bottleneck, and decoder with a final 1×1 convolution that maps from 64 feature channels to the desired number of output classes. For binary segmentation (foreground vs background), num_classes=1 with a sigmoid. For multi-class, use num_classes=N without sigmoid (apply softmax in the loss function).
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(Conv2d -> BN -> ReLU) × 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class UNet(nn.Module):
"""Full U-Net implementation for image segmentation."""
def __init__(self, in_channels=3, num_classes=1):
super().__init__()
# Encoder (contracting path)
self.enc1 = DoubleConv(in_channels, 64)
self.enc2 = DoubleConv(64, 128)
self.enc3 = DoubleConv(128, 256)
self.enc4 = DoubleConv(256, 512)
self.pool = nn.MaxPool2d(2)
# Bottleneck
self.bottleneck = DoubleConv(512, 1024)
# Decoder (expanding path)
self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.dec4 = DoubleConv(1024, 512) # 512 (up) + 512 (skip) = 1024 in
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = DoubleConv(512, 256) # 256 + 256 = 512 in
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = DoubleConv(256, 128) # 128 + 128 = 256 in
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = DoubleConv(128, 64) # 64 + 64 = 128 in
# Final 1x1 convolution
self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x):
# Encoder
s1 = self.enc1(x)
s2 = self.enc2(self.pool(s1))
s3 = self.enc3(self.pool(s2))
s4 = self.enc4(self.pool(s3))
# Bottleneck
b = self.bottleneck(self.pool(s4))
# Decoder with skip connections
d4 = self.up4(b)
d4 = self.dec4(torch.cat([d4, s4], dim=1))
d3 = self.up3(d4)
d3 = self.dec3(torch.cat([d3, s3], dim=1))
d2 = self.up2(d3)
d2 = self.dec2(torch.cat([d2, s2], dim=1))
d1 = self.up1(d2)
d1 = self.dec1(torch.cat([d1, s1], dim=1))
return self.final_conv(d1)
# Instantiate and test
model = UNet(in_channels=3, num_classes=1)
x = torch.randn(1, 3, 256, 256)
output = model(x)
print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"\nThe output is the same spatial size as input — one prediction per pixel!")
That's the complete U-Net in about 60 lines of code! The forward pass elegantly shows the symmetric structure: encode through 4 levels, pass through the bottleneck, then decode through 4 levels with skip connections. The final 1×1 convolution collapses the 64-channel feature map into a single-channel prediction (for binary segmentation). Each output pixel represents the model's confidence that the corresponding input pixel belongs to the foreground class.
Parameter Count Analysis
Understanding how parameters are distributed across the network helps with debugging and optimization. The bottleneck and deepest decoder levels hold the most parameters due to their high channel counts:
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class UNet(nn.Module):
def __init__(self, in_channels=3, num_classes=1):
super().__init__()
self.enc1 = DoubleConv(in_channels, 64)
self.enc2 = DoubleConv(64, 128)
self.enc3 = DoubleConv(128, 256)
self.enc4 = DoubleConv(256, 512)
self.pool = nn.MaxPool2d(2)
self.bottleneck = DoubleConv(512, 1024)
self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.dec4 = DoubleConv(1024, 512)
self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.dec3 = DoubleConv(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.dec2 = DoubleConv(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec1 = DoubleConv(128, 64)
self.final_conv = nn.Conv2d(64, num_classes, 1)
def forward(self, x):
s1 = self.enc1(x)
s2 = self.enc2(self.pool(s1))
s3 = self.enc3(self.pool(s2))
s4 = self.enc4(self.pool(s3))
b = self.bottleneck(self.pool(s4))
d4 = self.up4(b)
d4 = self.dec4(torch.cat([d4, s4], dim=1))
d3 = self.up3(d4)
d3 = self.dec3(torch.cat([d3, s3], dim=1))
d2 = self.up2(d3)
d2 = self.dec2(torch.cat([d2, s2], dim=1))
d1 = self.up1(d2)
d1 = self.dec1(torch.cat([d1, s1], dim=1))
return self.final_conv(d1)
# Analyze parameter distribution
model = UNet(in_channels=3, num_classes=1)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total U-Net parameters: {total_params:,}")
print(f"Total U-Net size: {total_params * 4 / 1024**2:.1f} MB (float32)\n")
# Per-component breakdown
components = {
'Encoder 1 (3→64)': model.enc1,
'Encoder 2 (64→128)': model.enc2,
'Encoder 3 (128→256)': model.enc3,
'Encoder 4 (256→512)': model.enc4,
'Bottleneck (512→1024)': model.bottleneck,
'Decoder 4 (1024→512)': model.dec4,
'Decoder 3 (512→256)': model.dec3,
'Decoder 2 (256→128)': model.dec2,
'Decoder 1 (128→64)': model.dec1,
}
for name, module in components.items():
params = sum(p.numel() for p in module.parameters())
pct = 100.0 * params / total_params
print(f" {name:30s} {params:>10,} params ({pct:.1f}%)")
You'll notice that the bottleneck and the deepest decoder/encoder levels dominate the parameter count. The 512→1024 and 1024→512 channel transitions involve the largest weight matrices. This is typical of encoder-decoder architectures and explains why reducing the base channel count (e.g., from 64 to 32) can dramatically reduce model size at the cost of some accuracy.
Training on a Segmentation Dataset
Let's train our U-Net on a synthetic dataset of random circles — this lets us validate the architecture works without downloading large datasets. The concept transfers directly to real datasets like Oxford-IIIT Pets, Cityscapes, or medical imaging datasets. We'll create random images with circular objects and binary masks, then train the model to segment them.
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
class SyntheticCircleDataset(Dataset):
"""Generates random images with circles and their segmentation masks."""
def __init__(self, num_samples=500, image_size=128):
self.num_samples = num_samples
self.image_size = image_size
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
size = self.image_size
# Create noisy grayscale image
image = np.random.normal(0.3, 0.1, (size, size)).astype(np.float32)
mask = np.zeros((size, size), dtype=np.float32)
# Add 1-3 random circles
num_circles = np.random.randint(1, 4)
for _ in range(num_circles):
cx = np.random.randint(20, size - 20)
cy = np.random.randint(20, size - 20)
r = np.random.randint(10, 30)
y, x = np.ogrid[:size, :size]
circle = ((x - cx)**2 + (y - cy)**2) < r**2
image[circle] = np.random.uniform(0.7, 1.0)
mask[circle] = 1.0
# Add noise to image
image += np.random.normal(0, 0.05, image.shape).astype(np.float32)
image = np.clip(image, 0, 1)
# Convert to tensors: image needs channel dim, mask stays 2D
image_tensor = torch.from_numpy(image).unsqueeze(0) # [1, H, W]
mask_tensor = torch.from_numpy(mask).unsqueeze(0) # [1, H, W]
return image_tensor, mask_tensor
# Create dataset and dataloader
dataset = SyntheticCircleDataset(num_samples=500, image_size=128)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# Check one batch
images, masks = next(iter(dataloader))
print(f"Batch images shape: {images.shape}") # [8, 1, 128, 128]
print(f"Batch masks shape: {masks.shape}") # [8, 1, 128, 128]
print(f"Image value range: [{images.min():.2f}, {images.max():.2f}]")
print(f"Mask unique values: {torch.unique(masks).tolist()}")
Our synthetic dataset generates grayscale images (1 channel) with 1-3 circles of varying sizes. The ground truth mask is a binary map where 1 = circle pixels, 0 = background. This is a simplified version of what medical imaging segmentation looks like (e.g., segmenting cells or lesions from background tissue).
Training Loop
Now let's write the training loop. We'll use Binary Cross-Entropy with Logits (which combines sigmoid + BCE for numerical stability) as our loss function. The model outputs raw logits, and we apply sigmoid only for visualization and metric computation:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
# --- Redefine model and dataset (independent block) ---
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
)
def forward(self, x): return self.net(x)
class UNet(nn.Module):
def __init__(self, in_ch=1, num_classes=1):
super().__init__()
self.enc1 = DoubleConv(in_ch, 32)
self.enc2 = DoubleConv(32, 64)
self.enc3 = DoubleConv(64, 128)
self.pool = nn.MaxPool2d(2)
self.bottleneck = DoubleConv(128, 256)
self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.dec3 = DoubleConv(256, 128)
self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec2 = DoubleConv(128, 64)
self.up1 = nn.ConvTranspose2d(64, 32, 2, stride=2)
self.dec1 = DoubleConv(64, 32)
self.final = nn.Conv2d(32, num_classes, 1)
def forward(self, x):
s1 = self.enc1(x)
s2 = self.enc2(self.pool(s1))
s3 = self.enc3(self.pool(s2))
b = self.bottleneck(self.pool(s3))
d3 = self.dec3(torch.cat([self.up3(b), s3], dim=1))
d2 = self.dec2(torch.cat([self.up2(d3), s2], dim=1))
d1 = self.dec1(torch.cat([self.up1(d2), s1], dim=1))
return self.final(d1)
class SyntheticCircleDataset(Dataset):
def __init__(self, n=500, size=128):
self.n, self.size = n, size
def __len__(self): return self.n
def __getitem__(self, idx):
s = self.size
img = np.random.normal(0.3, 0.1, (s, s)).astype(np.float32)
mask = np.zeros((s, s), dtype=np.float32)
for _ in range(np.random.randint(1, 4)):
cx, cy, r = np.random.randint(20, s-20), np.random.randint(20, s-20), np.random.randint(10, 30)
y, x = np.ogrid[:s, :s]
c = ((x-cx)**2 + (y-cy)**2) < r**2
img[c] = np.random.uniform(0.7, 1.0)
mask[c] = 1.0
img = np.clip(img + np.random.normal(0, 0.05, img.shape).astype(np.float32), 0, 1)
return torch.from_numpy(img).unsqueeze(0), torch.from_numpy(mask).unsqueeze(0)
# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_ch=1, num_classes=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()
dataloader = DataLoader(SyntheticCircleDataset(500, 128), batch_size=8, shuffle=True)
# Training loop
print(f"Training on: {device}")
for epoch in range(5):
model.train()
epoch_loss = 0.0
for images, masks in dataloader:
images, masks = images.to(device), masks.to(device)
# Forward pass
predictions = model(images)
loss = criterion(predictions, masks)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(dataloader)
print(f"Epoch {epoch+1}/5 - Loss: {avg_loss:.4f}")
print("\nTraining complete! Model learned to segment circles.")
Notice we used a smaller U-Net (32→64→128→256 channels instead of 64→128→256→512→1024) for faster training on this simple task. The architecture scales: for complex real-world data, you'd use the full-size version with more training epochs. The loss should decrease steadily, indicating the model is learning to distinguish circle pixels from background pixels.
Loss Functions for Segmentation
Choosing the right loss function is crucial for segmentation. Binary Cross-Entropy (BCE) works well but struggles with class imbalance — a common issue where background pixels vastly outnumber foreground pixels (imagine segmenting a tiny tumor in a large scan). Dice loss directly optimizes the overlap between prediction and ground truth, making it naturally robust to imbalance.
The Dice coefficient measures the overlap between two sets:
$$\text{Dice} = \frac{2|A \cap B|}{|A| + |B|}$$
For soft predictions (probabilities), the differentiable Dice loss is:
$$\mathcal{L}_{\text{Dice}} = 1 - \frac{2\sum p_i g_i + \epsilon}{\sum p_i + \sum g_i + \epsilon}$$
Where $p_i$ are predicted probabilities and $g_i$ are ground truth labels (0 or 1). The $\epsilon$ term (typically 1e-6) prevents division by zero when both prediction and ground truth are empty.
Similarly, Intersection over Union (IoU), also called the Jaccard index, measures overlap:
$$\text{IoU} = \frac{|A \cap B|}{|A \cup B|} = \frac{|A \cap B|}{|A| + |B| - |A \cap B|}$$
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
"""Differentiable Dice Loss for binary segmentation.
Directly optimizes the Dice coefficient, which handles class
imbalance better than BCE because it measures overlap rather
than per-pixel accuracy.
"""
def __init__(self, smooth=1e-6):
super().__init__()
self.smooth = smooth
def forward(self, predictions, targets):
# Apply sigmoid to get probabilities
probs = torch.sigmoid(predictions)
# Flatten to [batch, pixels]
probs_flat = probs.view(probs.size(0), -1)
targets_flat = targets.view(targets.size(0), -1)
# Compute Dice coefficient
intersection = (probs_flat * targets_flat).sum(dim=1)
union = probs_flat.sum(dim=1) + targets_flat.sum(dim=1)
dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
# Return loss (1 - dice), averaged over batch
return 1.0 - dice.mean()
class IoUMetric:
"""Compute Intersection over Union (Jaccard Index) for evaluation."""
def __call__(self, predictions, targets, threshold=0.5):
probs = torch.sigmoid(predictions)
preds_binary = (probs > threshold).float()
intersection = (preds_binary * targets).sum()
union = preds_binary.sum() + targets.sum() - intersection
iou = (intersection + 1e-6) / (union + 1e-6)
return iou.item()
# Test the loss functions
dice_loss_fn = DiceLoss()
bce_loss_fn = nn.BCEWithLogitsLoss()
iou_metric = IoUMetric()
# Simulated predictions and ground truth
predictions = torch.randn(4, 1, 64, 64) # Raw logits
targets = (torch.rand(4, 1, 64, 64) > 0.7).float() # Binary mask
dice = dice_loss_fn(predictions, targets)
bce = bce_loss_fn(predictions, targets)
iou = iou_metric(predictions, targets)
print(f"Dice Loss: {dice.item():.4f}")
print(f"BCE Loss: {bce.item():.4f}")
print(f"IoU Score: {iou:.4f}")
# Show that Dice handles imbalance better
# Scenario: 95% background, 5% foreground (common in medical imaging)
sparse_targets = torch.zeros(4, 1, 64, 64)
sparse_targets[:, :, 30:34, 30:34] = 1.0 # Tiny foreground region
sparse_preds = torch.zeros(4, 1, 64, 64) - 2.0 # Predicts all background
dice_sparse = dice_loss_fn(sparse_preds, sparse_targets)
bce_sparse = bce_loss_fn(sparse_preds, sparse_targets)
print(f"\nWith class imbalance (5% foreground):")
print(f"Dice Loss: {dice_sparse.item():.4f} (high - correctly penalizes missing foreground)")
print(f"BCE Loss: {bce_sparse.item():.4f} (low - misleadingly good from correct background)")
This demonstrates why Dice loss is preferred for imbalanced segmentation: BCE can achieve low loss by simply predicting "all background" (since background dominates), while Dice loss remains high because it directly measures foreground overlap. In practice, many state-of-the-art models use a combined loss that leverages both.
Combined Loss Strategy
The best practice in modern segmentation is to combine BCE (which provides stable per-pixel gradients) with Dice loss (which ensures the model focuses on foreground overlap). This gives you the best of both worlds:
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self, smooth=1e-6):
super().__init__()
self.smooth = smooth
def forward(self, predictions, targets):
probs = torch.sigmoid(predictions)
probs_flat = probs.view(probs.size(0), -1)
targets_flat = targets.view(targets.size(0), -1)
intersection = (probs_flat * targets_flat).sum(dim=1)
union = probs_flat.sum(dim=1) + targets_flat.sum(dim=1)
dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
return 1.0 - dice.mean()
class CombinedLoss(nn.Module):
"""BCE + Dice Loss — the go-to combination for segmentation."""
def __init__(self, bce_weight=0.5, dice_weight=0.5):
super().__init__()
self.bce = nn.BCEWithLogitsLoss()
self.dice = DiceLoss()
self.bce_weight = bce_weight
self.dice_weight = dice_weight
def forward(self, predictions, targets):
bce_loss = self.bce(predictions, targets)
dice_loss = self.dice(predictions, targets)
return self.bce_weight * bce_loss + self.dice_weight * dice_loss
# Compare all three loss functions
combined_loss_fn = CombinedLoss(bce_weight=0.5, dice_weight=0.5)
bce_fn = nn.BCEWithLogitsLoss()
dice_fn = DiceLoss()
# Good prediction (mostly correct)
good_pred = torch.ones(2, 1, 32, 32) * 3.0 # High confidence foreground
target = torch.ones(2, 1, 32, 32)
# Bad prediction (all wrong)
bad_pred = torch.ones(2, 1, 32, 32) * -3.0 # High confidence background
print("Good prediction (correct):")
print(f" BCE: {bce_fn(good_pred, target).item():.4f}")
print(f" Dice: {dice_fn(good_pred, target).item():.4f}")
print(f" Combined: {combined_loss_fn(good_pred, target).item():.4f}")
print("\nBad prediction (all wrong):")
print(f" BCE: {bce_fn(bad_pred, target).item():.4f}")
print(f" Dice: {dice_fn(bad_pred, target).item():.4f}")
print(f" Combined: {combined_loss_fn(bad_pred, target).item():.4f}")
The combined loss ensures stable training (from BCE's smooth gradients) while maintaining focus on segmentation quality (from Dice's overlap optimization). Adjusting the weights lets you tune the balance: increase dice_weight when class imbalance is severe, or increase bce_weight for more stable early training.
Practical Improvements
The original U-Net from 2015 has been significantly enhanced over the years. Here are the most impactful improvements that modern U-Net variants incorporate. Each addresses a specific limitation of the original architecture.
- Attention U-Net: Adds attention gates that learn which spatial regions and features are most relevant, reducing false positives in irrelevant areas.
- Residual U-Net: Replaces DoubleConv blocks with residual blocks (like ResNet), enabling deeper networks without degradation.
- U-Net++: Introduces nested skip connections with dense sub-networks between encoder and decoder, capturing features at multiple semantic levels.
- Efficient U-Net: Uses depth-wise separable convolutions to dramatically reduce computation while maintaining quality.
Let's implement an Attention Gate — the most popular U-Net enhancement. It learns to focus the decoder's attention on relevant spatial regions from the encoder skip connections, suppressing irrelevant features (like background noise in medical scans):
import torch
import torch.nn as nn
class AttentionGate(nn.Module):
"""Attention gate for U-Net skip connections.
Learns to weight spatial regions of the encoder features (skip)
based on the decoder's context (gate signal). Regions the decoder
finds relevant get amplified; irrelevant regions get suppressed.
"""
def __init__(self, gate_channels, skip_channels, inter_channels):
super().__init__()
# Transform gate signal (from decoder)
self.W_gate = nn.Sequential(
nn.Conv2d(gate_channels, inter_channels, 1, bias=False),
nn.BatchNorm2d(inter_channels)
)
# Transform skip connection (from encoder)
self.W_skip = nn.Sequential(
nn.Conv2d(skip_channels, inter_channels, 1, bias=False),
nn.BatchNorm2d(inter_channels)
)
# Combine and produce attention map
self.psi = nn.Sequential(
nn.Conv2d(inter_channels, 1, 1, bias=False),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, gate, skip):
# gate: decoder feature (lower resolution, upsampled to match skip)
g = self.W_gate(gate)
s = self.W_skip(skip)
# Additive attention
attention = self.relu(g + s)
attention = self.psi(attention) # [B, 1, H, W] attention weights
# Apply attention to skip connection
return skip * attention
# Test the attention gate
attn_gate = AttentionGate(gate_channels=256, skip_channels=256, inter_channels=128)
gate_signal = torch.randn(1, 256, 32, 32) # From decoder (upsampled)
skip_features = torch.randn(1, 256, 32, 32) # From encoder
attended_skip = attn_gate(gate_signal, skip_features)
print(f"Gate signal: {gate_signal.shape}")
print(f"Skip features: {skip_features.shape}")
print(f"Attended output: {attended_skip.shape}")
print(f"\nThe attention gate preserves dimensions but re-weights the skip")
print(f"connection, emphasizing relevant spatial regions.")
print(f"Attention gate params: {sum(p.numel() for p in attn_gate.parameters()):,}")
The attention gate takes two inputs: the decoder's feature map (gate signal, which carries semantic context about "what we're looking for") and the encoder's skip connection (which carries fine spatial details). It produces an attention map that tells the network which spatial regions of the skip connection are worth paying attention to. This is particularly powerful for medical imaging where the target structure (e.g., a tumor) might occupy a tiny region of the image.
Residual U-Net
Another powerful enhancement is replacing the standard DoubleConv blocks with residual blocks. This combines the benefits of ResNet (trainable to greater depths, better gradient flow) with U-Net's encoder-decoder structure:
import torch
import torch.nn as nn
class ResidualConvBlock(nn.Module):
"""Residual block for U-Net: adds input to output for better gradient flow.
If input and output channels differ, a 1x1 conv adjusts the shortcut.
This allows building deeper U-Nets without degradation.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
)
# Shortcut: adjust channels if needed
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels)
) if in_channels != out_channels else nn.Identity()
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = self.shortcut(x)
out = self.conv_block(x)
return self.relu(out + residual) # Skip connection!
# Compare standard DoubleConv vs ResidualConvBlock
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch), nn.ReLU(True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch), nn.ReLU(True),
)
def forward(self, x): return self.net(x)
# Test both
standard = DoubleConv(64, 128)
residual = ResidualConvBlock(64, 128)
x = torch.randn(1, 64, 32, 32)
print(f"Standard DoubleConv output: {standard(x).shape}")
print(f"Residual Block output: {residual(x).shape}")
print(f"\nStandard params: {sum(p.numel() for p in standard.parameters()):,}")
print(f"Residual params: {sum(p.numel() for p in residual.parameters()):,}")
print(f"\nThe residual block adds a small overhead for the shortcut 1x1 conv,")
print(f"but enables much deeper networks with better gradient flow.")
To create a Residual U-Net, simply replace all DoubleConv modules with ResidualConvBlock in the original U-Net implementation. The skip connections within each residual block (local shortcuts) work in harmony with U-Net's cross-level skip connections (global shortcuts), creating a network with excellent gradient flow at both micro and macro scales.
torchsummary or print shapes in the forward pass during development. Off-by-one pixel errors from odd input sizes are the #1 debugging headache with U-Net.
Let's verify our U-Net handles various input sizes correctly and demonstrate a useful debugging technique:
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch), nn.ReLU(True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch), nn.ReLU(True),
)
def forward(self, x): return self.net(x)
class UNet(nn.Module):
def __init__(self, in_ch=3, num_classes=1):
super().__init__()
self.enc1 = DoubleConv(in_ch, 64)
self.enc2 = DoubleConv(64, 128)
self.enc3 = DoubleConv(128, 256)
self.enc4 = DoubleConv(256, 512)
self.pool = nn.MaxPool2d(2)
self.bottleneck = DoubleConv(512, 1024)
self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.dec4 = DoubleConv(1024, 512)
self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.dec3 = DoubleConv(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.dec2 = DoubleConv(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec1 = DoubleConv(128, 64)
self.final = nn.Conv2d(64, num_classes, 1)
def forward(self, x):
s1 = self.enc1(x)
s2 = self.enc2(self.pool(s1))
s3 = self.enc3(self.pool(s2))
s4 = self.enc4(self.pool(s3))
b = self.bottleneck(self.pool(s4))
d4 = self.dec4(torch.cat([self.up4(b), s4], dim=1))
d3 = self.dec3(torch.cat([self.up3(d4), s3], dim=1))
d2 = self.dec2(torch.cat([self.up2(d3), s2], dim=1))
d1 = self.dec1(torch.cat([self.up1(d2), s1], dim=1))
return self.final(d1)
# Test with various input sizes
model = UNet(in_ch=3, num_classes=1)
model.eval()
# U-Net works best with inputs divisible by 16 (due to 4 pooling steps)
test_sizes = [64, 128, 256, 512]
print("Input Size → Output Size (should match):")
print("-" * 45)
for size in test_sizes:
x = torch.randn(1, 3, size, size)
with torch.no_grad():
out = model(x)
status = "✓" if out.shape[2:] == x.shape[2:] else "✗"
print(f" {status} [{1}, 3, {size}, {size}] → {list(out.shape)}")
print(f"\nTip: Input spatial dims should be divisible by 2^(num_pool_layers)")
print(f" For 4 pooling layers: divisible by 16")
U-Net requires input dimensions divisible by $2^n$ where $n$ is the number of pooling layers. With 4 pooling layers, inputs must be divisible by 16. For odd-sized inputs, either pad to the nearest valid size or use the crop-and-concatenate approach shown in our earlier DecoderBlock implementation.
Visualizing Predictions
Finally, let's bring everything together by visualizing what our trained model actually produces. Good segmentation visualization overlays the predicted mask on the original image, making it easy to assess boundary quality:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# Create a simple trained-like scenario for visualization
# (In practice, use your trained model)
np.random.seed(42)
# Generate sample image with ground truth
size = 128
image = np.random.normal(0.3, 0.1, (size, size)).astype(np.float32)
mask = np.zeros((size, size), dtype=np.float32)
# Add two circles
for cx, cy, r in [(40, 60, 20), (90, 80, 15)]:
y, x = np.ogrid[:size, :size]
circle = ((x - cx)**2 + (y - cy)**2) < r**2
image[circle] = 0.85
mask[circle] = 1.0
image = np.clip(image + np.random.normal(0, 0.05, image.shape).astype(np.float32), 0, 1)
# Simulate model prediction (slightly imperfect)
pred_mask = np.zeros((size, size), dtype=np.float32)
for cx, cy, r in [(41, 59, 19), (89, 81, 14)]: # Slight offset
y, x = np.ogrid[:size, :size]
circle = ((x - cx)**2 + (y - cy)**2) < r**2
pred_mask[circle] = 1.0
# Compute metrics
intersection = (pred_mask * mask).sum()
union = pred_mask.sum() + mask.sum() - intersection
iou = intersection / (union + 1e-6)
dice = 2 * intersection / (pred_mask.sum() + mask.sum() + 1e-6)
# Visualization
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
axes[0].imshow(image, cmap='gray')
axes[0].set_title('Input Image')
axes[1].imshow(mask, cmap='Blues')
axes[1].set_title('Ground Truth')
axes[2].imshow(pred_mask, cmap='Reds')
axes[2].set_title(f'Prediction\nIoU: {iou:.3f} | Dice: {dice:.3f}')
# Overlay
axes[3].imshow(image, cmap='gray')
axes[3].imshow(mask, cmap='Greens', alpha=0.3)
axes[3].imshow(pred_mask, cmap='Reds', alpha=0.3)
axes[3].set_title('Overlay (Green=GT, Red=Pred)')
for ax in axes:
ax.axis('off')
plt.tight_layout()
plt.show()
print(f"IoU Score: {iou:.4f}")
print(f"Dice Score: {dice:.4f}")
In a real deployment scenario, you'd run your trained model on unseen test images and generate these visualizations to qualitatively assess performance. Common issues to look for include: jagged boundaries (need more training or attention gates), systematic under-segmentation (increase Dice loss weight), and false positives in background regions (add attention gates or train longer).