The recurrence in code

Recurrent Neural Networks

Recurrent Neural Networks

A recurrent neural network carries a hidden state \mathbf{h}_t across time steps — a learned summary of all input seen so far:

\mathbf{h}_t = \phi(\mathbf{W}_{xh}\mathbf{x}_t + \mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{b}).

Same weights at every step → constant parameter count regardless of sequence length. Unbounded effective context (in principle), no fixed-size window like n-grams.

Stateful by design

An RNN with a hidden state.

Setup

from d2l import jax as d2l
import jax
from jax import numpy as jnp

The naive form: two matrix multiplies, summed:

X, W_xh = jax.random.normal(d2l.get_key(), (3, 1)), jax.random.normal(
                                                        d2l.get_key(), (1, 4))
H, W_hh = jax.random.normal(d2l.get_key(), (3, 4)), jax.random.normal(
                                                        d2l.get_key(), (4, 4))
d2l.matmul(X, W_xh) + d2l.matmul(H, W_hh)
Array([[ 4.384239  , -5.8578386 , -2.01427   ,  0.28860188],
       [ 0.36385813,  0.1319584 , -0.2936549 , -0.6778386 ],
       [ 2.647139  , -1.2814481 , -1.1859436 , -1.3897879 ]],      dtype=float32)

Equivalently — concatenate input and hidden, multiply by the concatenated weight matrix — same result, one matmul:

d2l.matmul(d2l.concat((X, H), 1), d2l.concat((W_xh, W_hh), 0))
Array([[ 4.384239  , -5.8578386 , -2.01427   ,  0.28860193],
       [ 0.3638581 ,  0.1319584 , -0.2936549 , -0.6778386 ],
       [ 2.647139  , -1.2814481 , -1.1859436 , -1.3897878 ]],      dtype=float32)

The “concat then multiply” form is what most framework RNN implementations actually do.

As a language model

  • Embedding maps token id → vector \mathbf{x}_t.
  • RNN updates the hidden state \mathbf{h}_t.
  • Linear head projects \mathbf{h}_t to vocab logits; softmax → P(x_{t+1} \mid x_{\le t}).
  • Loss = cross-entropy with the next-token target.

Character LM training

Input “machin”, target “achine” — same RNN, target shifted by one.

The next two sections build this end-to-end (from scratch + concise).

Recap

  • RNN: \mathbf{h}_t = \phi(\mathbf{W}_{xh}\mathbf{x}_t + \mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{b}).
  • Same parameters at every time step; hidden state carries arbitrarily long context (in theory).
  • Trains by backprop through time — gradients flow from \mathbf{h}_T back to every earlier hidden state.
  • Vanilla RNNs suffer from vanishing/exploding gradients on long sequences — fixed by LSTM and GRU in the next chapter.