from d2l import mxnet as d2l
from mxnet import npx, np
from mxnet.gluon import rnn
npx.set_np()LM conditions only on the past, so a left-to-right RNN is fine. But many tasks need both sides:
___. → “happy”___ hungry. → “very” / “not”___ hungry, and I can eat half a pig. → “very” onlyRight-context flips the answer. Bidirectional RNNs (Schuster & Paliwal, 1997) run two RNNs — one forward, one backward — and concatenate their hidden states at each step.
Forward + backward, hidden states concatenated.
Use case: encoding tasks (POS tagging, NER, BERT-style pretraining). Not an LM — you’d be peeking at the target.
Two RNNs with separate parameters; each step has both a forward and a backward hidden state:
\overrightarrow{\mathbf{H}}_t = \phi(\mathbf{X}_t \mathbf{W}_{xh}^{(f)} + \overrightarrow{\mathbf{H}}_{t-1} \mathbf{W}_{hh}^{(f)} + \mathbf{b}_h^{(f)}), \overleftarrow{\mathbf{H}}_t = \phi(\mathbf{X}_t \mathbf{W}_{xh}^{(b)} + \overleftarrow{\mathbf{H}}_{t+1} \mathbf{W}_{hh}^{(b)} + \mathbf{b}_h^{(b)}).
Concatenate to form the layer output: \mathbf{H}_t = [\overrightarrow{\mathbf{H}}_t, \overleftarrow{\mathbf{H}}_t] \in \mathbb{R}^{n \times 2h}. The output layer reads from 2h features.
Two RNNScratch cells, output dim doubled:
class BiRNNScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.f_rnn = d2l.RNNScratch(num_inputs, num_hiddens, sigma)
self.b_rnn = d2l.RNNScratch(num_inputs, num_hiddens, sigma)
self.num_hiddens *= 2 # The output dimension will be doubledRun the backward cell on the reversed input, then concatenate forward and backward outputs at each time step:
PyTorch / MXNet / TF expose this as a one-flag toggle on the library cell:
bidirectional=True.