from d2l import jax as d2l
from flax import linen as nnWe’ve seen a sequence of hand-designed architectures (LeNet → AlexNet → VGG → GoogLeNet → ResNet → DenseNet) — each a hypothesis about what makes nets work.
Can we design networks more systematically?
RegNet (Radosavovic et al., 2020):
AnyNet) — same template, free hyperparameters.Simple closed-form rules (“width grows linearly with stage”) outperform years of expert tuning.
The AnyNet design space.
Stem (low-level conv) → 4 stages of residual blocks → head (global pool + linear). Each stage’s depth, width, group count are free parameters:
The stem is deliberately plain: one stride-2 3×3 convolution, BatchNorm, ReLU. Its job is to halve resolution and create the first feature channels before the repeated stages begin.
class AnyNet(d2l.Classifier):
arch: tuple
stem_channels: int
lr: float = 0.1
num_classes: int = 10
training: bool = True
def setup(self):
self.net = self.create_net()
def stem(self, num_channels):
return nn.Sequential([
nn.Conv(num_channels, kernel_size=(3, 3), strides=(2, 2),
padding=(1, 1)),
nn.BatchNorm(not self.training),
nn.relu
])Each stage repeats the same ResNeXt block. The first block uses stride 2 and a 1×1 skip projection to change resolution and channel count; the rest preserve shape.
def stage(self, depth, num_channels, groups, bot_mul):
blk = []
for i in range(depth):
if i == 0:
blk.append(d2l.ResNeXtBlock(num_channels, groups, bot_mul,
use_1x1conv=True, strides=(2, 2), training=self.training))
else:
blk.append(d2l.ResNeXtBlock(num_channels, groups, bot_mul,
training=self.training))
return nn.Sequential(blk)The architecture tuple supplies (depth, channels, groups, bottleneck) per stage. The head is the now-standard global average pool + linear classifier.
def create_net(self):
net = nn.Sequential([self.stem(self.stem_channels)])
for i, s in enumerate(self.arch):
net.layers.extend([self.stage(*s)])
net.layers.extend([nn.Sequential([
lambda x: nn.avg_pool(x, window_shape=x.shape[1:3],
strides=x.shape[1:3], padding='valid'),
lambda x: x.reshape((x.shape[0], -1)),
nn.Dense(self.num_classes)])])
return netComparing error empirical distribution functions of design spaces.
RegNet narrows AnyNet with simple constraints: stage widths grow approximately linearly, bottleneck ratios stay fixed, and group widths are shared across stages. The result is a smaller search space with better probability of good models.
The paper’s empirical findings collapse to: width grows linearly with stage, depth stays roughly constant, ResNeXt-style groups. A scaled-down version for Fashion-MNIST:
The architecture is competitive with hand-designed ResNets at similar parameter counts — and the discovery process scales trivially with compute.