from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
import flax
batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)textCNN (Kim, 2014) — a 1D conv net for sentiment. Different architecture, same task as the RNN deck.
Why CNNs on text? Each filter is a learned n-gram detector. Run several filter widths in parallel (3, 4, 5 words) for multi-scale coverage. Max-over-time pool collapses position; concat → linear → softmax. Fast, strong, parallelizable.
GloVe → 1D conv filters of varying widths → max-pool → classifier.
Sliding kernel over a 1D sequence. Output element = elementwise multiply + sum of an n-token window:
1D conv: kernel (1, 2) slides over input; first output is 0 \cdot 1 + 1 \cdot 2 = 2.
Embedding dim = input channels. Kernel has the same channel count; output is single-channel (or multi if you have multiple kernels).
3-channel 1D conv.
Array([ 2., 5., 8., 11., 14., 17.], dtype=float32)
Equivalent to a 2D conv with kernel height = input height:
def corr1d_multi_in(X, K):
# First, iterate through the 0th dimension (channel dimension) of `X` and
# `K`. Then, add them together
return sum(corr1d(x, k) for x, k in zip(X, K))
X = d2l.tensor([[0, 1, 2, 3, 4, 5, 6],
[1, 2, 3, 4, 5, 6, 7],
[2, 3, 4, 5, 6, 7, 8]])
K = d2l.tensor([[1, 2], [3, 4], [-1, -3]])
corr1d_multi_in(X, K)Array([ 2., 8., 14., 20., 26., 32.], dtype=float32)
Take the max over the time axis for each filter. Resulting feature is independent of where in the sequence the n-gram appeared. One scalar per filter, regardless of sentence length:
Max-over-time = max along the sequence axis.
Embedding (frozen GloVe + a fine-tunable copy) → parallel 1D convs at widths 3, 4, 5 → max-over-time → concat → dropout → linear:
class TextCNN(nn.Module):
vocab_size: int
embed_size: int
kernel_sizes: list
num_channels: list
training: bool = True
def setup(self):
self.embedding = nn.Embed(self.vocab_size, self.embed_size)
# The embedding layer not to be trained
self.constant_embedding = nn.Embed(self.vocab_size, self.embed_size)
self.dropout = nn.Dropout(0.5)
self.decoder = nn.Dense(2)
# Create multiple one-dimensional convolutional layers
self.convs = [nn.Conv(features=c, kernel_size=(k,))
for c, k in zip(self.num_channels, self.kernel_sizes)]
def __call__(self, inputs, deterministic=False):
# Concatenate two embedding layer outputs with shape (batch size, no.
# of tokens, token vector dimension) along vectors
embeddings = jnp.concatenate((
self.embedding(inputs), self.constant_embedding(inputs)), axis=2)
# For Flax Conv, input shape is (batch, length, channels) which is
# already the shape of embeddings
# For each one-dimensional convolutional layer, after max-over-time
# pooling, a tensor of shape (batch size, no. of channels) is obtained.
# Concatenate along channels
encoding = jnp.concatenate([
jnp.max(nn.relu(conv(embeddings)), axis=1)
for conv in self.convs], axis=1)
outputs = self.decoder(self.dropout(encoding,
deterministic=deterministic))
return outputsThe concrete model uses 100 channels at each kernel width. After max-over-time pooling, the classifier sees sum(num_channels) features, independent of review length.
embed_size, kernel_sizes, nums_channels = 100, [3, 4, 5], [100, 100, 100]
devices = d2l.try_all_gpus()
net = TextCNN(len(vocab), embed_size, kernel_sizes, nums_channels)
# Initialize parameters
dummy_input = jnp.ones((1, 500), dtype=jnp.int32)
params = net.init(jax.random.PRNGKey(0), dummy_input, deterministic=True)Both embedding tables start from the same GloVe vectors: one stays fixed as a semantic anchor, the other is fine-tuned for sentiment-specific cues.
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
# Set pretrained embedding weights in the parameters
params = flax.core.unfreeze(params)
params['params']['embedding']['embedding'] = jnp.array(embeds)
params['params']['constant_embedding']['embedding'] = jnp.array(embeds)
params = flax.core.freeze(params)CNNs train fast because all windows are processed in parallel. Use the metric output to compare with the BiLSTM deck: similar accuracy, less sequential computation.
lr, num_epochs = 0.001, 5
optimizer = optax.adam(lr)
opt_state = optimizer.init(params['params'])
loss_fn = optax.softmax_cross_entropy_with_integer_labels
@jax.jit
def train_step(params, opt_state, X, y, key):
def compute_loss(p):
logits = net.apply({'params': p}, X, deterministic=False,
rngs={'dropout': key})
return loss_fn(logits, y).mean(), logits
(loss, logits), grads = jax.value_and_grad(
compute_loss, has_aux=True)(params)
updates, opt_state_new = optimizer.update(grads, opt_state, params)
params_new = optax.apply_updates(params, updates)
return params_new, opt_state_new, loss, logits
key = jax.random.PRNGKey(0)
for epoch in range(num_epochs):
metric = d2l.Accumulator(4)
for X, y in train_iter:
key, subkey = jax.random.split(key)
params_p = params['params']
params_p, opt_state, l, logits = train_step(
params_p, opt_state, X, y, subkey)
params = {'params': params_p}
metric.add(float(l) * len(y), float((logits.argmax(axis=-1) == y).sum()),
len(y), len(y))
# Evaluate
correct, total = 0, 0
for X, y in test_iter:
logits = net.apply(params, X, deterministic=True,
rngs={'dropout': jax.random.PRNGKey(0)})
correct += int((logits.argmax(axis=-1) == y).sum())
total += len(y)
print(f'epoch {epoch + 1}, loss {metric[0] / metric[2]:.3f}, '
f'train acc {metric[1] / metric[3]:.3f}, '
f'test acc {correct / total:.3f}')epoch 1, loss 0.524, train acc 0.742, test acc 0.841
epoch 2, loss 0.320, train acc 0.862, test acc 0.868
epoch 3, loss 0.207, train acc 0.917, test acc 0.873
epoch 4, loss 0.107, train acc 0.962, test acc 0.847
epoch 5, loss 0.050, train acc 0.985, test acc 0.865
'positive'