from d2l import tensorflow as d2l
import tensorflow as tf
import kerasDecomposable Attention (Parikh et al., 2016) — a small, fast NLI model that beat much more complex recurrence-based architectures on SNLI in 2016. No recurrence, no convolution — pure attention + MLPs.
Three steps: Attend → Compare → Aggregate.
GloVe → attend → compare → aggregate → 3-way classifier.
Align premise/hypothesis tokens, then compare and aggregate.
Compute alignment weights between every premise word and every hypothesis word. Use them to build aligned context vectors:
def mlp(num_hiddens, flatten):
net = keras.Sequential()
net.add(keras.layers.Dropout(0.2))
net.add(keras.layers.Dense(num_hiddens, activation='relu'))
if flatten:
net.add(keras.layers.Flatten())
net.add(keras.layers.Dropout(0.2))
net.add(keras.layers.Dense(num_hiddens, activation='relu'))
if flatten:
net.add(keras.layers.Flatten())
return netclass Attend(keras.layers.Layer):
def __init__(self, num_hiddens, **kwargs):
super(Attend, self).__init__(**kwargs)
self.f = mlp(num_hiddens=num_hiddens, flatten=False)
def call(self, A, B):
# Shape of `A`/`B`: (`batch_size`, no. of tokens in sequence A/B,
# `embed_size`)
# Shape of `f_A`/`f_B`: (`batch_size`, no. of tokens in sequence A/B,
# `num_hiddens`)
f_A = self.f(A)
f_B = self.f(B)
# Shape of `e`: (`batch_size`, no. of tokens in sequence A,
# no. of tokens in sequence B)
e = tf.matmul(f_A, tf.transpose(f_B, perm=[0, 2, 1]))
# Shape of `beta`: (`batch_size`, no. of tokens in sequence A,
# `embed_size`), where sequence B is softly aligned with each token
# (axis 1 of `beta`) in sequence A
beta = tf.matmul(tf.nn.softmax(e, axis=-1), B)
# Shape of `alpha`: (`batch_size`, no. of tokens in sequence B,
# `embed_size`), where sequence A is softly aligned with each token
# (axis 1 of `alpha`) in sequence B
alpha = tf.matmul(
tf.nn.softmax(tf.transpose(e, perm=[0, 2, 1]), axis=-1), A)
return beta, alphaFor each premise word a_i, run an MLP on [a_i, \beta_i] where \beta_i is the soft-aligned hypothesis context. Same for hypothesis words:
class Compare(keras.layers.Layer):
def __init__(self, num_hiddens, **kwargs):
super(Compare, self).__init__(**kwargs)
self.g = mlp(num_hiddens=num_hiddens, flatten=False)
def call(self, A, B, beta, alpha):
V_A = self.g(tf.concat([A, beta], axis=2))
V_B = self.g(tf.concat([B, alpha], axis=2))
return V_A, V_BSum the per-token compared vectors → concat the two sentence summaries → final MLP → 3-way logits:
class Aggregate(keras.layers.Layer):
def __init__(self, num_hiddens, num_outputs, **kwargs):
super(Aggregate, self).__init__(**kwargs)
self.h = mlp(num_hiddens=num_hiddens, flatten=True)
self.linear = keras.layers.Dense(num_outputs)
def call(self, V_A, V_B):
# Sum up both sets of comparison vectors
V_A = tf.reduce_sum(V_A, axis=1)
V_B = tf.reduce_sum(V_B, axis=1)
# Feed the concatenation of both summarization results into an MLP
Y_hat = self.linear(self.h(tf.concat([V_A, V_B], axis=1)))
return Y_hatThe final module wires the three stages into one classifier. Inputs are premise IDs and hypothesis IDs; output is 3 logits for entailment, contradiction, and neutral.
class DecomposableAttention(keras.Model):
def __init__(self, vocab, embed_size, num_hiddens, **kwargs):
super(DecomposableAttention, self).__init__(**kwargs)
self.embedding = keras.layers.Embedding(len(vocab), embed_size)
self.attend = Attend(num_hiddens)
self.compare = Compare(num_hiddens)
# There are 3 possible outputs: entailment, contradiction, and neutral
self.aggregate = Aggregate(num_hiddens, 3)
def call(self, inputs, training=False, **kwargs):
premises, hypotheses = inputs
A = self.embedding(premises)
B = self.embedding(hypotheses)
beta, alpha = self.attend(A, B)
V_A, V_B = self.compare(A, B, beta, alpha)
Y_hat = self.aggregate(V_A, V_B)
return Y_hatSNLI examples are padded premise/hypothesis pairs. Initialize the model with GloVe embeddings, then train all MLP stages end-to-end:
read 549367 examples
read 9824 examples
embed_size, num_hiddens, devices = 100, 200, d2l.try_all_gpus()
net = DecomposableAttention(vocab, embed_size, num_hiddens)
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
# Build the embedding layer before assigning pretrained weights.
net.embedding.build((None,))
net.embedding.set_weights([embeds])Loss should fall quickly: there is no recurrence, so every token-pair alignment and every MLP comparison is fully parallelizable.
lr, num_epochs = 0.001, 4
# Wrap tf.data batches: each yields (premises, hypotheses, labels);
# Keras model expects (X, y) where X = (premises, hypotheses)
def reformat(premises, hypotheses, labels):
return (premises, hypotheses), labels
train_iter_tf = train_iter.map(reformat)
test_iter_tf = test_iter.map(reformat)
net.compile(optimizer=keras.optimizers.Adam(lr),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
net.fit(train_iter_tf, validation_data=test_iter_tf, epochs=num_epochs)Epoch 1/4
Final epoch metrics: accuracy: 0.3281 - loss: 4.0683
Final epoch metrics: accuracy: 0.3148 - loss: 4.3249
Final epoch metrics: accuracy: 0.3260 - loss: 3.4287
Final epoch metrics: accuracy: 0.3315 - loss: 2.9140
Final epoch metrics: accuracy: 0.3364 - loss: 2.6008
...
Final epoch metrics: accuracy: 0.7942 - loss: 0.5206
Final epoch metrics: accuracy: 0.7942 - loss: 0.5206
Final epoch metrics: accuracy: 0.7942 - loss: 0.5205
Final epoch metrics: accuracy: 0.7942 - loss: 0.5205
Final epoch metrics: accuracy: 0.7949 - loss: 0.5184 - val_accuracy: 0.8150 - val_loss: 0.4777
Read the examples semantically: “he is good” follows from “he is great”, while “he is bad” contradicts it. The model’s label mapping should reflect that ordering.
def predict_snli(net, vocab, premise, hypothesis):
"""Predict the logical relationship between the premise and hypothesis."""
premise = tf.constant([vocab[premise]], dtype=tf.int32)
hypothesis = tf.constant([vocab[hypothesis]], dtype=tf.int32)
label = tf.argmax(net((premise, hypothesis), training=False), axis=1)
return 'entailment' if label == 0 else 'contradiction' if label == 1 \
else 'neutral'