Wrapping as a Module

Batch Normalization

BatchNorm stabilizes deep nets

Batch Normalization (Ioffe & Szegedy, 2015) is the single-biggest stability win in modern deep learning.

At each layer, normalize activations within the minibatch to zero mean / unit variance, then rescale with learned \gamma and \beta:

\text{BN}(\mathbf{x}) = \gamma \cdot \frac{\mathbf{x} - \hat\mu_\mathcal{B}}{\sqrt{\hat\sigma_\mathcal{B}^2 + \epsilon}} + \beta.

Why it works

  • Lets you train much deeper nets — gradients stay well-conditioned through the depth.
  • Allows higher learning rates; mildly regularizing.
  • Test time uses running estimates of mean / variance (no minibatch then).
  • Spawned a family — LayerNorm (per-example, used in Transformers), GroupNorm, InstanceNorm.

From scratch

Compute per-channel mean and variance over the minibatch (and spatial dims, for conv); normalize, then scale + shift:

from d2l import tensorflow as d2l
import tensorflow as tf
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps):
    # Compute reciprocal of square root of the moving variance elementwise
    inv = tf.cast(tf.math.rsqrt(moving_var + eps), X.dtype)
    # Scale and shift
    inv *= gamma
    Y = X * inv + (beta - moving_mean * inv)
    return Y

Buffers for moving_mean / moving_var (updated only during training); learnable gamma / beta parameters:

class BatchNorm(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(BatchNorm, self).__init__(**kwargs)

    def build(self, input_shape):
        weight_shape = [input_shape[-1], ]
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.gamma = self.add_weight(name='gamma', shape=weight_shape,
            initializer=tf.initializers.ones, trainable=True)
        self.beta = self.add_weight(name='beta', shape=weight_shape,
            initializer=tf.initializers.zeros, trainable=True)
        # The variables that are not model parameters are initialized to 0
        self.moving_mean = self.add_weight(name='moving_mean',
            shape=weight_shape, initializer=tf.initializers.zeros,
            trainable=False)
        self.moving_variance = self.add_weight(name='moving_variance',
            shape=weight_shape, initializer=tf.initializers.ones,
            trainable=False)
        super(BatchNorm, self).build(input_shape)

    def assign_moving_average(self, variable, value):
        momentum = 0.1
        delta = (1.0 - momentum) * variable + momentum * value
        return variable.assign(delta)

    @tf.function
    def call(self, inputs, training=False):
        if training:
            axes = list(range(len(inputs.shape) - 1))
            batch_mean = tf.reduce_mean(inputs, axes, keepdims=True)
            batch_variance = tf.reduce_mean(tf.math.squared_difference(
                inputs, tf.stop_gradient(batch_mean)), axes, keepdims=True)
            batch_mean = tf.squeeze(batch_mean, axes)
            batch_variance = tf.squeeze(batch_variance, axes)
            self.assign_moving_average(
                self.moving_mean, batch_mean)
            self.assign_moving_average(
                self.moving_variance, batch_variance)
            mean, variance = batch_mean, batch_variance
        else:
            mean, variance = self.moving_mean, self.moving_variance
        output = batch_norm(inputs, moving_mean=mean, moving_var=variance,
            beta=self.beta, gamma=self.gamma, eps=1e-5)
        return output

LeNet + BatchNorm

Drop a BatchNorm layer between each conv/linear and its activation:

class BNLeNetScratch(d2l.Classifier):
    def __init__(self, lr=0.1, num_classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.net = tf.keras.models.Sequential([
            tf.keras.Input(shape=(28, 28, 1)),
            tf.keras.layers.Conv2D(filters=6, kernel_size=5),
            BatchNorm(), tf.keras.layers.Activation('sigmoid'),
            tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
            tf.keras.layers.Conv2D(filters=16, kernel_size=5),
            BatchNorm(), tf.keras.layers.Activation('sigmoid'),
            tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
            tf.keras.layers.Flatten(), tf.keras.layers.Dense(120),
            BatchNorm(), tf.keras.layers.Activation('sigmoid'),
            tf.keras.layers.Dense(84), BatchNorm(),
            tf.keras.layers.Activation('sigmoid'),
            tf.keras.layers.Dense(num_classes)])

Train

Trains noticeably faster than vanilla LeNet — same accuracy in fewer epochs:

trainer = d2l.Trainer(max_epochs=10)
data = d2l.FashionMNIST(batch_size=128)
with d2l.try_gpu():
    model = BNLeNetScratch(lr=0.5)
    trainer.fit(model, data)

After training, gamma and beta are non-trivial — the layer learned the scale/shift it wants:

tf.reshape(model.net.layers[1].gamma, (-1,)), tf.reshape(
    model.net.layers[1].beta, (-1,))
(<tf.Tensor: shape=(6,), dtype=float32, numpy=
 array([2.6595492, 3.8950062, 3.6358182, 1.747875 , 3.273372 , 2.0926003],
       dtype=float32)>,
 <tf.Tensor: shape=(6,), dtype=float32, numpy=
 array([-2.3554864, -2.0791354, -3.5127552,  1.1700865,  1.3041193,
         2.197393 ], dtype=float32)>)

The framework version

nn.BatchNorm2d for conv layers, nn.BatchNorm1d for linear layers — same idea, much faster, handles the eval/training mode switch automatically:

class BNLeNet(d2l.Classifier):
    def __init__(self, lr=0.1, num_classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.net = tf.keras.models.Sequential([
            tf.keras.Input(shape=(28, 28, 1)),
            tf.keras.layers.Conv2D(filters=6, kernel_size=5),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('sigmoid'),
            tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
            tf.keras.layers.Conv2D(filters=16, kernel_size=5),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('sigmoid'),
            tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
            tf.keras.layers.Flatten(), tf.keras.layers.Dense(120),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('sigmoid'),
            tf.keras.layers.Dense(84),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('sigmoid'),
            tf.keras.layers.Dense(num_classes)])
trainer = d2l.Trainer(max_epochs=10)
data = d2l.FashionMNIST(batch_size=128)
with d2l.try_gpu():
    model = BNLeNet(lr=0.5)
    trainer.fit(model, data)

Recap

  • BatchNorm normalizes activations to zero mean / unit variance within each minibatch, then rescales with learned \gamma, \beta.
  • Track running statistics during training; use them at inference (no minibatch at test time).
  • Enables much deeper networks, higher LRs, faster convergence; mildly regularizing.
  • Spawned a family — LayerNorm (per-example, used in Transformers), GroupNorm, InstanceNorm.