Module hooks

Concise Implementation of Linear Regression

Framework linear regression

The same model, same data, same training — using the framework’s high-level layers and built-in losses:

  • Model: one LazyLinear (or equivalent) instead of hand-rolled w, b.
  • Loss: built-in MSELoss (no factor of ½).
  • Optimizer: built-in SGD.

End result: ~5 lines of model code instead of 30. Same convergence on synthetic data.

Model setup

Wrap a single linear layer with the right output dimension. The “lazy” variant defers the input-dim shape until the first forward:

from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
import optax
class LinearRegression(d2l.Module):
    """The linear regression model implemented with high-level APIs."""
    lr: float

    def setup(self):
        self.net = nn.Dense(1, kernel_init=nn.initializers.normal(0.01))

Hook the layer into our Module interface (forward, configure_optimizers):

@d2l.add_to_class(LinearRegression)
def forward(self, X):
    return self.net(X)

Loss and optimizer

Built-in MSE — note it omits the 1/2 factor we used by hand:

@d2l.add_to_class(LinearRegression)
def loss(self, params, X, y, state):
    y_hat = state.apply_fn({'params': params}, *X)
    return d2l.reduce_mean(jnp.square(y_hat - y))

Same SGD, instantiated with one call:

@d2l.add_to_class(LinearRegression)
def configure_optimizers(self):
    return optax.sgd(self.lr)

Train

Identical loop — the Trainer doesn’t care that the model is now a thin wrapper around a built-in layer:

model = LinearRegression(lr=0.03)
data = d2l.SyntheticRegressionData(w=d2l.tensor([2, -3.4]), b=4.2)
trainer = d2l.Trainer(max_epochs=3)
trainer.fit(model, data)

Compare with ground truth

Pull weights and bias back out of the layer:

@d2l.add_to_class(LinearRegression)
def get_w_b(self, state):
    net = state.params['net']
    return net['kernel'], net['bias']

w, b = model.get_w_b(trainer.state)
print(f'error in estimating w: {data.w - d2l.reshape(w, data.w.shape)}')
print(f'error in estimating b: {data.b - b}')
error in estimating w: [ 0.01253641 -0.01050758]
error in estimating b: [0.01987886]

Errors are tiny — same recovery as the from-scratch version, less glue code.

Recap

  • From scratch taught us what was happening; concise is what we’ll actually use.
  • The high-level layers / losses / optimizers compose with the same Module / Trainer scaffold.
  • Same minibatch loop, same convergence; one-line layer instead of hand-rolled parameters.