Parameter Initialization

Initialization Matters

Initialization isn’t cosmetic — it determines whether a deep network trains at all.

  • Zero weights → every neuron in a layer computes the same thing, gets the same gradient (“symmetry breaking” fails).
  • Too large → activations blow up.
  • Too small → activations and gradients vanish through depth.

The fix: choose the scale so signal variance stays roughly constant from layer to layer.

Why scale matters

Consider y = Wx with i.i.d. zero-mean x_i, variance \sigma_x^2, and weights with variance \sigma_w^2:

\text{Var}(y_i) = n_{\text{in}} \cdot \sigma_w^2 \cdot \sigma_x^2.

Stack L layers and the signal variance scales by (n_{\text{in}} \sigma_w^2)^L — keep it stable by picking \sigma_w^2 \approx 1/n_{\text{in}}.

Xavier and Kaiming

  • Xavier (Glorot 2010)\sigma_w^2 = \dfrac{2}{n_{\text{in}} + n_{\text{out}}}. Balances forward variance with backward gradient variance. Designed for \tanh / sigmoid.
  • Kaiming/He (2015)\sigma_w^2 = \dfrac{2}{n_{\text{in}}}. Compensates for ReLU killing half the signal. Default for modern CNNs / Transformers.

Bias usually starts at 0.

The framework defaults

Each framework picks one of these by default:

Framework Default for Linear/Dense
PyTorch Kaiming-uniform on weight; uniform \pm 1/\sqrt{\text{fan-in}} on bias
Flax (JAX) LeCun-normal (~Kaiming for \tanh)
Keras (TF) Glorot-uniform
MXNet Uniform \pm 0.07 (legacy; you should override)

Bottom line: every modern framework picks something fan-in/fan-out aware. You can usually leave it alone. Override when you need a non-standard scheme.

Setup

import tensorflow as tf
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(4, activation=tf.nn.relu),
    tf.keras.layers.Dense(1),
])

X = tf.random.uniform((2, 4))
net(X).shape
TensorShape([2, 1])

The universal pattern: net.apply(fn)

Override the default by walking the module tree and applying an initializer to each leaf module. PyTorch: net.apply(fn) calls fn(module) recursively for every submodule:

net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4, activation=tf.nn.relu,
        kernel_initializer=tf.random_normal_initializer(mean=0, stddev=0.01),
        bias_initializer=tf.zeros_initializer()),
    tf.keras.layers.Dense(1)])

net(X)
net.weights[0], net.weights[1]
(<Variable path=sequential_1/dense_2/kernel, shape=(4, 4), dtype=float32, value=[[ 0.01562339  0.01346171  0.00099036 -0.02101069]
  [-0.00057783  0.00708083  0.01856427  0.0180314 ]
  [-0.01718091 -0.02112225  0.00746007  0.00580992]
  [-0.01292037 -0.01140011 -0.00511245 -0.00499784]]>,
 <Variable path=sequential_1/dense_2/bias, shape=(4,), dtype=float32, value=[0. 0. 0. 0.]>)

Constants are an anti-pattern (kills symmetry-breaking) but illustrate the API:

net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4, activation=tf.nn.relu,
        kernel_initializer=tf.keras.initializers.Constant(1),
        bias_initializer=tf.zeros_initializer()),
    tf.keras.layers.Dense(1),
])

net(X)
net.weights[0], net.weights[1]
(<Variable path=sequential_2/dense_4/kernel, shape=(4, 4), dtype=float32, value=[[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]>,
 <Variable path=sequential_2/dense_4/bias, shape=(4,), dtype=float32, value=[0. 0. 0. 0.]>)

Different scheme per layer

Dispatch on layer type or layer index — Xavier for the first linear, constant 42 for the second:

net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4,
        activation=tf.nn.relu,
        kernel_initializer=tf.keras.initializers.GlorotUniform()),
    tf.keras.layers.Dense(
        1, kernel_initializer=tf.keras.initializers.Constant(42)),
])

net(X)
print(net.layers[1].weights[0])
print(net.layers[2].weights[0])
<Variable path=sequential_3/dense_6/kernel, shape=(4, 4), dtype=float32, value=[[-0.72431743  0.63194484 -0.5133345  -0.6819028 ]
 [ 0.7903401  -0.3202439  -0.5129437  -0.19256562]
 [ 0.21278745 -0.769619    0.39522105  0.3907023 ]
 [-0.16910183  0.79928845 -0.6808882  -0.68104863]]>
<Variable path=sequential_3/dense_7/kernel, shape=(4, 1), dtype=float32, value=[[42.]
 [42.]
 [42.]
 [42.]]>

The pattern: take a (name, module) tuple, decide what to do. Same machinery used for freezing layers (requires_grad = False), discriminative learning rates, and BERT-style “warm up the head, not the backbone”.

Custom initialization

For non-standard schemes, write the init function yourself. Here a heavy-tailed sample with thresholding:

w \sim U(-10, 10),\quad w \leftarrow w \cdot \mathbb{1}_{|w| \ge 5}.

class MyInit(tf.keras.initializers.Initializer):
    def __call__(self, shape, dtype=None):
        data=tf.random.uniform(shape, -10, 10, dtype=dtype)
        factor=(tf.abs(data) >= 5)
        factor=tf.cast(factor, tf.float32)
        return data * factor

net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4,
        activation=tf.nn.relu,
        kernel_initializer=MyInit()),
    tf.keras.layers.Dense(1),
])

net(X)
print(net.layers[1].weights[0])
<Variable path=sequential_4/dense_8/kernel, shape=(4, 4), dtype=float32, value=[[ 0.         7.3628063  0.         0.       ]
 [-7.1666217 -6.9643497  0.         0.       ]
 [ 0.         7.9453754 -6.659901   0.       ]
 [-8.549326   8.907385   0.        -9.55571  ]]>

For one-off surgery — loading specific weights, replacing a single layer’s tensor — assign to .data directly:

net.layers[1].weights[0][:].assign(net.layers[1].weights[0] + 1)
net.layers[1].weights[0][0, 0].assign(42)
net.layers[1].weights[0]
<Variable path=sequential_4/dense_8/kernel, shape=(4, 4), dtype=float32, value=[[42.         8.362806   1.         1.       ]
 [-6.1666217 -5.9643497  1.         1.       ]
 [ 1.         8.945375  -5.659901   1.       ]
 [-7.549326   9.907385   1.        -8.55571  ]]>

When to override defaults

Most of the time, don’t. Cases where you should:

  • Loading pretrained weightsload_state_dict is the ultimate “initialization” override.
  • Custom layers — you wrote a new layer with a different variance budget, e.g. small-residual init that puts ResBlocks at near-identity.
  • Reproducibility / ablations — comparing init schemes systematically.
  • Architecture-specific tricks — e.g. zero-init the last BN \gamma in each ResNet block (FixUp / Skip-init).

Recap

  • Init scale matters: set it so signal variance stays roughly constant across depth.
  • Xavier: \frac{2}{n_{in}+n_{out}} for \tanh/sigmoid.
  • Kaiming/He: \frac{2}{n_{in}} for ReLU.
  • Framework defaults are sane; override via net.apply(init_fn) and write per-type rules in the function.
  • Direct layer.weight.data[...] = ... for one-off tensor surgery.