%matplotlib inline
from d2l import tensorflow as d2l
import tensorflow as tf
import kerasA single GPU can train ResNet on ImageNet — slowly. Modern large models need many GPUs. Three ways to split work:
Data parallelism is the default for everyday training.
Original, network partitioning, layerwise partitioning, data parallelism.
Each GPU computes a forward + backward pass on its slice of the minibatch. After the backward pass, gradients are averaged across GPUs (all_reduce). Optimizer step then runs identically on every GPU, keeping replicas in sync.
Data-parallel SGD on 2 GPUs: split data, compute gradients independently, all-reduce, then update.
Tiny LeNet for the demo — small enough to fit on each GPU many times over:
# Initialize model parameters (NHWC layout for TF)
scale = 0.01
# W shape: (H, W, in_channels, out_channels) for TF conv2d
W1 = tf.Variable(tf.random.normal(shape=(3, 3, 1, 20)) * scale)
b1 = tf.Variable(tf.zeros(20))
W2 = tf.Variable(tf.random.normal(shape=(5, 5, 20, 50)) * scale)
b2 = tf.Variable(tf.zeros(50))
W3 = tf.Variable(tf.random.normal(shape=(800, 128)) * scale)
b3 = tf.Variable(tf.zeros(128))
W4 = tf.Variable(tf.random.normal(shape=(128, 10)) * scale)
b4 = tf.Variable(tf.zeros(10))
params = [W1, b1, W2, b2, W3, b3, W4, b4]
# Define the model (inputs in NHWC format)
def lenet(X, params):
h1_conv = tf.nn.conv2d(X, params[0], strides=1, padding='VALID') + params[1]
h1_activation = tf.nn.relu(h1_conv)
h1 = tf.nn.avg_pool2d(h1_activation, ksize=2, strides=2, padding='VALID')
h2_conv = tf.nn.conv2d(h1, params[2], strides=1, padding='VALID') + params[3]
h2_activation = tf.nn.relu(h2_conv)
h2 = tf.nn.avg_pool2d(h2_activation, ksize=2, strides=2, padding='VALID')
h2 = tf.reshape(h2, (tf.shape(h2)[0], -1))
h3_linear = tf.linalg.matmul(h2, params[4]) + params[5]
h3 = tf.nn.relu(h3_linear)
y_hat = tf.linalg.matmul(h3, params[6]) + params[7]
return y_hat
# Cross-entropy loss function
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
reduction='none')Replicate the parameter list onto each device:
b1 weight: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
b1 grad: <tf.Variable 'Variable:0' shape=(20,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.], dtype=float32)>
Sum vectors across GPUs and broadcast the result back — the gradient-averaging primitive of data-parallel SGD. NCCL implements this efficiently in production:
def allreduce(data):
"""Sum tensors from all devices and broadcast the result back."""
# Accumulate on the device of data[0]
with tf.device(data[0].device):
total = tf.add_n([tf.identity(d) for d in data])
# Broadcast to all devices in-place
for i in range(len(data)):
with tf.device(data[i].device):
data[i].assign(tf.identity(total))before allreduce:
[[1. 1.]]
[[2. 2.]]
after allreduce:
[[3. 3.]]
[[3. 3.]]
Split a tensor evenly across devices:
input : [[ 0. 1. 2. 3. 4.]
[ 5. 6. 7. 8. 9.]
[10. 11. 12. 13. 14.]
[15. 16. 17. 18. 19.]]
load into ['/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3']
output: [array([[0., 1., 2., 3., 4.]], dtype=float32), array([[5., 6., 7., 8., 9.]], dtype=float32), array([[10., 11., 12., 13., 14.]], dtype=float32), ...
Forward + backward on each replica → all_reduce gradients → update parameters identically:
def train_batch(X, y, device_params, devices, lr):
X_shards, y_shards = split_batch(X, y, devices)
# Compute loss and gradients independently on each GPU
grads_list = []
for X_shard, y_shard, device_W in zip(X_shards, y_shards, device_params):
with tf.device(device_W[0].device):
with tf.GradientTape() as tape:
# SparseCategoricalCrossentropy: loss(y_true, y_pred)
l = tf.reduce_sum(loss(y_shard, lenet(X_shard, device_W)))
grads_list.append(tape.gradient(l, device_W))
# All-reduce: sum gradients from every GPU onto GPU 0, then broadcast.
# grads_list[c][i] is a plain Tensor (not a Variable), so we sum them
# with tf.add_n and write the aggregated gradient into every replica.
num_params = len(device_params[0])
num_dev = len(devices)
agg_grads = []
for i in range(num_params):
with tf.device(device_params[0][i].device):
agg_grads.append(tf.add_n([
tf.identity(grads_list[c][i]) for c in range(num_dev)]))
# Apply SGD update on each GPU using the aggregated (all-reduced) gradient
for device_W in device_params:
for p, g in zip(device_W, agg_grads):
with tf.device(p.device):
# Denominator X.shape[0] normalizes for batch size.
# Invariant: loss uses tf.reduce_sum (not mean) and
# all-reduce uses tf.add_n (not mean), so dividing by
# batch size here gives the correct per-sample gradient.
p.assign_sub(lr / X.shape[0] * tf.identity(g))def train(num_gpus, batch_size, lr):
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
devices = tf.config.list_logical_devices('GPU')[:num_gpus]
# Copy model parameters to each GPU
device_params = [get_params(params, d.name) for d in devices]
num_epochs = 10
animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])
timer = d2l.Timer()
for epoch in range(num_epochs):
timer.start()
for X, y in train_iter:
# TF data pipelines yield NHWC tensors — compatible with our lenet
train_batch(X, y, device_params, devices, lr)
timer.stop()
# Evaluate the model on GPU 0
animator.add(epoch + 1, (d2l.evaluate_accuracy(
lambda x: lenet(x, device_params[0]), test_iter),))
print(f'test acc: {animator.Y[0][-1]:.2f}, {timer.avg():.1f} sec/epoch '
f'on {str([d.name for d in devices])}')test acc: 0.84, 2.5 sec/epoch on ['/device:GPU:0']
This gives the wall-clock reference point: one model copy, one minibatch stream, no gradient synchronization.
Per-epoch time roughly halves; per-step iteration count drops because each GPU sees half the minibatch:
test acc: 0.84, 5.1 sec/epoch on ['/device:GPU:0', '/device:GPU:1']
all_reduce is the workhorse — implemented as a ring reduction in NCCL; bandwidth-optimal for k GPUs.