Why Momentum
SGD on ill-conditioned problems is dreadful. In a steep narrow valley, gradients zigzag across the walls instead of moving along the floor. Drop \eta to stop overshooting → progress along the valley dies.
Momentum
Keep a running average of past gradients — a velocity \mathbf{v}_t :
\mathbf{v}_t = \beta \mathbf{v}_{t-1} + \mathbf{g}_t,\quad
\mathbf{x}_t = \mathbf{x}_{t-1} - \eta \mathbf{v}_t.
Components that consistently point one way accumulate; components that flip sign cancel. Faster progress along the valley, less zigzag.
\beta \in [0, 1) , typically 0.9 . Effective averaging window: 1/(1-\beta) steps.
The ill-conditioned problem
Anisotropic quadratic f(x_1, x_2) = 0.1 x_1^2 + 2 x_2^2 — gradient in x_2 is 20× larger than in x_1 :
% matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import numpy as np
eta = 0.4
def f_2d(x1, x2):
return 0.1 * x1 ** 2 + 2 * x2 ** 2
def gd_2d(x1, x2, s1, s2):
return (x1 - eta * 0.2 * x1, x2 - eta * 4 * x2, 0 , 0 )
d2l.show_trace_2d(f_2d, d2l.train_2d(gd_2d))
epoch 20, x1: -0.943467, x2: -0.000073
A larger \eta diverges in x_2 before making progress in x_1 :
eta = 0.6
d2l.show_trace_2d(f_2d, d2l.train_2d(gd_2d))
epoch 20, x1: -0.387814, x2: -1673.365109
Momentum on the same problem
Same \eta , add \beta = 0.5 momentum. Trajectory now sails straight down the valley:
def momentum_2d(x1, x2, v1, v2):
v1 = beta * v1 + 0.2 * x1
v2 = beta * v2 + 4 * x2
return x1 - eta * v1, x2 - eta * v2, v1, v2
eta, beta = 0.6 , 0.5
d2l.show_trace_2d(f_2d, d2l.train_2d(momentum_2d))
epoch 20, x1: 0.007188, x2: 0.002553
Bigger \beta — even straighter, but overshoot risk grows:
eta, beta = 0.6 , 0.25
d2l.show_trace_2d(f_2d, d2l.train_2d(momentum_2d))
epoch 20, x1: -0.126340, x2: -0.186632
Effective sample weight
The series \mathbf{v}_t = \sum_{i=0}^{t} \beta^i \mathbf{g}_{t-i} is an exponentially weighted moving average. Effective horizon: 1/(1-\beta) steps. \beta = 0.9 → ~10 steps; \beta = 0.99 → ~100 steps.
d2l.set_figsize()
betas = [0.95 , 0.9 , 0.6 , 0 ]
for beta in betas:
x = d2l.numpy(d2l.arange(40 ))
d2l.plt.plot(x, beta ** x, label= f'beta = { beta:.2f} ' )
d2l.plt.xlabel('time' )
d2l.plt.legend();
From-scratch implementation
Carry a velocity buffer per parameter. Standard PyTorch / SGD-with-momentum convention:
def init_momentum_states(feature_dim):
v_w = d2l.zeros((feature_dim, 1 ))
v_b = d2l.zeros(1 )
return [v_w, v_b]
def sgd_momentum(params, grads, states, hyperparams):
for i in range (len (params)):
states[i] = hyperparams['momentum' ] * states[i] + grads[i]
params[i] = params[i] - hyperparams['lr' ] * states[i]
return params[0 ], params[1 ]
Training: \beta sweep
Same airfoil regression, \beta \in \{0, 0.5, 0.9\} :
def train_momentum(lr, momentum, num_epochs= 2 ):
d2l.train_ch11(sgd_momentum, init_momentum_states(feature_dim),
{'lr' : lr, 'momentum' : momentum}, data_iter,
feature_dim, num_epochs)
data_iter, feature_dim = d2l.get_data_ch11(batch_size= 10 )
train_momentum(0.02 , 0.5 )
loss: 0.243, 1.128 sec/epoch
train_momentum(0.01 , 0.9 )
loss: 0.250, 0.641 sec/epoch
train_momentum(0.005 , 0.9 )
loss: 0.278, 0.728 sec/epoch
Concise: framework SGD with momentum
Most frameworks take momentum=0.9 as a one-line argument:
import optax
trainer = optax.sgd
d2l.train_concise_ch11(trainer, {'learning_rate' : 0.005 , 'momentum' : 0.9 },
data_iter)
loss: 0.245, 0.720 sec/epoch
Theory: scalar quadratic
For f(x) = \tfrac{1}{2} \lambda x^2 , the momentum recurrence is a 2D linear system. Eigenvalues of the update matrix dictate convergence — momentum effectively reduces the condition number the optimizer sees:
lambdas = [0.1 , 1 , 10 , 19 ]
eta = 0.1
d2l.set_figsize((6 , 4 ))
for lam in lambdas:
t = d2l.numpy(d2l.arange(20 ))
d2l.plt.plot(t, (1 - eta * lam) ** t, label= f'lambda = { lam:.2f} ' )
d2l.plt.xlabel('time' )
d2l.plt.legend();
Recap
\mathbf{v}_t = \beta \mathbf{v}_{t-1} + \mathbf{g}_t , \mathbf{x}_t = \mathbf{x}_{t-1} - \eta \mathbf{v}_t .
Smooths zigzag from ill-conditioning; effective averaging window 1/(1-\beta) .
\beta = 0.9 is the practical default; \beta = 0.99 for very noisy gradients.
Standard SGD-with-momentum is the workhorse of computer vision; Adam (coming up) generalizes the idea with per-parameter scaling.