num_epochs, timer = 20, d2l.Timer()
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
legend=['class error', 'bbox mae'])
# Anchors depend only on the input *shape*, so for a fixed input size
# they're constant across the epoch. Grab them once here and reuse
# them in every batch's multibox_target call — saves one forward pass
# per step.
sample_features, _ = next(iter(train_iter))
sample_features = tf.cast(sample_features, tf.float32)
anchors_static, _, _ = net(sample_features, training=False)
# `multibox_target` mixes TF ops with NumPy / `.numpy()` calls, which
# can't run inside `@tf.function`. We keep it eager and only put the
# pure-TF gradient step under @tf.function. The wrapper closes over
# `net` so its trainable variables aren't @tf.function arguments
# (which would re-trace each call) — same pattern that worked for
# gan / dcgan.
@tf.function(reduce_retracing=True)
def _grad_step(features, bbox_labels, bbox_masks, cls_labels):
with tf.GradientTape() as tape:
_, cls_preds, bbox_preds = net(features, training=True)
loss = net._compute_ssd_loss(cls_preds, cls_labels, bbox_preds,
bbox_labels, bbox_masks)
grads = tape.gradient(loss, net.trainable_variables)
net.optimizer.apply_gradients(zip(grads, net.trainable_variables))
cls_correct = tf.reduce_sum(
tf.cast(tf.argmax(cls_preds, axis=-1) ==
tf.cast(cls_labels, tf.int64), tf.int64))
cls_total = tf.cast(tf.size(cls_labels), tf.int64)
bbox_abs = tf.reduce_sum(
tf.abs((bbox_preds * bbox_masks) - (bbox_labels * bbox_masks)))
bbox_total = tf.cast(tf.size(bbox_labels), tf.float32)
return loss, cls_correct, cls_total, bbox_abs, bbox_total
for epoch in range(num_epochs):
# Accumulate metrics as TF tensors so the inner loop never forces
# a host sync (the per-batch `int(logs['cls_correct'])` was a
# significant bottleneck before).
cls_correct_sum = tf.zeros((), dtype=tf.int64)
cls_total = tf.zeros((), dtype=tf.int64)
bbox_abs_total = tf.zeros((), dtype=tf.float32)
bbox_total = tf.zeros((), dtype=tf.float32)
for features, target in train_iter:
timer.start()
features = tf.cast(features, tf.float32)
target = tf.cast(target, tf.float32)
bbox_labels, bbox_masks, cls_labels = d2l.multibox_target(
anchors_static, target)
_, cc, ct, ba, bt = _grad_step(features, bbox_labels,
bbox_masks, cls_labels)
cls_correct_sum += cc
cls_total += ct
bbox_abs_total += ba
bbox_total += bt
cls_err = float(1 - cls_correct_sum / cls_total)
bbox_mae = float(bbox_abs_total / bbox_total)
animator.add(epoch + 1, (cls_err, bbox_mae))
print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')
print(f'{len(train_iter) * batch_size / timer.stop():.1f} examples/sec')