Loading pretrained BERT

Natural Language Inference: Fine-Tuning BERT

BERT for NLI

Pretrained BERT does NLI off the shelf, near state-of-the-art, with one trick: feed <cls> premise <sep> hypothesis <sep> and stick a 3-way classifier on the <cls> token.

The illustration of why BERT mattered: arbitrary sentence-pair classification reduces to a few lines of fine-tuning on a pretrained encoder.

Pipeline

BERT encoder + small MLP head on <cls>.

Setup

from d2l import tensorflow as d2l
import tensorflow as tf
import keras
import numpy as np
import json
import multiprocessing
import os

We use a small pretrained BERT (the one we trained ourselves in the previous chapter, or a downloaded checkpoint). The framework-specific checkpoint conversion helpers are implementation plumbing, so the slide shows only the teaching contract:

  • register a checkpoint URL and checksum;
  • load the vocabulary;
  • instantiate the same BERT architecture;
  • copy pretrained weights into the encoder.

Instantiate pretrained BERT

The loaded encoder returns contextual token representations and the <cls> representation. Fine-tuning reuses that backbone and adds only a small task head.

  warnings.warn(

Encoding sentence pairs

Tokenize each (premise, hypothesis) pair into BERT input format: <cls> + premise + <sep> + hypothesis + <sep> with segment IDs distinguishing the two halves:

class SNLIBERTDataset:
    def __init__(self, dataset, max_len, vocab=None):
        all_premise_hypothesis_tokens = [[
            p_tokens, h_tokens] for p_tokens, h_tokens in zip(
            *[d2l.tokenize([s.lower() for s in sentences])
              for sentences in dataset[:2]])]

        self.labels = np.array(dataset[2])
        self.vocab = vocab
        self.max_len = max_len
        (self.all_token_ids, self.all_segments,
         self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
        print('read ' + str(len(self.all_token_ids)) + ' examples')

    def _preprocess(self, all_premise_hypothesis_tokens):
        pool = multiprocessing.Pool(4)  # Use 4 worker processes
        out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
        all_token_ids = [
            token_ids for token_ids, segments, valid_len in out]
        all_segments = [segments for token_ids, segments, valid_len in out]
        valid_lens = [valid_len for token_ids, segments, valid_len in out]
        return (np.array(all_token_ids, dtype='int32'),
                np.array(all_segments, dtype='int32'),
                np.array(valid_lens))

    def _mp_worker(self, premise_hypothesis_tokens):
        p_tokens, h_tokens = premise_hypothesis_tokens
        self._truncate_pair_of_tokens(p_tokens, h_tokens)
        tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
        token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
                             * (self.max_len - len(tokens))
        segments = segments + [0] * (self.max_len - len(segments))
        valid_len = len(tokens)
        return token_ids, segments, valid_len

    def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
        # Reserve slots for '<CLS>', '<SEP>', and '<SEP>' tokens for the BERT
        # input
        while len(p_tokens) + len(h_tokens) > self.max_len - 3:
            if len(p_tokens) > len(h_tokens):
                p_tokens.pop()
            else:
                h_tokens.pop()

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx]), self.labels[idx]

    def __len__(self):
        return len(self.all_token_ids)
# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len = 512, 128
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
AUTOTUNE = tf.data.AUTOTUNE
train_iter = tf.data.Dataset.from_tensor_slices(
    (train_set.all_token_ids, train_set.all_segments,
     train_set.valid_lens, train_set.labels)
).shuffle(buffer_size=len(train_set.labels)).batch(batch_size).prefetch(AUTOTUNE)
test_iter = tf.data.Dataset.from_tensor_slices(
    (test_set.all_token_ids, test_set.all_segments,
     test_set.valid_lens, test_set.labels)
).batch(batch_size).prefetch(AUTOTUNE)
read 549367 examples
read 9824 examples

Classifier head

Tiny MLP on the <cls> representation — 3 outputs (entailment, contradiction, neutral). Encoder weights are fine-tuned end-to-end:

class BERTClassifier(keras.Model):
    def __init__(self, bert):
        super(BERTClassifier, self).__init__()
        self.encoder = bert.encoder
        self.hidden = bert.hidden
        self.output_layer = keras.layers.Dense(3)

    def call(self, inputs, training=False):
        tokens_X, segments_X, valid_lens_x = inputs
        encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x,
                                 training=training)
        return self.output_layer(self.hidden(encoded_X[:, 0, :]))
net = BERTClassifier(bert)
# Warm up the classifier with a dummy forward pass
dummy_tokens = tf.ones((2, max_len), dtype=tf.int32)
dummy_segments = tf.zeros((2, max_len), dtype=tf.int32)
dummy_valid_lens = tf.cast(tf.fill((2,), max_len), dtype=tf.float32)
net((dummy_tokens, dummy_segments, dummy_valid_lens), training=False)
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 1.5142541, -1.0775912,  2.4618201],
       [ 1.5142541, -1.0775912,  2.4618201]], dtype=float32)>

Fine-tuning

Standard cross-entropy + Adam, low learning rate (e.g. 2e-5). Few epochs are enough — the model already knows language; we’re just teaching it the specific task. Validation accuracy is the main signal, since training loss can keep falling after the classifier starts overfitting SNLI artifacts:

lr, num_epochs = 1e-4, 5
# Wrap tf.data batches for Keras: each batch is
# (tokens_X, segments_X, valid_lens_x, labels)
def reformat(tokens_X, segments_X, valid_lens_x, labels):
    return (tokens_X, segments_X, valid_lens_x), labels

train_iter_tf = train_iter.map(reformat)
test_iter_tf = test_iter.map(reformat)
net.compile(
    optimizer=keras.optimizers.Adam(lr),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
net.fit(train_iter_tf, validation_data=test_iter_tf, epochs=num_epochs)
Epoch 1/5
Final epoch metrics: accuracy: 0.3203 - loss: 1.1621
Final epoch metrics: accuracy: 0.3255 - loss: 1.2187
Final epoch metrics: accuracy: 0.3337 - loss: 1.2087
Final epoch metrics: accuracy: 0.3401 - loss: 1.1982
Final epoch metrics: accuracy: 0.3453 - loss: 1.1884
...
Final epoch metrics: accuracy: 0.8120 - loss: 0.4751
Final epoch metrics: accuracy: 0.8120 - loss: 0.4751
Final epoch metrics: accuracy: 0.8120 - loss: 0.4751
Final epoch metrics: accuracy: 0.8120 - loss: 0.4751

Final epoch metrics: accuracy: 0.8089 - loss: 0.4801 - val_accuracy: 0.7817 - val_loss: 0.5477

Recap

  • Sentence-pair classification = encode <cls> A <sep> B <sep>, classify the <cls> representation.
  • Same recipe handles NLI, paraphrase, semantic similarity, and many more.
  • Fine-tuning hyperparameters: batch 32, lr ~2e-5, 2-4 epochs. Short, cheap, and reproducible.
  • The end of the pre-2018 NLI architecture wars: BERT made per-task model design largely obsolete.