from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnpLSTMs work, but they’re heavy: three gates, an input node, a separate cell state. Cho et al. (2014) asked whether the gating idea could be kept while collapsing the bookkeeping.
GRU = two gates, no separate cell state, a single hidden state. Often matches LSTM quality at lower compute.
The gates are sigmoid heads of X_t and H_{t-1}:
\mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\quad \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z).
Computing the reset and update gates.
Like a vanilla RNN cell, but \mathbf{H}_{t-1} is filtered by \mathbf{R}_t before entering the recurrence:
\tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + (\mathbf{R}_t \odot \mathbf{H}_{t-1}) \mathbf{W}_{hh} + \mathbf{b}_h).
Computing the candidate hidden state \tilde{\mathbf{H}}_t.
Convex combination of old state and candidate, ruled by \mathbf{Z}_t:
\mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.
\mathbf{Z}_t \to 1: skip this step. \mathbf{Z}_t \to 0: fully replace with the candidate. \mathbf{R}_t \to 1 recovers a vanilla RNN.
Computing the hidden state \mathbf{H}_t.
Nine matrices and three biases, grouped by gate:
class GRUScratch(d2l.Module):
num_inputs: int
num_hiddens: int
sigma: float = 0.01
def setup(self):
init_weight = lambda name, shape: self.param(name,
nn.initializers.normal(self.sigma),
shape)
triple = lambda name : (
init_weight(f'W_x{name}', (self.num_inputs, self.num_hiddens)),
init_weight(f'W_h{name}', (self.num_hiddens, self.num_hiddens)),
self.param(f'b_{name}', nn.initializers.zeros, (self.num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple('z') # Update gate
self.W_xr, self.W_hr, self.b_r = triple('r') # Reset gate
self.W_xh, self.W_hh, self.b_h = triple('h') # Candidate hidden stateOne step computes \mathbf{Z}_t, \mathbf{R}_t, the candidate, and the convex blend — five matmuls per time step versus eight in an LSTM.
def forward(self, inputs, H=None):
# Use lax.scan primitive instead of looping over the
# inputs, since scan saves time in jit compilation
def scan_fn(H, X):
Z = d2l.sigmoid(d2l.matmul(X, self.W_xz) + d2l.matmul(H, self.W_hz) +
self.b_z)
R = d2l.sigmoid(d2l.matmul(X, self.W_xr) +
d2l.matmul(H, self.W_hr) + self.b_r)
H_tilde = d2l.tanh(d2l.matmul(X, self.W_xh) +
d2l.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
return H, H # return carry, y
if H is None:
batch_size = inputs.shape[1]
carry = jnp.zeros((batch_size, self.num_hiddens))
else:
carry = H
# scan takes the scan_fn, initial carry state, xs with leading axis to be scanned
carry, outputs = jax.lax.scan(scan_fn, carry, inputs)
return outputs, carrySame RNNLMScratch head, same Trainer. Often converges to the same perplexity as the LSTM with fewer parameters.
Library cell drops into the same RNN scaffold:
class GRU(d2l.RNN):
num_hiddens: int
@nn.compact
def __call__(self, inputs, H=None, training=False):
if H is None:
batch_size = inputs.shape[1]
H = nn.GRUCell(features=self.num_hiddens).initialize_carry(
jax.random.PRNGKey(0), (batch_size, self.num_hiddens))
GRU = nn.scan(nn.GRUCell, variable_broadcast="params",
in_axes=0, out_axes=0, split_rngs={"params": False})
H, outputs = GRU(features=self.num_hiddens)(H, inputs)
return outputs, HTrain it: