Input representation

Bidirectional Encoder Representations from Transformers (BERT)

BERT

BERT (Devlin et al., 2018) — bidirectional Transformer encoder pretrained on a giant corpus, then fine-tuned to arbitrary downstream NLP tasks. Started the “pretrain + fine-tune” era for NLP.

Lineage

  • ELMo — LSTM, contextual but task-specific architecture.
  • GPT — Transformer, unidirectional (left context only).
  • BERT — Transformer, bidirectional via masked LM, task-agnostic.

ELMo vs GPT vs BERT.

Setup

from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np

A BERT input sequence packs in a lot:

  • <cls> + tokens of segment A + <sep> + tokens of segment B + <sep>.
  • Three additive embeddings: token, segment (A vs B), position.

Token + segment + position embeddings, all summed.

def get_tokens_and_segments(tokens_a, tokens_b=None):
    """Get tokens of the BERT input sequence and their segment IDs."""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0 and 1 are marking segment A and B, respectively
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

BERTEncoder

A standard Transformer encoder stack on the summed embeddings. The pretrained model exposes one hidden vector per input position:

class BERTEncoder(nn.Module):
    """BERT encoder."""
    vocab_size: int
    num_hiddens: int
    ffn_num_hiddens: int
    num_heads: int
    num_blks: int
    dropout: float
    max_len: int = 1000

    def setup(self):
        self.token_embedding = nn.Embed(self.vocab_size, self.num_hiddens)
        self.segment_embedding = nn.Embed(2, self.num_hiddens)
        self.blks = [d2l.TransformerEncoderBlock(
            self.num_hiddens, self.ffn_num_hiddens, self.num_heads,
            self.dropout, True) for _ in range(self.num_blks)]
        # In BERT, positional embeddings are learnable, thus we create a
        # parameter of positional embeddings that are long enough
        self.pos_embedding = self.param('pos_embedding',
                                        nn.initializers.normal(0.02),
                                        (1, self.max_len, self.num_hiddens))

    def __call__(self, tokens, segments, valid_lens, training=False):
        # Shape of `X` remains unchanged in the following code snippet:
        # (batch size, max sequence length, `num_hiddens`)
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding[:, :X.shape[1], :]
        for blk in self.blks:
            X, _ = blk(X, valid_lens, training=training)
        return X

Encoder shape check

The encoder emits a contextual vector for every input token plus one pooled <cls> vector. Both shapes should agree with num_hiddens; mismatches usually mean segment or position embeddings were not summed correctly.

vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
ffn_num_input, num_blks, dropout = 768, 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
                      num_blks, dropout)
tokens = jnp.ones((2, 8), dtype=jnp.int32)
segments = jnp.array([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
params = encoder.init(jax.random.PRNGKey(0), tokens, segments, None)
tokens = jax.random.randint(jax.random.PRNGKey(0), (2, 8), 0, vocab_size)
segments = jnp.array([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder.apply(params, tokens, segments, None)
encoded_X.shape
(2, 8, 768)

Pretraining task 1: Masked LM

Randomly mask 15% of input tokens (replace with <mask> 80% of the time, a random token 10%, leave unchanged 10%). Train the encoder to predict the originals. Forces the model to use both left and right context.

class MaskLM(nn.Module):
    """The masked language model task of BERT."""
    vocab_size: int
    num_hiddens: int

    @nn.compact
    def __call__(self, X, pred_positions):
        num_pred_positions = pred_positions.shape[1]
        pred_positions = pred_positions.reshape(-1)
        batch_size = X.shape[0]
        batch_idx = jnp.arange(0, batch_size)
        # Suppose that `batch_size` = 2, `num_pred_positions` = 3, then
        # `batch_idx` is `jnp.array([0, 0, 0, 1, 1, 1])`
        batch_idx = jnp.repeat(batch_idx, num_pred_positions)
        masked_X = X[batch_idx, pred_positions]
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
        mlm_Y_hat = nn.Dense(self.num_hiddens)(masked_X)
        mlm_Y_hat = nn.relu(mlm_Y_hat)
        mlm_Y_hat = nn.LayerNorm()(mlm_Y_hat)
        mlm_Y_hat = nn.Dense(self.vocab_size)(mlm_Y_hat)
        return mlm_Y_hat

MaskLM forward

Gather hidden states at the masked positions; project through an MLP head to vocab logits. The loss is evaluated only on these selected positions, not on every token:

mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = jnp.array([[1, 5, 2], [6, 1, 5]])
mlm_params = mlm.init(jax.random.PRNGKey(0), encoded_X, mlm_positions)
mlm_Y_hat = mlm.apply(mlm_params, encoded_X, mlm_positions)
mlm_Y_hat.shape
(2, 3, 10000)
mlm_Y = jnp.array([[7, 8, 9], [10, 20, 30]])
mlm_l = optax.softmax_cross_entropy_with_integer_labels(
    mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l.shape
(6,)

Pretraining task 2: Next Sentence Prediction

Auxiliary binary task: given two segments, are they consecutive in the corpus? Trains the <cls> token’s representation to capture sentence-pair relationships (useful for QA, NLI):

class NextSentencePred(nn.Module):
    """The next sentence prediction task of BERT."""
    @nn.compact
    def __call__(self, X):
        # `X` shape: (batch size, `num_hiddens`)
        return nn.Dense(2)(X)

NSP head

2-way classifier on the <cls> representation:

# Use the `<cls>` token (index 0) as input to NSP
# input_shape for NSP: (batch size, `num_hiddens`)
cls_X = encoded_X[:, 0, :]
nsp = NextSentencePred()
nsp_params = nsp.init(jax.random.PRNGKey(0), cls_X)
nsp_Y_hat = nsp.apply(nsp_params, cls_X)
nsp_Y_hat.shape
(2, 2)
nsp_y = jnp.array([0, 1])
nsp_l = optax.softmax_cross_entropy_with_integer_labels(nsp_Y_hat, nsp_y)
nsp_l.shape
(2,)

Putting it together

Encoder + MaskLM head + NSP head, sharing the same backbone. Pretrain end-to-end on (masked tokens, NSP label) tuples; fine-tune downstream by replacing the heads:

class BERTModel(nn.Module):
    """The BERT model."""
    vocab_size: int
    num_hiddens: int
    ffn_num_hiddens: int
    num_heads: int
    num_blks: int
    dropout: float
    max_len: int = 1000

    def setup(self):
        self.encoder = BERTEncoder(
            self.vocab_size, self.num_hiddens, self.ffn_num_hiddens,
            self.num_heads, self.num_blks, self.dropout,
            max_len=self.max_len)
        self.hidden = nn.Dense(self.num_hiddens)
        self.mlm = MaskLM(self.vocab_size, self.num_hiddens)
        self.nsp = NextSentencePred()

    def __call__(self, tokens, segments, valid_lens=None, pred_positions=None,
                 training=False):
        encoded_X = self.encoder(tokens, segments, valid_lens,
                                 training=training)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # The hidden layer of the MLP classifier for next sentence prediction.
        # 0 is the index of the '<cls>' token
        nsp_Y_hat = self.nsp(
            jnp.tanh(self.hidden(encoded_X[:, 0, :])))
        return encoded_X, mlm_Y_hat, nsp_Y_hat

Recap

  • BERT = bidirectional Transformer encoder, pretrained with masked LM + next-sentence prediction.
  • Three additive embeddings (token, segment, position) — the “BERT input” recipe.
  • Pretrain once on a huge corpus, fine-tune the head on any classification / tagging / QA task.
  • Successors: RoBERTa (drop NSP, more data), ELECTRA (replaced-token detection), DeBERTa (disentangled attention). All variations on the same template.