Subword Embedding

Subword Embeddings

Word-level embeddings have a problem: morphologically related words (“happy”, “happily”, “happiness”) get independent vectors. Rare or out-of-vocabulary words get nothing.

Two responses:

  • fastText (Bojanowski et al., 2017) — represent each word as a sum of n-gram embeddings. Generalizes to unseen words via shared subword vectors.
  • Byte pair encoding (BPE) — learn a vocabulary of variable-length subword units from the training data. Frequent words become single tokens; rare words split into morpheme-like pieces. The default in modern Transformers (GPT, BERT-WordPiece, T5).

This deck implements BPE in pure Python.

BPE: greedy merging

Start with a character-level vocabulary. Repeatedly: count adjacent symbol pairs, merge the most common one into a new token. Stop after k merges (sets the final vocabulary size).

import collections

symbols = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
           'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
           '_', '[UNK]']
raw_token_freqs = {'fast_': 4, 'faster_': 3, 'tall_': 5, 'taller_': 4}
token_freqs = {}
for token, freq in raw_token_freqs.items():
    token_freqs[' '.join(list(token))] = raw_token_freqs[token]
token_freqs
def get_max_freq_pair(token_freqs):
    pairs = collections.defaultdict(int)
    for token, freq in token_freqs.items():
        symbols = token.split()
        for i in range(len(symbols) - 1):
            # Key of `pairs` is a tuple of two consecutive symbols
            pairs[symbols[i], symbols[i + 1]] += freq
    return max(pairs, key=pairs.get)  # Key of `pairs` with the max value

Building the merge list

def merge_symbols(max_freq_pair, token_freqs, symbols):
    symbols.append(''.join(max_freq_pair))
    new_token_freqs = dict()
    for token, freq in token_freqs.items():
        new_token = token.replace(' '.join(max_freq_pair),
                                  ''.join(max_freq_pair))
        new_token_freqs[new_token] = token_freqs[token]
    return new_token_freqs
num_merges = 10
for i in range(num_merges):
    max_freq_pair = get_max_freq_pair(token_freqs)
    token_freqs = merge_symbols(max_freq_pair, token_freqs, symbols)
    print(f'merge #{i + 1}:', max_freq_pair)
print(symbols)

Tokenizing new text

After learning, segment a new word by greedily applying the merge rules in order. Out-of-vocabulary words still work — they’re broken into in-vocabulary subwords:

print(list(token_freqs.keys()))
def segment_BPE(tokens, symbols):
    outputs = []
    for token in tokens:
        start, end = 0, len(token)
        cur_output = []
        # Segment token with the longest possible subwords from symbols
        while start < len(token) and start < end:
            if token[start: end] in symbols:
                cur_output.append(token[start: end])
                start = end
                end = len(token)
            else:
                end -= 1
        if start < len(token):
            cur_output.append('[UNK]')
        outputs.append(' '.join(cur_output))
    return outputs
tokens = ['tallest_', 'fatter_']
print(segment_BPE(tokens, symbols))

Recap

  • Subword tokenization sits between character-level (universal but long sequences) and word-level (compact but OOV-prone).
  • BPE greedily merges the most frequent symbol pair each iteration.
  • Modern variants: WordPiece (BERT) — uses likelihood instead of frequency; SentencePiece (T5, LLaMA) — language-agnostic, handles whitespace as a regular symbol.
  • Every modern LM tokenizer is a subword tokenizer of some flavor.