def train(embed_v, embed_u, data_iter, lr, num_epochs):
key = jax.random.PRNGKey(42)
key, key_v, key_u = jax.random.split(key, 3)
# Initialize parameters
dummy = jnp.ones((1,), dtype=jnp.int32)
params_v = embed_v.init(key_v, dummy)
params_u = embed_u.init(key_u, dummy)
all_params = {'v': params_v, 'u': params_u}
optimizer = optax.adam(lr)
opt_state = optimizer.init(all_params)
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[1, num_epochs])
@jax.jit
def train_step(all_params, opt_state, center, context_negative,
mask, label):
def compute_loss(all_params):
pred = skip_gram(center, context_negative, embed_v, embed_u,
all_params['v'], all_params['u'])
l = (loss(pred.reshape(label.shape), label, mask)
/ mask.sum(axis=1) * mask.shape[1])
return l.sum(), l.size
(loss_val, l_size), grads = jax.value_and_grad(
compute_loss, has_aux=True)(all_params)
updates, opt_state = optimizer.update(grads, opt_state, all_params)
all_params = optax.apply_updates(all_params, updates)
return all_params, opt_state, loss_val, l_size
for epoch in range(num_epochs):
timer, num_batches = d2l.Timer(), len(data_iter)
# Accumulate on device to avoid per-batch host syncs
loss_sum, count = jnp.array(0.0), jnp.array(0, dtype=jnp.int32)
for i, batch in enumerate(data_iter):
center, context_negative, mask, label = batch
all_params, opt_state, loss_val, l_size = train_step(
all_params, opt_state, center, context_negative, mask, label)
loss_sum = loss_sum + loss_val
count = count + l_size
if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
animator.add(epoch + (i + 1) / num_batches,
(float(loss_sum / count),))
total_loss = float(loss_sum)
total_count = int(count)
print(f'loss {total_loss / total_count:.3f}, '
f'{total_count / timer.stop():.1f} tokens/sec')
return all_params