How Diffusion Models Work
Diffusion models generate images by learning to reverse a gradual noising process. The key insight is that destroying structure (adding noise) is trivial, but learning to recover structure (removing noise) produces a powerful generative model.
Forward Process (Adding Noise)
The forward process gradually adds Gaussian noise to a clean image over T timesteps until it becomes pure noise. At each step:
$$q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} \, x_{t-1}, \, \beta_t I)$$
Where $\beta_t$ is the noise schedule controlling how much noise is added at each step. Using the reparameterization trick, we can jump directly to any timestep:
$$x_t = \sqrt{\bar{\alpha}_t} \, x_0 + \sqrt{1 - \bar{\alpha}_t} \, \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)$$
Where $\bar{\alpha}_t = \prod_{s=1}^{t} (1 - \beta_s)$ is the cumulative product of noise retention factors.
Reverse Process (Denoising)
The reverse process learns to predict and remove noise step by step. A neural network (the U-Net) is trained to predict the noise $\epsilon_\theta(x_t, t)$ added at each timestep, then subtracts it to recover the cleaner image.
flowchart LR
A["Clean Image x0"] -->|"Add noise beta_1"| B["x1"]
B -->|"Add noise beta_2"| C["x2"]
C -->|"... T steps ..."| D["xT ~ N(0,I)"]
D -->|"U-Net predicts noise"| E["x(T-1)"]
E -->|"Denoise step"| F["x(T-2)"]
F -->|"... T steps ..."| G["Generated Image"]
style A fill:#3B9797,color:#fff
style D fill:#BF092F,color:#fff
style G fill:#3B9797,color:#fff
Latent Diffusion
Instead of operating on full-resolution pixels (e.g., 512×512×3 = 786,432 dimensions), Stable Diffusion first compresses the image to a latent space (64×64×4 = 16,384 dimensions) using a VAE. The diffusion process operates entirely in this compressed space, then decodes back to pixel space.
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Demonstrate the forward diffusion process
def forward_diffusion_demo(num_steps=5):
"""Show progressive noising of a synthetic image."""
# Create a simple synthetic image (gradient pattern)
x = np.zeros((64, 64, 3), dtype=np.float32)
x[:32, :, 0] = 1.0 # Red top half
x[32:, :, 2] = 1.0 # Blue bottom half
x[:, 32:, 1] = 0.5 # Green right half
# Linear noise schedule
betas = np.linspace(0.0001, 0.02, 1000)
alphas = 1.0 - betas
alpha_bar = np.cumprod(alphas)
# Show progressive noising
timesteps = np.linspace(0, 999, num_steps, dtype=int)
fig, axes = plt.subplots(1, num_steps, figsize=(15, 3))
for i, t in enumerate(timesteps):
noise = np.random.randn(*x.shape).astype(np.float32)
x_t = np.sqrt(alpha_bar[t]) * x + np.sqrt(1 - alpha_bar[t]) * noise
x_t_clipped = np.clip(x_t, 0, 1)
axes[i].imshow(x_t_clipped)
axes[i].set_title(f"t={t}")
axes[i].axis("off")
plt.suptitle("Forward Diffusion: Clean Image to Noise")
plt.tight_layout()
plt.show()
forward_diffusion_demo()
Noise Scheduling
The noise schedule $\{\beta_t\}_{t=1}^T$ controls how quickly images are destroyed during the forward process. The choice of schedule significantly impacts generation quality.
Linear Schedule
The original DDPM paper used a linear schedule from $\beta_1 = 10^{-4}$ to $\beta_T = 0.02$. This destroys information too quickly in early steps, wasting model capacity on nearly-pure noise.
Cosine Schedule
The cosine schedule provides a more gradual noise increase, preserving structure longer and giving the model more useful signal to learn from at intermediate timesteps.
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
"""Linear noise schedule as in original DDPM."""
return np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
def cosine_beta_schedule(timesteps, s=0.008):
"""Cosine noise schedule (improved DDPM)."""
steps = timesteps + 1
t = np.linspace(0, timesteps, steps, dtype=np.float64)
alpha_bar = np.cos(((t / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
alpha_bar = alpha_bar / alpha_bar[0]
betas = 1 - (alpha_bar[1:] / alpha_bar[:-1])
return np.clip(betas, 0.0001, 0.999).astype(np.float32)
# Compare schedules
T = 1000
linear_betas = linear_beta_schedule(T)
cosine_betas = cosine_beta_schedule(T)
# Compute alpha_bar for both
linear_alpha_bar = np.cumprod(1.0 - linear_betas)
cosine_alpha_bar = np.cumprod(1.0 - cosine_betas)
# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].plot(linear_betas, label="Linear", color="#BF092F")
axes[0].plot(cosine_betas, label="Cosine", color="#3B9797")
axes[0].set_xlabel("Timestep")
axes[0].set_ylabel("Beta (noise level)")
axes[0].set_title("Noise Schedule: Beta over Time")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[1].plot(linear_alpha_bar, label="Linear", color="#BF092F")
axes[1].plot(cosine_alpha_bar, label="Cosine", color="#3B9797")
axes[1].set_xlabel("Timestep")
axes[1].set_ylabel("Alpha Bar (signal retention)")
axes[1].set_title("Signal Retention over Time")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"Linear: signal at t=500: {linear_alpha_bar[500]:.4f}")
print(f"Cosine: signal at t=500: {cosine_alpha_bar[500]:.4f}")
The cosine schedule retains more signal at intermediate timesteps. At t=500, cosine retains ~50% signal while linear retains only ~5%. This means the model spends more training steps learning meaningful structure removal rather than just distinguishing noise from slightly-less noise.
import tensorflow as tf
import numpy as np
class NoiseScheduler:
"""Manages noise scheduling for diffusion models."""
def __init__(self, timesteps=1000, schedule_type="cosine"):
self.timesteps = timesteps
if schedule_type == "linear":
self.betas = tf.constant(
np.linspace(1e-4, 0.02, timesteps, dtype=np.float32)
)
elif schedule_type == "cosine":
steps = timesteps + 1
t = np.linspace(0, timesteps, steps, dtype=np.float64)
alpha_bar = np.cos(((t / timesteps) + 0.008) / 1.008 * np.pi * 0.5) ** 2
alpha_bar = alpha_bar / alpha_bar[0]
betas = 1 - (alpha_bar[1:] / alpha_bar[:-1])
self.betas = tf.constant(np.clip(betas, 1e-4, 0.999), dtype=tf.float32)
else:
raise ValueError(f"Unknown schedule: {schedule_type}")
self.alphas = 1.0 - self.betas
self.alpha_bar = tf.math.cumprod(self.alphas)
self.sqrt_alpha_bar = tf.sqrt(self.alpha_bar)
self.sqrt_one_minus_alpha_bar = tf.sqrt(1.0 - self.alpha_bar)
def add_noise(self, x_0, noise, t):
"""Add noise to clean images at timestep t."""
sqrt_ab = tf.gather(self.sqrt_alpha_bar, t)
sqrt_one_minus_ab = tf.gather(self.sqrt_one_minus_alpha_bar, t)
# Reshape for broadcasting with image dimensions
sqrt_ab = tf.reshape(sqrt_ab, [-1, 1, 1, 1])
sqrt_one_minus_ab = tf.reshape(sqrt_one_minus_ab, [-1, 1, 1, 1])
return sqrt_ab * x_0 + sqrt_one_minus_ab * noise
# Test the scheduler
scheduler = NoiseScheduler(timesteps=1000, schedule_type="cosine")
print(f"Beta range: [{scheduler.betas[0].numpy():.6f}, {scheduler.betas[-1].numpy():.6f}]")
print(f"Alpha_bar at t=0: {scheduler.alpha_bar[0].numpy():.4f}")
print(f"Alpha_bar at t=500: {scheduler.alpha_bar[500].numpy():.4f}")
print(f"Alpha_bar at t=999: {scheduler.alpha_bar[999].numpy():.4f}")
# Demo: add noise to a batch of random images
batch = tf.random.normal([4, 64, 64, 4]) # Latent space batch
noise = tf.random.normal(tf.shape(batch))
timesteps = tf.constant([0, 250, 500, 999])
noisy = scheduler.add_noise(batch, noise, timesteps)
print(f"Noisy batch shape: {noisy.shape}")
The U-Net Denoiser
The U-Net is the core neural network in Stable Diffusion. It takes a noisy latent image and a timestep embedding, and predicts the noise to be removed. The architecture uses an encoder-decoder structure with skip connections, residual blocks, and attention mechanisms at multiple resolutions.
Architecture Components
- Time Embedding: Sinusoidal positional encoding of the timestep, projected through two linear layers
- Residual Blocks: GroupNorm → SiLU → Conv → GroupNorm → SiLU → Conv + skip connection
- Attention Blocks: Self-attention for spatial reasoning at 32×32, 16×16, and 8×8 resolutions
- Downsample/Upsample: Strided convolutions (down) and transposed convolutions (up)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
class TimeEmbedding(layers.Layer):
"""Sinusoidal time embedding for conditioning on diffusion timestep."""
def __init__(self, embed_dim, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.dense1 = layers.Dense(embed_dim * 4, activation="swish")
self.dense2 = layers.Dense(embed_dim * 4)
def sinusoidal_embedding(self, t):
half_dim = self.embed_dim // 2
freq = tf.exp(
-tf.math.log(10000.0) * tf.range(0, half_dim, dtype=tf.float32) / half_dim
)
t = tf.cast(t, tf.float32)
args = t[:, None] * freq[None, :]
return tf.concat([tf.sin(args), tf.cos(args)], axis=-1)
def call(self, t):
emb = self.sinusoidal_embedding(t)
emb = self.dense1(emb)
emb = self.dense2(emb)
return emb
class ResBlock(layers.Layer):
"""Residual block with time embedding injection."""
def __init__(self, out_channels, **kwargs):
super().__init__(**kwargs)
self.out_channels = out_channels
self.norm1 = layers.GroupNormalization(groups=32)
self.conv1 = layers.Conv2D(out_channels, 3, padding="same")
self.time_proj = layers.Dense(out_channels)
self.norm2 = layers.GroupNormalization(groups=32)
self.conv2 = layers.Conv2D(out_channels, 3, padding="same")
self.skip_conv = None # Lazy init for channel mismatch
def build(self, input_shape):
if input_shape[-1] != self.out_channels:
self.skip_conv = layers.Conv2D(self.out_channels, 1)
super().build(input_shape)
def call(self, x, time_emb):
h = self.norm1(x)
h = tf.nn.silu(h)
h = self.conv1(h)
# Inject time embedding
t = tf.nn.silu(time_emb)
t = self.time_proj(t)[:, None, None, :]
h = h + t
h = self.norm2(h)
h = tf.nn.silu(h)
h = self.conv2(h)
# Skip connection
if self.skip_conv is not None:
x = self.skip_conv(x)
return x + h
class AttentionBlock(layers.Layer):
"""Self-attention block for spatial reasoning."""
def __init__(self, channels, num_heads=4, **kwargs):
super().__init__(**kwargs)
self.channels = channels
self.num_heads = num_heads
self.norm = layers.GroupNormalization(groups=32)
self.qkv = layers.Dense(channels * 3)
self.proj = layers.Dense(channels)
def call(self, x):
b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
residual = x
x = self.norm(x)
x = tf.reshape(x, [b, h * w, c])
qkv = self.qkv(x)
q, k, v = tf.split(qkv, 3, axis=-1)
# Multi-head attention
head_dim = c // self.num_heads
q = tf.reshape(q, [b, h * w, self.num_heads, head_dim])
k = tf.reshape(k, [b, h * w, self.num_heads, head_dim])
v = tf.reshape(v, [b, h * w, self.num_heads, head_dim])
q = tf.transpose(q, [0, 2, 1, 3])
k = tf.transpose(k, [0, 2, 1, 3])
v = tf.transpose(v, [0, 2, 1, 3])
scale = tf.math.rsqrt(tf.cast(head_dim, tf.float32))
attn = tf.matmul(q, k, transpose_b=True) * scale
attn = tf.nn.softmax(attn, axis=-1)
out = tf.matmul(attn, v)
out = tf.transpose(out, [0, 2, 1, 3])
out = tf.reshape(out, [b, h, w, c])
out = self.proj(tf.reshape(out, [b, h * w, c]))
out = tf.reshape(out, [b, h, w, c])
return residual + out
# Test components
time_emb = TimeEmbedding(embed_dim=128)
t = tf.constant([0, 100, 500, 999])
emb = time_emb(t)
print(f"Time embedding shape: {emb.shape}") # (4, 512)
res_block = ResBlock(out_channels=256)
x = tf.random.normal([2, 32, 32, 128])
out = res_block(x, emb[:2])
print(f"ResBlock output shape: {out.shape}") # (2, 32, 32, 256)
attn_block = AttentionBlock(channels=256, num_heads=4)
out = attn_block(out)
print(f"AttentionBlock output shape: {out.shape}") # (2, 32, 32, 256)
Complete U-Net Assembly
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
def build_unet(latent_channels=4, base_channels=128, channel_mults=(1, 2, 4, 4),
num_res_blocks=2, attention_resolutions=(2, 4)):
"""Build a simplified U-Net for latent diffusion."""
# Inputs
latent_input = keras.Input(shape=(64, 64, latent_channels), name="latent")
time_input = keras.Input(shape=(), dtype=tf.int32, name="timestep")
# Time embedding
time_emb_layer = TimeEmbedding(embed_dim=base_channels)
t_emb = time_emb_layer(time_input)
# Initial convolution
h = layers.Conv2D(base_channels, 3, padding="same")(latent_input)
# Encoder path
skips = [h]
for level, mult in enumerate(channel_mults):
ch = base_channels * mult
for _ in range(num_res_blocks):
res = ResBlock(ch)
h = res(h, t_emb)
if mult in attention_resolutions:
attn = AttentionBlock(ch)
h = attn(h)
skips.append(h)
if level < len(channel_mults) - 1:
h = layers.Conv2D(ch, 3, strides=2, padding="same")(h)
skips.append(h)
# Middle block
mid_ch = base_channels * channel_mults[-1]
mid_res1 = ResBlock(mid_ch)
mid_attn = AttentionBlock(mid_ch)
mid_res2 = ResBlock(mid_ch)
h = mid_res1(h, t_emb)
h = mid_attn(h)
h = mid_res2(h, t_emb)
# Decoder path
for level, mult in reversed(list(enumerate(channel_mults))):
ch = base_channels * mult
for _ in range(num_res_blocks + 1):
skip = skips.pop()
h = layers.Concatenate()([h, skip])
res = ResBlock(ch)
h = res(h, t_emb)
if mult in attention_resolutions:
attn = AttentionBlock(ch)
h = attn(h)
if level > 0:
h = layers.Conv2DTranspose(ch, 3, strides=2, padding="same")(h)
# Output
h = layers.GroupNormalization(groups=32)(h)
h = tf.nn.silu(h)
noise_pred = layers.Conv2D(latent_channels, 3, padding="same")(h)
model = keras.Model(inputs=[latent_input, time_input], outputs=noise_pred)
return model
# Build and summarize
unet = build_unet()
print(f"U-Net parameters: {unet.count_params():,}")
print(f"Input shapes: latent={unet.input[0].shape}, time={unet.input[1].shape}")
print(f"Output shape: {unet.output.shape}")
Cross-Attention: Conditioning on Text
The key innovation that makes Stable Diffusion controllable by text prompts is cross-attention. Rather than only attending to other spatial positions (self-attention), cross-attention layers allow the U-Net features to attend to text embeddings from a CLIP encoder.
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
In cross-attention: $Q$ comes from the image features (spatial), while $K$ and $V$ come from the text embeddings (77 tokens × 768 dimensions from CLIP).
import tensorflow as tf
from tensorflow.keras import layers
class CrossAttention(layers.Layer):
"""Cross-attention layer for text-conditioned image generation.
Q from image features, K/V from text embeddings (CLIP output).
"""
def __init__(self, channels, context_dim=768, num_heads=8, **kwargs):
super().__init__(**kwargs)
self.channels = channels
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.norm = layers.GroupNormalization(groups=32)
self.to_q = layers.Dense(channels, use_bias=False)
self.to_k = layers.Dense(channels, use_bias=False)
self.to_v = layers.Dense(channels, use_bias=False)
self.to_out = layers.Dense(channels)
def call(self, x, context):
"""
Args:
x: Image features [batch, height, width, channels]
context: Text embeddings [batch, seq_len, context_dim]
"""
b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
residual = x
x = self.norm(x)
x_flat = tf.reshape(x, [b, h * w, c])
# Q from image features, K/V from text context
q = self.to_q(x_flat)
k = self.to_k(context)
v = self.to_v(context)
# Reshape for multi-head attention
q = tf.reshape(q, [b, h * w, self.num_heads, self.head_dim])
k = tf.reshape(k, [b, -1, self.num_heads, self.head_dim])
v = tf.reshape(v, [b, -1, self.num_heads, self.head_dim])
q = tf.transpose(q, [0, 2, 1, 3]) # [b, heads, hw, dim]
k = tf.transpose(k, [0, 2, 1, 3]) # [b, heads, seq, dim]
v = tf.transpose(v, [0, 2, 1, 3])
# Scaled dot-product attention
scale = tf.math.rsqrt(tf.cast(self.head_dim, tf.float32))
attn_weights = tf.matmul(q, k, transpose_b=True) * scale
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_output = tf.matmul(attn_weights, v)
# Reshape back
attn_output = tf.transpose(attn_output, [0, 2, 1, 3])
attn_output = tf.reshape(attn_output, [b, h * w, c])
attn_output = self.to_out(attn_output)
attn_output = tf.reshape(attn_output, [b, h, w, c])
return residual + attn_output
# Test cross-attention
cross_attn = CrossAttention(channels=256, context_dim=768, num_heads=8)
# Simulated inputs
image_features = tf.random.normal([2, 32, 32, 256]) # From U-Net encoder
text_embeddings = tf.random.normal([2, 77, 768]) # From CLIP (77 tokens)
output = cross_attn(image_features, text_embeddings)
print(f"Cross-attention input: {image_features.shape}")
print(f"Text context: {text_embeddings.shape}")
print(f"Cross-attention output: {output.shape}")
# Verify shape preserved: (2, 32, 32, 256)
The VAE: Latent Space Encoding
The Variational Autoencoder (VAE) in Stable Diffusion provides the compression between pixel space and latent space. The encoder maps a 512×512×3 image to a 64×64×4 latent representation (48× compression), and the decoder reverses this mapping.
Why Latent Space?
- Computation: 512×512×3 = 786,432 values vs 64×64×4 = 16,384 values
- Memory: U-Net attention on 64×64 is tractable; on 512×512 it would require 100+ GB
- Quality: The latent space captures semantic features, not pixel noise
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
def build_vae_encoder(input_shape=(512, 512, 3), latent_channels=4):
"""VAE Encoder: Image -> Latent (with KL regularization)."""
inputs = keras.Input(shape=input_shape)
# Downsampling path: 512 -> 256 -> 128 -> 64
x = layers.Conv2D(64, 3, padding="same")(inputs)
x = layers.GroupNormalization(groups=32)(x)
x = tf.nn.silu(x)
x = layers.Conv2D(128, 3, strides=2, padding="same")(x) # 256x256
x = layers.GroupNormalization(groups=32)(x)
x = tf.nn.silu(x)
x = layers.Conv2D(256, 3, strides=2, padding="same")(x) # 128x128
x = layers.GroupNormalization(groups=32)(x)
x = tf.nn.silu(x)
x = layers.Conv2D(512, 3, strides=2, padding="same")(x) # 64x64
x = layers.GroupNormalization(groups=32)(x)
x = tf.nn.silu(x)
# Produce mean and log-variance
x = layers.Conv2D(latent_channels * 2, 3, padding="same")(x)
mean, logvar = tf.split(x, 2, axis=-1)
# Reparameterization trick
std = tf.exp(0.5 * logvar)
eps = tf.random.normal(tf.shape(std))
z = mean + std * eps
encoder = keras.Model(inputs, [z, mean, logvar], name="vae_encoder")
return encoder
def build_vae_decoder(latent_shape=(64, 64, 4), output_channels=3):
"""VAE Decoder: Latent -> Image."""
inputs = keras.Input(shape=latent_shape)
# Upsampling path: 64 -> 128 -> 256 -> 512
x = layers.Conv2D(512, 3, padding="same")(inputs)
x = layers.GroupNormalization(groups=32)(x)
x = tf.nn.silu(x)
x = layers.Conv2DTranspose(256, 3, strides=2, padding="same")(x) # 128x128
x = layers.GroupNormalization(groups=32)(x)
x = tf.nn.silu(x)
x = layers.Conv2DTranspose(128, 3, strides=2, padding="same")(x) # 256x256
x = layers.GroupNormalization(groups=32)(x)
x = tf.nn.silu(x)
x = layers.Conv2DTranspose(64, 3, strides=2, padding="same")(x) # 512x512
x = layers.GroupNormalization(groups=32)(x)
x = tf.nn.silu(x)
outputs = layers.Conv2D(output_channels, 3, padding="same", activation="tanh")(x)
decoder = keras.Model(inputs, outputs, name="vae_decoder")
return decoder
# Build and test
encoder = build_vae_encoder()
decoder = build_vae_decoder()
# Test forward pass
test_image = tf.random.normal([1, 512, 512, 3])
z, mean, logvar = encoder(test_image)
reconstructed = decoder(z)
print(f"Input image: {test_image.shape}")
print(f"Latent z: {z.shape}")
print(f"Mean: {mean.shape}, LogVar: {logvar.shape}")
print(f"Reconstructed: {reconstructed.shape}")
print(f"Compression ratio: {512*512*3 / (64*64*4):.1f}x")
# KL divergence loss
kl_loss = -0.5 * tf.reduce_mean(1 + logvar - tf.square(mean) - tf.exp(logvar))
print(f"KL divergence: {kl_loss.numpy():.4f}")
The Complete Stable Diffusion Pipeline
The full Stable Diffusion pipeline connects all components: a text prompt flows through CLIP to produce embeddings, random noise is iteratively denoised by the U-Net (conditioned on text via cross-attention), and the final latent is decoded by the VAE into a full-resolution image.
flowchart TD
A["Text Prompt"] --> B["CLIP Text Encoder"]
B --> C["Text Embeddings (77x768)"]
D["Random Noise z_T"] --> E["U-Net Denoiser"]
C --> E
F["Timestep t"] --> E
E --> G["Predicted Noise"]
G --> H["Remove Noise (Scheduler)"]
H --> I{"t > 0?"}
I -->|"Yes"| E
I -->|"No"| J["Clean Latent z_0"]
J --> K["VAE Decoder"]
K --> L["Generated Image (512x512)"]
style A fill:#3B9797,color:#fff
style L fill:#3B9797,color:#fff
style E fill:#132440,color:#fff
style K fill:#16476A,color:#fff
import tensorflow as tf
import numpy as np
class StableDiffusionPipeline:
"""Complete Stable Diffusion inference pipeline.
Components: CLIP text encoder, U-Net denoiser, VAE decoder, noise scheduler.
"""
def __init__(self, unet, vae_decoder, text_encoder, scheduler, guidance_scale=7.5):
self.unet = unet
self.vae_decoder = vae_decoder
self.text_encoder = text_encoder
self.scheduler = scheduler
self.guidance_scale = guidance_scale
def encode_text(self, prompt, negative_prompt=""):
"""Encode text prompt to CLIP embeddings."""
# In production: tokenize + encode with CLIP
# Here: simulate with random embeddings for demonstration
cond_embeddings = self.text_encoder(prompt)
uncond_embeddings = self.text_encoder(negative_prompt)
return cond_embeddings, uncond_embeddings
@tf.function
def denoise_step(self, latents, t, text_embeddings, uncond_embeddings):
"""Single denoising step with classifier-free guidance."""
# Predict noise for conditional and unconditional
noise_cond = self.unet([latents, t, text_embeddings])
noise_uncond = self.unet([latents, t, uncond_embeddings])
# Classifier-free guidance
noise_pred = noise_uncond + self.guidance_scale * (noise_cond - noise_uncond)
# Scheduler step: remove predicted noise
alpha_t = self.scheduler.alpha_bar[t]
alpha_prev = self.scheduler.alpha_bar[t - 1] if t > 0 else tf.constant(1.0)
beta_t = 1 - alpha_t / alpha_prev
# DDIM-like update (deterministic)
pred_x0 = (latents - tf.sqrt(1 - alpha_t) * noise_pred) / tf.sqrt(alpha_t)
direction = tf.sqrt(1 - alpha_prev) * noise_pred
latents = tf.sqrt(alpha_prev) * pred_x0 + direction
return latents
def generate(self, prompt, negative_prompt="", num_steps=50, seed=42):
"""Generate an image from a text prompt."""
tf.random.set_seed(seed)
# 1. Encode text
text_emb, uncond_emb = self.encode_text(prompt, negative_prompt)
# 2. Start with random noise in latent space
latents = tf.random.normal([1, 64, 64, 4])
# 3. Iterative denoising
timesteps = np.linspace(999, 0, num_steps, dtype=int)
for i, t in enumerate(timesteps):
t_tensor = tf.constant(t, dtype=tf.int32)
latents = self.denoise_step(latents, t_tensor, text_emb, uncond_emb)
# 4. Decode latent to pixel space
image = self.vae_decoder(latents)
# 5. Post-process: [-1, 1] -> [0, 255]
image = (image + 1.0) * 127.5
image = tf.clip_by_value(image, 0, 255)
image = tf.cast(image, tf.uint8)
return image[0] # Remove batch dimension
# Usage demonstration (with placeholder components)
print("Pipeline components:")
print(" 1. CLIP Text Encoder: prompt -> 77x768 embeddings")
print(" 2. Noise Scheduler: manages beta/alpha schedules")
print(" 3. U-Net: iterative denoising (conditioned on text)")
print(" 4. VAE Decoder: 64x64x4 latent -> 512x512x3 image")
print(f" 5. Guidance scale: 7.5 (default)")
print(f" 6. Denoising steps: 50 (default)")
print(f" 7. Latent dimensions: 64x64x4 = {64*64*4:,} values")
print(f" 8. Output dimensions: 512x512x3 = {512*512*3:,} pixels")
Classifier-Free Guidance
Classifier-Free Guidance (CFG) is the technique that makes text-to-image generation controllable. During training, the text condition is randomly dropped (replaced with an empty embedding) 10% of the time, teaching the model both conditional and unconditional generation.
The Guidance Formula
At inference time, we compute both conditional and unconditional noise predictions, then extrapolate away from the unconditional prediction toward the conditional one:
$$\hat{\epsilon} = \epsilon_{\text{uncond}} + s \cdot (\epsilon_{\text{cond}} - \epsilon_{\text{uncond}})$$
Where $s$ is the guidance scale:
- $s = 1$: No guidance (just conditional prediction)
- $s = 7.5$: Default — good balance of quality and diversity
- $s = 15+$: Strong guidance — very prompt-adherent but less diverse, can be oversaturated
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def classifier_free_guidance(noise_cond, noise_uncond, guidance_scale):
"""Apply classifier-free guidance to noise predictions.
Args:
noise_cond: Noise predicted with text conditioning
noise_uncond: Noise predicted without text (unconditional)
guidance_scale: How strongly to follow the text prompt
Returns:
Guided noise prediction
"""
return noise_uncond + guidance_scale * (noise_cond - noise_uncond)
# Demonstrate effect of guidance scale
def visualize_guidance_effect():
"""Show how guidance scale affects the generation direction."""
# Simulate noise predictions (1D for visualization)
np.random.seed(42)
noise_uncond = np.random.randn(100)
noise_cond = noise_uncond + 0.3 * np.ones(100) # Conditional shifts toward prompt
scales = [1.0, 3.0, 7.5, 15.0, 25.0]
fig, axes = plt.subplots(1, len(scales), figsize=(18, 3))
for i, s in enumerate(scales):
guided = noise_uncond + s * (noise_cond - noise_uncond)
axes[i].hist(noise_uncond, bins=20, alpha=0.4, label="Unconditional", color="#BF092F")
axes[i].hist(guided, bins=20, alpha=0.6, label=f"Guided (s={s})", color="#3B9797")
axes[i].set_title(f"Scale = {s}")
axes[i].legend(fontsize=7)
axes[i].set_xlim(-4, 6)
plt.suptitle("Classifier-Free Guidance: Effect of Scale on Noise Distribution")
plt.tight_layout()
plt.show()
visualize_guidance_effect()
# Practical implementation
class GuidedDiffusionSampler:
"""Diffusion sampler with classifier-free guidance."""
def __init__(self, model, scheduler, guidance_scale=7.5):
self.model = model
self.scheduler = scheduler
self.guidance_scale = guidance_scale
def predict_noise(self, latents, timestep, text_embeddings, uncond_embeddings):
"""Predict noise with CFG."""
# Batch both predictions together for efficiency
latent_input = tf.concat([latents, latents], axis=0)
text_input = tf.concat([uncond_embeddings, text_embeddings], axis=0)
t_input = tf.concat([timestep, timestep], axis=0)
# Single forward pass for both
noise_pred = self.model([latent_input, t_input, text_input])
noise_uncond, noise_cond = tf.split(noise_pred, 2, axis=0)
# Apply guidance
guided_noise = classifier_free_guidance(
noise_cond, noise_uncond, self.guidance_scale
)
return guided_noise
print("Guidance Scale Effects:")
print(" s=1.0 : Pure conditional (no guidance boost)")
print(" s=3.0 : Mild guidance (more diverse outputs)")
print(" s=7.5 : Default (balanced quality/diversity)")
print(" s=15.0 : Strong guidance (high prompt adherence)")
print(" s=25.0 : Extreme (oversaturated, artifacts likely)")
Choosing Guidance Scale: For photorealistic images, use s=7-9. For artistic/stylized images, use s=10-15. For maximum prompt adherence (e.g., specific compositions), try s=12-20. If you see color oversaturation or strange artifacts, reduce the scale. If outputs ignore parts of your prompt, increase it.
Using KerasCV Stable Diffusion
For production use, KerasCV provides a fully pre-trained Stable Diffusion implementation that generates high-quality images in just a few lines of code. This is built on the same architecture we implemented above but with optimized weights and inference paths.
Text-to-Image Generation
import tensorflow as tf
import keras_cv
import matplotlib.pyplot as plt
# Initialize Stable Diffusion model (downloads weights on first run)
model = keras_cv.models.StableDiffusion(
img_width=512,
img_height=512,
jit_compile=True # XLA compilation for speed
)
# Generate images from text prompt
images = model.text_to_image(
prompt="a serene mountain lake at sunset, photorealistic, 4k",
batch_size=1,
num_steps=50,
unconditional_guidance_scale=7.5,
seed=42
)
# Display result
plt.figure(figsize=(8, 8))
plt.imshow(images[0])
plt.axis("off")
plt.title("Generated: Mountain Lake at Sunset")
plt.tight_layout()
plt.show()
print(f"Image shape: {images[0].shape}")
print(f"Pixel range: [{images[0].min()}, {images[0].max()}]")
Image-to-Image (Modifying Existing Images)
import tensorflow as tf
import keras_cv
import numpy as np
import matplotlib.pyplot as plt
# Initialize model
model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)
# Load and preprocess an existing image
# (Using a synthetic gradient image for demonstration)
source_image = np.zeros((512, 512, 3), dtype=np.float32)
source_image[:256, :, 0] = np.linspace(0, 1, 256)[:, None] # Red gradient top
source_image[256:, :, 2] = np.linspace(0, 1, 256)[:, None] # Blue gradient bottom
# img2img: modify existing image with a text prompt
# strength controls how much to change (0=no change, 1=full generation)
images = model.image_to_image(
prompt="a beautiful landscape painting in watercolor style",
image=source_image,
strength=0.7, # 70% modification
num_steps=50,
unconditional_guidance_scale=8.0,
seed=123
)
# Compare
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(np.clip(source_image, 0, 1))
axes[0].set_title("Source Image")
axes[0].axis("off")
axes[1].imshow(images[0])
axes[1].set_title("img2img Result (strength=0.7)")
axes[1].axis("off")
plt.tight_layout()
plt.show()
Performance Optimization
import tensorflow as tf
import keras_cv
import time
# Enable mixed precision for faster inference
tf.keras.mixed_precision.set_global_policy("mixed_float16")
# Build with XLA compilation
model = keras_cv.models.StableDiffusion(
img_width=512,
img_height=512,
jit_compile=True # Enables XLA - major speedup
)
# Warm-up run (XLA compilation happens here)
print("Warm-up run (compiling XLA graph)...")
_ = model.text_to_image("warmup", batch_size=1, num_steps=25)
# Benchmark
prompt = "a cyberpunk cityscape at night with neon lights, detailed"
start = time.time()
images = model.text_to_image(prompt, batch_size=1, num_steps=50, seed=777)
elapsed = time.time() - start
print(f"Generation time: {elapsed:.2f} seconds")
print(f"Steps per second: {50 / elapsed:.1f}")
print(f"Mixed precision: {tf.keras.mixed_precision.global_policy().name}")
print(f"XLA compiled: True")
# Memory usage
gpu_devices = tf.config.list_physical_devices("GPU")
if gpu_devices:
mem_info = tf.config.experimental.get_memory_info("GPU:0")
print(f"GPU memory used: {mem_info['current'] / 1e9:.2f} GB")
- From-scratch implementation: Research, custom architectures, fine-tuning specific components, educational purposes
- KerasCV pre-built: Production applications, rapid prototyping, standard text-to-image/img2img workflows, when you need optimized inference