%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import optaxThe simplest regularization technique in the book — add a penalty on the squared norm of the weights:
L_{\text{reg}}(\mathbf{w}, b) = L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|_2^2.
The gradient gains a +\lambda\mathbf{w} term, so the update subtracts \eta\lambda\mathbf{w} and weights decay toward zero each step. One hyperparameter \lambda (wd in code) controls how much.
Why? An overparameterized model fit to a tiny dataset memorizes the noise. Capping how big the weights can grow keeps the fit tame.
E0524 02:42:24.226834 42059 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:42:24.227227 42059 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
Generate a tiny dataset (20 train, 100 val) where the truth has 200 inputs but only a small total signal:
y = 0.05 + \sum_{i=1}^{200} 0.01\,x_i + \epsilon, \quad \epsilon \sim \mathcal{N}(0, 0.01^2).
Far more parameters than data — perfect overfitting setup:
class Data(d2l.DataModule):
def __init__(self, num_train, num_val, num_inputs, batch_size):
self.save_hyperparameters()
n = num_train + num_val
key_X, key_noise = jax.random.split(jax.random.PRNGKey(0))
self.X = jax.random.normal(key_X, (n, num_inputs))
noise = jax.random.normal(key_noise, (n, 1)) * 0.01
w, b = d2l.ones((num_inputs, 1)) * 0.01, 0.05
self.y = d2l.matmul(self.X, w) + b + noise
def get_dataloader(self, train):
i = slice(0, self.num_train) if train else slice(self.num_train, None)
return self.get_tensorloader([self.X, self.y], train, i)The penalty itself is one line:
Subclass the from-scratch linear regression to add the penalty into the loss:
data = Data(num_train=20, num_val=100, num_inputs=200, batch_size=5)
trainer = d2l.Trainer(max_epochs=10)
def train_scratch(lambd):
model = WeightDecayScratch(num_inputs=200, lambd=lambd, lr=0.01)
model.board.yscale='log'
trainer.fit(model, data)
print('L2 norm of w:',
float(l2_penalty(trainer.state.params['w'])))\lambda = 0: the model fits the 20 training examples almost perfectly while validation loss explodes:
L2 norm of w: 0.008094558492302895
\lambda = 3: training loss is higher, but validation loss is much lower. Generalization wins:
L2 norm of w: 0.0014479696983471513
The training-vs-validation gap is the regularization payoff.
Most optimizers accept a weight_decay argument that adds the \lambda \mathbf{w} gradient term automatically — same idea, no manual penalty code:
class WeightDecay(d2l.LinearRegression):
wd: float = 0
def configure_optimizers(self):
# Weight Decay is not available directly within optax.sgd, but
# optax allows chaining several transformations together. We
# mask the decay so it applies to the kernel only (not bias),
# matching the per-parameter-group convention in PyTorch / MXNet.
def kernel_mask(params):
return jax.tree_util.tree_map_with_path(
lambda path, _: path[-1].key != 'bias', params)
return optax.chain(
optax.masked(optax.add_decayed_weights(self.wd), kernel_mask),
optax.sgd(self.lr))L2 norm of w: 0.0018107470823451877
(Note: framework weight_decay typically applies to all parameters; if you don’t want bias decay, exclude it explicitly via parameter groups.)
wd” in code) trades training fit for generalization. Tune it on a validation set.weight_decay= arg.