Toy network

Training on Multiple GPUs

Scaling beyond one GPU

A single GPU can train ResNet on ImageNet — slowly. Modern large models need many GPUs. Three ways to split work:

  • Network partitioning — different layers on different GPUs. Hard to balance; rarely used alone today.
  • Layerwise partitioning — split each layer’s parameters across GPUs (model parallelism). For giant models whose weights don’t fit on one GPU.
  • Data parallelism — replicate the model on every GPU; each processes a different minibatch chunk; gradients averaged across GPUs.

Data parallelism is the default for everyday training.

Strategies side by side

Original, network partitioning, layerwise partitioning, data parallelism.

Data parallelism

Each GPU computes a forward + backward pass on its slice of the minibatch. After the backward pass, gradients are averaged across GPUs (all_reduce). Optimizer step then runs identically on every GPU, keeping replicas in sync.

Data-parallel SGD on 2 GPUs: split data, compute gradients independently, all-reduce, then update.

%matplotlib inline
from d2l import tensorflow as d2l
import tensorflow as tf
import keras

Tiny LeNet for the demo — small enough to fit on each GPU many times over:

# Initialize model parameters (NHWC layout for TF)
scale = 0.01
# W shape: (H, W, in_channels, out_channels) for TF conv2d
W1 = tf.Variable(tf.random.normal(shape=(3, 3, 1, 20)) * scale)
b1 = tf.Variable(tf.zeros(20))
W2 = tf.Variable(tf.random.normal(shape=(5, 5, 20, 50)) * scale)
b2 = tf.Variable(tf.zeros(50))
W3 = tf.Variable(tf.random.normal(shape=(800, 128)) * scale)
b3 = tf.Variable(tf.zeros(128))
W4 = tf.Variable(tf.random.normal(shape=(128, 10)) * scale)
b4 = tf.Variable(tf.zeros(10))
params = [W1, b1, W2, b2, W3, b3, W4, b4]

# Define the model (inputs in NHWC format)
def lenet(X, params):
    h1_conv = tf.nn.conv2d(X, params[0], strides=1, padding='VALID') + params[1]
    h1_activation = tf.nn.relu(h1_conv)
    h1 = tf.nn.avg_pool2d(h1_activation, ksize=2, strides=2, padding='VALID')
    h2_conv = tf.nn.conv2d(h1, params[2], strides=1, padding='VALID') + params[3]
    h2_activation = tf.nn.relu(h2_conv)
    h2 = tf.nn.avg_pool2d(h2_activation, ksize=2, strides=2, padding='VALID')
    h2 = tf.reshape(h2, (tf.shape(h2)[0], -1))
    h3_linear = tf.linalg.matmul(h2, params[4]) + params[5]
    h3 = tf.nn.relu(h3_linear)
    y_hat = tf.linalg.matmul(h3, params[6]) + params[7]
    return y_hat

# Cross-entropy loss function
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                      reduction='none')

Distribute parameters

Replicate the parameter list onto each device:

def get_params(params, device):
    """Copy model parameters to a specific device and make them trainable."""
    with tf.device(device):
        new_params = [tf.Variable(tf.identity(p)) for p in params]
    return new_params
devices = tf.config.list_logical_devices('GPU')
new_params = get_params(params, devices[0].name)
print('b1 weight:', new_params[1].numpy())
print('b1 grad:', new_params[1])  # No gradient yet
b1 weight: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
b1 grad: <tf.Variable 'Variable:0' shape=(20,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.], dtype=float32)>

all_reduce

Sum vectors across GPUs and broadcast the result back — the gradient-averaging primitive of data-parallel SGD. NCCL implements this efficiently in production:

def allreduce(data):
    """Sum tensors from all devices and broadcast the result back."""
    # Accumulate on the device of data[0]
    with tf.device(data[0].device):
        total = tf.add_n([tf.identity(d) for d in data])
    # Broadcast to all devices in-place
    for i in range(len(data)):
        with tf.device(data[i].device):
            data[i].assign(tf.identity(total))
devices = tf.config.list_logical_devices('GPU')
data = [tf.Variable(tf.ones((1, 2)) * (i + 1)) for i in range(2)]
print('before allreduce:\n', data[0].numpy(), '\n', data[1].numpy())
allreduce(data)
print('after allreduce:\n', data[0].numpy(), '\n', data[1].numpy())
before allreduce:
 [[1. 1.]] 
 [[2. 2.]]
after allreduce:
 [[3. 3.]] 
 [[3. 3.]]

Distribute the minibatch

Split a tensor evenly across devices:

devices = tf.config.list_logical_devices('GPU')
data = tf.range(20, dtype=tf.float32)
data = tf.reshape(data, (4, 5))
split = tf.split(data, len(devices))
print('input :', data.numpy())
print('load into', [d.name for d in devices])
print('output:', [s.numpy() for s in split])
input : [[ 0.  1.  2.  3.  4.]
 [ 5.  6.  7.  8.  9.]
 [10. 11. 12. 13. 14.]
 [15. 16. 17. 18. 19.]]
load into ['/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3']
output: [array([[0., 1., 2., 3., 4.]], dtype=float32), array([[5., 6., 7., 8., 9.]], dtype=float32), array([[10., 11., 12., 13., 14.]], dtype=float32), ...
def split_batch(X, y, devices):
    """Split `X` and `y` into multiple devices."""
    assert X.shape[0] == y.shape[0]
    return (tf.split(X, len(devices)), tf.split(y, len(devices)))

One step of multi-GPU training

Forward + backward on each replica → all_reduce gradients → update parameters identically:

def train_batch(X, y, device_params, devices, lr):
    X_shards, y_shards = split_batch(X, y, devices)
    # Compute loss and gradients independently on each GPU
    grads_list = []
    for X_shard, y_shard, device_W in zip(X_shards, y_shards, device_params):
        with tf.device(device_W[0].device):
            with tf.GradientTape() as tape:
                # SparseCategoricalCrossentropy: loss(y_true, y_pred)
                l = tf.reduce_sum(loss(y_shard, lenet(X_shard, device_W)))
            grads_list.append(tape.gradient(l, device_W))
    # All-reduce: sum gradients from every GPU onto GPU 0, then broadcast.
    # grads_list[c][i] is a plain Tensor (not a Variable), so we sum them
    # with tf.add_n and write the aggregated gradient into every replica.
    num_params = len(device_params[0])
    num_dev = len(devices)
    agg_grads = []
    for i in range(num_params):
        with tf.device(device_params[0][i].device):
            agg_grads.append(tf.add_n([
                tf.identity(grads_list[c][i]) for c in range(num_dev)]))
    # Apply SGD update on each GPU using the aggregated (all-reduced) gradient
    for device_W in device_params:
        for p, g in zip(device_W, agg_grads):
            with tf.device(p.device):
                # Denominator X.shape[0] normalizes for batch size.
                # Invariant: loss uses tf.reduce_sum (not mean) and
                # all-reduce uses tf.add_n (not mean), so dividing by
                # batch size here gives the correct per-sample gradient.
                p.assign_sub(lr / X.shape[0] * tf.identity(g))
def train(num_gpus, batch_size, lr):
    train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
    devices = tf.config.list_logical_devices('GPU')[:num_gpus]
    # Copy model parameters to each GPU
    device_params = [get_params(params, d.name) for d in devices]
    num_epochs = 10
    animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])
    timer = d2l.Timer()
    for epoch in range(num_epochs):
        timer.start()
        for X, y in train_iter:
            # TF data pipelines yield NHWC tensors — compatible with our lenet
            train_batch(X, y, device_params, devices, lr)
        timer.stop()
        # Evaluate the model on GPU 0
        animator.add(epoch + 1, (d2l.evaluate_accuracy(
            lambda x: lenet(x, device_params[0]), test_iter),))
    print(f'test acc: {animator.Y[0][-1]:.2f}, {timer.avg():.1f} sec/epoch '
          f'on {str([d.name for d in devices])}')

Single-GPU baseline

train(num_gpus=1, batch_size=256, lr=0.2)

test acc: 0.84, 2.5 sec/epoch on ['/device:GPU:0']

This gives the wall-clock reference point: one model copy, one minibatch stream, no gradient synchronization.

Two GPUs

Per-epoch time roughly halves; per-step iteration count drops because each GPU sees half the minibatch:

train(num_gpus=2, batch_size=256, lr=0.2)

test acc: 0.84, 5.1 sec/epoch on ['/device:GPU:0', '/device:GPU:1']

Recap

  • Data parallelism is the default: replicate model, split minibatch, all-reduce gradients, identical optimizer step on every GPU.
  • all_reduce is the workhorse — implemented as a ring reduction in NCCL; bandwidth-optimal for k GPUs.
  • Model parallelism is for huge models that don’t fit on a single GPU; tensor / pipeline parallelism are modern variants.