Implementing it

Dropout

Dropout regularizes by thinning

Dropout (Srivastava, Hinton et al., 2014) is the simplest and most widely used regularizer for neural networks:

During training, set each hidden unit to zero independently with probability p. Rescale the survivors by 1/(1-p). Turn it off at test time.

Counterintuitive — we actively damage the network mid-training — but the trick is rock-solid. It still ships in modern Transformers (~10% rate standard).

Why we need it

Modern networks are overparameterized — more weights than training examples. Without a regularizer, gradient descent happily memorizes the training set.

Two complementary reasons dropout helps:

  • Noise injection = smoothness regularization (Bishop 1995). Robustness to hidden-unit dropout forces the network to be a smoother function of its inputs.
  • Anti-co-adaptation: a unit can’t rely on any specific upstream unit being present, so it picks up signal from a broader, redundant set of features.

What dropout looks like

On every minibatch we randomly zero a fraction of hidden units; the network on this iteration is a thinned subnetwork. Across iterations we sample many subnetworks:

Two of the five hidden units zeroed by a single dropout draw. Each iteration samples a different subset.

At test time dropout is off — we use the full network. Effectively we average exponentially many thinned subnetworks (a kind of cheap ensemble).

The arithmetic: keep the expectation

Per hidden unit h, replace with

h' = \begin{cases} 0 & \text{with probability } p, \\ \dfrac{h}{1 - p} & \text{otherwise.} \end{cases}

The rescaling 1/(1-p) is what makes \mathbb{E}[h'] = h. Without it, expected activations shrink by (1-p) during training but recover their full scale at test time → train/test mismatch.

This is “inverted dropout”; the version every modern framework uses.

Setup

from d2l import jax as d2l
from flax import linen as nn
from functools import partial
import jax
from jax import numpy as jnp
import optax

Sample a Bernoulli mask, multiply, rescale:

def dropout_layer(X, dropout, key=d2l.get_key()):
    # Note: `key` is bound at function-definition time (mutable default
    # pattern), so this educational from-scratch dropout uses one fixed
    # key for all calls. That keeps the function JIT-traceable — calling
    # `d2l.get_key()` at call time would mutate `d2l._master_key` and
    # leak a tracer when invoked inside a JIT'd loss. Production
    # randomness should go through Flax's `nn.Dropout`, which threads
    # PRNG keys via `rngs={"dropout": ...}`.
    assert 0 <= dropout <= 1
    if dropout == 1: return jnp.zeros_like(X)
    mask = jax.random.uniform(key, X.shape) > dropout
    return jnp.asarray(mask, dtype=jnp.float32) * X / (1.0 - dropout)

Quick check on a 2×8 input:

X = jnp.arange(16, dtype=jnp.float32).reshape(2, 8)
print('dropout_p = 0:', dropout_layer(X, 0))
print('dropout_p = 0.5:', dropout_layer(X, 0.5))
print('dropout_p = 1:', dropout_layer(X, 1))
dropout_p = 0: [[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]]
dropout_p = 0.5: [[ 0.  0.  0.  6.  0. 10. 12.  0.]
 [ 0. 18.  0.  0. 24.  0. 28. 30.]]
dropout_p = 1: [[0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]]
  • p = 0 → identity (no dropout).
  • p = 0.5 → about half the entries zero, the rest doubled.
  • p = 1.0 → all zeros (degenerate).

Where to put dropout

After the activation, before the next linear layer:

Linear → ReLU → Dropout(p₁) → Linear → ReLU → Dropout(p₂) → Linear

Convention: less on early layers (low-level features need to be reliable), more later (high-level features overfit).

Typical values:

  • MLPs / Transformers: 0.1–0.5.
  • CNNs: 0–0.2 (BatchNorm largely supplants dropout).
  • Just before the classifier head: 0.5 is standard.

MLP with dropout

class DropoutMLPScratch(d2l.Classifier):
    num_hiddens_1: int
    num_hiddens_2: int
    num_outputs: int
    dropout_1: float
    dropout_2: float
    lr: float
    training: bool = True

    def setup(self):
        self.lin1 = nn.Dense(self.num_hiddens_1)
        self.lin2 = nn.Dense(self.num_hiddens_2)
        self.lin3 = nn.Dense(self.num_outputs)
        self.relu = nn.relu

    def forward(self, X):
        H1 = self.relu(self.lin1(X.reshape(X.shape[0], -1)))
        if self.training:
            H1 = dropout_layer(H1, self.dropout_1)
        H2 = self.relu(self.lin2(H1))
        if self.training:
            H2 = dropout_layer(H2, self.dropout_2)
        return self.lin3(H2)

Training

Two hidden layers (256 each), dropout 0.5 between them:

hparams = {'num_outputs':10, 'num_hiddens_1':256, 'num_hiddens_2':256,
           'dropout_1':0.5, 'dropout_2':0.5, 'lr':0.1}
model = DropoutMLPScratch(**hparams)
data = d2l.FashionMNIST(batch_size=256)
trainer = d2l.Trainer(max_epochs=10)
trainer.fit(model, data)

Validation accuracy is better than the plain MLP from the previous deck — the gap between train and test loss shrinks visibly. Dropout shines when capacity exceeds the data.

Framework version

nn.Dropout(p) is a stock layer. It also handles the train vs. eval mode switch — call model.eval() and dropout becomes a no-op:

class DropoutMLP(d2l.Classifier):
    num_hiddens_1: int
    num_hiddens_2: int
    num_outputs: int
    dropout_1: float
    dropout_2: float
    lr: float
    training: bool = True

    @nn.compact
    def __call__(self, X):
        x = nn.relu(nn.Dense(self.num_hiddens_1)(X.reshape((X.shape[0], -1))))
        x = nn.Dropout(self.dropout_1, deterministic=not self.training)(x)
        x = nn.relu(nn.Dense(self.num_hiddens_2)(x))
        x = nn.Dropout(self.dropout_2, deterministic=not self.training)(x)
        return nn.Dense(self.num_outputs)(x)
model = DropoutMLP(**hparams)
trainer.fit(model, data)

Why dropout works (the modern view)

Several complementary explanations, none complete on its own:

  • Bayesian model averaging — training samples a different thinned network each step; testing averages \sim 2^n subnetworks → cheap ensemble.
  • Stochastic regularization — equivalent to adding Gaussian noise; Bishop showed this is Tikhonov (\ell_2) regularization on the function.
  • Anti-co-adaptation — forces redundant features.
  • Variance bound — caps the variance the network puts into any one direction in feature space.

Dropout in 2026

Modern deep nets often replace dropout with BatchNorm / LayerNorm, which provides similar regularization “for free”.

But dropout remains alive and well:

  • Transformers — rate 0.1 by default in attention and FFN sublayers.
  • Final classifier heads — 0.5 right before the output projection is still a standard recipe.

Recap

  • Dropout: zero each hidden unit with prob p during training; rescale survivors by 1/(1-p) to preserve expectations.
  • Off at test time — full network in use.
  • Place after activation, before next linear layer; rates 0.1–0.5 typical.
  • Equivalent to (a) injecting noise = smoothness regularization, and (b) ensembling exponentially many thinned subnetworks.
  • One of the cheapest, most reliable regularizers — combines well with weight decay, layer norm, and data augmentation.