from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnpThe same character-level LM, using the framework’s built-in nn.RNN. The cell + unroll + projection from scratch boil down to a few lines:
nn.RNN(input_size, hidden_size) handles the recurrence, including hardware-accelerated cuDNN kernels on GPU.RNNLMScratch head — it doesn’t care whether the cell is hand-rolled.Trainer, same gradient clipping, same data.End result: faster training, ~5× fewer lines of code, identical mathematics.
Built-in RNN cell + handing off the rest of the LM scaffold to the from-scratch base class:
class RNN(nn.Module):
"""The RNN model implemented with high-level APIs."""
num_hiddens: int
@nn.compact
def __call__(self, inputs, H=None):
if H is None:
batch_size = inputs.shape[1]
H = nn.SimpleCell(features=self.num_hiddens).initialize_carry(
jax.random.PRNGKey(0), (batch_size, self.num_hiddens))
SimpleRNN = nn.scan(nn.SimpleCell, variable_broadcast="params",
in_axes=0, out_axes=0,
split_rngs={"params": False})
H, outputs = SimpleRNN(features=self.num_hiddens)(H, inputs)
return outputs, Hclass RNNLM(d2l.RNNLMScratch):
"""The RNN-based language model implemented with high-level APIs."""
training: bool = True
def setup(self):
self.linear = nn.Dense(self.vocab_size)
def output_layer(self, hiddens):
return d2l.swapaxes(self.linear(hiddens), 0, 1)
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state, self.training)
return self.output_layer(rnn_outputs)Untrained model still runs — predictions are random characters, but shapes line up. This check isolates API wiring from learning quality:
Same Trainer, with gradient_clip_val=1 on the optimizer:
Output looks like simple English-shaped text — same character- level statistics the from-scratch version learned, in much less training time.
nn.RNN is the cell + unroll + (with cuDNN) GPU kernels in one stock layer.nn.LSTM, nn.GRU, etc. — drop-in replacements with better long-range gradient behavior.