from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jaxPlain seq2seq jams the entire source into one fixed vector — a bottleneck. Early tokens get forgotten by the time the encoder finishes.
Bahdanau, Cho & Bengio (2015): instead of one context vector, let the decoder query the encoder at every step.
\mathbf{c}_{t'} = \sum_{t=1}^{T} \alpha(\mathbf{s}_{t'-1}, \mathbf{h}_t)\, \mathbf{h}_t.
The original “soft alignment” mechanism — template for every Transformer.
Plain seq2seq: a single state vector is the only bridge between encoder and decoder.
Decoder queries the encoder’s per-step outputs at every decoding step.
Just adds an attention_weights property so we can pull weights out for visualization:
class AttentionDecoder(d2l.Decoder):
"""The base attention-based decoder interface.
Flax modules are dataclasses, so the base class deliberately omits
`__init__`; subclasses declare their fields as class-level
annotations and (optionally) a `setup()` method.
"""
@property
def attention_weights(self):
raise NotImplementedErrorPer step: take the previous decoder hidden state, run additive attention against the encoder outputs (masked by source valid_len), concat the resulting context with the embedded input, run one GRU step, project to vocab.
The score is learned:
a(\mathbf{s}_{t'-1}, \mathbf{h}_t) = \mathbf{w}_v^\top \tanh(\mathbf{W}_s \mathbf{s}_{t'-1} + \mathbf{W}_h \mathbf{h}_t).
class Seq2SeqAttentionDecoder(nn.Module):
vocab_size: int
embed_size: int
num_hiddens: int
num_layers: int
dropout: float = 0
def setup(self):
self.attention = d2l.AdditiveAttention(self.num_hiddens, self.dropout)
self.embedding = nn.Embed(self.vocab_size, self.embed_size)
self.dense = nn.Dense(self.vocab_size)
self.rnn = d2l.GRU(self.num_hiddens, self.num_layers, dropout=self.dropout)
def init_state(self, enc_outputs, enc_valid_lens, *args):
# Shape of outputs: (num_steps, batch_size, num_hiddens).
# Shape of hidden_state: (num_layers, batch_size, num_hiddens)
outputs, hidden_state = enc_outputs
# Attention Weights are returned as part of state; init with None
return (outputs.transpose(1, 0, 2), hidden_state, enc_valid_lens)
@nn.compact
def __call__(self, X, state, training=False):
# Shape of enc_outputs: (batch_size, num_steps, num_hiddens).
# Shape of hidden_state: (num_layers, batch_size, num_hiddens)
# Ignore Attention value in state
enc_outputs, hidden_state, enc_valid_lens = state
# Shape of the output X: (num_steps, batch_size, embed_size)
X = self.embedding(X).transpose(1, 0, 2)
outputs, attention_weights = [], []
for x in X:
# Shape of query: (batch_size, 1, num_hiddens)
query = jnp.expand_dims(hidden_state[-1], axis=1)
# Shape of context: (batch_size, 1, num_hiddens)
context, attention_w = self.attention(query, enc_outputs,
enc_outputs, enc_valid_lens,
training=training)
# Concatenate on the feature dimension
x = jnp.concatenate((context, jnp.expand_dims(x, axis=1)), axis=-1)
# Reshape x as (1, batch_size, embed_size + num_hiddens)
out, hidden_state = self.rnn(x.transpose(1, 0, 2), hidden_state,
training=training)
outputs.append(out)
attention_weights.append(attention_w)
# Flax sow API is used to capture intermediate variables
self.sow('intermediates', 'dec_attention_weights', attention_weights)
# After fully connected layer transformation, shape of outputs:
# (num_steps, batch_size, vocab_size)
outputs = self.dense(jnp.concatenate(outputs, axis=0))
return outputs.transpose(1, 0, 2), [enc_outputs, hidden_state,
enc_valid_lens]Same harness as plain seq2seq — same logit shape, plus a new attention-weight tensor of shape (num_steps, batch, src_steps):
vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 7
encoder = d2l.Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)
decoder = Seq2SeqAttentionDecoder(vocab_size, embed_size, num_hiddens,
num_layers)
X = jnp.zeros((batch_size, num_steps), dtype=jnp.int32)
state = decoder.init_state(encoder.init_with_output(d2l.get_key(),
X, training=False)[0],
None)
(output, state), _ = decoder.init_with_output(d2l.get_key(), X,
state, training=False)
d2l.check_shape(output, (batch_size, num_steps, vocab_size))
d2l.check_shape(state[0], (batch_size, num_steps, num_hiddens))
d2l.check_shape(state[1][0], (batch_size, num_hiddens)) return lax_numpy.astype(self, dtype, copy=copy, device=device)
Same hyperparameters as plain seq2seq (embed/hidden 256, 2 layers, dropout 0.2, Adam 0.005, 30 epochs). Gives the model attention; everything else stays the same:
data = d2l.MTFraEng(batch_size=128)
embed_size, num_hiddens, num_layers, dropout = 256, 256, 2, 0.2
encoder = d2l.Seq2SeqEncoder(
len(data.src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
len(data.tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab['<pad>'],
lr=0.005)
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)Compare BLEU vs. plain seq2seq — attention typically helps more on longer/harder sentences:
engs = ['i lost .', 'i\'m calm .', 'i\'m home .']
fras = ['j\'ai perdu .', 'je suis calme .', 'je suis chez moi .']
preds, _ = model.predict_step(
trainer.state.params, data.build(engs, fras), data.num_steps)
for en, fr, p in zip(engs, fras, preds):
translation = []
for token in data.tgt_vocab.to_tokens(p):
if token == '<eos>':
break
translation.append(token)
print(f'{en} => {translation}, bleu,'
f'{d2l.bleu(" ".join(translation), fr, k=2):.3f}')i lost . => ["j'ai", 'perdu', '.'], bleu,1.000
i'm calm . => ['je', 'suis', 'malade', '.'], bleu,0.658
i'm home . => ['je', 'suis', 'chez', 'moi', '.'], bleu,1.000
Pull attention weights from the predict step and plot them — rows are decoder steps, columns are source tokens. The diagonal-ish band is the model learning soft alignment:
valid_len.