Pretrained backbone

Fully Convolutional Networks

Fully Convolutional Networks

A fully convolutional network (Long, Shelhamer, Darrell 2015) is the simplest path to per-pixel prediction:

  1. Start with a pretrained classification CNN (ResNet).
  2. Strip the global average pool + final dense layer.
  3. Replace with a 1×1 conv mapping to num_classes.
  4. Upsample back to input resolution via transposed conv.

No FC layers anywhere — works on any input size, outputs a class-score map at input resolution.

Architecture

FCN: pretrained CNN body + 1×1 conv → class scores → transposed conv to upsample.

Setup

%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
from PIL import Image

ResNet-18 pretrained on ImageNet. Drop the head (avg pool + dense); keep the conv body that produces a \frac{H}{32} \times \frac{W}{32} feature map:

# Define ResNet building blocks for the feature extractor
class ResNetBlock(nn.Module):
    num_channels: int
    strides: tuple = (1, 1)
    use_1x1conv: bool = False

    @nn.compact
    def __call__(self, x, training=False):
        residual = x
        y = nn.Conv(self.num_channels, kernel_size=(3, 3),
                    strides=self.strides, padding='SAME')(x)
        y = nn.BatchNorm(use_running_average=not training)(y)
        y = nn.relu(y)
        y = nn.Conv(self.num_channels, kernel_size=(3, 3),
                    strides=(1, 1), padding='SAME')(y)
        y = nn.BatchNorm(use_running_average=not training)(y)
        if self.use_1x1conv:
            residual = nn.Conv(self.num_channels, kernel_size=(1, 1),
                               strides=self.strides)(x)
            residual = nn.BatchNorm(
                use_running_average=not training)(residual)
        return nn.relu(y + residual)

class ResNetFeatures(nn.Module):
    """ResNet-18 feature extractor (without global avg pool and FC)."""
    @nn.compact
    def __call__(self, x, training=False):
        # Initial conv + bn + relu + maxpool
        x = nn.Conv(64, kernel_size=(7, 7), strides=(2, 2),
                    padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2),
                        padding='SAME')
        # Stage 1: 64 channels
        x = ResNetBlock(64)(x, training)
        x = ResNetBlock(64)(x, training)
        # Stage 2: 128 channels, downsample
        x = ResNetBlock(128, strides=(2, 2), use_1x1conv=True)(x, training)
        x = ResNetBlock(128)(x, training)
        # Stage 3: 256 channels, downsample
        x = ResNetBlock(256, strides=(2, 2), use_1x1conv=True)(x, training)
        x = ResNetBlock(256)(x, training)
        # Stage 4: 512 channels, downsample
        x = ResNetBlock(512, strides=(2, 2), use_1x1conv=True)(x, training)
        x = ResNetBlock(512)(x, training)
        return x

pretrained_net = ResNetFeatures()
# Initialize with a dummy input to see the architecture
dummy = jnp.ones((1, 320, 480, 3))
variables = pretrained_net.init(jax.random.PRNGKey(0), dummy)
print('Feature extractor output shape:',
      pretrained_net.apply(variables, dummy).shape)
Feature extractor output shape: (1, 10, 15, 512)

Building the FCN

After removing the classifier head, the backbone produces a low-resolution feature map. The new FCN head must restore the original spatial resolution while changing channels to class logits.

# The ResNetFeatures module already excludes global avg pool and FC.
# We define the full FCN model below.
X = jnp.ones((1, 320, 480, 3))
pretrained_net.apply(variables, X).shape
(1, 10, 15, 512)

The class & upsampling head

1 \times 1 conv: num_featuresnum_classes (21 for VOC). Then a transposed conv that upsamples by 32× to recover input resolution:

num_classes = 21

class FCN(nn.Module):
    """Fully Convolutional Network for semantic segmentation."""
    num_classes: int

    @nn.compact
    def __call__(self, x, training=False):
        # Feature extraction (ResNet-18 backbone)
        x = ResNetFeatures()(x, training)
        # 1x1 conv to map to num_classes channels
        x = nn.Conv(self.num_classes, kernel_size=(1, 1))(x)
        # Transposed conv to upsample by 32x
        x = nn.ConvTranspose(self.num_classes, kernel_size=(64, 64),
                              strides=(32, 32), padding='SAME')(x)
        return x

net = FCN(num_classes=num_classes)
variables = net.init(jax.random.PRNGKey(0), jnp.ones((1, 320, 480, 3)))
print('FCN output shape:',
      net.apply(variables, jnp.ones((1, 320, 480, 3))).shape)
FCN output shape: (1, 320, 480, 21)

Bilinear init for transposed conv

A randomly initialized 32× upsampler is hard to train. Initialize it as bilinear interpolation — a sensible starting point that fine-tunes from there:

def bilinear_kernel(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = (np.arange(kernel_size).reshape(-1, 1),
          np.arange(kernel_size).reshape(1, -1))
    filt = (1 - np.abs(og[0] - center) / factor) * \
           (1 - np.abs(og[1] - center) / factor)
    # Flax uses HWIO format for ConvTranspose kernels
    weight = np.zeros((kernel_size, kernel_size, in_channels, out_channels))
    for i in range(min(in_channels, out_channels)):
        weight[:, :, i, i] = filt
    return jnp.array(weight)

Upsampling sanity check

Apply the initialized transposed convolution to an image. The output should be larger but visually similar, because the kernel starts as bilinear interpolation rather than random noise:

class BilinearConvTranspose(nn.Module):
    """A transposed conv layer initialized with bilinear interpolation."""
    channels: int
    kernel_size: int
    strides: tuple

    @nn.compact
    def __call__(self, x):
        return nn.ConvTranspose(self.channels,
                                kernel_size=(self.kernel_size,
                                             self.kernel_size),
                                strides=self.strides,
                                padding='SAME')(x)

conv_trans = BilinearConvTranspose(channels=3, kernel_size=4, strides=(2, 2))
dummy_img = jnp.ones((1, 100, 100, 3))
ct_variables = conv_trans.init(jax.random.PRNGKey(0), dummy_img)
# Replace the kernel with bilinear weights
bilinear_w = bilinear_kernel(3, 3, 4)
ct_variables = {**ct_variables,
    'params': {**ct_variables['params'],
               'ConvTranspose_0': {**ct_variables['params']['ConvTranspose_0'],
                                   'kernel': bilinear_w}}}
img = np.array(Image.open('../img/catdog.jpg')).astype(np.float32) / 255
X = jnp.expand_dims(jnp.array(img), axis=0)  # NHWC
Y = conv_trans.apply(ct_variables, X)
out_img = np.array(Y[0])
E0524 02:33:10.183951 53210 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, ...

Bilinear init (cont.)

The printed shapes should confirm the spatial scale-up. Then the same bilinear kernel initializes the FCN’s final upsampling layer:

d2l.set_figsize()
print('input image shape:', img.shape)
d2l.plt.imshow(img);
print('output image shape:', out_img.shape)
d2l.plt.imshow(out_img);

input image shape: (561, 728, 3)
output image shape: (1122, 1456, 3)
# Initialize the FCN with bilinear weights for the transposed conv layer
# and Xavier initialization for the 1x1 conv layer
W = bilinear_kernel(num_classes, num_classes, 64)

def init_fcn_weights(rng):
    """Initialize FCN with bilinear upsampling for transposed conv."""
    variables = net.init(rng, jnp.ones((1, 320, 480, 3)))
    params = variables['params']
    # Set bilinear kernel for the transposed conv layer
    flat_params = dict(params)
    flat_params['ConvTranspose_0'] = {
        **params['ConvTranspose_0'], 'kernel': W}
    return {**variables, 'params': flat_params}

variables = init_fcn_weights(jax.random.PRNGKey(42))

Loading data

batch_size, crop_size = 32, (320, 480)
train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size)
read 1114 examples
read 1078 examples

Training

Pixel-level cross-entropy. Common trick: freeze the backbone, train only the new head — gets reasonable results in a few epochs:

def loss_fn(params, batch_stats, X, Y):
    # X is NHWC, Y is NHW with integer class labels
    logits, updates = net.apply(
        {'params': params, 'batch_stats': batch_stats},
        X, training=True, mutable=['batch_stats'])
    # logits shape: (N, H, W, num_classes)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, Y)
    return loss.mean(), updates

num_epochs, lr, wd = 5, 0.001, 1e-3
optimizer = optax.sgd(lr)
opt_state = optimizer.init(variables['params'])
batch_stats = variables.get('batch_stats', {})

@jax.jit
def train_step(params, batch_stats, opt_state, X, Y):
    (loss_val, updates), grads = jax.value_and_grad(
        loss_fn, has_aux=True)(params, batch_stats, X, Y)
    param_updates, opt_state_new = optimizer.update(grads, opt_state, params)
    params_new = optax.apply_updates(params, param_updates)
    return params_new, updates['batch_stats'], opt_state_new, loss_val

params = variables['params']
for epoch in range(num_epochs):
    for X, Y in train_iter:
        # Convert CHW to HWC for JAX
        X = jnp.transpose(jnp.array(X), (0, 2, 3, 1))
        Y = jnp.array(Y)
        params, batch_stats, opt_state, loss_val = train_step(
            params, batch_stats, opt_state, X, Y)
    print(f'epoch {epoch + 1}, loss {float(loss_val):.3f}')
variables = {'params': params, 'batch_stats': batch_stats}
epoch 1, loss 1.659
epoch 2, loss 1.536
epoch 3, loss 1.192
epoch 4, loss 1.358
epoch 5, loss 1.040

Predict

Run the network on test images, take argmax over the class dimension, map class indices back to RGB:

def predict(img):
    rgb_mean = np.array([0.485, 0.456, 0.406])
    rgb_std = np.array([0.229, 0.224, 0.225])
    X = (img.astype(np.float32) / 255 - rgb_mean) / rgb_std
    X = jnp.expand_dims(jnp.array(X), axis=0)  # NHWC
    pred = net.apply(variables, X, training=False)
    return jnp.argmax(pred, axis=-1).reshape(pred.shape[1], pred.shape[2])

Visualize segmentation masks

The output grid is image, prediction, ground truth. Expect coarse boundaries: this plain FCN upsamples from a 32× downsampled feature map and has no skip connections.

def label2image(pred):
    colormap = jnp.array(d2l.VOC_COLORMAP, dtype=jnp.uint8)
    X = pred.astype(jnp.int32)
    return colormap[X, :]
voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')
test_images, test_labels = d2l.read_voc_images(voc_dir, False)
n, imgs = 4, []
for i in range(n):
    # Crop HWC arrays: top=0, left=0, height=320, width=480
    X = test_images[i][:320, :480, :]
    pred = label2image(predict(X))
    label_crop = test_labels[i][:320, :480, :]
    imgs += [X, np.array(pred), label_crop]
d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2);

Recap

  • FCN = pretrained classification CNN + 1×1 conv + transposed conv upsampler.
  • All-conv → input size doesn’t matter.
  • Bilinear-initialized transposed conv is the workable starting point; fine-tunes from there.
  • The blueprint behind U-Net (skip connections fix the blur), DeepLab (dilated convs avoid the heavy upsampling), and modern segmentation networks.