Per-head dimension trick

Multi-Head Attention

Multi-Head Attention

A single attention head computes one weighted average — one notion of “relevance”. But a sentence has many parallel relations: subject-verb agreement, syntax, coreference, topical similarity.

Multi-head attention runs h independent attention mechanisms in parallel, each with its own learned linear projections of \mathbf{Q}, \mathbf{K}, \mathbf{V}. Modern Transformers use 8, 16, even 96 heads.

The architecture

\mathbf{h}_i = f(\mathbf{W}_i^{(q)}\mathbf{q}, \mathbf{W}_i^{(k)}\mathbf{k}, \mathbf{W}_i^{(v)}\mathbf{v}), \text{MHA} = \mathbf{W}_o\,[\mathbf{h}_1; \ldots; \mathbf{h}_h].

h projections in parallel, concatenated and linearly transformed.

Setup

from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jax

To keep cost flat as h grows, set p_q = p_k = p_v = p_o/h. The h heads then have the same total compute as a single- head attention with hidden size p_o. Implementation: do one big \mathbf{W}_q producing p_o-dim outputs, then reshape into h heads.

class MultiHeadAttention(nn.Module):
    num_hiddens: int
    num_heads: int
    dropout: float
    bias: bool = False

    def setup(self):
        self.attention = d2l.DotProductAttention(self.dropout)
        self.W_q = nn.Dense(self.num_hiddens, use_bias=self.bias)
        self.W_k = nn.Dense(self.num_hiddens, use_bias=self.bias)
        self.W_v = nn.Dense(self.num_hiddens, use_bias=self.bias)
        self.W_o = nn.Dense(self.num_hiddens, use_bias=self.bias)

    @nn.compact
    def __call__(self, queries, keys, values, valid_lens, training=False):
        # Shape of queries, keys, or values:
        # (batch_size, no. of queries or key-value pairs, num_hiddens)
        # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
        # After transposing, shape of output queries, keys, or values:
        # (batch_size * num_heads, no. of queries or key-value pairs,
        # num_hiddens / num_heads)
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))

        if valid_lens is not None:
            # On axis 0, copy the first item (scalar or vector) for num_heads
            # times, then copy the next item, and so on
            valid_lens = jnp.repeat(valid_lens, self.num_heads, axis=0)

        # Shape of output: (batch_size * num_heads, no. of queries,
        # num_hiddens / num_heads)
        output, attention_weights = self.attention(
            queries, keys, values, valid_lens, training=training)
        # Shape of output_concat: (batch_size, no. of queries, num_hiddens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat), attention_weights

Reshape trick for parallel heads

Reshape (batch, len, num_hiddens)(batch, len, num_heads, dim/heads)(batch * num_heads, len, dim/heads) so the attention layer sees all heads as just more batch entries. transpose_output reverses it after the attention layer:

@d2l.add_to_class(MultiHeadAttention)
def transpose_qkv(self, X):
    """Transposition for parallel computation of multiple attention heads."""
    # Shape of input X: (batch_size, no. of queries or key-value pairs,
    # num_hiddens). Shape of output X: (batch_size, no. of queries or
    # key-value pairs, num_heads, num_hiddens / num_heads)
    X = X.reshape((X.shape[0], X.shape[1], self.num_heads, -1))
    # Shape of output X: (batch_size, num_heads, no. of queries or key-value
    # pairs, num_hiddens / num_heads)
    X = jnp.transpose(X, (0, 2, 1, 3))
    # Shape of output: (batch_size * num_heads, no. of queries or key-value
    # pairs, num_hiddens / num_heads)
    return X.reshape((-1, X.shape[2], X.shape[3]))

@d2l.add_to_class(MultiHeadAttention)
def transpose_output(self, X):
    """Reverse the operation of transpose_qkv."""
    X = X.reshape((-1, self.num_heads, X.shape[1], X.shape[2]))
    X = jnp.transpose(X, (0, 2, 1, 3))
    return X.reshape((X.shape[0], X.shape[1], -1))

Shape check

5 heads × 100 hidden, batch 2, 4 queries, 6 key-value pairs. Output shape matches input shape:

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = d2l.tensor([3, 2])
X = d2l.ones((batch_size, num_queries, num_hiddens))
Y = d2l.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention.init_with_output(d2l.get_key(), X, Y, Y, valid_lens,
                                           training=False)[0][0],
                (batch_size, num_queries, num_hiddens))

Recap

  • h heads, each its own learned \mathbf{W}_q, \mathbf{W}_k, \mathbf{W}_v, run in parallel; concat then project.
  • Set per-head dim to num_hiddens / num_heads so total compute stays the same as a single-head layer.
  • Reshape (B, L, D) → (B*h, L, D/h) to run all heads as one batched matmul — no Python loop.
  • The block of choice for Transformers; multiple heads let one layer learn many simultaneous notions of relevance.