Softmax Regression Implementation from Scratch

Softmax regression from scratch

The same recipe as linear regression, with two new pieces:

  1. Softmax turns logits into a probability distribution.
  2. Cross-entropy is the loss for distributions.

Wired into the same Module / Trainer scaffold from the regression chapter — Classifier adds accuracy reporting and we inherit the rest.

Sums along an axis

Quick reminder before defining softmax — sum along chosen axes:

from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
from functools import partial
E0524 02:41:27.450137 26689 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:41:27.450547 26689 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
X = d2l.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
d2l.reduce_sum(X, 0, keepdims=True), d2l.reduce_sum(X, 1, keepdims=True)
(Array([[5., 7., 9.]], dtype=float32),
 Array([[ 6.],
        [15.]], dtype=float32))

Softmax

\mathrm{softmax}(\mathbf{X})_{ij} = \frac{\exp(\mathbf{X}_{ij})}{\sum_k \exp(\mathbf{X}_{ik})}.

Three steps: exponentiate, sum across the class axis, divide.

def softmax(X):
    X_exp = d2l.exp(X)
    partition = d2l.reduce_sum(X_exp, 1, keepdims=True)
    return X_exp / partition  # The broadcasting mechanism is applied here

Result: every row is non-negative and sums to 1 — a valid probability distribution over classes:

X = jax.random.uniform(d2l.get_key(), (2, 5))
X_prob = softmax(X)
X_prob, d2l.reduce_sum(X_prob, 1)
(Array([[0.15581527, 0.19802004, 0.16004387, 0.31561127, 0.17050964],
        [0.2204476 , 0.26683843, 0.1833601 , 0.13127865, 0.19807518]],      dtype=float32),
 Array([1.        , 0.99999994], dtype=float32))

The model

Flatten each 32×32 image into a 1024-vector, hit one linear layer that outputs 10 logits — one per class:

class SoftmaxRegressionScratch(d2l.Classifier):
    num_inputs: int
    num_outputs: int
    lr: float
    sigma: float = 0.01

    def setup(self):
        self.W = self.param('W', nn.initializers.normal(self.sigma),
                            (self.num_inputs, self.num_outputs))
        self.b = self.param('b', nn.initializers.zeros, self.num_outputs)

The forward pass = flatten → linear → softmax:

def forward(self, X):
    X = d2l.reshape(X, (-1, self.W.shape[0]))
    return softmax(d2l.matmul(X, self.W) + self.b)

Cross-entropy loss

For label y (an integer class), the loss on one example is just

\ell = -\log \hat{y}_{y}

— the negative log of the predicted probability of the correct class. Here are two examples with 3 classes:

y = d2l.tensor([0, 2])
y_hat = d2l.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
Array([0.1, 0.5], dtype=float32)

Implementing it

One line — fancy indexing pulls out y_hat[i, y[i]] for each example, then negative log:

def cross_entropy(y_hat, y):
    # Tiny clip to keep log finite when softmax outputs underflow to 0.
    p = jnp.clip(y_hat[list(range(len(y_hat))), y], min=1e-12)
    return -d2l.reduce_mean(d2l.log(p))

cross_entropy(y_hat, y)
Array(1.4978662, dtype=float32)
@partial(jax.jit, static_argnums=(0))
def loss(self, params, X, y, state):
    def cross_entropy(y_hat, y):
        # Tiny clip to keep log finite when softmax outputs underflow to 0.
        p = jnp.clip(y_hat[list(range(len(y_hat))), y], min=1e-12)
        return -d2l.reduce_mean(d2l.log(p))
    y_hat = state.apply_fn({'params': params}, *X)
    # The returned empty dictionary is a placeholder for auxiliary data,
    # which will be used later (e.g., for batch norm)
    return cross_entropy(y_hat, y), {}

Train

10 epochs on Fashion-MNIST. The base Classifier already handles the validation loop and accuracy reporting:

data = d2l.FashionMNIST(batch_size=256)
model = SoftmaxRegressionScratch(num_inputs=784, num_outputs=10, lr=0.1)
trainer = d2l.Trainer(max_epochs=10)
trainer.fit(model, data)

Predicting

Pull a fresh validation batch and look at predicted vs. true classes:

X, y = next(iter(data.val_dataloader()))
preds = d2l.argmax(model.apply({'params': trainer.state.params}, X), axis=1)
preds.shape
(256,)

Tile the misclassified images, captioned with predicted / true:

wrong = d2l.astype(preds, y.dtype) != y
X, y, preds = X[wrong], y[wrong], preds[wrong]
labels = [a+'\n'+b for a, b in zip(
    data.text_labels(y), data.text_labels(preds))]
data.visualize([X, y], labels=labels)

Linear models cap out around ~83% on Fashion-MNIST — easy classes right, ambiguous shirt-vs-pullover wrong.

Recap

  • Softmax = exp + row-sum normalization → probabilities.
  • Cross-entropy = -\log p_\text{correct}, the standard classification loss.
  • A 10-output linear layer + softmax + CE loss is the baseline classifier — anything fancier (MLPs, CNNs) just replaces the forward pass.