from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
from types import FunctionTypeLeNet-5 (Yann LeCun et al., 1989; productionized 1998) was the first convolutional neural network at production scale — handwritten digits on U.S. bank checks. Some ATMs still run derivatives of the original C++ today.
It defined the architectural template every later CNN refines: a convolutional encoder (spatial dims shrink, channels grow) feeding a dense head. ResNet, EfficientNet, ViT — same skeleton, different components.
LeNet-5 data flow on a 28×28 handwritten digit. Spatial dims shrink; channels grow.
Two conv→sigmoid→avgpool blocks, three FC layers, 10 logits.
Same network, vertical schematic — the textbook version:
Compact LeNet-5 schematic.
400 × 120 = 48000 weights from conv block to first dense layer. Modern CNNs replace the dense stack with global average pooling — much cheaper.Almost mechanical translation from the figure to a Sequential. Xavier init keeps the sigmoid layers from saturating early in training:
class LeNet(d2l.Classifier):
"""The LeNet-5 model."""
lr: float = 0.1
num_classes: int = 10
kernel_init: FunctionType = nn.initializers.xavier_uniform
def setup(self):
self.net = nn.Sequential([
nn.Conv(features=6, kernel_size=(5, 5), padding='SAME',
kernel_init=self.kernel_init()),
nn.sigmoid,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
nn.Conv(features=16, kernel_size=(5, 5), padding='VALID',
kernel_init=self.kernel_init()),
nn.sigmoid,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
lambda x: x.reshape((x.shape[0], -1)), # flatten
nn.Dense(features=120, kernel_init=self.kernel_init()),
nn.sigmoid,
nn.Dense(features=84, kernel_init=self.kernel_init()),
nn.sigmoid,
nn.Dense(features=self.num_classes, kernel_init=self.kernel_init())
])Critical debugging tool: walk a dummy (1, 1, 28, 28) input through the layers and print the shape after each. Match this against the figure to verify the architecture is wired correctly:
@d2l.add_to_class(d2l.Classifier)
def layer_summary(self, X_shape, key=d2l.get_key()):
X = jnp.zeros(X_shape)
params = self.init(key, X)
bound_model = self.clone().bind(params, mutable=['batch_stats'])
_ = bound_model(X)
for layer in bound_model.net.layers:
X = layer(X)
print(layer.__class__.__name__, 'output shape:\t', X.shape)
model = LeNet()
model.layer_summary((1, 28, 28, 1))Conv output shape: (1, 28, 28, 6)
PjitFunction output shape: (1, 28, 28, 6)
function output shape: (1, 14, 14, 6)
Conv output shape: (1, 10, 10, 16)
PjitFunction output shape: (1, 10, 10, 16)
function output shape: (1, 5, 5, 16)
function output shape: (1, 400)
Dense output shape: (1, 120)
PjitFunction output shape: (1, 120)
Dense output shape: (1, 84)
PjitFunction output shape: (1, 84)
Dense output shape: (1, 10)
Confirms 28→28→14→10→5→flatten→120→84→10 — exactly the pyramid in the diagram.
Cross-entropy loss + SGD + 10 epochs. Same Trainer API as every previous chapter — only the model changes:
LeNet’s convolutional inductive bias clearly beats the dense MLP from the previous chapter on the same data — even with 1990s components (sigmoid, average pooling).
LeNet’s 1998 architecture vs. modern best practice:
| LeNet (1998) | Modern (2020s) |
|---|---|
| sigmoid activation | ReLU / GELU |
| average pooling | max pool / strided conv |
| no normalization | BatchNorm / LayerNorm |
| Xavier init | He init |
| 5 layers, ~60k params | 50+ layers, millions of params |
| dense head | global average pool + 1 linear |
Each substitution is the subject of a section in the next chapter (Modern CNNs). The skeleton — conv encoder + head — is unchanged.