from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnpInitialization isn’t cosmetic — it determines whether a deep network trains at all.
The fix: choose the scale so signal variance stays roughly constant from layer to layer.
Consider y = Wx with i.i.d. zero-mean x_i, variance \sigma_x^2, and weights with variance \sigma_w^2:
\text{Var}(y_i) = n_{\text{in}} \cdot \sigma_w^2 \cdot \sigma_x^2.
Stack L layers and the signal variance scales by (n_{\text{in}} \sigma_w^2)^L — keep it stable by picking \sigma_w^2 \approx 1/n_{\text{in}}.
Bias usually starts at 0.
Each framework picks one of these by default:
| Framework | Default for Linear/Dense |
|---|---|
| PyTorch | Kaiming-uniform on weight; uniform \pm 1/\sqrt{\text{fan-in}} on bias |
| Flax (JAX) | LeCun-normal (~Kaiming for \tanh) |
| Keras (TF) | Glorot-uniform |
| MXNet | Uniform \pm 0.07 (legacy; you should override) |
Bottom line: every modern framework picks something fan-in/fan-out aware. You can usually leave it alone. Override when you need a non-standard scheme.
Override the default by walking the module tree and applying an initializer to each leaf module. PyTorch: net.apply(fn) calls fn(module) recursively for every submodule:
weight_init = nn.initializers.normal(0.01)
bias_init = nn.initializers.zeros
net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
nn.relu,
nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])
params = net.init(d2l.get_key(), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0](Array([ 0.00390247, -0.00536502, 0.01318502, 0.01437083], dtype=float32),
Array(0., dtype=float32))
Constants are an anti-pattern (kills symmetry-breaking) but illustrate the API:
weight_init = nn.initializers.constant(1)
net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
nn.relu,
nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])
params = net.init(d2l.get_key(), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0](Array([1., 1., 1., 1.], dtype=float32), Array(0., dtype=float32))
Dispatch on layer type or layer index — Xavier for the first linear, constant 42 for the second:
net = nn.Sequential([nn.Dense(8, kernel_init=nn.initializers.xavier_uniform(),
bias_init=bias_init),
nn.relu,
nn.Dense(1, kernel_init=nn.initializers.constant(42),
bias_init=bias_init)])
params = net.init(d2l.get_key(), X)
params['params']['layers_0']['kernel'][:, 0], params['params']['layers_2']['kernel'](Array([ 0.29768014, -0.38261548, 0.50331783, 0.34027493], dtype=float32),
Array([[42.],
[42.],
[42.],
[42.],
[42.],
[42.],
[42.],
[42.]], dtype=float32))
The pattern: take a (name, module) tuple, decide what to do. Same machinery used for freezing layers (requires_grad = False), discriminative learning rates, and BERT-style “warm up the head, not the backbone”.
For non-standard schemes, write the init function yourself. Here a heavy-tailed sample with thresholding:
w \sim U(-10, 10),\quad w \leftarrow w \cdot \mathbb{1}_{|w| \ge 5}.
def my_init(key, shape, dtype=jnp.float_):
data = jax.random.uniform(key, shape, minval=-10, maxval=10)
return data * (jnp.abs(data) >= 5)
net = nn.Sequential([nn.Dense(8, kernel_init=my_init), nn.relu, nn.Dense(1)])
params = net.init(d2l.get_key(), X)
print(params['params']['layers_0']['kernel'][:, :2])[[0. 5.5548573]
[0. 8.096008 ]
[7.2521663 5.118685 ]
[0. 0. ]]
For one-off surgery — loading specific weights, replacing a single layer’s tensor — assign to .data directly:
Most of the time, don’t. Cases where you should:
load_state_dict is the ultimate “initialization” override.net.apply(init_fn) and write per-type rules in the function.layer.weight.data[...] = ... for one-off tensor surgery.