Embedding layers

Pretraining word2vec

word2vec Pretraining

The skip-gram model with negative sampling, Mikolov et al. 2013. Each word has two embedding vectors:

  • \mathbf{v}_w — its embedding as a center word.
  • \mathbf{u}_w — its embedding as a context word.

Per minibatch, for each (center, context+) pair plus K sampled negatives:

\mathcal{L} = -\log \sigma(\mathbf{u}_{c+}^\top \mathbf{v}_w) - \sum_{c-} \log \sigma(-\mathbf{u}_{c-}^\top \mathbf{v}_w).

Binary classification — distinguish real (center, context) pairs from negatives. Cheap, embarrassingly parallel, trains fast on a CPU. After convergence, \mathbf{v}_w (or \mathbf{u}_w, or their sum) is the word embedding.

Setup

from d2l import torch as d2l
import math
import torch
from torch import nn

batch_size, max_window_size, num_noise_words = 512, 5, 5
data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size,
                                     num_noise_words)

Two nn.Embeddings — one for center words, one for context. Same vocab, same dimension, separate weights:

embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
print(f'Parameter embedding_weight ({embed.weight.shape}, '
      f'dtype={embed.weight.dtype})')
Parameter embedding_weight (torch.Size([20, 4]), dtype=torch.float32)
x = d2l.tensor([[1, 2, 3], [4, 5, 6]])
embed(x)
tensor([[[ 1.3171,  1.5017, -0.4808,  0.4450],
         [ 2.0219, -0.7700, -2.2510, -0.1117],
         [ 1.4166, -0.5063,  0.5164,  1.1640]],

        [[ 0.7697, -1.1720, -0.7227, -1.4750],
         [-0.7597,  0.6493,  1.4758,  1.0520],
         [-0.3410, -1.2402,  0.3950, -1.4047]]], grad_fn=<EmbeddingBackward0>)

Forward pass

Look up center embeddings and context embeddings, batched matmul gives the dot products that go into the binary cross-entropy loss:

def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
    v = embed_v(center)
    u = embed_u(contexts_and_negatives)
    pred = torch.bmm(v, u.permute(0, 2, 1))
    return pred
skip_gram(torch.ones((2, 1), dtype=torch.long),
          torch.ones((2, 4), dtype=torch.long), embed, embed).shape
torch.Size([2, 1, 4])

Masked binary cross-entropy

Pad mask kicks the loss for invalid positions to zero so batching with variable-length context lists works:

class SigmoidBCELoss(nn.Module):
    # Binary cross-entropy loss with masking
    def __init__(self):
        super().__init__()

    def forward(self, inputs, target, mask=None):
        out = nn.functional.binary_cross_entropy_with_logits(
            inputs, target, weight=mask, reduction="none")
        return out.mean(dim=1)

loss = SigmoidBCELoss()
pred = d2l.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = d2l.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = d2l.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])
loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)
tensor([0.9352, 1.8462])
def sigmd(x):
    return -math.log(1 / (1 + math.exp(-x)))

print(f'{(sigmd(1.1) + sigmd(2.2) + sigmd(-3.3) + sigmd(4.4)) / 4:.4f}')
print(f'{(sigmd(-1.1) + sigmd(-2.2)) / 2:.4f}')
0.9352
1.8462

Init

Initialize two embedding tables. At the end, nearest-neighbor queries usually use the center-word table, but the context table was trained jointly and contains similar information.

embed_size = 100
net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab),
                                 embedding_dim=embed_size),
                    nn.Embedding(num_embeddings=len(vocab),
                                 embedding_dim=embed_size))

Training loop

Standard SGD; CPU is fine because the model is tiny and data loading dominates. The reported loss is binary classification loss over real and negative context pairs:

def train(net, data_iter, lr, num_epochs, device=d2l.try_gpu()):
    def init_weights(module):
        if type(module) == nn.Embedding:
            nn.init.xavier_uniform_(module.weight)
    net.apply(init_weights)
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs])
    # Sum of normalized losses, no. of normalized losses
    metric = d2l.Accumulator(2)
    for epoch in range(num_epochs):
        timer, num_batches = d2l.Timer(), len(data_iter)
        for i, batch in enumerate(data_iter):
            optimizer.zero_grad()
            center, context_negative, mask, label = [
                data.to(device) for data in batch]

            pred = skip_gram(center, context_negative, net[0], net[1])
            l = (loss(pred.reshape(label.shape).float(), label.float(), mask)
                     / mask.sum(axis=1) * mask.shape[1])
            l.sum().backward()
            optimizer.step()
            metric.add(l.sum(), l.numel())
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[1],))
    print(f'loss {metric[0] / metric[1]:.3f}, '
          f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')
lr, num_epochs = 0.002, 5
train(net, data_iter, lr, num_epochs)

loss 0.410, 438566.7 tokens/sec on cuda:0

Using the embeddings

Look up similar words by cosine similarity. Trained embeddings cluster semantically related terms; failures often come from rare words or corpus-specific meanings:

def get_similar_tokens(query_token, k, embed):
    W = embed.weight.data
    x = W[vocab[query_token]]
    # Compute the cosine similarity. Add 1e-9 for numerical stability
    cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) *
                                      torch.sum(x * x) + 1e-9)
    topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32')
    for i in topk[1:]:  # Remove the input words
        print(f'cosine sim={float(cos[i]):.3f}: {vocab.to_tokens(i)}')

get_similar_tokens('chip', 3, net[0])
cosine sim=0.706: microprocessor
cosine sim=0.692: intel
cosine sim=0.642: mips

Recap

  • Skip-gram + neg sampling = train two embeddings per word with binary cross-entropy on (center, context) pairs.
  • Cheap, parallelizable, no softmax over the vocab.
  • Output: dense word vectors with semantic structure (\mathbf{v}_\text{king} - \mathbf{v}_\text{man} + \mathbf{v}_\text{woman} \approx \mathbf{v}_\text{queen}).
  • Outdated as a state-of-the-art (BERT/contextual embeddings dominate) but still useful as input to small classifiers and as a teaching example.