%matplotlib inline
from d2l import mxnet as d2l
from mxnet import autograd, np, npx
npx.set_np()End-to-end linear regression with nothing but tensor ops:
Module with w and b parameters and a forward.Trainer’s fit_epoch, also from scratch.The next chapter does the same with nn.LazyLinear + MSELoss + SGD in two lines. This one shows what those two lines hide.
Initialize w randomly (small Gaussian), b at zero:
class LinearRegressionScratch(d2l.Module):
"""The linear regression model implemented from scratch."""
def __init__(self, num_inputs, lr, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.w = d2l.normal(0, sigma, (num_inputs, 1))
self.b = d2l.zeros(1)
self.w.attach_grad()
self.b.attach_grad()requires_grad=True (or the framework equivalent) so autograd tracks them.
The model is one matrix-vector product plus a bias — \hat{\mathbf{y}} = \mathbf{X}\mathbf{w} + b:
Squared error per example, averaged across the batch:
\ell(\hat{y}, y) = \tfrac{1}{2}(\hat{y} - y)^2.
The update rule \theta \leftarrow \theta - \eta \nabla_\theta L written out by hand:
What happens once per minibatch — forward, loss, backward, step:
The Trainer walks the train and val loaders once per epoch, calling the steps:
@d2l.add_to_class(d2l.Trainer)
def fit_epoch(self):
for batch in self.train_dataloader:
with autograd.record():
loss = self.model.training_step(self.prepare_batch(batch))
loss.backward()
if self.gradient_clip_val > 0:
self.clip_gradients(self.gradient_clip_val, self.model)
self.optim.step(1)
self.train_batch_idx += 1
if self.val_dataloader is None:
return
for batch in self.val_dataloader:
self.model.validation_step(self.prepare_batch(batch))
self.val_batch_idx += 1We know the true w and b — compare with the learned values:
Tiny differences come from finite training data + noise; tighter than that requires either more data or a better optimizer.
Module for linear regression boils down to __init__, forward, loss, configure_optimizers.Trainer.fit_epoch glue is what pytorch / tensorflow / jax / mxnet’s training APIs hide.