from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnpThe Transformer started as a translation model. Could it also do vision?
Vision Transformer (Dosovitskiy et al., 2021): chop the image into 16×16 patches, treat each patch as a token, run a pure Transformer encoder. With enough data (300M images) they outperform ResNets — at smaller scale they still need CNN-style biases or heavy regularization.
Patchify → embed + <cls> → n encoder blocks → classify from <cls> representation.
“Split into patches, then linearly project” = a single strided convolution with kernel_size = stride = patch_size. For a 96×96 image with 16×16 patches, this gives a sequence of 36 patch tokens, each a num_hiddens-dim vector:
class PatchEmbedding(nn.Module):
img_size: int = 96
patch_size: int = 16
num_hiddens: int = 512
def setup(self):
def _make_tuple(x):
if not isinstance(x, (list, tuple)):
return (x, x)
return x
img_size, patch_size = _make_tuple(self.img_size), _make_tuple(self.patch_size)
self.num_patches = (img_size[0] // patch_size[0]) * (
img_size[1] // patch_size[1])
self.conv = nn.Conv(self.num_hiddens, kernel_size=patch_size,
strides=patch_size, padding='SAME')
def __call__(self, X):
# Output shape: (batch size, no. of patches, no. of channels)
X = self.conv(X)
return X.reshape((X.shape[0], -1, X.shape[3]))The convolution returns one vector per patch. For 96×96 images and 16×16 patches, the sequence length is (96/16)^2 = 36.
img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4
patch_emb = PatchEmbedding(img_size, patch_size, num_hiddens)
X = d2l.zeros((batch_size, img_size, img_size, 3))
output, _ = patch_emb.init_with_output(d2l.get_key(), X)
d2l.check_shape(output, (batch_size, (img_size//patch_size)**2, num_hiddens))Two changes vs the original Transformer FFN:
class ViTMLP(nn.Module):
mlp_num_hiddens: int
mlp_num_outputs: int
dropout: float = 0.5
@nn.compact
def __call__(self, x, training=False):
x = nn.Dense(self.mlp_num_hiddens)(x)
x = nn.gelu(x)
x = nn.Dropout(self.dropout, deterministic=not training)(x)
x = nn.Dense(self.mlp_num_outputs)(x)
x = nn.Dropout(self.dropout, deterministic=not training)(x)
return xOriginal Transformer post-norm: LN(X + sublayer(X)). ViT pre-norm: X + sublayer(LN(X)). Pre-norm trains more stably and tolerates much deeper stacks — the standard choice in modern Transformers (LLaMA, GPT, etc.).
The framework-specific code differs, but the contract is the same: a ViT block maps (batch, num_patches + 1, num_hiddens) back to the same shape so blocks can stack.
class ViTBlock(nn.Module):
num_hiddens: int
mlp_num_hiddens: int
num_heads: int
dropout: float
use_bias: bool = False
def setup(self):
self.attention = d2l.MultiHeadAttention(self.num_hiddens, self.num_heads,
self.dropout, self.use_bias)
self.mlp = ViTMLP(self.mlp_num_hiddens, self.num_hiddens, self.dropout)
@nn.compact
def __call__(self, X, valid_lens=None, training=False):
X = X + self.attention(*([nn.LayerNorm()(X)] * 3),
valid_lens, training=training)[0]
return X + self.mlp(nn.LayerNorm()(X), training=training)Patch embed → prepend learnable <cls> token → add learnable positional embeddings (not fixed sin/cos) → dropout → N ViT blocks → take <cls> representation → LayerNorm → linear head:
class ViT(d2l.Classifier):
"""Vision Transformer."""
img_size: int
patch_size: int
num_hiddens: int
mlp_num_hiddens: int
num_heads: int
num_blks: int
emb_dropout: float
blk_dropout: float
lr: float = 0.1
use_bias: bool = False
num_classes: int = 10
training: bool = False
def setup(self):
self.patch_embedding = PatchEmbedding(self.img_size, self.patch_size,
self.num_hiddens)
self.cls_token = self.param('cls_token', nn.initializers.zeros,
(1, 1, self.num_hiddens))
num_steps = self.patch_embedding.num_patches + 1 # Add the cls token
# Positional embeddings are learnable
self.pos_embedding = self.param('pos_embed', nn.initializers.normal(),
(1, num_steps, self.num_hiddens))
self.blks = [ViTBlock(self.num_hiddens, self.mlp_num_hiddens,
self.num_heads, self.blk_dropout, self.use_bias)
for _ in range(self.num_blks)]
self.head = nn.Sequential([nn.LayerNorm(), nn.Dense(self.num_classes)])
@nn.compact
def __call__(self, X):
X = self.patch_embedding(X)
X = d2l.concat((jnp.tile(self.cls_token, (X.shape[0], 1, 1)), X), 1)
X = nn.Dropout(self.emb_dropout, deterministic=not self.training)(X + self.pos_embedding)
for blk in self.blks:
X = blk(X, training=self.training)
return self.head(X[:, 0])Tiny config (2 blocks, 512 hidden, 8 heads). On a small dataset, this won’t beat a ResNet — Transformers need scale:
img_size, patch_size = 96, 16
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads,
num_blks, emb_dropout, blk_dropout, lr)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(img_size, img_size))
trainer.fit(model, data)<cls> token representation.