from d2l import mxnet as d2l
from mxnet import autograd, np, npx, init
from mxnet.gluon import nn
npx.set_np()ModuleBatch Normalization (Ioffe & Szegedy, 2015) is the single-biggest stability win in modern deep learning.
At each layer, normalize activations within the minibatch to zero mean / unit variance, then rescale with learned \gamma and \beta:
\text{BN}(\mathbf{x}) = \gamma \cdot \frac{\mathbf{x} - \hat\mu_\mathcal{B}}{\sqrt{\hat\sigma_\mathcal{B}^2 + \epsilon}} + \beta.
Compute per-channel mean and variance over the minibatch (and spatial dims, for conv); normalize, then scale + shift:
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
# Use autograd to determine whether we are in training mode
if not autograd.is_training():
# In prediction mode, use mean and variance obtained by moving average
X_hat = (X - moving_mean) / np.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# When using a fully connected layer, calculate the mean and
# variance on the feature dimension
mean = X.mean(axis=0)
var = ((X - mean) ** 2).mean(axis=0)
else:
# When using a two-dimensional convolutional layer, calculate the
# mean and variance on the channel dimension (axis=1). Here we
# need to maintain the shape of X, so that the broadcasting
# operation can be carried out later
mean = X.mean(axis=(0, 2, 3), keepdims=True)
var = ((X - mean) ** 2).mean(axis=(0, 2, 3), keepdims=True)
# In training mode, the current mean and variance are used
X_hat = (X - mean) / np.sqrt(var + eps)
# Update the mean and variance using moving average
moving_mean = (1.0 - momentum) * moving_mean + momentum * mean
moving_var = (1.0 - momentum) * moving_var + momentum * var
Y = gamma * X_hat + beta # Scale and shift
return Y, moving_mean, moving_varBuffers for moving_mean / moving_var (updated only during training); learnable gamma / beta parameters:
from mxnet import gluon
class BatchNorm(nn.Block):
# `num_features`: the number of outputs for a fully connected layer
# or the number of output channels for a convolutional layer. `num_dims`:
# 2 for a fully connected layer and 4 for a convolutional layer
def __init__(self, num_features, num_dims):
super().__init__()
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
# The scale parameter and the shift parameter (model parameters) are
# initialized to 1 and 0, respectively
self.gamma = gluon.Parameter('gamma', shape=shape, init=init.One())
self.beta = gluon.Parameter('beta', shape=shape, init=init.Zero())
# The variables that are not model parameters are initialized to 0 and
# 1
self.moving_mean = np.zeros(shape)
self.moving_var = np.ones(shape)
def forward(self, X):
# If `X` is not on the main memory, copy `moving_mean` and
# `moving_var` to the device where `X` is located
if self.moving_mean.ctx != X.ctx:
self.moving_mean = self.moving_mean.copyto(X.ctx)
self.moving_var = self.moving_var.copyto(X.ctx)
# Save the updated `moving_mean` and `moving_var`
Y, self.moving_mean, self.moving_var = batch_norm(
X, self.gamma.data(), self.beta.data(), self.moving_mean,
self.moving_var, eps=1e-5, momentum=0.1)
return YDrop a BatchNorm layer between each conv/linear and its activation:
class BNLeNetScratch(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = nn.Sequential()
self.net.add(
nn.Conv2D(6, kernel_size=5), BatchNorm(6, num_dims=4),
nn.Activation('sigmoid'),
nn.AvgPool2D(pool_size=2, strides=2),
nn.Conv2D(16, kernel_size=5), BatchNorm(16, num_dims=4),
nn.Activation('sigmoid'),
nn.AvgPool2D(pool_size=2, strides=2), nn.Dense(120),
BatchNorm(120, num_dims=2), nn.Activation('sigmoid'),
nn.Dense(84), BatchNorm(84, num_dims=2),
nn.Activation('sigmoid'), nn.Dense(num_classes))
self.initialize()Trains noticeably faster than vanilla LeNet — same accuracy in fewer epochs:
nn.BatchNorm2d for conv layers, nn.BatchNorm1d for linear layers — same idea, much faster, handles the eval/training mode switch automatically:
class BNLeNet(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = nn.Sequential()
self.net.add(
nn.Conv2D(6, kernel_size=5), nn.BatchNorm(),
nn.Activation('sigmoid'),
nn.AvgPool2D(pool_size=2, strides=2),
nn.Conv2D(16, kernel_size=5), nn.BatchNorm(),
nn.Activation('sigmoid'),
nn.AvgPool2D(pool_size=2, strides=2),
nn.Dense(120), nn.BatchNorm(), nn.Activation('sigmoid'),
nn.Dense(84), nn.BatchNorm(), nn.Activation('sigmoid'),
nn.Dense(num_classes))
self.initialize()