%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import numpy as np
def init_adam_states(feature_dim):
v_w, v_b = jnp.zeros((feature_dim, 1)), jnp.zeros(1)
s_w, s_b = jnp.zeros((feature_dim, 1)), jnp.zeros(1)
return [(v_w, s_w), (v_b, s_b)]
def adam(params, grads, states, hyperparams):
beta1, beta2, eps = 0.9, 0.999, 1e-6
for i, (p, (v, s), grad) in enumerate(zip(params, states, grads)):
v = beta1 * v + (1 - beta1) * grad
s = beta2 * s + (1 - beta2) * jnp.square(grad)
v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
params[i] = p - hyperparams['lr'] * v_bias_corr / (jnp.sqrt(s_bias_corr)
+ eps)
states[i] = (v, s)
hyperparams['t'] += 1
return params[0], params[1]