from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnpA fully-connected layer on a 1-megapixel RGB image needs roughly 3 million weights per output unit — wildly wasteful, since pixel correlations are local and the same edge detector should work everywhere.
A convolutional layer swaps this for two strong inductive biases:
Thousands of parameters instead of millions, with exactly the right prior for natural images.
Slide a small kernel \mathbf{K} over the input \mathbf{X}. At each position, multiply elementwise and sum:
Y[i, j] = \sum_{a, b} X[i+a, j+b]\, K[a, b].
Cross-correlation: 3×3 input × 2×2 kernel → 2×2 output. Shaded element: 0{\cdot}0 + 1{\cdot}1 + 3{\cdot}2 + 4{\cdot}3 = 19.
The output is smaller than the input by k - 1 in each direction — same shrinking we’ll undo with padding next section.
Two nested loops over output positions. Each cell is a slice multiplied elementwise with the kernel and summed:
Verify against the figure — 3×3 input × 2×2 kernel → 2×2 output with the worked-out values:
Array([[19., 25.],
[37., 43.]], dtype=float32)
Wrap the operator as a learnable Module. Two parameters: the kernel weights and a scalar bias:
These are the only learnable parameters of a single-channel conv layer. A 3×3 conv has nine weights regardless of input size — that’s the parameter savings the inductive bias buys us.
Build an image with a vertical edge in the middle: 1s on the outsides, 0s in the middle four columns:
Array([[1., 1., 0., 0., 0., 0., 1., 1.],
[1., 1., 0., 0., 0., 0., 1., 1.],
[1., 1., 0., 0., 0., 0., 1., 1.],
[1., 1., 0., 0., 0., 0., 1., 1.],
[1., 1., 0., 0., 0., 0., 1., 1.],
[1., 1., 0., 0., 0., 0., 1., 1.]], dtype=float32)
Cross-correlate the image with the difference kernel: +1 at each white→black transition, -1 at each black→white, zero everywhere else:
Array([[ 0., 1., 0., 0., 0., -1., 0.],
[ 0., 1., 0., 0., 0., -1., 0.],
[ 0., 1., 0., 0., 0., -1., 0.],
[ 0., 1., 0., 0., 0., -1., 0.],
[ 0., 1., 0., 0., 0., -1., 0.],
[ 0., 1., 0., 0., 0., -1., 0.]], dtype=float32)
Transpose the image so the edge is now horizontal — the same kernel detects nothing:
Array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float32)
Filters are direction-sensitive. Real ConvNets stack many filters per layer to cover all directions / patterns.
We don’t have to design kernels by hand. Random init, SGD on squared error against ground truth \mathbf{Y}:
# Construct a two-dimensional convolutional layer with 1 output channel and a
# kernel of shape (1, 2). For the sake of simplicity, we ignore the bias here.
# Use a small-stddev normal init so the toy 10-step SGD has time to converge
# (Flax's default lecun_normal yields a much larger initial loss).
conv2d = nn.Conv(1, kernel_size=(1, 2), use_bias=False, padding='VALID',
kernel_init=nn.initializers.normal(stddev=0.01))
# The two-dimensional convolutional layer uses four-dimensional input and
# output in the format of (example, height, width, channel), where the batch
# size (number of examples in the batch) and the number of channels are both 1
X = X.reshape((1, 6, 8, 1))
Y = Y.reshape((1, 6, 7, 1))
lr = 3e-2 # Learning rate
params = conv2d.init(d2l.get_key(), X)
def loss(params, X, Y):
Y_hat = conv2d.apply(params, X)
return ((Y_hat - Y) ** 2).sum()
for i in range(10):
l, grads = jax.value_and_grad(loss)(params, X, Y)
# Update the kernel
params = jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)
if (i + 1) % 2 == 0:
print(f'epoch {i + 1}, loss {l:.3f}')epoch 2, loss 5.003
epoch 4, loss 0.842
epoch 6, loss 0.142
epoch 8, loss 0.024
epoch 10, loss 0.004
The receptive field of an output cell = the set of input positions that can affect it.
Local kernels + depth = global reach without the parameter cost of large kernels.
Hubel & Wiesel-style filters in the visual cortex. Trained CNN filters look strikingly similar.