The add_to_class trick

Object-Oriented Design for Implementation

Reusable training abstractions

Three recurring abstractions appear in every model we’ll build:

  • Module — the model: parameters, forward, loss, optimizer.
  • DataModule — the data: train and val loaders.
  • Trainer — the loop that fits a Module to a DataModule.

This chapter builds the scaffolding once. The rest of the book just subclasses these three.

The class shell

Long class definitions don’t fit one slide / one cell. Define the class shell first, then attach methods incrementally:

def add_to_class(Class):
    """Register functions as methods in created class."""
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
        return obj
    return wrapper
class A:
    def __init__(self):
        self.b = 1

a = A()
@add_to_class(A)
def do(self):
    print('Class attribute "b" is', self.b)

a.do()
Class attribute "b" is 1

The decorator just rebinds func onto Class — Python’s class namespace is mutable.

HyperParameters

Almost every class wants self.lr = lr, self.batch_size = … boilerplate in __init__. The HyperParameters mixin auto-saves constructor args as attributes:

class HyperParameters:
    """The base class of hyperparameters."""
    def save_hyperparameters(self, ignore=[]):
        raise NotImplementedError
# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
    def __init__(self, a, b, c):
        self.save_hyperparameters(ignore=['c'])
        print('self.a =', self.a, 'self.b =', self.b)
        print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True

One call (save_hyperparameters()) and every constructor arg is ready as self.<name>.

ProgressBoard

A live training-loss plot — call draw(x, y, label) from the training loop and the curve appears point-by-point:

class ProgressBoard(d2l.HyperParameters):
    """The board that plots data points in animation."""
    def __init__(self, xlabel=None, ylabel=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
                 fig=None, axes=None, figsize=(3.5, 2.5), display=True):
        self.save_hyperparameters()

    def draw(self, x, y, label, every_n=1):
        raise NotImplementedError
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
    board.draw(x, np.sin(x), 'sin', every_n=2)
    board.draw(x, np.cos(x), 'cos', every_n=10)

(The full implementation lives in d2l. We just need the API.)

Module: models

A Module knows how to forward, compute its loss, and hand back its optimizer. Every model we’ll write is a subclass:

class Module(d2l.nn_Module, d2l.HyperParameters):
    """The base class of models."""
    def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
        super().__init__()
        self.save_hyperparameters()
        self.board = ProgressBoard()

    def loss(self, y_hat, y):
        raise NotImplementedError

    def forward(self, X):
        assert hasattr(self, 'net'), 'Neural network is not defined'
        return self.net(X)

    def plot(self, key, value, train):
        """Plot a point in animation."""
        assert hasattr(self, 'trainer'), 'Trainer is not inited'
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        self.board.draw(x, d2l.numpy(d2l.to(value, d2l.cpu())),
                        ('train_' if train else 'val_') + key,
                        every_n=int(n))

    def training_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=True)
        return l

    def validation_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=False)

    def configure_optimizers(self):
        raise NotImplementedError

DataModule: data

A DataModule knows how to give back a train and a val dataloader, and a small get_dataloader(train: bool) hook subclasses override:

class DataModule(d2l.HyperParameters):
    """The base class of data."""
    def __init__(self, root='../data', num_workers=4):
        self.save_hyperparameters()

    def get_dataloader(self, train):
        raise NotImplementedError

    def train_dataloader(self):
        return self.get_dataloader(train=True)

    def val_dataloader(self):
        return self.get_dataloader(train=False)

Trainer: the loop

A Trainer ties them together: it owns the loop over epochs, drives model.training_step / validation_step, and updates the progress board:

class Trainer(d2l.HyperParameters):
    """The base class for training models with data."""
    def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
        self.save_hyperparameters()
        assert num_gpus == 0, 'No GPU support yet'

    def prepare_data(self, data):
        self.train_dataloader = data.train_dataloader()
        self.val_dataloader = data.val_dataloader()
        self.num_train_batches = len(self.train_dataloader)
        self.num_val_batches = (len(self.val_dataloader)
                                if self.val_dataloader is not None else 0)

    def prepare_model(self, model):
        model.trainer = self
        model.board.xlim = [0, self.max_epochs]
        self.model = model

    def fit(self, model, data):
        self.prepare_data(data)
        self.prepare_model(model)
        self.optim = model.configure_optimizers()
        self.epoch = 0
        self.train_batch_idx = 0
        self.val_batch_idx = 0
        for self.epoch in range(self.max_epochs):
            self.fit_epoch()

    def fit_epoch(self):
        raise NotImplementedError

Recap

  • Three classes (Module, DataModule, Trainer) form the scaffold for every model in the book.
  • add_to_class lets us define a class once and add methods later — friendly to slide-sized cells.
  • HyperParameters removes the constructor-boilerplate noise.
  • ProgressBoard gives us live loss curves with one call.