from d2l import tensorflow as d2l
import tensorflow as tf
import keras
import numpy as np
import json
import multiprocessing
import osPretrained BERT does NLI off the shelf, near state-of-the-art, with one trick: feed <cls> premise <sep> hypothesis <sep> and stick a 3-way classifier on the <cls> token.
The illustration of why BERT mattered: arbitrary sentence-pair classification reduces to a few lines of fine-tuning on a pretrained encoder.
BERT encoder + small MLP head on <cls>.
We use a small pretrained BERT (the one we trained ourselves in the previous chapter, or a downloaded checkpoint). The framework-specific checkpoint conversion helpers are implementation plumbing, so the slide shows only the teaching contract:
The loaded encoder returns contextual token representations and the <cls> representation. Fine-tuning reuses that backbone and adds only a small task head.
warnings.warn(
Tokenize each (premise, hypothesis) pair into BERT input format: <cls> + premise + <sep> + hypothesis + <sep> with segment IDs distinguishing the two halves:
class SNLIBERTDataset:
def __init__(self, dataset, max_len, vocab=None):
all_premise_hypothesis_tokens = [[
p_tokens, h_tokens] for p_tokens, h_tokens in zip(
*[d2l.tokenize([s.lower() for s in sentences])
for sentences in dataset[:2]])]
self.labels = np.array(dataset[2])
self.vocab = vocab
self.max_len = max_len
(self.all_token_ids, self.all_segments,
self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
print('read ' + str(len(self.all_token_ids)) + ' examples')
def _preprocess(self, all_premise_hypothesis_tokens):
pool = multiprocessing.Pool(4) # Use 4 worker processes
out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
all_token_ids = [
token_ids for token_ids, segments, valid_len in out]
all_segments = [segments for token_ids, segments, valid_len in out]
valid_lens = [valid_len for token_ids, segments, valid_len in out]
return (np.array(all_token_ids, dtype='int32'),
np.array(all_segments, dtype='int32'),
np.array(valid_lens))
def _mp_worker(self, premise_hypothesis_tokens):
p_tokens, h_tokens = premise_hypothesis_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
* (self.max_len - len(tokens))
segments = segments + [0] * (self.max_len - len(segments))
valid_len = len(tokens)
return token_ids, segments, valid_len
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# Reserve slots for '<CLS>', '<SEP>', and '<SEP>' tokens for the BERT
# input
while len(p_tokens) + len(h_tokens) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_token_ids)# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len = 512, 128
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
AUTOTUNE = tf.data.AUTOTUNE
train_iter = tf.data.Dataset.from_tensor_slices(
(train_set.all_token_ids, train_set.all_segments,
train_set.valid_lens, train_set.labels)
).shuffle(buffer_size=len(train_set.labels)).batch(batch_size).prefetch(AUTOTUNE)
test_iter = tf.data.Dataset.from_tensor_slices(
(test_set.all_token_ids, test_set.all_segments,
test_set.valid_lens, test_set.labels)
).batch(batch_size).prefetch(AUTOTUNE)read 549367 examples
read 9824 examples
Tiny MLP on the <cls> representation — 3 outputs (entailment, contradiction, neutral). Encoder weights are fine-tuned end-to-end:
class BERTClassifier(keras.Model):
def __init__(self, bert):
super(BERTClassifier, self).__init__()
self.encoder = bert.encoder
self.hidden = bert.hidden
self.output_layer = keras.layers.Dense(3)
def call(self, inputs, training=False):
tokens_X, segments_X, valid_lens_x = inputs
encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x,
training=training)
return self.output_layer(self.hidden(encoded_X[:, 0, :]))net = BERTClassifier(bert)
# Warm up the classifier with a dummy forward pass
dummy_tokens = tf.ones((2, max_len), dtype=tf.int32)
dummy_segments = tf.zeros((2, max_len), dtype=tf.int32)
dummy_valid_lens = tf.cast(tf.fill((2,), max_len), dtype=tf.float32)
net((dummy_tokens, dummy_segments, dummy_valid_lens), training=False)<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 1.5142541, -1.0775912, 2.4618201],
[ 1.5142541, -1.0775912, 2.4618201]], dtype=float32)>
Standard cross-entropy + Adam, low learning rate (e.g. 2e-5). Few epochs are enough — the model already knows language; we’re just teaching it the specific task. Validation accuracy is the main signal, since training loss can keep falling after the classifier starts overfitting SNLI artifacts:
lr, num_epochs = 1e-4, 5
# Wrap tf.data batches for Keras: each batch is
# (tokens_X, segments_X, valid_lens_x, labels)
def reformat(tokens_X, segments_X, valid_lens_x, labels):
return (tokens_X, segments_X, valid_lens_x), labels
train_iter_tf = train_iter.map(reformat)
test_iter_tf = test_iter.map(reformat)
net.compile(
optimizer=keras.optimizers.Adam(lr),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
net.fit(train_iter_tf, validation_data=test_iter_tf, epochs=num_epochs)Epoch 1/5
Final epoch metrics: accuracy: 0.3203 - loss: 1.1621
Final epoch metrics: accuracy: 0.3255 - loss: 1.2187
Final epoch metrics: accuracy: 0.3337 - loss: 1.2087
Final epoch metrics: accuracy: 0.3401 - loss: 1.1982
Final epoch metrics: accuracy: 0.3453 - loss: 1.1884
...
Final epoch metrics: accuracy: 0.8120 - loss: 0.4751
Final epoch metrics: accuracy: 0.8120 - loss: 0.4751
Final epoch metrics: accuracy: 0.8120 - loss: 0.4751
Final epoch metrics: accuracy: 0.8120 - loss: 0.4751
Final epoch metrics: accuracy: 0.8089 - loss: 0.4801 - val_accuracy: 0.7817 - val_loss: 0.5477
<cls> A <sep> B <sep>, classify the <cls> representation.