Stack L layers and the signal variance scales by (n_{\text{in}} \sigma_w^2)^L — keep it stable by picking \sigma_w^2 \approx 1/n_{\text{in}}.
Xavier and Kaiming
Xavier (Glorot 2010) — \sigma_w^2 = \dfrac{2}{n_{\text{in}} + n_{\text{out}}}. Balances forward variance with backward gradient variance. Designed for \tanh / sigmoid.
Kaiming/He (2015) — \sigma_w^2 = \dfrac{2}{n_{\text{in}}}. Compensates for ReLU killing half the signal. Default for modern CNNs / Transformers.
Bias usually starts at 0.
The framework defaults
Each framework picks one of these by default:
Framework
Default for Linear/Dense
PyTorch
Kaiming-uniform on weight; uniform \pm 1/\sqrt{\text{fan-in}} on bias
Flax (JAX)
LeCun-normal (~Kaiming for \tanh)
Keras (TF)
Glorot-uniform
MXNet
Uniform \pm 0.07 (legacy; you should override)
Bottom line: every modern framework picks something fan-in/fan-out aware. You can usually leave it alone. Override when you need a non-standard scheme.
Setup
import torchfrom torch import nn
net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))X = torch.rand(size=(2, 4))net(X).shape
torch.Size([2, 1])
The universal pattern: net.apply(fn)
Override the default by walking the module tree and applying an initializer to each leaf module. PyTorch: net.apply(fn) calls fn(module) recursively for every submodule:
The pattern: take a (name, module) tuple, decide what to do. Same machinery used for freezing layers (requires_grad = False), discriminative learning rates, and BERT-style “warm up the head, not the backbone”.
Custom initialization
For non-standard schemes, write the init function yourself. Here a heavy-tailed sample with thresholding:
w \sim U(-10, 10),\quad w \leftarrow w \cdot \mathbb{1}_{|w| \ge 5}.