import tensorflow as tfInitialization isn’t cosmetic — it determines whether a deep network trains at all.
The fix: choose the scale so signal variance stays roughly constant from layer to layer.
Consider y = Wx with i.i.d. zero-mean x_i, variance \sigma_x^2, and weights with variance \sigma_w^2:
\text{Var}(y_i) = n_{\text{in}} \cdot \sigma_w^2 \cdot \sigma_x^2.
Stack L layers and the signal variance scales by (n_{\text{in}} \sigma_w^2)^L — keep it stable by picking \sigma_w^2 \approx 1/n_{\text{in}}.
Bias usually starts at 0.
Each framework picks one of these by default:
| Framework | Default for Linear/Dense |
|---|---|
| PyTorch | Kaiming-uniform on weight; uniform \pm 1/\sqrt{\text{fan-in}} on bias |
| Flax (JAX) | LeCun-normal (~Kaiming for \tanh) |
| Keras (TF) | Glorot-uniform |
| MXNet | Uniform \pm 0.07 (legacy; you should override) |
Bottom line: every modern framework picks something fan-in/fan-out aware. You can usually leave it alone. Override when you need a non-standard scheme.
Override the default by walking the module tree and applying an initializer to each leaf module. PyTorch: net.apply(fn) calls fn(module) recursively for every submodule:
(<Variable path=sequential_1/dense_2/kernel, shape=(4, 4), dtype=float32, value=[[ 0.01562339 0.01346171 0.00099036 -0.02101069]
[-0.00057783 0.00708083 0.01856427 0.0180314 ]
[-0.01718091 -0.02112225 0.00746007 0.00580992]
[-0.01292037 -0.01140011 -0.00511245 -0.00499784]]>,
<Variable path=sequential_1/dense_2/bias, shape=(4,), dtype=float32, value=[0. 0. 0. 0.]>)
Constants are an anti-pattern (kills symmetry-breaking) but illustrate the API:
(<Variable path=sequential_2/dense_4/kernel, shape=(4, 4), dtype=float32, value=[[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]>,
<Variable path=sequential_2/dense_4/bias, shape=(4,), dtype=float32, value=[0. 0. 0. 0.]>)
Dispatch on layer type or layer index — Xavier for the first linear, constant 42 for the second:
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4,
activation=tf.nn.relu,
kernel_initializer=tf.keras.initializers.GlorotUniform()),
tf.keras.layers.Dense(
1, kernel_initializer=tf.keras.initializers.Constant(42)),
])
net(X)
print(net.layers[1].weights[0])
print(net.layers[2].weights[0])<Variable path=sequential_3/dense_6/kernel, shape=(4, 4), dtype=float32, value=[[-0.72431743 0.63194484 -0.5133345 -0.6819028 ]
[ 0.7903401 -0.3202439 -0.5129437 -0.19256562]
[ 0.21278745 -0.769619 0.39522105 0.3907023 ]
[-0.16910183 0.79928845 -0.6808882 -0.68104863]]>
<Variable path=sequential_3/dense_7/kernel, shape=(4, 1), dtype=float32, value=[[42.]
[42.]
[42.]
[42.]]>
The pattern: take a (name, module) tuple, decide what to do. Same machinery used for freezing layers (requires_grad = False), discriminative learning rates, and BERT-style “warm up the head, not the backbone”.
For non-standard schemes, write the init function yourself. Here a heavy-tailed sample with thresholding:
w \sim U(-10, 10),\quad w \leftarrow w \cdot \mathbb{1}_{|w| \ge 5}.
class MyInit(tf.keras.initializers.Initializer):
def __call__(self, shape, dtype=None):
data=tf.random.uniform(shape, -10, 10, dtype=dtype)
factor=(tf.abs(data) >= 5)
factor=tf.cast(factor, tf.float32)
return data * factor
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4,
activation=tf.nn.relu,
kernel_initializer=MyInit()),
tf.keras.layers.Dense(1),
])
net(X)
print(net.layers[1].weights[0])<Variable path=sequential_4/dense_8/kernel, shape=(4, 4), dtype=float32, value=[[ 0. 7.3628063 0. 0. ]
[-7.1666217 -6.9643497 0. 0. ]
[ 0. 7.9453754 -6.659901 0. ]
[-8.549326 8.907385 0. -9.55571 ]]>
For one-off surgery — loading specific weights, replacing a single layer’s tensor — assign to .data directly:
<Variable path=sequential_4/dense_8/kernel, shape=(4, 4), dtype=float32, value=[[42. 8.362806 1. 1. ]
[-6.1666217 -5.9643497 1. 1. ]
[ 1. 8.945375 -5.659901 1. ]
[-7.549326 9.907385 1. -8.55571 ]]>
Most of the time, don’t. Cases where you should:
load_state_dict is the ultimate “initialization” override.net.apply(init_fn) and write per-type rules in the function.layer.weight.data[...] = ... for one-off tensor surgery.