from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnpNetwork-in-Network (Lin et al., 2014) introduces two ideas the rest of the field happily adopts:
NiN: regular conv followed by two 1×1 convs; ends in global average pool.
A regular conv followed by two 1×1 convs (with ReLU between) — the “MLP within a conv layer”:
Four NiN blocks at growing channel counts (96, 256, 384, num_classes), with max-pool downsampling between, then global average pooling + flatten → done. No FC layers.
import optax
class NiN(d2l.Classifier):
lr: float = 0.1
num_classes = 10
training: bool = True
def setup(self):
self.net = nn.Sequential([
nin_block(96, kernel_size=(11, 11), strides=(4, 4), padding=(0, 0)),
lambda x: nn.max_pool(x, (3, 3), strides=(2, 2)),
nin_block(256, kernel_size=(5, 5), strides=(1, 1), padding=(2, 2)),
lambda x: nn.max_pool(x, (3, 3), strides=(2, 2)),
nin_block(384, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1)),
lambda x: nn.max_pool(x, (3, 3), strides=(2, 2)),
nn.Dropout(0.5, deterministic=not self.training),
nin_block(self.num_classes, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1)),
lambda x: nn.avg_pool(x, window_shape=x.shape[1:3], strides=x.shape[1:3], padding='valid'), # global avg pooling
lambda x: x.reshape((x.shape[0], -1)) # flatten
])
def configure_optimizers(self):
return optax.sgd(self.lr)Walk a 1×1×224×224 input through; spatial dims shrink, channels grow until the final block produces num_classes channels:
Sequential output shape: (1, 54, 54, 96)
function output shape: (1, 26, 26, 96)
Sequential output shape: (1, 26, 26, 256)
function output shape: (1, 12, 12, 256)
Sequential output shape: (1, 12, 12, 384)
function output shape: (1, 5, 5, 384)
Dropout output shape: (1, 5, 5, 384)
Sequential output shape: (1, 5, 5, 10)
function output shape: (1, 1, 1, 10)
function output shape: (1, 10)
Same Trainer, slightly higher learning rate than the FC nets (no dense layer to overfit on small batches):
The important comparison is parameter economy: accuracy comes from richer convolutional blocks, not a large fully connected head.