from d2l import jax as d2l
from flax import linen as nn
import jaxVGG (Simonyan & Zisserman, 2014) is AlexNet taken seriously: stack more layers, but make them regular.
The contribution wasn’t a clever architecture — it was a design principle: regular blocks of 3×3 conv + ReLU, ending in a 2×2 max-pool. Whole network = a sequence of such blocks at growing channel counts.
From AlexNet’s hand-tuned layers to VGG’s repeated 3×3 blocks.
(n_convs, channels) pairs; pass a different tuple for VGG-13/16/19.Stacking small kernels grows the visible patch without paying for a large kernel in one step.
For stride 1 and no dilation:
r_L = 1 + \sum_{\ell=1}^L (k_\ell - 1).
Two 3×3 convolutions see
1 + (3 - 1) + (3 - 1) = 5
pixels across: the same 5×5 receptive field as one 5×5 conv, but with two ReLUs and fewer weights.
A reusable subunit: n_convs consecutive Conv-ReLU pairs at out_channels, followed by a 2×2 MaxPool:
A whole VGG-11 (the smallest variant) is just five blocks at growing channel counts (64, 128, 256, 512, 512) plus a 3-layer dense head:
class VGG(d2l.Classifier):
arch: list
lr: float = 0.1
num_classes: int = 10
training: bool = True
def setup(self):
conv_blks = []
for (num_convs, out_channels) in self.arch:
conv_blks.append(vgg_block(num_convs, out_channels))
self.net = nn.Sequential([
*conv_blks,
lambda x: x.reshape((x.shape[0], -1)), # flatten
nn.Dense(4096), nn.relu,
nn.Dropout(0.5, deterministic=not self.training),
nn.Dense(4096), nn.relu,
nn.Dropout(0.5, deterministic=not self.training),
nn.Dense(self.num_classes)])Sequential output shape: (1, 112, 112, 64)
Sequential output shape: (1, 56, 56, 128)
Sequential output shape: (1, 28, 28, 256)
Sequential output shape: (1, 14, 14, 512)
Sequential output shape: (1, 7, 7, 512)
function output shape: (1, 25088)
...
custom_jvp output shape: (1, 4096)
Dropout output shape: (1, 4096)
Dense output shape: (1, 4096)
custom_jvp output shape: (1, 4096)
Dropout output shape: (1, 4096)
Dense output shape: (1, 10)
The “named architecture” is just a tuple of (n_convs, channels) pairs — passing a different tuple gives you VGG-13/16/19.
Full VGG-11 is heavy for a notebook. Train a thinned version (channels 16/32/64/128/128) on Fashion-MNIST as a smoke test:
Validates the block-at-scale design principle without melting your GPU.
((1, 64), (1, 128), (2, 256), …)) is everywhere — VGG, ResNet, EfficientNet, ConvNeXt all use it.