from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnpPooling is a parameter-free downsampling operation: slide a window, replace it with a single summary value (max or mean).
Two reasons it’s everywhere:
2×2 pool with stride 2 — halves resolution, the canonical example.
Same sliding-window pattern as a convolution, but the operation is max instead of multiply-and-sum:
2×2 max-pool: each output = max of a 2×2 input window. \max(0, 1, 3, 4) = 4.
Average pooling replaces max with mean. Max is the default in modern nets — it’s more selective (“did the feature fire somewhere in this region?”) and better preserves sharp activations.
A few lines — no kernel, just a reduction over each window. Two modes: max and avg.
def pool2d(X, pool_size, mode='max'):
p_h, p_w = pool_size
Y = jnp.zeros((X.shape[0] - p_h + 1, X.shape[1] - p_w + 1))
for i in range(Y.shape[0]):
for j in range(Y.shape[1]):
if mode == 'max':
Y = Y.at[i, j].set(X[i: i + p_h, j: j + p_w].max())
elif mode == 'avg':
Y = Y.at[i, j].set(X[i: i + p_h, j: j + p_w].mean())
return YMax gives 4, 5, 7, 8 — matches the diagram:
Array([[4., 5.],
[7., 8.]], dtype=float32)
A 2×2 max-pool window on [0, 1, 3, 4] returns 4. Shift the input by a pixel; window now sees [1, 0, 4, 0] — still 4.
A small shift moves which element fires, not whether some element in the window fires. As long as the feature stays inside the window, the output is unchanged.
Modern alternative: a strided convolution does the same downsampling but learns its own “pool” function.
Same knobs as conv, but different defaults: a framework MaxPool2d matches stride to window size (non-overlapping pools) — we want to reduce resolution, not preserve it.
Array([[[[ 0.],
[ 1.],
[ 2.],
[ 3.]],
[[ 4.],
...
[11.]],
[[12.],
[13.],
[14.],
[15.]]]], dtype=float32)
Override the defaults when you want overlapping pools:
Array([[[[ 5.],
[ 7.]],
[[13.],
[15.]]]], dtype=float32)
Convs combine channels (input channels feed every output channel). Pooling does not:
Array([[[[ 0., 1.],
[ 1., 2.],
[ 2., 3.],
[ 3., 4.]],
[[ 4., 5.],
...
[11., 12.]],
[[12., 13.],
[13., 14.],
[14., 15.],
[15., 16.]]]], dtype=float32)
stride=2) handle most downsampling. One initial max-pool, then strided convs.