Recurrent Neural Network Implementation from Scratch

RNNs from Scratch

A character-level language model on The Time Machine, with nothing but tensor ops. Four pieces:

  1. RNN cell — the recurrence \mathbf{h}_t = \tanh(\mathbf{W}_{xh} \mathbf{x}_t + \mathbf{W}_{hh} \mathbf{h}_{t-1} + \mathbf{b}).
  2. LM wrapper — one-hot tokens, project hidden states to vocab logits.
  3. Gradient clipping — keep BPTT gradients bounded.
  4. Training + decoding to generate continuations.

Char-level keeps the vocab tiny (~30 tokens) — embedding- free models that fit in a notebook.

The RNN cell

Parameters: \mathbf{W}_{xh}, \mathbf{W}_{hh}, \mathbf{b}. Initialize randomly, scaled to keep activations sensible:

%matplotlib inline
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
import math
E0524 02:49:35.971119 12206 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:49:35.971521 12206 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
class RNNScratch(nn.Module):
    """The RNN model implemented from scratch."""
    num_inputs: int
    num_hiddens: int
    sigma: float = 0.01

    def setup(self):
        self.W_xh = self.param('W_xh', nn.initializers.normal(self.sigma),
                               (self.num_inputs, self.num_hiddens))
        self.W_hh = self.param('W_hh', nn.initializers.normal(self.sigma),
                               (self.num_hiddens, self.num_hiddens))
        self.b_h = self.param('b_h', nn.initializers.zeros, (self.num_hiddens,))

Forward, unrolled

Walk a length-T input one step at a time, carrying the hidden state forward:

@d2l.add_to_class(RNNScratch)
def __call__(self, inputs, state=None):
    if state is not None:
        state, = state
    outputs = []
    for X in inputs:  # Shape of inputs: (num_steps, batch_size, num_inputs) 
        state = d2l.tanh(d2l.matmul(X, self.W_xh) + (
            d2l.matmul(state, self.W_hh) if state is not None else 0)
                         + self.b_h)
        outputs.append(state)
    return outputs, state
batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100
rnn = RNNScratch(num_inputs, num_hiddens)
X = d2l.ones((num_steps, batch_size, num_inputs))
(outputs, state), _ = rnn.init_with_output(d2l.get_key(), X)
E0524 02:49:38.435383   12206 cuda_blas.cc:196] failed to create cublas handle: the resource allocation failed
E0524 02:49:38.435418 12206 cuda_blas.cc:199] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your ...
E0524 02:49:38.860147   12206 cuda_blas.cc:196] failed to create cublas handle: the resource allocation failed
E0524 02:49:38.860183 12206 cuda_blas.cc:199] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your ...

Sanity check on output shapes:

def check_len(a, n):
    """Check the length of a list."""
    assert len(a) == n, f'list\'s length {len(a)} != expected length {n}'
    
def check_shape(a, shape):
    """Check the shape of a tensor."""
    assert a.shape == shape, \
            f'tensor\'s shape {a.shape} != expected shape {shape}'

check_len(outputs, num_steps)
check_shape(outputs[0], (batch_size, num_hiddens))
check_shape(state, (batch_size, num_hiddens))

Wrapping as a language model

Add a vocab-sized output projection on top of the RNN’s hidden states. This is the LM wrapper we’ll train:

class RNNLMScratch(d2l.Classifier):
    """The RNN-based language model implemented from scratch."""
    rnn: nn.Module
    vocab_size: int
    lr: float = 0.01

    def setup(self):
        self.W_hq = self.param('W_hq', nn.initializers.normal(self.rnn.sigma),
                               (self.rnn.num_hiddens, self.vocab_size))
        self.b_q = self.param('b_q', nn.initializers.zeros, (self.vocab_size))

    def training_step(self, params, batch, state):
        value, grads = jax.value_and_grad(
            self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
        l, _ = value
        self.plot('ppl', d2l.exp(l), train=True)
        return value, grads

    def validation_step(self, params, batch, state):
        l, _ = self.loss(params, batch[:-1], batch[-1], state)
        self.plot('ppl', d2l.exp(l), train=False)

Training objective

At every time step, the model predicts the next character. Flatten batch and time, apply cross-entropy, and average:

\mathcal{L} = \frac{1}{BT}\sum_{b,t} -\log P(x_{b,t+1} \mid x_{b,\le t}).

Perplexity is \exp(\mathcal{L}); lower means fewer effective choices for the next character.

Inputs as one-hot vectors

Tokens come in as integer ids; the RNN expects vectors. One-hot encoding is the simplest input embedding:

jax.nn.one_hot(jnp.array([0, 2]), 5)
Array([[1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0.]], dtype=float32)

A (batch, num_steps) minibatch of token ids becomes a (batch, num_steps, vocab) one-hot tensor:

@d2l.add_to_class(RNNLMScratch)
def one_hot(self, X):    
    # Output shape: (num_steps, batch_size, vocab_size)    
    return jax.nn.one_hot(X.T, self.vocab_size)

Output projection

Gather hidden states across all time steps and project through the head:

@d2l.add_to_class(RNNLMScratch)
def output_layer(self, rnn_outputs):
    outputs = [d2l.matmul(H, self.W_hq) + self.b_q for H in rnn_outputs]
    return d2l.stack(outputs, 1)

@d2l.add_to_class(RNNLMScratch)
def forward(self, X, state=None):
    embs = self.one_hot(X)
    rnn_outputs, _ = self.rnn(embs, state)
    return self.output_layer(rnn_outputs)

Smoke test — input shape (batch, num_steps), output shape (batch, num_steps, vocab):

model = RNNLMScratch(rnn, num_inputs)
outputs, _ = model.init_with_output(d2l.get_key(),
                                    d2l.ones((batch_size, num_steps),
                                             dtype=d2l.int32))
check_shape(outputs, (batch_size, num_steps, num_inputs))
E0524 02:49:39.806481   12206 cuda_blas.cc:196] failed to create cublas handle: the resource allocation failed
E0524 02:49:39.806507 12206 cuda_blas.cc:199] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your ...

Gradient clipping

The recurrence multiplies the gradient by \mathbf{W}_{hh} once per time step — a single explosion-prone factor. Clip before each step so its norm stays bounded:

\mathbf{g} \leftarrow \min\!\left(1, \frac{\theta}{\|\mathbf{g}\|}\right)\mathbf{g}.

@d2l.add_to_class(d2l.Trainer)
def clip_gradients(self, grad_clip_val, grads):
    grad_leaves, _ = jax.tree_util.tree_flatten(grads)
    norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in grad_leaves))
    clip = lambda grad: jnp.where(norm < grad_clip_val,
                                  grad, grad * (grad_clip_val / norm))
    return jax.tree_util.tree_map(clip, grads)

Training

~32 character window, batch 1024, ~30 epochs. Gradient clipping keeps the loss from going NaN:

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNNScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = RNNLMScratch(rnn, vocab_size=len(data.vocab), lr=1)
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)

Decoding (text generation)

Feed in a prompt, then sample the model’s predictions for the next character at each step:

@d2l.add_to_class(RNNLMScratch)
def predict(self, prefix, num_preds, vocab, params):
    state, outputs = None, [vocab[prefix[0]]]
    for i in range(len(prefix) + num_preds - 1):
        X = d2l.tensor([[outputs[-1]]])
        embs = self.one_hot(X)
        rnn_outputs, state = self.rnn.apply({'params': params['rnn']},
                                            embs, state)
        if i < len(prefix) - 1:  # Warm-up period
            outputs.append(vocab[prefix[i + 1]])
        else:  # Predict num_preds steps
            Y = self.apply({'params': params}, rnn_outputs,
                           method=self.output_layer)
            outputs.append(int(d2l.reshape(d2l.argmax(Y, axis=2), ())))
    return ''.join([vocab.idx_to_token[i] for i in outputs])
ppl = float(model.board.data['val_ppl'][-1].y)
pred = model.predict('time traveller', 20, data.vocab, trainer.state.params)
print(f'perplexity {ppl:.1f}, {pred!r}')
perplexity 7.4, 'time traveller the this the this t'

Output is recognizably English-shaped — but at this size the model hasn’t learned much beyond character-level statistics.

Recap

  • Char-level RNN LM: hand-rolled cell + one-hot input + linear head + cross-entropy.
  • Gradient clipping is mandatory for stable RNN training.
  • Training is truncated BPTT — backprop only through num_steps of unrolled history per batch.
  • The same scaffold takes any cell (LSTM, GRU) — only the recurrence changes. Coming next.