Back to TensorFlow Mastery Series

Deep Dive: EfficientNet — Compound Scaling in TensorFlow

May 3, 2026 Wasil Zafar 35 min read

Implement EfficientNet-B0 through B7 from scratch in TensorFlow/Keras — master compound scaling, MBConv blocks, squeeze-and-excitation, and train models that outperform ResNet at 8.4× fewer parameters.

Table of Contents

  1. The Scaling Problem
  2. Compound Scaling Method
  3. MBConv: The Building Block
  4. Squeeze-and-Excitation Module
  5. Building EfficientNet-B0
  6. Scaling to B1-B7
  7. Training on CIFAR-10
  8. Transfer Learning

The Scaling Problem

Convolutional neural networks have traditionally been scaled by increasing only one dimension at a time: deeper networks (ResNet-152), wider networks (Wide ResNet), or higher resolution inputs. Each approach yields diminishing returns because it ignores the fundamental relationship between network depth, width, and input resolution.

Key Insight: EfficientNet demonstrates that scaling all three dimensions (depth, width, resolution) together using a principled compound coefficient yields models that are up to 8.4× smaller and 6.1× faster than previous state-of-the-art while achieving superior accuracy.

Parameter Count Comparison

Let’s compare the parameter counts and approximate accuracy of popular architectures to understand why EfficientNet’s approach is revolutionary:

import tensorflow as tf
import numpy as np

# Compare parameter counts of popular architectures
models_info = {
    "VGG-16": {"params": 138_000_000, "top1_acc": 71.3, "flops_b": 15.5},
    "ResNet-50": {"params": 25_600_000, "top1_acc": 76.0, "flops_b": 4.1},
    "ResNet-152": {"params": 60_200_000, "top1_acc": 77.8, "flops_b": 11.6},
    "DenseNet-201": {"params": 20_000_000, "top1_acc": 77.4, "flops_b": 4.3},
    "EfficientNet-B0": {"params": 5_300_000, "top1_acc": 77.1, "flops_b": 0.39},
    "EfficientNet-B3": {"params": 12_000_000, "top1_acc": 81.6, "flops_b": 1.8},
    "EfficientNet-B7": {"params": 66_000_000, "top1_acc": 84.3, "flops_b": 37.0},
}

print(f"{'Model':<18} {'Params (M)':<12} {'Top-1 (%)':<10} {'FLOPs (B)':<10}")
print("-" * 50)
for name, info in models_info.items():
    params_m = info["params"] / 1e6
    print(f"{name:<18} {params_m:<12.1f} {info['top1_acc']:<10.1f} {info['flops_b']:<10.2f}")

# Calculate efficiency ratio (accuracy per million parameters)
print("\nEfficiency (Top-1 accuracy per million params):")
print("-" * 40)
for name, info in models_info.items():
    efficiency = info["top1_acc"] / (info["params"] / 1e6)
    print(f"{name:<18} {efficiency:.2f} %/M params")
Efficiency Analysis

EfficientNet-B0 achieves 77.1% ImageNet top-1 accuracy with only 5.3M parameters — that’s 14.5% accuracy per million parameters. ResNet-50 achieves similar accuracy (76.0%) but requires 25.6M parameters (only 2.97%/M). This 4.9× improvement in parameter efficiency is the direct result of compound scaling.

Compound Scaling Parameter Efficiency NAS

Compound Scaling Method

The compound scaling method uses a single compound coefficient $\phi$ to uniformly scale all three dimensions:

  • Depth: $d = \alpha^\phi$
  • Width: $w = \beta^\phi$
  • Resolution: $r = \gamma^\phi$

Subject to the constraint:

$$\alpha \cdot \beta^2 \cdot \gamma^2 \approx 2$$

where $\alpha \geq 1$, $\beta \geq 1$, $\gamma \geq 1$. The constraint ensures that for any new $\phi$, the total FLOPs roughly increase by $2^\phi$. The base coefficients found via grid search for EfficientNet are $\alpha = 1.2$, $\beta = 1.1$, $\gamma = 1.15$.

Why the constraint? FLOPS of a convolution scale as $d \cdot w^2 \cdot r^2$. With the constraint $\alpha \cdot \beta^2 \cdot \gamma^2 \approx 2$, total FLOPS increases by approximately $(\alpha \cdot \beta^2 \cdot \gamma^2)^\phi \approx 2^\phi$, giving us predictable compute scaling.

Implementing the Scaling Function

import numpy as np

def compute_scaling_coefficients(phi, alpha=1.2, beta=1.1, gamma=1.15):
    """
    Compute depth, width, and resolution scaling factors.

    Args:
        phi: Compound coefficient (0 for B0, 1 for B1, ..., 7 for B7)
        alpha: Depth scaling base (default 1.2)
        beta: Width scaling base (default 1.1)
        gamma: Resolution scaling base (default 1.15)

    Returns:
        Tuple of (depth_coeff, width_coeff, resolution)
    """
    depth_coeff = alpha ** phi
    width_coeff = beta ** phi
    resolution_coeff = gamma ** phi

    # Verify FLOPS constraint
    flops_ratio = alpha * (beta ** 2) * (gamma ** 2)
    print(f"FLOPS constraint check: alpha * beta^2 * gamma^2 = {flops_ratio:.4f} (target ~2.0)")

    return depth_coeff, width_coeff, resolution_coeff


# Base resolution for B0 is 224
BASE_RESOLUTION = 224

# Compute configs for B0 through B7
print(f"{'Variant':<12} {'phi':<5} {'Depth':<8} {'Width':<8} {'Resolution':<12}")
print("-" * 50)

for variant in range(8):
    phi = variant
    d, w, r = compute_scaling_coefficients(phi)
    resolution = int(BASE_RESOLUTION * r)
    # Round resolution to nearest multiple of 8 for hardware efficiency
    resolution = int(np.ceil(resolution / 8) * 8)
    print(f"B{variant:<11} {phi:<5} {d:<8.2f} {w:<8.2f} {resolution:<12}")
import numpy as np

# EfficientNet variant specifications (from the paper)
EFFICIENTNET_CONFIGS = {
    "B0": {"phi": 0, "resolution": 224, "depth_mult": 1.0, "width_mult": 1.0, "dropout": 0.2},
    "B1": {"phi": 1, "resolution": 240, "depth_mult": 1.1, "width_mult": 1.0, "dropout": 0.2},
    "B2": {"phi": 2, "resolution": 260, "depth_mult": 1.2, "width_mult": 1.1, "dropout": 0.3},
    "B3": {"phi": 3, "resolution": 300, "depth_mult": 1.4, "width_mult": 1.2, "dropout": 0.3},
    "B4": {"phi": 4, "resolution": 380, "depth_mult": 1.8, "width_mult": 1.4, "dropout": 0.4},
    "B5": {"phi": 5, "resolution": 456, "depth_mult": 2.2, "width_mult": 1.6, "dropout": 0.4},
    "B6": {"phi": 6, "resolution": 528, "depth_mult": 2.6, "width_mult": 1.8, "dropout": 0.5},
    "B7": {"phi": 7, "resolution": 600, "depth_mult": 3.1, "width_mult": 2.0, "dropout": 0.5},
}

print(f"{'Variant':<6} {'Resolution':<12} {'Depth Mult':<12} {'Width Mult':<12} {'Dropout':<8}")
print("-" * 52)
for name, cfg in EFFICIENTNET_CONFIGS.items():
    print(f"{name:<6} {cfg['resolution']:<12} {cfg['depth_mult']:<12.1f} "
          f"{cfg['width_mult']:<12.1f} {cfg['dropout']:<8.1f}")

MBConv: The Building Block

The Mobile Inverted Bottleneck Convolution (MBConv) from MobileNetV2 is the core building block of EfficientNet. Unlike standard residual blocks that use a wide→narrow→wide pattern, MBConv uses narrow→wide→narrow (inverted bottleneck):

MBConv Block Architecture
flowchart TD
    A[Input: C channels] --> B[1x1 Conv: Expand to C*t channels]
    B --> C[BatchNorm + Swish]
    C --> D[Depthwise Conv kxk]
    D --> E[BatchNorm + Swish]
    E --> F[Squeeze-and-Excitation]
    F --> G[1x1 Conv: Project to C_out channels]
    G --> H[BatchNorm]
    H --> I{Residual?}
    I -->|stride=1 AND C_in=C_out| J[Add Residual + Drop Connect]
    I -->|Otherwise| K[Output]
    J --> K
    A -.->|Skip Connection| J
                            

MBConv Keras Implementation

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

def mbconv_block(inputs, in_channels, out_channels, expand_ratio,
                 kernel_size, stride, se_ratio=0.25, drop_rate=0.0):
    """
    Mobile Inverted Bottleneck Convolution Block.

    Args:
        inputs: Input tensor
        in_channels: Number of input channels
        out_channels: Number of output channels
        expand_ratio: Expansion ratio for inverted bottleneck
        kernel_size: Depthwise convolution kernel size (3 or 5)
        stride: Stride for depthwise convolution (1 or 2)
        se_ratio: Squeeze-and-excitation reduction ratio
        drop_rate: Drop connect rate for stochastic depth

    Returns:
        Output tensor
    """
    expanded_channels = in_channels * expand_ratio
    use_residual = (stride == 1 and in_channels == out_channels)

    x = inputs

    # Phase 1: Expansion (skip if expand_ratio == 1)
    if expand_ratio != 1:
        x = layers.Conv2D(
            expanded_channels, 1, padding="same", use_bias=False,
            kernel_initializer="he_normal"
        )(x)
        x = layers.BatchNormalization(momentum=0.99, epsilon=1e-3)(x)
        x = layers.Activation("swish")(x)

    # Phase 2: Depthwise Convolution
    x = layers.DepthwiseConv2D(
        kernel_size, strides=stride, padding="same", use_bias=False,
        depthwise_initializer="he_normal"
    )(x)
    x = layers.BatchNormalization(momentum=0.99, epsilon=1e-3)(x)
    x = layers.Activation("swish")(x)

    # Phase 3: Squeeze-and-Excitation
    if se_ratio > 0:
        se_channels = max(1, int(in_channels * se_ratio))
        se = layers.GlobalAveragePooling2D(keepdims=True)(x)
        se = layers.Conv2D(se_channels, 1, activation="swish", use_bias=True)(se)
        se = layers.Conv2D(expanded_channels, 1, activation="sigmoid", use_bias=True)(se)
        x = layers.Multiply()([x, se])

    # Phase 4: Projection (pointwise linear)
    x = layers.Conv2D(
        out_channels, 1, padding="same", use_bias=False,
        kernel_initializer="he_normal"
    )(x)
    x = layers.BatchNormalization(momentum=0.99, epsilon=1e-3)(x)

    # Residual connection with drop connect
    if use_residual:
        if drop_rate > 0:
            x = layers.Dropout(drop_rate, noise_shape=(None, 1, 1, 1))(x)
        x = layers.Add()([x, inputs])

    return x


# Test the MBConv block
test_input = keras.Input(shape=(56, 56, 32))
test_output = mbconv_block(
    test_input, in_channels=32, out_channels=16,
    expand_ratio=1, kernel_size=3, stride=1, se_ratio=0.25
)
test_model = keras.Model(inputs=test_input, outputs=test_output)
print(f"MBConv1 (32->16, k3, s1): {test_model.count_params():,} parameters")
print(f"Output shape: {test_output.shape}")

Squeeze-and-Excitation Module

The Squeeze-and-Excitation (SE) module introduces channel attention by adaptively re-weighting feature maps. It learns which channels are most important for a given input and scales them accordingly.

Squeeze-and-Excitation Flow
flowchart LR
    A[Feature Map
H x W x C] --> B[Global Avg Pool
1 x 1 x C] B --> C[FC Reduce
1 x 1 x C/r] C --> D[ReLU/Swish] D --> E[FC Expand
1 x 1 x C] E --> F[Sigmoid] F --> G[Scale
H x W x C] A --> G

SE Block Implementation

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class SqueezeExcitation(layers.Layer):
    """
    Squeeze-and-Excitation block for channel attention.

    The SE block adaptively recalibrates channel-wise feature responses
    by explicitly modeling interdependencies between channels.
    """

    def __init__(self, input_channels, reduction_ratio=0.25, **kwargs):
        super().__init__(**kwargs)
        self.input_channels = input_channels
        self.reduced_channels = max(1, int(input_channels * reduction_ratio))

        # Squeeze: Global Average Pooling
        self.global_pool = layers.GlobalAveragePooling2D(keepdims=True)

        # Excitation: Two FC layers
        self.fc_reduce = layers.Conv2D(
            self.reduced_channels, 1, use_bias=True, activation="swish"
        )
        self.fc_expand = layers.Conv2D(
            input_channels, 1, use_bias=True, activation="sigmoid"
        )

    def call(self, inputs):
        # Squeeze: H x W x C -> 1 x 1 x C
        se = self.global_pool(inputs)

        # Excitation: Learn channel importance
        se = self.fc_reduce(se)   # 1x1xC -> 1x1x(C/r)
        se = self.fc_expand(se)   # 1x1x(C/r) -> 1x1xC

        # Scale: Reweight original features
        return inputs * se

    def get_config(self):
        config = super().get_config()
        config.update({
            "input_channels": self.input_channels,
            "reduced_channels": self.reduced_channels,
        })
        return config


# Demonstrate SE block behavior
import numpy as np

# Create a test feature map
np.random.seed(42)
test_features = tf.constant(np.random.randn(1, 7, 7, 64).astype(np.float32))

# Apply SE block
se_block = SqueezeExcitation(input_channels=64, reduction_ratio=0.25)
output = se_block(test_features)

print(f"Input shape:  {test_features.shape}")
print(f"Output shape: {output.shape}")
print(f"SE reduced channels: {se_block.reduced_channels}")
print(f"Parameters in SE block: {sum(p.numpy().size for p in se_block.trainable_weights)}")

# Show channel scaling weights for one sample
se_weights = se_block.global_pool(test_features)
se_weights = se_block.fc_reduce(se_weights)
se_weights = se_block.fc_expand(se_weights)
print(f"\nChannel attention weights (first 8 channels):")
print(f"  {se_weights.numpy()[0, 0, 0, :8].round(3)}")
print(f"  Min: {se_weights.numpy().min():.3f}, Max: {se_weights.numpy().max():.3f}")
SE Reduction Ratio: EfficientNet uses a reduction ratio of 0.25 relative to the input channels of the MBConv block (not the expanded channels). This means for an MBConv with 32 input channels, the SE bottleneck has $\lfloor 32 \times 0.25 \rfloor = 8$ channels, regardless of the expansion ratio.

Building EfficientNet-B0 from Scratch

EfficientNet-B0 is the baseline architecture discovered via Neural Architecture Search (NAS). It consists of a stem convolution, 7 stages of MBConv blocks with varying configurations, and a classification head.

Full EfficientNet-B0 Implementation

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import math


def mbconv_block(inputs, in_ch, out_ch, expand_ratio, kernel_size, stride,
                 se_ratio=0.25, drop_rate=0.0):
    """Mobile Inverted Bottleneck Conv block with SE attention."""
    expanded = in_ch * expand_ratio
    use_skip = (stride == 1 and in_ch == out_ch)
    x = inputs

    # Expansion phase
    if expand_ratio != 1:
        x = layers.Conv2D(expanded, 1, padding="same", use_bias=False)(x)
        x = layers.BatchNormalization(momentum=0.99, epsilon=1e-3)(x)
        x = layers.Activation("swish")(x)

    # Depthwise convolution
    x = layers.DepthwiseConv2D(kernel_size, strides=stride, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization(momentum=0.99, epsilon=1e-3)(x)
    x = layers.Activation("swish")(x)

    # Squeeze-and-Excitation
    if se_ratio:
        se_ch = max(1, int(in_ch * se_ratio))
        se = layers.GlobalAveragePooling2D(keepdims=True)(x)
        se = layers.Conv2D(se_ch, 1, activation="swish", use_bias=True)(se)
        se = layers.Conv2D(expanded, 1, activation="sigmoid", use_bias=True)(se)
        x = layers.Multiply()([x, se])

    # Projection
    x = layers.Conv2D(out_ch, 1, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization(momentum=0.99, epsilon=1e-3)(x)

    # Skip connection
    if use_skip:
        if drop_rate > 0:
            x = layers.Dropout(drop_rate, noise_shape=(None, 1, 1, 1))(x)
        x = layers.Add()([x, inputs])
    return x


def round_filters(filters, width_mult, divisor=8):
    """Round number of filters based on width multiplier."""
    filters = filters * width_mult
    new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
    if new_filters < 0.9 * filters:
        new_filters += divisor
    return int(new_filters)


def round_repeats(repeats, depth_mult):
    """Round number of repeats based on depth multiplier."""
    return int(math.ceil(repeats * depth_mult))


def build_efficientnet(input_shape=(224, 224, 3), num_classes=1000,
                       width_mult=1.0, depth_mult=1.0, dropout_rate=0.2):
    """
    Build EfficientNet model.

    B0 base architecture:
      Stage | Operator | Resolution | Channels | Layers | Stride | Kernel | Expand
      0     | Conv3x3  | 224->112   | 32       | 1      | 2      | 3      | -
      1     | MBConv1  | 112->112   | 16       | 1      | 1      | 3      | 1
      2     | MBConv6  | 112->56    | 24       | 2      | 2      | 3      | 6
      3     | MBConv6  | 56->28     | 40       | 2      | 2      | 5      | 6
      4     | MBConv6  | 28->14     | 80       | 3      | 2      | 3      | 6
      5     | MBConv6  | 14->14     | 112      | 3      | 1      | 5      | 6
      6     | MBConv6  | 14->7      | 192      | 4      | 2      | 5      | 6
      7     | MBConv6  | 7->7       | 320      | 1      | 1      | 3      | 6
      8     | Conv1x1  | 7->7       | 1280     | 1      | 1      | 1      | -
    """
    # Block configurations: (expand, channels, repeats, stride, kernel)
    block_configs = [
        (1, 16, 1, 1, 3),   # Stage 1
        (6, 24, 2, 2, 3),   # Stage 2
        (6, 40, 2, 2, 5),   # Stage 3
        (6, 80, 3, 2, 3),   # Stage 4
        (6, 112, 3, 1, 5),  # Stage 5
        (6, 192, 4, 2, 5),  # Stage 6
        (6, 320, 1, 1, 3),  # Stage 7
    ]

    inputs = keras.Input(shape=input_shape)

    # Stem: Conv 3x3, stride 2
    stem_filters = round_filters(32, width_mult)
    x = layers.Conv2D(stem_filters, 3, strides=2, padding="same", use_bias=False)(inputs)
    x = layers.BatchNormalization(momentum=0.99, epsilon=1e-3)(x)
    x = layers.Activation("swish")(x)

    # MBConv stages
    total_blocks = sum(round_repeats(cfg[2], depth_mult) for cfg in block_configs)
    block_idx = 0
    prev_channels = stem_filters

    for expand, channels, repeats, stride, kernel in block_configs:
        out_channels = round_filters(channels, width_mult)
        num_repeats = round_repeats(repeats, depth_mult)

        for i in range(num_repeats):
            s = stride if i == 0 else 1
            # Stochastic depth: linearly increase drop rate
            drop_rate = dropout_rate * block_idx / total_blocks
            x = mbconv_block(x, prev_channels, out_channels, expand, kernel, s,
                             se_ratio=0.25, drop_rate=drop_rate)
            prev_channels = out_channels
            block_idx += 1

    # Head: Conv 1x1 + GlobalPool + Dense
    head_filters = round_filters(1280, width_mult)
    x = layers.Conv2D(head_filters, 1, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization(momentum=0.99, epsilon=1e-3)(x)
    x = layers.Activation("swish")(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout_rate)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    model = keras.Model(inputs=inputs, outputs=outputs, name="EfficientNet")
    return model


# Build EfficientNet-B0
model_b0 = build_efficientnet(
    input_shape=(224, 224, 3), num_classes=1000,
    width_mult=1.0, depth_mult=1.0, dropout_rate=0.2
)
model_b0.summary(print_fn=lambda x: print(x) if "Total" in x or "param" in x else None)
print(f"\nEfficientNet-B0 Total Parameters: {model_b0.count_params():,}")
Architecture Verification

The implementation above should produce approximately 5.3M parameters for EfficientNet-B0, matching the paper’s reported 5.3M. Small discrepancies (within 1%) may arise from rounding conventions in filter counts. The key architectural properties to verify: 7 MBConv stages, total of 16 blocks, 224×224 input resolution, and 1280 head channels.

5.3M Params 16 Blocks NAS-Derived

Scaling to B1-B7

With our base model defined, scaling to larger variants is straightforward — we simply apply the compound scaling coefficients to get B1 through B7:

import tensorflow as tf
from tensorflow import keras
import math

# (Re-use build_efficientnet from previous block in practice)
# Here we show the scaling configurations and expected parameter counts

EFFICIENTNET_VARIANTS = {
    "B0": {"width": 1.0, "depth": 1.0, "resolution": 224, "dropout": 0.2},
    "B1": {"width": 1.0, "depth": 1.1, "resolution": 240, "dropout": 0.2},
    "B2": {"width": 1.1, "depth": 1.2, "resolution": 260, "dropout": 0.3},
    "B3": {"width": 1.2, "depth": 1.4, "resolution": 300, "dropout": 0.3},
    "B4": {"width": 1.4, "depth": 1.8, "resolution": 380, "dropout": 0.4},
    "B5": {"width": 1.6, "depth": 2.2, "resolution": 456, "dropout": 0.4},
    "B6": {"width": 1.8, "depth": 2.6, "resolution": 528, "dropout": 0.5},
    "B7": {"width": 2.0, "depth": 3.1, "resolution": 600, "dropout": 0.5},
}

# Expected parameter counts (millions) from the paper
EXPECTED_PARAMS = {
    "B0": 5.3, "B1": 7.8, "B2": 9.2, "B3": 12.0,
    "B4": 19.0, "B5": 30.0, "B6": 43.0, "B7": 66.0,
}

# Expected ImageNet Top-1 accuracy
EXPECTED_ACC = {
    "B0": 77.1, "B1": 79.1, "B2": 80.1, "B3": 81.6,
    "B4": 82.9, "B5": 83.6, "B6": 84.0, "B7": 84.3,
}

print(f"{'Variant':<6} {'Resolution':<12} {'Width':<8} {'Depth':<8} "
      f"{'Params (M)':<12} {'Top-1 (%)':<10} {'Dropout':<8}")
print("-" * 70)
for name, cfg in EFFICIENTNET_VARIANTS.items():
    print(f"{name:<6} {cfg['resolution']:<12} {cfg['width']:<8.1f} "
          f"{cfg['depth']:<8.1f} {EXPECTED_PARAMS[name]:<12.1f} "
          f"{EXPECTED_ACC[name]:<10.1f} {cfg['dropout']:<8.1f}")

# Show scaling ratios relative to B0
print("\nScaling Ratios Relative to B0:")
print("-" * 50)
for name, cfg in EFFICIENTNET_VARIANTS.items():
    param_ratio = EXPECTED_PARAMS[name] / EXPECTED_PARAMS["B0"]
    flops_ratio = (cfg["width"]**2 * cfg["depth"] *
                   (cfg["resolution"]/224)**2)
    print(f"{name}: {param_ratio:.1f}x params, ~{flops_ratio:.1f}x FLOPs, "
          f"+{EXPECTED_ACC[name] - EXPECTED_ACC['B0']:.1f}% accuracy")

Building Any Variant

import tensorflow as tf
from tensorflow import keras

def get_efficientnet(variant="B0", num_classes=1000):
    """
    Build any EfficientNet variant (B0-B7).

    Args:
        variant: String "B0" through "B7"
        num_classes: Number of output classes

    Returns:
        Keras Model instance
    """
    configs = {
        "B0": (1.0, 1.0, 224, 0.2),
        "B1": (1.0, 1.1, 240, 0.2),
        "B2": (1.1, 1.2, 260, 0.3),
        "B3": (1.2, 1.4, 300, 0.3),
        "B4": (1.4, 1.8, 380, 0.4),
        "B5": (1.6, 2.2, 456, 0.4),
        "B6": (1.8, 2.6, 528, 0.5),
        "B7": (2.0, 3.1, 600, 0.5),
    }

    if variant not in configs:
        raise ValueError(f"Unknown variant: {variant}. Choose B0-B7.")

    width_mult, depth_mult, resolution, dropout = configs[variant]
    input_shape = (resolution, resolution, 3)

    # build_efficientnet defined in previous section
    model = build_efficientnet(
        input_shape=input_shape,
        num_classes=num_classes,
        width_mult=width_mult,
        depth_mult=depth_mult,
        dropout_rate=dropout,
    )
    model._name = f"EfficientNet-{variant}"
    return model


# Example: Build B3 variant
model_b3 = get_efficientnet("B3", num_classes=1000)
print(f"EfficientNet-B3:")
print(f"  Input shape: (300, 300, 3)")
print(f"  Parameters: {model_b3.count_params():,}")
print(f"  Expected: ~12,000,000")

Training on CIFAR-10

Let’s train a scaled-down EfficientNet-B0 on CIFAR-10. We’ll use modern training techniques: RandAugment-style augmentation, cosine learning rate decay, and label smoothing.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

# Load CIFAR-10
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
y_train = y_train.flatten()
y_test = y_test.flatten()

NUM_CLASSES = 10
INPUT_SHAPE = (32, 32, 3)
BATCH_SIZE = 128
EPOCHS = 100
AUTOTUNE = tf.data.AUTOTUNE

print(f"Training samples: {len(x_train)}")
print(f"Test samples: {len(x_test)}")
print(f"Input shape: {INPUT_SHAPE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")


def build_augmentation():
    """Build data augmentation pipeline."""
    return keras.Sequential([
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.1),
        layers.RandomTranslation(0.1, 0.1),
        layers.RandomContrast(0.1),
    ], name="augmentation")


def create_dataset(images, labels, is_training=True):
    """Create tf.data pipeline with augmentation."""
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))

    if is_training:
        dataset = dataset.shuffle(10000, reshuffle_each_iteration=True)

    dataset = dataset.batch(BATCH_SIZE)

    if is_training:
        augment = build_augmentation()
        dataset = dataset.map(
            lambda x, y: (augment(x, training=True), y),
            num_parallel_calls=AUTOTUNE
        )

    dataset = dataset.prefetch(AUTOTUNE)
    return dataset


train_ds = create_dataset(x_train, y_train, is_training=True)
test_ds = create_dataset(x_test, y_test, is_training=False)

# Verify pipeline
for images, labels in train_ds.take(1):
    print(f"\nBatch shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")
    print(f"Pixel range: [{images.numpy().min():.3f}, {images.numpy().max():.3f}]")
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import math

# Cosine learning rate schedule with warmup
class CosineDecayWithWarmup(keras.optimizers.schedules.LearningRateSchedule):
    """Cosine decay schedule with linear warmup."""

    def __init__(self, base_lr, total_steps, warmup_steps):
        super().__init__()
        self.base_lr = base_lr
        self.total_steps = total_steps
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        total_steps = tf.cast(self.total_steps, tf.float32)

        # Linear warmup
        warmup_lr = self.base_lr * (step / warmup_steps)

        # Cosine decay after warmup
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        cosine_lr = self.base_lr * 0.5 * (1.0 + tf.math.cos(math.pi * progress))

        return tf.where(step < warmup_steps, warmup_lr, cosine_lr)


# Build a small EfficientNet for CIFAR-10 (32x32 input)
def build_cifar_efficientnet(input_shape=(32, 32, 3), num_classes=10):
    """Scaled-down EfficientNet for CIFAR-10."""
    inputs = keras.Input(shape=input_shape)

    # Stem (no stride-2 since input is only 32x32)
    x = layers.Conv2D(32, 3, padding="same", use_bias=False)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("swish")(x)

    # Simplified MBConv stages for 32x32
    configs = [
        # (filters, expand, kernel, stride, repeats)
        (16, 1, 3, 1, 1),
        (24, 6, 3, 2, 2),   # 32->16
        (40, 6, 5, 2, 2),   # 16->8
        (80, 6, 3, 2, 3),   # 8->4
        (112, 6, 5, 1, 3),  # 4->4
        (192, 6, 5, 2, 4),  # 4->2
        (320, 6, 3, 1, 1),  # 2->2
    ]

    prev_filters = 32
    for filters, expand, kernel, stride, repeats in configs:
        for i in range(repeats):
            s = stride if i == 0 else 1
            expanded = prev_filters * expand
            use_skip = (s == 1 and prev_filters == filters)

            block_input = x

            # Expansion
            if expand != 1:
                x = layers.Conv2D(expanded, 1, use_bias=False)(x)
                x = layers.BatchNormalization()(x)
                x = layers.Activation("swish")(x)

            # Depthwise
            x = layers.DepthwiseConv2D(kernel, strides=s, padding="same", use_bias=False)(x)
            x = layers.BatchNormalization()(x)
            x = layers.Activation("swish")(x)

            # SE
            se_ch = max(1, prev_filters // 4)
            se = layers.GlobalAveragePooling2D(keepdims=True)(x)
            se = layers.Conv2D(se_ch, 1, activation="swish")(se)
            se = layers.Conv2D(expanded if expand != 1 else prev_filters, 1, activation="sigmoid")(se)
            x = layers.Multiply()([x, se])

            # Projection
            x = layers.Conv2D(filters, 1, use_bias=False)(x)
            x = layers.BatchNormalization()(x)

            if use_skip:
                x = layers.Add()([x, block_input])
            prev_filters = filters

    # Head
    x = layers.Conv2D(1280, 1, use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("swish")(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs=inputs, outputs=outputs, name="EfficientNet-CIFAR10")


# Build and compile
model = build_cifar_efficientnet()
print(f"CIFAR-10 EfficientNet parameters: {model.count_params():,}")

# Learning rate schedule
steps_per_epoch = 50000 // 128
total_steps = steps_per_epoch * 100
warmup_steps = steps_per_epoch * 5

lr_schedule = CosineDecayWithWarmup(
    base_lr=0.01, total_steps=total_steps, warmup_steps=warmup_steps
)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=["accuracy"],
)

print(f"\nTotal training steps: {total_steps}")
print(f"Warmup steps: {warmup_steps}")
print(f"Base learning rate: 0.01")

Training Results

import tensorflow as tf
from tensorflow import keras
import numpy as np

# Training with callbacks (assumes model, train_ds, test_ds from above)
callbacks = [
    keras.callbacks.EarlyStopping(
        monitor="val_accuracy", patience=15, restore_best_weights=True
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.5, patience=5, min_lr=1e-6
    ),
]

# Train the model
# history = model.fit(
#     train_ds,
#     validation_data=test_ds,
#     epochs=100,
#     callbacks=callbacks,
# )

# Simulated training results for demonstration
epochs_ran = np.arange(1, 101)
train_acc = 1.0 - 0.85 * np.exp(-epochs_ran / 20) + np.random.normal(0, 0.005, 100)
val_acc = 1.0 - 0.88 * np.exp(-epochs_ran / 25) + np.random.normal(0, 0.008, 100)
train_acc = np.clip(train_acc, 0.1, 0.99)
val_acc = np.clip(val_acc, 0.1, 0.965)

print("Training Summary:")
print(f"  Best validation accuracy: {val_acc.max():.4f} (epoch {val_acc.argmax() + 1})")
print(f"  Final training accuracy: {train_acc[-1]:.4f}")
print(f"  Final validation accuracy: {val_acc[-1]:.4f}")
print(f"\nExpected results with full training:")
print(f"  EfficientNet-B0 on CIFAR-10: ~95-96% accuracy")
print(f"  With advanced augmentation: ~96-97% accuracy")
print(f"  Training time (GPU): ~30-45 minutes on V100")

Transfer Learning with Pre-trained EfficientNet

For practical applications, using pre-trained weights from ImageNet is far more efficient than training from scratch. TensorFlow provides EfficientNetV2 models with pre-trained weights:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Load pre-trained EfficientNetV2B0 (without top classification layer)
base_model = keras.applications.EfficientNetV2B0(
    include_top=False,
    weights="imagenet",
    input_shape=(224, 224, 3),
    include_preprocessing=True,  # Built-in normalization
)

# Freeze the base model
base_model.trainable = False

print(f"Base model: EfficientNetV2B0")
print(f"Parameters: {base_model.count_params():,}")
print(f"Trainable: {sum(p.numpy().size for p in base_model.trainable_weights):,}")
print(f"Output shape: {base_model.output_shape}")

# Build custom classification head
inputs = keras.Input(shape=(224, 224, 3))
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(256, activation="swish")(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(5, activation="softmax")(x)  # 5 flower classes

model = keras.Model(inputs, outputs, name="EfficientNet-Flowers")
model.summary(print_fn=lambda x: print(x) if "Total" in x or "param" in x else None)

# Compile for initial training (frozen base)
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

print(f"\nTotal params: {model.count_params():,}")
print(f"Trainable params: {sum(p.numpy().size for p in model.trainable_weights):,}")
print(f"Non-trainable params: {sum(p.numpy().size for p in model.non_trainable_weights):,}")

Fine-Tuning Strategy

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Load TF Flowers dataset
# In practice: tf.keras.utils.get_file or tfds.load('tf_flowers')
# Here we demonstrate the fine-tuning workflow

def create_flowers_dataset():
    """Load and prepare TF Flowers dataset."""
    import tensorflow_datasets as tfds

    # Load dataset (downloads ~218MB on first run)
    (train_ds, val_ds), info = tfds.load(
        "tf_flowers",
        split=["train[:80%]", "train[80%:]"],
        as_supervised=True,
        with_info=True,
    )

    num_classes = info.features["label"].num_classes
    print(f"Classes: {num_classes} ({info.features['label'].names})")
    print(f"Training samples: {info.splits['train'].num_examples}")

    def preprocess(image, label):
        image = tf.image.resize(image, (224, 224))
        image = tf.cast(image, tf.float32)
        return image, label

    train_ds = train_ds.map(preprocess).shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)
    val_ds = val_ds.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)

    return train_ds, val_ds, num_classes


# Fine-tuning procedure (2 phases)
print("=" * 60)
print("Phase 1: Train only the classification head (5 epochs)")
print("=" * 60)
print("  - Base model frozen")
print("  - Learning rate: 1e-3")
print("  - Expected accuracy: ~90-93%")

print("\n" + "=" * 60)
print("Phase 2: Fine-tune top layers of base model (5 epochs)")
print("=" * 60)
print("  - Unfreeze last 20 layers of base model")
print("  - Learning rate: 1e-5 (10x lower to avoid catastrophic forgetting)")
print("  - Expected accuracy: ~96-98%")

# Phase 2: Unfreeze top layers
# base_model.trainable = True
# for layer in base_model.layers[:-20]:
#     layer.trainable = False

# Recompile with lower learning rate
# model.compile(
#     optimizer=keras.optimizers.Adam(learning_rate=1e-5),
#     loss="sparse_categorical_crossentropy",
#     metrics=["accuracy"],
# )

# Fine-tune
# history_fine = model.fit(train_ds, validation_data=val_ds, epochs=5)

print("\nExpected Results:")
print("  Phase 1 (frozen): ~92% val accuracy in 5 epochs")
print("  Phase 2 (fine-tune): ~97% val accuracy in 5 more epochs")
print("  Total training time: ~3-5 minutes on GPU")
print("  vs. Training from scratch: ~2-4 hours for similar accuracy")
Transfer Learning Advantage: Fine-tuning a pre-trained EfficientNet for 10 total epochs achieves 96-98% accuracy on domain-specific tasks, compared to 30-100+ epochs training from scratch. This is because ImageNet pre-training provides rich, transferable feature representations that generalize well across visual domains.
Comparison: Scratch vs. Transfer
Approach Epochs Time (GPU) Accuracy
From scratch (B0) 100+ 2-4 hours ~85-90%
Transfer (frozen) 5 2 min ~92%
Transfer (fine-tuned) 5+5 5 min ~97%
Transfer Learning Fine-Tuning ImageNet