AttentionDecoder interface

The Bahdanau Attention Mechanism

Bahdanau Attention

Plain seq2seq jams the entire source into one fixed vector — a bottleneck. Early tokens get forgotten by the time the encoder finishes.

Bahdanau, Cho & Bengio (2015): instead of one context vector, let the decoder query the encoder at every step.

\mathbf{c}_{t'} = \sum_{t=1}^{T} \alpha(\mathbf{s}_{t'-1}, \mathbf{h}_t)\, \mathbf{h}_t.

The original “soft alignment” mechanism — template for every Transformer.

Without attention

Plain seq2seq: a single state vector is the only bridge between encoder and decoder.

With attention

Decoder queries the encoder’s per-step outputs at every decoding step.

Setup

from d2l import mxnet as d2l
from mxnet import init, np, npx
from mxnet.gluon import rnn, nn
npx.set_np()

Just adds an attention_weights property so we can pull weights out for visualization:

class AttentionDecoder(d2l.Decoder):
    """The base attention-based decoder interface."""
    def __init__(self):
        super().__init__()

    @property
    def attention_weights(self):
        raise NotImplementedError

Seq2SeqAttentionDecoder

Per step: take the previous decoder hidden state, run additive attention against the encoder outputs (masked by source valid_len), concat the resulting context with the embedded input, run one GRU step, project to vocab.

The score is learned:

a(\mathbf{s}_{t'-1}, \mathbf{h}_t) = \mathbf{w}_v^\top \tanh(\mathbf{W}_s \mathbf{s}_{t'-1} + \mathbf{W}_h \mathbf{h}_t).

class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0):
        super().__init__()
        self.attention = d2l.AdditiveAttention(num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = rnn.GRU(num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Dense(vocab_size, flatten=False)
        # MXNet 2.0: `init.Xavier()` is applied per-Block (so it doesn't
        # propagate into the GRU's fused 1D bias slice, which would raise
        # "Xavier initializer cannot be applied to vector"). The GRU keeps
        # its built-in Orthogonal/Xavier/Zero initialization.
        self.embedding.initialize(init.Xavier())
        self.dense.initialize(init.Xavier())
        self.attention.initialize(init.Xavier())
        self.rnn.initialize()

    def init_state(self, enc_outputs, enc_valid_lens):
        # Shape of outputs: (num_steps, batch_size, num_hiddens).
        # Shape of hidden_state: (num_layers, batch_size, num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.swapaxes(0, 1), hidden_state, enc_valid_lens)

    def forward(self, X, state):
        # Shape of enc_outputs: (batch_size, num_steps, num_hiddens).
        # Shape of hidden_state: (num_layers, batch_size, num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        # Shape of the output X: (num_steps, batch_size, embed_size)
        X = self.embedding(X).swapaxes(0, 1)
        outputs, self._attention_weights = [], []
        for x in X:
            # Shape of query: (batch_size, 1, num_hiddens)
            query = np.expand_dims(hidden_state[-1], axis=1)
            # Shape of context: (batch_size, 1, num_hiddens)
            context = self.attention(
                query, enc_outputs, enc_outputs, enc_valid_lens)
            # Concatenate on the feature dimension
            x = np.concatenate((context, np.expand_dims(x, axis=1)), axis=-1)
            # Reshape x as (1, batch_size, embed_size + num_hiddens)
            out, hidden_state = self.rnn(x.swapaxes(0, 1), hidden_state)
            hidden_state = hidden_state[0]
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        # After fully connected layer transformation, shape of outputs:
        # (num_steps, batch_size, vocab_size)
        outputs = self.dense(np.concatenate(outputs, axis=0))
        return outputs.swapaxes(0, 1), [enc_outputs, hidden_state,
                                        enc_valid_lens]

    @property
    def attention_weights(self):
        return self._attention_weights

Decoder shape check

Same harness as plain seq2seq — same logit shape, plus a new attention-weight tensor of shape (num_steps, batch, src_steps):

vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 7
encoder = d2l.Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)
decoder = Seq2SeqAttentionDecoder(vocab_size, embed_size, num_hiddens,
                                  num_layers)
X = d2l.zeros((batch_size, num_steps))
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
d2l.check_shape(output, (batch_size, num_steps, vocab_size))
d2l.check_shape(state[0], (batch_size, num_steps, num_hiddens))
d2l.check_shape(state[1][0], (batch_size, num_hiddens))

Training

Same hyperparameters as plain seq2seq (embed/hidden 256, 2 layers, dropout 0.2, Adam 0.005, 30 epochs). Gives the model attention; everything else stays the same:

data = d2l.MTFraEng(batch_size=128)
embed_size, num_hiddens, num_layers, dropout = 256, 256, 2, 0.2
encoder = d2l.Seq2SeqEncoder(
    len(data.src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
    len(data.tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab['<pad>'],
                    lr=0.005)
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)

Translate four sentences

Compare BLEU vs. plain seq2seq — attention typically helps more on longer/harder sentences:

engs = ['i lost .', 'i\'m calm .', 'i\'m home .']
fras = ['j\'ai perdu .', 'je suis calme .', 'je suis chez moi .']
preds, _ = model.predict_step(
    data.build(engs, fras), d2l.try_gpu(), data.num_steps)
for en, fr, p in zip(engs, fras, preds):
    translation = []
    for token in data.tgt_vocab.to_tokens(p):
        if token == '<eos>':
            break
        translation.append(token)
    print(f'{en} => {translation}, bleu,'
          f'{d2l.bleu(" ".join(translation), fr, k=2):.3f}')

Attention heatmap

Pull attention weights from the predict step and plot them — rows are decoder steps, columns are source tokens. The diagonal-ish band is the model learning soft alignment:

_, dec_attention_weights = model.predict_step(
    data.build([engs[-1]], [fras[-1]]), d2l.try_gpu(), data.num_steps, True)
attention_weights = d2l.concat(
    [step[0][0][0] for step in dec_attention_weights], 0)
attention_weights = d2l.reshape(attention_weights, (1, 1, -1, data.num_steps))
# Plus one to include the end-of-sequence token
d2l.show_heatmaps(
    attention_weights[:, :, :, :len(engs[-1].split()) + 1],
    xlabel='Key positions', ylabel='Query positions')

Recap

  • Bahdanau attention replaces the seq2seq bottleneck: at each decoder step, attend over all encoder outputs.
  • Decoder hidden state = query, encoder outputs = keys and values. Additive scoring; masked softmax with source valid_len.
  • Visualizing weights = soft alignment between source and target tokens.
  • This is the conceptual ancestor of the Transformer’s cross-attention.