Parameter Initialization

Initialization Matters

Initialization isn’t cosmetic — it determines whether a deep network trains at all.

  • Zero weights → every neuron in a layer computes the same thing, gets the same gradient (“symmetry breaking” fails).
  • Too large → activations blow up.
  • Too small → activations and gradients vanish through depth.

The fix: choose the scale so signal variance stays roughly constant from layer to layer.

Why scale matters

Consider y = Wx with i.i.d. zero-mean x_i, variance \sigma_x^2, and weights with variance \sigma_w^2:

\text{Var}(y_i) = n_{\text{in}} \cdot \sigma_w^2 \cdot \sigma_x^2.

Stack L layers and the signal variance scales by (n_{\text{in}} \sigma_w^2)^L — keep it stable by picking \sigma_w^2 \approx 1/n_{\text{in}}.

Xavier and Kaiming

  • Xavier (Glorot 2010)\sigma_w^2 = \dfrac{2}{n_{\text{in}} + n_{\text{out}}}. Balances forward variance with backward gradient variance. Designed for \tanh / sigmoid.
  • Kaiming/He (2015)\sigma_w^2 = \dfrac{2}{n_{\text{in}}}. Compensates for ReLU killing half the signal. Default for modern CNNs / Transformers.

Bias usually starts at 0.

The framework defaults

Each framework picks one of these by default:

Framework Default for Linear/Dense
PyTorch Kaiming-uniform on weight; uniform \pm 1/\sqrt{\text{fan-in}} on bias
Flax (JAX) LeCun-normal (~Kaiming for \tanh)
Keras (TF) Glorot-uniform
MXNet Uniform \pm 0.07 (legacy; you should override)

Bottom line: every modern framework picks something fan-in/fan-out aware. You can usually leave it alone. Override when you need a non-standard scheme.

Setup

from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)])
X = jax.random.uniform(d2l.get_key(), (2, 4))
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape
(2, 1)

The universal pattern: net.apply(fn)

Override the default by walking the module tree and applying an initializer to each leaf module. PyTorch: net.apply(fn) calls fn(module) recursively for every submodule:

weight_init = nn.initializers.normal(0.01)
bias_init = nn.initializers.zeros

net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])

params = net.init(d2l.get_key(), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
(Array([ 0.00390247, -0.00536502,  0.01318502,  0.01437083], dtype=float32),
 Array(0., dtype=float32))

Constants are an anti-pattern (kills symmetry-breaking) but illustrate the API:

weight_init = nn.initializers.constant(1)

net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])

params = net.init(d2l.get_key(), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
(Array([1., 1., 1., 1.], dtype=float32), Array(0., dtype=float32))

Different scheme per layer

Dispatch on layer type or layer index — Xavier for the first linear, constant 42 for the second:

net = nn.Sequential([nn.Dense(8, kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=nn.initializers.constant(42),
                              bias_init=bias_init)])

params = net.init(d2l.get_key(), X)
params['params']['layers_0']['kernel'][:, 0], params['params']['layers_2']['kernel']
(Array([ 0.29768014, -0.38261548,  0.50331783,  0.34027493], dtype=float32),
 Array([[42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.]], dtype=float32))

The pattern: take a (name, module) tuple, decide what to do. Same machinery used for freezing layers (requires_grad = False), discriminative learning rates, and BERT-style “warm up the head, not the backbone”.

Custom initialization

For non-standard schemes, write the init function yourself. Here a heavy-tailed sample with thresholding:

w \sim U(-10, 10),\quad w \leftarrow w \cdot \mathbb{1}_{|w| \ge 5}.

def my_init(key, shape, dtype=jnp.float_):
    data = jax.random.uniform(key, shape, minval=-10, maxval=10)
    return data * (jnp.abs(data) >= 5)

net = nn.Sequential([nn.Dense(8, kernel_init=my_init), nn.relu, nn.Dense(1)])
params = net.init(d2l.get_key(), X)
print(params['params']['layers_0']['kernel'][:, :2])
[[0.        5.5548573]
 [0.        8.096008 ]
 [7.2521663 5.118685 ]
 [0.        0.       ]]

For one-off surgery — loading specific weights, replacing a single layer’s tensor — assign to .data directly:

When to override defaults

Most of the time, don’t. Cases where you should:

  • Loading pretrained weightsload_state_dict is the ultimate “initialization” override.
  • Custom layers — you wrote a new layer with a different variance budget, e.g. small-residual init that puts ResBlocks at near-identity.
  • Reproducibility / ablations — comparing init schemes systematically.
  • Architecture-specific tricks — e.g. zero-init the last BN \gamma in each ResNet block (FixUp / Skip-init).

Recap

  • Init scale matters: set it so signal variance stays roughly constant across depth.
  • Xavier: \frac{2}{n_{in}+n_{out}} for \tanh/sigmoid.
  • Kaiming/He: \frac{2}{n_{in}} for ReLU.
  • Framework defaults are sane; override via net.apply(init_fn) and write per-type rules in the function.
  • Direct layer.weight.data[...] = ... for one-off tensor surgery.