Pretraining word2vec

word2vec Pretraining

The skip-gram model with negative sampling, Mikolov et al. 2013. Each word has two embedding vectors:

  • \mathbf{v}_w — its embedding as a center word.
  • \mathbf{u}_w — its embedding as a context word.

Per minibatch, for each (center, context+) pair plus K sampled negatives:

\mathcal{L} = -\log \sigma(\mathbf{u}_{c+}^\top \mathbf{v}_w) - \sum_{c-} \log \sigma(-\mathbf{u}_{c-}^\top \mathbf{v}_w).

Binary classification — distinguish real (center, context) pairs from negatives. Cheap, embarrassingly parallel, trains fast on a CPU. After convergence, \mathbf{v}_w (or \mathbf{u}_w, or their sum) is the word embedding.

Setup

from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import math
import numpy as np
import optax

batch_size, max_window_size, num_noise_words = 512, 5, 5
data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size,
                                     num_noise_words)
E0524 02:46:39.940780 25597 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 9.41GiB (10100251136 bytes) of ...
E0524 02:46:39.941199 25597 cuda_executor.cc:1206] [0] Failed to allocate device memory: INTERNAL: [0] Failed to allocate 8.47GiB (9090225152 bytes) of ...

Embedding layers

Two nn.Embeddings — one for center words, one for context. Same vocab, same dimension, separate weights:

embed = nn.Embed(num_embeddings=20, features=4)
params = embed.init(jax.random.PRNGKey(0), jnp.ones((1,), dtype=jnp.int32))
print(f'Parameter embedding ({params["params"]["embedding"].shape}, '
      f'dtype={params["params"]["embedding"].dtype})')
Parameter embedding ((20, 4), dtype=float32)
x = jnp.array([[1, 2, 3], [4, 5, 6]])
embed.apply(params, x)
Array([[[ 0.5744758 ,  0.42841545,  0.08417441,  0.4318122 ],
        [ 0.42884952, -0.11532515,  0.5799791 ,  0.32131234],
        [-0.4364363 ,  0.24882556,  1.2444718 , -0.99843353]],

       [[ 0.20081298,  0.30229744,  0.82455504, -0.8965778 ],
        [ 0.24884604, -0.35751957, -0.01951474, -0.4571129 ],
        [ 0.6228912 , -0.5813085 , -0.22786202,  0.5639511 ]]],      dtype=float32)

Forward pass

Look up center embeddings and context embeddings, batched matmul gives the dot products that go into the binary cross-entropy loss:

def skip_gram(center, contexts_and_negatives, embed_v, embed_u,
              params_v, params_u):
    v = embed_v.apply(params_v, center)
    u = embed_u.apply(params_u, contexts_and_negatives)
    pred = jnp.matmul(v, jnp.transpose(u, (0, 2, 1)))
    return pred
skip_gram(jnp.ones((2, 1), dtype=jnp.int32),
          jnp.ones((2, 4), dtype=jnp.int32), embed, embed,
          params, params).shape
(2, 1, 4)

Masked binary cross-entropy

Pad mask kicks the loss for invalid positions to zero so batching with variable-length context lists works:

def loss(inputs, target, mask=None):
    """Binary cross-entropy loss with masking."""
    out = optax.sigmoid_binary_cross_entropy(inputs, target)
    if mask is not None:
        out = out * mask
    return out.mean(axis=1)
pred = d2l.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = d2l.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = d2l.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])
loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)
Array([0.93521017, 1.8462093 ], dtype=float32)
def sigmd(x):
    return -math.log(1 / (1 + math.exp(-x)))

print(f'{(sigmd(1.1) + sigmd(2.2) + sigmd(-3.3) + sigmd(4.4)) / 4:.4f}')
print(f'{(sigmd(-1.1) + sigmd(-2.2)) / 2:.4f}')
0.9352
1.8462

Init

Initialize two embedding tables. At the end, nearest-neighbor queries usually use the center-word table, but the context table was trained jointly and contains similar information.

embed_size = 100
embed_v = nn.Embed(num_embeddings=len(vocab), features=embed_size)
embed_u = nn.Embed(num_embeddings=len(vocab), features=embed_size)

Training loop

Standard SGD; CPU is fine because the model is tiny and data loading dominates. The reported loss is binary classification loss over real and negative context pairs:

def train(embed_v, embed_u, data_iter, lr, num_epochs):
    key = jax.random.PRNGKey(42)
    key, key_v, key_u = jax.random.split(key, 3)
    # Initialize parameters
    dummy = jnp.ones((1,), dtype=jnp.int32)
    params_v = embed_v.init(key_v, dummy)
    params_u = embed_u.init(key_u, dummy)
    all_params = {'v': params_v, 'u': params_u}
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(all_params)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs])

    @jax.jit
    def train_step(all_params, opt_state, center, context_negative,
                   mask, label):
        def compute_loss(all_params):
            pred = skip_gram(center, context_negative, embed_v, embed_u,
                             all_params['v'], all_params['u'])
            l = (loss(pred.reshape(label.shape), label, mask)
                 / mask.sum(axis=1) * mask.shape[1])
            return l.sum(), l.size
        (loss_val, l_size), grads = jax.value_and_grad(
            compute_loss, has_aux=True)(all_params)
        updates, opt_state = optimizer.update(grads, opt_state, all_params)
        all_params = optax.apply_updates(all_params, updates)
        return all_params, opt_state, loss_val, l_size

    for epoch in range(num_epochs):
        timer, num_batches = d2l.Timer(), len(data_iter)
        # Accumulate on device to avoid per-batch host syncs
        loss_sum, count = jnp.array(0.0), jnp.array(0, dtype=jnp.int32)
        for i, batch in enumerate(data_iter):
            center, context_negative, mask, label = batch
            all_params, opt_state, loss_val, l_size = train_step(
                all_params, opt_state, center, context_negative, mask, label)
            loss_sum = loss_sum + loss_val
            count = count + l_size
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (float(loss_sum / count),))
    total_loss = float(loss_sum)
    total_count = int(count)
    print(f'loss {total_loss / total_count:.3f}, '
          f'{total_count / timer.stop():.1f} tokens/sec')
    return all_params
lr, num_epochs = 0.002, 5
all_params = train(embed_v, embed_u, data_iter, lr, num_epochs)

loss 0.356, 362964.3 tokens/sec

Using the embeddings

Look up similar words by cosine similarity. Trained embeddings cluster semantically related terms; failures often come from rare words or corpus-specific meanings:

def get_similar_tokens(query_token, k, embed_params):
    W = embed_params['params']['embedding']
    x = W[vocab[query_token]]
    # Compute the cosine similarity. Add 1e-9 for numerical stability
    cos = jnp.dot(W, x) / jnp.sqrt(jnp.sum(W * W, axis=1) *
                                    jnp.sum(x * x) + 1e-9)
    topk = jnp.argsort(-cos)[:k + 1]
    for i in topk[1:]:  # Remove the input words
        print(f'cosine sim={float(cos[i]):.3f}: {vocab.to_tokens(int(i))}')

get_similar_tokens('chip', 3, all_params['v'])
cosine sim=0.584: microprocessor
cosine sim=0.534: mips
cosine sim=0.526: intel

Recap

  • Skip-gram + neg sampling = train two embeddings per word with binary cross-entropy on (center, context) pairs.
  • Cheap, parallelizable, no softmax over the vocab.
  • Output: dense word vectors with semantic structure (\mathbf{v}_\text{king} - \mathbf{v}_\text{man} + \mathbf{v}_\text{woman} \approx \mathbf{v}_\text{queen}).
  • Outdated as a state-of-the-art (BERT/contextual embeddings dominate) but still useful as input to small classifiers and as a teaching example.