Momentum

Why Momentum

SGD on ill-conditioned problems is dreadful. In a steep narrow valley, gradients zigzag across the walls instead of moving along the floor. Drop \eta to stop overshooting → progress along the valley dies.

Momentum

Keep a running average of past gradients — a velocity \mathbf{v}_t:

\mathbf{v}_t = \beta \mathbf{v}_{t-1} + \mathbf{g}_t,\quad \mathbf{x}_t = \mathbf{x}_{t-1} - \eta \mathbf{v}_t.

Components that consistently point one way accumulate; components that flip sign cancel. Faster progress along the valley, less zigzag.

\beta \in [0, 1), typically 0.9. Effective averaging window: 1/(1-\beta) steps.

The ill-conditioned problem

Anisotropic quadratic f(x_1, x_2) = 0.1 x_1^2 + 2 x_2^2 — gradient in x_2 is 20× larger than in x_1:

%matplotlib inline
from d2l import mxnet as d2l
from mxnet import np, npx
npx.set_np()

eta = 0.4
def f_2d(x1, x2):
    return 0.1 * x1 ** 2 + 2 * x2 ** 2
def gd_2d(x1, x2, s1, s2):
    return (x1 - eta * 0.2 * x1, x2 - eta * 4 * x2, 0, 0)

d2l.show_trace_2d(f_2d, d2l.train_2d(gd_2d))

A larger \eta diverges in x_2 before making progress in x_1:

eta = 0.6
d2l.show_trace_2d(f_2d, d2l.train_2d(gd_2d))

Momentum on the same problem

Same \eta, add \beta = 0.5 momentum. Trajectory now sails straight down the valley:

def momentum_2d(x1, x2, v1, v2):
    v1 = beta * v1 + 0.2 * x1
    v2 = beta * v2 + 4 * x2
    return x1 - eta * v1, x2 - eta * v2, v1, v2

eta, beta = 0.6, 0.5
d2l.show_trace_2d(f_2d, d2l.train_2d(momentum_2d))

Bigger \beta — even straighter, but overshoot risk grows:

eta, beta = 0.6, 0.25
d2l.show_trace_2d(f_2d, d2l.train_2d(momentum_2d))

Effective sample weight

The series \mathbf{v}_t = \sum_{i=0}^{t} \beta^i \mathbf{g}_{t-i} is an exponentially weighted moving average. Effective horizon: 1/(1-\beta) steps. \beta = 0.9 → ~10 steps; \beta = 0.99 → ~100 steps.

d2l.set_figsize()
betas = [0.95, 0.9, 0.6, 0]
for beta in betas:
    x = d2l.numpy(d2l.arange(40))
    d2l.plt.plot(x, beta ** x, label=f'beta = {beta:.2f}')
d2l.plt.xlabel('time')
d2l.plt.legend();

From-scratch implementation

Carry a velocity buffer per parameter. Standard PyTorch / SGD-with-momentum convention:

def init_momentum_states(feature_dim):
    v_w = d2l.zeros((feature_dim, 1))
    v_b = d2l.zeros(1)
    return (v_w, v_b)
def sgd_momentum(params, states, hyperparams):
    for p, v in zip(params, states):
        v[:] = hyperparams['momentum'] * v + p.grad
        p[:] -= hyperparams['lr'] * v

Training: \beta sweep

Same airfoil regression, \beta \in \{0, 0.5, 0.9\}:

def train_momentum(lr, momentum, num_epochs=2):
    d2l.train_ch11(sgd_momentum, init_momentum_states(feature_dim),
                   {'lr': lr, 'momentum': momentum}, data_iter,
                   feature_dim, num_epochs)

data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
train_momentum(0.02, 0.5)
train_momentum(0.01, 0.9)
train_momentum(0.005, 0.9)

Concise: framework SGD with momentum

Most frameworks take momentum=0.9 as a one-line argument:

d2l.train_concise_ch11('sgd', {'learning_rate': 0.005, 'momentum': 0.9},
                       data_iter)

Theory: scalar quadratic

For f(x) = \tfrac{1}{2} \lambda x^2, the momentum recurrence is a 2D linear system. Eigenvalues of the update matrix dictate convergence — momentum effectively reduces the condition number the optimizer sees:

lambdas = [0.1, 1, 10, 19]
eta = 0.1
d2l.set_figsize((6, 4))
for lam in lambdas:
    t = d2l.numpy(d2l.arange(20))
    d2l.plt.plot(t, (1 - eta * lam) ** t, label=f'lambda = {lam:.2f}')
d2l.plt.xlabel('time')
d2l.plt.legend();

Recap

  • \mathbf{v}_t = \beta \mathbf{v}_{t-1} + \mathbf{g}_t, \mathbf{x}_t = \mathbf{x}_{t-1} - \eta \mathbf{v}_t.
  • Smooths zigzag from ill-conditioning; effective averaging window 1/(1-\beta).
  • \beta = 0.9 is the practical default; \beta = 0.99 for very noisy gradients.
  • Standard SGD-with-momentum is the workhorse of computer vision; Adam (coming up) generalizes the idea with per-parameter scaling.