from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
from scipy import statsHyperparameters are the knobs you tune outside gradient descent: learning rate, batch size, depth, dropout rate. Usually 5–20 of them; the validation loss is non-convex, noisy, and expensive — one full training run per setting.
Hyperparameter optimization (HPO) automates the tuning. Simplest variant: random search — sample configurations from a prior, evaluate, keep the best.
Train multiple models with different hyperparameters; pick the best.
Random search beats grid search and most hand-tuning. Smarter algorithms (Bayesian opt, Hyperband) come next.
Find \mathbf{x}^* = \arg\min_{\mathbf{x} \in \mathcal{X}} f(\mathbf{x}) where f is the validation error after training with hyperparameters \mathbf{x}, and \mathcal{X} is the configuration space — a structured product of discrete and continuous ranges.
The “function” we’re optimizing is “train a model with this config, return validation error”. Wrap that into a clean callable:
class HPOTrainer(d2l.Trainer):
def validation_error(self):
self.model.training = False
accuracy = 0
val_batch_idx = 0
for batch in self.val_dataloader:
batch = self.prepare_batch(batch)
accuracy += self.model.accuracy(
self.state.params, batch[:-1], batch[-1], self.state)
val_batch_idx += 1
return 1 - accuracy / val_batch_idxdef hpo_objective_softmax_classification(config, max_epochs=8):
learning_rate = config["learning_rate"]
trainer = d2l.HPOTrainer(max_epochs=max_epochs)
data = d2l.FashionMNIST(batch_size=16)
model = d2l.SoftmaxRegression(num_outputs=10, lr=learning_rate)
trainer.fit(model=model, data=data)
return float(trainer.validation_error())A structured space — log-uniform for learning rate (spans orders of magnitude), uniform integer for layer counts, categorical for activations:
Iterate: draw random config, evaluate, log. Keep the best seen so far. Brutally simple, surprisingly effective:
errors, values = [], []
num_iterations = 5
for i in range(num_iterations):
learning_rate = config_space["learning_rate"].rvs()
print(f"Trial {i}: learning_rate = {learning_rate}")
y = hpo_objective_softmax_classification({"learning_rate": learning_rate})
print(f" validation_error = {y}")
values.append(learning_rate)
errors.append(y) validation_error = 0.16739994287490845