Tiny BERT config

Pretraining BERT

Pretraining BERT

With the model (last deck) and the data (deck before that), we can finally pretrain a small BERT end-to-end. This deck does it on a tiny scale: 2 layers, 128 hidden dim, 2 heads. The recipe scales to BERT-Base (12 layers, 768 dim, 12 heads) and BERT-Large by just changing the config.

Load pretraining data

Each batch supplies tokens, segment IDs, valid lengths, masked positions/labels, MLM weights, and NSP labels:

from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
from flax.training import train_state

The notebook uses a deliberately small encoder so the full pretraining loop is runnable in class:

batch_size, max_len = 512, 64
# In JAX, train_iter is a callable returning a fresh iterator each call.
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)

Trainer setup

Initialize the optimizer/trainer for this tiny BERT. Scaling to BERT-Base changes only the data size, model width/depth, and compute budget:

net = d2l.BERTModel(len(vocab), num_hiddens=128,
                    ffn_num_hiddens=256, num_heads=2, num_blks=2, dropout=0.2)

Combined loss

Two heads, one combined loss:

\mathcal{L} = \mathcal{L}_\text{MLM} + \mathcal{L}_\text{NSP}.

MLM cross-entropy averaged over masked positions; NSP binary cross-entropy on the <cls> head:

def _get_batch_loss_bert(params, net, vocab_size, tokens_X,
                         segments_X, valid_lens_x,
                         pred_positions_X, mlm_weights_X,
                         mlm_Y, nsp_y, rng):
    # Forward pass
    _, mlm_Y_hat, nsp_Y_hat = net.apply(params, tokens_X, segments_X,
                                        valid_lens_x.reshape(-1),
                                        pred_positions_X, training=True,
                                        rngs={'dropout': rng})
    # Compute masked language model loss
    mlm_l = optax.softmax_cross_entropy_with_integer_labels(
        mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1))
    mlm_l = (mlm_l * mlm_weights_X.reshape(-1)).sum() / (
        mlm_weights_X.sum() + 1e-8)
    # Compute next sentence prediction loss
    nsp_l = optax.softmax_cross_entropy_with_integer_labels(
        nsp_Y_hat, nsp_y).mean()
    l = mlm_l + nsp_l
    return l, (mlm_l, nsp_l)

Training loop

Standard SGD with warmup; on this tiny corpus a few hundred steps is enough to see both losses drop. MLM loss stays higher than NSP because it predicts a large vocabulary rather than a binary label:

def train_bert(train_iter, net, vocab_size, num_steps):
    # Initialize model parameters using a dummy batch
    dummy_tokens = jnp.ones((2, 64), dtype=jnp.int32)
    dummy_segments = jnp.zeros((2, 64), dtype=jnp.int32)
    dummy_valid_lens = jnp.array([64, 64], dtype=jnp.float32)
    dummy_pred_positions = jnp.zeros((2, 10), dtype=jnp.int32)
    key = jax.random.PRNGKey(0)
    params = net.init(key, dummy_tokens, dummy_segments, dummy_valid_lens,
                      dummy_pred_positions, training=False)
    tx = optax.adam(learning_rate=1e-4)
    state = train_state.TrainState.create(
        apply_fn=net.apply, params=params, tx=tx)

    grad_fn = jax.value_and_grad(_get_batch_loss_bert, has_aux=True)
    step, timer = 0, d2l.Timer()
    animator = d2l.Animator(xlabel='step', ylabel='loss',
                            xlim=[1, num_steps], legend=['mlm', 'nsp'])
    # Sum of masked language modeling losses, sum of next sentence prediction
    # losses, no. of sentence pairs, count
    metric = d2l.Accumulator(4)
    num_steps_reached = False
    rng = jax.random.PRNGKey(1)
    while step < num_steps and not num_steps_reached:
        # train_iter is a callable: invoke it each epoch for a fresh iterator.
        for (tokens_X, segments_X, valid_lens_x, pred_positions_X,
             mlm_weights_X, mlm_Y, nsp_y) in train_iter():
            timer.start()
            rng, step_rng = jax.random.split(rng)
            (l, (mlm_l, nsp_l)), grads = grad_fn(
                state.params, net, vocab_size, tokens_X, segments_X,
                valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y,
                step_rng)
            state = state.apply_gradients(grads=grads)
            metric.add(float(mlm_l), float(nsp_l), tokens_X.shape[0], 1)
            timer.stop()
            animator.add(step + 1,
                         (metric[0] / metric[3], metric[1] / metric[3]))
            step += 1
            if step == num_steps:
                num_steps_reached = True
                break

    print(f'MLM loss {metric[0] / metric[3]:.3f}, '
          f'NSP loss {metric[1] / metric[3]:.3f}')
    print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
          f'{str(jax.devices())}')
    return state
state = train_bert(train_iter, net, len(vocab), 50)

MLM loss 9.454, NSP loss 0.718
408.4 sentence pairs/sec on [CudaDevice(id=0)]

Using the trained encoder

After pretraining, the encoder is the useful part — turn token sequences into contextual representations. The pretraining heads can be discarded for most downstream tasks:

def get_bert_encoding(net, params, tokens_a, tokens_b=None):
    tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
    token_ids = jnp.array(vocab[tokens], dtype=jnp.int32)[None, :]
    segments = jnp.array(segments, dtype=jnp.int32)[None, :]
    valid_len = jnp.array([len(tokens)], dtype=jnp.float32)
    encoded_X, _, _ = net.apply(params, token_ids, segments, valid_len,
                                training=False)
    return encoded_X

Single sentence

“a crane is flying” → 6 hidden vectors (one per token, including <cls> and <sep>). Each is contextual — the representation of “crane” depends on its neighbors:

tokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, state.params, tokens_a)
# Tokens: '<cls>', 'a', 'crane', 'is', 'flying', '<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]
((1, 6, 128),
 (1, 128),
 Array([-1.8185419 ,  0.563233  ,  0.89780295], dtype=float32))

Sentence pair

“a crane driver came” / “he just left”. Same encoder, two-segment input — segment IDs distinguish the two halves inside the same sequence:

tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, state.params, tokens_a, tokens_b)
# Tokens: '<cls>', 'a', 'crane', 'driver', 'came', '<sep>', 'he', 'just',
# 'left', '<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]
((1, 10, 128),
 (1, 128),
 Array([-1.6992598 ,  0.19103326,  0.9098285 ], dtype=float32))

Recap

  • BERT pretraining is just two losses (MLM + NSP) optimized end-to-end on the encoder + heads.
  • Output of pretraining: a contextual token encoder.
  • For downstream tasks: load encoder weights, attach a small head, fine-tune. The next chapter does exactly this for sentiment classification, NLI, and SQuAD-style QA.