from d2l import jax as d2l
from flax import linen as nn
from functools import partial
import jax
from jax import numpy as jnp
import optaxSame model, same data — using the framework’s built-in primitives:
W and b.cross_entropy that fuses softmax + log + NLL with numerical-stability tricks (the LogSumExp trick).Trainer, same convergence, much less code.Imports + a one-line linear layer wrapped in our Classifier scaffold:
Computing softmax then log then NLL separately blows up numerically when logits are large (exp(100) overflows in common float32 arithmetic). The framework’s cross_entropy takes raw logits and computes the loss directly via the LogSumExp trick — equivalent math, stable arithmetic:
\log \sum_j e^{o_j} = m + \log \sum_j e^{o_j-m}, \quad m=\max_j o_j.
@d2l.add_to_class(d2l.Classifier)
@partial(jax.jit, static_argnums=(0, 5))
def loss(self, params, X, Y, state, averaged=True):
# To be used later (e.g., for batch norm)
Y_hat = state.apply_fn({'params': params}, *X,
mutable=False, rngs=None)
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Y = d2l.reshape(Y, (-1,))
fn = optax.softmax_cross_entropy_with_integer_labels
# The returned empty dictionary is a placeholder for auxiliary data,
# which will be used later (e.g., for batch norm)
return (fn(Y_hat, Y).mean(), {}) if averaged else (fn(Y_hat, Y), {})The model output skips the explicit softmax — the loss handles both pieces.
Same Fashion-MNIST data, same 10 epochs, same Trainer:
Identical accuracy curve to the from-scratch version. Built-in loss = cleaner code + better numerics.
cross_entropy(logits, y) ≡ softmax → log → NLL with the LogSumExp stability trick baked in.