from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as npBERT (Devlin et al., 2018) — bidirectional Transformer encoder pretrained on a giant corpus, then fine-tuned to arbitrary downstream NLP tasks. Started the “pretrain + fine-tune” era for NLP.
ELMo vs GPT vs BERT.
A BERT input sequence packs in a lot:
<cls> + tokens of segment A + <sep> + tokens of segment B + <sep>.Token + segment + position embeddings, all summed.
def get_tokens_and_segments(tokens_a, tokens_b=None):
"""Get tokens of the BERT input sequence and their segment IDs."""
tokens = ['<cls>'] + tokens_a + ['<sep>']
# 0 and 1 are marking segment A and B, respectively
segments = [0] * (len(tokens_a) + 2)
if tokens_b is not None:
tokens += tokens_b + ['<sep>']
segments += [1] * (len(tokens_b) + 1)
return tokens, segmentsA standard Transformer encoder stack on the summed embeddings. The pretrained model exposes one hidden vector per input position:
class BERTEncoder(nn.Module):
"""BERT encoder."""
vocab_size: int
num_hiddens: int
ffn_num_hiddens: int
num_heads: int
num_blks: int
dropout: float
max_len: int = 1000
def setup(self):
self.token_embedding = nn.Embed(self.vocab_size, self.num_hiddens)
self.segment_embedding = nn.Embed(2, self.num_hiddens)
self.blks = [d2l.TransformerEncoderBlock(
self.num_hiddens, self.ffn_num_hiddens, self.num_heads,
self.dropout, True) for _ in range(self.num_blks)]
# In BERT, positional embeddings are learnable, thus we create a
# parameter of positional embeddings that are long enough
self.pos_embedding = self.param('pos_embedding',
nn.initializers.normal(0.02),
(1, self.max_len, self.num_hiddens))
def __call__(self, tokens, segments, valid_lens, training=False):
# Shape of `X` remains unchanged in the following code snippet:
# (batch size, max sequence length, `num_hiddens`)
X = self.token_embedding(tokens) + self.segment_embedding(segments)
X = X + self.pos_embedding[:, :X.shape[1], :]
for blk in self.blks:
X, _ = blk(X, valid_lens, training=training)
return XThe encoder emits a contextual vector for every input token plus one pooled <cls> vector. Both shapes should agree with num_hiddens; mismatches usually mean segment or position embeddings were not summed correctly.
vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
ffn_num_input, num_blks, dropout = 768, 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout)
tokens = jnp.ones((2, 8), dtype=jnp.int32)
segments = jnp.array([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
params = encoder.init(jax.random.PRNGKey(0), tokens, segments, None)Randomly mask 15% of input tokens (replace with <mask> 80% of the time, a random token 10%, leave unchanged 10%). Train the encoder to predict the originals. Forces the model to use both left and right context.
class MaskLM(nn.Module):
"""The masked language model task of BERT."""
vocab_size: int
num_hiddens: int
@nn.compact
def __call__(self, X, pred_positions):
num_pred_positions = pred_positions.shape[1]
pred_positions = pred_positions.reshape(-1)
batch_size = X.shape[0]
batch_idx = jnp.arange(0, batch_size)
# Suppose that `batch_size` = 2, `num_pred_positions` = 3, then
# `batch_idx` is `jnp.array([0, 0, 0, 1, 1, 1])`
batch_idx = jnp.repeat(batch_idx, num_pred_positions)
masked_X = X[batch_idx, pred_positions]
masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
mlm_Y_hat = nn.Dense(self.num_hiddens)(masked_X)
mlm_Y_hat = nn.relu(mlm_Y_hat)
mlm_Y_hat = nn.LayerNorm()(mlm_Y_hat)
mlm_Y_hat = nn.Dense(self.vocab_size)(mlm_Y_hat)
return mlm_Y_hatGather hidden states at the masked positions; project through an MLP head to vocab logits. The loss is evaluated only on these selected positions, not on every token:
(2, 3, 10000)
Auxiliary binary task: given two segments, are they consecutive in the corpus? Trains the <cls> token’s representation to capture sentence-pair relationships (useful for QA, NLI):
2-way classifier on the <cls> representation:
(2, 2)
Encoder + MaskLM head + NSP head, sharing the same backbone. Pretrain end-to-end on (masked tokens, NSP label) tuples; fine-tune downstream by replacing the heads:
class BERTModel(nn.Module):
"""The BERT model."""
vocab_size: int
num_hiddens: int
ffn_num_hiddens: int
num_heads: int
num_blks: int
dropout: float
max_len: int = 1000
def setup(self):
self.encoder = BERTEncoder(
self.vocab_size, self.num_hiddens, self.ffn_num_hiddens,
self.num_heads, self.num_blks, self.dropout,
max_len=self.max_len)
self.hidden = nn.Dense(self.num_hiddens)
self.mlm = MaskLM(self.vocab_size, self.num_hiddens)
self.nsp = NextSentencePred()
def __call__(self, tokens, segments, valid_lens=None, pred_positions=None,
training=False):
encoded_X = self.encoder(tokens, segments, valid_lens,
training=training)
if pred_positions is not None:
mlm_Y_hat = self.mlm(encoded_X, pred_positions)
else:
mlm_Y_hat = None
# The hidden layer of the MLP classifier for next sentence prediction.
# 0 is the index of the '<cls>' token
nsp_Y_hat = self.nsp(
jnp.tanh(self.hidden(encoded_X[:, 0, :])))
return encoded_X, mlm_Y_hat, nsp_Y_hat