from d2l import jax as d2l
from functools import partial
from jax import numpy as jnp
import jax
import optaxClassifier classA small Classifier base class that every classification model in the book inherits from. Same role as d2l.Module for regression — but with classification-specific defaults:
Subclasses just supply forward (and a custom loss if not plain cross-entropy).
Classifiers usually produce a vector of scores \mathbf{o}\in\mathbb{R}^q. The training loss may turn scores into probabilities, but the deployed decision is often just
\hat{y}=\arg\max_j o_j.
Keep the roles separate:
Accuracy is what many benchmarks report, but it is not a useful gradient: one tiny score change usually leaves argmax unchanged.
class Classifier(d2l.Module):
"""The base class of classification models."""
def training_step(self, params, batch, state):
# Here value is a tuple since models with BatchNorm layers require
# the loss to return auxiliary data
value, grads = jax.value_and_grad(
self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
l, _ = value
self.plot("loss", l, train=True)
return value, grads
def validation_step(self, params, batch, state):
# Discard the second returned value. It is used for training models
# with BatchNorm layers since loss also returns auxiliary data
l, _ = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
self.plot('acc', self.accuracy(params, batch[:-1], batch[-1], state),
train=False)A default configure_optimizers on Module so subclasses don’t have to write it:
Take the argmax along the class axis, compare with the true label element-wise, and average. The result is the fraction of correctly-classified examples in the batch:
@d2l.add_to_class(Classifier)
@partial(jax.jit, static_argnums=(0, 5))
def accuracy(self, params, X, Y, state, averaged=True):
"""Compute the number of correct predictions."""
Y_hat = state.apply_fn({'params': params,
'batch_stats': state.batch_stats}, # BatchNorm Only
*X)
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
preds = d2l.astype(d2l.argmax(Y_hat, axis=1), Y.dtype)
compare = d2l.astype(preds == d2l.reshape(Y, (-1,)), d2l.float32)
return d2l.reduce_mean(compare) if averaged else compareThe validation step then reports both the loss (lower is better) and accuracy (higher is better) every epoch.
Two models can have the same accuracy but different confidence. Cross-entropy still notices whether the correct class received probability 0.51 or 0.99.
Use both during training:
Classifier(d2l.Module) adds accuracy reporting to the base scaffold from the regression chapter.argmax → ==y → mean.