from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jaxGoogLeNet (Szegedy et al., 2014) — winner of ImageNet 2014 — introduces a different design axis: width, not just depth.
Each layer is an Inception block that runs multiple filter sizes in parallel (1×1, 3×3, 5×5, plus pool) and concatenates their outputs. The network can choose, layer by layer, which scale of filter is most useful.
Heavy use of 1×1 convs as bottleneck reductions keeps the parameter count manageable despite the multi-branch design.
Four parallel branches at the same spatial size, concatenated along the channel axis:
Inception: four parallel branches, channel-concatenated.
class Inception(nn.Module):
# `c1`--`c4` are the number of output channels for each branch
c1: int
c2: tuple
c3: tuple
c4: int
def setup(self):
# Branch 1
self.b1_1 = nn.Conv(self.c1, kernel_size=(1, 1))
# Branch 2
self.b2_1 = nn.Conv(self.c2[0], kernel_size=(1, 1))
self.b2_2 = nn.Conv(self.c2[1], kernel_size=(3, 3), padding='same')
# Branch 3
self.b3_1 = nn.Conv(self.c3[0], kernel_size=(1, 1))
self.b3_2 = nn.Conv(self.c3[1], kernel_size=(5, 5), padding='same')
# Branch 4
self.b4_1 = lambda x: nn.max_pool(x, window_shape=(3, 3),
strides=(1, 1), padding='same')
self.b4_2 = nn.Conv(self.c4, kernel_size=(1, 1))
def __call__(self, x):
b1 = nn.relu(self.b1_1(x))
b2 = nn.relu(self.b2_2(nn.relu(self.b2_1(x))))
b3 = nn.relu(self.b3_2(nn.relu(self.b3_1(x))))
b4 = nn.relu(self.b4_2(self.b4_1(x)))
return jnp.concatenate((b1, b2, b3, b4), axis=-1)Five sequential “stages” — each a small stack of conv + pool + inception modules — built up methodically. The stem and second stage reduce resolution quickly before the Inception blocks take over:
class GoogleNet(d2l.Classifier):
lr: float = 0.1
num_classes: int = 10
def setup(self):
self.net = nn.Sequential([self.b1(), self.b2(), self.b3(), self.b4(),
self.b5(), nn.Dense(self.num_classes)])
def b1(self):
return nn.Sequential([
nn.Conv(64, kernel_size=(7, 7), strides=(2, 2), padding='same'),
nn.relu,
lambda x: nn.max_pool(x, window_shape=(3, 3), strides=(2, 2),
padding='same')])Stage 3 introduces the repeating pattern: two Inception blocks, then pooling. Channel counts are split across branches, then concatenated back together.
Stage 4 is the compute-heavy middle of the network: five Inception blocks before the next spatial downsample.
def b4(self):
return nn.Sequential([Inception(192, (96, 208), (16, 48), 64),
Inception(160, (112, 224), (24, 64), 64),
Inception(128, (128, 256), (24, 64), 64),
Inception(112, (144, 288), (32, 64), 64),
Inception(256, (160, 320), (32, 128), 128),
lambda x: nn.max_pool(x, window_shape=(3, 3),
strides=(2, 2),
padding='same')])Stage 5 uses global average pooling before the final classifier, then __init__ simply wires b1 through b5 together.
def b5(self):
return nn.Sequential([Inception(256, (160, 320), (32, 128), 128),
Inception(384, (192, 384), (48, 128), 128),
# Flax does not provide a GlobalAvgPool2D layer
lambda x: nn.avg_pool(x,
window_shape=x.shape[1:3],
strides=x.shape[1:3],
padding='valid'),
lambda x: x.reshape((x.shape[0], -1))])For Fashion-MNIST we shrink the input to 96×96 to keep training time reasonable; layer summary on the smaller input:
Sequential output shape: (1, 24, 24, 64)
Sequential output shape: (1, 12, 12, 192)
Sequential output shape: (1, 6, 6, 480)
Sequential output shape: (1, 3, 3, 832)
Sequential output shape: (1, 1024)
Dense output shape: (1, 10)
Notice the pattern: spatial resolution falls at pools, while channel depth grows after concatenating each Inception block’s branches.
The original GoogLeNet has 22 weighted layers (~7M params) — far fewer than VGG (~138M) — yet better ImageNet accuracy.