Training

Adam

Adam

Adam (Kingma & Ba, 2014) combines the two best ideas of the chapter:

  • Momentum — first moment EMA \mathbf{v}_t = \beta_1 \mathbf{v}_{t-1} + (1-\beta_1)\mathbf{g}_t.
  • RMSProp scaling — second moment EMA \mathbf{s}_t = \beta_2 \mathbf{s}_{t-1} + (1-\beta_2)\mathbf{g}_t^2.

The default optimizer for deep learning since ~2015. Variants (AdamW, RAdam, NAdam, Lion) iterate on the recipe.

Bias correction + update

The EMAs are initialized at zero, so they’re biased toward zero early on. Correct it:

\hat{\mathbf{v}}_t = \mathbf{v}_t / (1-\beta_1^t),\quad \hat{\mathbf{s}}_t = \mathbf{s}_t / (1-\beta_2^t).

Update:

\mathbf{x}_t \leftarrow \mathbf{x}_{t-1} - \frac{\eta}{\sqrt{\hat{\mathbf{s}}_t} + \epsilon} \odot \hat{\mathbf{v}}_t.

Defaults that just work most of the time: \beta_1 = 0.9, \beta_2 = 0.999, \epsilon = 10^{-8}.

From-scratch Adam

Two state buffers per parameter (v, s); track the step counter for bias correction:

%matplotlib inline
from d2l import torch as d2l
import torch

def init_adam_states(feature_dim):
    v_w, v_b = d2l.zeros((feature_dim, 1)), d2l.zeros(1)
    s_w, s_b = d2l.zeros((feature_dim, 1)), d2l.zeros(1)
    return ((v_w, s_w), (v_b, s_b))

def adam(params, states, hyperparams):
    beta1, beta2, eps = 0.9, 0.999, 1e-6
    for p, (v, s) in zip(params, states):
        with torch.no_grad():
            v[:] = beta1 * v + (1 - beta1) * p.grad
            s[:] = beta2 * s + (1 - beta2) * torch.square(p.grad)
            v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
            s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
            p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr)
                                                       + eps)
        p.grad.data.zero_()
    hyperparams['t'] += 1
data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(adam, init_adam_states(feature_dim),
               {'lr': 0.01, 't': 1}, data_iter, feature_dim);

loss: 0.242, 0.095 sec/epoch

The key comparison is not syntax but behavior: Adam reaches a useful loss with little learning-rate tuning because the step is both smoothed and coordinate-scaled.

Concise: framework Adam

trainer = torch.optim.Adam
d2l.train_concise_ch11(trainer, {'lr': 0.01}, data_iter)

loss: 0.245, 0.075 sec/epoch

Yogi: a robustness fix

Adam’s \mathbf{s}_t EMA can grow or shrink per step. Yogi (Zaheer et al., 2018) replaces the mixing with a sign-aware update so \mathbf{s}_t only grows under large gradients — more stable for sparse / heavy-tailed gradient distributions:

\mathbf{s}_t = \mathbf{s}_{t-1} - (1-\beta_2)\, \text{sign}(\mathbf{s}_{t-1} - \mathbf{g}_t^2) \odot \mathbf{g}_t^2.

def yogi(params, states, hyperparams):
    beta1, beta2, eps = 0.9, 0.999, 1e-3
    for p, (v, s) in zip(params, states):
        with torch.no_grad():
            v[:] = beta1 * v + (1 - beta1) * p.grad
            s[:] = s + (1 - beta2) * torch.sign(
                torch.square(p.grad) - s) * torch.square(p.grad)
            v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
            s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
            p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr)
                                                       + eps)
        p.grad.data.zero_()
    hyperparams['t'] += 1

data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(yogi, init_adam_states(feature_dim),
               {'lr': 0.01, 't': 1}, data_iter, feature_dim);

loss: 0.244, 0.077 sec/epoch

Recap

  • Adam = momentum + RMSProp + bias correction. The default for transformers / language models.
  • Defaults \beta_1 = 0.9, \beta_2 = 0.999, \epsilon = 10^{-8} rarely need tuning.
  • Failure modes: instability with rare-but-large gradients (Yogi addresses this); generalization gap vs. SGD on certain vision tasks (motivated AdamW + decoupled weight decay).