from d2l import torch as d2l
import math
import pandas as pd
import torch
from torch import nn2017’s Attention is All You Need threw out RNNs entirely and built sequence models from self-attention, positionwise MLPs, residuals, and layer norm.
The Transformer is now the architecture for language, vision, speech, and beyond. Same code, 6 layers (original) to 96+ layers (frontier LLMs).
Bottom-up assembly:
The Transformer architecture.
Encoder: N identical blocks (self-attention → Add & Norm → FFN → Add & Norm). Decoder: same, plus masked self-attention and encoder-decoder cross-attention. Embedding + positional encoding before the first block.
The same multi-head attention operator is used in three different roles:
A two-layer MLP applied independently at every sequence position — same weights everywhere. Lets each position process its attention output through a nonlinear feature mixer:
class PositionWiseFFN(nn.Module):
"""The positionwise feed-forward network."""
def __init__(self, ffn_num_hiddens, ffn_num_outputs):
super().__init__()
self.dense1 = nn.LazyLinear(ffn_num_hiddens)
self.relu = nn.ReLU()
self.dense2 = nn.LazyLinear(ffn_num_outputs)
def forward(self, X):
return self.dense2(self.relu(self.dense1(X)))Shape check: rank-3 input, only the last dim changes:
tensor([[ 0.2018, 0.4420, -0.4201, 0.1596, 0.4288, -0.2823, -0.1116, 0.3560],
[ 0.2018, 0.4420, -0.4201, 0.1596, 0.4288, -0.2823, -0.1116, 0.3560],
[ 0.2018, 0.4420, -0.4201, 0.1596, 0.4288, -0.2823, -0.1116, 0.3560]],
grad_fn=<SelectBackward0>)
BatchNorm normalizes across the batch — fragile with variable-length sequences and small batches. LayerNorm normalizes across the feature dimension of one example — batch-size and length independent. That’s why NLP picked it.
layer norm: tensor([[-1.0000, 1.0000],
[-1.0000, 1.0000]], grad_fn=<NativeLayerNormBackward0>)
batch norm: tensor([[-1.0000, -1.0000],
[ 1.0000, 1.0000]], grad_fn=<NativeBatchNormBackward0>)
The repeating motif: residual connection (X + sublayer(X)), dropout, then LayerNorm. Both inputs must have the same shape:
One block = MultiHead self-attention → AddNorm → FFN → AddNorm. Shape in = shape out, so blocks stack without any projection in between:
class TransformerEncoderBlock(nn.Module):
"""The Transformer encoder block."""
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout,
use_bias=False):
super().__init__()
self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads,
dropout, use_bias)
self.addnorm1 = AddNorm(num_hiddens, dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(num_hiddens, dropout)
def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))Embed tokens, scale by \sqrt{d} to balance against the positional encoding, add positions, then run N blocks. Save attention weights per block for later visualization:
class TransformerEncoder(d2l.Encoder):
"""The Transformer encoder."""
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,
num_heads, num_blks, dropout, use_bias=False):
super().__init__()
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_blks):
self.blks.add_module("block"+str(i), TransformerEncoderBlock(
num_hiddens, ffn_num_hiddens, num_heads, dropout, use_bias))
def forward(self, X, valid_lens):
# Since positional encoding values are between -1 and 1, the embedding
# values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[
i] = blk.attention.attention.attention_weights
return XThree sublayers, each wrapped in AddNorm:
class TransformerDecoderBlock(nn.Module):
# The i-th block in the Transformer decoder
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, i):
super().__init__()
self.i = i
self.attention1 = d2l.MultiHeadAttention(num_hiddens, num_heads,
dropout)
self.addnorm1 = AddNorm(num_hiddens, dropout)
self.attention2 = d2l.MultiHeadAttention(num_hiddens, num_heads,
dropout)
self.addnorm2 = AddNorm(num_hiddens, dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(num_hiddens, dropout)
def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
# During training, all the tokens of any output sequence are processed
# at the same time, so state[2][self.i] is None as initialized. When
# decoding any output sequence token by token during prediction,
# state[2][self.i] contains representations of the decoded output at
# the i-th block up to the current time step
if state[2][self.i] is None:
key_values = X
else:
key_values = torch.cat((state[2][self.i], X), dim=1)
state[2][self.i] = key_values
if self.training:
batch_size, num_steps, _ = X.shape
# Shape of dec_valid_lens: (batch_size, num_steps), where every
# row is [1, 2, ..., num_steps]
dec_valid_lens = torch.arange(
1, num_steps + 1, device=X.device).repeat(batch_size, 1)
else:
dec_valid_lens = None
# Self-attention
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
# Encoder-decoder attention. Shape of enc_outputs:
# (batch_size, num_steps, num_hiddens)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), stateRun the decoder with fake encoder outputs and source valid_lens. The output is target-position logits; the state carries encoder outputs plus per-block caches used during autoregressive prediction:
Token embedding + positional encoding -> N decoder blocks -> vocab projection. During training, causal masks are built from target positions; during prediction, the cache grows one token at a time.
class TransformerDecoder(d2l.AttentionDecoder):
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout):
super().__init__()
self.num_hiddens = num_hiddens
self.num_blks = num_blks
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_blks):
self.blks.add_module("block"+str(i), TransformerDecoderBlock(
num_hiddens, ffn_num_hiddens, num_heads, dropout, i))
self.dense = nn.LazyLinear(vocab_size)
def init_state(self, enc_outputs, enc_valid_lens):
return [enc_outputs, enc_valid_lens, [None] * self.num_blks]
def forward(self, X, state):
# During step-by-step prediction, position-encode the new token using
# its true offset (the number of tokens already decoded), rather than
# always re-applying P[0:1]. This matches the pos encoding seen at
# training time and is critical for stable autoregressive decoding.
pos_offset = 0 if state[2][0] is None else state[2][0].shape[1]
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens),
offset=pos_offset)
self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
# Decoder self-attention weights
self._attention_weights[0][
i] = blk.attention1.attention.attention_weights
# Encoder-decoder attention weights
self._attention_weights[1][
i] = blk.attention2.attention.attention_weights
return self.dense(X), state
@property
def attention_weights(self):
return self._attention_weightsSame MTFraEng dataset as the seq2seq chapter. 2 layers, 256 hidden, 4 heads, dropout 0.2. Adam lr=0.001, gradient clip 1, 30 epochs:
data = d2l.MTFraEng(batch_size=128)
num_hiddens, num_blks, dropout = 256, 2, 0.2
ffn_num_hiddens, num_heads = 64, 4
encoder = TransformerEncoder(
len(data.src_vocab), num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout)
decoder = TransformerDecoder(
len(data.tgt_vocab), num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout)
model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab['<pad>'],
lr=0.001)
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)This is a tiny model on a tiny dataset. Look for good short translations and BLEU differences across examples; errors are usually data/model-size limits, not a change in the architecture.
engs = ['i lost .', 'i\'m calm .', 'i\'m home .']
fras = ['j\'ai perdu .', 'je suis calme .', 'je suis chez moi .']
preds, _ = model.predict_step(
data.build(engs, fras), d2l.try_gpu(), data.num_steps)
for en, fr, p in zip(engs, fras, preds):
translation = []
for token in data.tgt_vocab.to_tokens(p):
if token == '<eos>':
break
translation.append(token)
print(f'{en} => {translation}, bleu,'
f'{d2l.bleu(" ".join(translation), fr, k=2):.3f}')i lost . => ["j'ai", 'perdu', '.'], bleu,1.000
i'm calm . => ['je', 'suis', 'calme', '.'], bleu,1.000
i'm home . => ['je', 'suis', 'chez', 'moi', '.'], bleu,1.000
Pull the encoder’s stored attention weights, reshape into (layers × heads × queries × keys), heatmap them. Different heads attend to different patterns:
_, dec_attention_weights = model.predict_step(
data.build([engs[-1]], [fras[-1]]), d2l.try_gpu(), data.num_steps, True)
enc_attention_weights = d2l.concat(model.encoder.attention_weights, 0)
shape = (num_blks, num_heads, -1, data.num_steps)
enc_attention_weights = d2l.reshape(enc_attention_weights, shape)
d2l.check_shape(enc_attention_weights,
(num_blks, num_heads, data.num_steps, data.num_steps))The decoder has two attention sublayers per block — masked self-attention and encoder-decoder cross-attention. Pull both from the prediction trace and reshape them into (blocks × heads × queries × keys):
dec_attention_weights_2d = [head[0].tolist()
for step in dec_attention_weights
for attn in step for blk in attn for head in blk]
dec_attention_weights_filled = d2l.tensor(
pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
shape = (-1, 2, num_blks, num_heads, data.num_steps)
dec_attention_weights = d2l.reshape(dec_attention_weights_filled, shape)
dec_self_attention_weights, dec_inter_attention_weights = \
dec_attention_weights.permute(1, 2, 3, 0, 4)The self-attention heatmap must be lower triangular: query position t can attend only to keys at positions \le t. That is what makes the decoder a language model during generation.
Cross-attention from decoder queries to encoder keys: notice zero weight on source padding tokens. Masking with valid_lens during attention is what enforces this: