%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import numpy as np
def init_adadelta_states(feature_dim):
s_w, s_b = jnp.zeros((feature_dim, 1)), jnp.zeros(1)
delta_w, delta_b = jnp.zeros((feature_dim, 1)), jnp.zeros(1)
return [(s_w, delta_w), (s_b, delta_b)]
def adadelta(params, grads, states, hyperparams):
rho, eps = hyperparams['rho'], 1e-5
for i, (p, (s, delta), grad) in enumerate(zip(params, states, grads)):
s = rho * s + (1 - rho) * jnp.square(grad)
g = (jnp.sqrt(delta + eps) / jnp.sqrt(s + eps)) * grad
params[i] = p - g
states[i] = (s, rho * delta + (1 - rho) * g * g)
return params[0], params[1]