%matplotlib inline
from d2l import torch as d2l
import torch
def init_adadelta_states(feature_dim):
s_w, s_b = d2l.zeros((feature_dim, 1)), d2l.zeros(1)
delta_w, delta_b = d2l.zeros((feature_dim, 1)), d2l.zeros(1)
return ((s_w, delta_w), (s_b, delta_b))
def adadelta(params, states, hyperparams):
rho, eps = hyperparams['rho'], 1e-5
for p, (s, delta) in zip(params, states):
with torch.no_grad():
# In-place updates via [:]
s[:] = rho * s + (1 - rho) * torch.square(p.grad)
g = (torch.sqrt(delta + eps) / torch.sqrt(s + eps)) * p.grad
p[:] -= g
delta[:] = rho * delta + (1 - rho) * g * g
p.grad.data.zero_()