GPUs

Working on GPUs

GPUs are the reason modern deep learning works at scale. A single 4090 does ~80 TFLOPs of FP16 — about a thousand times faster than a CPU on the matmul-heavy ops convolutions and attention need.

The cost: every tensor and every parameter has a device. Mix devices in one operation and you crash.

Add tensors from different devices: implicit copies are forbidden, you must copy explicitly.

The two-and-a-half rules

RuntimeError: Expected all tensors to be on the same device,
but found at least two devices, cuda:0 and cpu!
  • Tensors live on a device. Cross-device operations require an explicit copy.
  • Model parameters live on a device. Move the model to the GPU before training; the optimizer’s state follows.
  • Cross-device copies are slow. Avoid them in the inner loop — copy at the boundary, keep the loop on one device.

What hardware do we have?

from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
def cpu():
    """Get the CPU device."""
    return jax.devices('cpu')[0]

def gpu(i=0):
    """Get a GPU device."""
    return jax.devices('gpu')[i]

cpu(), gpu(), gpu(1)
(CpuDevice(id=0), CudaDevice(id=0), CudaDevice(id=1))
def num_gpus():
    """Get the number of available GPUs."""
    try:
        return jax.device_count('gpu')
    except:
        return 0  # No GPU backend found

num_gpus()
4

Portable device handle

try_gpu(i) returns GPU i if it exists, else CPU. Same code runs on a laptop, a workstation, or a multi-GPU box — the device object swaps but the code stays the same:

def try_gpu(i=0):
    """Return gpu(i) if exists, otherwise return cpu()."""
    if num_gpus() >= i + 1:
        return gpu(i)
    return cpu()

def try_all_gpus():
    """Return all available GPUs, or [cpu(),] if no GPU exists."""
    return [gpu(i) for i in range(num_gpus())]

try_gpu(), try_gpu(10), try_all_gpus()
(CudaDevice(id=0),
 CpuDevice(id=0),
 [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)])

Tensors carry a device

Every tensor has a .device attribute:

x = jnp.array([1, 2, 3])
x.device
CudaDevice(id=0)

Create directly on a device — avoids an unnecessary CPU → GPU copy:

# By default JAX puts arrays to GPUs or TPUs if available
X = jax.device_put(jnp.ones((2, 3)), try_gpu())
X
Array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)
Y = jax.device_put(jax.random.uniform(jax.random.PRNGKey(0), (2, 3)),
                   try_gpu(1))
Y
Array([[0.947667  , 0.9785799 , 0.33229148],
       [0.46866846, 0.5698887 , 0.16550303]], dtype=float32)

Cross-device math: copy, then operate

Tensors on different devices can’t be combined directly. The fix: explicit copy with .cuda(i) or .to(device):

Z = jax.device_put(X, try_gpu(1))
print(X)
print(Z)
[[1. 1. 1.]
 [1. 1. 1.]]
[[1. 1. 1.]
 [1. 1. 1.]]
Y + Z
Array([[1.947667 , 1.9785799, 1.3322915],
       [1.4686685, 1.5698887, 1.165503 ]], dtype=float32)

.cuda(i) on a tensor already on GPU i is a no-op — the framework checks first:

Z2 = jax.device_put(Z, try_gpu(1))
Z2 is Z
False

Why this matters: a .to(device) in your training inner loop adds a cudaMemcpy round trip that can dwarf the actual computation. Copy at the boundary; keep everything inside the loop on one device.

Models on the GPU

The model is a tree of Parameter tensors. Move them all in one shot with .to(device):

net = nn.Sequential([nn.Dense(1)])

key1, key2 = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(key1, (10,))  # Dummy input
params = net.init(key2, x)  # Initialization call
net.apply(params, x)
Array([-0.3464166], dtype=float32)

After this, every input batch must also be on device before the forward pass:

print(jax.tree_util.tree_map(lambda x: x.device, params))
{'params': {'layers_0': {'bias': CudaDevice(id=0), 'kernel': CudaDevice(id=0)}}}

Where to put the device move

The training-loop sweet spot:

device = try_gpu(0)
model = MyModel().to(device)         # once, before training
opt = SGD(model.parameters(), …)     # picks up device

for batch in loader:
    X, y = batch[0].to(device), batch[1].to(device)
    # ... forward, loss, backward, step ...

The Trainer baseline does exactly this — patch prepare_batch to call .to(device) and prepare_model to move parameters once:

@d2l.add_to_class(d2l.Trainer)
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
    self.save_hyperparameters()
    self.gpus = [d2l.gpu(i) for i in range(min(num_gpus, d2l.num_gpus()))]

@d2l.add_to_class(d2l.Trainer)
def prepare_batch(self, batch):
    if self.gpus:
        batch = [d2l.to(a, self.gpus[0]) for a in batch]
    return batch

Common mistakes

  • Forgetting one tensor. Every tensor in the forward pass has to be on the same device. Custom buffers are the usual culprit — use register_buffer so they move with .to(device).
  • Creating tensors with torch.zeros((10,)) mid-forward defaults to CPU. Use torch.zeros((10,), device=x.device) to follow the input.
  • Optimizer set up before move. Construct the optimizer after .to(device) — otherwise its state lives on the wrong side.
  • .numpy() mid-loop forces a sync to CPU. The asynchronous CUDA stream stalls. Defer all conversions to the end of the epoch.

Recap

  • Tensors and parameters carry a device; cross-device operations require an explicit copy.
  • Move the model to the GPU once, before training; the optimizer follows its parameters.
  • try_gpu(i) keeps code portable across hardware.
  • Cross-device copies are expensive — keep the inner loop device-clean.
  • Use register_buffer so non-trainable state moves alongside parameters.