Pooling

Pooling summarizes local evidence

Pooling is a parameter-free downsampling operation: slide a window, replace it with a single summary value (max or mean).

Two reasons it’s everywhere:

  • Spatial aggregation — summarize over locations to answer “is there a cat anywhere in the image?”.
  • Translation invariance — a 1-pixel shift doesn’t usually change the max of a small window. Robust to small spatial perturbations.

2×2 pool with stride 2 — halves resolution, the canonical example.

Max-pooling at a glance

Same sliding-window pattern as a convolution, but the operation is max instead of multiply-and-sum:

2×2 max-pool: each output = max of a 2×2 input window. \max(0, 1, 3, 4) = 4.

Average pooling replaces max with mean. Max is the default in modern nets — it’s more selective (“did the feature fire somewhere in this region?”) and better preserves sharp activations.

Implementation

A few lines — no kernel, just a reduction over each window. Two modes: max and avg.

from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
def pool2d(X, pool_size, mode='max'):
    p_h, p_w = pool_size
    Y = jnp.zeros((X.shape[0] - p_h + 1, X.shape[1] - p_w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            if mode == 'max':
                Y = Y.at[i, j].set(X[i: i + p_h, j: j + p_w].max())
            elif mode == 'avg':
                Y = Y.at[i, j].set(X[i: i + p_h, j: j + p_w].mean())
    return Y

Verify against the figure

Max gives 4, 5, 7, 8 — matches the diagram:

X = d2l.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
pool2d(X, (2, 2))
Array([[4., 5.],
       [7., 8.]], dtype=float32)
pool2d(X, (2, 2), 'avg')
Array([[2., 3.],
       [5., 6.]], dtype=float32)

Why max gives translation invariance

A 2×2 max-pool window on [0, 1, 3, 4] returns 4. Shift the input by a pixel; window now sees [1, 0, 4, 0] — still 4.

A small shift moves which element fires, not whether some element in the window fires. As long as the feature stays inside the window, the output is unchanged.

Modern alternative: a strided convolution does the same downsampling but learns its own “pool” function.

Padding and stride for pooling

Same knobs as conv, but different defaults: a framework MaxPool2d matches stride to window size (non-overlapping pools) — we want to reduce resolution, not preserve it.

X = d2l.reshape(d2l.arange(16, dtype=d2l.float32), (1, 4, 4, 1))
X
Array([[[[ 0.],
         [ 1.],
         [ 2.],
         [ 3.]],

        [[ 4.],
...
         [11.]],

        [[12.],
         [13.],
         [14.],
         [15.]]]], dtype=float32)
# Pooling has no model parameters, hence it needs no initialization
nn.max_pool(X, window_shape=(3, 3), strides=(3, 3))
Array([[[[10.]]]], dtype=float32)

Overlapping and asymmetric pools

Override the defaults when you want overlapping pools:

X_padded = jnp.pad(X, ((0, 0), (1, 0), (1, 0), (0, 0)), mode='constant')
nn.max_pool(X_padded, window_shape=(3, 3), padding='VALID', strides=(2, 2))
Array([[[[ 5.],
         [ 7.]],

        [[13.],
         [15.]]]], dtype=float32)

Or asymmetric pools per axis:

X_padded = jnp.pad(X, ((0, 0), (0, 0), (1, 1), (0, 0)), mode='constant')
nn.max_pool(X_padded, window_shape=(2, 3), strides=(2, 3), padding='VALID')
Array([[[[ 5.],
         [ 7.]],

        [[13.],
         [15.]]]], dtype=float32)

Multi-channel pooling

Convs combine channels (input channels feed every output channel). Pooling does not:

  • Each input channel is pooled independently.
  • Output channel count = input channel count.
  • Pooling has no notion of channel mixing.
# Concatenate along `dim=3` due to channels-last syntax
X = d2l.concat([X, X + 1], 3)
X
Array([[[[ 0.,  1.],
         [ 1.,  2.],
         [ 2.,  3.],
         [ 3.,  4.]],

        [[ 4.,  5.],
...
         [11., 12.]],

        [[12., 13.],
         [13., 14.],
         [14., 15.],
         [15., 16.]]]], dtype=float32)
X_padded = jnp.pad(X, ((0, 0), (1, 0), (1, 0), (0, 0)), mode='constant')
nn.max_pool(X_padded, window_shape=(3, 3), padding='VALID', strides=(2, 2))
Array([[[[ 5.,  6.],
         [ 7.,  8.]],

        [[13., 14.],
         [15., 16.]]]], dtype=float32)

Where pooling sits in modern architectures

  • Classic CNNs (LeNet, AlexNet, VGG): pool every few conv layers to halve spatial dims; final stack is fully connected.
  • ResNet / modern: pool less often — strided convs (stride=2) handle most downsampling. One initial max-pool, then strided convs.
  • Global average pooling: at the very end, average the entire feature map per channel. Replaces the fully-connected stack with a tiny linear classifier; drastically cuts parameters. Default in ResNet, ViT classification head, etc.

Recap

  • Pooling = window-slide reduction (max or mean), no learnable parameters.
  • 2×2 max-pool with stride 2 is the classic spatial downsampler.
  • Provides small translation invariance — output unchanged under sub-window shifts.
  • Per-channel — no channel mixing.
  • Modern nets mix pooling with strided convs and end with global average pool.