import time
from d2l import torch as d2l
from scipy import statsHPO algorithms have a common structure. The next two decks will swap out pieces (parallel scheduling, multi-fidelity). This deck factors out the common skeleton:
Same shape every modern HPO library uses (Optuna, SyneTune, Vizier, Ray Tune).
A concrete RandomSearcher:
class RandomSearcher(HPOSearcher):
def __init__(self, config_space: dict, initial_config=None):
self.save_hyperparameters()
def sample_configuration(self) -> dict:
if self.initial_config is not None:
result = self.initial_config
self.initial_config = None
else:
result = {
name: domain.rvs()
for name, domain in self.config_space.items()
}
return resultConcrete sequential / FIFO scheduler:
Combines searcher + scheduler + objective into a single loop:
class HPOTuner(d2l.HyperParameters):
def __init__(self, scheduler: HPOScheduler, objective: callable):
self.save_hyperparameters()
# Bookkeeping results for plotting
self.incumbent = None
self.incumbent_error = None
self.incumbent_trajectory = []
self.cumulative_runtime = []
self.current_runtime = 0
self.records = []
def run(self, number_of_trials):
for i in range(number_of_trials):
start_time = time.time()
config = self.scheduler.suggest()
print(f"Trial {i}: config = {config}")
error = self.objective(**config)
error = float(d2l.numpy(error.cpu()))
self.scheduler.update(config, error)
runtime = time.time() - start_time
self.bookkeeping(config, error, runtime)
print(f" error = {error}, runtime = {runtime}")Track wall-clock time and best-seen objective so we can plot any-time performance later:
@d2l.add_to_class(HPOTuner)
def bookkeeping(self, config: dict, error: float, runtime: float):
self.records.append({"config": config, "error": error, "runtime": runtime})
# Check if the last hyperparameter configuration performs better
# than the incumbent
if self.incumbent is None or self.incumbent_error > error:
self.incumbent = config
self.incumbent_error = error
# Add current best observed performance to the optimization trajectory
self.incumbent_trajectory.append(self.incumbent_error)
# Update runtime
self.current_runtime += runtime
self.cumulative_runtime.append(self.current_runtime)Run the abstraction on a real model — a small CNN on Fashion-MNIST. Search over learning rate, batch size, and network width:
def hpo_objective_lenet(learning_rate, batch_size, max_epochs=10):
model = d2l.LeNet(lr=learning_rate, num_classes=10)
trainer = d2l.HPOTrainer(max_epochs=max_epochs, num_gpus=1)
data = d2l.FashionMNIST(batch_size=batch_size)
model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)
trainer.fit(model=model, data=data)
validation_error = trainer.validation_error()
return validation_errorThe incumbent curve reports the best validation error found so far as the tuner spends more wall-clock time. Downward steps mean a new configuration beat the previous best; flat regions mean the search is still evaluating but has not improved the incumbent.
error = 0.4261544346809387, runtime = 54.515828132629395