Stack L layers and the signal variance scales by (n_{\text{in}} \sigma_w^2)^L — keep it stable by picking \sigma_w^2 \approx 1/n_{\text{in}}.
Xavier and Kaiming
Xavier (Glorot 2010) — \sigma_w^2 = \dfrac{2}{n_{\text{in}} + n_{\text{out}}}. Balances forward variance with backward gradient variance. Designed for \tanh / sigmoid.
Kaiming/He (2015) — \sigma_w^2 = \dfrac{2}{n_{\text{in}}}. Compensates for ReLU killing half the signal. Default for modern CNNs / Transformers.
Bias usually starts at 0.
The framework defaults
Each framework picks one of these by default:
Framework
Default for Linear/Dense
PyTorch
Kaiming-uniform on weight; uniform \pm 1/\sqrt{\text{fan-in}} on bias
Flax (JAX)
LeCun-normal (~Kaiming for \tanh)
Keras (TF)
Glorot-uniform
MXNet
Uniform \pm 0.07 (legacy; you should override)
Bottom line: every modern framework picks something fan-in/fan-out aware. You can usually leave it alone. Override when you need a non-standard scheme.
Setup
from mxnet import init, np, npxfrom mxnet.gluon import nnnpx.set_np()
net = nn.Sequential()net.add(nn.Dense(8, activation='relu'))net.add(nn.Dense(1))net.initialize() # Use the default initialization methodX = np.random.uniform(size=(2, 4))net(X).shape
The universal pattern: net.apply(fn)
Override the default by walking the module tree and applying an initializer to each leaf module. PyTorch: net.apply(fn) calls fn(module) recursively for every submodule:
# Here force_reinit ensures that parameters are freshly initialized even if# they were already initialized previouslynet.initialize(init=init.Normal(sigma=0.01), force_reinit=True)net[0].weight.data()[0]
Constants are an anti-pattern (kills symmetry-breaking) but illustrate the API:
The pattern: take a (name, module) tuple, decide what to do. Same machinery used for freezing layers (requires_grad = False), discriminative learning rates, and BERT-style “warm up the head, not the backbone”.
Custom initialization
For non-standard schemes, write the init function yourself. Here a heavy-tailed sample with thresholding:
w \sim U(-10, 10),\quad w \leftarrow w \cdot \mathbb{1}_{|w| \ge 5}.
class MyInit(init.Initializer):def _init_weight(self, name, data):print('Init', name, data.shape) data[:] = np.random.uniform(-10, 10, data.shape) data *= np.abs(data) >=5net.initialize(MyInit(), force_reinit=True)net[0].weight.data()[:2]
For one-off surgery — loading specific weights, replacing a single layer’s tensor — assign to .data directly: