@jax.jit
def train_step(params, batch_stats, opt_state, X, Y):
def loss_fn(params):
variables = {'params': params, 'batch_stats': batch_stats}
(anchors, cls_preds, bbox_preds), updates = net.apply(
variables, X, training=True, mutable=['batch_stats'])
bbox_labels, bbox_masks, cls_labels = d2l.multibox_target(anchors, Y)
l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,
bbox_masks)
return l.mean(), (updates, cls_preds, bbox_preds,
bbox_labels, bbox_masks, cls_labels)
(loss, aux), grads = jax.value_and_grad(
loss_fn, has_aux=True)(params)
updates_dict, cls_preds, bbox_preds, bbox_labels, bbox_masks, \
cls_labels = aux
new_batch_stats = updates_dict['batch_stats']
param_updates, opt_state = trainer.update(grads, opt_state, params)
params = optax.apply_updates(params, param_updates)
# Compute scalar metrics inside the jit so we don't ship large tensors
# back to the host every step.
cls_correct = (cls_preds.argmax(axis=-1).astype(cls_labels.dtype)
== cls_labels).sum()
cls_count = jnp.array(cls_labels.size, dtype=cls_correct.dtype)
bbox_abs_sum = jnp.abs((bbox_labels - bbox_preds) * bbox_masks).sum()
bbox_count = jnp.array(bbox_labels.size, dtype=bbox_abs_sum.dtype)
return (params, new_batch_stats, opt_state, loss,
cls_correct, cls_count, bbox_abs_sum, bbox_count)
num_epochs, timer = 20, d2l.Timer()
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
legend=['class error', 'bbox mae'])
for epoch in range(num_epochs):
# Sum of training accuracy, no. of examples in sum of training accuracy,
# Sum of absolute error, no. of examples in sum of absolute error
cls_correct_sum = 0.0
cls_total = 0.0
bbox_abs_total = 0.0
bbox_total = 0.0
for features, target in train_iter:
timer.start()
X, Y = jnp.asarray(features), jnp.asarray(target)
(params, batch_stats, opt_state, loss,
cls_correct, cls_count, bbox_abs_sum, bbox_count) = train_step(
params, batch_stats, opt_state, X, Y)
cls_correct_sum += cls_correct
cls_total += cls_count
bbox_abs_total += bbox_abs_sum
bbox_total += bbox_count
cls_err = float(1 - cls_correct_sum / cls_total)
bbox_mae = float(bbox_abs_total / bbox_total)
animator.add(epoch + 1, (cls_err, bbox_mae))
print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')
print(f'{len(train_iter) * batch_size / timer.stop():.1f} examples/sec on '
f'{str(jax.devices()[0])}')