Stacking from scratch

Deep Recurrent Neural Networks

Deep RNNs

A single RNN layer is already deep in time — but within one time step, input-to-output is just one nonlinearity.

Stacking RNN layers makes the model deep along the layer axis too. Each layer sees the previous layer’s hidden states as its input sequence; topmost layer feeds the readout.

\mathbf{H}_t^{(l)} = \phi_l(\mathbf{H}_t^{(l-1)} \mathbf{W}_{xh}^{(l)} + \mathbf{H}_{t-1}^{(l)} \mathbf{W}_{hh}^{(l)} + \mathbf{b}_h^{(l)}).

Typical sizes: width 64–2048, depth 1–8.

Architecture

Layer l at time t depends on layer l at time t{-}1 and layer l{-}1 at time t.

Setup

from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp

A StackedRNNScratch is just a list of RNNScratch cells — each layer’s input width is num_hiddens (except the bottom layer, which sees raw inputs):

class StackedRNNScratch(d2l.Module):
    num_inputs: int
    num_hiddens: int
    num_layers: int
    sigma: float = 0.01

    def setup(self):
        self.rnns = [d2l.RNNScratch(self.num_inputs if i==0 else self.num_hiddens,
                                    self.num_hiddens, self.sigma)
                     for i in range(self.num_layers)]

The forward pass walks the layers, feeding each layer’s output sequence into the next:

def forward(self, inputs, Hs=None):
    outputs = inputs
    if Hs is None: Hs = [None] * self.num_layers
    for i in range(self.num_layers):
        outputs, Hs[i] = self.rnns[i](outputs, Hs[i])
        if i < self.num_layers - 1:
            outputs = d2l.stack(outputs, 0)
    return outputs, Hs

Training the stacked RNN

Two-layer stack on The Time Machine. Lower learning rate (lr=2) — deeper recurrents are harder to optimize:

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn_block = StackedRNNScratch(num_inputs=len(data.vocab),
                              num_hiddens=32, num_layers=2)
model = d2l.RNNLMScratch(rnn_block, vocab_size=len(data.vocab), lr=2)
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)

Concise: multilayer GRU

nn.GRU(..., num_layers=L, dropout=p) collapses the stack into one library call — and adds dropout between layers, which is the standard regularizer for stacked RNNs:

class GRU(d2l.RNN):
    """The multilayer GRU model."""
    num_hiddens: int
    num_layers: int
    dropout: float = 0

    @nn.compact
    def __call__(self, X, state=None, training=False):
        outputs = X
        new_state = []
        if state is None:
            batch_size = X.shape[1]
            state = [nn.GRUCell(features=self.num_hiddens).initialize_carry(
                jax.random.PRNGKey(0),
                (batch_size, self.num_hiddens))] * self.num_layers

        GRU = nn.scan(nn.GRUCell, variable_broadcast="params",
                      in_axes=0, out_axes=0, split_rngs={"params": False})

        # Introduce a dropout layer after every GRU layer except last
        for i in range(self.num_layers - 1):
            layer_i_state, X = GRU(features=self.num_hiddens)(state[i], outputs)
            new_state.append(layer_i_state)
            X = nn.Dropout(self.dropout, deterministic=not training)(X)

        # Final GRU layer without dropout
        out_state, X = GRU(features=self.num_hiddens)(state[-1], X)
        new_state.append(out_state)
        return X, jnp.array(new_state)

Training and decoding

Two-layer GRU LM, same Trainer:

gru = GRU(num_hiddens=32, num_layers=2)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=2)
trainer.fit(model, data)

Decode from a prefix:

model.predict('it has', 20, data.vocab, trainer.state.params)
'it has go back for any the'

Recap

  • Deep RNNs stack L recurrent layers; layer l’s input is layer l{-}1’s output sequence.
  • Same idea applies to vanilla RNN, LSTM, or GRU cells.
  • Use lower learning rate and (usually) inter-layer dropout — vertical depth makes optimization noticeably trickier.
  • nn.GRU(..., num_layers=L, dropout=p) is the production one-liner.