from d2l import tensorflow as d2l
import tensorflow as tfModuleBatch Normalization (Ioffe & Szegedy, 2015) is the single-biggest stability win in modern deep learning.
At each layer, normalize activations within the minibatch to zero mean / unit variance, then rescale with learned \gamma and \beta:
\text{BN}(\mathbf{x}) = \gamma \cdot \frac{\mathbf{x} - \hat\mu_\mathcal{B}}{\sqrt{\hat\sigma_\mathcal{B}^2 + \epsilon}} + \beta.
Compute per-channel mean and variance over the minibatch (and spatial dims, for conv); normalize, then scale + shift:
Buffers for moving_mean / moving_var (updated only during training); learnable gamma / beta parameters:
class BatchNorm(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(BatchNorm, self).__init__(**kwargs)
def build(self, input_shape):
weight_shape = [input_shape[-1], ]
# The scale parameter and the shift parameter (model parameters) are
# initialized to 1 and 0, respectively
self.gamma = self.add_weight(name='gamma', shape=weight_shape,
initializer=tf.initializers.ones, trainable=True)
self.beta = self.add_weight(name='beta', shape=weight_shape,
initializer=tf.initializers.zeros, trainable=True)
# The variables that are not model parameters are initialized to 0
self.moving_mean = self.add_weight(name='moving_mean',
shape=weight_shape, initializer=tf.initializers.zeros,
trainable=False)
self.moving_variance = self.add_weight(name='moving_variance',
shape=weight_shape, initializer=tf.initializers.ones,
trainable=False)
super(BatchNorm, self).build(input_shape)
def assign_moving_average(self, variable, value):
momentum = 0.1
delta = (1.0 - momentum) * variable + momentum * value
return variable.assign(delta)
@tf.function
def call(self, inputs, training=False):
if training:
axes = list(range(len(inputs.shape) - 1))
batch_mean = tf.reduce_mean(inputs, axes, keepdims=True)
batch_variance = tf.reduce_mean(tf.math.squared_difference(
inputs, tf.stop_gradient(batch_mean)), axes, keepdims=True)
batch_mean = tf.squeeze(batch_mean, axes)
batch_variance = tf.squeeze(batch_variance, axes)
self.assign_moving_average(
self.moving_mean, batch_mean)
self.assign_moving_average(
self.moving_variance, batch_variance)
mean, variance = batch_mean, batch_variance
else:
mean, variance = self.moving_mean, self.moving_variance
output = batch_norm(inputs, moving_mean=mean, moving_var=variance,
beta=self.beta, gamma=self.gamma, eps=1e-5)
return outputDrop a BatchNorm layer between each conv/linear and its activation:
class BNLeNetScratch(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = tf.keras.models.Sequential([
tf.keras.Input(shape=(28, 28, 1)),
tf.keras.layers.Conv2D(filters=6, kernel_size=5),
BatchNorm(), tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
tf.keras.layers.Conv2D(filters=16, kernel_size=5),
BatchNorm(), tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
tf.keras.layers.Flatten(), tf.keras.layers.Dense(120),
BatchNorm(), tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.Dense(84), BatchNorm(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.Dense(num_classes)])Trains noticeably faster than vanilla LeNet — same accuracy in fewer epochs:
After training, gamma and beta are non-trivial — the layer learned the scale/shift it wants:
(<tf.Tensor: shape=(6,), dtype=float32, numpy=
array([2.6595492, 3.8950062, 3.6358182, 1.747875 , 3.273372 , 2.0926003],
dtype=float32)>,
<tf.Tensor: shape=(6,), dtype=float32, numpy=
array([-2.3554864, -2.0791354, -3.5127552, 1.1700865, 1.3041193,
2.197393 ], dtype=float32)>)
nn.BatchNorm2d for conv layers, nn.BatchNorm1d for linear layers — same idea, much faster, handles the eval/training mode switch automatically:
class BNLeNet(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = tf.keras.models.Sequential([
tf.keras.Input(shape=(28, 28, 1)),
tf.keras.layers.Conv2D(filters=6, kernel_size=5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
tf.keras.layers.Conv2D(filters=16, kernel_size=5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
tf.keras.layers.Flatten(), tf.keras.layers.Dense(120),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.Dense(84),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.Dense(num_classes)])