from d2l import tensorflow as d2l
import tensorflow as tf
import kerasThe Transformer started as a translation model. Could it also do vision?
Vision Transformer (Dosovitskiy et al., 2021): chop the image into 16×16 patches, treat each patch as a token, run a pure Transformer encoder. With enough data (300M images) they outperform ResNets — at smaller scale they still need CNN-style biases or heavy regularization.
Patchify → embed + <cls> → n encoder blocks → classify from <cls> representation.
“Split into patches, then linearly project” = a single strided convolution with kernel_size = stride = patch_size. For a 96×96 image with 16×16 patches, this gives a sequence of 36 patch tokens, each a num_hiddens-dim vector:
class PatchEmbedding(tf.keras.layers.Layer):
def __init__(self, img_size=96, patch_size=16, num_hiddens=512):
super().__init__()
def _make_tuple(x):
if not isinstance(x, (list, tuple)):
return (x, x)
return x
img_size, patch_size = _make_tuple(img_size), _make_tuple(patch_size)
self.num_patches = (img_size[0] // patch_size[0]) * (
img_size[1] // patch_size[1])
self.conv = tf.keras.layers.Conv2D(num_hiddens, kernel_size=patch_size,
strides=patch_size)
def call(self, X):
# Input shape: (batch, H, W, C); output: (batch, num_patches, num_hiddens)
X = self.conv(X)
return tf.reshape(X, (tf.shape(X)[0], -1, X.shape[-1]))The convolution returns one vector per patch. For 96×96 images and 16×16 patches, the sequence length is (96/16)^2 = 36.
Two changes vs the original Transformer FFN:
class ViTMLP(tf.keras.layers.Layer):
def __init__(self, mlp_num_hiddens, mlp_num_outputs, dropout=0.5):
super().__init__()
self.dense1 = tf.keras.layers.Dense(mlp_num_hiddens, activation='gelu')
self.dropout1 = tf.keras.layers.Dropout(dropout)
self.dense2 = tf.keras.layers.Dense(mlp_num_outputs)
self.dropout2 = tf.keras.layers.Dropout(dropout)
def call(self, x, training=False):
return self.dropout2(self.dense2(
self.dropout1(self.dense1(x), training=training)),
training=training)Original Transformer post-norm: LN(X + sublayer(X)). ViT pre-norm: X + sublayer(LN(X)). Pre-norm trains more stably and tolerates much deeper stacks — the standard choice in modern Transformers (LLaMA, GPT, etc.).
The framework-specific code differs, but the contract is the same: a ViT block maps (batch, num_patches + 1, num_hiddens) back to the same shape so blocks can stack.
class ViTBlock(tf.keras.layers.Layer):
def __init__(self, num_hiddens, mlp_num_hiddens, num_heads, dropout,
use_bias=False):
super().__init__()
self.ln1 = tf.keras.layers.LayerNormalization()
self.attention = tf.keras.layers.MultiHeadAttention(
num_heads=num_heads, key_dim=num_hiddens // num_heads,
dropout=dropout, use_bias=use_bias)
self.ln2 = tf.keras.layers.LayerNormalization()
self.mlp = ViTMLP(mlp_num_hiddens, num_hiddens, dropout)
def call(self, X, training=False):
X_norm = self.ln1(X, training=training)
X = X + self.attention(X_norm, X_norm, training=training)
return X + self.mlp(self.ln2(X, training=training), training=training)Patch embed → prepend learnable <cls> token → add learnable positional embeddings (not fixed sin/cos) → dropout → N ViT blocks → take <cls> representation → LayerNorm → linear head:
class ViT(d2l.Classifier):
"""Vision Transformer."""
def __init__(self, img_size, patch_size, num_hiddens, mlp_num_hiddens,
num_heads, num_blks, emb_dropout, blk_dropout, lr=0.1,
use_bias=False, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.patch_embedding = PatchEmbedding(img_size, patch_size, num_hiddens)
num_steps = self.patch_embedding.num_patches + 1 # Add the cls token
self.num_steps = num_steps
self.num_hiddens = num_hiddens
self.emb_dropout = tf.keras.layers.Dropout(emb_dropout)
self.blks = [ViTBlock(num_hiddens, mlp_num_hiddens, num_heads,
blk_dropout, use_bias)
for _ in range(num_blks)]
self.head_norm = tf.keras.layers.LayerNormalization()
self.head_dense = tf.keras.layers.Dense(num_classes)
def build(self, input_shape):
self.cls_token = self.add_weight(
name='cls_token', shape=(1, 1, self.num_hiddens),
initializer='zeros', trainable=True)
self.pos_embedding = self.add_weight(
name='pos_embedding', shape=(1, self.num_steps, self.num_hiddens),
initializer='random_normal', trainable=True)
super().build(input_shape)
def call(self, X, training=False):
X = self.patch_embedding(X)
batch_size = tf.shape(X)[0]
cls_tokens = tf.tile(self.cls_token, [batch_size, 1, 1])
X = tf.concat([cls_tokens, X], axis=1)
X = self.emb_dropout(X + self.pos_embedding, training=training)
for blk in self.blks:
X = blk(X, training=training)
return self.head_dense(self.head_norm(X[:, 0]))Tiny config (2 blocks, 512 hidden, 8 heads). On a small dataset, this won’t beat a ResNet — Transformers need scale:
img_size, patch_size = 96, 16
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
data = d2l.FashionMNIST(batch_size=128, resize=(img_size, img_size))
trainer = d2l.Trainer(max_epochs=10)
with d2l.try_gpu():
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads,
num_blks, emb_dropout, blk_dropout, lr)
trainer.fit(model, data)<cls> token representation.