Step 1: Attend

Natural Language Inference: Using Attention

Decomposable Attention

Decomposable Attention (Parikh et al., 2016) — a small, fast NLI model that beat much more complex recurrence-based architectures on SNLI in 2016. No recurrence, no convolution — pure attention + MLPs.

Three steps: AttendCompareAggregate.

Pipeline

GloVe → attend → compare → aggregate → 3-way classifier.

The decomposable attention model

Align premise/hypothesis tokens, then compare and aggregate.

Setup

from d2l import mxnet as d2l
from mxnet import gluon, init, np, npx
from mxnet.gluon import nn

npx.set_np()

Compute alignment weights between every premise word and every hypothesis word. Use them to build aligned context vectors:

def mlp(num_hiddens, flatten):
    net = nn.Sequential()
    net.add(nn.Dropout(0.2))
    net.add(nn.Dense(num_hiddens, activation='relu', flatten=flatten))
    net.add(nn.Dropout(0.2))
    net.add(nn.Dense(num_hiddens, activation='relu', flatten=flatten))
    return net
class Attend(nn.Block):
    def __init__(self, num_hiddens):
        super().__init__()
        self.f = mlp(num_hiddens=num_hiddens, flatten=False)

    def forward(self, A, B):
        # Shape of `A`/`B`: (b`atch_size`, no. of tokens in sequence A/B,
        # `embed_size`)
        # Shape of `f_A`/`f_B`: (`batch_size`, no. of tokens in sequence A/B,
        # `num_hiddens`)
        f_A = self.f(A)
        f_B = self.f(B)
        # Shape of `e`: (`batch_size`, no. of tokens in sequence A,
        # no. of tokens in sequence B)
        e = npx.batch_dot(f_A, f_B, transpose_b=True)
        # Shape of `beta`: (`batch_size`, no. of tokens in sequence A,
        # `embed_size`), where sequence B is softly aligned with each token
        # (axis 1 of `beta`) in sequence A
        beta = npx.batch_dot(npx.softmax(e), B)
        # Shape of `alpha`: (`batch_size`, no. of tokens in sequence B,
        # `embed_size`), where sequence A is softly aligned with each token
        # (axis 1 of `alpha`) in sequence B
        alpha = npx.batch_dot(npx.softmax(e.transpose(0, 2, 1)), A)
        return beta, alpha

Step 2: Compare

For each premise word a_i, run an MLP on [a_i, \beta_i] where \beta_i is the soft-aligned hypothesis context. Same for hypothesis words:

class Compare(nn.Block):
    def __init__(self, num_hiddens):
        super().__init__()
        self.g = mlp(num_hiddens=num_hiddens, flatten=False)

    def forward(self, A, B, beta, alpha):
        V_A = self.g(np.concatenate([A, beta], axis=2))
        V_B = self.g(np.concatenate([B, alpha], axis=2))
        return V_A, V_B

Step 3: Aggregate

Sum the per-token compared vectors → concat the two sentence summaries → final MLP → 3-way logits:

class Aggregate(nn.Block):
    def __init__(self, num_hiddens, num_outputs):
        super().__init__()
        self.h = mlp(num_hiddens=num_hiddens, flatten=True)
        self.h.add(nn.Dense(num_outputs))

    def forward(self, V_A, V_B):
        # Sum up both sets of comparison vectors
        V_A = V_A.sum(axis=1)
        V_B = V_B.sum(axis=1)
        # Feed the concatenation of both summarization results into an MLP
        Y_hat = self.h(np.concatenate([V_A, V_B], axis=1))
        return Y_hat

Putting it together

The final module wires the three stages into one classifier. Inputs are premise IDs and hypothesis IDs; output is 3 logits for entailment, contradiction, and neutral.

class DecomposableAttention(nn.Block):
    def __init__(self, vocab, embed_size, num_hiddens):
        super().__init__()
        self.embedding = nn.Embedding(len(vocab), embed_size)
        self.attend = Attend(num_hiddens)
        self.compare = Compare(num_hiddens)
        # There are 3 possible outputs: entailment, contradiction, and neutral
        self.aggregate = Aggregate(num_hiddens, 3)

    def forward(self, X):
        premises, hypotheses = X
        A = self.embedding(premises)
        B = self.embedding(hypotheses)
        beta, alpha = self.attend(A, B)
        V_A, V_B = self.compare(A, B, beta, alpha)
        Y_hat = self.aggregate(V_A, V_B)
        return Y_hat

Loading data + model

SNLI examples are padded premise/hypothesis pairs. Initialize the model with GloVe embeddings, then train all MLP stages end-to-end:

batch_size, num_steps = 256, 50
train_iter, test_iter, vocab = d2l.load_data_snli(batch_size, num_steps)
embed_size, num_hiddens, devices = 100, 200, d2l.try_all_gpus()
net = DecomposableAttention(vocab, embed_size, num_hiddens)
net.initialize(init.Xavier(), ctx=devices)
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.set_data(embeds)

Training

Loss should fall quickly: there is no recurrence, so every token-pair alignment and every MLP comparison is fully parallelizable.

def split_batch_multi_inputs(X, y, devices):
    """Split multi-input `X` and `y` into multiple devices."""
    X = list(zip(*[gluon.utils.split_and_load(
        feature, devices, even_split=False) for feature in X]))
    return (X, gluon.utils.split_and_load(y, devices, even_split=False))
# lr divided by batch_size: gluon Trainer no longer rescales (issue 7 fix in d2l.train_batch_ch13)
lr, num_epochs = 3.90625e-6, 4
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices,
               split_batch_multi_inputs)

Predict

Read the examples semantically: “he is good” follows from “he is great”, while “he is bad” contradicts it. The model’s label mapping should reflect that ordering.

def predict_snli(net, vocab, premise, hypothesis):
    """Predict the logical relationship between the premise and hypothesis."""
    premise = np.array(vocab[premise], ctx=d2l.try_gpu())
    hypothesis = np.array(vocab[hypothesis], ctx=d2l.try_gpu())
    label = np.argmax(net([premise.reshape((1, -1)),
                           hypothesis.reshape((1, -1))]), axis=1)
    return 'entailment' if label == 0 else 'contradiction' if label == 1 \
            else 'neutral'
predict_snli(net, vocab, ['he', 'is', 'good', '.'], ['he', 'is', 'bad', '.'])

Recap

  • Decomposable Attention does NLI in three small MLP stages: attend, compare, aggregate.
  • No recurrence — completely parallelizable; trains fast even before GPU acceleration was abundant.
  • A precursor to the cross-attention machinery that BERT (next deck) does end-to-end inside one Transformer encoder.