def add(a, b):
return a + b
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
print(fancy_func(1, 2, 3, 4))PyTorch / MXNet imperative — eager execution. Every line of Python issues a kernel and waits. Easy to debug, but costs you Python-loop overhead and prevents whole-graph optimization.
The fix: trace or script the model into a graph, then let the framework JIT-compile it (TorchScript, MXNet Hybridize, TF @tf.function, JAX jit). Result: 10–100× less Python overhead, plus operator fusion and memory-layout optimization.
Imperative execution: each line dispatches a separate kernel.
10
Imperative: Python-controlled, easy to print/debug, expensive per op. Symbolic: graph captured, compiled once, runs as fused kernels. Modern frameworks let you switch between modes:
def add_():
return '''
def add(a, b):
return a + b
'''
def fancy_func_():
return '''
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
'''
def evoke_():
return add_() + fancy_func_() + 'print(fancy_func(1, 2, 3, 4))'
prog = evoke_()
print(prog)
y = compile(prog, '', 'exec')
exec(y)
def add(a, b):
return a + b
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
print(fancy_func(1, 2, 3, 4))
10
Build the same MLP as a regular module, then opt into graph mode (PyTorch: torch.jit.script; MXNet: HybridSequential.hybridize(); TF: @tf.function):
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import numpy as np
# Factory for networks
class Net(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(2)(x)
return x
net = Net()
x = jnp.ones((1, 512))
params = net.init(jax.random.PRNGKey(0), x)
net.apply(params, x)Array([[ 0.5303471, -0.5135098]], dtype=float32)
Wall-clock benchmark, eager vs hybridized. The exact ratio depends on model size and op count, but the win is usually substantial:
Without jax.jit: 6.4955 sec
With jax.jit: 0.1588 sec
A graph is portable: save it once, load and run from C++, mobile, or another language without Python in the loop. Models in production almost always ship the graph form:
Serialized parameter size: 658128 bytes
{ lambda ; a:f32[256] b:f32[512,256] c:f32[128] d:f32[256,128] e:f32[2] f:f32[128,2]
g:f32[1,512]. let
h:f32[1,256] = dot_general[dimension_numbers=(([1], [0]), ([], []))] g b
i:f32[1,256] = reshape[dimensions=None new_sizes=(1, 256) sharding=None] a
j:f32[1,256] = add h i
...
symbolic_zeros=False
] p
t:f32[1,2] = dot_general[dimension_numbers=(([1], [0]), ([], []))] q f
u:f32[1,2] = reshape[dimensions=None new_sizes=(1, 2) sharding=None] e
v:f32[1,2] = add t u
in (v,) }
The compiled module exposes its computation graph for inspection (or further optimization):