Pretrained backbone

Fully Convolutional Networks

Fully Convolutional Networks

A fully convolutional network (Long, Shelhamer, Darrell 2015) is the simplest path to per-pixel prediction:

  1. Start with a pretrained classification CNN (ResNet).
  2. Strip the global average pool + final dense layer.
  3. Replace with a 1×1 conv mapping to num_classes.
  4. Upsample back to input resolution via transposed conv.

No FC layers anywhere — works on any input size, outputs a class-score map at input resolution.

Architecture

FCN: pretrained CNN body + 1×1 conv → class scores → transposed conv to upsample.

Setup

%matplotlib inline
from d2l import torch as d2l
import torch
import torchvision
from torch import nn
from torch.nn import functional as F

ResNet-18 pretrained on ImageNet. Drop the head (avg pool + dense); keep the conv body that produces a \frac{H}{32} \times \frac{W}{32} feature map:

pretrained_net = torchvision.models.resnet18(
    weights=torchvision.models.ResNet18_Weights.DEFAULT)
list(pretrained_net.children())[-3:]
[Sequential(
   (0): BasicBlock(
     (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
...
     (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
 ),
 AdaptiveAvgPool2d(output_size=(1, 1)),
 Linear(in_features=512, out_features=1000, bias=True)]

Building the FCN

After removing the classifier head, the backbone produces a low-resolution feature map. The new FCN head must restore the original spatial resolution while changing channels to class logits.

net = nn.Sequential(*list(pretrained_net.children())[:-2])
X = torch.rand(size=(1, 3, 320, 480))
net(X).shape
torch.Size([1, 512, 10, 15])

The class & upsampling head

1 \times 1 conv: num_featuresnum_classes (21 for VOC). Then a transposed conv that upsamples by 32× to recover input resolution:

num_classes = 21
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv', nn.ConvTranspose2d(num_classes, num_classes,
                                    kernel_size=64, padding=16, stride=32))

Bilinear init for transposed conv

A randomly initialized 32× upsampler is hard to train. Initialize it as bilinear interpolation — a sensible starting point that fine-tunes from there:

def bilinear_kernel(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = (torch.arange(kernel_size).reshape(-1, 1),
          torch.arange(kernel_size).reshape(1, -1))
    filt = (1 - torch.abs(og[0] - center) / factor) * \
           (1 - torch.abs(og[1] - center) / factor)
    weight = torch.zeros((in_channels, out_channels,
                          kernel_size, kernel_size))
    weight[range(in_channels), range(out_channels), :, :] = filt
    return weight

Upsampling sanity check

Apply the initialized transposed convolution to an image. The output should be larger but visually similar, because the kernel starts as bilinear interpolation rather than random noise:

conv_trans = nn.ConvTranspose2d(3, 3, kernel_size=4, padding=1, stride=2,
                                bias=False)
conv_trans.weight.data.copy_(bilinear_kernel(3, 3, 4));
img = torchvision.transforms.ToTensor()(d2l.Image.open('../img/catdog.jpg'))
X = img.unsqueeze(0)
Y = conv_trans(X)
out_img = Y[0].permute(1, 2, 0).detach()

Bilinear init (cont.)

The printed shapes should confirm the spatial scale-up. Then the same bilinear kernel initializes the FCN’s final upsampling layer:

d2l.set_figsize()
print('input image shape:', img.permute(1, 2, 0).shape)
d2l.plt.imshow(img.permute(1, 2, 0));
print('output image shape:', out_img.shape)
d2l.plt.imshow(out_img);

input image shape: torch.Size([561, 728, 3])
output image shape: torch.Size([1122, 1456, 3])
W = bilinear_kernel(num_classes, num_classes, 64)
net.transpose_conv.weight.data.copy_(W);

Loading data

batch_size, crop_size = 32, (320, 480)
train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size)
read 1114 examples
read 1078 examples

Training

Pixel-level cross-entropy. Common trick: freeze the backbone, train only the new head — gets reasonable results in a few epochs:

def loss(inputs, targets):
    return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)

num_epochs, lr, wd, devices = 5, 0.001, 1e-3, d2l.try_all_gpus()
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

loss 0.426, train acc 0.867, test acc 0.851
371.7 examples/sec on [device(type='cuda', index=0)]

Predict

Run the network on test images, take argmax over the class dimension, map class indices back to RGB:

def predict(img):
    X = test_iter.dataset.normalize_image(img).unsqueeze(0)
    pred = net(X.to(devices[0])).argmax(dim=1)
    return pred.reshape(pred.shape[1], pred.shape[2])

Visualize segmentation masks

The output grid is image, prediction, ground truth. Expect coarse boundaries: this plain FCN upsamples from a 32× downsampled feature map and has no skip connections.

def label2image(pred):
    colormap = torch.tensor(d2l.VOC_COLORMAP, device=devices[0])
    X = pred.long()
    return colormap[X, :]
voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')
test_images, test_labels = d2l.read_voc_images(voc_dir, False)
n, imgs = 4, []
for i in range(n):
    crop_rect = (0, 0, 320, 480)
    X = torchvision.transforms.functional.crop(test_images[i], *crop_rect)
    pred = label2image(predict(X))
    imgs += [X.permute(1,2,0), pred.cpu(),
             torchvision.transforms.functional.crop(
                 test_labels[i], *crop_rect).permute(1,2,0)]
d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2);

Recap

  • FCN = pretrained classification CNN + 1×1 conv + transposed conv upsampler.
  • All-conv → input size doesn’t matter.
  • Bilinear-initialized transposed conv is the workable starting point; fine-tunes from there.
  • The blueprint behind U-Net (skip connections fix the blur), DeepLab (dilated convs avoid the heavy upsampling), and modern segmentation networks.