import time
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
from scipy import statsHPO algorithms have a common structure. The next two decks will swap out pieces (parallel scheduling, multi-fidelity). This deck factors out the common skeleton:
Same shape every modern HPO library uses (Optuna, SyneTune, Vizier, Ray Tune).
A concrete RandomSearcher:
class RandomSearcher(HPOSearcher):
def __init__(self, config_space: dict, initial_config=None):
self.save_hyperparameters()
def sample_configuration(self) -> dict:
if self.initial_config is not None:
result = self.initial_config
self.initial_config = None
else:
result = {
name: domain.rvs()
for name, domain in self.config_space.items()
}
return resultConcrete sequential / FIFO scheduler:
Combines searcher + scheduler + objective into a single loop:
class HPOTuner(d2l.HyperParameters):
def __init__(self, scheduler: HPOScheduler, objective: callable):
self.save_hyperparameters()
# Bookkeeping results for plotting
self.incumbent = None
self.incumbent_error = None
self.incumbent_trajectory = []
self.cumulative_runtime = []
self.current_runtime = 0
self.records = []
def run(self, number_of_trials):
for i in range(number_of_trials):
start_time = time.time()
config = self.scheduler.suggest()
print(f"Trial {i}: config = {config}")
error = self.objective(**config)
error = float(error)
self.scheduler.update(config, error)
runtime = time.time() - start_time
self.bookkeeping(config, error, runtime)
print(f" error = {error}, runtime = {runtime}")Track wall-clock time and best-seen objective so we can plot any-time performance later:
@d2l.add_to_class(HPOTuner)
def bookkeeping(self, config: dict, error: float, runtime: float):
self.records.append({"config": config, "error": error, "runtime": runtime})
# Check if the last hyperparameter configuration performs better
# than the incumbent
if self.incumbent is None or self.incumbent_error > error:
self.incumbent = config
self.incumbent_error = error
# Add current best observed performance to the optimization trajectory
self.incumbent_trajectory.append(self.incumbent_error)
# Update runtime
self.current_runtime += runtime
self.cumulative_runtime.append(self.current_runtime)Run the abstraction on a real model — a small CNN on Fashion-MNIST. Search over learning rate, batch size, and network width:
def hpo_objective_lenet(learning_rate, batch_size, max_epochs=10):
model = d2l.LeNet(lr=learning_rate, num_classes=10)
trainer = d2l.HPOTrainer(max_epochs=max_epochs, num_gpus=1)
data = d2l.FashionMNIST(batch_size=batch_size)
trainer.fit(model=model, data=data)
validation_error = trainer.validation_error()
return validation_errorThe incumbent curve reports the best validation error found so far as the tuner spends more wall-clock time. Downward steps mean a new configuration beat the previous best; flat regions mean the search is still evaluating but has not improved the incumbent.
error = 0.8033770322799683, runtime = 30.238523721694946