from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jaxBahdanau attention links two sequences (decoder steps to encoder steps). What if we use the same trick within a single sequence — let every token query every other token? That’s self-attention: queries, keys, and values all come from the same input.
The output for each token is a weighted average of all tokens, with weights determined by query-key compatibility. This is the elementary block of every Transformer encoder and decoder layer.
Reuse multi-head attention with the same input fed three times. Output shape matches input shape — same sequence length, same hidden size:
Three ways to map a length-n sequence to another length-n sequence with d-dim tokens:
| Compute | Sequential ops | Max path | |
|---|---|---|---|
| CNN (k-wide) | \mathcal{O}(knd^2) | \mathcal{O}(1) | \mathcal{O}(n/k) |
| RNN | \mathcal{O}(nd^2) | \mathcal{O}(n) | \mathcal{O}(n) |
| Self-attention | \mathcal{O}(n^2 d) | \mathcal{O}(1) | \mathcal{O}(1) |
Self-attention wins on parallelism and path length — every token reaches every other in one hop. The price is n^2 scaling.
CNN, RNN, self-attention. Path lengths: \mathcal{O}(n/k), \mathcal{O}(n), \mathcal{O}(1).
Self-attention is permutation-equivariant: shuffle the input tokens, and the outputs shuffle the same way. The model has no idea about word order.
Solution: inject position information into each token’s representation. Vaswani et al. use fixed sine/cosine encodings:
p_{i,2j} = \sin\!\left(\frac{i}{10000^{2j/d}}\right),\quad p_{i,2j+1} = \cos\!\left(\frac{i}{10000^{2j/d}}\right).
Different frequencies along the embedding dimension; same position i across all dims gives a unique fingerprint.
Precompute \mathbf{P} once for max_len positions, slice to actual length at forward time, add to inputs:
class PositionalEncoding(nn.Module):
"""Positional encoding."""
num_hiddens: int
dropout: float
max_len: int = 1000
def setup(self):
# Create a long enough P
self.P = d2l.zeros((1, self.max_len, self.num_hiddens))
X = d2l.arange(self.max_len, dtype=jnp.float32).reshape(
-1, 1) / jnp.power(10000, jnp.arange(
0, self.num_hiddens, 2, dtype=jnp.float32) / self.num_hiddens)
self.P = self.P.at[:, :, 0::2].set(jnp.sin(X))
self.P = self.P.at[:, :, 1::2].set(jnp.cos(X[:, :self.num_hiddens // 2]))
@nn.compact
def __call__(self, X, training=False, offset=0):
# Flax sow API is used to capture intermediate variables
self.sow('intermediates', 'P', self.P)
# `offset` lets autoregressive decoders advance the encoding position
# past tokens already emitted, instead of always slicing from 0.
X = X + self.P[:, offset:offset + X.shape[1], :]
return nn.Dropout(self.dropout)(X, deterministic=not training)Plot four columns of \mathbf{P}. Lower-index columns oscillate fast; higher columns oscillate slow:
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
params = pos_encoding.init(d2l.get_key(), d2l.zeros((1, num_steps, encoding_dim)))
X, inter_vars = pos_encoding.apply(params, d2l.zeros((1, num_steps, encoding_dim)),
mutable='intermediates')
P = inter_vars['intermediates']['P'][0] # retrieve intermediate value P
P = P[:, :X.shape[1], :]
d2l.plot(d2l.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in range(6, 10)])Compare with binary representations of small integers — same “low bits flip fast, high bits flip slow” pattern, but in continuous values:
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
Heatmap reveals the multi-frequency structure. Each row is a unique fingerprint for a position:
For any fixed offset \delta, the encoding at position i + \delta is a linear function of the encoding at position i:
\begin{bmatrix} p_{i+\delta, 2j} \\ p_{i+\delta, 2j+1} \end{bmatrix} = \begin{bmatrix} \cos(\delta\omega_j) & \sin(\delta\omega_j) \\ -\sin(\delta\omega_j) & \cos(\delta\omega_j) \end{bmatrix} \begin{bmatrix} p_{i, 2j} \\ p_{i, 2j+1} \end{bmatrix}.
The rotation depends on \delta but not on i. So the network can learn to “shift attention by 5 tokens” with one linear transformation — relative positions come for free.