Parameter Initialization

Initialization Matters

Initialization isn’t cosmetic — it determines whether a deep network trains at all.

  • Zero weights → every neuron in a layer computes the same thing, gets the same gradient (“symmetry breaking” fails).
  • Too large → activations blow up.
  • Too small → activations and gradients vanish through depth.

The fix: choose the scale so signal variance stays roughly constant from layer to layer.

Why scale matters

Consider y = Wx with i.i.d. zero-mean x_i, variance \sigma_x^2, and weights with variance \sigma_w^2:

\text{Var}(y_i) = n_{\text{in}} \cdot \sigma_w^2 \cdot \sigma_x^2.

Stack L layers and the signal variance scales by (n_{\text{in}} \sigma_w^2)^L — keep it stable by picking \sigma_w^2 \approx 1/n_{\text{in}}.

Xavier and Kaiming

  • Xavier (Glorot 2010)\sigma_w^2 = \dfrac{2}{n_{\text{in}} + n_{\text{out}}}. Balances forward variance with backward gradient variance. Designed for \tanh / sigmoid.
  • Kaiming/He (2015)\sigma_w^2 = \dfrac{2}{n_{\text{in}}}. Compensates for ReLU killing half the signal. Default for modern CNNs / Transformers.

Bias usually starts at 0.

The framework defaults

Each framework picks one of these by default:

Framework Default for Linear/Dense
PyTorch Kaiming-uniform on weight; uniform \pm 1/\sqrt{\text{fan-in}} on bias
Flax (JAX) LeCun-normal (~Kaiming for \tanh)
Keras (TF) Glorot-uniform
MXNet Uniform \pm 0.07 (legacy; you should override)

Bottom line: every modern framework picks something fan-in/fan-out aware. You can usually leave it alone. Override when you need a non-standard scheme.

Setup

import torch
from torch import nn
net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))
X = torch.rand(size=(2, 4))
net(X).shape
torch.Size([2, 1])

The universal pattern: net.apply(fn)

Override the default by walking the module tree and applying an initializer to each leaf module. PyTorch: net.apply(fn) calls fn(module) recursively for every submodule:

def init_normal(module):
    if type(module) == nn.Linear:
        nn.init.normal_(module.weight, mean=0, std=0.01)
        nn.init.zeros_(module.bias)

net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]
(tensor([ 0.0070, -0.0096,  0.0008,  0.0114]), tensor(0.))

Constants are an anti-pattern (kills symmetry-breaking) but illustrate the API:

def init_constant(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 1)
        nn.init.zeros_(module.bias)

net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]
(tensor([1., 1., 1., 1.]), tensor(0.))

Different scheme per layer

Dispatch on layer type or layer index — Xavier for the first linear, constant 42 for the second:

def init_xavier(module):
    if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight)

def init_42(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 42)

net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)
tensor([ 0.1768,  0.0041,  0.6029, -0.2218])
tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])

The pattern: take a (name, module) tuple, decide what to do. Same machinery used for freezing layers (requires_grad = False), discriminative learning rates, and BERT-style “warm up the head, not the backbone”.

Custom initialization

For non-standard schemes, write the init function yourself. Here a heavy-tailed sample with thresholding:

w \sim U(-10, 10),\quad w \leftarrow w \cdot \mathbb{1}_{|w| \ge 5}.

def my_init(module):
    if type(module) == nn.Linear:
        print("Init", *[(name, param.shape)
                        for name, param in module.named_parameters()][0])
        nn.init.uniform_(module.weight, -10, 10)
        module.weight.data *= module.weight.data.abs() >= 5

net.apply(my_init)
net[0].weight[:2]
Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])
tensor([[ 0.0000, -9.7206,  8.7272, -7.9359],
        [-0.0000, -0.0000,  0.0000, -0.0000]], grad_fn=<SliceBackward0>)

For one-off surgery — loading specific weights, replacing a single layer’s tensor — assign to .data directly:

net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]
tensor([42.0000, -8.7206,  9.7272, -6.9359])

When to override defaults

Most of the time, don’t. Cases where you should:

  • Loading pretrained weightsload_state_dict is the ultimate “initialization” override.
  • Custom layers — you wrote a new layer with a different variance budget, e.g. small-residual init that puts ResBlocks at near-identity.
  • Reproducibility / ablations — comparing init schemes systematically.
  • Architecture-specific tricks — e.g. zero-init the last BN \gamma in each ResNet block (FixUp / Skip-init).

Recap

  • Init scale matters: set it so signal variance stays roughly constant across depth.
  • Xavier: \frac{2}{n_{in}+n_{out}} for \tanh/sigmoid.
  • Kaiming/He: \frac{2}{n_{in}} for ReLU.
  • Framework defaults are sane; override via net.apply(init_fn) and write per-type rules in the function.
  • Direct layer.weight.data[...] = ... for one-off tensor surgery.