@d2l.add_to_class(SuccessiveHalvingScheduler)
def update(self, config: dict, error: float, info=None):
ri = int(config["max_epochs"]) # Rung r_i
# Update our searcher, e.g if we use Bayesian optimization later
self.searcher.update(config, error, additional_info=info)
self.all_observed_error_at_rungs[ri].append((config, error))
if ri < self.r_max:
# Bookkeeping
self.observed_error_at_rungs[ri].append((config, error))
# Determine how many configurations should be evaluated on this rung
ki = self.K - self.rung_levels.index(ri)
ni = int(self.prefact * self.eta ** ki)
# If we observed all configuration on this rung r_i, we estimate the
# top 1 / eta configuration, add them to queue and promote them for
# the next rung r_{i+1}
if len(self.observed_error_at_rungs[ri]) >= ni:
kiplus1 = ki - 1
niplus1 = int(self.prefact * self.eta ** kiplus1)
best_performing_configurations = self.get_top_n_configurations(
rung_level=ri, n=niplus1
)
riplus1 = self.rung_levels[self.K - kiplus1] # r_{i+1}
# Queue may not be empty: insert new entries at the beginning
self.queue = [
dict(config, max_epochs=riplus1)
for config in best_performing_configurations
] + self.queue
self.observed_error_at_rungs[ri] = [] # Reset