A bigger model

Concise Implementation for Multiple GPUs

Concise multi-GPU training

The previous section did data-parallel training the hard way — manual all_reduce, manual replica management. In practice, every framework wraps it in a one-liner:

  • PyTorch: nn.DataParallel(net) (multi-GPU on one host) or nn.parallel.DistributedDataParallel (multi-host).
  • MXNet: gluon.Trainer(..., kvstore='device').
  • TensorFlow: tf.distribute.MirroredStrategy().

Same numerical result; orders of magnitude less boilerplate; NCCL all-reduce under the hood.

from d2l import jax as d2l
import functools
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
from flax.training import train_state
import flax
import numpy as np

We use a small ResNet for these experiments — the speedup from data parallelism only matters once the per-GPU compute is non-trivial:

class ResNet18(nn.Module):
    """A slightly modified ResNet-18 model."""
    num_classes: int = 10
    training: bool = True

    def setup(self):
        self.net = nn.Sequential([
            nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding='same'),
            nn.BatchNorm(not self.training),
            nn.relu,
            # ResNet blocks
            d2l.Residual(64, training=self.training),
            d2l.Residual(64, training=self.training),
            d2l.Residual(128, use_1x1conv=True, strides=(2, 2),
                         training=self.training),
            d2l.Residual(128, training=self.training),
            d2l.Residual(256, use_1x1conv=True, strides=(2, 2),
                         training=self.training),
            d2l.Residual(256, training=self.training),
            d2l.Residual(512, use_1x1conv=True, strides=(2, 2),
                         training=self.training),
            d2l.Residual(512, training=self.training),
            # Global average pooling and classifier
            lambda x: x.mean(axis=(1, 2)),
            nn.Dense(self.num_classes),
        ])

    def __call__(self, x):
        return self.net(x)

Multi-GPU initialization

Wrap the model in the framework’s data-parallel container. Parameters are replicated to each GPU automatically:

net = ResNet18(num_classes=10)
# Count available devices (GPUs/TPUs)
num_devices = jax.local_device_count()
print(f'Using {num_devices} devices: {jax.devices()}')
# We will initialize the network inside the training loop
Using 4 devices: [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]

Parallel evaluation

The wrapper also handles inference — splits the input minibatch across replicas, gathers outputs:

Training loop

The loop looks like ordinary single-GPU training because the wrapper owns the distributed work:

  • scatter each minibatch across devices;
  • run the same model replica on each shard;
  • average gradients across replicas;
  • step one synchronized set of parameters.

The important lesson is the interface: after wrapping the model, most training code should not need to know how many GPUs are present.

Single-GPU baseline

train(num_devices=1, batch_size=256, lr=0.1)

test acc: 0.92, 9.2 sec/epoch on 4 devices

Use this as the throughput baseline before the data-parallel wrapper adds replication and gradient averaging.

Two GPUs

train(num_devices=2, batch_size=512, lr=0.2)

test acc: 0.91, 5.7 sec/epoch on 4 devices

The training loop is unchanged; the wrapper splits the minibatch and synchronizes gradients under the hood.

Recap

  • Framework wrappers (DataParallel, MirroredStrategy) reduce data-parallel SGD to one line of setup.
  • Same numerical recipe as the from-scratch version: replicate, split, all-reduce, identical step.
  • For multi-host distributed training, use DistributedDataParallel / MultiWorkerMirroredStrategy — same idea, NCCL/Gloo across the network.