Softmax + cross-entropy, fused

Concise Implementation of Softmax Regression

Concise softmax regression

Same model, same data — using the framework’s built-in primitives:

  • One linear layer instead of hand-rolled W and b.
  • Built-in cross_entropy that fuses softmax + log + NLL with numerical-stability tricks (the LogSumExp trick).
  • Same Trainer, same convergence, much less code.

The model

Imports + a one-line linear layer wrapped in our Classifier scaffold:

from d2l import jax as d2l
from flax import linen as nn
from functools import partial
import jax
from jax import numpy as jnp
import optax
class SoftmaxRegression(d2l.Classifier):
    num_outputs: int
    lr: float

    @nn.compact
    def __call__(self, X):
        X = X.reshape((X.shape[0], -1))  # Flatten
        X = nn.Dense(self.num_outputs)(X)
        return X

Computing softmax then log then NLL separately blows up numerically when logits are large (exp(100) overflows in common float32 arithmetic). The framework’s cross_entropy takes raw logits and computes the loss directly via the LogSumExp trick — equivalent math, stable arithmetic:

\log \sum_j e^{o_j} = m + \log \sum_j e^{o_j-m}, \quad m=\max_j o_j.

@d2l.add_to_class(d2l.Classifier)
@partial(jax.jit, static_argnums=(0, 5))
def loss(self, params, X, Y, state, averaged=True):
    # To be used later (e.g., for batch norm)
    Y_hat = state.apply_fn({'params': params}, *X,
                           mutable=False, rngs=None)
    Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
    Y = d2l.reshape(Y, (-1,))
    fn = optax.softmax_cross_entropy_with_integer_labels
    # The returned empty dictionary is a placeholder for auxiliary data,
    # which will be used later (e.g., for batch norm)
    return (fn(Y_hat, Y).mean(), {}) if averaged else (fn(Y_hat, Y), {})

The model output skips the explicit softmax — the loss handles both pieces.

Train

Same Fashion-MNIST data, same 10 epochs, same Trainer:

data = d2l.FashionMNIST(batch_size=256)
model = SoftmaxRegression(num_outputs=10, lr=0.1)
trainer = d2l.Trainer(max_epochs=10)
trainer.fit(model, data)

Identical accuracy curve to the from-scratch version. Built-in loss = cleaner code + better numerics.

Recap

  • From-scratch taught us softmax and cross-entropy; concise is what we actually use.
  • Built-in cross_entropy(logits, y)softmax → log → NLL with the LogSumExp stability trick baked in.
  • The forward pass should output logits, not softmax probabilities — the loss does the rest.