import collections
from d2l import jax as d2l
from flax import linen as nn
from functools import partial
import jax
from jax import numpy as jnp
import math
import optaxTwo RNNs glued together (Sutskever et al., 2014; Cho et al., 2014):
<eos>.The decoder is just a conditional language model.
Seq2seq with RNN encoder and decoder. <eos> ends the sequence; <bos> starts decoding.
The single-vector context is a bottleneck — motivates attention in the next chapter.
Decoder input: <bos>, “Ils”, “regardent”, “.”. Decoder label: “Ils”, “regardent”, “.”, <eos>.
The MTFraEng pipeline already produces this shifted pair. Same self-supervised setup as a language model — only now the encoder output is concatenated as extra context.
Alternative — scheduled sampling: occasionally feed the prediction back. More realistic at inference but harder to optimize.
Embedding layer for input tokens, then a multilayer GRU. Output: per-step hidden states (top layer) and final state (all layers):
class Seq2SeqEncoder(d2l.Encoder):
"""The RNN encoder for sequence-to-sequence learning."""
vocab_size: int
embed_size: int
num_hiddens: int
num_layers: int
dropout: float = 0
def setup(self):
self.embedding = nn.Embed(self.vocab_size, self.embed_size)
self.rnn = d2l.GRU(self.num_hiddens, self.num_layers, self.dropout)
def __call__(self, X, *args, training=False):
# X shape: (batch_size, num_steps)
embs = self.embedding(d2l.astype(d2l.transpose(X), d2l.int64))
# embs shape: (num_steps, batch_size, embed_size)
outputs, state = self.rnn(embs, training=training)
# outputs shape: (num_steps, batch_size, num_hiddens)
# state shape: (num_layers, batch_size, num_hiddens)
return outputs, stateTwo-layer GRU, hidden 16, batch 4, seq length 9. Confirm shapes:
vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 9
encoder = Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)
X = d2l.zeros((batch_size, num_steps))
(enc_outputs, enc_state), _ = encoder.init_with_output(d2l.get_key(), X)
d2l.check_shape(enc_outputs, (num_steps, batch_size, num_hiddens)) return lax_numpy.astype(self, dtype, copy=copy, device=device)
Embed the previous target token, concatenate the encoder’s final hidden state at every decoder time step (context broadcast across the sequence), run a GRU, and project to vocab logits:
class Seq2SeqDecoder(d2l.Decoder):
"""The RNN decoder for sequence to sequence learning."""
vocab_size: int
embed_size: int
num_hiddens: int
num_layers: int
dropout: float = 0
def setup(self):
self.embedding = nn.Embed(self.vocab_size, self.embed_size)
self.rnn = d2l.GRU(self.num_hiddens, self.num_layers, self.dropout)
self.dense = nn.Dense(self.vocab_size)
def init_state(self, enc_all_outputs, *args):
return enc_all_outputs
def __call__(self, X, state, training=False):
# X shape: (batch_size, num_steps)
# embs shape: (num_steps, batch_size, embed_size)
embs = self.embedding(d2l.astype(d2l.transpose(X), d2l.int64))
enc_output, hidden_state = state
# context shape: (batch_size, num_hiddens)
context = enc_output[-1]
# Broadcast context to (num_steps, batch_size, num_hiddens)
context = jnp.tile(context, (embs.shape[0], 1, 1))
# Concat at the feature dimension
embs_and_context = d2l.concat((embs, context), -1)
outputs, hidden_state = self.rnn(embs_and_context, hidden_state,
training=training)
outputs = d2l.swapaxes(self.dense(outputs), 0, 1)
# outputs shape: (batch_size, num_steps, vocab_size)
# hidden_state shape: (num_layers, batch_size, num_hiddens)
return outputs, [enc_output, hidden_state]End-to-end forward pass produces (batch, num_steps, vocab) logits and a state of shape (num_layers, batch, num_hiddens):
decoder = Seq2SeqDecoder(vocab_size, embed_size, num_hiddens, num_layers)
state = decoder.init_state(encoder.init_with_output(d2l.get_key(), X)[0])
(dec_outputs, state), _ = decoder.init_with_output(d2l.get_key(), X,
state)
d2l.check_shape(dec_outputs, (batch_size, num_steps, vocab_size))
d2l.check_shape(state[1], (num_layers, batch_size, num_hiddens))Subclass EncoderDecoder, add the optimizer:
Layers of the RNN encoder–decoder: embedding → encoder GRU → decoder GRU (with broadcast context) → dense.
class Seq2Seq(d2l.EncoderDecoder):
"""The RNN encoder--decoder for sequence to sequence learning."""
encoder: nn.Module
decoder: nn.Module
tgt_pad: int
lr: float
@partial(jax.jit, static_argnums=(0, 5))
def loss(self, params, X, Y, state, averaged=False):
Y_hat = state.apply_fn({'params': params}, *X, training=True,
rngs={'dropout': state.dropout_rng})
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Y = d2l.reshape(Y, (-1,))
fn = optax.softmax_cross_entropy_with_integer_labels
l = fn(Y_hat, Y)
mask = d2l.astype(Y != self.tgt_pad, d2l.float32)
return d2l.reduce_sum(l * mask) / d2l.reduce_sum(mask), {}
def validation_step(self, params, batch, state):
# Evaluate with dropout disabled (training=False); training=True path
# is used by self.loss during fit.
Y_hat = state.apply_fn({'params': params}, *batch[:-1],
training=False)
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Y = d2l.reshape(batch[-1], (-1,))
fn = optax.softmax_cross_entropy_with_integer_labels
l = fn(Y_hat, Y)
mask = d2l.astype(Y != self.tgt_pad, d2l.float32)
l = d2l.reduce_sum(l * mask) / d2l.reduce_sum(mask)
self.plot('loss', l, train=False)
def configure_optimizers(self):
# Adam optimizer is used here
return optax.adam(learning_rate=self.lr)<pad> predictions shouldn’t contribute to the loss. Build a mask Y != tgt_pad and average only over real tokens:
\mathcal{L} = \frac{\sum_{b,t} \mathbf{1}\{y_{b,t} \ne \texttt{<pad>}\} \, \ell(\hat{\mathbf{y}}_{b,t}, y_{b,t})} {\sum_{b,t} \mathbf{1}\{y_{b,t} \ne \texttt{<pad>}\}}.
@partial(jax.jit, static_argnums=(0, 5))
def loss(self, params, X, Y, state, averaged=False):
Y_hat = state.apply_fn({'params': params}, *X, training=True,
rngs={'dropout': state.dropout_rng})
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Y = d2l.reshape(Y, (-1,))
fn = optax.softmax_cross_entropy_with_integer_labels
l = fn(Y_hat, Y)
mask = d2l.astype(d2l.reshape(Y, -1) != self.tgt_pad, d2l.float32)
return d2l.reduce_sum(l * mask) / d2l.reduce_sum(mask), {}2-layer GRU, embed/hidden 256, dropout 0.2, Adam lr=0.005, gradient clip 1, 30 epochs:
data = d2l.MTFraEng(batch_size=128)
embed_size, num_hiddens, num_layers, dropout = 256, 256, 2, 0.2
encoder = Seq2SeqEncoder(
len(data.src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqDecoder(
len(data.tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
model = Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab['<pad>'],
lr=0.005)
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)Run the encoder once, then loop: feed the previous prediction back, take argmax over the vocab. Stop after num_steps (or when <eos> appears — handled by the caller).
Predicting token by token: feed the previous prediction back, stop on <eos>.
@d2l.add_to_class(d2l.EncoderDecoder)
def predict_step(self, params, batch, num_steps,
save_attention_weights=False):
src, tgt, src_valid_len, _ = batch
enc_all_outputs, inter_enc_vars = self.encoder.apply(
{'params': params['encoder']}, src, src_valid_len, training=False,
mutable='intermediates')
# Save encoder attention weights if inter_enc_vars containing encoder
# attention weights is not empty. (to be covered later)
enc_attention_weights = []
if bool(inter_enc_vars) and save_attention_weights:
# Encoder Attention Weights saved in the intermediates collection
enc_attention_weights = inter_enc_vars[
'intermediates']['enc_attention_weights'][0]
dec_state = self.decoder.init_state(enc_all_outputs, src_valid_len)
outputs, attention_weights = [d2l.expand_dims(tgt[:,0], 1), ], []
for _ in range(num_steps):
(Y, dec_state), inter_dec_vars = self.decoder.apply(
{'params': params['decoder']}, outputs[-1], dec_state,
training=False, mutable='intermediates')
outputs.append(d2l.argmax(Y, 2))
# Save attention weights (to be covered later)
if save_attention_weights:
# Decoder Attention Weights saved in the intermediates collection
dec_attention_weights = inter_dec_vars[
'intermediates']['dec_attention_weights'][0]
attention_weights.append(dec_attention_weights)
return d2l.concat(outputs[1:], 1), (attention_weights,
enc_attention_weights)Compare prediction n-grams against reference. Geometric mean of n-gram precisions, with a brevity penalty so the model can’t game it by emitting “the the”.
\text{BLEU} = \exp\!\left(\min\!\left(0, 1 - \frac{\text{len}_{\text{label}}}{\text{len}_{\text{pred}}}\right)\right) \prod_{n=1}^k p_n^{1/2^n}.
def bleu(pred_seq, label_seq, k):
"""Compute the BLEU."""
pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
len_pred, len_label = len(pred_tokens), len(label_tokens)
score = math.exp(min(0, 1 - len_label / len_pred))
for n in range(1, min(k, len_pred) + 1):
num_matches, label_subs = 0, collections.defaultdict(int)
for i in range(len_label - n + 1):
label_subs[' '.join(label_tokens[i: i + n])] += 1
for i in range(len_pred - n + 1):
if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
num_matches += 1
label_subs[' '.join(pred_tokens[i: i + n])] -= 1
score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
return scoreRun the model on a handful of English sentences and score each. Short, frequent patterns tend to translate cleanly; low BLEU on a sentence usually means one missing or misplaced token is enough to break an n-gram match:
engs = ['i lost .', 'i\'m calm .', 'i\'m home .']
fras = ['j\'ai perdu .', 'je suis calme .', 'je suis chez moi .']
preds, _ = model.predict_step(trainer.state.params, data.build(engs, fras),
data.num_steps)
for en, fr, p in zip(engs, fras, preds):
translation = []
for token in data.tgt_vocab.to_tokens(p):
if token == '<eos>':
break
translation.append(token)
print(f'{en} => {translation}, bleu,'
f'{bleu(" ".join(translation), fr, k=2):.3f}')i lost . => ['je', 'le', 'refuse', '.'], bleu,0.000
i'm calm . => ['je', 'suis', '<unk>', '.'], bleu,0.658
i'm home . => ['je', 'suis', 'chez', 'la', 'la', 'la', 'la', 'la', 'la'], bleu,0.408
<pad> from the loss.