Linear Regression Implementation from Scratch

Linear regression from scratch

End-to-end linear regression with nothing but tensor ops:

  1. Model — a Module with w and b parameters and a forward.
  2. Loss — squared error.
  3. Optimizer — minibatch SGD, written by hand.
  4. Training loop — the Trainer’s fit_epoch, also from scratch.

The next chapter does the same with nn.LazyLinear + MSELoss + SGD in two lines. This one shows what those two lines hide.

Parameters

Initialize w randomly (small Gaussian), b at zero:

%matplotlib inline
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
import optax
E0524 02:42:00.143640 36072 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:42:00.144041 36072 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
class LinearRegressionScratch(d2l.Module):
    """The linear regression model implemented from scratch."""
    num_inputs: int
    lr: float
    sigma: float = 0.01

    def setup(self):
        self.w = self.param('w', nn.initializers.normal(self.sigma),
                            (self.num_inputs, 1))
        self.b = self.param('b', nn.initializers.zeros, (1))

requires_grad=True (or the framework equivalent) so autograd tracks them.

Forward pass

The model is one matrix-vector product plus a bias — \hat{\mathbf{y}} = \mathbf{X}\mathbf{w} + b:

@d2l.add_to_class(LinearRegressionScratch)
def forward(self, X):
    return d2l.matmul(X, self.w) + self.b

Loss

Squared error per example, averaged across the batch:

\ell(\hat{y}, y) = \tfrac{1}{2}(\hat{y} - y)^2.

@d2l.add_to_class(LinearRegressionScratch)
def loss(self, params, X, y, state):
    y_hat = state.apply_fn({'params': params}, *X)  # X unpacked from a tuple
    l = (y_hat - d2l.reshape(y, y_hat.shape)) ** 2 / 2
    return d2l.reduce_mean(l)

Optimizer: minibatch SGD

The update rule \theta \leftarrow \theta - \eta \nabla_\theta L written out by hand:

class SGD(d2l.HyperParameters):
    """Minibatch stochastic gradient descent."""
    # The key transformation of Optax is the GradientTransformation
    # defined by two methods, the init and the update.
    # The init initializes the state and the update transforms the gradients.
    # https://github.com/deepmind/optax/blob/master/optax/_src/transform.py
    def __init__(self, lr):
        self.save_hyperparameters()

    def init(self, params):
        # Delete unused params
        del params
        return optax.EmptyState

    def update(self, updates, state, params=None):
        del params
        # When state.apply_gradients method is called to update flax's
        # train_state object, it internally calls optax.apply_updates method
        # adding the params to the update equation defined below.
        updates = jax.tree_util.tree_map(lambda g: -self.lr * g, updates)
        return updates, state

    def __call__(self):
        return optax.GradientTransformation(self.init, self.update)

The model class wires it up in configure_optimizers:

@d2l.add_to_class(LinearRegressionScratch)
def configure_optimizers(self):
    return SGD(self.lr)

Training step

What happens once per minibatch — forward, loss, backward, step:

@d2l.add_to_class(d2l.Trainer)
def prepare_batch(self, batch):
    return batch

The whole epoch

The Trainer walks the train and val loaders once per epoch, calling the steps:

# Fuse the optimizer + state.replace updates into a single JIT'd call so
# JAX dispatches one compiled kernel per batch instead of many Python-level
# optax ops.
@jax.jit
def _trainer_update(state, grads):
    return state.apply_gradients(grads=grads).replace(
        dropout_rng=jax.random.split(state.dropout_rng)[0])


@jax.jit
def _trainer_update_with_bn(state, grads, batch_stats):
    return state.apply_gradients(grads=grads).replace(
        dropout_rng=jax.random.split(state.dropout_rng)[0],
        batch_stats=batch_stats)


@d2l.add_to_class(d2l.Trainer)
def fit_epoch(self):
    self.model.training = True
    # Some hand-rolled optimizers (e.g. LinearRegressionScratch.SGD) are not
    # JIT-traceable: detect that by probing opt_state and skip the JIT path.
    jit_ok = isinstance(self.state.opt_state, tuple)
    if self.state.batch_stats:
        # Mutable states will be used later (e.g., for batch norm)
        for batch in self.train_dataloader:
            (_, mutated_vars), grads = self.model.training_step(
                self.state.params, self.prepare_batch(batch), self.state)
            if jit_ok:
                self.state = _trainer_update_with_bn(
                    self.state, grads, mutated_vars['batch_stats'])
            else:
                self.state = self.state.apply_gradients(grads=grads).replace(
                    dropout_rng=jax.random.split(self.state.dropout_rng)[0],
                    batch_stats=mutated_vars['batch_stats'])
            self.train_batch_idx += 1
    else:
        for batch in self.train_dataloader:
            _, grads = self.model.training_step(
                self.state.params, self.prepare_batch(batch), self.state)
            if jit_ok:
                self.state = _trainer_update(self.state, grads)
            else:
                self.state = self.state.apply_gradients(grads=grads).replace(
                    dropout_rng=jax.random.split(self.state.dropout_rng)[0])
            self.train_batch_idx += 1

    if self.val_dataloader is None:
        return
    self.model.training = False
    for batch in self.val_dataloader:
        self.model.validation_step(self.state.params,
                                   self.prepare_batch(batch),
                                   self.state)
        self.val_batch_idx += 1

Run training on the synthetic dataset:

model = LinearRegressionScratch(2, lr=0.03)
data = d2l.SyntheticRegressionData(w=d2l.tensor([2, -3.4]), b=4.2)
trainer = d2l.Trainer(max_epochs=10)
trainer.fit(model, data)

Did it learn the right thing?

We know the true w and b — compare with the learned values:

params = trainer.state.params
print(f"error in estimating w: {data.w - d2l.reshape(params['w'], data.w.shape)}")
print(f"error in estimating b: {data.b - params['b']}")
error in estimating w: [ 0.00061727 -0.00057435]
error in estimating b: [0.00088215]

Tiny differences come from finite training data + noise; tighter than that requires either more data or a better optimizer.

Recap

  • A Module for linear regression boils down to __init__, forward, loss, configure_optimizers.
  • A hand-rolled SGD is ~10 lines.
  • The Trainer.fit_epoch glue is what pytorch / tensorflow / jax / mxnet’s training APIs hide.
  • Synthetic data lets us check that the optimizer recovered the ground-truth parameters.