from d2l import mxnet as d2l
from mxnet import autograd, gluon, init, np, npx
npx.set_np()With the model (last deck) and the data (deck before that), we can finally pretrain a small BERT end-to-end. This deck does it on a tiny scale: 2 layers, 128 hidden dim, 2 heads. The recipe scales to BERT-Base (12 layers, 768 dim, 12 heads) and BERT-Large by just changing the config.
Each batch supplies tokens, segment IDs, valid lengths, masked positions/labels, MLM weights, and NSP labels:
The notebook uses a deliberately small encoder so the full pretraining loop is runnable in class:
Initialize the optimizer/trainer for this tiny BERT. Scaling to BERT-Base changes only the data size, model width/depth, and compute budget:
Two heads, one combined loss:
\mathcal{L} = \mathcal{L}_\text{MLM} + \mathcal{L}_\text{NSP}.
MLM cross-entropy averaged over masked positions; NSP binary cross-entropy on the <cls> head:
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X_shards,
segments_X_shards, valid_lens_x_shards,
pred_positions_X_shards, mlm_weights_X_shards,
mlm_Y_shards, nsp_y_shards):
mlm_ls, nsp_ls, ls = [], [], []
for (tokens_X_shard, segments_X_shard, valid_lens_x_shard,
pred_positions_X_shard, mlm_weights_X_shard, mlm_Y_shard,
nsp_y_shard) in zip(
tokens_X_shards, segments_X_shards, valid_lens_x_shards,
pred_positions_X_shards, mlm_weights_X_shards, mlm_Y_shards,
nsp_y_shards):
# Forward pass
_, mlm_Y_hat, nsp_Y_hat = net(
tokens_X_shard, segments_X_shard, valid_lens_x_shard.reshape(-1),
pred_positions_X_shard)
# Compute masked language model loss
mlm_l = loss(
mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y_shard.reshape(-1),
mlm_weights_X_shard.reshape((-1, 1)))
mlm_l = mlm_l.sum() / (mlm_weights_X_shard.sum() + 1e-8)
# Compute next sentence prediction loss
nsp_l = loss(nsp_Y_hat, nsp_y_shard)
nsp_l = nsp_l.mean()
mlm_ls.append(mlm_l)
nsp_ls.append(nsp_l)
ls.append(mlm_l + nsp_l)
return mlm_ls, nsp_ls, lsStandard SGD with warmup; on this tiny corpus a few hundred steps is enough to see both losses drop. MLM loss stays higher than NSP because it predicts a large vocabulary rather than a binary label:
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
trainer = gluon.Trainer(net.collect_params(), 'adam',
{'learning_rate': 1e-4})
step, timer = 0, d2l.Timer()
animator = d2l.Animator(xlabel='step', ylabel='loss',
xlim=[1, num_steps], legend=['mlm', 'nsp'])
# Sum of masked language modeling losses, sum of next sentence prediction
# losses, no. of sentence pairs, count
metric = d2l.Accumulator(4)
num_steps_reached = False
while step < num_steps and not num_steps_reached:
for batch in train_iter:
(tokens_X_shards, segments_X_shards, valid_lens_x_shards,
pred_positions_X_shards, mlm_weights_X_shards,
mlm_Y_shards, nsp_y_shards) = [gluon.utils.split_and_load(
elem, devices, even_split=False) for elem in batch]
timer.start()
with autograd.record():
mlm_ls, nsp_ls, ls = _get_batch_loss_bert(
net, loss, vocab_size, tokens_X_shards, segments_X_shards,
valid_lens_x_shards, pred_positions_X_shards,
mlm_weights_X_shards, mlm_Y_shards, nsp_y_shards)
for l in ls:
l.backward()
trainer.step(1)
# Accumulate without forcing a GPU->CPU sync per step; the
# animator's metric read below is the only sync point.
mlm_l_sum = sum(mlm_ls)
nsp_l_sum = sum(nsp_ls)
metric.add(mlm_l_sum / len(mlm_ls), nsp_l_sum / len(nsp_ls),
batch[0].shape[0], 1)
timer.stop()
animator.add(step + 1,
(metric[0] / metric[3], metric[1] / metric[3]))
step += 1
if step == num_steps:
num_steps_reached = True
break
print(f'MLM loss {metric[0] / metric[3]:.3f}, '
f'NSP loss {metric[1] / metric[3]:.3f}')
print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
f'{str(devices)}')After pretraining, the encoder is the useful part — turn token sequences into contextual representations. The pretraining heads can be discarded for most downstream tasks:
def get_bert_encoding(net, tokens_a, tokens_b=None):
tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
token_ids = np.expand_dims(np.array(vocab[tokens], ctx=devices[0]),
axis=0)
segments = np.expand_dims(np.array(segments, ctx=devices[0]), axis=0)
valid_len = np.expand_dims(np.array(len(tokens), ctx=devices[0]), axis=0)
encoded_X, _, _ = net(token_ids, segments, valid_len)
return encoded_X“a crane is flying” → 6 hidden vectors (one per token, including <cls> and <sep>). Each is contextual — the representation of “crane” depends on its neighbors:
tokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# Tokens: '<cls>', 'a', 'crane', 'is', 'flying', '<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]“a crane driver came” / “he just left”. Same encoder, two-segment input — segment IDs distinguish the two halves inside the same sequence:
tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# Tokens: '<cls>', 'a', 'crane', 'driver', 'came', '<sep>', 'he', 'just',
# 'left', '<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]