Momentum

Why Momentum

SGD on ill-conditioned problems is dreadful. In a steep narrow valley, gradients zigzag across the walls instead of moving along the floor. Drop \eta to stop overshooting → progress along the valley dies.

Momentum

Keep a running average of past gradients — a velocity \mathbf{v}_t:

\mathbf{v}_t = \beta \mathbf{v}_{t-1} + \mathbf{g}_t,\quad \mathbf{x}_t = \mathbf{x}_{t-1} - \eta \mathbf{v}_t.

Components that consistently point one way accumulate; components that flip sign cancel. Faster progress along the valley, less zigzag.

\beta \in [0, 1), typically 0.9. Effective averaging window: 1/(1-\beta) steps.

The ill-conditioned problem

Anisotropic quadratic f(x_1, x_2) = 0.1 x_1^2 + 2 x_2^2 — gradient in x_2 is 20× larger than in x_1:

%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import numpy as np

eta = 0.4
def f_2d(x1, x2):
    return 0.1 * x1 ** 2 + 2 * x2 ** 2
def gd_2d(x1, x2, s1, s2):
    return (x1 - eta * 0.2 * x1, x2 - eta * 4 * x2, 0, 0)

d2l.show_trace_2d(f_2d, d2l.train_2d(gd_2d))

epoch 20, x1: -0.943467, x2: -0.000073

A larger \eta diverges in x_2 before making progress in x_1:

eta = 0.6
d2l.show_trace_2d(f_2d, d2l.train_2d(gd_2d))

epoch 20, x1: -0.387814, x2: -1673.365109

Momentum on the same problem

Same \eta, add \beta = 0.5 momentum. Trajectory now sails straight down the valley:

def momentum_2d(x1, x2, v1, v2):
    v1 = beta * v1 + 0.2 * x1
    v2 = beta * v2 + 4 * x2
    return x1 - eta * v1, x2 - eta * v2, v1, v2

eta, beta = 0.6, 0.5
d2l.show_trace_2d(f_2d, d2l.train_2d(momentum_2d))

epoch 20, x1: 0.007188, x2: 0.002553

Bigger \beta — even straighter, but overshoot risk grows:

eta, beta = 0.6, 0.25
d2l.show_trace_2d(f_2d, d2l.train_2d(momentum_2d))

epoch 20, x1: -0.126340, x2: -0.186632

Effective sample weight

The series \mathbf{v}_t = \sum_{i=0}^{t} \beta^i \mathbf{g}_{t-i} is an exponentially weighted moving average. Effective horizon: 1/(1-\beta) steps. \beta = 0.9 → ~10 steps; \beta = 0.99 → ~100 steps.

d2l.set_figsize()
betas = [0.95, 0.9, 0.6, 0]
for beta in betas:
    x = d2l.numpy(d2l.arange(40))
    d2l.plt.plot(x, beta ** x, label=f'beta = {beta:.2f}')
d2l.plt.xlabel('time')
d2l.plt.legend();

From-scratch implementation

Carry a velocity buffer per parameter. Standard PyTorch / SGD-with-momentum convention:

def init_momentum_states(feature_dim):
    v_w = d2l.zeros((feature_dim, 1))
    v_b = d2l.zeros(1)
    return [v_w, v_b]
def sgd_momentum(params, grads, states, hyperparams):
    for i in range(len(params)):
        states[i] = hyperparams['momentum'] * states[i] + grads[i]
        params[i] = params[i] - hyperparams['lr'] * states[i]
    return params[0], params[1]

Training: \beta sweep

Same airfoil regression, \beta \in \{0, 0.5, 0.9\}:

def train_momentum(lr, momentum, num_epochs=2):
    d2l.train_ch11(sgd_momentum, init_momentum_states(feature_dim),
                   {'lr': lr, 'momentum': momentum}, data_iter,
                   feature_dim, num_epochs)

data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
train_momentum(0.02, 0.5)

loss: 0.243, 1.128 sec/epoch
train_momentum(0.01, 0.9)

loss: 0.250, 0.641 sec/epoch
train_momentum(0.005, 0.9)

loss: 0.278, 0.728 sec/epoch

Concise: framework SGD with momentum

Most frameworks take momentum=0.9 as a one-line argument:

import optax

trainer = optax.sgd
d2l.train_concise_ch11(trainer, {'learning_rate': 0.005, 'momentum': 0.9},
                       data_iter)

loss: 0.245, 0.720 sec/epoch

Theory: scalar quadratic

For f(x) = \tfrac{1}{2} \lambda x^2, the momentum recurrence is a 2D linear system. Eigenvalues of the update matrix dictate convergence — momentum effectively reduces the condition number the optimizer sees:

lambdas = [0.1, 1, 10, 19]
eta = 0.1
d2l.set_figsize((6, 4))
for lam in lambdas:
    t = d2l.numpy(d2l.arange(20))
    d2l.plt.plot(t, (1 - eta * lam) ** t, label=f'lambda = {lam:.2f}')
d2l.plt.xlabel('time')
d2l.plt.legend();

Recap

  • \mathbf{v}_t = \beta \mathbf{v}_{t-1} + \mathbf{g}_t, \mathbf{x}_t = \mathbf{x}_{t-1} - \eta \mathbf{v}_t.
  • Smooths zigzag from ill-conditioning; effective averaging window 1/(1-\beta).
  • \beta = 0.9 is the practical default; \beta = 0.99 for very noisy gradients.
  • Standard SGD-with-momentum is the workhorse of computer vision; Adam (coming up) generalizes the idea with per-parameter scaling.