def train_ranking(net, train_iter, test_iter, loss, optimizer, test_seq_iter,
num_users, num_items, num_epochs, devices, evaluator,
candidates, eval_step=1):
timer, hit_rate, auc = d2l.Timer(), 0, 0
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
legend=['test hit rate', 'test AUC'])
for epoch in range(num_epochs):
metric = d2l.Accumulator(3)
for i, values in enumerate(train_iter):
input_data = [v.to(devices[0]) for v in values]
p_pos = net(*input_data[:-1])
p_neg = net(*input_data[:-2], input_data[-1])
ls = loss(p_pos, p_neg)
optimizer.zero_grad()
ls.backward()
optimizer.step()
# Per-batch loss only; accumulating across batches inside `l`
# turned the printed train-loss into a quadratic sum.
metric.add(ls.item(), values[0].shape[0], values[0].numel())
timer.stop()
with torch.no_grad():
if (epoch + 1) % eval_step == 0:
hit_rate, auc = evaluator(net, test_iter, test_seq_iter,
candidates, num_users, num_items,
devices)
animator.add(epoch + 1, (hit_rate, auc))
print(f'train loss {metric[0] / metric[1]:.3f}, '
f'test hit rate {float(hit_rate):.3f}, test AUC {float(auc):.3f}')
print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
f'on {str(devices)}')