import collections
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import math
import numpy as np
import os
import randomData pipeline for word2vec. Three classical tricks make skip-gram tractable on large corpora:
Output: minibatches of (center, context+, context-) ready for the next deck’s skip-gram model.
Standard small-corpus NLP benchmark. Stream as a list of sentences, each a list of tokens:
d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',
'319d85e578af0cdc590547f26231e4e31cdf1e42')
def read_ptb():
"""Load the PTB dataset into a list of text lines."""
data_dir = d2l.download_extract('ptb')
# Read the training set
with open(os.path.join(data_dir, 'ptb.train.txt')) as f:
raw_text = f.read()
return [line.split() for line in raw_text.split('\n')]
sentences = read_ptb()
f'# sentences: {len(sentences)}''# sentences: 42069'
Drop word w with probability 1 - \sqrt{t / f(w)}, where f(w) is its corpus frequency. Common words get aggressively dropped, rare ones almost never:
def subsample(sentences, vocab):
"""Subsample high-frequency words."""
# Exclude unknown tokens ('<unk>')
sentences = [[token for token in line if vocab[token] != vocab.unk]
for line in sentences]
counter = collections.Counter([
token for line in sentences for token in line])
num_tokens = sum(counter.values())
# Return True if `token` is kept during subsampling
def keep(token):
return(random.uniform(0, 1) <
math.sqrt(1e-4 / counter[token] * num_tokens))
return ([[token for token in line if keep(token)] for line in sentences],
counter)
subsampled, counter = subsample(sentences, vocab)Compare counts before / after — frequent words shrink, rare words mostly survive. The target is not a smaller corpus for its own sake; it is a less dominated training signal for content words:
'# of "the": before=50770, after=2023'
For each center token, sample a window size up to max_window_size and take the surrounding tokens as context:
def get_centers_and_contexts(corpus, max_window_size):
"""Return center words and context words in skip-gram."""
centers, contexts = [], []
for line in corpus:
# To form a "center word--context word" pair, each sentence needs to
# have at least 2 words
if len(line) < 2:
continue
centers += line
for i in range(len(line)): # Context window centered at `i`
window_size = random.randint(1, max_window_size)
indices = list(range(max(0, i - window_size),
min(len(line), i + 1 + window_size)))
# Exclude the center word from the context words
indices.remove(i)
contexts.append([line[idx] for idx in indices])
return centers, contextsdataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1]
center 1 has contexts [0, 2, 3]
center 2 has contexts [0, 1, 3, 4]
center 3 has contexts [2, 4]
center 4 has contexts [3, 5]
center 5 has contexts [4, 6]
center 6 has contexts [5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [7, 8]
Sample K negative context words per positive pair from P(w) \propto f(w)^{0.75} — the empirical-frequency distribution dampened by an exponent (the original word2vec choice):
class RandomGenerator:
"""Randomly draw among {1, ..., n} according to n sampling weights."""
def __init__(self, sampling_weights):
# Exclude
self.population = list(range(1, len(sampling_weights) + 1))
self.sampling_weights = sampling_weights
self.candidates = []
self.i = 0
def draw(self):
if self.i == len(self.candidates):
# Cache `k` random sampling results
self.candidates = random.choices(
self.population, self.sampling_weights, k=10000)
self.i = 0
self.i += 1
return self.candidates[self.i - 1][3, 3, 2, 3, 3, 2, 3, 3, 2, 2]
def get_negatives(all_contexts, vocab, counter, K):
"""Return noise words in negative sampling."""
# Sampling weights for words with indices 1, 2, ... (index 0 is the
# excluded unknown token) in the vocabulary
sampling_weights = [counter[vocab.to_tokens(i)]**0.75
for i in range(1, len(vocab))]
all_negatives, generator = [], RandomGenerator(sampling_weights)
for contexts in all_contexts:
negatives = []
while len(negatives) < len(contexts) * K:
neg = generator.draw()
# Noise words cannot be context words
if neg not in contexts:
negatives.append(neg)
all_negatives.append(negatives)
return all_negatives
all_negatives = get_negatives(all_contexts, vocab, counter, 5)Different centers have different numbers of context + negative tokens. Pad to the batch max, store a mask so loss only counts real positions:
def batchify(data):
"""Return a minibatch of examples for skip-gram with negative sampling."""
max_len = max(len(c) + len(n) for _, c, n in data)
centers, contexts_negatives, masks, labels = [], [], [], []
for center, context, negative in data:
cur_len = len(context) + len(negative)
centers += [center]
contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
masks += [[1] * cur_len + [0] * (max_len - cur_len)]
labels += [[1] * len(context) + [0] * (max_len - len(context))]
return (d2l.reshape(d2l.tensor(centers), (-1, 1)), d2l.tensor(
contexts_negatives), d2l.tensor(masks), d2l.tensor(labels))centers = [[1]
[1]]
contexts_negatives = [[2 2 3 3 3 3]
[2 2 2 3 3 0]]
masks = [[1 1 1 1 1 1]
[1 1 1 1 1 0]]
labels = [[1 1 0 0 0 0]
[1 1 1 0 0 0]]
The final factory returns the same pieces the skip-gram loss expects: centers, contexts+negatives, masks, and labels.
def _pad_ptb(all_centers, all_contexts, all_negatives):
"""Pre-pad all skip-gram examples to the global max length.
Returns four NumPy arrays: centers (N,), contexts_negatives (N, L),
masks (N, L), labels (N, L), where L = max(len(c) + len(n))."""
import numpy as _np
n = len(all_centers)
max_len = max(len(c) + len(neg)
for c, neg in zip(all_contexts, all_negatives))
centers = _np.asarray(all_centers, dtype=_np.int64)
contexts_negatives = _np.zeros((n, max_len), dtype=_np.int64)
masks = _np.zeros((n, max_len), dtype=_np.float32)
labels = _np.zeros((n, max_len), dtype=_np.float32)
for i, (c, neg) in enumerate(zip(all_contexts, all_negatives)):
cur_len = len(c) + len(neg)
contexts_negatives[i, :cur_len] = c + neg
masks[i, :cur_len] = 1.
labels[i, :len(c)] = 1.
return centers, contexts_negatives, masks, labelsdef load_data_ptb(batch_size, max_window_size, num_noise_words):
"""Download the PTB dataset and then load it into memory."""
sentences = read_ptb()
vocab = d2l.Vocab(sentences, min_freq=10)
subsampled, counter = subsample(sentences, vocab)
corpus = [vocab[line] for line in subsampled]
all_centers, all_contexts = get_centers_and_contexts(
corpus, max_window_size)
all_negatives = get_negatives(
all_contexts, vocab, counter, num_noise_words)
centers, cn, masks, labels = _pad_ptb(
all_centers, all_contexts, all_negatives)
centers = centers.reshape(-1, 1)
n = len(centers)
class PTBDataIter:
def __len__(self):
return math.ceil(n / batch_size)
def __iter__(self):
import numpy as _np
idx = _np.random.permutation(n)
for i in range(0, n, batch_size):
b = idx[i:i + batch_size]
yield (jnp.asarray(centers[b]), jnp.asarray(cn[b]),
jnp.asarray(masks[b]), jnp.asarray(labels[b]))
return PTBDataIter(), vocab