1D convolution

Sentiment Analysis: Using Convolutional Neural Networks

textCNN

textCNN (Kim, 2014) — a 1D conv net for sentiment. Different architecture, same task as the RNN deck.

Why CNNs on text? Each filter is a learned n-gram detector. Run several filter widths in parallel (3, 4, 5 words) for multi-scale coverage. Max-over-time pool collapses position; concat → linear → softmax. Fast, strong, parallelizable.

Pipeline

GloVe → 1D conv filters of varying widths → max-pool → classifier.

Setup

from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
import flax

batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)

Sliding kernel over a 1D sequence. Output element = elementwise multiply + sum of an n-token window:

1D conv: kernel (1, 2) slides over input; first output is 0 \cdot 1 + 1 \cdot 2 = 2.

def corr1d(X, K):
    w = K.shape[0]
    Y = d2l.zeros((X.shape[0] - w + 1))
    for i in range(Y.shape[0]):
        Y = Y.at[i].set((X[i: i + w] * K).sum())
    return Y

Multi-channel 1D conv

Embedding dim = input channels. Kernel has the same channel count; output is single-channel (or multi if you have multiple kernels).

3-channel 1D conv.

X, K = d2l.tensor([0, 1, 2, 3, 4, 5, 6]), d2l.tensor([1, 2])
corr1d(X, K)
Array([ 2.,  5.,  8., 11., 14., 17.], dtype=float32)

Equivalent 2D-conv view

Equivalent to a 2D conv with kernel height = input height:

def corr1d_multi_in(X, K):
    # First, iterate through the 0th dimension (channel dimension) of `X` and
    # `K`. Then, add them together
    return sum(corr1d(x, k) for x, k in zip(X, K))

X = d2l.tensor([[0, 1, 2, 3, 4, 5, 6],
              [1, 2, 3, 4, 5, 6, 7],
              [2, 3, 4, 5, 6, 7, 8]])
K = d2l.tensor([[1, 2], [3, 4], [-1, -3]])
corr1d_multi_in(X, K)
Array([ 2.,  8., 14., 20., 26., 32.], dtype=float32)

Max-over-time pooling

Take the max over the time axis for each filter. Resulting feature is independent of where in the sequence the n-gram appeared. One scalar per filter, regardless of sentence length:

Max-over-time = max along the sequence axis.

textCNN model

Embedding (frozen GloVe + a fine-tunable copy) → parallel 1D convs at widths 3, 4, 5 → max-over-time → concat → dropout → linear:

class TextCNN(nn.Module):
    vocab_size: int
    embed_size: int
    kernel_sizes: list
    num_channels: list
    training: bool = True

    def setup(self):
        self.embedding = nn.Embed(self.vocab_size, self.embed_size)
        # The embedding layer not to be trained
        self.constant_embedding = nn.Embed(self.vocab_size, self.embed_size)
        self.dropout = nn.Dropout(0.5)
        self.decoder = nn.Dense(2)
        # Create multiple one-dimensional convolutional layers
        self.convs = [nn.Conv(features=c, kernel_size=(k,))
                      for c, k in zip(self.num_channels, self.kernel_sizes)]

    def __call__(self, inputs, deterministic=False):
        # Concatenate two embedding layer outputs with shape (batch size, no.
        # of tokens, token vector dimension) along vectors
        embeddings = jnp.concatenate((
            self.embedding(inputs), self.constant_embedding(inputs)), axis=2)
        # For Flax Conv, input shape is (batch, length, channels) which is
        # already the shape of embeddings
        # For each one-dimensional convolutional layer, after max-over-time
        # pooling, a tensor of shape (batch size, no. of channels) is obtained.
        # Concatenate along channels
        encoding = jnp.concatenate([
            jnp.max(nn.relu(conv(embeddings)), axis=1)
            for conv in self.convs], axis=1)
        outputs = self.decoder(self.dropout(encoding,
                                            deterministic=deterministic))
        return outputs

textCNN instance

The concrete model uses 100 channels at each kernel width. After max-over-time pooling, the classifier sees sum(num_channels) features, independent of review length.

embed_size, kernel_sizes, nums_channels = 100, [3, 4, 5], [100, 100, 100]
devices = d2l.try_all_gpus()
net = TextCNN(len(vocab), embed_size, kernel_sizes, nums_channels)
# Initialize parameters
dummy_input = jnp.ones((1, 500), dtype=jnp.int32)
params = net.init(jax.random.PRNGKey(0), dummy_input, deterministic=True)

Loading pretrained GloVe

Both embedding tables start from the same GloVe vectors: one stays fixed as a semantic anchor, the other is fine-tuned for sentiment-specific cues.

glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
# Set pretrained embedding weights in the parameters
params = flax.core.unfreeze(params)
params['params']['embedding']['embedding'] = jnp.array(embeds)
params['params']['constant_embedding']['embedding'] = jnp.array(embeds)
params = flax.core.freeze(params)

Training

CNNs train fast because all windows are processed in parallel. Use the metric output to compare with the BiLSTM deck: similar accuracy, less sequential computation.

lr, num_epochs = 0.001, 5
optimizer = optax.adam(lr)
opt_state = optimizer.init(params['params'])
loss_fn = optax.softmax_cross_entropy_with_integer_labels

@jax.jit
def train_step(params, opt_state, X, y, key):
    def compute_loss(p):
        logits = net.apply({'params': p}, X, deterministic=False,
                           rngs={'dropout': key})
        return loss_fn(logits, y).mean(), logits
    (loss, logits), grads = jax.value_and_grad(
        compute_loss, has_aux=True)(params)
    updates, opt_state_new = optimizer.update(grads, opt_state, params)
    params_new = optax.apply_updates(params, updates)
    return params_new, opt_state_new, loss, logits

key = jax.random.PRNGKey(0)
for epoch in range(num_epochs):
    metric = d2l.Accumulator(4)
    for X, y in train_iter:
        key, subkey = jax.random.split(key)
        params_p = params['params']
        params_p, opt_state, l, logits = train_step(
            params_p, opt_state, X, y, subkey)
        params = {'params': params_p}
        metric.add(float(l) * len(y), float((logits.argmax(axis=-1) == y).sum()),
                   len(y), len(y))
    # Evaluate
    correct, total = 0, 0
    for X, y in test_iter:
        logits = net.apply(params, X, deterministic=True,
                           rngs={'dropout': jax.random.PRNGKey(0)})
        correct += int((logits.argmax(axis=-1) == y).sum())
        total += len(y)
    print(f'epoch {epoch + 1}, loss {metric[0] / metric[2]:.3f}, '
          f'train acc {metric[1] / metric[3]:.3f}, '
          f'test acc {correct / total:.3f}')
epoch 1, loss 0.524, train acc 0.742, test acc 0.841
epoch 2, loss 0.320, train acc 0.862, test acc 0.868
epoch 3, loss 0.207, train acc 0.917, test acc 0.873
epoch 4, loss 0.107, train acc 0.962, test acc 0.847
epoch 5, loss 0.050, train acc 0.985, test acc 0.865
tokens = jnp.array(vocab['this movie is so great'.split()])
logits = net.apply(params, tokens.reshape(1, -1), deterministic=True,
                   rngs={'dropout': jax.random.PRNGKey(0)})
'positive' if int(jnp.argmax(logits, axis=1)[0]) == 1 else 'negative'
'positive'
tokens = jnp.array(vocab['this movie is so bad'.split()])
logits = net.apply(params, tokens.reshape(1, -1), deterministic=True,
                   rngs={'dropout': jax.random.PRNGKey(0)})
'positive' if int(jnp.argmax(logits, axis=1)[0]) == 1 else 'negative'
'negative'

Recap

  • textCNN = parallel 1D convs over word embeddings + max pooling + linear head.
  • Each filter learns an n-gram detector; different widths give multi-scale coverage.
  • Comparable accuracy to BiLSTM on IMDb at a fraction of the training time and zero recurrence.
  • The shape (parallel filter widths, pooled features) is the template for many text-classification CNNs that followed.