Per-head dimension trick

Multi-Head Attention

Multi-Head Attention

A single attention head computes one weighted average — one notion of “relevance”. But a sentence has many parallel relations: subject-verb agreement, syntax, coreference, topical similarity.

Multi-head attention runs h independent attention mechanisms in parallel, each with its own learned linear projections of \mathbf{Q}, \mathbf{K}, \mathbf{V}. Modern Transformers use 8, 16, even 96 heads.

The architecture

\mathbf{h}_i = f(\mathbf{W}_i^{(q)}\mathbf{q}, \mathbf{W}_i^{(k)}\mathbf{k}, \mathbf{W}_i^{(v)}\mathbf{v}), \text{MHA} = \mathbf{W}_o\,[\mathbf{h}_1; \ldots; \mathbf{h}_h].

h projections in parallel, concatenated and linearly transformed.

Setup

from d2l import mxnet as d2l
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
npx.set_np()

To keep cost flat as h grows, set p_q = p_k = p_v = p_o/h. The h heads then have the same total compute as a single- head attention with hidden size p_o. Implementation: do one big \mathbf{W}_q producing p_o-dim outputs, then reshape into h heads.

class MultiHeadAttention(d2l.Module):
    """Multi-head attention."""
    def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
                 **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)

    def forward(self, queries, keys, values, valid_lens):
        # Shape of queries, keys, or values:
        # (batch_size, no. of queries or key-value pairs, num_hiddens)
        # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
        # After transposing, shape of output queries, keys, or values:
        # (batch_size * num_heads, no. of queries or key-value pairs,
        # num_hiddens / num_heads)
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))

        if valid_lens is not None:
            # On axis 0, copy the first item (scalar or vector) for num_heads
            # times, then copy the next item, and so on
            valid_lens = valid_lens.repeat(self.num_heads, axis=0)

        # Shape of output: (batch_size * num_heads, no. of queries,
        # num_hiddens / num_heads)
        output = self.attention(queries, keys, values, valid_lens)
        
        # Shape of output_concat: (batch_size, no. of queries, num_hiddens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)

Reshape trick for parallel heads

Reshape (batch, len, num_hiddens)(batch, len, num_heads, dim/heads)(batch * num_heads, len, dim/heads) so the attention layer sees all heads as just more batch entries. transpose_output reverses it after the attention layer:

@d2l.add_to_class(MultiHeadAttention)
def transpose_qkv(self, X):
    """Transposition for parallel computation of multiple attention heads."""
    # Shape of input X: (batch_size, no. of queries or key-value pairs,
    # num_hiddens). Shape of output X: (batch_size, no. of queries or
    # key-value pairs, num_heads, num_hiddens / num_heads)
    X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
    # Shape of output X: (batch_size, num_heads, no. of queries or key-value
    # pairs, num_hiddens / num_heads)
    X = X.transpose(0, 2, 1, 3)
    # Shape of output: (batch_size * num_heads, no. of queries or key-value
    # pairs, num_hiddens / num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

@d2l.add_to_class(MultiHeadAttention)
def transpose_output(self, X):
    """Reverse the operation of transpose_qkv."""
    X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
    X = X.transpose(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

Shape check

5 heads × 100 hidden, batch 2, 4 queries, 6 key-value pairs. Output shape matches input shape:

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = d2l.tensor([3, 2])
X = d2l.ones((batch_size, num_queries, num_hiddens))
Y = d2l.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens),
                (batch_size, num_queries, num_hiddens))

Recap

  • h heads, each its own learned \mathbf{W}_q, \mathbf{W}_k, \mathbf{W}_v, run in parallel; concat then project.
  • Set per-head dim to num_hiddens / num_heads so total compute stays the same as a single-head layer.
  • Reshape (B, L, D) → (B*h, L, D/h) to run all heads as one batched matmul — no Python loop.
  • The block of choice for Transformers; multiple heads let one layer learn many simultaneous notions of relevance.