from d2l import mxnet as d2l
from mxnet import gluon, init, npx
from mxnet.gluon import nn
npx.set_np()Same model, same data — using the framework’s built-in primitives:
W and b.cross_entropy that fuses softmax + log + NLL with numerical-stability tricks (the LogSumExp trick).Trainer, same convergence, much less code.Imports + a one-line linear layer wrapped in our Classifier scaffold:
Computing softmax then log then NLL separately blows up numerically when logits are large (exp(100) overflows in common float32 arithmetic). The framework’s cross_entropy takes raw logits and computes the loss directly via the LogSumExp trick — equivalent math, stable arithmetic:
\log \sum_j e^{o_j} = m + \log \sum_j e^{o_j-m}, \quad m=\max_j o_j.
The model output skips the explicit softmax — the loss handles both pieces.
Same Fashion-MNIST data, same 10 epochs, same Trainer:
Identical accuracy curve to the from-scratch version. Built-in loss = cleaner code + better numerics.
cross_entropy(logits, y) ≡ softmax → log → NLL with the LogSumExp stability trick baked in.