%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
from PIL import ImageA fully convolutional network (Long, Shelhamer, Darrell 2015) is the simplest path to per-pixel prediction:
num_classes.No FC layers anywhere — works on any input size, outputs a class-score map at input resolution.
FCN: pretrained CNN body + 1×1 conv → class scores → transposed conv to upsample.
ResNet-18 pretrained on ImageNet. Drop the head (avg pool + dense); keep the conv body that produces a \frac{H}{32} \times \frac{W}{32} feature map:
# Define ResNet building blocks for the feature extractor
class ResNetBlock(nn.Module):
num_channels: int
strides: tuple = (1, 1)
use_1x1conv: bool = False
@nn.compact
def __call__(self, x, training=False):
residual = x
y = nn.Conv(self.num_channels, kernel_size=(3, 3),
strides=self.strides, padding='SAME')(x)
y = nn.BatchNorm(use_running_average=not training)(y)
y = nn.relu(y)
y = nn.Conv(self.num_channels, kernel_size=(3, 3),
strides=(1, 1), padding='SAME')(y)
y = nn.BatchNorm(use_running_average=not training)(y)
if self.use_1x1conv:
residual = nn.Conv(self.num_channels, kernel_size=(1, 1),
strides=self.strides)(x)
residual = nn.BatchNorm(
use_running_average=not training)(residual)
return nn.relu(y + residual)
class ResNetFeatures(nn.Module):
"""ResNet-18 feature extractor (without global avg pool and FC)."""
@nn.compact
def __call__(self, x, training=False):
# Initial conv + bn + relu + maxpool
x = nn.Conv(64, kernel_size=(7, 7), strides=(2, 2),
padding='SAME')(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2),
padding='SAME')
# Stage 1: 64 channels
x = ResNetBlock(64)(x, training)
x = ResNetBlock(64)(x, training)
# Stage 2: 128 channels, downsample
x = ResNetBlock(128, strides=(2, 2), use_1x1conv=True)(x, training)
x = ResNetBlock(128)(x, training)
# Stage 3: 256 channels, downsample
x = ResNetBlock(256, strides=(2, 2), use_1x1conv=True)(x, training)
x = ResNetBlock(256)(x, training)
# Stage 4: 512 channels, downsample
x = ResNetBlock(512, strides=(2, 2), use_1x1conv=True)(x, training)
x = ResNetBlock(512)(x, training)
return x
pretrained_net = ResNetFeatures()
# Initialize with a dummy input to see the architecture
dummy = jnp.ones((1, 320, 480, 3))
variables = pretrained_net.init(jax.random.PRNGKey(0), dummy)
print('Feature extractor output shape:',
pretrained_net.apply(variables, dummy).shape)Feature extractor output shape: (1, 10, 15, 512)
After removing the classifier head, the backbone produces a low-resolution feature map. The new FCN head must restore the original spatial resolution while changing channels to class logits.
1 \times 1 conv: num_features → num_classes (21 for VOC). Then a transposed conv that upsamples by 32× to recover input resolution:
num_classes = 21
class FCN(nn.Module):
"""Fully Convolutional Network for semantic segmentation."""
num_classes: int
@nn.compact
def __call__(self, x, training=False):
# Feature extraction (ResNet-18 backbone)
x = ResNetFeatures()(x, training)
# 1x1 conv to map to num_classes channels
x = nn.Conv(self.num_classes, kernel_size=(1, 1))(x)
# Transposed conv to upsample by 32x
x = nn.ConvTranspose(self.num_classes, kernel_size=(64, 64),
strides=(32, 32), padding='SAME')(x)
return x
net = FCN(num_classes=num_classes)
variables = net.init(jax.random.PRNGKey(0), jnp.ones((1, 320, 480, 3)))
print('FCN output shape:',
net.apply(variables, jnp.ones((1, 320, 480, 3))).shape)FCN output shape: (1, 320, 480, 21)
A randomly initialized 32× upsampler is hard to train. Initialize it as bilinear interpolation — a sensible starting point that fine-tunes from there:
def bilinear_kernel(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = (np.arange(kernel_size).reshape(-1, 1),
np.arange(kernel_size).reshape(1, -1))
filt = (1 - np.abs(og[0] - center) / factor) * \
(1 - np.abs(og[1] - center) / factor)
# Flax uses HWIO format for ConvTranspose kernels
weight = np.zeros((kernel_size, kernel_size, in_channels, out_channels))
for i in range(min(in_channels, out_channels)):
weight[:, :, i, i] = filt
return jnp.array(weight)Apply the initialized transposed convolution to an image. The output should be larger but visually similar, because the kernel starts as bilinear interpolation rather than random noise:
class BilinearConvTranspose(nn.Module):
"""A transposed conv layer initialized with bilinear interpolation."""
channels: int
kernel_size: int
strides: tuple
@nn.compact
def __call__(self, x):
return nn.ConvTranspose(self.channels,
kernel_size=(self.kernel_size,
self.kernel_size),
strides=self.strides,
padding='SAME')(x)
conv_trans = BilinearConvTranspose(channels=3, kernel_size=4, strides=(2, 2))
dummy_img = jnp.ones((1, 100, 100, 3))
ct_variables = conv_trans.init(jax.random.PRNGKey(0), dummy_img)
# Replace the kernel with bilinear weights
bilinear_w = bilinear_kernel(3, 3, 4)
ct_variables = {**ct_variables,
'params': {**ct_variables['params'],
'ConvTranspose_0': {**ct_variables['params']['ConvTranspose_0'],
'kernel': bilinear_w}}}E0524 02:33:10.183951 53210 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, ...
The printed shapes should confirm the spatial scale-up. Then the same bilinear kernel initializes the FCN’s final upsampling layer:
input image shape: (561, 728, 3)
output image shape: (1122, 1456, 3)
# Initialize the FCN with bilinear weights for the transposed conv layer
# and Xavier initialization for the 1x1 conv layer
W = bilinear_kernel(num_classes, num_classes, 64)
def init_fcn_weights(rng):
"""Initialize FCN with bilinear upsampling for transposed conv."""
variables = net.init(rng, jnp.ones((1, 320, 480, 3)))
params = variables['params']
# Set bilinear kernel for the transposed conv layer
flat_params = dict(params)
flat_params['ConvTranspose_0'] = {
**params['ConvTranspose_0'], 'kernel': W}
return {**variables, 'params': flat_params}
variables = init_fcn_weights(jax.random.PRNGKey(42))read 1114 examples
read 1078 examples
Pixel-level cross-entropy. Common trick: freeze the backbone, train only the new head — gets reasonable results in a few epochs:
def loss_fn(params, batch_stats, X, Y):
# X is NHWC, Y is NHW with integer class labels
logits, updates = net.apply(
{'params': params, 'batch_stats': batch_stats},
X, training=True, mutable=['batch_stats'])
# logits shape: (N, H, W, num_classes)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, Y)
return loss.mean(), updates
num_epochs, lr, wd = 5, 0.001, 1e-3
optimizer = optax.sgd(lr)
opt_state = optimizer.init(variables['params'])
batch_stats = variables.get('batch_stats', {})
@jax.jit
def train_step(params, batch_stats, opt_state, X, Y):
(loss_val, updates), grads = jax.value_and_grad(
loss_fn, has_aux=True)(params, batch_stats, X, Y)
param_updates, opt_state_new = optimizer.update(grads, opt_state, params)
params_new = optax.apply_updates(params, param_updates)
return params_new, updates['batch_stats'], opt_state_new, loss_val
params = variables['params']
for epoch in range(num_epochs):
for X, Y in train_iter:
# Convert CHW to HWC for JAX
X = jnp.transpose(jnp.array(X), (0, 2, 3, 1))
Y = jnp.array(Y)
params, batch_stats, opt_state, loss_val = train_step(
params, batch_stats, opt_state, X, Y)
print(f'epoch {epoch + 1}, loss {float(loss_val):.3f}')
variables = {'params': params, 'batch_stats': batch_stats}epoch 1, loss 1.659
epoch 2, loss 1.536
epoch 3, loss 1.192
epoch 4, loss 1.358
epoch 5, loss 1.040
Run the network on test images, take argmax over the class dimension, map class indices back to RGB:
def predict(img):
rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])
X = (img.astype(np.float32) / 255 - rgb_mean) / rgb_std
X = jnp.expand_dims(jnp.array(X), axis=0) # NHWC
pred = net.apply(variables, X, training=False)
return jnp.argmax(pred, axis=-1).reshape(pred.shape[1], pred.shape[2])The output grid is image, prediction, ground truth. Expect coarse boundaries: this plain FCN upsamples from a 32× downsampled feature map and has no skip connections.
voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')
test_images, test_labels = d2l.read_voc_images(voc_dir, False)
n, imgs = 4, []
for i in range(n):
# Crop HWC arrays: top=0, left=0, height=320, width=480
X = test_images[i][:320, :480, :]
pred = label2image(predict(X))
label_crop = test_labels[i][:320, :480, :]
imgs += [X, np.array(pred), label_crop]
d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2);