from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnpGPUs are the reason modern deep learning works at scale. A single 4090 does ~80 TFLOPs of FP16 — about a thousand times faster than a CPU on the matmul-heavy ops convolutions and attention need.
The cost: every tensor and every parameter has a device. Mix devices in one operation and you crash.
Add tensors from different devices: implicit copies are forbidden, you must copy explicitly.
RuntimeError: Expected all tensors to be on the same device,
but found at least two devices, cuda:0 and cpu!
(CpuDevice(id=0), CudaDevice(id=0), CudaDevice(id=1))
try_gpu(i) returns GPU i if it exists, else CPU. Same code runs on a laptop, a workstation, or a multi-GPU box — the device object swaps but the code stays the same:
(CudaDevice(id=0),
CpuDevice(id=0),
[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)])
Every tensor has a .device attribute:
CudaDevice(id=0)
Create directly on a device — avoids an unnecessary CPU → GPU copy:
Array([[1., 1., 1.],
[1., 1., 1.]], dtype=float32)
Tensors on different devices can’t be combined directly. The fix: explicit copy with .cuda(i) or .to(device):
[[1. 1. 1.]
[1. 1. 1.]]
[[1. 1. 1.]
[1. 1. 1.]]
.cuda(i) on a tensor already on GPU i is a no-op — the framework checks first:
False
Why this matters: a .to(device) in your training inner loop adds a cudaMemcpy round trip that can dwarf the actual computation. Copy at the boundary; keep everything inside the loop on one device.
The model is a tree of Parameter tensors. Move them all in one shot with .to(device):
The training-loop sweet spot:
device = try_gpu(0)
model = MyModel().to(device) # once, before training
opt = SGD(model.parameters(), …) # picks up device
for batch in loader:
X, y = batch[0].to(device), batch[1].to(device)
# ... forward, loss, backward, step ...
The Trainer baseline does exactly this — patch prepare_batch to call .to(device) and prepare_model to move parameters once:
@d2l.add_to_class(d2l.Trainer)
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
self.gpus = [d2l.gpu(i) for i in range(min(num_gpus, d2l.num_gpus()))]
@d2l.add_to_class(d2l.Trainer)
def prepare_batch(self, batch):
if self.gpus:
batch = [d2l.to(a, self.gpus[0]) for a in batch]
return batchregister_buffer so they move with .to(device).torch.zeros((10,)) mid-forward defaults to CPU. Use torch.zeros((10,), device=x.device) to follow the input..to(device) — otherwise its state lives on the wrong side..numpy() mid-loop forces a sync to CPU. The asynchronous CUDA stream stalls. Defer all conversions to the end of the epoch.try_gpu(i) keeps code portable across hardware.register_buffer so non-trainable state moves alongside parameters.