Input representation

Bidirectional Encoder Representations from Transformers (BERT)

BERT

BERT (Devlin et al., 2018) — bidirectional Transformer encoder pretrained on a giant corpus, then fine-tuned to arbitrary downstream NLP tasks. Started the “pretrain + fine-tune” era for NLP.

Lineage

  • ELMo — LSTM, contextual but task-specific architecture.
  • GPT — Transformer, unidirectional (left context only).
  • BERT — Transformer, bidirectional via masked LM, task-agnostic.

ELMo vs GPT vs BERT.

Setup

from d2l import tensorflow as d2l
import tensorflow as tf
from tensorflow import keras
import numpy as np

A BERT input sequence packs in a lot:

  • <cls> + tokens of segment A + <sep> + tokens of segment B + <sep>.
  • Three additive embeddings: token, segment (A vs B), position.

Token + segment + position embeddings, all summed.

def get_tokens_and_segments(tokens_a, tokens_b=None):
    """Get tokens of the BERT input sequence and their segment IDs."""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0 and 1 are marking segment A and B, respectively
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

BERTEncoder

A standard Transformer encoder stack on the summed embeddings. The pretrained model exposes one hidden vector per input position:

class BERTEncoder(keras.layers.Layer):
    """BERT encoder."""
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
                 num_blks, dropout, max_len=1000, **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = keras.layers.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = keras.layers.Embedding(2, num_hiddens)
        # In BERT, positional embeddings are learnable, thus we create a
        # trainable variable of positional embeddings that are long enough
        self.pos_embedding = self.add_weight(
            name='pos_embedding', shape=(1, max_len, num_hiddens),
            initializer='random_normal', trainable=True)
        norm_shape = [num_hiddens]
        # BERT's attention sublayers use biased projections; the default for
        # `TransformerEncoderBlock` is `bias=False`, so override here.
        self.blks = [d2l.TransformerEncoderBlock(
            num_hiddens, num_hiddens, num_hiddens, num_hiddens, norm_shape,
            ffn_num_hiddens, num_heads, dropout, bias=True)
            for _ in range(num_blks)]

    def call(self, tokens, segments, valid_lens, training=False, **kwargs):
        # Shape of `X` remains unchanged in the following code snippet:
        # (batch size, max sequence length, `num_hiddens`)
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding[:, :tf.shape(X)[1], :]
        for blk in self.blks:
            X = blk(X, valid_lens, training=training)
        return X

Encoder shape check

The encoder emits a contextual vector for every input token plus one pooled <cls> vector. Both shapes should agree with num_hiddens; mismatches usually mean segment or position embeddings were not summed correctly.

vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
num_blks, dropout = 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
                      num_blks, dropout)
tokens = tf.random.uniform(shape=(2, 8), minval=0, maxval=vocab_size,
                           dtype=tf.int32)
segments = tf.constant([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder(tokens, segments, None, training=False)
encoded_X.shape
TensorShape([2, 8, 768])

Pretraining task 1: Masked LM

Randomly mask 15% of input tokens (replace with <mask> 80% of the time, a random token 10%, leave unchanged 10%). Train the encoder to predict the originals. Forces the model to use both left and right context.

class MaskLM(keras.layers.Layer):
    """The masked language model task of BERT."""
    def __init__(self, vocab_size, num_hiddens, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp = keras.Sequential([
            keras.layers.Dense(num_hiddens, activation='relu'),
            keras.layers.LayerNormalization(),
            keras.layers.Dense(vocab_size),
        ])

    def call(self, X, pred_positions, **kwargs):
        num_pred_positions = pred_positions.shape[1]
        pred_positions_flat = tf.reshape(pred_positions, [-1])
        batch_size = tf.shape(X)[0]
        batch_idx = tf.repeat(tf.range(batch_size), num_pred_positions)
        # Suppose that `batch_size` = 2, `num_pred_positions` = 3, then
        # `batch_idx` is `tf.tensor([0, 0, 0, 1, 1, 1])`
        indices = tf.stack([batch_idx, pred_positions_flat], axis=1)
        masked_X = tf.gather_nd(X, indices)
        masked_X = tf.reshape(masked_X, [batch_size, num_pred_positions, -1])
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat

MaskLM forward

Gather hidden states at the masked positions; project through an MLP head to vocab logits. The loss is evaluated only on these selected positions, not on every token:

mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = tf.constant([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)
mlm_Y_hat.shape
TensorShape([2, 3, 10000])
mlm_Y = tf.constant([[7, 8, 9], [10, 20, 30]])
loss = keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
mlm_l = loss(tf.reshape(mlm_Y, [-1]),
             tf.reshape(mlm_Y_hat, [-1, vocab_size]))
mlm_l.shape
TensorShape([6])

Pretraining task 2: Next Sentence Prediction

Auxiliary binary task: given two segments, are they consecutive in the corpus? Trains the <cls> token’s representation to capture sentence-pair relationships (useful for QA, NLI):

class NextSentencePred(keras.layers.Layer):
    """The next sentence prediction task of BERT."""
    def __init__(self, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        # `output` is reserved on Keras Layer (a read-only property), so use
        # `dense` for the head.
        self.dense = keras.layers.Dense(2)

    def call(self, X, **kwargs):
        # `X` shape: (batch size, `num_hiddens`)
        return self.dense(X)

NSP head

2-way classifier on the <cls> representation:

# Use the `<cls>` token (index 0) as input to NSP
# input_shape for NSP: (batch size, `num_hiddens`)
nsp = NextSentencePred()
nsp_Y_hat = nsp(encoded_X[:, 0, :])
nsp_Y_hat.shape
TensorShape([2, 2])
nsp_y = tf.constant([0, 1])
nsp_l = loss(nsp_y, nsp_Y_hat)
nsp_l.shape
TensorShape([2])

Putting it together

Encoder + MaskLM head + NSP head, sharing the same backbone. Pretrain end-to-end on (masked tokens, NSP label) tuples; fine-tune downstream by replacing the heads:

class BERTModel(keras.Model):
    """The BERT model."""
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,
                 num_heads, num_blks, dropout, max_len=1000):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens,
                                   num_heads, num_blks, dropout,
                                   max_len=max_len)
        self.hidden = keras.layers.Dense(num_hiddens, activation='tanh')
        self.mlm = MaskLM(vocab_size, num_hiddens)
        self.nsp = NextSentencePred()

    def call(self, tokens, segments, valid_lens=None, pred_positions=None,
             training=False, **kwargs):
        encoded_X = self.encoder(tokens, segments, valid_lens,
                                 training=training)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # The hidden layer of the MLP classifier for next sentence prediction.
        # 0 is the index of the '<cls>' token
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat

Recap

  • BERT = bidirectional Transformer encoder, pretrained with masked LM + next-sentence prediction.
  • Three additive embeddings (token, segment, position) — the “BERT input” recipe.
  • Pretrain once on a huge corpus, fine-tune the head on any classification / tagging / QA task.
  • Successors: RoBERTa (drop NSP, more data), ELECTRA (replaced-token detection), DeBERTa (disentangled attention). All variations on the same template.