from d2l import jax as d2l
import jax
from jax import numpy as jnpA recurrent neural network carries a hidden state \mathbf{h}_t across time steps — a learned summary of all input seen so far:
\mathbf{h}_t = \phi(\mathbf{W}_{xh}\mathbf{x}_t + \mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{b}).
Same weights at every step → constant parameter count regardless of sequence length. Unbounded effective context (in principle), no fixed-size window like n-grams.
An RNN with a hidden state.
The naive form: two matrix multiplies, summed:
Array([[ 4.384239 , -5.8578386 , -2.01427 , 0.28860188],
[ 0.36385813, 0.1319584 , -0.2936549 , -0.6778386 ],
[ 2.647139 , -1.2814481 , -1.1859436 , -1.3897879 ]], dtype=float32)
Equivalently — concatenate input and hidden, multiply by the concatenated weight matrix — same result, one matmul:
Array([[ 4.384239 , -5.8578386 , -2.01427 , 0.28860193],
[ 0.3638581 , 0.1319584 , -0.2936549 , -0.6778386 ],
[ 2.647139 , -1.2814481 , -1.1859436 , -1.3897878 ]], dtype=float32)
The “concat then multiply” form is what most framework RNN implementations actually do.
Input “machin”, target “achine” — same RNN, target shifted by one.
The next two sections build this end-to-end (from scratch + concise).