from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jaxResNet (He et al., 2015) is the architecture that finally made very deep networks trainable. The key:
\mathbf{y} = f(\mathbf{x}) + \mathbf{x}.
The function only needs to learn the residual relative to identity. Identity is always representable, so adding more layers can’t hurt — 18 → 152 layers genuinely improves accuracy. Gradients flow through the skip at full strength, so deep nets train as easily as shallow ones.
Plain block (left) vs residual block (right). Skip-add carries the input around the conv stack.
A 2-conv block with a skip-add. Optional 1×1 conv on the skip path matches channel/stride changes:
class Residual(nn.Module):
"""The Residual block of ResNet models."""
num_channels: int
use_1x1conv: bool = False
strides: tuple = (1, 1)
training: bool = True
def setup(self):
self.conv1 = nn.Conv(self.num_channels, kernel_size=(3, 3),
padding='same', strides=self.strides)
self.conv2 = nn.Conv(self.num_channels, kernel_size=(3, 3),
padding='same')
# Auto-enable 1x1 conv when downsampling so the residual shape matches.
if self.use_1x1conv or any(s != 1 for s in self.strides):
self.conv3 = nn.Conv(self.num_channels, kernel_size=(1, 1),
strides=self.strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm(not self.training)
self.bn2 = nn.BatchNorm(not self.training)
def __call__(self, X):
Y = nn.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
Y += X
return nn.relu(Y)Same shape in, same shape out:
(4, 6, 6, 3)
Stages of N residual blocks, with downsampling at the start of each stage:
ResNet-18: four stages of two residual blocks each, plus stem and head.
The stem does early feature extraction and spatial reduction, similar to AlexNet and GoogLeNet:
class ResNet(d2l.Classifier):
arch: tuple
lr: float = 0.1
num_classes: int = 10
training: bool = True
def setup(self):
self.net = self.create_net()
def b1(self):
return nn.Sequential([
nn.Conv(64, kernel_size=(7, 7), strides=(2, 2), padding='same'),
nn.BatchNorm(not self.training), nn.relu,
lambda x: nn.max_pool(x, window_shape=(3, 3), strides=(2, 2),
padding='same')])A stage is a stack of residual blocks. The first block can downsample and project the skip path; later blocks keep shape.
def block(self, num_residuals, num_channels, first_block=False):
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
blk.append(Residual(num_channels, use_1x1conv=True,
strides=(2, 2), training=self.training))
else:
blk.append(Residual(num_channels, training=self.training))
return nn.Sequential(blk)After the residual stages, global average pooling collapses the spatial map and the final linear layer predicts classes.
def create_net(self):
net = nn.Sequential([self.b1()])
for i, b in enumerate(self.arch):
net.layers.extend([self.block(*b, first_block=(i==0))])
net.layers.extend([nn.Sequential([
# Flax does not provide a GlobalAvg2D layer
lambda x: nn.avg_pool(x, window_shape=x.shape[1:3],
strides=x.shape[1:3], padding='valid'),
lambda x: x.reshape((x.shape[0], -1)),
nn.Dense(self.num_classes)])])
return netFour stages × 2 residual blocks each — same template defines ResNet-34/50/101/152:
Sequential output shape: (1, 24, 24, 64)
Sequential output shape: (1, 24, 24, 64)
Sequential output shape: (1, 12, 12, 128)
Sequential output shape: (1, 6, 6, 256)
Sequential output shape: (1, 3, 3, 512)
Sequential output shape: (1, 10)
The notebook trains a compact ResNet-18 variant on Fashion-MNIST; the point is to validate that the residual-stage template plugs into the same Trainer used by earlier CNNs.
A cleaner variant: each block has multiple parallel paths (cardinality C) instead of one wide one — same parameter budget, better accuracy:
class ResNeXtBlock(nn.Module):
"""The ResNeXt block."""
num_channels: int
groups: int
bot_mul: int
use_1x1conv: bool = False
strides: tuple = (1, 1)
training: bool = True
def setup(self):
bot_channels = int(round(self.num_channels * self.bot_mul))
self.conv1 = nn.Conv(bot_channels, kernel_size=(1, 1),
strides=(1, 1))
self.conv2 = nn.Conv(bot_channels, kernel_size=(3, 3),
strides=self.strides, padding='same',
feature_group_count=self.groups)
self.conv3 = nn.Conv(self.num_channels, kernel_size=(1, 1),
strides=(1, 1))
self.bn1 = nn.BatchNorm(not self.training)
self.bn2 = nn.BatchNorm(not self.training)
self.bn3 = nn.BatchNorm(not self.training)
if self.use_1x1conv:
self.conv4 = nn.Conv(self.num_channels, kernel_size=(1, 1),
strides=self.strides)
self.bn4 = nn.BatchNorm(not self.training)
else:
self.conv4 = None
def __call__(self, X):
Y = nn.relu(self.bn1(self.conv1(X)))
Y = nn.relu(self.bn2(self.conv2(Y)))
Y = self.bn3(self.conv3(Y))
if self.conv4:
X = self.bn4(self.conv4(X))
return nn.relu(Y + X)Grouped convolution cuts the expensive 3×3 channel mixing by a factor of groups, while surrounding 1×1 convolutions let information mix before and after the grouped work.
(4, 96, 96, 32)