from d2l import tensorflow as d2l
import tensorflow as tfPlain seq2seq jams the entire source into one fixed vector — a bottleneck. Early tokens get forgotten by the time the encoder finishes.
Bahdanau, Cho & Bengio (2015): instead of one context vector, let the decoder query the encoder at every step.
\mathbf{c}_{t'} = \sum_{t=1}^{T} \alpha(\mathbf{s}_{t'-1}, \mathbf{h}_t)\, \mathbf{h}_t.
The original “soft alignment” mechanism — template for every Transformer.
Plain seq2seq: a single state vector is the only bridge between encoder and decoder.
Decoder queries the encoder’s per-step outputs at every decoding step.
Just adds an attention_weights property so we can pull weights out for visualization:
Per step: take the previous decoder hidden state, run additive attention against the encoder outputs (masked by source valid_len), concat the resulting context with the embedded input, run one GRU step, project to vocab.
The score is learned:
a(\mathbf{s}_{t'-1}, \mathbf{h}_t) = \mathbf{w}_v^\top \tanh(\mathbf{W}_s \mathbf{s}_{t'-1} + \mathbf{W}_h \mathbf{h}_t).
class Seq2SeqAttentionDecoder(AttentionDecoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0):
super().__init__()
self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens,
num_hiddens, dropout)
self.embedding = tf.keras.layers.Embedding(vocab_size, embed_size)
self.rnn = tf.keras.layers.RNN(tf.keras.layers.StackedRNNCells(
[tf.keras.layers.GRUCell(num_hiddens, dropout=dropout)
for _ in range(num_layers)]), return_sequences=True,
return_state=True)
self.dense = tf.keras.layers.Dense(vocab_size)
def init_state(self, enc_outputs, enc_valid_lens):
# Shape of outputs: (batch_size, num_steps, num_hiddens).
# Length of list hidden_state is num_layers, where the shape of its
# element is (batch_size, num_hiddens)
outputs, hidden_state = enc_outputs
return (tf.transpose(outputs, (1, 0, 2)), hidden_state,
enc_valid_lens)
def call(self, X, state, training=False, **kwargs):
# Shape of output enc_outputs: # (batch_size, num_steps, num_hiddens)
# Length of list hidden_state is num_layers, where the shape of its
# element is (batch_size, num_hiddens)
enc_outputs, hidden_state, enc_valid_lens = state
# Shape of the output X: (num_steps, batch_size, embed_size)
X = self.embedding(X) # Input X has shape: (batch_size, num_steps)
X = tf.transpose(X, perm=(1, 0, 2))
outputs, self._attention_weights = [], []
for x in tf.unstack(X):
# Shape of query: (batch_size, 1, num_hiddens)
query = tf.expand_dims(hidden_state[-1], axis=1)
# Shape of context: (batch_size, 1, num_hiddens)
context = self.attention(query, enc_outputs, enc_outputs,
enc_valid_lens, training=training)
# Concatenate on the feature dimension
x = tf.concat((context, tf.expand_dims(x, axis=1)), axis=-1)
out = self.rnn(x, hidden_state, training=training)
hidden_state = out[1:]
outputs.append(out[0])
self._attention_weights.append(self.attention.attention_weights)
# After fully connected layer transformation, shape of outputs:
# (batch_size, num_steps, vocab_size)
outputs = self.dense(tf.concat(outputs, axis=1))
return outputs, [enc_outputs, hidden_state, enc_valid_lens]
@property
def attention_weights(self):
return self._attention_weightsSame harness as plain seq2seq — same logit shape, plus a new attention-weight tensor of shape (num_steps, batch, src_steps):
vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 7
encoder = d2l.Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)
decoder = Seq2SeqAttentionDecoder(vocab_size, embed_size, num_hiddens,
num_layers)
X = tf.zeros((batch_size, num_steps))
state = decoder.init_state(encoder(X, training=False), None)
output, state = decoder(X, state, training=False)
d2l.check_shape(output, (batch_size, num_steps, vocab_size))
d2l.check_shape(state[0], (batch_size, num_steps, num_hiddens))
d2l.check_shape(state[1][0], (batch_size, num_hiddens))Same hyperparameters as plain seq2seq (embed/hidden 256, 2 layers, dropout 0.2, Adam 0.005, 30 epochs). Gives the model attention; everything else stays the same:
data = d2l.MTFraEng(batch_size=128)
embed_size, num_hiddens, num_layers, dropout = 256, 256, 2, 0.2
with d2l.try_gpu():
encoder = d2l.Seq2SeqEncoder(
len(data.src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
len(data.tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab['<pad>'],
lr=0.005)
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1)
trainer.fit(model, data)Compare BLEU vs. plain seq2seq — attention typically helps more on longer/harder sentences:
engs = ['i lost .', 'i\'m calm .', 'i\'m home .']
fras = ['j\'ai perdu .', 'je suis calme .', 'je suis chez moi .']
preds, _ = model.predict_step(
data.build(engs, fras), d2l.try_gpu(), data.num_steps)
for en, fr, p in zip(engs, fras, preds):
translation = []
for token in data.tgt_vocab.to_tokens(p):
if token == '<eos>':
break
translation.append(token)
print(f'{en} => {translation}, bleu,'
f'{d2l.bleu(" ".join(translation), fr, k=2):.3f}')i lost . => ['je', 'suis', '<unk>', '.'], bleu,0.000
i'm calm . => ['je', 'suis', '<unk>', '.'], bleu,0.658
i'm home . => ['je', 'suis', 'malade', '.'], bleu,0.512
Pull attention weights from the predict step and plot them — rows are decoder steps, columns are source tokens. The diagonal-ish band is the model learning soft alignment:
valid_len.