The Dataset for Pretraining BERT

BERT Pretraining Data

The previous deck specified BERT’s model. This one specifies the data: how to turn raw text into the (masked tokens, NSP label, segment IDs, valid lengths) tuples that the pretraining loop expects.

We use WikiText-2 — a small, readable Wikipedia subset. Real BERT was pretrained on BookCorpus + English Wikipedia (~3.3B tokens); the recipe is identical, just scaled up.

Read WikiText-2

WikiText-2 keeps punctuation, case, and numbers. The loader returns paragraphs as sentence lists so NSP can sample adjacent or random sentence pairs:

from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
import os
import random
WIKITEXT_2_URL = ('https://huggingface.co/datasets/Salesforce/wikitext/'
                  'resolve/main/wikitext-2-v1/train-00000-of-00001.parquet')


def _read_wiki(data_dir=None):
    import contextlib
    import io
    import pandas as pd
    with contextlib.redirect_stdout(io.StringIO()):
        fname = d2l.download(WIKITEXT_2_URL, folder='../data')
    lines = pd.read_parquet(fname)['text'].tolist()
    # Uppercase letters are converted to lowercase ones
    paragraphs = [line.strip().lower().split(' . ')
                  for line in lines if len(line.split(' . ')) >= 2]
    random.shuffle(paragraphs)
    return paragraphs

Generating NSP examples

For each sentence, with probability 0.5 pair it with the next sentence (is_next=1); otherwise pair with a random sentence (is_next=0):

def _get_next_sentence(sentence, next_sentence, paragraphs):
    if random.random() < 0.5:
        is_next = True
    else:
        # `paragraphs` is a list of lists of lists
        next_sentence = random.choice(random.choice(paragraphs))
        is_next = False
    return sentence, next_sentence, is_next
def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
    nsp_data_from_paragraph = []
    for i in range(len(paragraph) - 1):
        tokens_a, tokens_b, is_next = _get_next_sentence(
            paragraph[i], paragraph[i + 1], paragraphs)
        # Consider 1 '<cls>' token and 2 '<sep>' tokens
        if len(tokens_a) + len(tokens_b) + 3 > max_len:
            continue
        tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
        nsp_data_from_paragraph.append((tokens, segments, is_next))
    return nsp_data_from_paragraph

Generating Masked LM labels

Pick 15% of token positions. For those:

  • 80%: replace with <mask>.
  • 10%: replace with a random token.
  • 10%: leave the original (so the model can’t tell which position was selected for prediction).
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
                        vocab):
    # For the input of a masked language model, make a new copy of tokens and
    # replace some of them by '<mask>' or random tokens
    mlm_input_tokens = [token for token in tokens]
    pred_positions_and_labels = []
    # Shuffle for getting 15% random tokens for prediction in the masked
    # language modeling task
    random.shuffle(candidate_pred_positions)
    for mlm_pred_position in candidate_pred_positions:
        if len(pred_positions_and_labels) >= num_mlm_preds:
            break
        masked_token = None
        # 80% of the time: replace the word with the '<mask>' token
        if random.random() < 0.8:
            masked_token = '<mask>'
        else:
            # 10% of the time: keep the word unchanged
            if random.random() < 0.5:
                masked_token = tokens[mlm_pred_position]
            # 10% of the time: replace the word with a random word
            else:
                masked_token = random.choice(vocab.idx_to_token)
        mlm_input_tokens[mlm_pred_position] = masked_token
        pred_positions_and_labels.append(
            (mlm_pred_position, tokens[mlm_pred_position]))
    return mlm_input_tokens, pred_positions_and_labels
def _get_mlm_data_from_tokens(tokens, vocab):
    candidate_pred_positions = []
    # `tokens` is a list of strings
    for i, token in enumerate(tokens):
        # Special tokens are not predicted in the masked language modeling
        # task
        if token in ['<cls>', '<sep>']:
            continue
        candidate_pred_positions.append(i)
    # 15% of random tokens are predicted in the masked language modeling task
    num_mlm_preds = max(1, round(len(tokens) * 0.15))
    mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(
        tokens, candidate_pred_positions, num_mlm_preds, vocab)
    pred_positions_and_labels = sorted(pred_positions_and_labels,
                                       key=lambda x: x[0])
    pred_positions = [v[0] for v in pred_positions_and_labels]
    mlm_pred_labels = [v[1] for v in pred_positions_and_labels]
    return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]

Padding

Pad to the batch max length; track valid_lens for attention masking; pad MLM labels with zero so the loss ignores them:

def _pad_bert_inputs(examples, max_len, vocab):
    max_num_mlm_preds = round(max_len * 0.15)
    all_token_ids, all_segments, valid_lens,  = [], [], []
    all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
    nsp_labels = []
    for (token_ids, pred_positions, mlm_pred_label_ids, segments,
         is_next) in examples:
        all_token_ids.append(jnp.array(token_ids + [vocab['<pad>']] * (
            max_len - len(token_ids)), dtype=jnp.int32))
        all_segments.append(jnp.array(segments + [0] * (
            max_len - len(segments)), dtype=jnp.int32))
        # `valid_lens` excludes count of '<pad>' tokens
        valid_lens.append(jnp.array(len(token_ids), dtype=jnp.float32))
        all_pred_positions.append(jnp.array(pred_positions + [0] * (
            max_num_mlm_preds - len(pred_positions)), dtype=jnp.int32))
        # Predictions of padded tokens will be filtered out in the loss via
        # multiplication of 0 weights
        all_mlm_weights.append(
            jnp.array([1.0] * len(mlm_pred_label_ids) + [0.0] * (
                max_num_mlm_preds - len(pred_positions)),
                dtype=jnp.float32))
        all_mlm_labels.append(jnp.array(mlm_pred_label_ids + [0] * (
            max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=jnp.int32))
        nsp_labels.append(jnp.array(is_next, dtype=jnp.int32))
    return (all_token_ids, all_segments, valid_lens, all_pred_positions,
            all_mlm_weights, all_mlm_labels, nsp_labels)

Custom Dataset class

Wraps the per-example generators into a __getitem__ interface — the standard PyTorch / framework idiom:

class _WikiTextDataset:
    def __init__(self, paragraphs, max_len):
        # Input `paragraphs[i]` is a list of sentence strings representing a
        # paragraph; while output `paragraphs[i]` is a list of sentences
        # representing a paragraph, where each sentence is a list of tokens
        paragraphs = [d2l.tokenize(
            paragraph, token='word') for paragraph in paragraphs]
        sentences = [sentence for paragraph in paragraphs
                     for sentence in paragraph]
        self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[
            '<pad>', '<mask>', '<cls>', '<sep>'])
        # Get data for the next sentence prediction task
        examples = []
        for paragraph in paragraphs:
            examples.extend(_get_nsp_data_from_paragraph(
                paragraph, paragraphs, self.vocab, max_len))
        # Get data for the masked language model task
        examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)
                      + (segments, is_next))
                     for tokens, segments, is_next in examples]
        # Pad inputs
        (self.all_token_ids, self.all_segments, self.valid_lens,
         self.all_pred_positions, self.all_mlm_weights,
         self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(
            examples, max_len, self.vocab)

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx], self.all_pred_positions[idx],
                self.all_mlm_weights[idx], self.all_mlm_labels[idx],
                self.nsp_labels[idx])

    def __len__(self):
        return len(self.all_token_ids)

Loader factory

Download corpus → tokenize → generate NSP + MLM pairs → DataLoader:

def load_data_wiki(batch_size, max_len):
    """Load the WikiText-2 dataset."""
    paragraphs = _read_wiki()
    train_set = _WikiTextDataset(paragraphs, max_len)
    # Create an index array and shuffle it
    indices = list(range(len(train_set)))
    random.shuffle(indices)

    # Return a callable so each call yields a fresh iterator (one-shot
    # generators can't be re-entered for multi-epoch training).
    def data_iter():
        for i in range(0, len(indices), batch_size):
            batch_indices = indices[i:i + batch_size]
            if len(batch_indices) < batch_size:
                continue
            batch = [train_set[idx] for idx in batch_indices]
            yield (jnp.stack([b[0] for b in batch]),
                   jnp.stack([b[1] for b in batch]),
                   jnp.stack([b[2] for b in batch]),
                   jnp.stack([b[3] for b in batch]),
                   jnp.stack([b[4] for b in batch]),
                   jnp.stack([b[5] for b in batch]),
                   jnp.stack([b[6] for b in batch]))
    return data_iter, train_set.vocab

Inspect a minibatch

Verify shapes: tokens, segments, valid_lens, pred_positions, mlm_weights, mlm_labels, nsp_labels. mlm_weights marks which padded prediction slots should contribute to the MLM loss:

batch_size, max_len = 512, 64
# train_iter is a callable returning a fresh iterator on each call.
train_iter, vocab = load_data_wiki(batch_size, max_len)

for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,
     mlm_Y, nsp_y) in train_iter():
    print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,
          pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,
          nsp_y.shape)
    break
(512, 64) (512, 64) (512,) (512, 10) (512, 10) (512, 10) (512,)
len(vocab)
20257

Recap

  • A BERT minibatch carries seven tensors: tokens, segments, valid_lens, masked positions, MLM weights, MLM labels, NSP label.
  • 15% MLM masking with the 80/10/10 split is the original recipe; modern variants (RoBERTa) drop NSP, increase masking, and dynamic-mask each epoch.
  • Pretraining BERT for real takes 16+ TPU/GPU-days; the WikiText-2 demo is small enough to play with.