Successive halving — algorithm

Multi-Fidelity Hyperparameter Optimization

Multi-Fidelity HPO

Random search wastes most of its compute on bad configs — training each one to convergence to find out. But you often know after a few epochs that a config is hopeless: the training loss curve barely moves.

Multi-fidelity HPO uses cheap, low-fidelity evaluations (few epochs, small subset of data) to prune bad configs early, spending the saved compute on the promising ones.

Learning curves of random configs: most diverge fast; a few rise to the top.

The classic algorithm: successive halving. Run N configs for a small budget; keep the top 1/\eta, run for \eta\times longer; repeat.

Successive halving: prune bad configs at every rung, double the budget for survivors.

Setup

from d2l import torch as d2l
import numpy as np
from scipy import stats
from collections import defaultdict
d2l.set_figsize()

Discretize budget into rungs r_1 < r_2 < \ldots. At rung i:

  1. Train all surviving configs to budget r_i epochs.
  2. Sort by validation error.
  3. Keep the top 1/\eta. Discard the rest.

Continue to the next rung with \eta \times the budget, 1/\eta the number of configs. Total compute roughly constant per rung.

If n_i \approx n_0/\eta^i and r_i=r_0\eta^i, then n_i r_i \approx n_0 r_0: later rungs spend the same budget on fewer, better configurations.

Implementation

class SuccessiveHalvingScheduler(d2l.HPOScheduler):
    def __init__(self, searcher, eta, r_min, r_max, prefact=1):
        self.save_hyperparameters()
        # Compute K, which is later used to determine the number of configurations
        self.K = int(np.log(r_max / r_min) / np.log(eta))
        # Define the rungs
        self.rung_levels = [r_min * eta ** k for k in range(self.K + 1)]
        if r_max not in self.rung_levels:
            # The final rung should be r_max
            self.rung_levels.append(r_max)
            self.K += 1
        # Bookkeeping
        self.observed_error_at_rungs = defaultdict(list)
        self.all_observed_error_at_rungs = defaultdict(list)
        # Our processing queue
        self.queue = []
@d2l.add_to_class(SuccessiveHalvingScheduler)
def suggest(self):
    if len(self.queue) == 0:
        # Start a new round of successive halving
        # Number of configurations for the first rung:
        n0 = int(self.prefact * self.eta ** self.K)
        for _ in range(n0):
            config = self.searcher.sample_configuration()
            config["max_epochs"] = self.r_min  # Set r = r_min
            self.queue.append(config)
    # Return an element from the queue
    return self.queue.pop()

Implementation (cont.)

@d2l.add_to_class(SuccessiveHalvingScheduler)
def update(self, config: dict, error: float, info=None):
    ri = int(config["max_epochs"])  # Rung r_i
    # Update our searcher, e.g if we use Bayesian optimization later
    self.searcher.update(config, error, additional_info=info)
    self.all_observed_error_at_rungs[ri].append((config, error))
    if ri < self.r_max:
        # Bookkeeping
        self.observed_error_at_rungs[ri].append((config, error))
        # Determine how many configurations should be evaluated on this rung
        ki = self.K - self.rung_levels.index(ri)
        ni = int(self.prefact * self.eta ** ki)
        # If we observed all configuration on this rung r_i, we estimate the
        # top 1 / eta configuration, add them to queue and promote them for
        # the next rung r_{i+1}
        if len(self.observed_error_at_rungs[ri]) >= ni:
            kiplus1 = ki - 1
            niplus1 = int(self.prefact * self.eta ** kiplus1)
            best_performing_configurations = self.get_top_n_configurations(
                rung_level=ri, n=niplus1
            )
            riplus1 = self.rung_levels[self.K - kiplus1]  # r_{i+1}
            # Queue may not be empty: insert new entries at the beginning
            self.queue = [
                dict(config, max_epochs=riplus1)
                for config in best_performing_configurations
            ] + self.queue
            self.observed_error_at_rungs[ri] = []  # Reset
@d2l.add_to_class(SuccessiveHalvingScheduler)
def get_top_n_configurations(self, rung_level, n):
    rung = self.observed_error_at_rungs[rung_level]
    if not rung:
        return []
    sorted_rung = sorted(rung, key=lambda x: x[1])
    return [x[0] for x in sorted_rung[:n]]

Running it

min_number_of_epochs = 2
max_number_of_epochs = 10
eta = 2
num_gpus=1

config_space = {
    "learning_rate": stats.loguniform(1e-2, 1),
    "batch_size": stats.randint(32, 256),
}
initial_config = {
    "learning_rate": 0.1,
    "batch_size": 128,
}

The run evaluates 30 configurations across increasing budgets. For lecture, keep the code visible but suppress the intermediate animation frames; the next slide shows the resulting incumbent trajectory.

Recap

  • Successive halving = aggressive early stopping based on partial-training scores.
  • Hyperparameter \eta (typical: 3 or 4) controls how aggressive the halving is.
  • Total budget per rung stays constant; survivors get exponentially more compute.
  • The asynchronous variant (ASHA, next deck) generalizes this to parallel workers without idle time.