BiRNN classifier

Sentiment Analysis: Using Recurrent Neural Networks

Sentiment RNN

Sentiment classification on IMDb: pretrained word vectors → bidirectional LSTM → linear head. Standard pre-Transformer text-classification recipe.

The encoder reads the review left-to-right and right-to-left; concatenated final hidden states feed a binary classifier. GloVe gives a strong initialization that the LSTM then specializes for sentiment.

Pipeline

GloVe embeddings → BiLSTM → output classifier.

Setup

from d2l import jax as d2l
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn
import optax
import numpy as np

batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)

Class definition: embedding -> bidirectional LSTM -> concatenate the first and last hidden states -> 2-way decoder. The decoder input has width 4h: two directions times two endpoint states.

class BiRNN(nn.Module):
    vocab_size: int
    embed_size: int
    num_hiddens: int
    num_layers: int

    def setup(self):
        self.embedding = nn.Embed(self.vocab_size, self.embed_size)
        # Forward and backward LSTMs for bidirectional encoding
        self.forward_rnn = nn.RNN(
            nn.OptimizedLSTMCell(self.num_hiddens), reverse=False)
        self.backward_rnn = nn.RNN(
            nn.OptimizedLSTMCell(self.num_hiddens), reverse=True)
        self.decoder = nn.Dense(2)

    def __call__(self, inputs):
        # The shape of `inputs` is (batch size, no. of time steps)
        embeddings = self.embedding(inputs)
        # Run forward and backward RNNs
        # Each output shape is (batch size, no. of time steps, num_hiddens)
        forward_out = self.forward_rnn(embeddings)
        # Flax's `nn.RNN(..., reverse=True)` un-reverses its output back to
        # input order, so `backward_out[:, 0, :]` corresponds to the
        # backward LSTM having consumed the entire reversed sequence (its
        # *final* hidden) while `forward_out[:, -1, :]` is the forward
        # LSTM's final hidden. Concatenate these two: shape (batch size,
        # 2 * num_hiddens).
        backward_out = self.backward_rnn(embeddings)
        encoding = jnp.concatenate(
            [forward_out[:, -1, :], backward_out[:, 0, :]], axis=1)
        outs = self.decoder(encoding)
        return outs

BiRNN instance

Instantiate a 2-layer BiLSTM with 100-dimensional embeddings and 100 hidden units. Frameworks initialize recurrent weights differently, but the model contract is the same:

embed_size, num_hiddens, num_layers, devices = 100, 100, 2, d2l.try_all_gpus()
net = BiRNN(len(vocab), embed_size, num_hiddens, num_layers)
# JAX/Flax modules are initialized lazily; we initialize parameters here
dummy_input = jnp.ones((1, 500), dtype=jnp.int32)
params = net.init(jax.random.PRNGKey(0), dummy_input)

Loading pretrained GloVe

Use 100-dim GloVe vectors trained on Wikipedia + Gigaword. Initialize the embedding layer from them; freeze or fine-tune (we fine-tune):

glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
embeds.shape
(49346, 100)
# Set pretrained embedding weights in the parameters
params = flax.core.unfreeze(params)
params['params']['embedding']['embedding'] = jnp.array(embeds)
params = flax.core.freeze(params)

Training

Standard cross-entropy + Adam. Watch validation accuracy, not just training loss; sentiment models overfit quickly on IMDb if the embedding and classifier are too large:

lr, num_epochs = 0.01, 5
# Work with the inner params dict directly so JIT caches across iterations
params_p = params['params']
optimizer = optax.adam(lr)
opt_state = optimizer.init(params_p)
loss_fn = optax.softmax_cross_entropy_with_integer_labels

@jax.jit
def train_step(params_p, opt_state, X, y):
    def compute_loss(p):
        logits = net.apply({'params': p}, X)
        return loss_fn(logits, y).mean(), logits
    (loss, logits), grads = jax.value_and_grad(
        compute_loss, has_aux=True)(params_p)
    updates, opt_state = optimizer.update(grads, opt_state, params_p)
    params_p = optax.apply_updates(params_p, updates)
    return params_p, opt_state, loss, logits

@jax.jit
def eval_step(params_p, X):
    return net.apply({'params': params_p}, X)

for epoch in range(num_epochs):
    metric = d2l.Accumulator(4)
    for X, y in train_iter:
        params_p, opt_state, l, logits = train_step(
            params_p, opt_state, X, y)
        metric.add(float(l) * len(y), float((logits.argmax(axis=-1) == y).sum()),
                   len(y), len(y))
    # Evaluate
    correct, total = 0, 0
    for X, y in test_iter:
        logits = eval_step(params_p, X)
        correct += int((logits.argmax(axis=-1) == y).sum())
        total += len(y)
    print(f'epoch {epoch + 1}, loss {metric[0] / metric[2]:.3f}, '
          f'train acc {metric[1] / metric[3]:.3f}, '
          f'test acc {correct / total:.3f}')
# Re-wrap params for downstream use (e.g. predict_sentiment)
params = {'params': params_p}
epoch 1, loss 0.694, train acc 0.505, test acc 0.502
epoch 2, loss 0.679, train acc 0.526, test acc 0.628
epoch 3, loss 0.481, train acc 0.776, test acc 0.789
epoch 4, loss 0.283, train acc 0.889, test acc 0.806
epoch 5, loss 0.190, train acc 0.930, test acc 0.805
def predict_sentiment(net, params, vocab, sequence):
    """Predict the sentiment of a text sequence."""
    sequence = jnp.array(vocab[sequence.split()])
    label = jnp.argmax(net.apply(params, sequence.reshape(1, -1)), axis=1)
    return 'positive' if label == 1 else 'negative'

Predict on new reviews

The final check should classify clearly positive and clearly negative synthetic reviews differently. This is not a full evaluation, but it catches label/order mistakes in the pipeline.

predict_sentiment(net, params, vocab, 'this movie is so great')
'negative'
predict_sentiment(net, params, vocab, 'this movie is so bad')
'negative'

Recap

  • BiLSTM-on-GloVe: a strong pre-Transformer baseline for text classification.
  • Pretrained embeddings carry general-purpose word semantics; LSTM specializes for sentiment.
  • Easily beaten today by fine-tuned BERT, but a clean template for sequence-to-label tasks more broadly.