from d2l import tensorflow as d2l
import tensorflow as tfA single attention head computes one weighted average — one notion of “relevance”. But a sentence has many parallel relations: subject-verb agreement, syntax, coreference, topical similarity.
Multi-head attention runs h independent attention mechanisms in parallel, each with its own learned linear projections of \mathbf{Q}, \mathbf{K}, \mathbf{V}. Modern Transformers use 8, 16, even 96 heads.
\mathbf{h}_i = f(\mathbf{W}_i^{(q)}\mathbf{q}, \mathbf{W}_i^{(k)}\mathbf{k}, \mathbf{W}_i^{(v)}\mathbf{v}), \text{MHA} = \mathbf{W}_o\,[\mathbf{h}_1; \ldots; \mathbf{h}_h].
h projections in parallel, concatenated and linearly transformed.
To keep cost flat as h grows, set p_q = p_k = p_v = p_o/h. The h heads then have the same total compute as a single- head attention with hidden size p_o. Implementation: do one big \mathbf{W}_q producing p_o-dim outputs, then reshape into h heads.
class MultiHeadAttention(d2l.Module):
"""Multi-head attention."""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super().__init__()
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
def call(self, queries, keys, values, valid_lens, training=False, **kwargs):
# Shape of queries, keys, or values:
# (batch_size, no. of queries or key-value pairs, num_hiddens)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
# After transposing, shape of output queries, keys, or values:
# (batch_size * num_heads, no. of queries or key-value pairs,
# num_hiddens / num_heads)
queries = self.transpose_qkv(self.W_q(queries))
keys = self.transpose_qkv(self.W_k(keys))
values = self.transpose_qkv(self.W_v(values))
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for num_heads
# times, then copy the next item, and so on
valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0)
# Shape of output: (batch_size * num_heads, no. of queries,
# num_hiddens / num_heads)
output = self.attention(queries, keys, values, valid_lens,
training=training)
# Shape of output_concat: (batch_size, no. of queries, num_hiddens)
output_concat = self.transpose_output(output)
return self.W_o(output_concat)Reshape (batch, len, num_hiddens) → (batch, len, num_heads, dim/heads) → (batch * num_heads, len, dim/heads) so the attention layer sees all heads as just more batch entries. transpose_output reverses it after the attention layer:
@d2l.add_to_class(MultiHeadAttention)
def transpose_qkv(self, X):
"""Transposition for parallel computation of multiple attention heads."""
# Shape of input X: (batch_size, no. of queries or key-value pairs,
# num_hiddens). Shape of output X: (batch_size, no. of queries or
# key-value pairs, num_heads, num_hiddens / num_heads)
X = tf.reshape(X, (tf.shape(X)[0], tf.shape(X)[1], self.num_heads, -1))
# Shape of output X: (batch_size, num_heads, no. of queries or key-value
# pairs, num_hiddens / num_heads)
X = tf.transpose(X, perm=(0, 2, 1, 3))
# Shape of output: (batch_size * num_heads, no. of queries or key-value
# pairs, num_hiddens / num_heads)
return tf.reshape(X, (-1, tf.shape(X)[2], tf.shape(X)[3]))
@d2l.add_to_class(MultiHeadAttention)
def transpose_output(self, X):
"""Reverse the operation of transpose_qkv."""
X = tf.reshape(X, (-1, self.num_heads, tf.shape(X)[1], tf.shape(X)[2]))
X = tf.transpose(X, perm=(0, 2, 1, 3))
return tf.reshape(X, (tf.shape(X)[0], tf.shape(X)[1], -1))5 heads × 100 hidden, batch 2, 4 queries, 6 key-value pairs. Output shape matches input shape:
num_hiddens / num_heads so total compute stays the same as a single-head layer.(B, L, D) → (B*h, L, D/h) to run all heads as one batched matmul — no Python loop.