%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from jax import grad, vmapWhy was deep learning hard before 2012? Nobody could train deep networks reliably — gradients either died at zero or blew up to infinity.
Three ingredients fixed it:
This deck makes the failure modes concrete.
For an L-layer network with hidden states \mathbf{h}^{(1)}, \mathbf{h}^{(2)}, \ldots, the gradient of the loss with respect to a weight in layer \ell is
\frac{\partial \mathcal{L}}{\partial \mathbf{W}^{(\ell)}} = \underbrace{\frac{\partial \mathcal{L}}{\partial \mathbf{h}^{(L)}}}_{\text{loss}} \cdot \underbrace{\frac{\partial \mathbf{h}^{(L)}}{\partial \mathbf{h}^{(L-1)}}}_{\mathbf{M}_L} \cdots \underbrace{\frac{\partial \mathbf{h}^{(\ell+1)}}{\partial \mathbf{h}^{(\ell)}}}_{\mathbf{M}_{\ell+1}} \cdot \frac{\partial \mathbf{h}^{(\ell)}}{\partial \mathbf{W}^{(\ell)}}.
It’s a product of L - \ell Jacobian matrices. Two ways this product can misbehave:
The sigmoid’s derivative peaks at \sigma'(0) = 0.25 and collapses to zero at the tails. In a 10-layer stack 0.25^{10} \approx 10^{-6} — gradients at layer 1 are a millionth of those near the output. ReLU fixes this: derivative is exactly 1 wherever the unit is active.
Multiply 100 random 4\times4 Gaussian matrices and watch the entries:
a single matrix
[[-1.2574776 -0.4016044 -1.1213601 0.87837774]
[-0.86175495 0.34651348 0.9404431 -0.12234341]
[-1.1891836 0.07152013 -1.7588533 -1.0268484 ]
[ 1.0308878 -0.00312373 0.01677057 0.04222 ]]
after multiplying 100 matrices
[[-1.3234898e+23 2.3477504e+22 -1.3833368e+23 -1.0161722e+23]
[-2.1248584e+22 3.7692705e+21 -2.2209315e+22 -1.6314578e+22]
[-1.6165937e+23 2.8676842e+22 -1.6896933e+23 -1.2412165e+23]
[ 7.9361379e+22 -1.4077955e+22 8.2949937e+22 6.0933451e+22]]
Random Gaussian matrices have spectral radius > 1, so the product diverges. Same effect on gradients in a deep net with poorly scaled weights — loss goes to NaN in a few hundred steps.
Forward pass through a linear layer with n_{\text{in}} inputs:
o_i = \sum_{j=1}^{n_{\text{in}}} w_{ij}\, x_j.
If w_{ij} \sim \mathcal{N}(0, \sigma^2) and inputs are i.i.d. with variance \gamma^2:
\mathbb{E}[o_i] = 0,\quad \mathrm{Var}[o_i] = n_{\text{in}}\, \sigma^2\, \gamma^2.
For variance to be preserved layer-to-layer (\mathrm{Var}[o] = \gamma^2):
\boxed{\sigma^2 = \frac{1}{n_{\text{in}}}}.
Same argument for the backward pass gives \sigma^2 = 1/n_{\text{out}}. Can’t satisfy both — so Xavier averages them.
Xavier / Glorot (2010):
\sigma^2 = \frac{2}{n_{\text{in}} + n_{\text{out}}}.
Preserves variance both forward and backward. Designed for \tanh / sigmoid.
Kaiming / He (2015):
\sigma^2 = \frac{2}{n_{\text{in}}}.
Same idea, but compensates for ReLU halving the post-activation variance. Default for modern CNNs and Transformers.
Both ship as defaults in every framework. Bias starts at 0.
Set every weight to the same constant c:
Initialize randomly — even tiny noise breaks the permutation symmetry between hidden units. (SGD alone doesn’t.)
Init alone gets you “trains a 10-layer net without NaN”. Modern best practice for hundreds of layers stacks more on top:
The chapter on Modern CNNs revisits these.