from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jax
import mathAttention pooling needs a scoring function a(\mathbf{q}, \mathbf{k}) that softmax turns into weights:
\alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\exp\, a(\mathbf{q}, \mathbf{k}_i)}{\sum_j \exp\, a(\mathbf{q}, \mathbf{k}_j)}.
Output = weighted sum of values; weights = softmax of scoring function a.
Both feed into the same softmax + value-pooling pipeline.
For d-dimensional queries and keys with independent, zero-mean, unit-variance coordinates,
\operatorname{Var}(\mathbf{q}^\top\mathbf{k}) = \operatorname{Var}\left(\sum_{\ell=1}^d q_\ell k_\ell\right) = d.
As d grows, raw dot products become large in magnitude, softmax saturates, and gradients shrink. Scaling by 1/\sqrt d keeps the logit variance approximately constant:
a(\mathbf{q}, \mathbf{k}_i) = \mathbf{q}^\top \mathbf{k}_i / \sqrt{d}.
A Gaussian-kernel view gives useful geometric intuition, but the variance argument is the operational reason used in Transformers.
Padded sequences in a minibatch — we don’t want <pad> keys to receive attention mass. Set their pre-softmax scores to a large negative number so \exp flushes them to zero:
def masked_softmax(X, valid_lens):
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
def _sequence_mask(X, valid_len, value=0):
maxlen = X.shape[1]
mask = jnp.arange((maxlen),
dtype=jnp.float32)[None, :] < valid_len[:, None]
return jnp.where(mask, X, value)
if valid_lens is None:
return nn.softmax(X, axis=-1)
else:
shape = X.shape
if valid_lens.ndim == 1:
valid_lens = jnp.repeat(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return nn.softmax(X.reshape(shape), axis=-1)Random scores; specify a valid length per row:
Array([[[0.44036096, 0.55963904, 0. , 0. ],
[0.3914764 , 0.60852355, 0. , 0. ]],
[[0.2912151 , 0.43938962, 0.26939523, 0. ],
[0.41406792, 0.29180348, 0.29412857, 0. ]]], dtype=float32)
Attention runs in batches; weights × values is a batched matmul. bmm does the right thing — confirm shapes:
Stateless layer — no parameters, just \mathbf{Q}\mathbf{K}^\top/\sqrt d, masked softmax, then weighted sum of values:
class DotProductAttention(nn.Module):
"""Scaled dot product attention."""
dropout: float
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
@nn.compact
def __call__(self, queries, keys, values, valid_lens=None,
training=False):
d = queries.shape[-1]
# Swap the last two dimensions of keys with keys.swapaxes(1, 2)
scores = queries@(keys.swapaxes(1, 2)) / math.sqrt(d)
attention_weights = masked_softmax(scores, valid_lens)
dropout_layer = nn.Dropout(self.dropout, deterministic=not training)
return dropout_layer(attention_weights)@values, attention_weights2 queries, 10 keys/values, valid lengths (2, 6) — only the first 2 / first 6 keys per batch get nonzero weight:
queries = jax.random.normal(d2l.get_key(), (2, 1, 2))
keys = jax.random.normal(d2l.get_key(), (2, 10, 2))
values = jax.random.normal(d2l.get_key(), (2, 10, 4))
valid_lens = d2l.tensor([2, 6])
attention = DotProductAttention(dropout=0.5)
(output, attention_weights), params = attention.init_with_output(
d2l.get_key(), queries, keys, values, valid_lens)
print(output)[[[-0.04701408 0.61414313 -1.0031724 -0.661477 ]]
[[ 0.83954984 0.3183016 0.52048445 0.48648867]]]
a(\mathbf{q}, \mathbf{k}) = \mathbf{w}_v^\top \tanh(\mathbf{W}_q\mathbf{q} + \mathbf{W}_k\mathbf{k}). Learnable \mathbf{W}_q, \mathbf{W}_k, \mathbf{w}_v. Lets queries and keys live in different feature spaces.
class AdditiveAttention(nn.Module):
num_hiddens: int
dropout: float
def setup(self):
self.W_k = nn.Dense(self.num_hiddens, use_bias=False)
self.W_q = nn.Dense(self.num_hiddens, use_bias=False)
self.w_v = nn.Dense(1, use_bias=False)
@nn.compact
def __call__(self, queries, keys, values, valid_lens, training=False):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
# key-value pairs, num_hiddens). Sum them up with broadcasting
features = jnp.expand_dims(queries, axis=2) + jnp.expand_dims(keys, axis=1)
features = nn.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores: (batch_size,
# no. of queries, no. of key-value pairs)
scores = self.w_v(features).squeeze(-1)
attention_weights = masked_softmax(scores, valid_lens)
dropout_layer = nn.Dropout(self.dropout, deterministic=not training)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return dropout_layer(attention_weights)@values, attention_weightsSame shapes as before, with mismatched query/key dims allowed:
[[[ 0.37564147 0.10681814 -0.39052612 0.30577394]]
[[ 0.45246756 0.02368876 0.20076841 0.56524944]]]