%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
np.set_printoptions(2) # Simplify printing accuracyA dense detector should not regress arbitrary boxes from scratch at every location. Anchor boxes turn the problem into residual regression around structured candidate boxes.
Core of SSD, RetinaNet, and the region proposal network in Faster R-CNN. Some modern detectors are anchor-free, but the same ideas — dense classification, localization, and NMS — remain central.
Keep three coordinate systems separate:
In toy code we generate anchors at image pixels for clarity. In production detectors, anchors are usually tied to feature-map locations so each prediction uses the local receptive field.
At each pixel center, generate boxes for n scales and m aspect ratios — but only those involving the smallest scale or smallest ratio, giving n + m - 1 boxes per pixel (not nm):
\text{anchor width} = w s\sqrt{r}, \quad \text{anchor height} = h s / \sqrt{r}.
def multibox_prior(data, sizes, ratios):
"""Generate anchor boxes with different shapes centered on each pixel."""
in_height, in_width = data.shape[-2:]
num_sizes, num_ratios = len(sizes), len(ratios)
boxes_per_pixel = (num_sizes + num_ratios - 1)
size_tensor = jnp.array(sizes)
ratio_tensor = jnp.array(ratios)
# Offsets are required to move the anchor to the center of a pixel. Since
# a pixel has height=1 and width=1, we choose to offset our centers by 0.5
offset_h, offset_w = 0.5, 0.5
steps_h = 1.0 / in_height # Scaled steps in y axis
steps_w = 1.0 / in_width # Scaled steps in x axis
# Generate all center points for the anchor boxes
center_h = (jnp.arange(in_height) + offset_h) * steps_h
center_w = (jnp.arange(in_width) + offset_w) * steps_w
shift_y, shift_x = jnp.meshgrid(center_h, center_w, indexing='ij')
shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)
# Generate `boxes_per_pixel` number of heights and widths that are later
# used to create anchor box corner coordinates (xmin, xmax, ymin, ymax)
w = jnp.concatenate((size_tensor * jnp.sqrt(ratio_tensor[0]),
sizes[0] * jnp.sqrt(ratio_tensor[1:])))\
* in_height / in_width # Handle rectangular inputs
h = jnp.concatenate((size_tensor / jnp.sqrt(ratio_tensor[0]),
sizes[0] / jnp.sqrt(ratio_tensor[1:])))
# Divide by 2 to get half height and half width
anchor_manipulations = jnp.tile(jnp.stack((-w, -h, w, h)).T,
(in_height * in_width, 1)) / 2
# Each center point will have `boxes_per_pixel` number of anchor boxes, so
# generate a grid of all anchor box centers with `boxes_per_pixel` repeats
out_grid = jnp.repeat(jnp.stack([shift_x, shift_y, shift_x, shift_y],
axis=1), boxes_per_pixel, axis=0)
output = out_grid + anchor_manipulations
return jnp.expand_dims(output, axis=0)For a h \times w feature map, total anchors = hw(n+m-1). Stored as a tensor of shape (1, num_anchors, 4):
561 728
(1, 2042040, 4)
Visualize the n + m - 1 anchors centered at a single pixel — different scales and ratios:
def show_bboxes(axes, bboxes, labels=None, colors=None):
"""Show bounding boxes."""
def make_list(obj, default_values=None):
if obj is None:
obj = default_values
elif not isinstance(obj, (list, tuple)):
obj = [obj]
return obj
labels = make_list(labels)
colors = make_list(colors, ['b', 'g', 'r', 'm', 'c'])
for i, bbox in enumerate(bboxes):
color = colors[i % len(colors)]
rect = d2l.bbox_to_rect(d2l.numpy(bbox), color)
axes.add_patch(rect)
if labels and len(labels) > i:
text_color = 'k' if color == 'w' else 'w'
axes.text(rect.xy[0], rect.xy[1], labels[i],
va='center', ha='center', fontsize=9, color=text_color,
bbox=dict(facecolor=color, lw=0))We need a similarity measure between two boxes to know which anchor matches which ground truth.
\text{IoU}(A, B) = \frac{|A \cap B|}{|A \cup B|}.
IoU = intersection area / union area.
def box_iou(boxes1, boxes2):
"""Compute pairwise IoU across two lists of anchor or bounding boxes."""
box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *
(boxes[:, 3] - boxes[:, 1]))
# Shape of `boxes1`, `boxes2`, `areas1`, `areas2`: (no. of boxes1, 4),
# (no. of boxes2, 4), (no. of boxes1,), (no. of boxes2,)
areas1 = box_area(boxes1)
areas2 = box_area(boxes2)
# Shape of `inter_upperlefts`, `inter_lowerrights`, `inters`: (no. of
# boxes1, no. of boxes2, 2)
inter_upperlefts = jnp.maximum(boxes1[:, None, :2], boxes2[:, :2])
inter_lowerrights = jnp.minimum(boxes1[:, None, 2:], boxes2[:, 2:])
inters = jnp.clip(inter_lowerrights - inter_upperlefts, 0)
# Shape of `inter_areas` and `union_areas`: (no. of boxes1, no. of boxes2)
inter_areas = inters[:, :, 0] * inters[:, :, 1]
union_areas = areas1[:, None] + areas2 - inter_areas
return inter_areas / union_areasFor each anchor box, decide which ground-truth box (if any) it should learn to predict. Common rule: greedy assignment by highest IoU, with a threshold (e.g. 0.5) for “positive” matches:
Anchor → GT assignment by IoU.
def assign_anchor_to_bbox(ground_truth, anchors, device, iou_threshold=0.5):
"""Assign closest ground-truth bounding boxes to anchor boxes."""
num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]
jaccard = box_iou(anchors, ground_truth)
anchors_bbox_map = jnp.full((num_anchors,), -1, dtype=jnp.int32)
max_ious = jnp.max(jaccard, axis=1)
indices = jnp.argmax(jaccard, axis=1)
mask = max_ious >= iou_threshold
anchors_bbox_map = jnp.where(mask, indices, anchors_bbox_map)
col_discard = jnp.full((num_anchors,), -1.0)
row_discard = jnp.full((num_gt_boxes,), -1.0)
# Use lax.fori_loop so JIT does not unroll (and re-trace) per gt box
def body(_, carry):
jaccard, anchors_bbox_map = carry
max_idx = jnp.argmax(jaccard)
box_idx = max_idx % num_gt_boxes
anc_idx = max_idx // num_gt_boxes
anchors_bbox_map = anchors_bbox_map.at[anc_idx].set(box_idx)
jaccard = jaccard.at[:, box_idx].set(col_discard)
jaccard = jaccard.at[anc_idx, :].set(row_discard)
return (jaccard, anchors_bbox_map)
_, anchors_bbox_map = jax.lax.fori_loop(
0, num_gt_boxes, body, (jaccard, anchors_bbox_map))
return anchors_bbox_mapA matched anchor learns:
\Big(\frac{(x_b{-}x_a)/w_a - \mu_x}{\sigma_x},\; \frac{(y_b{-}y_a)/h_a - \mu_y}{\sigma_y},\; \frac{\log(w_b/w_a) - \mu_w}{\sigma_w},\; \frac{\log(h_b/h_a) - \mu_h}{\sigma_h}\Big).
The log-scale on width/height keeps gradients stable for both small and large boxes.
def offset_boxes(anchors, assigned_bb, eps=1e-6):
"""Transform for anchor box offsets."""
c_anc = d2l.box_corner_to_center(anchors)
c_assigned_bb = d2l.box_corner_to_center(assigned_bb)
offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]
offset_wh = 5 * d2l.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])
offset = d2l.concat([offset_xy, offset_wh], axis=1)
return offsetdef multibox_target(anchors, labels):
"""Label anchor boxes using ground-truth bounding boxes."""
anchors = anchors.squeeze(axis=0)
num_anchors = anchors.shape[0]
def per_image(label):
anchors_bbox_map = assign_anchor_to_bbox(
label[:, 1:], anchors, None)
bbox_mask = jnp.tile(
jnp.expand_dims((anchors_bbox_map >= 0).astype(jnp.float32),
axis=-1), (1, 4))
valid = anchors_bbox_map >= 0
safe_idx = jnp.maximum(anchors_bbox_map, 0)
class_labels = jnp.where(
valid, label[safe_idx, 0].astype(jnp.int32) + 1,
jnp.zeros(num_anchors, dtype=jnp.int32))
assigned_bb = jnp.where(
valid[:, None], label[safe_idx, 1:],
jnp.zeros((num_anchors, 4), dtype=jnp.float32))
offset = offset_boxes(anchors, assigned_bb) * bbox_mask
return offset.reshape(-1), bbox_mask.reshape(-1), class_labels
# vmap over batch instead of Python for loop: one compiled kernel
# is reused for every image instead of unrolling 32x
bbox_offset, bbox_mask, class_labels = jax.vmap(per_image)(labels)
return (bbox_offset, bbox_mask, class_labels)Hand-pick ground truth (dog, cat) and a few anchors; plot them:
ground_truth = d2l.tensor([[0, 0.1, 0.08, 0.52, 0.92],
[1, 0.55, 0.2, 0.9, 0.88]])
anchors = d2l.tensor([[0, 0.1, 0.2, 0.3], [0.15, 0.2, 0.4, 0.4],
[0.63, 0.05, 0.88, 0.98], [0.66, 0.45, 0.8, 0.8],
[0.57, 0.3, 0.92, 0.9]])
fig = d2l.plt.imshow(img)
show_bboxes(fig.axes, ground_truth[:, 1:] * bbox_scale, ['dog', 'cat'], 'k')
show_bboxes(fig.axes, anchors * bbox_scale, ['0', '1', '2', '3', '4']);Run the labeler:
The returned tensors are easier to read with the contract in mind: class label 0 means background, positive labels are shifted by one, and the offset mask zeros out anchors that should not contribute to the localization loss.
Array([[0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
1., 1., 1., 1.]], dtype=float32)
At inference, the network spits out class scores and offset deltas; invert the offset to recover predicted boxes:
def offset_inverse(anchors, offset_preds):
"""Predict bounding boxes based on anchor boxes with predicted offsets."""
anc = d2l.box_corner_to_center(anchors)
pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]
pred_bbox_wh = d2l.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]
pred_bbox = d2l.concat((pred_bbox_xy, pred_bbox_wh), axis=1)
predicted_bbox = d2l.box_center_to_corner(pred_bbox)
return predicted_bboxA single object generates many high-confidence anchors. NMS keeps the highest-scoring one and suppresses any with \text{IoU} > \tau to it:
def nms(boxes, scores, iou_threshold):
"""Sort confidence scores of predicted bounding boxes."""
# Work in NumPy so the Python while-loop isn't forcing a host/device
# sync every iteration.
boxes_np = np.asarray(boxes)
B = np.argsort(-np.asarray(scores))
keep = [] # Indices of predicted bounding boxes that will be kept
while B.size > 0:
i = int(B[0])
keep.append(i)
if B.size == 1: break
rest = B[1:]
# Pairwise IoU between box i and every remaining box, in NumPy
box_i, rest_boxes = boxes_np[i], boxes_np[rest]
lt = np.maximum(box_i[:2], rest_boxes[:, :2])
rb = np.minimum(box_i[2:], rest_boxes[:, 2:])
inter = np.clip(rb - lt, 0, None).prod(axis=1)
area_i = (box_i[2:] - box_i[:2]).prod()
area_rest = (rest_boxes[:, 2:] - rest_boxes[:, :2]).prod(axis=1)
iou = inter / (area_i + area_rest - inter)
B = rest[iou <= iou_threshold]
return jnp.array(keep, dtype=jnp.int32)def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,
pos_threshold=0.009999999):
"""Predict bounding boxes using non-maximum suppression."""
batch_size = cls_probs.shape[0]
anchors = anchors.squeeze(axis=0)
num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]
out = []
for i in range(batch_size):
cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)
conf, class_id = jnp.max(cls_prob[1:], axis=0), jnp.argmax(
cls_prob[1:], axis=0)
predicted_bb = offset_inverse(anchors, offset_pred)
keep = nms(predicted_bb, conf, nms_threshold)
# Find all non-`keep` indices and set the class to background
all_idx = jnp.arange(num_anchors, dtype=jnp.int32)
combined = jnp.concatenate((keep, all_idx))
unique, counts = jnp.unique(combined, return_counts=True)
non_keep = unique[counts == 1]
all_id_sorted = jnp.concatenate((keep, non_keep))
class_id = class_id.at[non_keep].set(-1)
class_id = class_id[all_id_sorted].astype(jnp.float32)
conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]
# Here `pos_threshold` is a threshold for positive (non-background)
# predictions
below_min_idx = (conf < pos_threshold)
class_id = jnp.where(below_min_idx, -1, class_id)
conf = jnp.where(below_min_idx, 1 - conf, conf)
pred_info = jnp.concatenate((jnp.expand_dims(class_id, axis=1),
jnp.expand_dims(conf, axis=1),
predicted_bb), axis=1)
out.append(pred_info)
return jnp.stack(out)Four overlapping predictions; NMS picks the top-scoring one and suppresses the rest. If a lower-score box covers the same object, it should disappear; if it covers a different object, its IoU should be low enough to survive:
anchors = d2l.tensor([[0.1, 0.08, 0.52, 0.92], [0.08, 0.2, 0.56, 0.95],
[0.15, 0.3, 0.62, 0.91], [0.55, 0.2, 0.9, 0.88]])
offset_preds = d2l.tensor([0] * d2l.size(anchors))
cls_probs = d2l.tensor([[0] * 4, # Predicted background likelihood
[0.9, 0.8, 0.7, 0.1], # Predicted dog likelihood
[0.1, 0.2, 0.3, 0.9]]) # Predicted cat likelihoodEach output row is (class_id, confidence, x1, y1, x2, y2). Rows with class -1 have been suppressed or filtered out; the remaining rows are the detector’s final boxes.
Array([[[ 0. , 0.9 , 0.1 , 0.08, 0.52, 0.92],
[ 1. , 0.9 , 0.55, 0.2 , 0.9 , 0.88],
[-1. , 0.8 , 0.08, 0.2 , 0.56, 0.95],
[-1. , 0.7 , 0.15, 0.3 , 0.62, 0.91]]], dtype=float32)