%matplotlib inline
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
import optaxEnd-to-end linear regression with nothing but tensor ops:
Module with w and b parameters and a forward.Trainer’s fit_epoch, also from scratch.The next chapter does the same with nn.LazyLinear + MSELoss + SGD in two lines. This one shows what those two lines hide.
Initialize w randomly (small Gaussian), b at zero:
E0524 02:42:00.143640 36072 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:42:00.144041 36072 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
class LinearRegressionScratch(d2l.Module):
"""The linear regression model implemented from scratch."""
num_inputs: int
lr: float
sigma: float = 0.01
def setup(self):
self.w = self.param('w', nn.initializers.normal(self.sigma),
(self.num_inputs, 1))
self.b = self.param('b', nn.initializers.zeros, (1))requires_grad=True (or the framework equivalent) so autograd tracks them.
The model is one matrix-vector product plus a bias — \hat{\mathbf{y}} = \mathbf{X}\mathbf{w} + b:
Squared error per example, averaged across the batch:
\ell(\hat{y}, y) = \tfrac{1}{2}(\hat{y} - y)^2.
The update rule \theta \leftarrow \theta - \eta \nabla_\theta L written out by hand:
class SGD(d2l.HyperParameters):
"""Minibatch stochastic gradient descent."""
# The key transformation of Optax is the GradientTransformation
# defined by two methods, the init and the update.
# The init initializes the state and the update transforms the gradients.
# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py
def __init__(self, lr):
self.save_hyperparameters()
def init(self, params):
# Delete unused params
del params
return optax.EmptyState
def update(self, updates, state, params=None):
del params
# When state.apply_gradients method is called to update flax's
# train_state object, it internally calls optax.apply_updates method
# adding the params to the update equation defined below.
updates = jax.tree_util.tree_map(lambda g: -self.lr * g, updates)
return updates, state
def __call__(self):
return optax.GradientTransformation(self.init, self.update)What happens once per minibatch — forward, loss, backward, step:
The Trainer walks the train and val loaders once per epoch, calling the steps:
# Fuse the optimizer + state.replace updates into a single JIT'd call so
# JAX dispatches one compiled kernel per batch instead of many Python-level
# optax ops.
@jax.jit
def _trainer_update(state, grads):
return state.apply_gradients(grads=grads).replace(
dropout_rng=jax.random.split(state.dropout_rng)[0])
@jax.jit
def _trainer_update_with_bn(state, grads, batch_stats):
return state.apply_gradients(grads=grads).replace(
dropout_rng=jax.random.split(state.dropout_rng)[0],
batch_stats=batch_stats)
@d2l.add_to_class(d2l.Trainer)
def fit_epoch(self):
self.model.training = True
# Some hand-rolled optimizers (e.g. LinearRegressionScratch.SGD) are not
# JIT-traceable: detect that by probing opt_state and skip the JIT path.
jit_ok = isinstance(self.state.opt_state, tuple)
if self.state.batch_stats:
# Mutable states will be used later (e.g., for batch norm)
for batch in self.train_dataloader:
(_, mutated_vars), grads = self.model.training_step(
self.state.params, self.prepare_batch(batch), self.state)
if jit_ok:
self.state = _trainer_update_with_bn(
self.state, grads, mutated_vars['batch_stats'])
else:
self.state = self.state.apply_gradients(grads=grads).replace(
dropout_rng=jax.random.split(self.state.dropout_rng)[0],
batch_stats=mutated_vars['batch_stats'])
self.train_batch_idx += 1
else:
for batch in self.train_dataloader:
_, grads = self.model.training_step(
self.state.params, self.prepare_batch(batch), self.state)
if jit_ok:
self.state = _trainer_update(self.state, grads)
else:
self.state = self.state.apply_gradients(grads=grads).replace(
dropout_rng=jax.random.split(self.state.dropout_rng)[0])
self.train_batch_idx += 1
if self.val_dataloader is None:
return
self.model.training = False
for batch in self.val_dataloader:
self.model.validation_step(self.state.params,
self.prepare_batch(batch),
self.state)
self.val_batch_idx += 1We know the true w and b — compare with the learned values:
error in estimating w: [ 0.00061727 -0.00057435]
error in estimating b: [0.00088215]
Tiny differences come from finite training data + noise; tighter than that requires either more data or a better optimizer.
Module for linear regression boils down to __init__, forward, loss, configure_optimizers.Trainer.fit_epoch glue is what pytorch / tensorflow / jax / mxnet’s training APIs hide.