Reading SNLI

Natural Language Inference and the Dataset

NLI Data

Natural Language Inference (NLI) — given a premise sentence and a hypothesis sentence, decide whether the hypothesis is

  • entailed by the premise,
  • contradicted by it, or
  • neutral (independent).

A 3-way sentence-pair classification benchmark. Pre-BERT NLI was a hard task that drove a lot of attention-based research; BERT then crushed it and made NLI a standard fine-tuning probe.

This deck loads SNLI (Stanford NLI), 570k labeled sentence pairs. Output: (premise_ids, hypothesis_ids, label) minibatches.

Setup

from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
import os
import re


d2l.DATA_HUB['SNLI'] = (
    'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
    '9fcde07509c7e87ec61c640c1b2753d9041758e4')

data_dir = d2l.download_extract('SNLI')

TSV format: gold_label remise ypothesis .. Skip the - (no consensus) labels:

def read_snli(data_dir, is_train):
    """Read the SNLI dataset into premises, hypotheses, and labels."""
    def extract_text(s):
        # Remove information that will not be used by us
        s = re.sub('\\(', '', s) 
        s = re.sub('\\)', '', s)
        # Substitute two or more consecutive whitespace with space
        s = re.sub('\\s{2,}', ' ', s)
        return s.strip()
    label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
    file_name = os.path.join(data_dir, 'snli_1.0_train.txt'
                             if is_train else 'snli_1.0_test.txt')
    with open(file_name, 'r') as f:
        rows = [row.split('\t') for row in f.readlines()[1:]]
    premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
    hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set]
    labels = [label_set[row[0]] for row in rows if row[0] in label_set]
    return premises, hypotheses, labels
train_data = read_snli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
    print('premise:', x0)
    print('hypothesis:', x1)
    print('label:', y)
premise: A person on a horse jumps over a broken down airplane .
hypothesis: A person is training his horse for a competition .
label: 2
premise: A person on a horse jumps over a broken down airplane .
hypothesis: A person is at a diner , ordering an omelette .
label: 1
premise: A person on a horse jumps over a broken down airplane .
hypothesis: A person is outdoors , on a horse .
label: 0
test_data = read_snli(data_dir, is_train=False)
for data in [train_data, test_data]:
    print([[row for row in data[2]].count(i) for i in range(3)])
[183416, 183187, 182764]
[3368, 3237, 3219]

Custom Dataset

Two parallel sequence inputs, one label per pair:

class SNLIDataset:
    """A customized dataset to load the SNLI dataset."""
    def __init__(self, dataset, num_steps, vocab=None):
        self.num_steps = num_steps
        all_premise_tokens = d2l.tokenize(dataset[0])
        all_hypothesis_tokens = d2l.tokenize(dataset[1])
        if vocab is None:
            self.vocab = d2l.Vocab(all_premise_tokens + all_hypothesis_tokens,
                                   min_freq=5, reserved_tokens=['<pad>'])
        else:
            self.vocab = vocab
        self.premises = self._pad(all_premise_tokens)
        self.hypotheses = self._pad(all_hypothesis_tokens)
        self.labels = jnp.array(dataset[2])
        print('read ' + str(len(self.premises)) + ' examples')

    def _pad(self, lines):
        return jnp.array([d2l.truncate_pad(
            self.vocab[line], self.num_steps, self.vocab['<pad>'])
                         for line in lines])

    def __getitem__(self, idx):
        return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]

    def __len__(self):
        return len(self.premises)

Loader factory

def load_data_snli(batch_size, num_steps=50):
    """Download the SNLI dataset and return data iterators and vocabulary."""
    data_dir = d2l.download_extract('SNLI')
    train_data = read_snli(data_dir, True)
    test_data = read_snli(data_dir, False)
    train_set = SNLIDataset(train_data, num_steps)
    test_set = SNLIDataset(test_data, num_steps, train_set.vocab)
    train_iter = d2l.load_array(
        (train_set.premises, train_set.hypotheses, train_set.labels),
        batch_size, is_train=True)
    test_iter = d2l.load_array(
        (test_set.premises, test_set.hypotheses, test_set.labels),
        batch_size, is_train=False)
    return train_iter, test_iter, train_set.vocab
train_iter, test_iter, vocab = load_data_snli(128, 50)
len(vocab)
read 549367 examples
read 9824 examples
18678
for batch in train_iter:
    print(batch[0].shape)
    print(batch[1].shape)
    print(batch[2].shape)
    break
(128, 50)
(128, 50)
(128,)

Recap

  • NLI = 3-way premise/hypothesis classification (entailment, contradiction, neutral).
  • SNLI is the standard corpus; 550k+ labeled pairs.
  • Output: aligned premise/hypothesis token ID pairs + label, ready for the attention model (next deck) and BERT fine-tuning (deck after).