%matplotlib inline
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
import mathA character-level language model on The Time Machine, with nothing but tensor ops. Four pieces:
Char-level keeps the vocab tiny (~30 tokens) — embedding- free models that fit in a notebook.
Parameters: \mathbf{W}_{xh}, \mathbf{W}_{hh}, \mathbf{b}. Initialize randomly, scaled to keep activations sensible:
E0524 02:49:35.971119 12206 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:49:35.971521 12206 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...
class RNNScratch(nn.Module):
"""The RNN model implemented from scratch."""
num_inputs: int
num_hiddens: int
sigma: float = 0.01
def setup(self):
self.W_xh = self.param('W_xh', nn.initializers.normal(self.sigma),
(self.num_inputs, self.num_hiddens))
self.W_hh = self.param('W_hh', nn.initializers.normal(self.sigma),
(self.num_hiddens, self.num_hiddens))
self.b_h = self.param('b_h', nn.initializers.zeros, (self.num_hiddens,))Walk a length-T input one step at a time, carrying the hidden state forward:
@d2l.add_to_class(RNNScratch)
def __call__(self, inputs, state=None):
if state is not None:
state, = state
outputs = []
for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs)
state = d2l.tanh(d2l.matmul(X, self.W_xh) + (
d2l.matmul(state, self.W_hh) if state is not None else 0)
+ self.b_h)
outputs.append(state)
return outputs, stateE0524 02:49:38.435383 12206 cuda_blas.cc:196] failed to create cublas handle: the resource allocation failed
E0524 02:49:38.435418 12206 cuda_blas.cc:199] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your ...
E0524 02:49:38.860147 12206 cuda_blas.cc:196] failed to create cublas handle: the resource allocation failed
E0524 02:49:38.860183 12206 cuda_blas.cc:199] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your ...
Sanity check on output shapes:
def check_len(a, n):
"""Check the length of a list."""
assert len(a) == n, f'list\'s length {len(a)} != expected length {n}'
def check_shape(a, shape):
"""Check the shape of a tensor."""
assert a.shape == shape, \
f'tensor\'s shape {a.shape} != expected shape {shape}'
check_len(outputs, num_steps)
check_shape(outputs[0], (batch_size, num_hiddens))
check_shape(state, (batch_size, num_hiddens))Add a vocab-sized output projection on top of the RNN’s hidden states. This is the LM wrapper we’ll train:
class RNNLMScratch(d2l.Classifier):
"""The RNN-based language model implemented from scratch."""
rnn: nn.Module
vocab_size: int
lr: float = 0.01
def setup(self):
self.W_hq = self.param('W_hq', nn.initializers.normal(self.rnn.sigma),
(self.rnn.num_hiddens, self.vocab_size))
self.b_q = self.param('b_q', nn.initializers.zeros, (self.vocab_size))
def training_step(self, params, batch, state):
value, grads = jax.value_and_grad(
self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
l, _ = value
self.plot('ppl', d2l.exp(l), train=True)
return value, grads
def validation_step(self, params, batch, state):
l, _ = self.loss(params, batch[:-1], batch[-1], state)
self.plot('ppl', d2l.exp(l), train=False)At every time step, the model predicts the next character. Flatten batch and time, apply cross-entropy, and average:
\mathcal{L} = \frac{1}{BT}\sum_{b,t} -\log P(x_{b,t+1} \mid x_{b,\le t}).
Perplexity is \exp(\mathcal{L}); lower means fewer effective choices for the next character.
Tokens come in as integer ids; the RNN expects vectors. One-hot encoding is the simplest input embedding:
Array([[1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.]], dtype=float32)
Gather hidden states across all time steps and project through the head:
@d2l.add_to_class(RNNLMScratch)
def output_layer(self, rnn_outputs):
outputs = [d2l.matmul(H, self.W_hq) + self.b_q for H in rnn_outputs]
return d2l.stack(outputs, 1)
@d2l.add_to_class(RNNLMScratch)
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state)
return self.output_layer(rnn_outputs)Smoke test — input shape (batch, num_steps), output shape (batch, num_steps, vocab):
E0524 02:49:39.806481 12206 cuda_blas.cc:196] failed to create cublas handle: the resource allocation failed
E0524 02:49:39.806507 12206 cuda_blas.cc:199] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your ...
The recurrence multiplies the gradient by \mathbf{W}_{hh} once per time step — a single explosion-prone factor. Clip before each step so its norm stays bounded:
\mathbf{g} \leftarrow \min\!\left(1, \frac{\theta}{\|\mathbf{g}\|}\right)\mathbf{g}.
@d2l.add_to_class(d2l.Trainer)
def clip_gradients(self, grad_clip_val, grads):
grad_leaves, _ = jax.tree_util.tree_flatten(grads)
norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in grad_leaves))
clip = lambda grad: jnp.where(norm < grad_clip_val,
grad, grad * (grad_clip_val / norm))
return jax.tree_util.tree_map(clip, grads)~32 character window, batch 1024, ~30 epochs. Gradient clipping keeps the loss from going NaN:
Feed in a prompt, then sample the model’s predictions for the next character at each step:
@d2l.add_to_class(RNNLMScratch)
def predict(self, prefix, num_preds, vocab, params):
state, outputs = None, [vocab[prefix[0]]]
for i in range(len(prefix) + num_preds - 1):
X = d2l.tensor([[outputs[-1]]])
embs = self.one_hot(X)
rnn_outputs, state = self.rnn.apply({'params': params['rnn']},
embs, state)
if i < len(prefix) - 1: # Warm-up period
outputs.append(vocab[prefix[i + 1]])
else: # Predict num_preds steps
Y = self.apply({'params': params}, rnn_outputs,
method=self.output_layer)
outputs.append(int(d2l.reshape(d2l.argmax(Y, axis=2), ())))
return ''.join([vocab.idx_to_token[i] for i in outputs])perplexity 7.4, 'time traveller the this the this t'
Output is recognizably English-shaped — but at this size the model hasn’t learned much beyond character-level statistics.
num_steps of unrolled history per batch.