%matplotlib inline
from d2l import torch as d2l
import torch
from torch import nn
from torch.nn import functional as FA single GPU can train ResNet on ImageNet — slowly. Modern large models need many GPUs. Three ways to split work:
Data parallelism is the default for everyday training.
Original, network partitioning, layerwise partitioning, data parallelism.
Each GPU computes a forward + backward pass on its slice of the minibatch. After the backward pass, gradients are averaged across GPUs (all_reduce). Optimizer step then runs identically on every GPU, keeping replicas in sync.
Data-parallel SGD on 2 GPUs: split data, compute gradients independently, all-reduce, then update.
Tiny LeNet for the demo — small enough to fit on each GPU many times over:
# Initialize model parameters
scale = 0.01
W1 = torch.randn(size=(20, 1, 3, 3)) * scale
b1 = torch.zeros(20)
W2 = torch.randn(size=(50, 20, 5, 5)) * scale
b2 = torch.zeros(50)
W3 = torch.randn(size=(800, 128)) * scale
b3 = torch.zeros(128)
W4 = torch.randn(size=(128, 10)) * scale
b4 = torch.zeros(10)
params = [W1, b1, W2, b2, W3, b3, W4, b4]
# Define the model
def lenet(X, params):
h1_conv = F.conv2d(input=X, weight=params[0], bias=params[1])
h1_activation = F.relu(h1_conv)
h1 = F.avg_pool2d(input=h1_activation, kernel_size=(2, 2), stride=(2, 2))
h2_conv = F.conv2d(input=h1, weight=params[2], bias=params[3])
h2_activation = F.relu(h2_conv)
h2 = F.avg_pool2d(input=h2_activation, kernel_size=(2, 2), stride=(2, 2))
h2 = h2.reshape(h2.shape[0], -1)
h3_linear = torch.mm(h2, params[4]) + params[5]
h3 = F.relu(h3_linear)
y_hat = torch.mm(h3, params[6]) + params[7]
return y_hat
# Cross-entropy loss function
loss = nn.CrossEntropyLoss(reduction='none')Replicate the parameter list onto each device:
Sum vectors across GPUs and broadcast the result back — the gradient-averaging primitive of data-parallel SGD. NCCL implements this efficiently in production:
before allreduce:
tensor([[1., 1.]], device='cuda:0')
tensor([[2., 2.]], device='cuda:1')
after allreduce:
tensor([[3., 3.]], device='cuda:0')
tensor([[3., 3.]], device='cuda:1')
Split a tensor evenly across devices:
input : tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]])
load into [device(type='cuda', index=0), device(type='cuda', index=1)]
output: (tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]], device='cuda:0'), tensor([[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]], device='cuda:1'))
Forward + backward on each replica → all_reduce gradients → update parameters identically:
def train_batch(X, y, device_params, devices, lr):
X_shards, y_shards = split_batch(X, y, devices)
# Loss is calculated separately on each GPU
ls = [loss(lenet(X_shard, device_W), y_shard).sum()
for X_shard, y_shard, device_W in zip(
X_shards, y_shards, device_params)]
for l in ls: # Backpropagation is performed separately on each GPU
l.backward()
# Sum all gradients from each GPU and broadcast them to all GPUs
with torch.no_grad():
for i in range(len(device_params[0])):
allreduce([device_params[c][i].grad for c in range(len(devices))])
# The model parameters are updated separately on each GPU
for param in device_params:
d2l.sgd(param, lr, X.shape[0]) # Here, we use a full-size batchdef train(num_gpus, batch_size, lr):
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
devices = [d2l.try_gpu(i) for i in range(num_gpus)]
# Copy model parameters to `num_gpus` GPUs
device_params = [get_params(params, d) for d in devices]
num_epochs = 10
animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])
timer = d2l.Timer()
for epoch in range(num_epochs):
timer.start()
for X, y in train_iter:
# Perform multi-GPU training for a single minibatch
train_batch(X, y, device_params, devices, lr)
torch.cuda.synchronize()
timer.stop()
# Evaluate the model on GPU 0
animator.add(epoch + 1, (d2l.evaluate_accuracy_gpu(
lambda x: lenet(x, device_params[0]), test_iter, devices[0]),))
print(f'test acc: {animator.Y[0][-1]:.2f}, {timer.avg():.1f} sec/epoch '
f'on {str(devices)}')test acc: 0.83, 1.7 sec/epoch on [device(type='cuda', index=0)]
This gives the wall-clock reference point: one model copy, one minibatch stream, no gradient synchronization.
Per-epoch time roughly halves; per-step iteration count drops because each GPU sees half the minibatch:
test acc: 0.84, 2.3 sec/epoch on [device(type='cuda', index=0), device(type='cuda', index=1)]
all_reduce is the workhorse — implemented as a ring reduction in NCCL; bandwidth-optimal for k GPUs.