AttentionDecoder interface

The Bahdanau Attention Mechanism

Bahdanau Attention

Plain seq2seq jams the entire source into one fixed vector — a bottleneck. Early tokens get forgotten by the time the encoder finishes.

Bahdanau, Cho & Bengio (2015): instead of one context vector, let the decoder query the encoder at every step.

\mathbf{c}_{t'} = \sum_{t=1}^{T} \alpha(\mathbf{s}_{t'-1}, \mathbf{h}_t)\, \mathbf{h}_t.

The original “soft alignment” mechanism — template for every Transformer.

Without attention

Plain seq2seq: a single state vector is the only bridge between encoder and decoder.

With attention

Decoder queries the encoder’s per-step outputs at every decoding step.

Setup

from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jax

Just adds an attention_weights property so we can pull weights out for visualization:

class AttentionDecoder(d2l.Decoder):
    """The base attention-based decoder interface.

    Flax modules are dataclasses, so the base class deliberately omits
    `__init__`; subclasses declare their fields as class-level
    annotations and (optionally) a `setup()` method.
    """
    @property
    def attention_weights(self):
        raise NotImplementedError

Seq2SeqAttentionDecoder

Per step: take the previous decoder hidden state, run additive attention against the encoder outputs (masked by source valid_len), concat the resulting context with the embedded input, run one GRU step, project to vocab.

The score is learned:

a(\mathbf{s}_{t'-1}, \mathbf{h}_t) = \mathbf{w}_v^\top \tanh(\mathbf{W}_s \mathbf{s}_{t'-1} + \mathbf{W}_h \mathbf{h}_t).

class Seq2SeqAttentionDecoder(nn.Module):
    vocab_size: int
    embed_size: int
    num_hiddens: int
    num_layers: int
    dropout: float = 0

    def setup(self):
        self.attention = d2l.AdditiveAttention(self.num_hiddens, self.dropout)
        self.embedding = nn.Embed(self.vocab_size, self.embed_size)
        self.dense = nn.Dense(self.vocab_size)
        self.rnn = d2l.GRU(self.num_hiddens, self.num_layers, dropout=self.dropout)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # Shape of outputs: (num_steps, batch_size, num_hiddens).
        # Shape of hidden_state: (num_layers, batch_size, num_hiddens)
        outputs, hidden_state = enc_outputs
        # Attention Weights are returned as part of state; init with None
        return (outputs.transpose(1, 0, 2), hidden_state, enc_valid_lens)

    @nn.compact
    def __call__(self, X, state, training=False):
        # Shape of enc_outputs: (batch_size, num_steps, num_hiddens).
        # Shape of hidden_state: (num_layers, batch_size, num_hiddens)
        # Ignore Attention value in state
        enc_outputs, hidden_state, enc_valid_lens = state
        # Shape of the output X: (num_steps, batch_size, embed_size)
        X = self.embedding(X).transpose(1, 0, 2)
        outputs, attention_weights = [], []
        for x in X:
            # Shape of query: (batch_size, 1, num_hiddens)
            query = jnp.expand_dims(hidden_state[-1], axis=1)
            # Shape of context: (batch_size, 1, num_hiddens)
            context, attention_w = self.attention(query, enc_outputs,
                                                  enc_outputs, enc_valid_lens,
                                                  training=training)
            # Concatenate on the feature dimension
            x = jnp.concatenate((context, jnp.expand_dims(x, axis=1)), axis=-1)
            # Reshape x as (1, batch_size, embed_size + num_hiddens)
            out, hidden_state = self.rnn(x.transpose(1, 0, 2), hidden_state,
                                         training=training)
            outputs.append(out)
            attention_weights.append(attention_w)

        # Flax sow API is used to capture intermediate variables
        self.sow('intermediates', 'dec_attention_weights', attention_weights)

        # After fully connected layer transformation, shape of outputs:
        # (num_steps, batch_size, vocab_size)
        outputs = self.dense(jnp.concatenate(outputs, axis=0))
        return outputs.transpose(1, 0, 2), [enc_outputs, hidden_state,
                                            enc_valid_lens]

Decoder shape check

Same harness as plain seq2seq — same logit shape, plus a new attention-weight tensor of shape (num_steps, batch, src_steps):

vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 7
encoder = d2l.Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)
decoder = Seq2SeqAttentionDecoder(vocab_size, embed_size, num_hiddens,
                                  num_layers)
X = jnp.zeros((batch_size, num_steps), dtype=jnp.int32)
state = decoder.init_state(encoder.init_with_output(d2l.get_key(),
                                                    X, training=False)[0],
                           None)
(output, state), _ = decoder.init_with_output(d2l.get_key(), X,
                                              state, training=False)
d2l.check_shape(output, (batch_size, num_steps, vocab_size))
d2l.check_shape(state[0], (batch_size, num_steps, num_hiddens))
d2l.check_shape(state[1][0], (batch_size, num_hiddens))
  return lax_numpy.astype(self, dtype, copy=copy, device=device)

Training

Same hyperparameters as plain seq2seq (embed/hidden 256, 2 layers, dropout 0.2, Adam 0.005, 30 epochs). Gives the model attention; everything else stays the same:

data = d2l.MTFraEng(batch_size=128)
embed_size, num_hiddens, num_layers, dropout = 256, 256, 2, 0.2
encoder = d2l.Seq2SeqEncoder(
    len(data.src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
    len(data.tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab['<pad>'],
                    lr=0.005)
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)

Translate four sentences

Compare BLEU vs. plain seq2seq — attention typically helps more on longer/harder sentences:

engs = ['i lost .', 'i\'m calm .', 'i\'m home .']
fras = ['j\'ai perdu .', 'je suis calme .', 'je suis chez moi .']
preds, _ = model.predict_step(
    trainer.state.params, data.build(engs, fras), data.num_steps)
for en, fr, p in zip(engs, fras, preds):
    translation = []
    for token in data.tgt_vocab.to_tokens(p):
        if token == '<eos>':
            break
        translation.append(token)
    print(f'{en} => {translation}, bleu,'
          f'{d2l.bleu(" ".join(translation), fr, k=2):.3f}')
i lost . => ["j'ai", 'perdu', '.'], bleu,1.000
i'm calm . => ['je', 'suis', 'malade', '.'], bleu,0.658
i'm home . => ['je', 'suis', 'chez', 'moi', '.'], bleu,1.000

Attention heatmap

Pull attention weights from the predict step and plot them — rows are decoder steps, columns are source tokens. The diagonal-ish band is the model learning soft alignment:

_, (dec_attention_weights, _) = model.predict_step(
    trainer.state.params, data.build([engs[-1]], [fras[-1]]),
    data.num_steps, True)
attention_weights = d2l.concat(
    [step[0][0][0] for step in dec_attention_weights], 0)
attention_weights = d2l.reshape(attention_weights, (1, 1, -1, data.num_steps))
# Plus one to include the end-of-sequence token
d2l.show_heatmaps(attention_weights[:, :, :, :len(engs[-1].split()) + 1],
                  xlabel='Key positions', ylabel='Query positions')

Recap

  • Bahdanau attention replaces the seq2seq bottleneck: at each decoder step, attend over all encoder outputs.
  • Decoder hidden state = query, encoder outputs = keys and values. Additive scoring; masked softmax with source valid_len.
  • Visualizing weights = soft alignment between source and target tokens.
  • This is the conceptual ancestor of the Transformer’s cross-attention.