from d2l import tensorflow as d2l
import tensorflow as tfLSTMs work, but they’re heavy: three gates, an input node, a separate cell state. Cho et al. (2014) asked whether the gating idea could be kept while collapsing the bookkeeping.
GRU = two gates, no separate cell state, a single hidden state. Often matches LSTM quality at lower compute.
The gates are sigmoid heads of X_t and H_{t-1}:
\mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\quad \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z).
Computing the reset and update gates.
Like a vanilla RNN cell, but \mathbf{H}_{t-1} is filtered by \mathbf{R}_t before entering the recurrence:
\tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + (\mathbf{R}_t \odot \mathbf{H}_{t-1}) \mathbf{W}_{hh} + \mathbf{b}_h).
Computing the candidate hidden state \tilde{\mathbf{H}}_t.
Convex combination of old state and candidate, ruled by \mathbf{Z}_t:
\mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.
\mathbf{Z}_t \to 1: skip this step. \mathbf{Z}_t \to 0: fully replace with the candidate. \mathbf{R}_t \to 1 recovers a vanilla RNN.
Computing the hidden state \mathbf{H}_t.
Nine matrices and three biases, grouped by gate:
class GRUScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: tf.Variable(d2l.normal(shape) * sigma)
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
tf.Variable(d2l.zeros(num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple() # Update gate
self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state One step computes \mathbf{Z}_t, \mathbf{R}_t, the candidate, and the convex blend — five matmuls per time step versus eight in an LSTM.
def forward(self, inputs, H=None):
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
H = tf.zeros((tf.shape(inputs)[1], self.num_hiddens))
outputs = []
for X in tf.unstack(inputs):
Z = d2l.sigmoid(d2l.matmul(X, self.W_xz) +
d2l.matmul(H, self.W_hz) + self.b_z)
R = d2l.sigmoid(d2l.matmul(X, self.W_xr) +
d2l.matmul(H, self.W_hr) + self.b_r)
H_tilde = d2l.tanh(d2l.matmul(X, self.W_xh) +
d2l.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
outputs.append(H)
return outputs, HSame RNNLMScratch head, same Trainer. Often converges to the same perplexity as the LSTM with fewer parameters.
Library cell drops into the same RNN scaffold:
Train it: