Pooling

Pooling summarizes local evidence

Pooling is a parameter-free downsampling operation: slide a window, replace it with a single summary value (max or mean).

Two reasons it’s everywhere:

  • Spatial aggregation — summarize over locations to answer “is there a cat anywhere in the image?”.
  • Translation invariance — a 1-pixel shift doesn’t usually change the max of a small window. Robust to small spatial perturbations.

2×2 pool with stride 2 — halves resolution, the canonical example.

Max-pooling at a glance

Same sliding-window pattern as a convolution, but the operation is max instead of multiply-and-sum:

2×2 max-pool: each output = max of a 2×2 input window. \max(0, 1, 3, 4) = 4.

Average pooling replaces max with mean. Max is the default in modern nets — it’s more selective (“did the feature fire somewhere in this region?”) and better preserves sharp activations.

Implementation

A few lines — no kernel, just a reduction over each window. Two modes: max and avg.

from d2l import torch as d2l
import torch
from torch import nn
def pool2d(X, pool_size, mode='max'):
    p_h, p_w = pool_size
    Y = d2l.zeros((X.shape[0] - p_h + 1, X.shape[1] - p_w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            if mode == 'max':
                Y[i, j] = X[i: i + p_h, j: j + p_w].max()
            elif mode == 'avg':
                Y[i, j] = X[i: i + p_h, j: j + p_w].mean()
    return Y

Verify against the figure

Max gives 4, 5, 7, 8 — matches the diagram:

X = d2l.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
pool2d(X, (2, 2))
tensor([[4., 5.],
        [7., 8.]])
pool2d(X, (2, 2), 'avg')
tensor([[2., 3.],
        [5., 6.]])

Why max gives translation invariance

A 2×2 max-pool window on [0, 1, 3, 4] returns 4. Shift the input by a pixel; window now sees [1, 0, 4, 0] — still 4.

A small shift moves which element fires, not whether some element in the window fires. As long as the feature stays inside the window, the output is unchanged.

Modern alternative: a strided convolution does the same downsampling but learns its own “pool” function.

Padding and stride for pooling

Same knobs as conv, but different defaults: a framework MaxPool2d matches stride to window size (non-overlapping pools) — we want to reduce resolution, not preserve it.

X = d2l.reshape(d2l.arange(16, dtype=d2l.float32), (1, 1, 4, 4))
X
tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]]]])
pool2d = nn.MaxPool2d(3)
# Pooling has no model parameters, hence it needs no initialization
pool2d(X)
tensor([[[[10.]]]])

Overlapping and asymmetric pools

Override the defaults when you want overlapping pools:

pool2d = nn.MaxPool2d(3, padding=1, stride=2)
pool2d(X)
tensor([[[[ 5.,  7.],
          [13., 15.]]]])

Or asymmetric pools per axis:

pool2d = nn.MaxPool2d((2, 3), stride=(2, 3), padding=(0, 1))
pool2d(X)
tensor([[[[ 5.,  7.],
          [13., 15.]]]])

Multi-channel pooling

Convs combine channels (input channels feed every output channel). Pooling does not:

  • Each input channel is pooled independently.
  • Output channel count = input channel count.
  • Pooling has no notion of channel mixing.
X = d2l.concat((X, X + 1), 1)
X
tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]],

         [[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]]]])
pool2d = nn.MaxPool2d(3, padding=1, stride=2)
pool2d(X)
tensor([[[[ 5.,  7.],
          [13., 15.]],

         [[ 6.,  8.],
          [14., 16.]]]])

Where pooling sits in modern architectures

  • Classic CNNs (LeNet, AlexNet, VGG): pool every few conv layers to halve spatial dims; final stack is fully connected.
  • ResNet / modern: pool less often — strided convs (stride=2) handle most downsampling. One initial max-pool, then strided convs.
  • Global average pooling: at the very end, average the entire feature map per channel. Replaces the fully-connected stack with a tiny linear classifier; drastically cuts parameters. Default in ResNet, ViT classification head, etc.

Recap

  • Pooling = window-slide reduction (max or mean), no learnable parameters.
  • 2×2 max-pool with stride 2 is the classic spatial downsampler.
  • Provides small translation invariance — output unchanged under sub-window shifts.
  • Per-channel — no channel mixing.
  • Modern nets mix pooling with strided convs and end with global average pool.