from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnpThe simplest multilayer perceptron — two affine layers with a ReLU between them — trained end-to-end on Fashion-MNIST (28×28 grayscale, 10 classes).
X (batch, 784)
│ Linear 784 → 256
│ ReLU
│ Linear 256 → 10
▼
logits (batch, 10)
We’ll build it twice — from scratch (manage the weights by hand) and concise (nn.Sequential) — to make concrete what the framework’s abstraction buys you.
For Fashion-MNIST (784 inputs → 10 outputs):
These are hyperparameters — not learned. We set them by hand, train, and see what works.
E0524 02:42:54.528434 58930 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:42:54.528845 58930 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
Two weight matrices, two bias vectors. Init: small Gaussian \mathcal{N}(0, \sigma^2) for weights, zero for biases.
\mathbf{W}^{(1)} \in \mathbb{R}^{784 \times 256},\quad \mathbf{b}^{(1)} \in \mathbb{R}^{256}, \mathbf{W}^{(2)} \in \mathbb{R}^{256 \times 10},\quad \mathbf{b}^{(2)} \in \mathbb{R}^{10}.
Total: 784 \cdot 256 + 256 + 256 \cdot 10 + 10 = 203\,530 parameters.
class MLPScratch(d2l.Classifier):
num_inputs: int
num_outputs: int
num_hiddens: int
lr: float
sigma: float = 0.01
def setup(self):
self.W1 = self.param('W1', nn.initializers.normal(self.sigma),
(self.num_inputs, self.num_hiddens))
self.b1 = self.param('b1', nn.initializers.zeros, self.num_hiddens)
self.W2 = self.param('W2', nn.initializers.normal(self.sigma),
(self.num_hiddens, self.num_outputs))
self.b2 = self.param('b2', nn.initializers.zeros, self.num_outputs)First, our own ReLU — just max(X, 0) elementwise:
Then the forward pass:
\mathbf{H} = \mathrm{ReLU}(\mathbf{X}\mathbf{W}^{(1)} + \mathbf{b}^{(1)}),\quad \mathbf{O} = \mathbf{H}\mathbf{W}^{(2)} + \mathbf{b}^{(2)}.
Image pixels are flattened to a 784-vector first — we’re ignoring spatial structure. (CNNs in the next chapter fix this.)
Same Trainer, same Fashion-MNIST loaders, same cross-entropy loss as softmax regression. Only the model class changed:
About 1–2 percentage points better than plain softmax regression on the same data. A nonlinearity earns its keep.
Stack the same architecture using the framework’s container. Lazy linear layers infer input shapes; ReLU is built in:
That’s the whole architecture: 6 layers in a Sequential (Flatten + 2 Linear + 1 ReLU + glue), zero hand-rolled parameter management.
Both versions produce the same model. The framework just removes the bookkeeping.
Identical convergence behavior. Built-in Linear and ReLU give you exactly what the from-scratch version computes — one of them is just easier to read and harder to bug.
We have a working MLP — but the real questions are open:
Each is the topic of one of the next decks.
Sequential(Flatten, Linear, ReLU, Linear) — same model, less bookkeeping.