Minibatch Stochastic Gradient Descent

Minibatch SGD

GD: \mathcal{O}(n) per step, optimal use of data. SGD: \mathcal{O}(1) per step, noisy and one-at-a-time.

The compromise everyone uses: minibatch SGD — sample a batch of b examples, average their gradients:

\mathbf{x} \leftarrow \mathbf{x} - \frac{\eta}{b} \sum_{i \in \mathcal{B}} \nabla f_i(\mathbf{x}).

Why minibatches win

  • Variance reduction — averaging b noisy gradients cuts variance by b.
  • Hardware efficiency — GPUs do a batched matmul b\times faster than a loop over b single examples.

Vectorization and cache reuse are the real reason for batching, not just statistics.

Setup

%matplotlib inline
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
import numpy as np
import optax
import time

A = jnp.zeros((256, 256))
B = jnp.array(np.random.normal(0, 1, (256, 256)))
C = jnp.array(np.random.normal(0, 1, (256, 256)))
class Timer:
    """Record multiple running times."""
    def __init__(self):
        self.times = []
        self.start()

    def start(self):
        """Start the timer."""
        self.tik = time.time()

    def stop(self):
        """Stop the timer and record the time in a list."""
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        """Return the average time."""
        return sum(self.times) / len(self.times)

    def sum(self):
        """Return the sum of time."""
        return sum(self.times)

    def cumsum(self):
        """Return the accumulated time."""
        return np.array(self.times).cumsum().tolist()

timer = Timer()

Three loops, three speeds

Compute \mathbf{A} = \mathbf{B}\mathbf{C} on 256 \times 256 matrices in three increasingly vectorized ways:

# Compute A = BC one element at a time. JAX is functionally pure, so a
# literal `A.at[i, j].set(...)` would copy the full matrix on every write
# (O(n^2) memory traffic), turning a demo into a multi-minute run. We
# instead use a NumPy buffer to mirror the eager semantics other frameworks
# get for free; the *point* of this cell is that the elementwise dispatch
# is much slower than vectorized matmul.
A = np.zeros((256, 256), dtype=np.float32)
B_np = np.array(B)
C_np = np.array(C)
timer.start()
for i in range(256):
    for j in range(256):
        A[i, j] = np.dot(B_np[i, :], C_np[:, j])
timer.stop()
0.16472673416137695
# Compute A = BC one column at a time. We keep B/C on device; only the
# Python loop and per-column dispatch cost remain.
A = jnp.zeros((256, 256))
timer.start()
for j in range(256):
    A = A.at[:, j].set(jnp.dot(B, C[:, j]))
A.block_until_ready()
timer.stop()
0.8467981815338135
# Compute A = BC in one go
timer.start()
A = jnp.dot(B, C)
A.block_until_ready()
timer.stop()

gigaflops = [0.03 / i for i in timer.times]
print(f'performance in Gigaflops: element {gigaflops[0]:.3f}, '
      f'column {gigaflops[1]:.3f}, full {gigaflops[2]:.3f}')
performance in Gigaflops: element 0.182, column 0.035, full 0.031

Element-wise → column-wise → matrix-wise: typically two orders of magnitude difference. The cache and SIMD do the work; the loop is overhead.

The minibatch sweet spot

  • b = 1: high gradient variance, GPU mostly idle.
  • b = n: full GD — best variance, slowest progress.
  • b = 32 \ldots 1024: enough work to fill the GPU, variance reduced by \sqrt{b}.

In practice, batch size is constrained more by memory and parallelism than by statistics. Modern training: 256 to 65k+ on accelerators.

timer.start()
for j in range(0, 256, 64):
    A = A.at[:, j:j+64].set(jnp.dot(B, C[:, j:j+64]))
A.block_until_ready()
timer.stop()
print(f'performance in Gigaflops: block {0.03 / timer.times[3]:.3f}')
performance in Gigaflops: block 0.038

Airfoil dataset

Real regression dataset for the experiments — 1503 examples, 5 features, 1 target:

d2l.DATA_HUB['airfoil'] = (d2l.DATA_URL + 'airfoil_self_noise.dat',
                           '76e5be1548fd8222e5074cf0faae75edff8cf93f')


def get_data_ch11(batch_size=10, n=1500):
    data = np.genfromtxt(d2l.download('airfoil'),
                         dtype=np.float32, delimiter='\t')
    data = (data - data.mean(axis=0)) / data.std(axis=0)
    data_iter = d2l.load_array(
        (jnp.array(data[:n, :-1]), jnp.array(data[:n, -1])),
        batch_size, is_train=True)
    return data_iter, data.shape[1]-1

Manual SGD update

First isolate the optimizer step. For minibatch size b, average the gradients and move each parameter by -\eta \nabla f_{\mathcal{B}}:

def sgd(params, grads, states, hyperparams):
    updated = []
    for param, grad in zip(params, grads):
        updated.append(param - hyperparams['lr'] * grad)
    return updated

Generic training loop

The reusable trainer initializes a tiny linear model, runs forward/backward on each minibatch, and records loss against wall-clock time:

def train_ch11(trainer_fn, states, hyperparams, data_iter,
               feature_dim, num_epochs=2):
    # Initialization
    w = jnp.array(np.random.normal(scale=0.01, size=(feature_dim, 1)),
                  dtype=jnp.float32)
    b = jnp.zeros(1)
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    # Train
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[0, num_epochs], ylim=[0.22, 0.35])
    n, timer = 0, d2l.Timer()
    # JIT-fuse the per-batch step (grad + optimizer update) so we don't
    # round-trip through Python on every minibatch.
    @jax.jit
    def step(w, b, X, y):
        def loss_fn(w, b):
            return d2l.squared_loss(d2l.linreg(X, w, b), y).mean()
        grads = jax.grad(loss_fn, argnums=(0, 1))(w, b)
        return trainer_fn([w, b], list(grads), states, hyperparams)
    # Pre-stack the full dataset on device so the periodic evaluate_loss
    # stays inside one compiled call instead of looping in Python.
    eval_batches = [(jnp.array(X), jnp.array(y)) for X, y in data_iter]
    Xs = jnp.concatenate([X for X, _ in eval_batches], axis=0)
    ys = jnp.concatenate([y for _, y in eval_batches], axis=0)
    @jax.jit
    def full_eval(w, b):
        out = d2l.linreg(Xs, w, b)
        y_r = ys.reshape(out.shape)
        return ((out - y_r) ** 2 / 2).mean()
    for _ in range(num_epochs):
        for X, y in data_iter:
            X, y = jnp.array(X), jnp.array(y)
            w, b = step(w, b, X, y)
            n += X.shape[0]
            if n % 200 == 0:
                timer.stop()
                animator.add(n/X.shape[0]/len(data_iter),
                             (float(full_eval(w, b)),))
                timer.start()
    print(f'loss: {animator.Y[0][-1]:.3f}, {timer.sum()/num_epochs:.3f} sec/epoch')
    return timer.cumsum(), animator.Y[0]

Full-batch baseline

Set b to the whole dataset. Each epoch gives only one update, so the curve is smooth but progress per second is poor:

def train_sgd(lr, batch_size, num_epochs=2):
    data_iter, feature_dim = get_data_ch11(batch_size)
    return train_ch11(
        sgd, None, {'lr': lr}, data_iter, feature_dim, num_epochs)

gd_res = train_sgd(1, 1500, 10)

loss: 0.248, 0.077 sec/epoch

Comparing batch sizes

b = 1500 (full batch), b = 1 (pure SGD), b = 100, b = 10 — same model, same total epochs:

sgd_res = train_sgd(0.005, 1)

loss: 0.246, 6.083 sec/epoch
mini1_res = train_sgd(.4, 100)

loss: 0.244, 0.470 sec/epoch

Pure SGD updates often but wastes vector hardware. b=100 processes enough examples per step to make each update cheap and stable.

Wall-clock view

mini2_res = train_sgd(.05, 10)

loss: 0.247, 1.059 sec/epoch
d2l.set_figsize([6, 3])
d2l.plot(*list(map(list, zip(gd_res, sgd_res, mini1_res, mini2_res))),
         'time (sec)', 'loss', xlim=[1e-2, 10],
         legend=['gd', 'sgd', 'batch size=100', 'batch size=10'])
d2l.plt.gca().set_xscale('log')

Read the x-axis as elapsed time, not examples processed: minibatches win because they make each second of compute do more useful linear algebra.

Concise: framework optimizer

Same experiment using the framework’s built-in SGD optimizer — fewer lines, same numbers:

def train_concise_ch11(trainer_fn, hyperparams, data_iter, num_epochs=2):
    # Initialization
    net = nn.Dense(1)
    key = jax.random.PRNGKey(0)
    X_dummy = jnp.ones((1, 5))
    params = net.init(key, X_dummy)

    optimizer = trainer_fn(**hyperparams)
    opt_state = optimizer.init(params)

    loss = lambda pred, y: jnp.mean((pred - y) ** 2) / 2
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[0, num_epochs], ylim=[0.22, 0.35])
    n, timer = 0, d2l.Timer()
    # JIT-fuse the per-batch optimizer update so per-step Python overhead
    # stays out of the inner loop.
    @jax.jit
    def step(params, opt_state, X, y):
        def loss_fn(params):
            out = net.apply(params, X)
            y_reshaped = y.reshape(out.shape)
            return jnp.mean((out - y_reshaped) ** 2) / 2
        l, grads = jax.value_and_grad(loss_fn)(params)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, l

    # Pre-stack the full dataset on device so the periodic full-loss
    # evaluation is a single compiled call.
    eval_batches = [(jnp.array(X), jnp.array(y)) for X, y in data_iter]
    Xs = jnp.concatenate([X for X, _ in eval_batches], axis=0)
    ys = jnp.concatenate([y for _, y in eval_batches], axis=0)
    @jax.jit
    def full_eval(params):
        out = net.apply(params, Xs)
        y_r = ys.reshape(out.shape)
        return jnp.mean((out - y_r) ** 2) / 2
    for _ in range(num_epochs):
        for X, y in data_iter:
            X, y = jnp.array(X), jnp.array(y)
            params, opt_state, _ = step(params, opt_state, X, y)
            n += X.shape[0]
            if n % 200 == 0:
                timer.stop()
                animator.add(n/X.shape[0]/len(data_iter),
                             (float(full_eval(params)),))
                timer.start()
    print(f'loss: {animator.Y[0][-1]:.3f}, {timer.sum()/num_epochs:.3f} sec/epoch')
data_iter, _ = get_data_ch11(10)
trainer = optax.sgd
train_concise_ch11(trainer, {'learning_rate': 0.05}, data_iter)

loss: 0.247, 0.707 sec/epoch

Recap

  • Minibatch SGD interpolates between GD and pure SGD.
  • Vectorization and cache reuse make b > 1 vastly cheaper per example — that’s why everyone uses minibatches even setting statistics aside.
  • Variance scales as 1/b; convergence rate has a \sqrt{b} effective step size advantage over single-example.
  • Choose b to fill the accelerator, not to optimize the bias/variance tradeoff theoretically.