Why scaled dot product

Attention Scoring Functions

Scoring Functions

Attention pooling needs a scoring function a(\mathbf{q}, \mathbf{k}) that softmax turns into weights:

\alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\exp\, a(\mathbf{q}, \mathbf{k}_i)}{\sum_j \exp\, a(\mathbf{q}, \mathbf{k}_j)}.

Output = weighted sum of values; weights = softmax of scoring function a.

Two scorings dominate practice

  • Scaled dot producta = \mathbf{q}^\top \mathbf{k}/\sqrt{d}. Cheap, parameter-free; query and key share a dimension. The Transformer choice.
  • Additive (Bahdanau) — a tiny MLP over [\mathbf{q}; \mathbf{k}]. More expressive, learns the metric, allows different \mathbf{q}/\mathbf{k} shapes.

Both feed into the same softmax + value-pooling pipeline.

Setup

import math
from d2l import mxnet as d2l
from mxnet import np, npx
from mxnet.gluon import nn
npx.set_np()

For d-dimensional queries and keys with independent, zero-mean, unit-variance coordinates,

\operatorname{Var}(\mathbf{q}^\top\mathbf{k}) = \operatorname{Var}\left(\sum_{\ell=1}^d q_\ell k_\ell\right) = d.

As d grows, raw dot products become large in magnitude, softmax saturates, and gradients shrink. Scaling by 1/\sqrt d keeps the logit variance approximately constant:

a(\mathbf{q}, \mathbf{k}_i) = \mathbf{q}^\top \mathbf{k}_i / \sqrt{d}.

A Gaussian-kernel view gives useful geometric intuition, but the variance argument is the operational reason used in Transformers.

Masked softmax

Padded sequences in a minibatch — we don’t want <pad> keys to receive attention mass. Set their pre-softmax scores to a large negative number so \exp flushes them to zero:

def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    if valid_lens is None:
        return npx.softmax(X)
    else:
        shape = X.shape
        if valid_lens.ndim == 1:
            valid_lens = valid_lens.repeat(shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = npx.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, True,
                              value=-1e6, axis=1)
        return npx.softmax(X).reshape(shape)

Masked softmax in action

Random scores; specify a valid length per row:

masked_softmax(np.random.uniform(size=(2, 2, 4)), d2l.tensor([2, 3]))

Per-row mask vectors work too:

masked_softmax(np.random.uniform(size=(2, 2, 4)),
               d2l.tensor([[1, 3], [2, 4]]))

Batched matmul

Attention runs in batches; weights × values is a batched matmul. bmm does the right thing — confirm shapes:

Q = d2l.ones((2, 3, 4))
K = d2l.ones((2, 4, 6))
d2l.check_shape(npx.batch_dot(Q, K), (2, 3, 6))

DotProductAttention class

Stateless layer — no parameters, just \mathbf{Q}\mathbf{K}^\top/\sqrt d, masked softmax, then weighted sum of values:

class DotProductAttention(nn.Block):
    """Scaled dot product attention."""
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Set transpose_b=True to swap the last two dimensions of keys
        scores = npx.batch_dot(queries, keys, transpose_b=True) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return npx.batch_dot(self.dropout(self.attention_weights), values)

DotProduct demo

2 queries, 10 keys/values, valid lengths (2, 6) — only the first 2 / first 6 keys per batch get nonzero weight:

queries = d2l.normal(0, 1, (2, 1, 2))
keys = d2l.normal(0, 1, (2, 10, 2))
values = d2l.normal(0, 1, (2, 10, 4))
valid_lens = d2l.tensor([2, 6])

attention = DotProductAttention(dropout=0.5)
attention.initialize()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))

Visualize the resulting attention matrix:

d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')

AdditiveAttention class

a(\mathbf{q}, \mathbf{k}) = \mathbf{w}_v^\top \tanh(\mathbf{W}_q\mathbf{q} + \mathbf{W}_k\mathbf{k}). Learnable \mathbf{W}_q, \mathbf{W}_k, \mathbf{w}_v. Lets queries and keys live in different feature spaces.

class AdditiveAttention(nn.Block):
    """Additive attention."""
    def __init__(self, num_hiddens, dropout):
        super().__init__()
        # Use flatten=False to only transform the last axis so that the
        # shapes for the other axes are kept the same
        self.W_k = nn.Dense(num_hiddens, use_bias=False, flatten=False)
        self.W_q = nn.Dense(num_hiddens, use_bias=False, flatten=False)
        self.w_v = nn.Dense(1, use_bias=False, flatten=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # After dimension expansion, shape of queries: (batch_size, no. of
        # queries, 1, num_hiddens) and shape of keys: (batch_size, 1,
        # no. of key-value pairs, num_hiddens). Sum them up with
        # broadcasting
        features = np.expand_dims(queries, axis=2) + np.expand_dims(
            keys, axis=1)
        features = np.tanh(features)
        # There is only one output of self.w_v, so we remove the last
        # one-dimensional entry from the shape. Shape of scores:
        # (batch_size, no. of queries, no. of key-value pairs)
        scores = np.squeeze(self.w_v(features), axis=-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # Shape of values: (batch_size, no. of key-value pairs, value
        # dimension)
        return npx.batch_dot(self.dropout(self.attention_weights), values)

Additive demo

Same shapes as before, with mismatched query/key dims allowed:

queries = d2l.normal(0, 1, (2, 1, 20))

attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.initialize()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))

Visualize:

d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')

Recap

  • Scoring function a + softmax \Rightarrow attention weights; pool values with those weights.
  • Scaled dot product is the default — cheap, parameter-free, scales by 1/\sqrt d to control softmax saturation.
  • Additive attention is more flexible (separate \mathbf{q}/ \mathbf{k} shapes, learned metric) but slower and less used at modern scale.
  • Masked softmax is the workhorse for handling padded sequences in batched inference.