Compilers and Interpreters

From Eager to Graph Execution

PyTorch / MXNet imperative — eager execution. Every line of Python issues a kernel and waits. Easy to debug, but costs you Python-loop overhead and prevents whole-graph optimization.

The fix: trace or script the model into a graph, then let the framework JIT-compile it (TorchScript, MXNet Hybridize, TF @tf.function, JAX jit). Result: 10–100× less Python overhead, plus operator fusion and memory-layout optimization.

Imperative execution: each line dispatches a separate kernel.

def add(a, b):
    return a + b

def fancy_func(a, b, c, d):
    e = add(a, b)
    f = add(c, d)
    g = add(e, f)
    return g

print(fancy_func(1, 2, 3, 4))
10

Imperative vs symbolic

Imperative: Python-controlled, easy to print/debug, expensive per op. Symbolic: graph captured, compiled once, runs as fused kernels. Modern frameworks let you switch between modes:

def add_():
    return '''
def add(a, b):
    return a + b
'''

def fancy_func_():
    return '''
def fancy_func(a, b, c, d):
    e = add(a, b)
    f = add(c, d)
    g = add(e, f)
    return g
'''

def evoke_():
    return add_() + fancy_func_() + 'print(fancy_func(1, 2, 3, 4))'

prog = evoke_()
print(prog)
y = compile(prog, '', 'exec')
exec(y)

def add(a, b):
    return a + b

def fancy_func(a, b, c, d):
    e = add(a, b)
    f = add(c, d)
    g = add(e, f)
    return g
print(fancy_func(1, 2, 3, 4))
10

Hybridizing a Sequential model

Build the same MLP as a regular module, then opt into graph mode (PyTorch: torch.jit.script; MXNet: HybridSequential.hybridize(); TF: @tf.function):

from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import numpy as np

# Factory for networks
class Net(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(2)(x)
        return x

net = Net()
x = jnp.ones((1, 512))
params = net.init(jax.random.PRNGKey(0), x)
net.apply(params, x)
Array([[ 0.5303471, -0.5135098]], dtype=float32)
jitted_apply = jax.jit(net.apply)
jitted_apply(params, x)
Array([[ 0.5303472, -0.5135098]], dtype=float32)

Speedup

Wall-clock benchmark, eager vs hybridized. The exact ratio depends on model size and op count, but the win is usually substantial:

class Benchmark:
    """For measuring running time."""
    def __init__(self, description='Done'):
        self.description = description

    def __enter__(self):
        self.timer = d2l.Timer()
        return self

    def __exit__(self, *args):
        print(f'{self.description}: {self.timer.stop():.4f} sec')
with Benchmark('Without jax.jit'):
    for i in range(1000): net.apply(params, x)

jitted_apply = jax.jit(net.apply)
jitted_apply(params, x)  # Warm-up (triggers compilation)
with Benchmark('With jax.jit'):
    for i in range(1000): jitted_apply(params, x).block_until_ready()
Without jax.jit: 6.4955 sec
With jax.jit: 0.1588 sec

Serialization

A graph is portable: save it once, load and run from C++, mobile, or another language without Python in the loop. Models in production almost always ship the graph form:

from flax import serialization
param_bytes = serialization.to_bytes(params)
print(f'Serialized parameter size: {len(param_bytes)} bytes')
# We can also inspect the computation graph via jaxpr
print(jax.make_jaxpr(net.apply)(params, x))
Serialized parameter size: 658128 bytes
{ lambda ; a:f32[256] b:f32[512,256] c:f32[128] d:f32[256,128] e:f32[2] f:f32[128,2]
    g:f32[1,512]. let
    h:f32[1,256] = dot_general[dimension_numbers=(([1], [0]), ([], []))] g b
    i:f32[1,256] = reshape[dimensions=None new_sizes=(1, 256) sharding=None] a
    j:f32[1,256] = add h i
...
      symbolic_zeros=False
    ] p
    t:f32[1,2] = dot_general[dimension_numbers=(([1], [0]), ([], []))] q f
    u:f32[1,2] = reshape[dimensions=None new_sizes=(1, 2) sharding=None] e
    v:f32[1,2] = add t u
  in (v,) }

Inspecting the graph

The compiled module exposes its computation graph for inspection (or further optimization):

Recap

  • Eager Python is great for development; graph form is faster for production.
  • Hybridization = trace or script imperative code into a static graph, then JIT-compile.
  • Wins: kernel fusion, no Python overhead, deployable to C++ / mobile.
  • Costs: control flow that depends on tensor values is harder to capture; debugging is less interactive.