%matplotlib inline
from d2l import mxnet as d2l
from mxnet import np, npx
npx.set_np()
def init_adadelta_states(feature_dim):
s_w, s_b = d2l.zeros((feature_dim, 1)), d2l.zeros(1)
delta_w, delta_b = d2l.zeros((feature_dim, 1)), d2l.zeros(1)
return ((s_w, delta_w), (s_b, delta_b))
def adadelta(params, states, hyperparams):
rho, eps = hyperparams['rho'], 1e-5
for p, (s, delta) in zip(params, states):
# In-place updates via [:]
s[:] = rho * s + (1 - rho) * np.square(p.grad)
g = (np.sqrt(delta + eps) / np.sqrt(s + eps)) * p.grad
p[:] -= g
delta[:] = rho * delta + (1 - rho) * g * g