from d2l import jax as d2l
import jax
from jax import numpy as jnpOnce the framework runs asynchronously and tracks dependencies, two kinds of parallelism happen for free:
Two-layer MLP scheduled across CPU and 2 GPUs — independent branches run in parallel.
You don’t write any threads. The dependency tracker does it for you. This deck quantifies the speedup.
Run the same matmul on GPU 0 and GPU 1 separately, then run both at the same time:
GPU1 time: 0.0792 sec
GPU2 time: 0.0829 sec
Compute on GPU 0 and copy the result to GPU 1 — sequential vs overlapped:
Run on GPU1: 0.0785 sec
Copy to CPU: 0.9835 sec