import collections
from d2l import tensorflow as d2l
import math
import numpy as np
import os
import random
import tensorflow as tfData pipeline for word2vec. Three classical tricks make skip-gram tractable on large corpora:
Output: minibatches of (center, context+, context-) ready for the next deck’s skip-gram model.
Standard small-corpus NLP benchmark. Stream as a list of sentences, each a list of tokens:
d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',
'319d85e578af0cdc590547f26231e4e31cdf1e42')
def read_ptb():
"""Load the PTB dataset into a list of text lines."""
data_dir = d2l.download_extract('ptb')
# Read the training set
with open(os.path.join(data_dir, 'ptb.train.txt')) as f:
raw_text = f.read()
return [line.split() for line in raw_text.split('\n')]
sentences = read_ptb()
f'# sentences: {len(sentences)}''# sentences: 42069'
Drop word w with probability 1 - \sqrt{t / f(w)}, where f(w) is its corpus frequency. Common words get aggressively dropped, rare ones almost never:
def subsample(sentences, vocab):
"""Subsample high-frequency words."""
# Exclude unknown tokens ('<unk>')
sentences = [[token for token in line if vocab[token] != vocab.unk]
for line in sentences]
counter = collections.Counter([
token for line in sentences for token in line])
num_tokens = sum(counter.values())
# Return True if `token` is kept during subsampling
def keep(token):
return(random.uniform(0, 1) <
math.sqrt(1e-4 / counter[token] * num_tokens))
return ([[token for token in line if keep(token)] for line in sentences],
counter)
subsampled, counter = subsample(sentences, vocab)Compare counts before / after — frequent words shrink, rare words mostly survive. The target is not a smaller corpus for its own sake; it is a less dominated training signal for content words:
'# of "the": before=50770, after=2017'
For each center token, sample a window size up to max_window_size and take the surrounding tokens as context:
def get_centers_and_contexts(corpus, max_window_size):
"""Return center words and context words in skip-gram."""
centers, contexts = [], []
for line in corpus:
# To form a "center word--context word" pair, each sentence needs to
# have at least 2 words
if len(line) < 2:
continue
centers += line
for i in range(len(line)): # Context window centered at `i`
window_size = random.randint(1, max_window_size)
indices = list(range(max(0, i - window_size),
min(len(line), i + 1 + window_size)))
# Exclude the center word from the context words
indices.remove(i)
contexts.append([line[idx] for idx in indices])
return centers, contextsdataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1]
center 1 has contexts [0, 2, 3]
center 2 has contexts [1, 3]
center 3 has contexts [2, 4]
center 4 has contexts [3, 5]
center 5 has contexts [4, 6]
center 6 has contexts [5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [8]
Sample K negative context words per positive pair from P(w) \propto f(w)^{0.75} — the empirical-frequency distribution dampened by an exponent (the original word2vec choice):
class RandomGenerator:
"""Randomly draw among {1, ..., n} according to n sampling weights."""
def __init__(self, sampling_weights):
# Exclude
self.population = list(range(1, len(sampling_weights) + 1))
self.sampling_weights = sampling_weights
self.candidates = []
self.i = 0
def draw(self):
if self.i == len(self.candidates):
# Cache `k` random sampling results
self.candidates = random.choices(
self.population, self.sampling_weights, k=10000)
self.i = 0
self.i += 1
return self.candidates[self.i - 1][2, 3, 1, 3, 3, 2, 3, 3, 3, 1]
def get_negatives(all_contexts, vocab, counter, K):
"""Return noise words in negative sampling."""
# Sampling weights for words with indices 1, 2, ... (index 0 is the
# excluded unknown token) in the vocabulary
sampling_weights = [counter[vocab.to_tokens(i)]**0.75
for i in range(1, len(vocab))]
all_negatives, generator = [], RandomGenerator(sampling_weights)
for contexts in all_contexts:
negatives = []
while len(negatives) < len(contexts) * K:
neg = generator.draw()
# Noise words cannot be context words
if neg not in contexts:
negatives.append(neg)
all_negatives.append(negatives)
return all_negatives
all_negatives = get_negatives(all_contexts, vocab, counter, 5)Different centers have different numbers of context + negative tokens. Pad to the batch max, store a mask so loss only counts real positions:
def batchify(data):
"""Return a minibatch of examples for skip-gram with negative sampling."""
max_len = max(len(c) + len(n) for _, c, n in data)
centers, contexts_negatives, masks, labels = [], [], [], []
for center, context, negative in data:
cur_len = len(context) + len(negative)
centers += [center]
contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
masks += [[1] * cur_len + [0] * (max_len - cur_len)]
labels += [[1] * len(context) + [0] * (max_len - len(context))]
return (d2l.reshape(d2l.tensor(centers), (-1, 1)), d2l.tensor(
contexts_negatives), d2l.tensor(masks), d2l.tensor(labels))centers = tf.Tensor(
[[1]
[1]], shape=(2, 1), dtype=int32)
contexts_negatives = tf.Tensor(
[[2 2 3 3 3 3]
[2 2 2 3 3 0]], shape=(2, 6), dtype=int32)
masks = tf.Tensor(
[[1 1 1 1 1 1]
[1 1 1 1 1 0]], shape=(2, 6), dtype=int32)
labels = tf.Tensor(
[[1 1 0 0 0 0]
[1 1 1 0 0 0]], shape=(2, 6), dtype=int32)
The final factory returns the same pieces the skip-gram loss expects: centers, contexts+negatives, masks, and labels.
def _pad_ptb(all_centers, all_contexts, all_negatives):
"""Pre-pad all skip-gram examples to the global max length.
Returns four NumPy arrays: centers (N,), contexts_negatives (N, L),
masks (N, L), labels (N, L), where L = max(len(c) + len(n))."""
import numpy as _np
n = len(all_centers)
max_len = max(len(c) + len(neg)
for c, neg in zip(all_contexts, all_negatives))
centers = _np.asarray(all_centers, dtype=_np.int64)
contexts_negatives = _np.zeros((n, max_len), dtype=_np.int64)
masks = _np.zeros((n, max_len), dtype=_np.float32)
labels = _np.zeros((n, max_len), dtype=_np.float32)
for i, (c, neg) in enumerate(zip(all_contexts, all_negatives)):
cur_len = len(c) + len(neg)
contexts_negatives[i, :cur_len] = c + neg
masks[i, :cur_len] = 1.
labels[i, :len(c)] = 1.
return centers, contexts_negatives, masks, labelsdef load_data_ptb(batch_size, max_window_size, num_noise_words):
"""Download the PTB dataset and then load it into memory."""
sentences = read_ptb()
vocab = d2l.Vocab(sentences, min_freq=10)
subsampled, counter = subsample(sentences, vocab)
corpus = [vocab[line] for line in subsampled]
all_centers, all_contexts = get_centers_and_contexts(
corpus, max_window_size)
all_negatives = get_negatives(
all_contexts, vocab, counter, num_noise_words)
centers, cn, masks, labels = _pad_ptb(
all_centers, all_contexts, all_negatives)
# centers shape: (N,) -> (N, 1) to match batchify convention
centers_t = tf.constant(centers[:, None], dtype=tf.int64)
cn_t = tf.constant(cn, dtype=tf.int64)
masks_t = tf.constant(masks, dtype=tf.float32)
labels_t = tf.constant(labels, dtype=tf.float32)
dataset = tf.data.Dataset.from_tensor_slices(
(centers_t, cn_t, masks_t, labels_t))
dataset = dataset.shuffle(buffer_size=len(centers)).batch(
batch_size).prefetch(tf.data.AUTOTUNE)
return dataset, vocab