Block variants

Residual Networks (ResNet) and ResNeXt

ResNet learns residuals

ResNet (He et al., 2015) is the architecture that finally made very deep networks trainable. The key:

\mathbf{y} = f(\mathbf{x}) + \mathbf{x}.

The function only needs to learn the residual relative to identity. Identity is always representable, so adding more layers can’t hurt — 18 → 152 layers genuinely improves accuracy. Gradients flow through the skip at full strength, so deep nets train as easily as shallow ones.

Residual block

Plain block (left) vs residual block (right). Skip-add carries the input around the conv stack.

Block in code

A 2-conv block with a skip-add. Optional 1×1 conv on the skip path matches channel/stride changes:

from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jax
class Residual(nn.Module):
    """The Residual block of ResNet models."""
    num_channels: int
    use_1x1conv: bool = False
    strides: tuple = (1, 1)
    training: bool = True

    def setup(self):
        self.conv1 = nn.Conv(self.num_channels, kernel_size=(3, 3),
                             padding='same', strides=self.strides)
        self.conv2 = nn.Conv(self.num_channels, kernel_size=(3, 3),
                             padding='same')
        # Auto-enable 1x1 conv when downsampling so the residual shape matches.
        if self.use_1x1conv or any(s != 1 for s in self.strides):
            self.conv3 = nn.Conv(self.num_channels, kernel_size=(1, 1),
                                 strides=self.strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm(not self.training)
        self.bn2 = nn.BatchNorm(not self.training)

    def __call__(self, X):
        Y = nn.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return nn.relu(Y)

Same shape in, same shape out:

blk = Residual(3)
X = jax.random.normal(d2l.get_key(), (4, 6, 6, 3))
blk.init_with_output(d2l.get_key(), X)[0].shape
(4, 6, 6, 3)

Halve spatial dims and double channels (transition between stages):

blk = Residual(6, use_1x1conv=True, strides=(2, 2))
blk.init_with_output(d2l.get_key(), X)[0].shape
(4, 3, 3, 6)

The ResNet model

Stages of N residual blocks, with downsampling at the start of each stage:

ResNet-18: four stages of two residual blocks each, plus stem and head.

ResNet stem

The stem does early feature extraction and spatial reduction, similar to AlexNet and GoogLeNet:

class ResNet(d2l.Classifier):
    arch: tuple
    lr: float = 0.1
    num_classes: int = 10
    training: bool = True

    def setup(self):
        self.net = self.create_net()

    def b1(self):
        return nn.Sequential([
            nn.Conv(64, kernel_size=(7, 7), strides=(2, 2), padding='same'),
            nn.BatchNorm(not self.training), nn.relu,
            lambda x: nn.max_pool(x, window_shape=(3, 3), strides=(2, 2),
                                  padding='same')])

Residual stages

A stage is a stack of residual blocks. The first block can downsample and project the skip path; later blocks keep shape.

def block(self, num_residuals, num_channels, first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(num_channels, use_1x1conv=True,
                                strides=(2, 2), training=self.training))
        else:
            blk.append(Residual(num_channels, training=self.training))
    return nn.Sequential(blk)

ResNet head

After the residual stages, global average pooling collapses the spatial map and the final linear layer predicts classes.

def create_net(self):
    net = nn.Sequential([self.b1()])
    for i, b in enumerate(self.arch):
        net.layers.extend([self.block(*b, first_block=(i==0))])
    net.layers.extend([nn.Sequential([
        # Flax does not provide a GlobalAvg2D layer
        lambda x: nn.avg_pool(x, window_shape=x.shape[1:3],
                              strides=x.shape[1:3], padding='valid'),
        lambda x: x.reshape((x.shape[0], -1)),
        nn.Dense(self.num_classes)])])
    return net

ResNet-18 assembly

Four stages × 2 residual blocks each — same template defines ResNet-34/50/101/152:

class ResNet18(ResNet):
    arch: tuple = ((2, 64), (2, 128), (2, 256), (2, 512))
    lr: float = 0.1
    num_classes: int = 10
ResNet18(training=False).layer_summary((1, 96, 96, 1))
Sequential output shape:     (1, 24, 24, 64)
Sequential output shape:     (1, 24, 24, 64)
Sequential output shape:     (1, 12, 12, 128)
Sequential output shape:     (1, 6, 6, 256)
Sequential output shape:     (1, 3, 3, 512)
Sequential output shape:     (1, 10)

Training

model = ResNet18(lr=0.01)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))
trainer.fit(model, data)

The notebook trains a compact ResNet-18 variant on Fashion-MNIST; the point is to validate that the residual-stage template plugs into the same Trainer used by earlier CNNs.

ResNeXt: width via cardinality

A cleaner variant: each block has multiple parallel paths (cardinality C) instead of one wide one — same parameter budget, better accuracy:

class ResNeXtBlock(nn.Module):
    """The ResNeXt block."""
    num_channels: int
    groups: int
    bot_mul: int
    use_1x1conv: bool = False
    strides: tuple = (1, 1)
    training: bool = True

    def setup(self):
        bot_channels = int(round(self.num_channels * self.bot_mul))
        self.conv1 = nn.Conv(bot_channels, kernel_size=(1, 1),
                               strides=(1, 1))
        self.conv2 = nn.Conv(bot_channels, kernel_size=(3, 3),
                               strides=self.strides, padding='same',
                               feature_group_count=self.groups)
        self.conv3 = nn.Conv(self.num_channels, kernel_size=(1, 1),
                               strides=(1, 1))
        self.bn1 = nn.BatchNorm(not self.training)
        self.bn2 = nn.BatchNorm(not self.training)
        self.bn3 = nn.BatchNorm(not self.training)
        if self.use_1x1conv:
            self.conv4 = nn.Conv(self.num_channels, kernel_size=(1, 1),
                                       strides=self.strides)
            self.bn4 = nn.BatchNorm(not self.training)
        else:
            self.conv4 = None

    def __call__(self, X):
        Y = nn.relu(self.bn1(self.conv1(X)))
        Y = nn.relu(self.bn2(self.conv2(Y)))
        Y = self.bn3(self.conv3(Y))
        if self.conv4:
            X = self.bn4(self.conv4(X))
        return nn.relu(Y + X)

Grouped-conv savings

Grouped convolution cuts the expensive 3×3 channel mixing by a factor of groups, while surrounding 1×1 convolutions let information mix before and after the grouped work.

blk = ResNeXtBlock(32, 16, 1)
X = jnp.zeros((4, 96, 96, 32))
blk.init_with_output(d2l.get_key(), X)[0].shape
(4, 96, 96, 32)

Recap

  • Residual connection: \mathbf{y} = f(\mathbf{x}) + \mathbf{x} — guarantees identity is always representable.
  • Trains networks arbitrarily deep (152, 1000+) without optimization pathologies.
  • The “residual block” + “stage” template is universal — used in vision (ResNet, ResNeXt, DenseNet), language (Transformers, all use residual + LayerNorm), and beyond.
  • ResNet-50 is the default ImageNet backbone for transfer learning even a decade later.