Single Shot Multibox Detection

Single Shot Detection

Single Shot MultiBox Detector (Liu et al., 2016) is the prototype single-stage detector. One forward pass produces class scores and box offsets for every anchor at every scale; NMS keeps the survivors.

The architecture: a CNN trunk, then a pyramid of feature maps at decreasing resolutions. Each level has its own pair of 1×1-style heads — one for class scores, one for box offsets. Predictions from all levels are concatenated.

SSD = base network + several multiscale feature blocks; each block has its own anchor predictor.

Scaling to objects in images

Objects appear at different pixel sizes. SSD handles this by predicting from several feature maps at once:

  • early maps: many spatial cells, small receptive fields, small anchor boxes;
  • deeper maps: fewer spatial cells, larger receptive fields, larger anchor boxes.

The model does not resize every candidate region. It learns classification and offset heads at each scale, then pools all anchors into one detection set before NMS.

Class and box prediction heads

For a feature map with a anchors per pixel and q classes, the class head is a 3×3 conv with a(q+1) output channels; the box head outputs 4a:

%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
from PIL import Image

def cls_predictor(num_anchors, num_classes):
    return nn.Conv(num_anchors * (num_classes + 1), kernel_size=(3, 3),
                   padding='SAME')
def bbox_predictor(num_anchors):
    return nn.Conv(num_anchors * 4, kernel_size=(3, 3), padding='SAME')

Concatenating across scales

Each level produces predictions of a different shape; flatten and concat them so the loss can run on a single tensor:

def forward(x, block):
    # Flax uses NHWC format; input shape: (N, H, W, C)
    return block.init_with_output(jax.random.PRNGKey(0), x)[0]

Y1 = forward(jnp.zeros((2, 20, 20, 8)), cls_predictor(5, 10))
Y2 = forward(jnp.zeros((2, 10, 10, 16)), cls_predictor(3, 10))
Y1.shape, Y2.shape
((2, 20, 20, 55), (2, 10, 10, 33))
def flatten_pred(pred):
    # Flax output is NHWC, flatten H*W*C
    return pred.reshape(pred.shape[0], -1)

def concat_preds(preds):
    return jnp.concatenate([flatten_pred(p) for p in preds], axis=1)
concat_preds([Y1, Y2]).shape
(2, 25300)

Downsampling block

Halves the feature map resolution between scales — two 3×3 conv-BN-ReLU layers + 2×2 max pool:

class DownSampleBlk(nn.Module):
    num_channels: int

    @nn.compact
    def __call__(self, x, training=False):
        for _ in range(2):
            x = nn.Conv(self.num_channels, kernel_size=(3, 3),
                        padding='SAME')(x)
            x = nn.BatchNorm(use_running_average=not training)(x)
            x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        return x
forward(jnp.zeros((2, 20, 20, 3)), DownSampleBlk(num_channels=10)).shape
(2, 10, 10, 10)

Base network

A small CNN that takes the input image down to the first useful resolution:

class BaseNet(nn.Module):
    @nn.compact
    def __call__(self, x, training=False):
        for num_filters in [16, 32, 64]:
            x = DownSampleBlk(num_channels=num_filters)(x, training)
        return x

forward(jnp.zeros((2, 256, 256, 3)), BaseNet()).shape
(2, 32, 32, 64)

Five-block pyramid

Stack base network + a few downsampling blocks. Each level exposes its feature map for anchor prediction:

def get_blk(i):
    if i == 0:
        return BaseNet()
    elif i == 4:
        return None  # Global max pooling handled in TinySSD
    else:
        return DownSampleBlk(num_channels=128)
def blk_forward(X, blk_params, blk_apply, size, ratio, cls_params,
                cls_apply, bbox_params, bbox_apply, training=False,
                batch_stats=None):
    if blk_apply is not None:
        if batch_stats is not None:
            Y, updates = blk_apply({'params': blk_params,
                                    'batch_stats': batch_stats},
                                   X, training=training,
                                   mutable=['batch_stats'])
        else:
            Y = blk_apply({'params': blk_params}, X, training=training)
            updates = {}
    else:
        # Global max pooling
        Y = X.max(axis=(1, 2), keepdims=True)
        updates = {}
    # Convert NHWC to NCHW for multibox_prior
    Y_nchw = jnp.transpose(Y, (0, 3, 1, 2))
    anchors = d2l.multibox_prior(Y_nchw, sizes=size, ratios=ratio)
    # Keep predictions in NHWC; flatten_pred relies on channel-last layout
    # to align each anchor's class and box predictions with multibox_prior.
    cls_preds = cls_apply({'params': cls_params}, Y)
    bbox_preds = bbox_apply({'params': bbox_params}, Y)
    return (Y, anchors, cls_preds, bbox_preds, updates)

Per-level scales

Bigger anchor scales at deeper levels (small feature map → large receptive field → large anchors):

  • five feature levels use progressively larger anchors;
  • each level predicts class logits and box offsets;
  • predictions are concatenated across levels before loss or decoding.

TinySSD model

The full model is a feature pyramid plus two lightweight heads per level:

\text{image} \rightarrow \{(\text{anchors}_\ell, \text{class}_\ell, \text{box}_\ell)\}_{\ell=1}^{5}.

Showing the whole class definition on a slide hides the idea; the important contract is the output shape and anchor ordering. Every anchor needs one class vector and one four-number offset vector.

TinySSD output shapes

For a 256 \times 256 image, the five feature maps create (32^2 + 16^2 + 8^2 + 4^2 + 1) \times 4 = 5444 anchors. With one foreground class, expect:

  • anchors: (batch, 5444, 4);
  • class logits: (batch, 5444, 2) for background/banana;
  • offsets: (batch, 5444 * 4).
output anchors: (1, 5444, 4)
output class preds: (32, 5444, 2)
output bbox preds: (32, 21776)

Loading data + init

batch_size = 32
train_iter, _ = d2l.load_data_bananas(batch_size)
read 1000 training examples
read 100 validation examples
net = TinySSD(num_classes=1)
dummy_X = jnp.zeros((32, 3, 256, 256))
variables = net.init(jax.random.PRNGKey(0), dummy_X, training=True)
params = variables['params']
batch_stats = variables.get('batch_stats', {})
trainer = optax.sgd(learning_rate=0.2)
opt_state = trainer.init(params)

Multi-task loss

Two loss terms:

  • Classification — cross-entropy over class scores.
  • LocalizationL_1 on box offsets, computed only on positive anchors (ignore the rest).

\mathcal{L} = \text{CE}(\hat{\mathbf{c}}, \mathbf{c}) + \frac{1}{N_+}\sum_i m_i \lVert \hat{\mathbf{t}}_i - \mathbf{t}_i\rVert_1,

where m_i=1 only for anchors matched to an object.

def calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks):
    batch_size, num_classes = cls_preds.shape[0], cls_preds.shape[2]
    cls = optax.softmax_cross_entropy_with_integer_labels(
        cls_preds.reshape(-1, num_classes),
        cls_labels.reshape(-1)).reshape(batch_size, -1).mean(axis=1)
    bbox = jnp.abs(
        (bbox_preds * bbox_masks) -
        (bbox_labels * bbox_masks)).mean(axis=1)
    return cls + bbox
def cls_eval(cls_preds, cls_labels):
    # Because the class prediction results are on the final dimension,
    # `argmax` needs to specify this dimension
    return float((cls_preds.argmax(axis=-1).astype(
        cls_labels.dtype) == cls_labels).sum())

def bbox_eval(bbox_preds, bbox_labels, bbox_masks):
    return float((jnp.abs((bbox_labels - bbox_preds) * bbox_masks)).sum())

Training

Standard SGD loop, two evaluation metrics (class accuracy, box mean abs error). Read them together: class accuracy is dominated by many background anchors, while box error only makes sense on matched positive anchors.

@jax.jit
def train_step(params, batch_stats, opt_state, X, Y):
    def loss_fn(params):
        variables = {'params': params, 'batch_stats': batch_stats}
        (anchors, cls_preds, bbox_preds), updates = net.apply(
            variables, X, training=True, mutable=['batch_stats'])
        bbox_labels, bbox_masks, cls_labels = d2l.multibox_target(anchors, Y)
        l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,
                      bbox_masks)
        return l.mean(), (updates, cls_preds, bbox_preds,
                          bbox_labels, bbox_masks, cls_labels)

    (loss, aux), grads = jax.value_and_grad(
        loss_fn, has_aux=True)(params)
    updates_dict, cls_preds, bbox_preds, bbox_labels, bbox_masks, \
        cls_labels = aux
    new_batch_stats = updates_dict['batch_stats']
    param_updates, opt_state = trainer.update(grads, opt_state, params)
    params = optax.apply_updates(params, param_updates)
    # Compute scalar metrics inside the jit so we don't ship large tensors
    # back to the host every step.
    cls_correct = (cls_preds.argmax(axis=-1).astype(cls_labels.dtype)
                   == cls_labels).sum()
    cls_count = jnp.array(cls_labels.size, dtype=cls_correct.dtype)
    bbox_abs_sum = jnp.abs((bbox_labels - bbox_preds) * bbox_masks).sum()
    bbox_count = jnp.array(bbox_labels.size, dtype=bbox_abs_sum.dtype)
    return (params, new_batch_stats, opt_state, loss,
            cls_correct, cls_count, bbox_abs_sum, bbox_count)

num_epochs, timer = 20, d2l.Timer()
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                        legend=['class error', 'bbox mae'])
for epoch in range(num_epochs):
    # Sum of training accuracy, no. of examples in sum of training accuracy,
    # Sum of absolute error, no. of examples in sum of absolute error
    cls_correct_sum = 0.0
    cls_total = 0.0
    bbox_abs_total = 0.0
    bbox_total = 0.0
    for features, target in train_iter:
        timer.start()
        X, Y = jnp.asarray(features), jnp.asarray(target)
        (params, batch_stats, opt_state, loss,
         cls_correct, cls_count, bbox_abs_sum, bbox_count) = train_step(
            params, batch_stats, opt_state, X, Y)
        cls_correct_sum += cls_correct
        cls_total += cls_count
        bbox_abs_total += bbox_abs_sum
        bbox_total += bbox_count
    cls_err = float(1 - cls_correct_sum / cls_total)
    bbox_mae = float(bbox_abs_total / bbox_total)
    animator.add(epoch + 1, (cls_err, bbox_mae))
print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')
print(f'{len(train_iter) * batch_size / timer.stop():.1f} examples/sec on '
      f'{str(jax.devices()[0])}')

class err 3.59e-03, bbox mae 3.56e-03
7979.4 examples/sec on cuda:0

Inference

Forward pass → anchors + class scores + offsets → invert offsets → NMS → keep boxes above a confidence threshold:

from PIL import Image as PILImage
img_pil = PILImage.open('../img/banana.jpg')
img = jnp.array(img_pil)
X = jnp.transpose(img, (2, 0, 1)).astype(jnp.float32)
X = jnp.expand_dims(X, axis=0)
def predict(X):
    variables = {'params': params, 'batch_stats': batch_stats}
    anchors, cls_preds, bbox_preds = net.apply(variables, X, training=False)
    cls_probs = jax.nn.softmax(cls_preds, axis=2).transpose(0, 2, 1)
    output = d2l.multibox_detection(cls_probs, bbox_preds, anchors)
    idx = [i for i, row in enumerate(output[0]) if row[0] != -1]
    return output[0, idx]

output = predict(X)

Detect bananas

Visualize all predictions with confidence ≥ 0.9. The useful thing to notice is not the raw tensor length, but whether NMS leaves one tight high-confidence box per banana:

def display(img, output, threshold):
    d2l.set_figsize((5, 5))
    fig = d2l.plt.imshow(img)
    for row in output:
        score = float(row[1])
        if score < threshold:
            continue
        h, w = img.shape[:2]
        bbox = [row[2:6] * jnp.array((w, h, w, h))]
        d2l.show_bboxes(fig.axes, bbox, '%.2f' % score, 'w')

display(img, output, threshold=0.9)

def smooth_l1(data, scalar):
    out = []
    for i in data:
        if abs(i) < 1 / (scalar ** 2):
            out.append(((scalar * i) ** 2) / 2)
        else:
            out.append(abs(i) - 0.5 / (scalar ** 2))
    return jnp.array(out)

sigmas = [10, 1, 0.5]
lines = ['-', '--', '-.']
x = jnp.arange(-2, 2, 0.1)
d2l.set_figsize()

for l, s in zip(lines, sigmas):
    y = smooth_l1(x, scalar=s)
    d2l.plt.plot(x, y, l, label='sigma=%.1f' % s)
d2l.plt.legend();

def focal_loss(gamma, x):
    return -(1 - x) ** gamma * jnp.log(x)

x = jnp.arange(0.01, 1, 0.01)
for l, gamma in zip(lines, [0, 1, 5]):
    y = d2l.plt.plot(x, focal_loss(gamma, x), l, label='gamma=%.1f' % gamma)
d2l.plt.legend();

Recap

  • SSD = base CNN + multiscale feature pyramid + per-level class & offset heads.
  • One forward pass → all anchor predictions; NMS at the end. No region proposal step.
  • Loss = class cross-entropy + offset L_1, only on positive anchors.
  • SSD and RetinaNet are anchor-based dense single-stage detectors. YOLO is a related single-stage family, while modern anchor-free detectors remove explicit anchors but keep dense classification/localization over feature maps.