Step 1: Attend

Natural Language Inference: Using Attention

Decomposable Attention

Decomposable Attention (Parikh et al., 2016) — a small, fast NLI model that beat much more complex recurrence-based architectures on SNLI in 2016. No recurrence, no convolution — pure attention + MLPs.

Three steps: AttendCompareAggregate.

Pipeline

GloVe → attend → compare → aggregate → 3-way classifier.

The decomposable attention model

Align premise/hypothesis tokens, then compare and aggregate.

Setup

from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np

Compute alignment weights between every premise word and every hypothesis word. Use them to build aligned context vectors:

class MLP(nn.Module):
    num_hiddens: int
    flatten: bool

    @nn.compact
    def __call__(self, x, training=False):
        x = nn.Dropout(0.2)(x, deterministic=not training)
        x = nn.Dense(self.num_hiddens)(x)
        x = nn.relu(x)
        if self.flatten:
            x = x.reshape((x.shape[0], -1))
        x = nn.Dropout(0.2)(x, deterministic=not training)
        x = nn.Dense(self.num_hiddens)(x)
        x = nn.relu(x)
        if self.flatten:
            x = x.reshape((x.shape[0], -1))
        return x
class Attend(nn.Module):
    num_hiddens: int

    @nn.compact
    def __call__(self, A, B, training=False):
        f = MLP(self.num_hiddens, flatten=False)
        # Shape of `A`/`B`: (`batch_size`, no. of tokens in sequence A/B,
        # `embed_size`)
        # Shape of `f_A`/`f_B`: (`batch_size`, no. of tokens in sequence A/B,
        # `num_hiddens`)
        f_A = f(A, training=training)
        f_B = f(B, training=training)
        # Shape of `e`: (`batch_size`, no. of tokens in sequence A,
        # no. of tokens in sequence B)
        e = jnp.matmul(f_A, f_B.transpose(0, 2, 1))
        # Shape of `beta`: (`batch_size`, no. of tokens in sequence A,
        # `embed_size`), where sequence B is softly aligned with each token
        # (axis 1 of `beta`) in sequence A
        beta = jnp.matmul(jax.nn.softmax(e, axis=-1), B)
        # Shape of `alpha`: (`batch_size`, no. of tokens in sequence B,
        # `embed_size`), where sequence A is softly aligned with each token
        # (axis 1 of `alpha`) in sequence B
        alpha = jnp.matmul(jax.nn.softmax(e.transpose(0, 2, 1), axis=-1), A)
        return beta, alpha

Step 2: Compare

For each premise word a_i, run an MLP on [a_i, \beta_i] where \beta_i is the soft-aligned hypothesis context. Same for hypothesis words:

class Compare(nn.Module):
    num_hiddens: int

    @nn.compact
    def __call__(self, A, B, beta, alpha, training=False):
        g = MLP(self.num_hiddens, flatten=False)
        V_A = g(jnp.concatenate([A, beta], axis=2), training=training)
        V_B = g(jnp.concatenate([B, alpha], axis=2), training=training)
        return V_A, V_B

Step 3: Aggregate

Sum the per-token compared vectors → concat the two sentence summaries → final MLP → 3-way logits:

class Aggregate(nn.Module):
    num_hiddens: int
    num_outputs: int

    @nn.compact
    def __call__(self, V_A, V_B, training=False):
        # Sum up both sets of comparison vectors
        V_A = V_A.sum(axis=1)
        V_B = V_B.sum(axis=1)
        # Feed the concatenation of both summarization results into an MLP
        Y_hat = nn.Dense(self.num_outputs)(
            MLP(self.num_hiddens, flatten=True)(
                jnp.concatenate([V_A, V_B], axis=1), training=training))
        return Y_hat

Putting it together

The final module wires the three stages into one classifier. Inputs are premise IDs and hypothesis IDs; output is 3 logits for entailment, contradiction, and neutral.

class DecomposableAttention(nn.Module):
    vocab_size: int
    embed_size: int
    num_hiddens: int

    @nn.compact
    def __call__(self, premises, hypotheses, training=False):
        A = nn.Embed(self.vocab_size, self.embed_size)(premises)
        B = nn.Embed(self.vocab_size, self.embed_size)(hypotheses)
        beta, alpha = Attend(self.num_hiddens)(A, B, training=training)
        V_A, V_B = Compare(self.num_hiddens)(
            A, B, beta, alpha, training=training)
        # There are 3 possible outputs: entailment, contradiction, and neutral
        Y_hat = Aggregate(self.num_hiddens, num_outputs=3)(
            V_A, V_B, training=training)
        return Y_hat

Loading data + model

SNLI examples are padded premise/hypothesis pairs. Initialize the model with GloVe embeddings, then train all MLP stages end-to-end:

batch_size, num_steps = 256, 50
train_iter, test_iter, vocab = d2l.load_data_snli(batch_size, num_steps)
read 549367 examples
read 9824 examples
embed_size, num_hiddens, devices = 100, 200, d2l.try_all_gpus()
net = DecomposableAttention(len(vocab), embed_size, num_hiddens)
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]

Training

Loss should fall quickly: there is no recurrence, so every token-pair alignment and every MLP comparison is fully parallelizable.

lr, num_epochs = 0.001, 4
# Initialize model parameters
dummy_premises = jnp.ones((1, 50), dtype=jnp.int32)
dummy_hypotheses = jnp.ones((1, 50), dtype=jnp.int32)
rng = jax.random.PRNGKey(0)
params = net.init(rng, dummy_premises, dummy_hypotheses, training=False)
# Replace the embedding parameters with pretrained GloVe embeddings
params = {**params, 'params': {**params['params'],
    'Embed_0': {'embedding': jnp.array(embeds)}}}

optimizer = optax.adam(lr)
opt_state = optimizer.init(params)

def loss_fn(params, premises, hypotheses, labels, rng):
    logits = net.apply(params, premises, hypotheses, training=True,
                       rngs={'dropout': rng})
    return optax.softmax_cross_entropy_with_integer_labels(
        logits, labels).mean()

@jax.jit
def train_step(params, opt_state, premises, hypotheses, labels, rng):
    loss, grads = jax.value_and_grad(loss_fn)(
        params, premises, hypotheses, labels, rng)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

@jax.jit
def eval_step(params, premises, hypotheses, labels):
    logits = net.apply(params, premises, hypotheses, training=False)
    return (logits.argmax(axis=-1) == labels).sum()

for epoch in range(num_epochs):
    train_loss, n_train = 0.0, 0
    for batch in train_iter:
        premises, hypotheses, labels = batch[0], batch[1], batch[2]
        rng, step_rng = jax.random.split(rng)
        params, opt_state, loss = train_step(
            params, opt_state, premises, hypotheses, labels, step_rng)
        train_loss += float(loss) * len(labels)
        n_train += len(labels)
    # Evaluate on test set
    n_correct, n_test = 0, 0
    for batch in test_iter:
        premises, hypotheses, labels = batch[0], batch[1], batch[2]
        n_correct += int(eval_step(params, premises, hypotheses, labels))
        n_test += len(labels)
    print(f'epoch {epoch + 1}, loss {train_loss / n_train:.4f}, '
          f'test acc {n_correct / n_test:.4f}')
epoch 1, loss 0.8864, test acc 0.6559
epoch 2, loss 0.7418, test acc 0.6972
epoch 3, loss 0.6797, test acc 0.7232
epoch 4, loss 0.6332, test acc 0.7384

Predict

Read the examples semantically: “he is good” follows from “he is great”, while “he is bad” contradicts it. The model’s label mapping should reflect that ordering.

def predict_snli(net, params, vocab, premise, hypothesis):
    """Predict the logical relationship between the premise and hypothesis."""
    premise = jnp.array(vocab[premise]).reshape((1, -1))
    hypothesis = jnp.array(vocab[hypothesis]).reshape((1, -1))
    label = jnp.argmax(net.apply(params, premise, hypothesis, training=False),
                       axis=1)
    return 'entailment' if label == 0 else 'contradiction' if label == 1 \
            else 'neutral'
predict_snli(net, params, vocab, ['he', 'is', 'good', '.'],
             ['he', 'is', 'bad', '.'])
'contradiction'

Recap

  • Decomposable Attention does NLI in three small MLP stages: attend, compare, aggregate.
  • No recurrence — completely parallelizable; trains fast even before GPU acceleration was abundant.
  • A precursor to the cross-attention machinery that BERT (next deck) does end-to-end inside one Transformer encoder.