Automatic Differentiation

Automatic Differentiation

Hand-deriving gradients for a 100-million-parameter network is a non-starter. Every modern framework ships an automatic differentiation engine that:

  • Records each operation onto a computational graph.
  • Walks the graph in reverse to apply the chain rule.
  • Returns the gradient with respect to every input you asked about — typically the model parameters.

This chapter teaches the API; the rest of the book leans on it.

A worked example

We’ll differentiate

y = 2\,\mathbf{x}^\top \mathbf{x}

with respect to the column vector \mathbf{x}. The analytic gradient is \nabla_\mathbf{x} y = 4\mathbf{x} — a useful sanity-check target.

import torch
x = torch.arange(4.0)
x
tensor([0., 1., 2., 3.])

Tracking gradients

We tell the framework to track operations on x and reserve a slot for its gradient:

# Can also create x = torch.arange(4.0, requires_grad=True)
x.requires_grad_(True)
x.grad  # The gradient is None by default

Then run the forward pass — y is built from x, so the engine records the dependency:

y = 2 * torch.dot(x, x)
y
tensor(28., grad_fn=<MulBackward0>)

Backward pass

A single call walks the recorded graph backwards:

y.backward()
x.grad
tensor([ 0.,  4.,  8., 12.])

The result lands in x.grad. Compare with the analytic answer, 4\mathbf{x}:

x.grad == 4 * x
tensor([True, True, True, True])

Resetting & re-using

Gradients accumulate by default — call .zero_() (or its equivalent) before computing a fresh gradient:

x.grad.zero_()  # Reset the gradient
y = x.sum()
y.backward()
x.grad
tensor([1., 1., 1., 1.])

For non-scalar y, the engine sums up gradients computed for each output element (or you supply weights):

x.grad.zero_()
y = x * x
y.backward(gradient=torch.ones(len(y)))  # Faster: y.sum().backward()
x.grad
tensor([0., 2., 4., 6.])

Detaching from the graph

Sometimes we want a value treated as a constant in the backward pass — e.g., the auxiliary u below should not propagate gradients into x:

x.grad.zero_()
y = x * x
u = y.detach()
z = u * x

z.sum().backward()
x.grad == u
tensor([True, True, True, True])

After detach() (or stop_gradient / lax.stop_gradient), the gradient flows around the detached tensor, not through it:

x.grad.zero_()
y.sum().backward()
x.grad == 2 * x
tensor([True, True, True, True])

Gradients through control flow

Autograd doesn’t care about Python ifs and whiles — it records whichever ops actually executed. Here’s a function whose behavior depends on its input:

def f(a):
    b = a * 2
    while b.norm() < 1000:
        b = b * 2
    if b.sum() > 0:
        c = b
    else:
        c = 100 * b
    return c

The number of while iterations and the branch taken both depend on the value of a.

…it just works

Run the function on a random scalar and ask for the gradient:

a = torch.randn(size=(), requires_grad=True)
d = f(a)
d.backward()

The gradient is correct even though the path through the function is data-dependent. Here f(a) ends up linear in a along whichever branch ran, so f'(a) = f(a) / a:

a.grad == d / a
tensor(True)

Recap

  • Mark inputs as needing gradients.
  • Run the forward pass — the engine records ops.
  • backward() (or grad()) walks the graph in reverse via the chain rule.
  • Gradients accumulate; reset between iterations.
  • detach / stop_gradient to break the graph.
  • Works through arbitrary Python control flow.