def add_to_class(Class):
"""Register functions as methods in created class."""
def wrapper(obj):
setattr(Class, obj.__name__, obj)
return obj
return wrapperadd_to_class trickThree recurring abstractions appear in every model we’ll build:
Module — the model: parameters, forward, loss, optimizer.DataModule — the data: train and val loaders.Trainer — the loop that fits a Module to a DataModule.This chapter builds the scaffolding once. The rest of the book just subclasses these three.
Long class definitions don’t fit one slide / one cell. Define the class shell first, then attach methods incrementally:
Class attribute "b" is 1
The decorator just rebinds func onto Class — Python’s class namespace is mutable.
HyperParametersAlmost every class wants self.lr = lr, self.batch_size = … boilerplate in __init__. The HyperParameters mixin auto-saves constructor args as attributes:
self.a = 1 self.b = 2
There is no self.c = True
One call (save_hyperparameters()) and every constructor arg is ready as self.<name>.
ProgressBoardA live training-loss plot — call draw(x, y, label) from the training loop and the curve appears point-by-point:
class ProgressBoard(d2l.HyperParameters):
"""The board that plots data points in animation."""
def __init__(self, xlabel=None, ylabel=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
fig=None, axes=None, figsize=(3.5, 2.5), display=True):
self.save_hyperparameters()
def draw(self, x, y, label, every_n=1):
raise NotImplementedErrorModule: modelsA Module knows how to forward, compute its loss, and hand back its optimizer. Every model we’ll write is a subclass:
class Module(d2l.nn_Module, d2l.HyperParameters):
"""The base class of models."""
# No need for save_hyperparam when using Python dataclass
plot_train_per_epoch: int = field(default=2, init=False)
plot_valid_per_epoch: int = field(default=1, init=False)
# Use default_factory to make sure new plots are generated on each run
board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
init=False)
def loss(self, y_hat, y):
raise NotImplementedError
# JAX & Flax do not have a forward-method-like syntax. Flax uses setup
# and built-in __call__ magic methods for forward pass. Adding here
# for consistency
def forward(self, X, *args, **kwargs):
assert hasattr(self, 'net'), 'Neural network is not defined'
return self.net(X, *args, **kwargs)
def __call__(self, X, *args, **kwargs):
return self.forward(X, *args, **kwargs)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
# Skip device sync unless this point will be plotted; every_n
# buckets ensure alignment with the board's own filter
if train and int(n) > 1 and self.trainer.train_batch_idx % int(n):
return
self.board.draw(x, d2l.to(value, d2l.cpu()),
('train_' if train else 'val_') + key,
every_n=int(n))
def training_step(self, params, batch, state):
l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
batch[-1], state)
self.plot("loss", l, train=True)
return l, grads
def validation_step(self, params, batch, state):
l = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
def apply_init(self, dummy_input, key):
"""To be defined later in :numref:`sec_lazy_init`"""
raise NotImplementedError
def configure_optimizers(self):
raise NotImplementedErrorDataModule: dataA DataModule knows how to give back a train and a val dataloader, and a small get_dataloader(train: bool) hook subclasses override:
class DataModule(d2l.HyperParameters):
"""The base class of data."""
def __init__(self, root='../data', num_workers=4):
self.save_hyperparameters()
def get_dataloader(self, train):
raise NotImplementedError
def train_dataloader(self):
return self.get_dataloader(train=True)
def val_dataloader(self):
return self.get_dataloader(train=False)Trainer: the loopA Trainer ties them together: it owns the loop over epochs, drives model.training_step / validation_step, and updates the progress board:
class Trainer(d2l.HyperParameters):
"""The base class for training models with data."""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'
def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model
def fit(self, model, data, key=None):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
if key is None:
root_key = d2l.get_key()
else:
root_key = key
params_key, dropout_key = jax.random.split(root_key)
key = {'params': params_key, 'dropout': dropout_key}
dummy_input = next(iter(self.train_dataloader))[:-1]
variables = model.apply_init(dummy_input, key=key)
params = variables['params']
if 'batch_stats' in variables.keys():
# Here batch_stats will be used later (e.g., for batch norm)
batch_stats = variables['batch_stats']
else:
batch_stats = {}
# Flax uses optax under the hood for a single state obj TrainState.
# More will be discussed later in the dropout and batch
# normalization section
class TrainState(train_state.TrainState):
batch_stats: Any
dropout_rng: jax.Array
self.state = TrainState.create(apply_fn=model.apply,
params=params,
batch_stats=batch_stats,
dropout_rng=dropout_key,
tx=model.configure_optimizers())
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
def fit_epoch(self):
raise NotImplementedErrorModule, DataModule, Trainer) form the scaffold for every model in the book.add_to_class lets us define a class once and add methods later — friendly to slide-sized cells.HyperParameters removes the constructor-boilerplate noise.ProgressBoard gives us live loss curves with one call.