Loading pretrained BERT

Natural Language Inference: Fine-Tuning BERT

BERT for NLI

Pretrained BERT does NLI off the shelf, near state-of-the-art, with one trick: feed <cls> premise <sep> hypothesis <sep> and stick a 3-way classifier on the <cls> token.

The illustration of why BERT mattered: arbitrary sentence-pair classification reduces to a few lines of fine-tuning on a pretrained encoder.

Pipeline

BERT encoder + small MLP head on <cls>.

Setup

from d2l import mxnet as d2l
import json
import multiprocessing
from mxnet import gluon, np, npx
from mxnet.gluon import nn
import os

npx.set_np()

We use a small pretrained BERT (the one we trained ourselves in the previous chapter, or a downloaded checkpoint). The framework-specific checkpoint conversion helpers are implementation plumbing, so the slide shows only the teaching contract:

  • register a checkpoint URL and checksum;
  • load the vocabulary;
  • instantiate the same BERT architecture;
  • copy pretrained weights into the encoder.

Instantiate pretrained BERT

The loaded encoder returns contextual token representations and the <cls> representation. Fine-tuning reuses that backbone and adds only a small task head.

Encoding sentence pairs

Tokenize each (premise, hypothesis) pair into BERT input format: <cls> + premise + <sep> + hypothesis + <sep> with segment IDs distinguishing the two halves:

class SNLIBERTDataset(gluon.data.Dataset):
    def __init__(self, dataset, max_len, vocab=None):
        all_premise_hypothesis_tokens = [[
            p_tokens, h_tokens] for p_tokens, h_tokens in zip(
            *[d2l.tokenize([s.lower() for s in sentences])
              for sentences in dataset[:2]])]
        
        self.labels = np.array(dataset[2])
        self.vocab = vocab
        self.max_len = max_len
        (self.all_token_ids, self.all_segments,
         self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
        print('read ' + str(len(self.all_token_ids)) + ' examples')

    def _preprocess(self, all_premise_hypothesis_tokens):
        pool = multiprocessing.Pool(4)  # Use 4 worker processes
        out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
        all_token_ids = [
            token_ids for token_ids, segments, valid_len in out]
        all_segments = [segments for token_ids, segments, valid_len in out]
        valid_lens = [valid_len for token_ids, segments, valid_len in out]
        return (np.array(all_token_ids, dtype='int32'),
                np.array(all_segments, dtype='int32'), 
                np.array(valid_lens))

    def _mp_worker(self, premise_hypothesis_tokens):
        p_tokens, h_tokens = premise_hypothesis_tokens
        self._truncate_pair_of_tokens(p_tokens, h_tokens)
        tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
        token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
                             * (self.max_len - len(tokens))
        segments = segments + [0] * (self.max_len - len(segments))
        valid_len = len(tokens)
        return token_ids, segments, valid_len

    def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
        # Reserve slots for '<CLS>', '<SEP>', and '<SEP>' tokens for the BERT
        # input
        while len(p_tokens) + len(h_tokens) > self.max_len - 3:
            if len(p_tokens) > len(h_tokens):
                p_tokens.pop()
            else:
                h_tokens.pop()

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx]), self.labels[idx]

    def __len__(self):
        return len(self.all_token_ids)
# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True,
                                   num_workers=num_workers)
test_iter = gluon.data.DataLoader(test_set, batch_size,
                                  num_workers=num_workers)

Classifier head

Tiny MLP on the <cls> representation — 3 outputs (entailment, contradiction, neutral). Encoder weights are fine-tuned end-to-end:

class BERTClassifier(nn.Block):
    def __init__(self, bert):
        super(BERTClassifier, self).__init__()
        self.encoder = bert.encoder
        self.hidden = bert.hidden
        self.output = nn.Dense(3)

    def forward(self, inputs):
        tokens_X, segments_X, valid_lens_x = inputs
        encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
        return self.output(self.hidden(encoded_X[:, 0, :]))
net = BERTClassifier(bert)
net.output.initialize(ctx=devices)

Fine-tuning

Standard cross-entropy + Adam, low learning rate (e.g. 2e-5). Few epochs are enough — the model already knows language; we’re just teaching it the specific task. Validation accuracy is the main signal, since training loss can keep falling after the classifier starts overfitting SNLI artifacts:

# lr divided by batch_size: gluon Trainer no longer rescales (issue 7 fix in d2l.train_batch_ch13)
lr, num_epochs = 1.953125e-7, 5
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices,
               d2l.split_batch_multi_inputs)

Recap

  • Sentence-pair classification = encode <cls> A <sep> B <sep>, classify the <cls> representation.
  • Same recipe handles NLI, paraphrase, semantic similarity, and many more.
  • Fine-tuning hyperparameters: batch 32, lr ~2e-5, 2-4 epochs. Short, cheap, and reproducible.
  • The end of the pre-2018 NLI architecture wars: BERT made per-task model design largely obsolete.