from d2l import jax as d2l
from flax import linen as nn
from functools import partial
import jax
from jax import numpy as jnp
import optaxDropout (Srivastava, Hinton et al., 2014) is the simplest and most widely used regularizer for neural networks:
During training, set each hidden unit to zero independently with probability p. Rescale the survivors by 1/(1-p). Turn it off at test time.
Counterintuitive — we actively damage the network mid-training — but the trick is rock-solid. It still ships in modern Transformers (~10% rate standard).
Modern networks are overparameterized — more weights than training examples. Without a regularizer, gradient descent happily memorizes the training set.
Two complementary reasons dropout helps:
On every minibatch we randomly zero a fraction of hidden units; the network on this iteration is a thinned subnetwork. Across iterations we sample many subnetworks:
Two of the five hidden units zeroed by a single dropout draw. Each iteration samples a different subset.
At test time dropout is off — we use the full network. Effectively we average exponentially many thinned subnetworks (a kind of cheap ensemble).
Per hidden unit h, replace with
h' = \begin{cases} 0 & \text{with probability } p, \\ \dfrac{h}{1 - p} & \text{otherwise.} \end{cases}
The rescaling 1/(1-p) is what makes \mathbb{E}[h'] = h. Without it, expected activations shrink by (1-p) during training but recover their full scale at test time → train/test mismatch.
This is “inverted dropout”; the version every modern framework uses.
Sample a Bernoulli mask, multiply, rescale:
def dropout_layer(X, dropout, key=d2l.get_key()):
# Note: `key` is bound at function-definition time (mutable default
# pattern), so this educational from-scratch dropout uses one fixed
# key for all calls. That keeps the function JIT-traceable — calling
# `d2l.get_key()` at call time would mutate `d2l._master_key` and
# leak a tracer when invoked inside a JIT'd loss. Production
# randomness should go through Flax's `nn.Dropout`, which threads
# PRNG keys via `rngs={"dropout": ...}`.
assert 0 <= dropout <= 1
if dropout == 1: return jnp.zeros_like(X)
mask = jax.random.uniform(key, X.shape) > dropout
return jnp.asarray(mask, dtype=jnp.float32) * X / (1.0 - dropout)Quick check on a 2×8 input:
dropout_p = 0: [[ 0. 1. 2. 3. 4. 5. 6. 7.]
[ 8. 9. 10. 11. 12. 13. 14. 15.]]
dropout_p = 0.5: [[ 0. 0. 0. 6. 0. 10. 12. 0.]
[ 0. 18. 0. 0. 24. 0. 28. 30.]]
dropout_p = 1: [[0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0.]]
After the activation, before the next linear layer:
Linear → ReLU → Dropout(p₁) → Linear → ReLU → Dropout(p₂) → Linear
Convention: less on early layers (low-level features need to be reliable), more later (high-level features overfit).
Typical values:
class DropoutMLPScratch(d2l.Classifier):
num_hiddens_1: int
num_hiddens_2: int
num_outputs: int
dropout_1: float
dropout_2: float
lr: float
training: bool = True
def setup(self):
self.lin1 = nn.Dense(self.num_hiddens_1)
self.lin2 = nn.Dense(self.num_hiddens_2)
self.lin3 = nn.Dense(self.num_outputs)
self.relu = nn.relu
def forward(self, X):
H1 = self.relu(self.lin1(X.reshape(X.shape[0], -1)))
if self.training:
H1 = dropout_layer(H1, self.dropout_1)
H2 = self.relu(self.lin2(H1))
if self.training:
H2 = dropout_layer(H2, self.dropout_2)
return self.lin3(H2)Two hidden layers (256 each), dropout 0.5 between them:
Validation accuracy is better than the plain MLP from the previous deck — the gap between train and test loss shrinks visibly. Dropout shines when capacity exceeds the data.
nn.Dropout(p) is a stock layer. It also handles the train vs. eval mode switch — call model.eval() and dropout becomes a no-op:
class DropoutMLP(d2l.Classifier):
num_hiddens_1: int
num_hiddens_2: int
num_outputs: int
dropout_1: float
dropout_2: float
lr: float
training: bool = True
@nn.compact
def __call__(self, X):
x = nn.relu(nn.Dense(self.num_hiddens_1)(X.reshape((X.shape[0], -1))))
x = nn.Dropout(self.dropout_1, deterministic=not self.training)(x)
x = nn.relu(nn.Dense(self.num_hiddens_2)(x))
x = nn.Dropout(self.dropout_2, deterministic=not self.training)(x)
return nn.Dense(self.num_outputs)(x)Several complementary explanations, none complete on its own:
Modern deep nets often replace dropout with BatchNorm / LayerNorm, which provides similar regularization “for free”.
But dropout remains alive and well: