from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
import json
import osPretrained BERT does NLI off the shelf, near state-of-the-art, with one trick: feed <cls> premise <sep> hypothesis <sep> and stick a 3-way classifier on the <cls> token.
The illustration of why BERT mattered: arbitrary sentence-pair classification reduces to a few lines of fine-tuning on a pretrained encoder.
BERT encoder + small MLP head on <cls>.
We use a small pretrained BERT (the one we trained ourselves in the previous chapter, or a downloaded checkpoint). The framework-specific checkpoint conversion helpers are implementation plumbing, so the slide shows only the teaching contract:
The loaded encoder returns contextual token representations and the <cls> representation. Fine-tuning reuses that backbone and adds only a small task head.
Tokenize each (premise, hypothesis) pair into BERT input format: <cls> + premise + <sep> + hypothesis + <sep> with segment IDs distinguishing the two halves:
class SNLIBERTDataset:
def __init__(self, dataset, max_len, vocab=None):
all_premise_hypothesis_tokens = [[
p_tokens, h_tokens] for p_tokens, h_tokens in zip(
*[d2l.tokenize([s.lower() for s in sentences])
for sentences in dataset[:2]])]
self.labels = jnp.array(dataset[2])
self.vocab = vocab
self.max_len = max_len
(self.all_token_ids, self.all_segments,
self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
print('read ' + str(len(self.all_token_ids)) + ' examples')
def _preprocess(self, all_premise_hypothesis_tokens):
# JAX arrays cannot be passed across process boundaries, so we use a
# plain list comprehension instead of multiprocessing.Pool.
out = [self._mp_worker(tokens)
for tokens in all_premise_hypothesis_tokens]
all_token_ids = [
token_ids for token_ids, segments, valid_len in out]
all_segments = [segments for token_ids, segments, valid_len in out]
valid_lens = [valid_len for token_ids, segments, valid_len in out]
return (jnp.array(all_token_ids, dtype=jnp.int32),
jnp.array(all_segments, dtype=jnp.int32),
jnp.array(valid_lens))
def _mp_worker(self, premise_hypothesis_tokens):
p_tokens, h_tokens = premise_hypothesis_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
* (self.max_len - len(tokens))
segments = segments + [0] * (self.max_len - len(segments))
valid_len = len(tokens)
return token_ids, segments, valid_len
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# Reserve slots for '<CLS>', '<SEP>', and '<SEP>' tokens for the BERT
# input
while len(p_tokens) + len(h_tokens) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_token_ids)# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len = 512, 128
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = d2l.load_array(
(train_set.all_token_ids, train_set.all_segments,
train_set.valid_lens, train_set.labels), batch_size, is_train=True)
test_iter = d2l.load_array(
(test_set.all_token_ids, test_set.all_segments,
test_set.valid_lens, test_set.labels), batch_size, is_train=False)read 549367 examples
read 9824 examples
Tiny MLP on the <cls> representation — 3 outputs (entailment, contradiction, neutral). Encoder weights are fine-tuned end-to-end:
net = BERTClassifier(bert)
# Initialize the classifier with pretrained BERT parameters
dummy_tokens = jnp.ones((2, max_len), dtype=jnp.int32)
dummy_segments = jnp.zeros((2, max_len), dtype=jnp.int32)
dummy_valid_lens = jnp.array([max_len, max_len], dtype=jnp.float32)
rng = jax.random.PRNGKey(0)
params = net.init(rng, dummy_tokens, dummy_segments, dummy_valid_lens,
training=False)
# Copy pretrained BERT encoder and hidden layer parameters
import copy
new_params = copy.deepcopy(dict(params))
new_params['params']['bert'] = bert_params['params']
params = new_paramsStandard cross-entropy + Adam, low learning rate (e.g. 2e-5). Few epochs are enough — the model already knows language; we’re just teaching it the specific task. Validation accuracy is the main signal, since training loss can keep falling after the classifier starts overfitting SNLI artifacts:
lr, num_epochs = 1e-4, 5
optimizer = optax.adam(lr)
opt_state = optimizer.init(params)
def loss_fn(params, tokens_X, segments_X, valid_lens_x, labels, rng):
logits = net.apply(params, tokens_X, segments_X, valid_lens_x,
training=True, rngs={'dropout': rng})
return optax.softmax_cross_entropy_with_integer_labels(
logits, labels).mean()
@jax.jit
def train_step(params, opt_state, tokens_X, segments_X, valid_lens_x,
labels, rng):
loss, grads = jax.value_and_grad(loss_fn)(
params, tokens_X, segments_X, valid_lens_x, labels, rng)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
@jax.jit
def eval_step(params, tokens_X, segments_X, valid_lens_x, labels):
logits = net.apply(params, tokens_X, segments_X, valid_lens_x,
training=False)
return (logits.argmax(axis=-1) == labels).sum()
rng = jax.random.PRNGKey(0)
for epoch in range(num_epochs):
train_loss, n_train = 0.0, 0
for batch in train_iter:
tokens_X, segments_X, valid_lens_x, labels = (
batch[0], batch[1], batch[2], batch[3])
rng, step_rng = jax.random.split(rng)
params, opt_state, loss = train_step(
params, opt_state, tokens_X, segments_X, valid_lens_x,
labels, step_rng)
train_loss += float(loss) * len(labels)
n_train += len(labels)
# Evaluate on test set
n_correct, n_test = 0, 0
for batch in test_iter:
tokens_X, segments_X, valid_lens_x, labels = (
batch[0], batch[1], batch[2], batch[3])
n_correct += int(eval_step(
params, tokens_X, segments_X, valid_lens_x, labels))
n_test += len(labels)
print(f'epoch {epoch + 1}, loss {train_loss / n_train:.4f}, '
f'test acc {n_correct / n_test:.4f}')epoch 1, loss 0.7854, test acc 0.7282
epoch 2, loss 0.6365, test acc 0.7539
epoch 3, loss 0.5683, test acc 0.7723
epoch 4, loss 0.5180, test acc 0.7827
epoch 5, loss 0.4812, test acc 0.7830
<cls> A <sep> B <sep>, classify the <cls> representation.