import jax
from jax import numpy as jnp
from flax import linen as nn
from d2l import jax as d2l
import numpy as npA standard convolution + pooling stack reduces spatial resolution. For dense prediction (semantic segmentation, generative models, super-resolution) we need to go the other way — upsample features back to image resolution.
The standard tool: transposed convolution, also called “deconvolution” (a misnomer — it’s not a true inverse). Each input element broadcasts a full kernel into the output, contributions from neighbors get summed:
A 2 \times 2 transposed convolution: each input element scatters its kernel into the output.
Output shape grows: with stride 1, kernel k, no padding, n_{\text{out}} = n_{\text{in}} + k - 1. With stride s, multiplied accordingly.
The hand-written implementation should match the framework operator. If the shape or values differ, the usual culprits are padding semantics or channel layout:
Array([[ 0., 0., 1.],
[ 0., 4., 6.],
[ 4., 12., 9.]], dtype=float32)
Same result via the framework op (PyTorch ConvTranspose2d, etc.):
Array([[[[ 0.],
[ 3.]],
[[ 6.],
[14.]]]], dtype=float32)
Padding here removes output rows/columns instead of adding them — it’s the inverse interpretation.
Stride > 1 inserts zeros between input elements before the scatter — that’s how transposed conv upsamples:
Stride-2 transposed conv: each input element’s kernel is placed at twice-spaced positions, then summed.
Array([[[[14.]]]], dtype=float32)
Array([[[[0.],
[0.],
[3.],
[2.]],
[[0.],
...
[6.]],
[[2.],
[0.],
[3.],
[0.]]]], dtype=float32)
Multi-channel works as expected: input channels reduce-add through the kernel, output channels stack in parallel:
# JAX uses channels-last format: (batch, height, width, channels)
X = jax.random.normal(jax.random.PRNGKey(0), (1, 16, 16, 10))
conv = nn.Conv(20, kernel_size=(5, 5), padding='SAME', strides=(3, 3))
tconv = nn.ConvTranspose(10, kernel_size=(5, 5), padding='SAME', strides=(3, 3))
params_conv = conv.init(jax.random.PRNGKey(1), X)
Y = conv.apply(params_conv, X)
params_tconv = tconv.init(jax.random.PRNGKey(2), Y)
tconv.apply(params_tconv, Y).shape == X.shapeFalse
A standard convolution can be written as a sparse matrix multiplication \mathbf{y} = \mathbf{K}\mathbf{x} where \mathbf{K} encodes the kernel + stride + padding.
A transposed convolution multiplies by the transpose: \mathbf{x}' = \mathbf{K}^\top \mathbf{y}. That’s where the name comes from.
Array([[27., 37.],
[57., 67.]], dtype=float32)
Array([[1., 2., 0., 3., 4., 0., 0., 0., 0.],
[0., 1., 2., 0., 3., 4., 0., 0., 0.],
[0., 0., 0., 1., 2., 0., 3., 4., 0.],
[0., 0., 0., 0., 1., 2., 0., 3., 4.]], dtype=float32)
Array([[ True, True],
[ True, True]], dtype=bool)