from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
from functools import partialThe same recipe as linear regression, with two new pieces:
Wired into the same Module / Trainer scaffold from the regression chapter — Classifier adds accuracy reporting and we inherit the rest.
Quick reminder before defining softmax — sum along chosen axes:
E0524 02:41:27.450137 26689 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:41:27.450547 26689 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
(Array([[5., 7., 9.]], dtype=float32),
Array([[ 6.],
[15.]], dtype=float32))
\mathrm{softmax}(\mathbf{X})_{ij} = \frac{\exp(\mathbf{X}_{ij})}{\sum_k \exp(\mathbf{X}_{ik})}.
Three steps: exponentiate, sum across the class axis, divide.
Result: every row is non-negative and sums to 1 — a valid probability distribution over classes:
(Array([[0.15581527, 0.19802004, 0.16004387, 0.31561127, 0.17050964],
[0.2204476 , 0.26683843, 0.1833601 , 0.13127865, 0.19807518]], dtype=float32),
Array([1. , 0.99999994], dtype=float32))
Flatten each 32×32 image into a 1024-vector, hit one linear layer that outputs 10 logits — one per class:
For label y (an integer class), the loss on one example is just
\ell = -\log \hat{y}_{y}
— the negative log of the predicted probability of the correct class. Here are two examples with 3 classes:
Array([0.1, 0.5], dtype=float32)
One line — fancy indexing pulls out y_hat[i, y[i]] for each example, then negative log:
Array(1.4978662, dtype=float32)
@partial(jax.jit, static_argnums=(0))
def loss(self, params, X, y, state):
def cross_entropy(y_hat, y):
# Tiny clip to keep log finite when softmax outputs underflow to 0.
p = jnp.clip(y_hat[list(range(len(y_hat))), y], min=1e-12)
return -d2l.reduce_mean(d2l.log(p))
y_hat = state.apply_fn({'params': params}, *X)
# The returned empty dictionary is a placeholder for auxiliary data,
# which will be used later (e.g., for batch norm)
return cross_entropy(y_hat, y), {}10 epochs on Fashion-MNIST. The base Classifier already handles the validation loop and accuracy reporting:
Pull a fresh validation batch and look at predicted vs. true classes:
(256,)
Tile the misclassified images, captioned with predicted / true:
Linear models cap out around ~83% on Fashion-MNIST — easy classes right, ambiguous shirt-vs-pullover wrong.