Reset and update gates

Gated Recurrent Units (GRU)

Gated Recurrent Units

LSTMs work, but they’re heavy: three gates, an input node, a separate cell state. Cho et al. (2014) asked whether the gating idea could be kept while collapsing the bookkeeping.

GRU = two gates, no separate cell state, a single hidden state. Often matches LSTM quality at lower compute.

  • Reset gate \mathbf{R}_t — how much of the past to mix into the candidate hidden state.
  • Update gate \mathbf{Z}_t — convex blend between old hidden state and new candidate.

Setup

from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp

The gates are sigmoid heads of X_t and H_{t-1}:

\mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\quad \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z).

Computing the reset and update gates.

Candidate hidden state

Like a vanilla RNN cell, but \mathbf{H}_{t-1} is filtered by \mathbf{R}_t before entering the recurrence:

\tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + (\mathbf{R}_t \odot \mathbf{H}_{t-1}) \mathbf{W}_{hh} + \mathbf{b}_h).

Computing the candidate hidden state \tilde{\mathbf{H}}_t.

Final hidden state

Convex combination of old state and candidate, ruled by \mathbf{Z}_t:

\mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.

\mathbf{Z}_t \to 1: skip this step. \mathbf{Z}_t \to 0: fully replace with the candidate. \mathbf{R}_t \to 1 recovers a vanilla RNN.

Computing the hidden state \mathbf{H}_t.

From scratch: parameters

Nine matrices and three biases, grouped by gate:

class GRUScratch(d2l.Module):
    num_inputs: int
    num_hiddens: int
    sigma: float = 0.01

    def setup(self):
        init_weight = lambda name, shape: self.param(name,
                                                     nn.initializers.normal(self.sigma),
                                                     shape)
        triple = lambda name : (
            init_weight(f'W_x{name}', (self.num_inputs, self.num_hiddens)),
            init_weight(f'W_h{name}', (self.num_hiddens, self.num_hiddens)),
            self.param(f'b_{name}', nn.initializers.zeros, (self.num_hiddens)))

        self.W_xz, self.W_hz, self.b_z = triple('z')  # Update gate
        self.W_xr, self.W_hr, self.b_r = triple('r')  # Reset gate
        self.W_xh, self.W_hh, self.b_h = triple('h')  # Candidate hidden state

Forward pass

One step computes \mathbf{Z}_t, \mathbf{R}_t, the candidate, and the convex blend — five matmuls per time step versus eight in an LSTM.

def forward(self, inputs, H=None):
    # Use lax.scan primitive instead of looping over the
    # inputs, since scan saves time in jit compilation
    def scan_fn(H, X):
        Z = d2l.sigmoid(d2l.matmul(X, self.W_xz) + d2l.matmul(H, self.W_hz) +
                        self.b_z)
        R = d2l.sigmoid(d2l.matmul(X, self.W_xr) +
                        d2l.matmul(H, self.W_hr) + self.b_r)
        H_tilde = d2l.tanh(d2l.matmul(X, self.W_xh) +
                           d2l.matmul(R * H, self.W_hh) + self.b_h)
        H = Z * H + (1 - Z) * H_tilde
        return H, H  # return carry, y

    if H is None:
        batch_size = inputs.shape[1]
        carry = jnp.zeros((batch_size, self.num_hiddens))
    else:
        carry = H

    # scan takes the scan_fn, initial carry state, xs with leading axis to be scanned
    carry, outputs = jax.lax.scan(scan_fn, carry, inputs)
    return outputs, carry

Training

Same RNNLMScratch head, same Trainer. Often converges to the same perplexity as the LSTM with fewer parameters.

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)

Concise: nn.GRU

Library cell drops into the same RNN scaffold:

class GRU(d2l.RNN):
    num_hiddens: int

    @nn.compact
    def __call__(self, inputs, H=None, training=False):
        if H is None:
            batch_size = inputs.shape[1]
            H = nn.GRUCell(features=self.num_hiddens).initialize_carry(
                jax.random.PRNGKey(0), (batch_size, self.num_hiddens))

        GRU = nn.scan(nn.GRUCell, variable_broadcast="params",
                      in_axes=0, out_axes=0, split_rngs={"params": False})

        H, outputs = GRU(features=self.num_hiddens)(H, inputs)
        return outputs, H

Train it:

gru = GRU(num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)

Decode:

model.predict('it has', 20, data.vocab, trainer.state.params)
'it has of the travell the '

Recap

  • GRU = LSTM with one gate fewer and no separate cell state.
  • Reset gate gates the past before the recurrence; update gate convex-blends old vs. new.
  • Recovers vanilla RNN (\mathbf{R}_t = 1, \mathbf{Z}_t = 0) and identity (\mathbf{Z}_t = 1) as limit cases.
  • Roughly 25 % fewer parameters than LSTM at the same width; comparable perplexity on most LM tasks.