Dense block

Densely Connected Networks (DenseNet)

DenseNet concatenates features

DenseNet (Huang et al., 2017) takes the residual idea one step further: instead of adding skip connections, concatenate them.

\mathbf{x}_\ell = f_\ell\bigl([\mathbf{x}_0, \mathbf{x}_1, \dots, \mathbf{x}_{\ell-1}]\bigr).

Every layer in a dense block sees the concatenation of all preceding outputs.

Dense block + transition

Dense block grows channels by concatenation; transition layers (1×1 conv + pool) reset channels between blocks.

Pros: maximum feature reuse, fewer parameters than ResNet for similar accuracy. Cons: memory grows linearly with depth within a block — handled by transitions.

Conv block

A small conv block (BN → ReLU → 3×3 conv) is the unit; a DenseBlock will reuse it repeatedly.

from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jax
class ConvBlock(nn.Module):
    num_channels: int
    training: bool = True

    @nn.compact
    def __call__(self, X):
        Y = nn.relu(nn.BatchNorm(not self.training)(X))
        Y = nn.Conv(self.num_channels, kernel_size=(3, 3), padding=(1, 1))(Y)
        Y = jnp.concatenate((X, Y), axis=-1)
        return Y

Now stack the conv blocks. After each block, concatenate its new features onto the running input, so later blocks see everything computed so far.

class DenseBlock(nn.Module):
    num_convs: int
    num_channels: int
    training: bool = True

    def setup(self):
        layer = []
        for i in range(self.num_convs):
            layer.append(ConvBlock(self.num_channels, self.training))
        self.net = nn.Sequential(layer)

    def __call__(self, X):
        return self.net(X)

Channel growth

A DenseBlock(num_convs=2, num_channels=10) on a 3-channel input grows channels by num_convs * num_channels per block:

blk = DenseBlock(2, 10)
X = jnp.zeros((4, 8, 8, 3))
Y = blk.init_with_output(d2l.get_key(), X)[0]
Y.shape
(4, 8, 8, 23)

Transition layer

Stops the channel explosion between dense blocks: 1×1 conv halves channels, 2×2 avg-pool halves spatial dims:

class TransitionBlock(nn.Module):
    num_channels: int
    training: bool = True

    @nn.compact
    def __call__(self, X):
        X = nn.BatchNorm(not self.training)(X)
        X = nn.relu(X)
        X = nn.Conv(self.num_channels, kernel_size=(1, 1))(X)
        X = nn.avg_pool(X, window_shape=(2, 2), strides=(2, 2))
        return X
blk = TransitionBlock(10)
blk.init_with_output(d2l.get_key(), Y)[0].shape
(4, 4, 4, 10)

The DenseNet model

A standard “stem → dense block → transition → dense block → transition → … → global avg-pool → linear” pipeline:

class DenseNet(d2l.Classifier):
    num_channels: int = 64
    growth_rate: int = 32
    arch: tuple = (4, 4, 4, 4)
    lr: float = 0.1
    num_classes: int = 10
    training: bool = True

    def setup(self):
        self.net = self.create_net()

    def b1(self):
        return nn.Sequential([
            nn.Conv(64, kernel_size=(7, 7), strides=(2, 2), padding='same'),
            nn.BatchNorm(not self.training),
            nn.relu,
            lambda x: nn.max_pool(x, window_shape=(3, 3),
                                  strides=(2, 2), padding='same')
        ])
def create_net(self):
    net = self.b1()
    for i, num_convs in enumerate(self.arch):
        net.layers.extend([DenseBlock(num_convs, self.growth_rate,
                                      training=self.training)])
        # The number of output channels in the previous dense block
        num_channels = self.num_channels + (num_convs * self.growth_rate)
        # A transition layer that halves the number of channels is added
        # between the dense blocks
        if i != len(self.arch) - 1:
            num_channels //= 2
            net.layers.extend([TransitionBlock(num_channels,
                                               training=self.training)])
    net.layers.extend([
        nn.BatchNorm(not self.training),
        nn.relu,
        lambda x: nn.avg_pool(x, window_shape=x.shape[1:3],
                              strides=x.shape[1:3], padding='valid'),
        lambda x: x.reshape((x.shape[0], -1)),
        nn.Dense(self.num_classes)
    ])
    return net

Training

model = DenseNet(lr=0.01)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))
trainer.fit(model, data)

DenseNet hits competitive ImageNet accuracy with far fewer parameters than equivalent ResNets — the concatenation reuse genuinely helps.

Recap

  • ResNet adds skip connections; DenseNet concatenates them.
  • Inside a dense block, layer \ell sees all of layers 0, …, \ell-1 — maximum feature reuse.
  • Transition layers between dense blocks rein in the channel-count explosion via 1×1 conv + pool.
  • Same parameter count → typically better accuracy than ResNet; same accuracy → fewer parameters.