Parameter Initialization

Initialization Matters

Initialization isn’t cosmetic — it determines whether a deep network trains at all.

  • Zero weights → every neuron in a layer computes the same thing, gets the same gradient (“symmetry breaking” fails).
  • Too large → activations blow up.
  • Too small → activations and gradients vanish through depth.

The fix: choose the scale so signal variance stays roughly constant from layer to layer.

Why scale matters

Consider y = Wx with i.i.d. zero-mean x_i, variance \sigma_x^2, and weights with variance \sigma_w^2:

\text{Var}(y_i) = n_{\text{in}} \cdot \sigma_w^2 \cdot \sigma_x^2.

Stack L layers and the signal variance scales by (n_{\text{in}} \sigma_w^2)^L — keep it stable by picking \sigma_w^2 \approx 1/n_{\text{in}}.

Xavier and Kaiming

  • Xavier (Glorot 2010)\sigma_w^2 = \dfrac{2}{n_{\text{in}} + n_{\text{out}}}. Balances forward variance with backward gradient variance. Designed for \tanh / sigmoid.
  • Kaiming/He (2015)\sigma_w^2 = \dfrac{2}{n_{\text{in}}}. Compensates for ReLU killing half the signal. Default for modern CNNs / Transformers.

Bias usually starts at 0.

The framework defaults

Each framework picks one of these by default:

Framework Default for Linear/Dense
PyTorch Kaiming-uniform on weight; uniform \pm 1/\sqrt{\text{fan-in}} on bias
Flax (JAX) LeCun-normal (~Kaiming for \tanh)
Keras (TF) Glorot-uniform
MXNet Uniform \pm 0.07 (legacy; you should override)

Bottom line: every modern framework picks something fan-in/fan-out aware. You can usually leave it alone. Override when you need a non-standard scheme.

Setup

from mxnet import init, np, npx
from mxnet.gluon import nn
npx.set_np()
net = nn.Sequential()
net.add(nn.Dense(8, activation='relu'))
net.add(nn.Dense(1))
net.initialize()  # Use the default initialization method

X = np.random.uniform(size=(2, 4))
net(X).shape

The universal pattern: net.apply(fn)

Override the default by walking the module tree and applying an initializer to each leaf module. PyTorch: net.apply(fn) calls fn(module) recursively for every submodule:

# Here force_reinit ensures that parameters are freshly initialized even if
# they were already initialized previously
net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)
net[0].weight.data()[0]

Constants are an anti-pattern (kills symmetry-breaking) but illustrate the API:

net.initialize(init=init.Constant(1), force_reinit=True)
net[0].weight.data()[0]

Different scheme per layer

Dispatch on layer type or layer index — Xavier for the first linear, constant 42 for the second:

net[0].weight.initialize(init=init.Xavier(), force_reinit=True)
net[1].initialize(init=init.Constant(42), force_reinit=True)
print(net[0].weight.data()[0])
print(net[1].weight.data())

The pattern: take a (name, module) tuple, decide what to do. Same machinery used for freezing layers (requires_grad = False), discriminative learning rates, and BERT-style “warm up the head, not the backbone”.

Custom initialization

For non-standard schemes, write the init function yourself. Here a heavy-tailed sample with thresholding:

w \sim U(-10, 10),\quad w \leftarrow w \cdot \mathbb{1}_{|w| \ge 5}.

class MyInit(init.Initializer):
    def _init_weight(self, name, data):
        print('Init', name, data.shape)
        data[:] = np.random.uniform(-10, 10, data.shape)
        data *= np.abs(data) >= 5

net.initialize(MyInit(), force_reinit=True)
net[0].weight.data()[:2]

For one-off surgery — loading specific weights, replacing a single layer’s tensor — assign to .data directly:

net[0].weight.data()[:] += 1
net[0].weight.data()[0, 0] = 42
net[0].weight.data()[0]

When to override defaults

Most of the time, don’t. Cases where you should:

  • Loading pretrained weightsload_state_dict is the ultimate “initialization” override.
  • Custom layers — you wrote a new layer with a different variance budget, e.g. small-residual init that puts ResBlocks at near-identity.
  • Reproducibility / ablations — comparing init schemes systematically.
  • Architecture-specific tricks — e.g. zero-init the last BN \gamma in each ResNet block (FixUp / Skip-init).

Recap

  • Init scale matters: set it so signal variance stays roughly constant across depth.
  • Xavier: \frac{2}{n_{in}+n_{out}} for \tanh/sigmoid.
  • Kaiming/He: \frac{2}{n_{in}} for ReLU.
  • Framework defaults are sane; override via net.apply(init_fn) and write per-type rules in the function.
  • Direct layer.weight.data[...] = ... for one-off tensor surgery.