from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
import os
import randomThe previous deck specified BERT’s model. This one specifies the data: how to turn raw text into the (masked tokens, NSP label, segment IDs, valid lengths) tuples that the pretraining loop expects.
We use WikiText-2 — a small, readable Wikipedia subset. Real BERT was pretrained on BookCorpus + English Wikipedia (~3.3B tokens); the recipe is identical, just scaled up.
WikiText-2 keeps punctuation, case, and numbers. The loader returns paragraphs as sentence lists so NSP can sample adjacent or random sentence pairs:
WIKITEXT_2_URL = ('https://huggingface.co/datasets/Salesforce/wikitext/'
'resolve/main/wikitext-2-v1/train-00000-of-00001.parquet')
def _read_wiki(data_dir=None):
import contextlib
import io
import pandas as pd
with contextlib.redirect_stdout(io.StringIO()):
fname = d2l.download(WIKITEXT_2_URL, folder='../data')
lines = pd.read_parquet(fname)['text'].tolist()
# Uppercase letters are converted to lowercase ones
paragraphs = [line.strip().lower().split(' . ')
for line in lines if len(line.split(' . ')) >= 2]
random.shuffle(paragraphs)
return paragraphsFor each sentence, with probability 0.5 pair it with the next sentence (is_next=1); otherwise pair with a random sentence (is_next=0):
def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
nsp_data_from_paragraph = []
for i in range(len(paragraph) - 1):
tokens_a, tokens_b, is_next = _get_next_sentence(
paragraph[i], paragraph[i + 1], paragraphs)
# Consider 1 '<cls>' token and 2 '<sep>' tokens
if len(tokens_a) + len(tokens_b) + 3 > max_len:
continue
tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
nsp_data_from_paragraph.append((tokens, segments, is_next))
return nsp_data_from_paragraphPick 15% of token positions. For those:
<mask>.def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
vocab):
# For the input of a masked language model, make a new copy of tokens and
# replace some of them by '<mask>' or random tokens
mlm_input_tokens = [token for token in tokens]
pred_positions_and_labels = []
# Shuffle for getting 15% random tokens for prediction in the masked
# language modeling task
random.shuffle(candidate_pred_positions)
for mlm_pred_position in candidate_pred_positions:
if len(pred_positions_and_labels) >= num_mlm_preds:
break
masked_token = None
# 80% of the time: replace the word with the '<mask>' token
if random.random() < 0.8:
masked_token = '<mask>'
else:
# 10% of the time: keep the word unchanged
if random.random() < 0.5:
masked_token = tokens[mlm_pred_position]
# 10% of the time: replace the word with a random word
else:
masked_token = random.choice(vocab.idx_to_token)
mlm_input_tokens[mlm_pred_position] = masked_token
pred_positions_and_labels.append(
(mlm_pred_position, tokens[mlm_pred_position]))
return mlm_input_tokens, pred_positions_and_labelsdef _get_mlm_data_from_tokens(tokens, vocab):
candidate_pred_positions = []
# `tokens` is a list of strings
for i, token in enumerate(tokens):
# Special tokens are not predicted in the masked language modeling
# task
if token in ['<cls>', '<sep>']:
continue
candidate_pred_positions.append(i)
# 15% of random tokens are predicted in the masked language modeling task
num_mlm_preds = max(1, round(len(tokens) * 0.15))
mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(
tokens, candidate_pred_positions, num_mlm_preds, vocab)
pred_positions_and_labels = sorted(pred_positions_and_labels,
key=lambda x: x[0])
pred_positions = [v[0] for v in pred_positions_and_labels]
mlm_pred_labels = [v[1] for v in pred_positions_and_labels]
return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]Pad to the batch max length; track valid_lens for attention masking; pad MLM labels with zero so the loss ignores them:
def _pad_bert_inputs(examples, max_len, vocab):
max_num_mlm_preds = round(max_len * 0.15)
all_token_ids, all_segments, valid_lens, = [], [], []
all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
nsp_labels = []
for (token_ids, pred_positions, mlm_pred_label_ids, segments,
is_next) in examples:
all_token_ids.append(jnp.array(token_ids + [vocab['<pad>']] * (
max_len - len(token_ids)), dtype=jnp.int32))
all_segments.append(jnp.array(segments + [0] * (
max_len - len(segments)), dtype=jnp.int32))
# `valid_lens` excludes count of '<pad>' tokens
valid_lens.append(jnp.array(len(token_ids), dtype=jnp.float32))
all_pred_positions.append(jnp.array(pred_positions + [0] * (
max_num_mlm_preds - len(pred_positions)), dtype=jnp.int32))
# Predictions of padded tokens will be filtered out in the loss via
# multiplication of 0 weights
all_mlm_weights.append(
jnp.array([1.0] * len(mlm_pred_label_ids) + [0.0] * (
max_num_mlm_preds - len(pred_positions)),
dtype=jnp.float32))
all_mlm_labels.append(jnp.array(mlm_pred_label_ids + [0] * (
max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=jnp.int32))
nsp_labels.append(jnp.array(is_next, dtype=jnp.int32))
return (all_token_ids, all_segments, valid_lens, all_pred_positions,
all_mlm_weights, all_mlm_labels, nsp_labels)Wraps the per-example generators into a __getitem__ interface — the standard PyTorch / framework idiom:
class _WikiTextDataset:
def __init__(self, paragraphs, max_len):
# Input `paragraphs[i]` is a list of sentence strings representing a
# paragraph; while output `paragraphs[i]` is a list of sentences
# representing a paragraph, where each sentence is a list of tokens
paragraphs = [d2l.tokenize(
paragraph, token='word') for paragraph in paragraphs]
sentences = [sentence for paragraph in paragraphs
for sentence in paragraph]
self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[
'<pad>', '<mask>', '<cls>', '<sep>'])
# Get data for the next sentence prediction task
examples = []
for paragraph in paragraphs:
examples.extend(_get_nsp_data_from_paragraph(
paragraph, paragraphs, self.vocab, max_len))
# Get data for the masked language model task
examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)
+ (segments, is_next))
for tokens, segments, is_next in examples]
# Pad inputs
(self.all_token_ids, self.all_segments, self.valid_lens,
self.all_pred_positions, self.all_mlm_weights,
self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(
examples, max_len, self.vocab)
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx], self.all_pred_positions[idx],
self.all_mlm_weights[idx], self.all_mlm_labels[idx],
self.nsp_labels[idx])
def __len__(self):
return len(self.all_token_ids)Download corpus → tokenize → generate NSP + MLM pairs → DataLoader:
def load_data_wiki(batch_size, max_len):
"""Load the WikiText-2 dataset."""
paragraphs = _read_wiki()
train_set = _WikiTextDataset(paragraphs, max_len)
# Create an index array and shuffle it
indices = list(range(len(train_set)))
random.shuffle(indices)
# Return a callable so each call yields a fresh iterator (one-shot
# generators can't be re-entered for multi-epoch training).
def data_iter():
for i in range(0, len(indices), batch_size):
batch_indices = indices[i:i + batch_size]
if len(batch_indices) < batch_size:
continue
batch = [train_set[idx] for idx in batch_indices]
yield (jnp.stack([b[0] for b in batch]),
jnp.stack([b[1] for b in batch]),
jnp.stack([b[2] for b in batch]),
jnp.stack([b[3] for b in batch]),
jnp.stack([b[4] for b in batch]),
jnp.stack([b[5] for b in batch]),
jnp.stack([b[6] for b in batch]))
return data_iter, train_set.vocabVerify shapes: tokens, segments, valid_lens, pred_positions, mlm_weights, mlm_labels, nsp_labels. mlm_weights marks which padded prediction slots should contribute to the MLM loss:
batch_size, max_len = 512, 64
# train_iter is a callable returning a fresh iterator on each call.
train_iter, vocab = load_data_wiki(batch_size, max_len)
for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,
mlm_Y, nsp_y) in train_iter():
print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,
pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,
nsp_y.shape)
break(512, 64) (512, 64) (512,) (512, 10) (512, 10) (512, 10) (512,)