from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as npDecomposable Attention (Parikh et al., 2016) — a small, fast NLI model that beat much more complex recurrence-based architectures on SNLI in 2016. No recurrence, no convolution — pure attention + MLPs.
Three steps: Attend → Compare → Aggregate.
GloVe → attend → compare → aggregate → 3-way classifier.
Align premise/hypothesis tokens, then compare and aggregate.
Compute alignment weights between every premise word and every hypothesis word. Use them to build aligned context vectors:
class MLP(nn.Module):
num_hiddens: int
flatten: bool
@nn.compact
def __call__(self, x, training=False):
x = nn.Dropout(0.2)(x, deterministic=not training)
x = nn.Dense(self.num_hiddens)(x)
x = nn.relu(x)
if self.flatten:
x = x.reshape((x.shape[0], -1))
x = nn.Dropout(0.2)(x, deterministic=not training)
x = nn.Dense(self.num_hiddens)(x)
x = nn.relu(x)
if self.flatten:
x = x.reshape((x.shape[0], -1))
return xclass Attend(nn.Module):
num_hiddens: int
@nn.compact
def __call__(self, A, B, training=False):
f = MLP(self.num_hiddens, flatten=False)
# Shape of `A`/`B`: (`batch_size`, no. of tokens in sequence A/B,
# `embed_size`)
# Shape of `f_A`/`f_B`: (`batch_size`, no. of tokens in sequence A/B,
# `num_hiddens`)
f_A = f(A, training=training)
f_B = f(B, training=training)
# Shape of `e`: (`batch_size`, no. of tokens in sequence A,
# no. of tokens in sequence B)
e = jnp.matmul(f_A, f_B.transpose(0, 2, 1))
# Shape of `beta`: (`batch_size`, no. of tokens in sequence A,
# `embed_size`), where sequence B is softly aligned with each token
# (axis 1 of `beta`) in sequence A
beta = jnp.matmul(jax.nn.softmax(e, axis=-1), B)
# Shape of `alpha`: (`batch_size`, no. of tokens in sequence B,
# `embed_size`), where sequence A is softly aligned with each token
# (axis 1 of `alpha`) in sequence B
alpha = jnp.matmul(jax.nn.softmax(e.transpose(0, 2, 1), axis=-1), A)
return beta, alphaFor each premise word a_i, run an MLP on [a_i, \beta_i] where \beta_i is the soft-aligned hypothesis context. Same for hypothesis words:
Sum the per-token compared vectors → concat the two sentence summaries → final MLP → 3-way logits:
class Aggregate(nn.Module):
num_hiddens: int
num_outputs: int
@nn.compact
def __call__(self, V_A, V_B, training=False):
# Sum up both sets of comparison vectors
V_A = V_A.sum(axis=1)
V_B = V_B.sum(axis=1)
# Feed the concatenation of both summarization results into an MLP
Y_hat = nn.Dense(self.num_outputs)(
MLP(self.num_hiddens, flatten=True)(
jnp.concatenate([V_A, V_B], axis=1), training=training))
return Y_hatThe final module wires the three stages into one classifier. Inputs are premise IDs and hypothesis IDs; output is 3 logits for entailment, contradiction, and neutral.
class DecomposableAttention(nn.Module):
vocab_size: int
embed_size: int
num_hiddens: int
@nn.compact
def __call__(self, premises, hypotheses, training=False):
A = nn.Embed(self.vocab_size, self.embed_size)(premises)
B = nn.Embed(self.vocab_size, self.embed_size)(hypotheses)
beta, alpha = Attend(self.num_hiddens)(A, B, training=training)
V_A, V_B = Compare(self.num_hiddens)(
A, B, beta, alpha, training=training)
# There are 3 possible outputs: entailment, contradiction, and neutral
Y_hat = Aggregate(self.num_hiddens, num_outputs=3)(
V_A, V_B, training=training)
return Y_hatSNLI examples are padded premise/hypothesis pairs. Initialize the model with GloVe embeddings, then train all MLP stages end-to-end:
read 549367 examples
read 9824 examples
Loss should fall quickly: there is no recurrence, so every token-pair alignment and every MLP comparison is fully parallelizable.
lr, num_epochs = 0.001, 4
# Initialize model parameters
dummy_premises = jnp.ones((1, 50), dtype=jnp.int32)
dummy_hypotheses = jnp.ones((1, 50), dtype=jnp.int32)
rng = jax.random.PRNGKey(0)
params = net.init(rng, dummy_premises, dummy_hypotheses, training=False)
# Replace the embedding parameters with pretrained GloVe embeddings
params = {**params, 'params': {**params['params'],
'Embed_0': {'embedding': jnp.array(embeds)}}}
optimizer = optax.adam(lr)
opt_state = optimizer.init(params)
def loss_fn(params, premises, hypotheses, labels, rng):
logits = net.apply(params, premises, hypotheses, training=True,
rngs={'dropout': rng})
return optax.softmax_cross_entropy_with_integer_labels(
logits, labels).mean()
@jax.jit
def train_step(params, opt_state, premises, hypotheses, labels, rng):
loss, grads = jax.value_and_grad(loss_fn)(
params, premises, hypotheses, labels, rng)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
@jax.jit
def eval_step(params, premises, hypotheses, labels):
logits = net.apply(params, premises, hypotheses, training=False)
return (logits.argmax(axis=-1) == labels).sum()
for epoch in range(num_epochs):
train_loss, n_train = 0.0, 0
for batch in train_iter:
premises, hypotheses, labels = batch[0], batch[1], batch[2]
rng, step_rng = jax.random.split(rng)
params, opt_state, loss = train_step(
params, opt_state, premises, hypotheses, labels, step_rng)
train_loss += float(loss) * len(labels)
n_train += len(labels)
# Evaluate on test set
n_correct, n_test = 0, 0
for batch in test_iter:
premises, hypotheses, labels = batch[0], batch[1], batch[2]
n_correct += int(eval_step(params, premises, hypotheses, labels))
n_test += len(labels)
print(f'epoch {epoch + 1}, loss {train_loss / n_train:.4f}, '
f'test acc {n_correct / n_test:.4f}')epoch 1, loss 0.8864, test acc 0.6559
epoch 2, loss 0.7418, test acc 0.6972
epoch 3, loss 0.6797, test acc 0.7232
epoch 4, loss 0.6332, test acc 0.7384
Read the examples semantically: “he is good” follows from “he is great”, while “he is bad” contradicts it. The model’s label mapping should reflect that ordering.
def predict_snli(net, params, vocab, premise, hypothesis):
"""Predict the logical relationship between the premise and hypothesis."""
premise = jnp.array(vocab[premise]).reshape((1, -1))
hypothesis = jnp.array(vocab[hypothesis]).reshape((1, -1))
label = jnp.argmax(net.apply(params, premise, hypothesis, training=False),
axis=1)
return 'entailment' if label == 0 else 'contradiction' if label == 1 \
else 'neutral'