import torch
from torch import nn
from d2l import torch as d2lA standard convolution + pooling stack reduces spatial resolution. For dense prediction (semantic segmentation, generative models, super-resolution) we need to go the other way — upsample features back to image resolution.
The standard tool: transposed convolution, also called “deconvolution” (a misnomer — it’s not a true inverse). Each input element broadcasts a full kernel into the output, contributions from neighbors get summed:
A 2 \times 2 transposed convolution: each input element scatters its kernel into the output.
Output shape grows: with stride 1, kernel k, no padding, n_{\text{out}} = n_{\text{in}} + k - 1. With stride s, multiplied accordingly.
The hand-written implementation should match the framework operator. If the shape or values differ, the usual culprits are padding semantics or channel layout:
tensor([[ 0., 0., 1.],
[ 0., 4., 6.],
[ 4., 12., 9.]])
Padding here removes output rows/columns instead of adding them — it’s the inverse interpretation.
Stride > 1 inserts zeros between input elements before the scatter — that’s how transposed conv upsamples:
Stride-2 transposed conv: each input element’s kernel is placed at twice-spaced positions, then summed.
tensor([[[[4.]]]], grad_fn=<ConvolutionBackward0>)
tensor([[[[0., 0., 0., 1.],
[0., 0., 2., 3.],
[0., 2., 0., 3.],
[4., 6., 6., 9.]]]], grad_fn=<ConvolutionBackward0>)
Multi-channel works as expected: input channels reduce-add through the kernel, output channels stack in parallel:
True
A standard convolution can be written as a sparse matrix multiplication \mathbf{y} = \mathbf{K}\mathbf{x} where \mathbf{K} encodes the kernel + stride + padding.
A transposed convolution multiplies by the transpose: \mathbf{x}' = \mathbf{K}^\top \mathbf{y}. That’s where the name comes from.
tensor([[27., 37.],
[57., 67.]])
tensor([[1., 2., 0., 3., 4., 0., 0., 0., 0.],
[0., 1., 2., 0., 3., 4., 0., 0., 0.],
[0., 0., 0., 1., 2., 0., 3., 4., 0.],
[0., 0., 0., 0., 1., 2., 0., 3., 4.]])
tensor([[True, True],
[True, True]])