def train_ch11(trainer_fn, states, hyperparams, data_iter,
feature_dim, num_epochs=2):
# Initialization
w = jnp.array(np.random.normal(scale=0.01, size=(feature_dim, 1)),
dtype=jnp.float32)
b = jnp.zeros(1)
net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
# Train
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[0, num_epochs], ylim=[0.22, 0.35])
n, timer = 0, d2l.Timer()
# JIT-fuse the per-batch step (grad + optimizer update) so we don't
# round-trip through Python on every minibatch.
@jax.jit
def step(w, b, X, y):
def loss_fn(w, b):
return d2l.squared_loss(d2l.linreg(X, w, b), y).mean()
grads = jax.grad(loss_fn, argnums=(0, 1))(w, b)
return trainer_fn([w, b], list(grads), states, hyperparams)
# Pre-stack the full dataset on device so the periodic evaluate_loss
# stays inside one compiled call instead of looping in Python.
eval_batches = [(jnp.array(X), jnp.array(y)) for X, y in data_iter]
Xs = jnp.concatenate([X for X, _ in eval_batches], axis=0)
ys = jnp.concatenate([y for _, y in eval_batches], axis=0)
@jax.jit
def full_eval(w, b):
out = d2l.linreg(Xs, w, b)
y_r = ys.reshape(out.shape)
return ((out - y_r) ** 2 / 2).mean()
for _ in range(num_epochs):
for X, y in data_iter:
X, y = jnp.array(X), jnp.array(y)
w, b = step(w, b, X, y)
n += X.shape[0]
if n % 200 == 0:
timer.stop()
animator.add(n/X.shape[0]/len(data_iter),
(float(full_eval(w, b)),))
timer.start()
print(f'loss: {animator.Y[0][-1]:.3f}, {timer.sum()/num_epochs:.3f} sec/epoch')
return timer.cumsum(), animator.Y[0]