%matplotlib inline
from d2l import torch as d2l
import math
import torch
from torch import nn
from torch.nn import functional as FA character-level language model on The Time Machine, with nothing but tensor ops. Four pieces:
Char-level keeps the vocab tiny (~30 tokens) — embedding- free models that fit in a notebook.
Parameters: \mathbf{W}_{xh}, \mathbf{W}_{hh}, \mathbf{b}. Initialize randomly, scaled to keep activations sensible:
class RNNScratch(d2l.Module):
"""The RNN model implemented from scratch."""
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.W_xh = nn.Parameter(
d2l.randn(num_inputs, num_hiddens) * sigma)
self.W_hh = nn.Parameter(
d2l.randn(num_hiddens, num_hiddens) * sigma)
self.b_h = nn.Parameter(d2l.zeros(num_hiddens))Walk a length-T input one step at a time, carrying the hidden state forward:
@d2l.add_to_class(RNNScratch)
def forward(self, inputs, state=None):
if state is None:
# Initial state with shape: (batch_size, num_hiddens)
state = d2l.zeros((inputs.shape[1], self.num_hiddens),
device=inputs.device)
else:
state, = state
outputs = []
for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs)
state = d2l.tanh(d2l.matmul(X, self.W_xh) +
d2l.matmul(state, self.W_hh) + self.b_h)
outputs.append(state)
return outputs, stateSanity check on output shapes:
def check_len(a, n):
"""Check the length of a list."""
assert len(a) == n, f'list\'s length {len(a)} != expected length {n}'
def check_shape(a, shape):
"""Check the shape of a tensor."""
assert a.shape == shape, \
f'tensor\'s shape {a.shape} != expected shape {shape}'
check_len(outputs, num_steps)
check_shape(outputs[0], (batch_size, num_hiddens))
check_shape(state, (batch_size, num_hiddens))Add a vocab-sized output projection on top of the RNN’s hidden states. This is the LM wrapper we’ll train:
class RNNLMScratch(d2l.Classifier):
"""The RNN-based language model implemented from scratch."""
def __init__(self, rnn, vocab_size, lr=0.01):
super().__init__()
self.save_hyperparameters()
self.init_params()
def init_params(self):
self.W_hq = nn.Parameter(
d2l.randn(
self.rnn.num_hiddens, self.vocab_size) * self.rnn.sigma)
self.b_q = nn.Parameter(d2l.zeros(self.vocab_size))
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('ppl', d2l.exp(l), train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('ppl', d2l.exp(l), train=False)At every time step, the model predicts the next character. Flatten batch and time, apply cross-entropy, and average:
\mathcal{L} = \frac{1}{BT}\sum_{b,t} -\log P(x_{b,t+1} \mid x_{b,\le t}).
Perplexity is \exp(\mathcal{L}); lower means fewer effective choices for the next character.
Tokens come in as integer ids; the RNN expects vectors. One-hot encoding is the simplest input embedding:
tensor([[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0]])
Gather hidden states across all time steps and project through the head:
@d2l.add_to_class(RNNLMScratch)
def output_layer(self, rnn_outputs):
outputs = [d2l.matmul(H, self.W_hq) + self.b_q for H in rnn_outputs]
return d2l.stack(outputs, 1)
@d2l.add_to_class(RNNLMScratch)
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state)
return self.output_layer(rnn_outputs)The recurrence multiplies the gradient by \mathbf{W}_{hh} once per time step — a single explosion-prone factor. Clip before each step so its norm stays bounded:
\mathbf{g} \leftarrow \min\!\left(1, \frac{\theta}{\|\mathbf{g}\|}\right)\mathbf{g}.
~32 character window, batch 1024, ~30 epochs. Gradient clipping keeps the loss from going NaN:
Feed in a prompt, then sample the model’s predictions for the next character at each step:
@d2l.add_to_class(RNNLMScratch)
def predict(self, prefix, num_preds, vocab, device=None):
state, outputs = None, [vocab[prefix[0]]]
for i in range(len(prefix) + num_preds - 1):
X = d2l.tensor([[outputs[-1]]], device=device)
embs = self.one_hot(X)
rnn_outputs, state = self.rnn(embs, state)
if i < len(prefix) - 1: # Warm-up period
outputs.append(vocab[prefix[i + 1]])
else: # Predict num_preds steps
Y = self.output_layer(rnn_outputs)
outputs.append(int(d2l.reshape(d2l.argmax(Y, axis=2), ())))
return ''.join([vocab.idx_to_token[i] for i in outputs])perplexity 7.2, 'time traveller a move the travel o'
Output is recognizably English-shaped — but at this size the model hasn’t learned much beyond character-level statistics.
num_steps of unrolled history per batch.