from d2l import jax as d2l
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn
import optax
import numpy as np
batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)Sentiment classification on IMDb: pretrained word vectors → bidirectional LSTM → linear head. Standard pre-Transformer text-classification recipe.
The encoder reads the review left-to-right and right-to-left; concatenated final hidden states feed a binary classifier. GloVe gives a strong initialization that the LSTM then specializes for sentiment.
GloVe embeddings → BiLSTM → output classifier.
Class definition: embedding -> bidirectional LSTM -> concatenate the first and last hidden states -> 2-way decoder. The decoder input has width 4h: two directions times two endpoint states.
class BiRNN(nn.Module):
vocab_size: int
embed_size: int
num_hiddens: int
num_layers: int
def setup(self):
self.embedding = nn.Embed(self.vocab_size, self.embed_size)
# Forward and backward LSTMs for bidirectional encoding
self.forward_rnn = nn.RNN(
nn.OptimizedLSTMCell(self.num_hiddens), reverse=False)
self.backward_rnn = nn.RNN(
nn.OptimizedLSTMCell(self.num_hiddens), reverse=True)
self.decoder = nn.Dense(2)
def __call__(self, inputs):
# The shape of `inputs` is (batch size, no. of time steps)
embeddings = self.embedding(inputs)
# Run forward and backward RNNs
# Each output shape is (batch size, no. of time steps, num_hiddens)
forward_out = self.forward_rnn(embeddings)
# Flax's `nn.RNN(..., reverse=True)` un-reverses its output back to
# input order, so `backward_out[:, 0, :]` corresponds to the
# backward LSTM having consumed the entire reversed sequence (its
# *final* hidden) while `forward_out[:, -1, :]` is the forward
# LSTM's final hidden. Concatenate these two: shape (batch size,
# 2 * num_hiddens).
backward_out = self.backward_rnn(embeddings)
encoding = jnp.concatenate(
[forward_out[:, -1, :], backward_out[:, 0, :]], axis=1)
outs = self.decoder(encoding)
return outsInstantiate a 2-layer BiLSTM with 100-dimensional embeddings and 100 hidden units. Frameworks initialize recurrent weights differently, but the model contract is the same:
Use 100-dim GloVe vectors trained on Wikipedia + Gigaword. Initialize the embedding layer from them; freeze or fine-tune (we fine-tune):
Standard cross-entropy + Adam. Watch validation accuracy, not just training loss; sentiment models overfit quickly on IMDb if the embedding and classifier are too large:
lr, num_epochs = 0.01, 5
# Work with the inner params dict directly so JIT caches across iterations
params_p = params['params']
optimizer = optax.adam(lr)
opt_state = optimizer.init(params_p)
loss_fn = optax.softmax_cross_entropy_with_integer_labels
@jax.jit
def train_step(params_p, opt_state, X, y):
def compute_loss(p):
logits = net.apply({'params': p}, X)
return loss_fn(logits, y).mean(), logits
(loss, logits), grads = jax.value_and_grad(
compute_loss, has_aux=True)(params_p)
updates, opt_state = optimizer.update(grads, opt_state, params_p)
params_p = optax.apply_updates(params_p, updates)
return params_p, opt_state, loss, logits
@jax.jit
def eval_step(params_p, X):
return net.apply({'params': params_p}, X)
for epoch in range(num_epochs):
metric = d2l.Accumulator(4)
for X, y in train_iter:
params_p, opt_state, l, logits = train_step(
params_p, opt_state, X, y)
metric.add(float(l) * len(y), float((logits.argmax(axis=-1) == y).sum()),
len(y), len(y))
# Evaluate
correct, total = 0, 0
for X, y in test_iter:
logits = eval_step(params_p, X)
correct += int((logits.argmax(axis=-1) == y).sum())
total += len(y)
print(f'epoch {epoch + 1}, loss {metric[0] / metric[2]:.3f}, '
f'train acc {metric[1] / metric[3]:.3f}, '
f'test acc {correct / total:.3f}')
# Re-wrap params for downstream use (e.g. predict_sentiment)
params = {'params': params_p}epoch 1, loss 0.694, train acc 0.505, test acc 0.502
epoch 2, loss 0.679, train acc 0.526, test acc 0.628
epoch 3, loss 0.481, train acc 0.776, test acc 0.789
epoch 4, loss 0.283, train acc 0.889, test acc 0.806
epoch 5, loss 0.190, train acc 0.930, test acc 0.805
The final check should classify clearly positive and clearly negative synthetic reviews differently. This is not a full evaluation, but it catches label/order mistakes in the pipeline.
'negative'