from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jaxDenseNet (Huang et al., 2017) takes the residual idea one step further: instead of adding skip connections, concatenate them.
\mathbf{x}_\ell = f_\ell\bigl([\mathbf{x}_0, \mathbf{x}_1, \dots, \mathbf{x}_{\ell-1}]\bigr).
Every layer in a dense block sees the concatenation of all preceding outputs.
Dense block grows channels by concatenation; transition layers (1×1 conv + pool) reset channels between blocks.
Pros: maximum feature reuse, fewer parameters than ResNet for similar accuracy. Cons: memory grows linearly with depth within a block — handled by transitions.
A small conv block (BN → ReLU → 3×3 conv) is the unit; a DenseBlock will reuse it repeatedly.
Now stack the conv blocks. After each block, concatenate its new features onto the running input, so later blocks see everything computed so far.
A DenseBlock(num_convs=2, num_channels=10) on a 3-channel input grows channels by num_convs * num_channels per block:
(4, 8, 8, 23)
Stops the channel explosion between dense blocks: 1×1 conv halves channels, 2×2 avg-pool halves spatial dims:
A standard “stem → dense block → transition → dense block → transition → … → global avg-pool → linear” pipeline:
class DenseNet(d2l.Classifier):
num_channels: int = 64
growth_rate: int = 32
arch: tuple = (4, 4, 4, 4)
lr: float = 0.1
num_classes: int = 10
training: bool = True
def setup(self):
self.net = self.create_net()
def b1(self):
return nn.Sequential([
nn.Conv(64, kernel_size=(7, 7), strides=(2, 2), padding='same'),
nn.BatchNorm(not self.training),
nn.relu,
lambda x: nn.max_pool(x, window_shape=(3, 3),
strides=(2, 2), padding='same')
])def create_net(self):
net = self.b1()
for i, num_convs in enumerate(self.arch):
net.layers.extend([DenseBlock(num_convs, self.growth_rate,
training=self.training)])
# The number of output channels in the previous dense block
num_channels = self.num_channels + (num_convs * self.growth_rate)
# A transition layer that halves the number of channels is added
# between the dense blocks
if i != len(self.arch) - 1:
num_channels //= 2
net.layers.extend([TransitionBlock(num_channels,
training=self.training)])
net.layers.extend([
nn.BatchNorm(not self.training),
nn.relu,
lambda x: nn.avg_pool(x, window_shape=x.shape[1:3],
strides=x.shape[1:3], padding='valid'),
lambda x: x.reshape((x.shape[0], -1)),
nn.Dense(self.num_classes)
])
return netDenseNet hits competitive ImageNet accuracy with far fewer parameters than equivalent ResNets — the concatenation reuse genuinely helps.