%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import math
import numpy as npThe deep-learning loss is an average:
f(\mathbf{x}) = \frac{1}{n} \sum_{i=1}^{n} f_i(\mathbf{x}).
A full gradient \nabla f costs \mathcal{O}(n) per step. A million-example dataset → a million forward passes per parameter update. Untenable.
Pick a random example i and step with \nabla f_i — \mathcal{O}(1) per step, unbiased estimator (\mathbb{E}_i \nabla f_i = \nabla f):
\mathbf{x} \leftarrow \mathbf{x} - \eta \nabla f_i(\mathbf{x}).
The price: noisy gradients. They blur the trajectory, but also help escape narrow local minima — a double-edged property this chapter unpacks.
We don’t actually need a dataset. Take the same anisotropic f(x_1, x_2) = x_1^2 + 2x_2^2 from the GD section, add \mathcal{N}(0, 1) noise to each gradient component, and watch how the trajectory differs:
With constant learning rate, SGD oscillates around the minimum forever — the variance of the noise sets a floor on how close it gets:
epoch 50, x1: 0.125304, x2: -0.081043
Constant \eta → \mathcal{O}(\eta) noise floor. Decay \eta over time → converges to the minimum.
Common schedules:
epoch 1000, x1: -0.903036, x2: -0.020639
Exponential decay reduces variance quickly, but can shrink the step size too fast. Polynomial inverse-square-root decay keeps exploration longer and converges better in this example:
epoch 50, x1: -0.114423, x2: -0.057479