from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnpAlexNet (Krizhevsky, Sutskever, Hinton — 2012) is what made deep learning the approach to vision. Won ImageNet by a huge margin and started the modern era.
AlexNet alongside the LeNet from a decade earlier.
The architecture itself is straightforward; what changed was the scale.
Five conv layers (11×11 → 5×5 → three 3×3) + max-pool, then three FC layers down to 1000 classes:
class AlexNet(d2l.Classifier):
lr: float = 0.1
num_classes: int = 10
training: bool = True
def setup(self):
self.net = nn.Sequential([
nn.Conv(features=96, kernel_size=(11, 11), strides=4, padding=1),
nn.relu,
lambda x: nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)),
nn.Conv(features=256, kernel_size=(5, 5)),
nn.relu,
lambda x: nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)),
nn.Conv(features=384, kernel_size=(3, 3)), nn.relu,
nn.Conv(features=384, kernel_size=(3, 3)), nn.relu,
nn.Conv(features=256, kernel_size=(3, 3)), nn.relu,
lambda x: nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)),
lambda x: x.reshape((x.shape[0], -1)), # flatten
nn.Dense(features=4096),
nn.relu,
nn.Dropout(0.5, deterministic=not self.training),
nn.Dense(features=4096),
nn.relu,
nn.Dropout(0.5, deterministic=not self.training),
nn.Dense(features=self.num_classes)
])Walk a single 1×1×224×224 image through and print each block’s output shape — the feature pyramid going from 224×224×1 down to 6×6×256:
Conv output shape: (1, 54, 54, 96)
custom_jvp output shape: (1, 54, 54, 96)
function output shape: (1, 26, 26, 96)
Conv output shape: (1, 26, 26, 256)
custom_jvp output shape: (1, 26, 26, 256)
function output shape: (1, 12, 12, 256)
...
custom_jvp output shape: (1, 4096)
Dropout output shape: (1, 4096)
Dense output shape: (1, 4096)
custom_jvp output shape: (1, 4096)
Dropout output shape: (1, 4096)
Dense output shape: (1, 10)
For demonstration, upsample the 28×28 Fashion-MNIST images to the 224×224 input AlexNet expects, then train at lr=0.01:
Trains slowly even on a GPU — AlexNet has ~10× the parameters of LeNet. The architecture’s lasting contribution: it proved that bigger is better when paired with enough data and compute.