Patch embedding

Transformers for Vision

Vision Transformer

The Transformer started as a translation model. Could it also do vision?

Vision Transformer (Dosovitskiy et al., 2021): chop the image into 16×16 patches, treat each patch as a token, run a pure Transformer encoder. With enough data (300M images) they outperform ResNets — at smaller scale they still need CNN-style biases or heavy regularization.

Architecture

Patchify → embed + <cls>n encoder blocks → classify from <cls> representation.

Setup

from d2l import torch as d2l
import torch
from torch import nn

“Split into patches, then linearly project” = a single strided convolution with kernel_size = stride = patch_size. For a 96×96 image with 16×16 patches, this gives a sequence of 36 patch tokens, each a num_hiddens-dim vector:

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=96, patch_size=16, num_hiddens=512):
        super().__init__()
        def _make_tuple(x):
            if not isinstance(x, (list, tuple)):
                return (x, x)
            return x
        img_size, patch_size = _make_tuple(img_size), _make_tuple(patch_size)
        self.num_patches = (img_size[0] // patch_size[0]) * (
            img_size[1] // patch_size[1])
        self.conv = nn.LazyConv2d(num_hiddens, kernel_size=patch_size,
                                  stride=patch_size)

    def forward(self, X):
        # Output shape: (batch size, no. of patches, no. of channels)
        return self.conv(X).flatten(2).transpose(1, 2)

Patch embedding shape check

The convolution returns one vector per patch. For 96×96 images and 16×16 patches, the sequence length is (96/16)^2 = 36.

img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4
patch_emb = PatchEmbedding(img_size, patch_size, num_hiddens)
X = d2l.zeros(batch_size, 3, img_size, img_size)
d2l.check_shape(patch_emb(X),
                (batch_size, (img_size//patch_size)**2, num_hiddens))

ViT MLP block

Two changes vs the original Transformer FFN:

  • GELU instead of ReLU — smoother, slightly better in practice for Transformers.
  • Dropout after both linear layers, not just the output.
class ViTMLP(nn.Module):
    def __init__(self, mlp_num_hiddens, mlp_num_outputs, dropout=0.5):
        super().__init__()
        self.dense1 = nn.LazyLinear(mlp_num_hiddens)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.dense2 = nn.LazyLinear(mlp_num_outputs)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout2(self.dense2(self.dropout1(self.gelu(
            self.dense1(x)))))

ViT block: pre-norm

Original Transformer post-norm: LN(X + sublayer(X)). ViT pre-norm: X + sublayer(LN(X)). Pre-norm trains more stably and tolerates much deeper stacks — the standard choice in modern Transformers (LLaMA, GPT, etc.).

class ViTBlock(nn.Module):
    def __init__(self, num_hiddens, norm_shape, mlp_num_hiddens,
                 num_heads, dropout, use_bias=False):
        super().__init__()
        self.ln1 = nn.LayerNorm(norm_shape)
        self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads,
                                                dropout, use_bias)
        self.ln2 = nn.LayerNorm(norm_shape)
        self.mlp = ViTMLP(mlp_num_hiddens, num_hiddens, dropout)

    def forward(self, X, valid_lens=None):
        X = X + self.attention(*([self.ln1(X)] * 3), valid_lens)
        return X + self.mlp(self.ln2(X))

JAX/TF variants and shape check

The framework-specific code differs, but the contract is the same: a ViT block maps (batch, num_patches + 1, num_hiddens) back to the same shape so blocks can stack.

X = d2l.ones((2, 100, 24))
encoder_blk = ViTBlock(24, 24, 48, 8, 0.5)
encoder_blk.eval()
d2l.check_shape(encoder_blk(X), X.shape)

Putting it together

Patch embed → prepend learnable <cls> token → add learnable positional embeddings (not fixed sin/cos) → dropout → N ViT blocks → take <cls> representation → LayerNorm → linear head:

class ViT(d2l.Classifier):
    """Vision Transformer."""
    def __init__(self, img_size, patch_size, num_hiddens, mlp_num_hiddens,
                 num_heads, num_blks, emb_dropout, blk_dropout, lr=0.1,
                 use_bias=False, num_classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.patch_embedding = PatchEmbedding(
            img_size, patch_size, num_hiddens)
        self.cls_token = nn.Parameter(d2l.zeros(1, 1, num_hiddens))
        num_steps = self.patch_embedding.num_patches + 1  # Add the cls token
        # Positional embeddings are learnable
        self.pos_embedding = nn.Parameter(
            torch.randn(1, num_steps, num_hiddens))
        self.dropout = nn.Dropout(emb_dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module(f"{i}", ViTBlock(
                num_hiddens, num_hiddens, mlp_num_hiddens,
                num_heads, blk_dropout, use_bias))
        self.head = nn.Sequential(nn.LayerNorm(num_hiddens),
                                  nn.Linear(num_hiddens, num_classes))

    def forward(self, X):
        X = self.patch_embedding(X)
        X = d2l.concat((self.cls_token.expand(X.shape[0], -1, -1), X), 1)
        X = self.dropout(X + self.pos_embedding)
        for blk in self.blks:
            X = blk(X)
        return self.head(X[:, 0])

Training on Fashion-MNIST

Tiny config (2 blocks, 512 hidden, 8 heads). On a small dataset, this won’t beat a ResNet — Transformers need scale:

img_size, patch_size = 96, 16
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads,
            num_blks, emb_dropout, blk_dropout, lr)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(img_size, img_size))
trainer.fit(model, data)

Recap

  • ViT = patchify image → standard Transformer encoder → classify from <cls> token representation.
  • Patches replace tokens; positional embeddings are learned (not sin/cos) since 2D positions don’t need closed-form encoding.
  • Pre-norm beats post-norm at scale; GELU beats ReLU in MLPs.
  • ViTs lose to ResNets on small data — they lack locality and translation invariance — but win at large scale (300M+ images). Swin Transformers and DeiT bridge the gap.