%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import numpy as npA single GPU can train ResNet on ImageNet — slowly. Modern large models need many GPUs. Three ways to split work:
Data parallelism is the default for everyday training.
Original, network partitioning, layerwise partitioning, data parallelism.
Each GPU computes a forward + backward pass on its slice of the minibatch. After the backward pass, gradients are averaged across GPUs (all_reduce). Optimizer step then runs identically on every GPU, keeping replicas in sync.
Data-parallel SGD on 2 GPUs: split data, compute gradients independently, all-reduce, then update.
Tiny LeNet for the demo — small enough to fit on each GPU many times over:
import functools
import optax
# Initialize model parameters
scale = 0.01
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 4)
W1 = jax.random.normal(keys[0], (20, 1, 3, 3)) * scale
b1 = jnp.zeros(20)
W2 = jax.random.normal(keys[1], (50, 20, 5, 5)) * scale
b2 = jnp.zeros(50)
W3 = jax.random.normal(keys[2], (800, 128)) * scale
b3 = jnp.zeros(128)
W4 = jax.random.normal(keys[3], (128, 10)) * scale
b4 = jnp.zeros(10)
params = [W1, b1, W2, b2, W3, b3, W4, b4]
# Define the model
def lenet(params, X):
h1_conv = jax.lax.conv_general_dilated(
X, params[0], window_strides=(1, 1), padding='VALID',
dimension_numbers=('NCHW', 'OIHW', 'NCHW'))
h1_conv = h1_conv + params[1].reshape(1, -1, 1, 1)
h1_activation = jax.nn.relu(h1_conv)
# Average pooling
h1 = jax.lax.reduce_window(
h1_activation, 0.0, jax.lax.add, (1, 1, 2, 2), (1, 1, 2, 2),
'VALID') / 4.0
h2_conv = jax.lax.conv_general_dilated(
h1, params[2], window_strides=(1, 1), padding='VALID',
dimension_numbers=('NCHW', 'OIHW', 'NCHW'))
h2_conv = h2_conv + params[3].reshape(1, -1, 1, 1)
h2_activation = jax.nn.relu(h2_conv)
h2 = jax.lax.reduce_window(
h2_activation, 0.0, jax.lax.add, (1, 1, 2, 2), (1, 1, 2, 2),
'VALID') / 4.0
h2 = h2.reshape(h2.shape[0], -1)
h3_linear = jnp.dot(h2, params[4]) + params[5]
h3 = jax.nn.relu(h3_linear)
y_hat = jnp.dot(h3, params[6]) + params[7]
return y_hatReplicate the parameter list onto each device:
Sum vectors across GPUs and broadcast the result back — the gradient-averaging primitive of data-parallel SGD. NCCL implements this efficiently in production:
# In JAX, allreduce is done inside pmap via jax.lax.psum/pmean.
# Here we demonstrate with a simple pmap example.
devices = jax.local_devices()[:2]
data = jnp.stack([jnp.ones((1, 2)) * (i + 1) for i in range(2)])
print('before allreduce:\n', data[0], '\n', data[1])
summed = jax.pmap(lambda x: jax.lax.psum(x, axis_name='i'),
axis_name='i')(data)
print('after allreduce:\n', summed[0], '\n', summed[1])before allreduce:
[[1. 1.]]
[[2. 2.]]
after allreduce:
[[3. 3.]]
[[3. 3.]]
Split a tensor evenly across devices:
data = jnp.arange(20).reshape(4, 5)
devices = jax.local_devices()[:2]
mesh = jax.sharding.Mesh(np.array(devices), ('dev',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
split = jax.device_put(data.reshape(2, 2, 5), sharding)
print('input :', data)
print('load into', devices)
print('output:', split)input : [[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]]
load into [CudaDevice(id=0), CudaDevice(id=1)]
output:
[[[ 0 1 2 3 4]
[ 5 6 7 8 9]]
[[10 11 12 13 14]
[15 16 17 18 19]]]
def split_batch(X, y, num_devices):
"""Split `X` and `y` across devices by reshaping."""
assert X.shape[0] == y.shape[0]
batch_size = X.shape[0]
# Reshape (batch, ...) -> (num_devices, batch_per_device, ...)
def _reshape(a):
return a.reshape(num_devices, batch_size // num_devices, *a.shape[1:])
return _reshape(X), _reshape(y)Forward + backward on each replica → all_reduce gradients → update parameters identically:
@functools.partial(jax.pmap, axis_name='batch',
static_broadcasted_argnums=(3,))
def pmap_step(params, X_shard, y_shard, lr):
"""One training step executed in parallel on each device."""
def loss_fn(p):
y_hat = lenet(p, X_shard)
return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(
y_hat, y_shard))
grads = jax.grad(loss_fn)(params)
# All-reduce: sum gradients across devices
grads = jax.lax.pmean(grads, axis_name='batch')
# SGD update
params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
return params
def train_batch(replicated_params, X, y, num_gpus, lr):
X_shards, y_shards = split_batch(X, y, num_gpus)
replicated_params = pmap_step(
replicated_params, X_shards, y_shards, lr)
return replicated_paramsdef evaluate_accuracy_jax(predict_fn, data_iter):
"""Evaluate accuracy using JAX predict function."""
num_correct, num_total = 0, 0
for X, y in data_iter:
X, y = jnp.array(X).transpose(0, 3, 1, 2), jnp.array(y)
y_hat = predict_fn(X)
num_correct += jnp.sum(jnp.argmax(y_hat, axis=1) == y).item()
num_total += y.shape[0]
return num_correct / num_total
def train(num_gpus, batch_size, lr):
data = d2l.FashionMNIST(batch_size=batch_size)
train_iter = data.get_dataloader(train=True)
test_iter = data.get_dataloader(train=False)
devices = jax.local_devices()[:num_gpus]
# Replicate model parameters to `num_gpus` GPUs
replicated_params = jax.tree.map(
lambda x: jnp.stack([x] * num_gpus), params)
num_epochs = 10
animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])
timer = d2l.Timer()
for epoch in range(num_epochs):
timer.start()
for X, y in train_iter:
X = jnp.array(X).transpose(0, 3, 1, 2)
y = jnp.array(y)
# Perform multi-GPU training for a single minibatch
replicated_params = train_batch(
replicated_params, X, y, num_gpus, lr)
# Block until computation is done
jax.tree.map(lambda x: x.block_until_ready(), replicated_params)
timer.stop()
# Evaluate on the first replica's parameters
host_params = jax.tree.map(lambda x: x[0], replicated_params)
animator.add(epoch + 1, (evaluate_accuracy_jax(
lambda x: lenet(host_params, x), test_iter),))
print(f'test acc: {animator.Y[0][-1]:.2f}, {timer.avg():.1f} sec/epoch '
f'on {str(devices)}')test acc: 0.84, 0.9 sec/epoch on [CudaDevice(id=0)]
This gives the wall-clock reference point: one model copy, one minibatch stream, no gradient synchronization.
Per-epoch time roughly halves; per-step iteration count drops because each GPU sees half the minibatch:
test acc: 0.85, 1.3 sec/epoch on [CudaDevice(id=0), CudaDevice(id=1)]
all_reduce is the workhorse — implemented as a ring reduction in NCCL; bandwidth-optimal for k GPUs.