PTB corpus

The Dataset for Pretraining Word Embeddings

word2vec Data

Data pipeline for word2vec. Three classical tricks make skip-gram tractable on large corpora:

  • Subsampling — drop frequent words (“the”, “a”) with probability rising with their frequency.
  • Center / context window — each token contributes a few (center, context) pairs.
  • Negative sampling — train a binary classifier with K random negatives per positive; \mathcal{O}(|V|) \to \mathcal{O}(K).

Output: minibatches of (center, context+, context-) ready for the next deck’s skip-gram model.

Setup

import collections
from d2l import tensorflow as d2l
import math
import numpy as np
import os
import random
import tensorflow as tf

Standard small-corpus NLP benchmark. Stream as a list of sentences, each a list of tokens:

d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',
                       '319d85e578af0cdc590547f26231e4e31cdf1e42')


def read_ptb():
    """Load the PTB dataset into a list of text lines."""
    data_dir = d2l.download_extract('ptb')
    # Read the training set
    with open(os.path.join(data_dir, 'ptb.train.txt')) as f:
        raw_text = f.read()
    return [line.split() for line in raw_text.split('\n')]

sentences = read_ptb()
f'# sentences: {len(sentences)}'
'# sentences: 42069'
vocab = d2l.Vocab(sentences, min_freq=10)
f'vocab size: {len(vocab)}'
'vocab size: 6720'

Subsampling frequent words

Drop word w with probability 1 - \sqrt{t / f(w)}, where f(w) is its corpus frequency. Common words get aggressively dropped, rare ones almost never:

def subsample(sentences, vocab):
    """Subsample high-frequency words."""
    # Exclude unknown tokens ('<unk>')
    sentences = [[token for token in line if vocab[token] != vocab.unk]
                 for line in sentences]
    counter = collections.Counter([
        token for line in sentences for token in line])
    num_tokens = sum(counter.values())

    # Return True if `token` is kept during subsampling
    def keep(token):
        return(random.uniform(0, 1) <
               math.sqrt(1e-4 / counter[token] * num_tokens))

    return ([[token for token in line if keep(token)] for line in sentences],
            counter)

subsampled, counter = subsample(sentences, vocab)
d2l.show_list_len_pair_hist(['origin', 'subsampled'], '# tokens per sentence',
                            'count', sentences, subsampled);

Subsampling effect

Compare counts before / after — frequent words shrink, rare words mostly survive. The target is not a smaller corpus for its own sake; it is a less dominated training signal for content words:

def compare_counts(token):
    return (f'# of "{token}": '
            f'before={sum([l.count(token) for l in sentences])}, '
            f'after={sum([l.count(token) for l in subsampled])}')

compare_counts('the')
'# of "the": before=50770, after=2017'
compare_counts('join')
'# of "join": before=45, after=45'
corpus = [vocab[line] for line in subsampled]
corpus[:3]
[[], [393, 2116, 407], [5278, 3055, 1581, 96]]

Center + context windows

For each center token, sample a window size up to max_window_size and take the surrounding tokens as context:

def get_centers_and_contexts(corpus, max_window_size):
    """Return center words and context words in skip-gram."""
    centers, contexts = [], []
    for line in corpus:
        # To form a "center word--context word" pair, each sentence needs to
        # have at least 2 words
        if len(line) < 2:
            continue
        centers += line
        for i in range(len(line)):  # Context window centered at `i`
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, i - window_size),
                                 min(len(line), i + 1 + window_size)))
            # Exclude the center word from the context words
            indices.remove(i)
            contexts.append([line[idx] for idx in indices])
    return centers, contexts
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
    print('center', center, 'has contexts', context)
dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1]
center 1 has contexts [0, 2, 3]
center 2 has contexts [1, 3]
center 3 has contexts [2, 4]
center 4 has contexts [3, 5]
center 5 has contexts [4, 6]
center 6 has contexts [5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [8]
all_centers, all_contexts = get_centers_and_contexts(corpus, 5)
f'# center-context pairs: {sum([len(contexts) for contexts in all_contexts])}'
'# center-context pairs: 1501956'

Negative sampling

Sample K negative context words per positive pair from P(w) \propto f(w)^{0.75} — the empirical-frequency distribution dampened by an exponent (the original word2vec choice):

class RandomGenerator:
    """Randomly draw among {1, ..., n} according to n sampling weights."""
    def __init__(self, sampling_weights):
        # Exclude 
        self.population = list(range(1, len(sampling_weights) + 1))
        self.sampling_weights = sampling_weights
        self.candidates = []
        self.i = 0

    def draw(self):
        if self.i == len(self.candidates):
            # Cache `k` random sampling results
            self.candidates = random.choices(
                self.population, self.sampling_weights, k=10000)
            self.i = 0
        self.i += 1
        return self.candidates[self.i - 1]
generator = RandomGenerator([2, 3, 4])
[generator.draw() for _ in range(10)]
[2, 3, 1, 3, 3, 2, 3, 3, 3, 1]
def get_negatives(all_contexts, vocab, counter, K):
    """Return noise words in negative sampling."""
    # Sampling weights for words with indices 1, 2, ... (index 0 is the
    # excluded unknown token) in the vocabulary
    sampling_weights = [counter[vocab.to_tokens(i)]**0.75
                        for i in range(1, len(vocab))]
    all_negatives, generator = [], RandomGenerator(sampling_weights)
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            neg = generator.draw()
            # Noise words cannot be context words
            if neg not in contexts:
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives

all_negatives = get_negatives(all_contexts, vocab, counter, 5)

Padding into minibatches

Different centers have different numbers of context + negative tokens. Pad to the batch max, store a mask so loss only counts real positions:

def batchify(data):
    """Return a minibatch of examples for skip-gram with negative sampling."""
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (d2l.reshape(d2l.tensor(centers), (-1, 1)), d2l.tensor(
        contexts_negatives), d2l.tensor(masks), d2l.tensor(labels))
x_1 = (1, [2, 2], [3, 3, 3, 3])
x_2 = (1, [2, 2, 2], [3, 3])
batch = batchify((x_1, x_2))

names = ['centers', 'contexts_negatives', 'masks', 'labels']
for name, data in zip(names, batch):
    print(name, '=', data)
centers = tf.Tensor(
[[1]
 [1]], shape=(2, 1), dtype=int32)
contexts_negatives = tf.Tensor(
[[2 2 3 3 3 3]
 [2 2 2 3 3 0]], shape=(2, 6), dtype=int32)
masks = tf.Tensor(
[[1 1 1 1 1 1]
 [1 1 1 1 1 0]], shape=(2, 6), dtype=int32)
labels = tf.Tensor(
[[1 1 0 0 0 0]
 [1 1 1 0 0 0]], shape=(2, 6), dtype=int32)

Reusable loader

The final factory returns the same pieces the skip-gram loss expects: centers, contexts+negatives, masks, and labels.

def _pad_ptb(all_centers, all_contexts, all_negatives):
    """Pre-pad all skip-gram examples to the global max length.

    Returns four NumPy arrays: centers (N,), contexts_negatives (N, L),
    masks (N, L), labels (N, L), where L = max(len(c) + len(n))."""
    import numpy as _np
    n = len(all_centers)
    max_len = max(len(c) + len(neg)
                  for c, neg in zip(all_contexts, all_negatives))
    centers = _np.asarray(all_centers, dtype=_np.int64)
    contexts_negatives = _np.zeros((n, max_len), dtype=_np.int64)
    masks = _np.zeros((n, max_len), dtype=_np.float32)
    labels = _np.zeros((n, max_len), dtype=_np.float32)
    for i, (c, neg) in enumerate(zip(all_contexts, all_negatives)):
        cur_len = len(c) + len(neg)
        contexts_negatives[i, :cur_len] = c + neg
        masks[i, :cur_len] = 1.
        labels[i, :len(c)] = 1.
    return centers, contexts_negatives, masks, labels
def load_data_ptb(batch_size, max_window_size, num_noise_words):
    """Download the PTB dataset and then load it into memory."""
    sentences = read_ptb()
    vocab = d2l.Vocab(sentences, min_freq=10)
    subsampled, counter = subsample(sentences, vocab)
    corpus = [vocab[line] for line in subsampled]
    all_centers, all_contexts = get_centers_and_contexts(
        corpus, max_window_size)
    all_negatives = get_negatives(
        all_contexts, vocab, counter, num_noise_words)
    centers, cn, masks, labels = _pad_ptb(
        all_centers, all_contexts, all_negatives)
    # centers shape: (N,) -> (N, 1) to match batchify convention
    centers_t = tf.constant(centers[:, None], dtype=tf.int64)
    cn_t = tf.constant(cn, dtype=tf.int64)
    masks_t = tf.constant(masks, dtype=tf.float32)
    labels_t = tf.constant(labels, dtype=tf.float32)
    dataset = tf.data.Dataset.from_tensor_slices(
        (centers_t, cn_t, masks_t, labels_t))
    dataset = dataset.shuffle(buffer_size=len(centers)).batch(
        batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset, vocab
data_iter, vocab = load_data_ptb(512, 5, 5)
for batch in data_iter:
    for name, data in zip(names, batch):
        print(name, 'shape:', data.shape)
    break
centers shape: (512, 1)
contexts_negatives shape: (512, 60)
masks shape: (512, 60)
labels shape: (512, 60)

Recap

  • Subsample frequent words, sample dynamic-size context windows, draw K negatives per pair.
  • Pad to batch-max with a binary mask — same trick as variable-length sequence batching elsewhere.
  • This loader feeds the skip-gram model in the next deck; same pattern still used by fastText / GloVe-style pretraining.