Concise Implementation of Recurrent Neural Networks

Concise RNNs

The same character-level LM, using the framework’s built-in nn.RNN. The cell + unroll + projection from scratch boil down to a few lines:

  • nn.RNN(input_size, hidden_size) handles the recurrence, including hardware-accelerated cuDNN kernels on GPU.
  • Reuse the RNNLMScratch head — it doesn’t care whether the cell is hand-rolled.
  • Same Trainer, same gradient clipping, same data.

End result: faster training, ~5× fewer lines of code, identical mathematics.

The model

Built-in RNN cell + handing off the rest of the LM scaffold to the from-scratch base class:

from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
class RNN(nn.Module):
    """The RNN model implemented with high-level APIs."""
    num_hiddens: int

    @nn.compact
    def __call__(self, inputs, H=None):
        if H is None:
            batch_size = inputs.shape[1]
            H = nn.SimpleCell(features=self.num_hiddens).initialize_carry(
                jax.random.PRNGKey(0), (batch_size, self.num_hiddens))

        SimpleRNN = nn.scan(nn.SimpleCell, variable_broadcast="params",
                            in_axes=0, out_axes=0,
                            split_rngs={"params": False})

        H, outputs = SimpleRNN(features=self.num_hiddens)(H, inputs)
        return outputs, H
class RNNLM(d2l.RNNLMScratch):
    """The RNN-based language model implemented with high-level APIs."""
    training: bool = True

    def setup(self):
        self.linear = nn.Dense(self.vocab_size)

    def output_layer(self, hiddens):
        return d2l.swapaxes(self.linear(hiddens), 0, 1)

    def forward(self, X, state=None):
        embs = self.one_hot(X)
        rnn_outputs, _ = self.rnn(embs, state, self.training)
        return self.output_layer(rnn_outputs)

Sanity check

Untrained model still runs — predictions are random characters, but shapes line up. This check isolates API wiring from learning quality:

Training and decoding

Same Trainer, with gradient_clip_val=1 on the optimizer:

Output looks like simple English-shaped text — same character- level statistics the from-scratch version learned, in much less training time.

Recap

  • nn.RNN is the cell + unroll + (with cuDNN) GPU kernels in one stock layer.
  • Reuse the from-scratch LM wrapper — only the cell changes.
  • Same scaffold accepts nn.LSTM, nn.GRU, etc. — drop-in replacements with better long-range gradient behavior.
  • The framework version trains noticeably faster than the from-scratch version on the same hardware.