Embedding layers

Pretraining word2vec

word2vec Pretraining

The skip-gram model with negative sampling, Mikolov et al. 2013. Each word has two embedding vectors:

  • \mathbf{v}_w — its embedding as a center word.
  • \mathbf{u}_w — its embedding as a context word.

Per minibatch, for each (center, context+) pair plus K sampled negatives:

\mathcal{L} = -\log \sigma(\mathbf{u}_{c+}^\top \mathbf{v}_w) - \sum_{c-} \log \sigma(-\mathbf{u}_{c-}^\top \mathbf{v}_w).

Binary classification — distinguish real (center, context) pairs from negatives. Cheap, embarrassingly parallel, trains fast on a CPU. After convergence, \mathbf{v}_w (or \mathbf{u}_w, or their sum) is the word embedding.

Setup

from d2l import tensorflow as d2l
import math
import tensorflow as tf
import keras

batch_size, max_window_size, num_noise_words = 512, 5, 5
data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size,
                                     num_noise_words)

Two nn.Embeddings — one for center words, one for context. Same vocab, same dimension, separate weights:

embed = keras.layers.Embedding(input_dim=20, output_dim=4)
# Build the layer so weights are allocated
embed.build((None,))
print(f'Parameter embeddings ({embed.embeddings.shape}, '
      f'dtype={embed.embeddings.dtype})')
Parameter embeddings ((20, 4), dtype=float32)
x = tf.constant([[1, 2, 3], [4, 5, 6]])
embed(x)
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[ 0.04831025,  0.00279317,  0.00059955,  0.0339028 ],
        [-0.04569703,  0.03141973, -0.0368275 , -0.01015358],
        [ 0.00344048,  0.0062824 , -0.02941361,  0.03182969]],

       [[-0.01791795, -0.01939764, -0.02993177, -0.04934167],
        [ 0.02480478,  0.01086305,  0.01374618, -0.00330287],
        [-0.03378527,  0.0324753 , -0.04043093, -0.01626465]]],
      dtype=float32)>

Forward pass

Look up center embeddings and context embeddings, batched matmul gives the dot products that go into the binary cross-entropy loss:

def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
    v = embed_v(center)
    u = embed_u(contexts_and_negatives)
    pred = tf.linalg.matmul(v, tf.transpose(u, perm=[0, 2, 1]))
    return pred
skip_gram(tf.ones((2, 1), dtype=tf.int32),
          tf.ones((2, 4), dtype=tf.int32), embed, embed).shape
TensorShape([2, 1, 4])

Masked binary cross-entropy

Pad mask kicks the loss for invalid positions to zero so batching with variable-length context lists works:

def loss(inputs, target, mask=None):
    """Binary cross-entropy loss with masking."""
    # Use raw per-element sigmoid BCE so we can apply the mask before reducing
    out = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.cast(target, tf.float32),
        logits=tf.cast(inputs, tf.float32))
    if mask is not None:
        out = out * tf.cast(mask, tf.float32)
    return tf.reduce_mean(out, axis=1)
pred = d2l.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = d2l.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = d2l.tensor([[1, 1, 1, 1], [1, 1, 0, 0]], dtype=tf.float32)
loss(pred, label, mask) * mask.shape[1] / tf.reduce_sum(mask, axis=1)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.9352101, 1.8462093], dtype=float32)>
def sigmd(x):
    return -math.log(1 / (1 + math.exp(-x)))

print(f'{(sigmd(1.1) + sigmd(2.2) + sigmd(-3.3) + sigmd(4.4)) / 4:.4f}')
print(f'{(sigmd(-1.1) + sigmd(-2.2)) / 2:.4f}')
0.9352
1.8462

Init

Initialize two embedding tables. At the end, nearest-neighbor queries usually use the center-word table, but the context table was trained jointly and contains similar information.

embed_size = 100
embed_v = keras.layers.Embedding(input_dim=len(vocab),
                                 output_dim=embed_size)
embed_u = keras.layers.Embedding(input_dim=len(vocab),
                                 output_dim=embed_size)

Training loop

Standard SGD; CPU is fine because the model is tiny and data loading dominates. The reported loss is binary classification loss over real and negative context pairs:

def train(embed_v, embed_u, data_iter, lr, num_epochs):
    optimizer = keras.optimizers.Adam(learning_rate=lr)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs])
    # Sum of normalized losses, no. of normalized losses
    metric = d2l.Accumulator(2)
    # Count batches once, before training, to avoid re-iterating the dataset
    # every epoch just to size the animator.
    num_batches = int(tf.data.experimental.cardinality(data_iter).numpy())
    for epoch in range(num_epochs):
        timer = d2l.Timer()
        for i, batch in enumerate(data_iter):
            center, context_negative, mask, label = batch
            with tf.GradientTape() as tape:
                pred = skip_gram(center, context_negative, embed_v, embed_u)
                l = (loss(tf.reshape(pred, tf.shape(label)), label, mask)
                     / tf.reduce_sum(mask, axis=1) * tf.cast(
                         tf.shape(mask)[1], tf.float32))
                l_sum = tf.reduce_sum(l)
            params = embed_v.trainable_variables + embed_u.trainable_variables
            grads = tape.gradient(l_sum, params)
            optimizer.apply_gradients(zip(grads, params))
            metric.add(float(l_sum), int(tf.size(l)))
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[1],))
    print(f'loss {metric[0] / metric[1]:.3f}, '
          f'{metric[1] / timer.stop():.1f} tokens/sec')
lr, num_epochs = 0.002, 5
train(embed_v, embed_u, data_iter, lr, num_epochs)

loss 0.406, 119554.9 tokens/sec

Using the embeddings

Look up similar words by cosine similarity. Trained embeddings cluster semantically related terms; failures often come from rare words or corpus-specific meanings:

def get_similar_tokens(query_token, k, embed):
    W = embed.embeddings
    x = W[vocab[query_token]]
    # Compute the cosine similarity. Add 1e-9 for numerical stability
    cos = tf.linalg.matvec(W, x) / tf.sqrt(
        tf.reduce_sum(W * W, axis=1) * tf.reduce_sum(x * x) + 1e-9)
    topk = tf.argsort(-cos)[:k + 1]
    for i in topk[1:]:  # Remove the input words
        print(f'cosine sim={float(cos[i]):.3f}: {vocab.to_tokens(int(i))}')

get_similar_tokens('chip', 3, embed_v)
cosine sim=0.652: microprocessor
cosine sim=0.617: intel
cosine sim=0.614: desktop

Recap

  • Skip-gram + neg sampling = train two embeddings per word with binary cross-entropy on (center, context) pairs.
  • Cheap, parallelizable, no softmax over the vocab.
  • Output: dense word vectors with semantic structure (\mathbf{v}_\text{king} - \mathbf{v}_\text{man} + \mathbf{v}_\text{woman} \approx \mathbf{v}_\text{queen}).
  • Outdated as a state-of-the-art (BERT/contextual embeddings dominate) but still useful as input to small classifiers and as a teaching example.