from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnpA single RNN layer is already deep in time — but within one time step, input-to-output is just one nonlinearity.
Stacking RNN layers makes the model deep along the layer axis too. Each layer sees the previous layer’s hidden states as its input sequence; topmost layer feeds the readout.
\mathbf{H}_t^{(l)} = \phi_l(\mathbf{H}_t^{(l-1)} \mathbf{W}_{xh}^{(l)} + \mathbf{H}_{t-1}^{(l)} \mathbf{W}_{hh}^{(l)} + \mathbf{b}_h^{(l)}).
Typical sizes: width 64–2048, depth 1–8.
Layer l at time t depends on layer l at time t{-}1 and layer l{-}1 at time t.
A StackedRNNScratch is just a list of RNNScratch cells — each layer’s input width is num_hiddens (except the bottom layer, which sees raw inputs):
The forward pass walks the layers, feeding each layer’s output sequence into the next:
Two-layer stack on The Time Machine. Lower learning rate (lr=2) — deeper recurrents are harder to optimize:
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn_block = StackedRNNScratch(num_inputs=len(data.vocab),
num_hiddens=32, num_layers=2)
model = d2l.RNNLMScratch(rnn_block, vocab_size=len(data.vocab), lr=2)
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)nn.GRU(..., num_layers=L, dropout=p) collapses the stack into one library call — and adds dropout between layers, which is the standard regularizer for stacked RNNs:
class GRU(d2l.RNN):
"""The multilayer GRU model."""
num_hiddens: int
num_layers: int
dropout: float = 0
@nn.compact
def __call__(self, X, state=None, training=False):
outputs = X
new_state = []
if state is None:
batch_size = X.shape[1]
state = [nn.GRUCell(features=self.num_hiddens).initialize_carry(
jax.random.PRNGKey(0),
(batch_size, self.num_hiddens))] * self.num_layers
GRU = nn.scan(nn.GRUCell, variable_broadcast="params",
in_axes=0, out_axes=0, split_rngs={"params": False})
# Introduce a dropout layer after every GRU layer except last
for i in range(self.num_layers - 1):
layer_i_state, X = GRU(features=self.num_hiddens)(state[i], outputs)
new_state.append(layer_i_state)
X = nn.Dropout(self.dropout, deterministic=not training)(X)
# Final GRU layer without dropout
out_state, X = GRU(features=self.num_hiddens)(state[-1], X)
new_state.append(out_state)
return X, jnp.array(new_state)Two-layer GRU LM, same Trainer:
nn.GRU(..., num_layers=L, dropout=p) is the production one-liner.