from d2l import jax as d2l
import functools
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
from flax.training import train_state
import flax
import numpy as npThe previous section did data-parallel training the hard way — manual all_reduce, manual replica management. In practice, every framework wraps it in a one-liner:
nn.DataParallel(net) (multi-GPU on one host) or nn.parallel.DistributedDataParallel (multi-host).gluon.Trainer(..., kvstore='device').tf.distribute.MirroredStrategy().Same numerical result; orders of magnitude less boilerplate; NCCL all-reduce under the hood.
We use a small ResNet for these experiments — the speedup from data parallelism only matters once the per-GPU compute is non-trivial:
class ResNet18(nn.Module):
"""A slightly modified ResNet-18 model."""
num_classes: int = 10
training: bool = True
def setup(self):
self.net = nn.Sequential([
nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding='same'),
nn.BatchNorm(not self.training),
nn.relu,
# ResNet blocks
d2l.Residual(64, training=self.training),
d2l.Residual(64, training=self.training),
d2l.Residual(128, use_1x1conv=True, strides=(2, 2),
training=self.training),
d2l.Residual(128, training=self.training),
d2l.Residual(256, use_1x1conv=True, strides=(2, 2),
training=self.training),
d2l.Residual(256, training=self.training),
d2l.Residual(512, use_1x1conv=True, strides=(2, 2),
training=self.training),
d2l.Residual(512, training=self.training),
# Global average pooling and classifier
lambda x: x.mean(axis=(1, 2)),
nn.Dense(self.num_classes),
])
def __call__(self, x):
return self.net(x)Wrap the model in the framework’s data-parallel container. Parameters are replicated to each GPU automatically:
Using 4 devices: [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]
The wrapper also handles inference — splits the input minibatch across replicas, gathers outputs:
The loop looks like ordinary single-GPU training because the wrapper owns the distributed work:
The important lesson is the interface: after wrapping the model, most training code should not need to know how many GPUs are present.
test acc: 0.92, 9.2 sec/epoch on 4 devices
Use this as the throughput baseline before the data-parallel wrapper adds replication and gradient averaging.
test acc: 0.91, 5.7 sec/epoch on 4 devices
The training loop is unchanged; the wrapper splits the minibatch and synchronizes gradients under the hood.
DataParallel, MirroredStrategy) reduce data-parallel SGD to one line of setup.DistributedDataParallel / MultiWorkerMirroredStrategy — same idea, NCCL/Gloo across the network.