from d2l import mxnet as d2l
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
npx.set_np()Bahdanau attention links two sequences (decoder steps to encoder steps). What if we use the same trick within a single sequence — let every token query every other token? That’s self-attention: queries, keys, and values all come from the same input.
The output for each token is a weighted average of all tokens, with weights determined by query-key compatibility. This is the elementary block of every Transformer encoder and decoder layer.
Reuse multi-head attention with the same input fed three times. Output shape matches input shape — same sequence length, same hidden size:
Three ways to map a length-n sequence to another length-n sequence with d-dim tokens:
| Compute | Sequential ops | Max path | |
|---|---|---|---|
| CNN (k-wide) | \mathcal{O}(knd^2) | \mathcal{O}(1) | \mathcal{O}(n/k) |
| RNN | \mathcal{O}(nd^2) | \mathcal{O}(n) | \mathcal{O}(n) |
| Self-attention | \mathcal{O}(n^2 d) | \mathcal{O}(1) | \mathcal{O}(1) |
Self-attention wins on parallelism and path length — every token reaches every other in one hop. The price is n^2 scaling.
CNN, RNN, self-attention. Path lengths: \mathcal{O}(n/k), \mathcal{O}(n), \mathcal{O}(1).
Self-attention is permutation-equivariant: shuffle the input tokens, and the outputs shuffle the same way. The model has no idea about word order.
Solution: inject position information into each token’s representation. Vaswani et al. use fixed sine/cosine encodings:
p_{i,2j} = \sin\!\left(\frac{i}{10000^{2j/d}}\right),\quad p_{i,2j+1} = \cos\!\left(\frac{i}{10000^{2j/d}}\right).
Different frequencies along the embedding dimension; same position i across all dims gives a unique fingerprint.
Precompute \mathbf{P} once for max_len positions, slice to actual length at forward time, add to inputs:
class PositionalEncoding(nn.Block):
"""Positional encoding."""
def __init__(self, num_hiddens, dropout, max_len=1000):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create a long enough P
self.P = d2l.zeros((1, max_len, num_hiddens))
X = d2l.arange(max_len).reshape(-1, 1) / np.power(
10000, np.arange(0, num_hiddens, 2) / num_hiddens)
self.P[:, :, 0::2] = np.sin(X)
self.P[:, :, 1::2] = np.cos(X[:, :num_hiddens // 2])
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].as_in_ctx(X.ctx)
return self.dropout(X)Plot four columns of \mathbf{P}. Lower-index columns oscillate fast; higher columns oscillate slow:
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.initialize()
X = pos_encoding(np.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(d2l.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in range(6, 10)])Compare with binary representations of small integers — same “low bits flip fast, high bits flip slow” pattern, but in continuous values:
Heatmap reveals the multi-frequency structure. Each row is a unique fingerprint for a position:
For any fixed offset \delta, the encoding at position i + \delta is a linear function of the encoding at position i:
\begin{bmatrix} p_{i+\delta, 2j} \\ p_{i+\delta, 2j+1} \end{bmatrix} = \begin{bmatrix} \cos(\delta\omega_j) & \sin(\delta\omega_j) \\ -\sin(\delta\omega_j) & \cos(\delta\omega_j) \end{bmatrix} \begin{bmatrix} p_{i, 2j} \\ p_{i, 2j+1} \end{bmatrix}.
The rotation depends on \delta but not on i. So the network can learn to “shift attention by 5 tokens” with one linear transformation — relative positions come for free.