Back to TensorFlow Mastery Series

Deep Dive: Stable Diffusion — Image Generation in TensorFlow

May 3, 2026 Wasil Zafar 40 min read

Implement the complete Stable Diffusion pipeline in TensorFlow — from the denoising U-Net and VAE to the CLIP text encoder, noise scheduling, and classifier-free guidance for text-to-image generation.

Table of Contents

  1. Diffusion Fundamentals
  2. U-Net Architecture
  3. VAE & Latent Space
  4. Guidance & Production

How Diffusion Models Work

Diffusion models generate images by learning to reverse a gradual noising process. The key insight is that destroying structure (adding noise) is trivial, but learning to recover structure (removing noise) produces a powerful generative model.

Core Idea: Stable Diffusion operates in a compressed latent space rather than directly on pixels, making it 48× more computationally efficient than pixel-space diffusion models like DALL-E 2.

Forward Process (Adding Noise)

The forward process gradually adds Gaussian noise to a clean image over T timesteps until it becomes pure noise. At each step:

$$q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} \, x_{t-1}, \, \beta_t I)$$

Where $\beta_t$ is the noise schedule controlling how much noise is added at each step. Using the reparameterization trick, we can jump directly to any timestep:

$$x_t = \sqrt{\bar{\alpha}_t} \, x_0 + \sqrt{1 - \bar{\alpha}_t} \, \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)$$

Where $\bar{\alpha}_t = \prod_{s=1}^{t} (1 - \beta_s)$ is the cumulative product of noise retention factors.

Reverse Process (Denoising)

The reverse process learns to predict and remove noise step by step. A neural network (the U-Net) is trained to predict the noise $\epsilon_\theta(x_t, t)$ added at each timestep, then subtracts it to recover the cleaner image.

Diffusion Forward and Reverse Processes
flowchart LR
    A["Clean Image x0"] -->|"Add noise beta_1"| B["x1"]
    B -->|"Add noise beta_2"| C["x2"]
    C -->|"... T steps ..."| D["xT ~ N(0,I)"]
    D -->|"U-Net predicts noise"| E["x(T-1)"]
    E -->|"Denoise step"| F["x(T-2)"]
    F -->|"... T steps ..."| G["Generated Image"]

    style A fill:#3B9797,color:#fff
    style D fill:#BF092F,color:#fff
    style G fill:#3B9797,color:#fff
                            

Latent Diffusion

Instead of operating on full-resolution pixels (e.g., 512×512×3 = 786,432 dimensions), Stable Diffusion first compresses the image to a latent space (64×64×4 = 16,384 dimensions) using a VAE. The diffusion process operates entirely in this compressed space, then decodes back to pixel space.

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Demonstrate the forward diffusion process
def forward_diffusion_demo(num_steps=5):
    """Show progressive noising of a synthetic image."""
    # Create a simple synthetic image (gradient pattern)
    x = np.zeros((64, 64, 3), dtype=np.float32)
    x[:32, :, 0] = 1.0  # Red top half
    x[32:, :, 2] = 1.0  # Blue bottom half
    x[:, 32:, 1] = 0.5  # Green right half

    # Linear noise schedule
    betas = np.linspace(0.0001, 0.02, 1000)
    alphas = 1.0 - betas
    alpha_bar = np.cumprod(alphas)

    # Show progressive noising
    timesteps = np.linspace(0, 999, num_steps, dtype=int)
    fig, axes = plt.subplots(1, num_steps, figsize=(15, 3))

    for i, t in enumerate(timesteps):
        noise = np.random.randn(*x.shape).astype(np.float32)
        x_t = np.sqrt(alpha_bar[t]) * x + np.sqrt(1 - alpha_bar[t]) * noise
        x_t_clipped = np.clip(x_t, 0, 1)
        axes[i].imshow(x_t_clipped)
        axes[i].set_title(f"t={t}")
        axes[i].axis("off")

    plt.suptitle("Forward Diffusion: Clean Image to Noise")
    plt.tight_layout()
    plt.show()

forward_diffusion_demo()

Noise Scheduling

The noise schedule $\{\beta_t\}_{t=1}^T$ controls how quickly images are destroyed during the forward process. The choice of schedule significantly impacts generation quality.

Linear Schedule

The original DDPM paper used a linear schedule from $\beta_1 = 10^{-4}$ to $\beta_T = 0.02$. This destroys information too quickly in early steps, wasting model capacity on nearly-pure noise.

Cosine Schedule

The cosine schedule provides a more gradual noise increase, preserving structure longer and giving the model more useful signal to learn from at intermediate timesteps.

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
    """Linear noise schedule as in original DDPM."""
    return np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)

def cosine_beta_schedule(timesteps, s=0.008):
    """Cosine noise schedule (improved DDPM)."""
    steps = timesteps + 1
    t = np.linspace(0, timesteps, steps, dtype=np.float64)
    alpha_bar = np.cos(((t / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alpha_bar = alpha_bar / alpha_bar[0]
    betas = 1 - (alpha_bar[1:] / alpha_bar[:-1])
    return np.clip(betas, 0.0001, 0.999).astype(np.float32)

# Compare schedules
T = 1000
linear_betas = linear_beta_schedule(T)
cosine_betas = cosine_beta_schedule(T)

# Compute alpha_bar for both
linear_alpha_bar = np.cumprod(1.0 - linear_betas)
cosine_alpha_bar = np.cumprod(1.0 - cosine_betas)

# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(linear_betas, label="Linear", color="#BF092F")
axes[0].plot(cosine_betas, label="Cosine", color="#3B9797")
axes[0].set_xlabel("Timestep")
axes[0].set_ylabel("Beta (noise level)")
axes[0].set_title("Noise Schedule: Beta over Time")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(linear_alpha_bar, label="Linear", color="#BF092F")
axes[1].plot(cosine_alpha_bar, label="Cosine", color="#3B9797")
axes[1].set_xlabel("Timestep")
axes[1].set_ylabel("Alpha Bar (signal retention)")
axes[1].set_title("Signal Retention over Time")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Linear: signal at t=500: {linear_alpha_bar[500]:.4f}")
print(f"Cosine: signal at t=500: {cosine_alpha_bar[500]:.4f}")
Observation

The cosine schedule retains more signal at intermediate timesteps. At t=500, cosine retains ~50% signal while linear retains only ~5%. This means the model spends more training steps learning meaningful structure removal rather than just distinguishing noise from slightly-less noise.

noise-schedule DDPM cosine
import tensorflow as tf
import numpy as np

class NoiseScheduler:
    """Manages noise scheduling for diffusion models."""

    def __init__(self, timesteps=1000, schedule_type="cosine"):
        self.timesteps = timesteps

        if schedule_type == "linear":
            self.betas = tf.constant(
                np.linspace(1e-4, 0.02, timesteps, dtype=np.float32)
            )
        elif schedule_type == "cosine":
            steps = timesteps + 1
            t = np.linspace(0, timesteps, steps, dtype=np.float64)
            alpha_bar = np.cos(((t / timesteps) + 0.008) / 1.008 * np.pi * 0.5) ** 2
            alpha_bar = alpha_bar / alpha_bar[0]
            betas = 1 - (alpha_bar[1:] / alpha_bar[:-1])
            self.betas = tf.constant(np.clip(betas, 1e-4, 0.999), dtype=tf.float32)
        else:
            raise ValueError(f"Unknown schedule: {schedule_type}")

        self.alphas = 1.0 - self.betas
        self.alpha_bar = tf.math.cumprod(self.alphas)
        self.sqrt_alpha_bar = tf.sqrt(self.alpha_bar)
        self.sqrt_one_minus_alpha_bar = tf.sqrt(1.0 - self.alpha_bar)

    def add_noise(self, x_0, noise, t):
        """Add noise to clean images at timestep t."""
        sqrt_ab = tf.gather(self.sqrt_alpha_bar, t)
        sqrt_one_minus_ab = tf.gather(self.sqrt_one_minus_alpha_bar, t)
        # Reshape for broadcasting with image dimensions
        sqrt_ab = tf.reshape(sqrt_ab, [-1, 1, 1, 1])
        sqrt_one_minus_ab = tf.reshape(sqrt_one_minus_ab, [-1, 1, 1, 1])
        return sqrt_ab * x_0 + sqrt_one_minus_ab * noise

# Test the scheduler
scheduler = NoiseScheduler(timesteps=1000, schedule_type="cosine")
print(f"Beta range: [{scheduler.betas[0].numpy():.6f}, {scheduler.betas[-1].numpy():.6f}]")
print(f"Alpha_bar at t=0: {scheduler.alpha_bar[0].numpy():.4f}")
print(f"Alpha_bar at t=500: {scheduler.alpha_bar[500].numpy():.4f}")
print(f"Alpha_bar at t=999: {scheduler.alpha_bar[999].numpy():.4f}")

# Demo: add noise to a batch of random images
batch = tf.random.normal([4, 64, 64, 4])  # Latent space batch
noise = tf.random.normal(tf.shape(batch))
timesteps = tf.constant([0, 250, 500, 999])
noisy = scheduler.add_noise(batch, noise, timesteps)
print(f"Noisy batch shape: {noisy.shape}")

The U-Net Denoiser

The U-Net is the core neural network in Stable Diffusion. It takes a noisy latent image and a timestep embedding, and predicts the noise to be removed. The architecture uses an encoder-decoder structure with skip connections, residual blocks, and attention mechanisms at multiple resolutions.

Architecture Components

  • Time Embedding: Sinusoidal positional encoding of the timestep, projected through two linear layers
  • Residual Blocks: GroupNorm → SiLU → Conv → GroupNorm → SiLU → Conv + skip connection
  • Attention Blocks: Self-attention for spatial reasoning at 32×32, 16×16, and 8×8 resolutions
  • Downsample/Upsample: Strided convolutions (down) and transposed convolutions (up)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

class TimeEmbedding(layers.Layer):
    """Sinusoidal time embedding for conditioning on diffusion timestep."""

    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense1 = layers.Dense(embed_dim * 4, activation="swish")
        self.dense2 = layers.Dense(embed_dim * 4)

    def sinusoidal_embedding(self, t):
        half_dim = self.embed_dim // 2
        freq = tf.exp(
            -tf.math.log(10000.0) * tf.range(0, half_dim, dtype=tf.float32) / half_dim
        )
        t = tf.cast(t, tf.float32)
        args = t[:, None] * freq[None, :]
        return tf.concat([tf.sin(args), tf.cos(args)], axis=-1)

    def call(self, t):
        emb = self.sinusoidal_embedding(t)
        emb = self.dense1(emb)
        emb = self.dense2(emb)
        return emb


class ResBlock(layers.Layer):
    """Residual block with time embedding injection."""

    def __init__(self, out_channels, **kwargs):
        super().__init__(**kwargs)
        self.out_channels = out_channels
        self.norm1 = layers.GroupNormalization(groups=32)
        self.conv1 = layers.Conv2D(out_channels, 3, padding="same")
        self.time_proj = layers.Dense(out_channels)
        self.norm2 = layers.GroupNormalization(groups=32)
        self.conv2 = layers.Conv2D(out_channels, 3, padding="same")
        self.skip_conv = None  # Lazy init for channel mismatch

    def build(self, input_shape):
        if input_shape[-1] != self.out_channels:
            self.skip_conv = layers.Conv2D(self.out_channels, 1)
        super().build(input_shape)

    def call(self, x, time_emb):
        h = self.norm1(x)
        h = tf.nn.silu(h)
        h = self.conv1(h)
        # Inject time embedding
        t = tf.nn.silu(time_emb)
        t = self.time_proj(t)[:, None, None, :]
        h = h + t
        h = self.norm2(h)
        h = tf.nn.silu(h)
        h = self.conv2(h)
        # Skip connection
        if self.skip_conv is not None:
            x = self.skip_conv(x)
        return x + h


class AttentionBlock(layers.Layer):
    """Self-attention block for spatial reasoning."""

    def __init__(self, channels, num_heads=4, **kwargs):
        super().__init__(**kwargs)
        self.channels = channels
        self.num_heads = num_heads
        self.norm = layers.GroupNormalization(groups=32)
        self.qkv = layers.Dense(channels * 3)
        self.proj = layers.Dense(channels)

    def call(self, x):
        b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
        residual = x
        x = self.norm(x)
        x = tf.reshape(x, [b, h * w, c])
        qkv = self.qkv(x)
        q, k, v = tf.split(qkv, 3, axis=-1)
        # Multi-head attention
        head_dim = c // self.num_heads
        q = tf.reshape(q, [b, h * w, self.num_heads, head_dim])
        k = tf.reshape(k, [b, h * w, self.num_heads, head_dim])
        v = tf.reshape(v, [b, h * w, self.num_heads, head_dim])
        q = tf.transpose(q, [0, 2, 1, 3])
        k = tf.transpose(k, [0, 2, 1, 3])
        v = tf.transpose(v, [0, 2, 1, 3])
        scale = tf.math.rsqrt(tf.cast(head_dim, tf.float32))
        attn = tf.matmul(q, k, transpose_b=True) * scale
        attn = tf.nn.softmax(attn, axis=-1)
        out = tf.matmul(attn, v)
        out = tf.transpose(out, [0, 2, 1, 3])
        out = tf.reshape(out, [b, h, w, c])
        out = self.proj(tf.reshape(out, [b, h * w, c]))
        out = tf.reshape(out, [b, h, w, c])
        return residual + out


# Test components
time_emb = TimeEmbedding(embed_dim=128)
t = tf.constant([0, 100, 500, 999])
emb = time_emb(t)
print(f"Time embedding shape: {emb.shape}")  # (4, 512)

res_block = ResBlock(out_channels=256)
x = tf.random.normal([2, 32, 32, 128])
out = res_block(x, emb[:2])
print(f"ResBlock output shape: {out.shape}")  # (2, 32, 32, 256)

attn_block = AttentionBlock(channels=256, num_heads=4)
out = attn_block(out)
print(f"AttentionBlock output shape: {out.shape}")  # (2, 32, 32, 256)

Complete U-Net Assembly

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

def build_unet(latent_channels=4, base_channels=128, channel_mults=(1, 2, 4, 4),
               num_res_blocks=2, attention_resolutions=(2, 4)):
    """Build a simplified U-Net for latent diffusion."""

    # Inputs
    latent_input = keras.Input(shape=(64, 64, latent_channels), name="latent")
    time_input = keras.Input(shape=(), dtype=tf.int32, name="timestep")

    # Time embedding
    time_emb_layer = TimeEmbedding(embed_dim=base_channels)
    t_emb = time_emb_layer(time_input)

    # Initial convolution
    h = layers.Conv2D(base_channels, 3, padding="same")(latent_input)

    # Encoder path
    skips = [h]
    for level, mult in enumerate(channel_mults):
        ch = base_channels * mult
        for _ in range(num_res_blocks):
            res = ResBlock(ch)
            h = res(h, t_emb)
            if mult in attention_resolutions:
                attn = AttentionBlock(ch)
                h = attn(h)
            skips.append(h)
        if level < len(channel_mults) - 1:
            h = layers.Conv2D(ch, 3, strides=2, padding="same")(h)
            skips.append(h)

    # Middle block
    mid_ch = base_channels * channel_mults[-1]
    mid_res1 = ResBlock(mid_ch)
    mid_attn = AttentionBlock(mid_ch)
    mid_res2 = ResBlock(mid_ch)
    h = mid_res1(h, t_emb)
    h = mid_attn(h)
    h = mid_res2(h, t_emb)

    # Decoder path
    for level, mult in reversed(list(enumerate(channel_mults))):
        ch = base_channels * mult
        for _ in range(num_res_blocks + 1):
            skip = skips.pop()
            h = layers.Concatenate()([h, skip])
            res = ResBlock(ch)
            h = res(h, t_emb)
            if mult in attention_resolutions:
                attn = AttentionBlock(ch)
                h = attn(h)
        if level > 0:
            h = layers.Conv2DTranspose(ch, 3, strides=2, padding="same")(h)

    # Output
    h = layers.GroupNormalization(groups=32)(h)
    h = tf.nn.silu(h)
    noise_pred = layers.Conv2D(latent_channels, 3, padding="same")(h)

    model = keras.Model(inputs=[latent_input, time_input], outputs=noise_pred)
    return model

# Build and summarize
unet = build_unet()
print(f"U-Net parameters: {unet.count_params():,}")
print(f"Input shapes: latent={unet.input[0].shape}, time={unet.input[1].shape}")
print(f"Output shape: {unet.output.shape}")

Cross-Attention: Conditioning on Text

The key innovation that makes Stable Diffusion controllable by text prompts is cross-attention. Rather than only attending to other spatial positions (self-attention), cross-attention layers allow the U-Net features to attend to text embeddings from a CLIP encoder.

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

In cross-attention: $Q$ comes from the image features (spatial), while $K$ and $V$ come from the text embeddings (77 tokens × 768 dimensions from CLIP).

import tensorflow as tf
from tensorflow.keras import layers

class CrossAttention(layers.Layer):
    """Cross-attention layer for text-conditioned image generation.

    Q from image features, K/V from text embeddings (CLIP output).
    """

    def __init__(self, channels, context_dim=768, num_heads=8, **kwargs):
        super().__init__(**kwargs)
        self.channels = channels
        self.num_heads = num_heads
        self.head_dim = channels // num_heads

        self.norm = layers.GroupNormalization(groups=32)
        self.to_q = layers.Dense(channels, use_bias=False)
        self.to_k = layers.Dense(channels, use_bias=False)
        self.to_v = layers.Dense(channels, use_bias=False)
        self.to_out = layers.Dense(channels)

    def call(self, x, context):
        """
        Args:
            x: Image features [batch, height, width, channels]
            context: Text embeddings [batch, seq_len, context_dim]
        """
        b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
        residual = x

        x = self.norm(x)
        x_flat = tf.reshape(x, [b, h * w, c])

        # Q from image features, K/V from text context
        q = self.to_q(x_flat)
        k = self.to_k(context)
        v = self.to_v(context)

        # Reshape for multi-head attention
        q = tf.reshape(q, [b, h * w, self.num_heads, self.head_dim])
        k = tf.reshape(k, [b, -1, self.num_heads, self.head_dim])
        v = tf.reshape(v, [b, -1, self.num_heads, self.head_dim])

        q = tf.transpose(q, [0, 2, 1, 3])  # [b, heads, hw, dim]
        k = tf.transpose(k, [0, 2, 1, 3])  # [b, heads, seq, dim]
        v = tf.transpose(v, [0, 2, 1, 3])

        # Scaled dot-product attention
        scale = tf.math.rsqrt(tf.cast(self.head_dim, tf.float32))
        attn_weights = tf.matmul(q, k, transpose_b=True) * scale
        attn_weights = tf.nn.softmax(attn_weights, axis=-1)
        attn_output = tf.matmul(attn_weights, v)

        # Reshape back
        attn_output = tf.transpose(attn_output, [0, 2, 1, 3])
        attn_output = tf.reshape(attn_output, [b, h * w, c])
        attn_output = self.to_out(attn_output)
        attn_output = tf.reshape(attn_output, [b, h, w, c])

        return residual + attn_output


# Test cross-attention
cross_attn = CrossAttention(channels=256, context_dim=768, num_heads=8)

# Simulated inputs
image_features = tf.random.normal([2, 32, 32, 256])   # From U-Net encoder
text_embeddings = tf.random.normal([2, 77, 768])       # From CLIP (77 tokens)

output = cross_attn(image_features, text_embeddings)
print(f"Cross-attention input: {image_features.shape}")
print(f"Text context: {text_embeddings.shape}")
print(f"Cross-attention output: {output.shape}")
# Verify shape preserved: (2, 32, 32, 256)
Key Insight: The text embeddings act as a “steering signal” for the denoising process. At each denoising step, cross-attention allows the model to focus on different words from the prompt depending on which spatial regions are being refined. This is why prompt word order and emphasis matter.

The VAE: Latent Space Encoding

The Variational Autoencoder (VAE) in Stable Diffusion provides the compression between pixel space and latent space. The encoder maps a 512×512×3 image to a 64×64×4 latent representation (48× compression), and the decoder reverses this mapping.

Why Latent Space?

  • Computation: 512×512×3 = 786,432 values vs 64×64×4 = 16,384 values
  • Memory: U-Net attention on 64×64 is tractable; on 512×512 it would require 100+ GB
  • Quality: The latent space captures semantic features, not pixel noise
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

def build_vae_encoder(input_shape=(512, 512, 3), latent_channels=4):
    """VAE Encoder: Image -> Latent (with KL regularization)."""
    inputs = keras.Input(shape=input_shape)

    # Downsampling path: 512 -> 256 -> 128 -> 64
    x = layers.Conv2D(64, 3, padding="same")(inputs)
    x = layers.GroupNormalization(groups=32)(x)
    x = tf.nn.silu(x)

    x = layers.Conv2D(128, 3, strides=2, padding="same")(x)  # 256x256
    x = layers.GroupNormalization(groups=32)(x)
    x = tf.nn.silu(x)

    x = layers.Conv2D(256, 3, strides=2, padding="same")(x)  # 128x128
    x = layers.GroupNormalization(groups=32)(x)
    x = tf.nn.silu(x)

    x = layers.Conv2D(512, 3, strides=2, padding="same")(x)  # 64x64
    x = layers.GroupNormalization(groups=32)(x)
    x = tf.nn.silu(x)

    # Produce mean and log-variance
    x = layers.Conv2D(latent_channels * 2, 3, padding="same")(x)
    mean, logvar = tf.split(x, 2, axis=-1)

    # Reparameterization trick
    std = tf.exp(0.5 * logvar)
    eps = tf.random.normal(tf.shape(std))
    z = mean + std * eps

    encoder = keras.Model(inputs, [z, mean, logvar], name="vae_encoder")
    return encoder


def build_vae_decoder(latent_shape=(64, 64, 4), output_channels=3):
    """VAE Decoder: Latent -> Image."""
    inputs = keras.Input(shape=latent_shape)

    # Upsampling path: 64 -> 128 -> 256 -> 512
    x = layers.Conv2D(512, 3, padding="same")(inputs)
    x = layers.GroupNormalization(groups=32)(x)
    x = tf.nn.silu(x)

    x = layers.Conv2DTranspose(256, 3, strides=2, padding="same")(x)  # 128x128
    x = layers.GroupNormalization(groups=32)(x)
    x = tf.nn.silu(x)

    x = layers.Conv2DTranspose(128, 3, strides=2, padding="same")(x)  # 256x256
    x = layers.GroupNormalization(groups=32)(x)
    x = tf.nn.silu(x)

    x = layers.Conv2DTranspose(64, 3, strides=2, padding="same")(x)   # 512x512
    x = layers.GroupNormalization(groups=32)(x)
    x = tf.nn.silu(x)

    outputs = layers.Conv2D(output_channels, 3, padding="same", activation="tanh")(x)

    decoder = keras.Model(inputs, outputs, name="vae_decoder")
    return decoder


# Build and test
encoder = build_vae_encoder()
decoder = build_vae_decoder()

# Test forward pass
test_image = tf.random.normal([1, 512, 512, 3])
z, mean, logvar = encoder(test_image)
reconstructed = decoder(z)

print(f"Input image: {test_image.shape}")
print(f"Latent z: {z.shape}")
print(f"Mean: {mean.shape}, LogVar: {logvar.shape}")
print(f"Reconstructed: {reconstructed.shape}")
print(f"Compression ratio: {512*512*3 / (64*64*4):.1f}x")

# KL divergence loss
kl_loss = -0.5 * tf.reduce_mean(1 + logvar - tf.square(mean) - tf.exp(logvar))
print(f"KL divergence: {kl_loss.numpy():.4f}")

The Complete Stable Diffusion Pipeline

The full Stable Diffusion pipeline connects all components: a text prompt flows through CLIP to produce embeddings, random noise is iteratively denoised by the U-Net (conditioned on text via cross-attention), and the final latent is decoded by the VAE into a full-resolution image.

Complete Stable Diffusion Pipeline
flowchart TD
    A["Text Prompt"] --> B["CLIP Text Encoder"]
    B --> C["Text Embeddings (77x768)"]
    D["Random Noise z_T"] --> E["U-Net Denoiser"]
    C --> E
    F["Timestep t"] --> E
    E --> G["Predicted Noise"]
    G --> H["Remove Noise (Scheduler)"]
    H --> I{"t > 0?"}
    I -->|"Yes"| E
    I -->|"No"| J["Clean Latent z_0"]
    J --> K["VAE Decoder"]
    K --> L["Generated Image (512x512)"]

    style A fill:#3B9797,color:#fff
    style L fill:#3B9797,color:#fff
    style E fill:#132440,color:#fff
    style K fill:#16476A,color:#fff
                            
import tensorflow as tf
import numpy as np

class StableDiffusionPipeline:
    """Complete Stable Diffusion inference pipeline.

    Components: CLIP text encoder, U-Net denoiser, VAE decoder, noise scheduler.
    """

    def __init__(self, unet, vae_decoder, text_encoder, scheduler, guidance_scale=7.5):
        self.unet = unet
        self.vae_decoder = vae_decoder
        self.text_encoder = text_encoder
        self.scheduler = scheduler
        self.guidance_scale = guidance_scale

    def encode_text(self, prompt, negative_prompt=""):
        """Encode text prompt to CLIP embeddings."""
        # In production: tokenize + encode with CLIP
        # Here: simulate with random embeddings for demonstration
        cond_embeddings = self.text_encoder(prompt)
        uncond_embeddings = self.text_encoder(negative_prompt)
        return cond_embeddings, uncond_embeddings

    @tf.function
    def denoise_step(self, latents, t, text_embeddings, uncond_embeddings):
        """Single denoising step with classifier-free guidance."""
        # Predict noise for conditional and unconditional
        noise_cond = self.unet([latents, t, text_embeddings])
        noise_uncond = self.unet([latents, t, uncond_embeddings])

        # Classifier-free guidance
        noise_pred = noise_uncond + self.guidance_scale * (noise_cond - noise_uncond)

        # Scheduler step: remove predicted noise
        alpha_t = self.scheduler.alpha_bar[t]
        alpha_prev = self.scheduler.alpha_bar[t - 1] if t > 0 else tf.constant(1.0)
        beta_t = 1 - alpha_t / alpha_prev

        # DDIM-like update (deterministic)
        pred_x0 = (latents - tf.sqrt(1 - alpha_t) * noise_pred) / tf.sqrt(alpha_t)
        direction = tf.sqrt(1 - alpha_prev) * noise_pred
        latents = tf.sqrt(alpha_prev) * pred_x0 + direction

        return latents

    def generate(self, prompt, negative_prompt="", num_steps=50, seed=42):
        """Generate an image from a text prompt."""
        tf.random.set_seed(seed)

        # 1. Encode text
        text_emb, uncond_emb = self.encode_text(prompt, negative_prompt)

        # 2. Start with random noise in latent space
        latents = tf.random.normal([1, 64, 64, 4])

        # 3. Iterative denoising
        timesteps = np.linspace(999, 0, num_steps, dtype=int)
        for i, t in enumerate(timesteps):
            t_tensor = tf.constant(t, dtype=tf.int32)
            latents = self.denoise_step(latents, t_tensor, text_emb, uncond_emb)

        # 4. Decode latent to pixel space
        image = self.vae_decoder(latents)

        # 5. Post-process: [-1, 1] -> [0, 255]
        image = (image + 1.0) * 127.5
        image = tf.clip_by_value(image, 0, 255)
        image = tf.cast(image, tf.uint8)

        return image[0]  # Remove batch dimension


# Usage demonstration (with placeholder components)
print("Pipeline components:")
print("  1. CLIP Text Encoder: prompt -> 77x768 embeddings")
print("  2. Noise Scheduler: manages beta/alpha schedules")
print("  3. U-Net: iterative denoising (conditioned on text)")
print("  4. VAE Decoder: 64x64x4 latent -> 512x512x3 image")
print(f"  5. Guidance scale: 7.5 (default)")
print(f"  6. Denoising steps: 50 (default)")
print(f"  7. Latent dimensions: 64x64x4 = {64*64*4:,} values")
print(f"  8. Output dimensions: 512x512x3 = {512*512*3:,} pixels")

Classifier-Free Guidance

Classifier-Free Guidance (CFG) is the technique that makes text-to-image generation controllable. During training, the text condition is randomly dropped (replaced with an empty embedding) 10% of the time, teaching the model both conditional and unconditional generation.

The Guidance Formula

At inference time, we compute both conditional and unconditional noise predictions, then extrapolate away from the unconditional prediction toward the conditional one:

$$\hat{\epsilon} = \epsilon_{\text{uncond}} + s \cdot (\epsilon_{\text{cond}} - \epsilon_{\text{uncond}})$$

Where $s$ is the guidance scale:

  • $s = 1$: No guidance (just conditional prediction)
  • $s = 7.5$: Default — good balance of quality and diversity
  • $s = 15+$: Strong guidance — very prompt-adherent but less diverse, can be oversaturated
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def classifier_free_guidance(noise_cond, noise_uncond, guidance_scale):
    """Apply classifier-free guidance to noise predictions.

    Args:
        noise_cond: Noise predicted with text conditioning
        noise_uncond: Noise predicted without text (unconditional)
        guidance_scale: How strongly to follow the text prompt

    Returns:
        Guided noise prediction
    """
    return noise_uncond + guidance_scale * (noise_cond - noise_uncond)


# Demonstrate effect of guidance scale
def visualize_guidance_effect():
    """Show how guidance scale affects the generation direction."""
    # Simulate noise predictions (1D for visualization)
    np.random.seed(42)
    noise_uncond = np.random.randn(100)
    noise_cond = noise_uncond + 0.3 * np.ones(100)  # Conditional shifts toward prompt

    scales = [1.0, 3.0, 7.5, 15.0, 25.0]
    fig, axes = plt.subplots(1, len(scales), figsize=(18, 3))

    for i, s in enumerate(scales):
        guided = noise_uncond + s * (noise_cond - noise_uncond)
        axes[i].hist(noise_uncond, bins=20, alpha=0.4, label="Unconditional", color="#BF092F")
        axes[i].hist(guided, bins=20, alpha=0.6, label=f"Guided (s={s})", color="#3B9797")
        axes[i].set_title(f"Scale = {s}")
        axes[i].legend(fontsize=7)
        axes[i].set_xlim(-4, 6)

    plt.suptitle("Classifier-Free Guidance: Effect of Scale on Noise Distribution")
    plt.tight_layout()
    plt.show()

visualize_guidance_effect()

# Practical implementation
class GuidedDiffusionSampler:
    """Diffusion sampler with classifier-free guidance."""

    def __init__(self, model, scheduler, guidance_scale=7.5):
        self.model = model
        self.scheduler = scheduler
        self.guidance_scale = guidance_scale

    def predict_noise(self, latents, timestep, text_embeddings, uncond_embeddings):
        """Predict noise with CFG."""
        # Batch both predictions together for efficiency
        latent_input = tf.concat([latents, latents], axis=0)
        text_input = tf.concat([uncond_embeddings, text_embeddings], axis=0)
        t_input = tf.concat([timestep, timestep], axis=0)

        # Single forward pass for both
        noise_pred = self.model([latent_input, t_input, text_input])
        noise_uncond, noise_cond = tf.split(noise_pred, 2, axis=0)

        # Apply guidance
        guided_noise = classifier_free_guidance(
            noise_cond, noise_uncond, self.guidance_scale
        )
        return guided_noise


print("Guidance Scale Effects:")
print("  s=1.0  : Pure conditional (no guidance boost)")
print("  s=3.0  : Mild guidance (more diverse outputs)")
print("  s=7.5  : Default (balanced quality/diversity)")
print("  s=15.0 : Strong guidance (high prompt adherence)")
print("  s=25.0 : Extreme (oversaturated, artifacts likely)")
Practical Guide

Choosing Guidance Scale: For photorealistic images, use s=7-9. For artistic/stylized images, use s=10-15. For maximum prompt adherence (e.g., specific compositions), try s=12-20. If you see color oversaturation or strange artifacts, reduce the scale. If outputs ignore parts of your prompt, increase it.

CFG guidance-scale quality-control

Using KerasCV Stable Diffusion

For production use, KerasCV provides a fully pre-trained Stable Diffusion implementation that generates high-quality images in just a few lines of code. This is built on the same architecture we implemented above but with optimized weights and inference paths.

Production Ready: KerasCV’s Stable Diffusion supports text-to-image, image-to-image, and inpainting. With mixed precision and XLA compilation, it can generate a 512×512 image in under 10 seconds on a consumer GPU.

Text-to-Image Generation

import tensorflow as tf
import keras_cv
import matplotlib.pyplot as plt

# Initialize Stable Diffusion model (downloads weights on first run)
model = keras_cv.models.StableDiffusion(
    img_width=512,
    img_height=512,
    jit_compile=True  # XLA compilation for speed
)

# Generate images from text prompt
images = model.text_to_image(
    prompt="a serene mountain lake at sunset, photorealistic, 4k",
    batch_size=1,
    num_steps=50,
    unconditional_guidance_scale=7.5,
    seed=42
)

# Display result
plt.figure(figsize=(8, 8))
plt.imshow(images[0])
plt.axis("off")
plt.title("Generated: Mountain Lake at Sunset")
plt.tight_layout()
plt.show()

print(f"Image shape: {images[0].shape}")
print(f"Pixel range: [{images[0].min()}, {images[0].max()}]")

Image-to-Image (Modifying Existing Images)

import tensorflow as tf
import keras_cv
import numpy as np
import matplotlib.pyplot as plt

# Initialize model
model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)

# Load and preprocess an existing image
# (Using a synthetic gradient image for demonstration)
source_image = np.zeros((512, 512, 3), dtype=np.float32)
source_image[:256, :, 0] = np.linspace(0, 1, 256)[:, None]  # Red gradient top
source_image[256:, :, 2] = np.linspace(0, 1, 256)[:, None]  # Blue gradient bottom

# img2img: modify existing image with a text prompt
# strength controls how much to change (0=no change, 1=full generation)
images = model.image_to_image(
    prompt="a beautiful landscape painting in watercolor style",
    image=source_image,
    strength=0.7,       # 70% modification
    num_steps=50,
    unconditional_guidance_scale=8.0,
    seed=123
)

# Compare
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(np.clip(source_image, 0, 1))
axes[0].set_title("Source Image")
axes[0].axis("off")
axes[1].imshow(images[0])
axes[1].set_title("img2img Result (strength=0.7)")
axes[1].axis("off")
plt.tight_layout()
plt.show()

Performance Optimization

import tensorflow as tf
import keras_cv
import time

# Enable mixed precision for faster inference
tf.keras.mixed_precision.set_global_policy("mixed_float16")

# Build with XLA compilation
model = keras_cv.models.StableDiffusion(
    img_width=512,
    img_height=512,
    jit_compile=True  # Enables XLA - major speedup
)

# Warm-up run (XLA compilation happens here)
print("Warm-up run (compiling XLA graph)...")
_ = model.text_to_image("warmup", batch_size=1, num_steps=25)

# Benchmark
prompt = "a cyberpunk cityscape at night with neon lights, detailed"
start = time.time()
images = model.text_to_image(prompt, batch_size=1, num_steps=50, seed=777)
elapsed = time.time() - start

print(f"Generation time: {elapsed:.2f} seconds")
print(f"Steps per second: {50 / elapsed:.1f}")
print(f"Mixed precision: {tf.keras.mixed_precision.global_policy().name}")
print(f"XLA compiled: True")

# Memory usage
gpu_devices = tf.config.list_physical_devices("GPU")
if gpu_devices:
    mem_info = tf.config.experimental.get_memory_info("GPU:0")
    print(f"GPU memory used: {mem_info['current'] / 1e9:.2f} GB")
When to Use Each Approach:
  • From-scratch implementation: Research, custom architectures, fine-tuning specific components, educational purposes
  • KerasCV pre-built: Production applications, rapid prototyping, standard text-to-image/img2img workflows, when you need optimized inference