from d2l import torch as d2l
import torch
from torch import nnPooling is a parameter-free downsampling operation: slide a window, replace it with a single summary value (max or mean).
Two reasons it’s everywhere:
2×2 pool with stride 2 — halves resolution, the canonical example.
Same sliding-window pattern as a convolution, but the operation is max instead of multiply-and-sum:
2×2 max-pool: each output = max of a 2×2 input window. \max(0, 1, 3, 4) = 4.
Average pooling replaces max with mean. Max is the default in modern nets — it’s more selective (“did the feature fire somewhere in this region?”) and better preserves sharp activations.
A few lines — no kernel, just a reduction over each window. Two modes: max and avg.
def pool2d(X, pool_size, mode='max'):
p_h, p_w = pool_size
Y = d2l.zeros((X.shape[0] - p_h + 1, X.shape[1] - p_w + 1))
for i in range(Y.shape[0]):
for j in range(Y.shape[1]):
if mode == 'max':
Y[i, j] = X[i: i + p_h, j: j + p_w].max()
elif mode == 'avg':
Y[i, j] = X[i: i + p_h, j: j + p_w].mean()
return YMax gives 4, 5, 7, 8 — matches the diagram:
tensor([[4., 5.],
[7., 8.]])
A 2×2 max-pool window on [0, 1, 3, 4] returns 4. Shift the input by a pixel; window now sees [1, 0, 4, 0] — still 4.
A small shift moves which element fires, not whether some element in the window fires. As long as the feature stays inside the window, the output is unchanged.
Modern alternative: a strided convolution does the same downsampling but learns its own “pool” function.
Same knobs as conv, but different defaults: a framework MaxPool2d matches stride to window size (non-overlapping pools) — we want to reduce resolution, not preserve it.
tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]]])
Override the defaults when you want overlapping pools:
tensor([[[[ 5., 7.],
[13., 15.]]]])
Convs combine channels (input channels feed every output channel). Pooling does not:
tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]],
[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]])
stride=2) handle most downsampling. One initial max-pool, then strided convs.