Reset and update gates

Gated Recurrent Units (GRU)

Gated Recurrent Units

LSTMs work, but they’re heavy: three gates, an input node, a separate cell state. Cho et al. (2014) asked whether the gating idea could be kept while collapsing the bookkeeping.

GRU = two gates, no separate cell state, a single hidden state. Often matches LSTM quality at lower compute.

  • Reset gate \mathbf{R}_t — how much of the past to mix into the candidate hidden state.
  • Update gate \mathbf{Z}_t — convex blend between old hidden state and new candidate.

Setup

from d2l import torch as d2l
import torch
from torch import nn

The gates are sigmoid heads of X_t and H_{t-1}:

\mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\quad \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z).

Computing the reset and update gates.

Candidate hidden state

Like a vanilla RNN cell, but \mathbf{H}_{t-1} is filtered by \mathbf{R}_t before entering the recurrence:

\tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + (\mathbf{R}_t \odot \mathbf{H}_{t-1}) \mathbf{W}_{hh} + \mathbf{b}_h).

Computing the candidate hidden state \tilde{\mathbf{H}}_t.

Final hidden state

Convex combination of old state and candidate, ruled by \mathbf{Z}_t:

\mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.

\mathbf{Z}_t \to 1: skip this step. \mathbf{Z}_t \to 0: fully replace with the candidate. \mathbf{R}_t \to 1 recovers a vanilla RNN.

Computing the hidden state \mathbf{H}_t.

From scratch: parameters

Nine matrices and three biases, grouped by gate:

class GRUScratch(d2l.Module):
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()
        
        init_weight = lambda *shape: nn.Parameter(d2l.randn(*shape) * sigma)
        triple = lambda: (init_weight(num_inputs, num_hiddens),
                          init_weight(num_hiddens, num_hiddens),
                          nn.Parameter(d2l.zeros(num_hiddens)))
        self.W_xz, self.W_hz, self.b_z = triple()  # Update gate
        self.W_xr, self.W_hr, self.b_r = triple()  # Reset gate
        self.W_xh, self.W_hh, self.b_h = triple()  # Candidate hidden state        

Forward pass

One step computes \mathbf{Z}_t, \mathbf{R}_t, the candidate, and the convex blend — five matmuls per time step versus eight in an LSTM.

def forward(self, inputs, H=None):
    if H is None:
        # Initial state with shape: (batch_size, num_hiddens)
        H = d2l.zeros((inputs.shape[1], self.num_hiddens),
                      device=inputs.device)
    outputs = []
    for X in inputs:
        Z = d2l.sigmoid(d2l.matmul(X, self.W_xz) +
                        d2l.matmul(H, self.W_hz) + self.b_z)
        R = d2l.sigmoid(d2l.matmul(X, self.W_xr) + 
                        d2l.matmul(H, self.W_hr) + self.b_r)
        H_tilde = d2l.tanh(d2l.matmul(X, self.W_xh) + 
                           d2l.matmul(R * H, self.W_hh) + self.b_h)
        H = Z * H + (1 - Z) * H_tilde
        outputs.append(H)
    return outputs, H

Training

Same RNNLMScratch head, same Trainer. Often converges to the same perplexity as the LSTM with fewer parameters.

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)

Concise: nn.GRU

Library cell drops into the same RNN scaffold:

class GRU(d2l.RNN):
    def __init__(self, num_inputs, num_hiddens):
        d2l.Module.__init__(self)
        self.save_hyperparameters()
        self.rnn = nn.GRU(num_inputs, num_hiddens)

Train it:

gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)

Decode:

model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has experifice and the '

Recap

  • GRU = LSTM with one gate fewer and no separate cell state.
  • Reset gate gates the past before the recurrence; update gate convex-blends old vs. new.
  • Recovers vanilla RNN (\mathbf{R}_t = 1, \mathbf{Z}_t = 0) and identity (\mathbf{Z}_t = 1) as limit cases.
  • Roughly 25 % fewer parameters than LSTM at the same width; comparable perplexity on most LM tasks.