Softmax + cross-entropy, fused

Concise Implementation of Softmax Regression

Concise softmax regression

Same model, same data — using the framework’s built-in primitives:

  • One linear layer instead of hand-rolled W and b.
  • Built-in cross_entropy that fuses softmax + log + NLL with numerical-stability tricks (the LogSumExp trick).
  • Same Trainer, same convergence, much less code.

The model

Imports + a one-line linear layer wrapped in our Classifier scaffold:

from d2l import torch as d2l
import torch
from torch import nn
from torch.nn import functional as F
class SoftmaxRegression(d2l.Classifier):
    """The softmax regression model."""
    def __init__(self, num_outputs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.net = nn.Sequential(nn.Flatten(),
                                 nn.LazyLinear(num_outputs))

    def forward(self, X):
        return self.net(X)

Computing softmax then log then NLL separately blows up numerically when logits are large (exp(100) overflows in common float32 arithmetic). The framework’s cross_entropy takes raw logits and computes the loss directly via the LogSumExp trick — equivalent math, stable arithmetic:

\log \sum_j e^{o_j} = m + \log \sum_j e^{o_j-m}, \quad m=\max_j o_j.

@d2l.add_to_class(d2l.Classifier)
def loss(self, Y_hat, Y, averaged=True):
    Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
    Y = d2l.reshape(Y, (-1,))
    return F.cross_entropy(
        Y_hat, Y, reduction='mean' if averaged else 'none')

The model output skips the explicit softmax — the loss handles both pieces.

Train

Same Fashion-MNIST data, same 10 epochs, same Trainer:

data = d2l.FashionMNIST(batch_size=256)
model = SoftmaxRegression(num_outputs=10, lr=0.1)
trainer = d2l.Trainer(max_epochs=10)
trainer.fit(model, data)

Identical accuracy curve to the from-scratch version. Built-in loss = cleaner code + better numerics.

Recap

  • From-scratch taught us softmax and cross-entropy; concise is what we actually use.
  • Built-in cross_entropy(logits, y)softmax → log → NLL with the LogSumExp stability trick baked in.
  • The forward pass should output logits, not softmax probabilities — the loss does the rest.