from d2l import jax as d2l
import jax
from jax import numpy as jnp
import numpy as np
X = jnp.arange(16.).reshape(1, 1, 4, 4)
XSSD does it all in one forward pass. The R-CNN family takes a different approach: first propose regions of interest, then classify and refine each one. Slower per image but historically more accurate and easier to extend (masks, keypoints).
The lineage:
For each of ~2k selective-search proposals, warp to fixed size, run a CNN, classify with an SVM, regress a refined box. Conceptually clear, computationally horrible — 2k forward passes per image:
R-CNN: per-proposal forward passes.
One forward pass on the whole image. Proposals come from the same selective search, but they index into the shared feature map via RoI pooling, which crops and resizes a variable rectangle to a fixed-size feature:
Fast R-CNN: shared backbone + RoI pooling per proposal.
Variable rectangle in feature space → fixed grid (e.g. 2 \times 2). Each output cell is the max over its sub-region of the rectangle. Differentiable, fast, batchable:
2 \times 2 RoI pooling: max-pool each sub-region of the proposal to a fixed-size output.
Array([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]]], dtype=float32)
Each RoI row stores (batch_id, x1, y1, x2, y2) in input-image coordinates. spatial_scale maps those coordinates onto the shared feature map before pooling:
No matter how large the proposal is, the result has fixed shape (num_rois, channels, pooled_h, pooled_w). That fixed-size tensor is what lets Fast R-CNN batch all region heads:
# JAX does not have a built-in ROI pooling operator, so we implement it
def roi_pool(X, rois, pooled_size, spatial_scale):
num_rois = rois.shape[0]
c = X.shape[1]
ph, pw = pooled_size
outputs = []
for i in range(num_rois):
roi = rois[i]
batch_idx = int(roi[0])
x1 = int(np.round(float(roi[1]) * spatial_scale))
y1 = int(np.round(float(roi[2]) * spatial_scale))
x2 = int(np.round(float(roi[3]) * spatial_scale))
y2 = int(np.round(float(roi[4]) * spatial_scale))
x2, y2 = max(x2, x1 + 1), max(y2, y1 + 1)
roi_feat = X[batch_idx, :, y1:y2, x1:x2]
# Divide into ph x pw bins and take the max of each
h, w = roi_feat.shape[1], roi_feat.shape[2]
bin_h = np.linspace(0, h, ph + 1).astype(int)
bin_w = np.linspace(0, w, pw + 1).astype(int)
pooled = jnp.zeros((c, ph, pw))
for pi in range(ph):
for pj in range(pw):
sub = roi_feat[:, bin_h[pi]:bin_h[pi+1], bin_w[pj]:bin_w[pj+1]]
pooled = pooled.at[:, pi, pj].set(sub.max(axis=(1, 2)))
outputs.append(pooled)
return jnp.stack(outputs, axis=0).reshape(num_rois, c, ph, pw)
roi_pool(X, rois, pooled_size=(2, 2), spatial_scale=0.1)Array([[[[ 0., 1.],
[ 4., 5.]]],
[[[ 4., 6.],
[ 8., 10.]]]], dtype=float32)
Replace selective search with a learnable Region Proposal Network. The RPN is a small CNN head sharing the same backbone — it proposes anchors that the second-stage head classifies and refines. End-to-end trainable:
Faster R-CNN: RPN replaces selective search; one network does both stages.
Add a third per-RoI head — a small FCN that produces a binary mask. Switching from RoI pool to RoI align (no quantization rounding) was crucial for getting masks sharp enough to be useful:
Mask R-CNN: Faster R-CNN + per-RoI mask FCN.