Forward and backward states

Bidirectional Recurrent Neural Networks

Bidirectional RNNs

LM conditions only on the past, so a left-to-right RNN is fine. But many tasks need both sides:

  • I am ___. → “happy”
  • I am ___ hungry. → “very” / “not”
  • I am ___ hungry, and I can eat half a pig. → “very” only

Right-context flips the answer. Bidirectional RNNs (Schuster & Paliwal, 1997) run two RNNs — one forward, one backward — and concatenate their hidden states at each step.

The architecture

Forward + backward, hidden states concatenated.

Use case: encoding tasks (POS tagging, NER, BERT-style pretraining). Not an LM — you’d be peeking at the target.

Setup

from d2l import jax as d2l
from jax import numpy as jnp

Two RNNs with separate parameters; each step has both a forward and a backward hidden state:

\overrightarrow{\mathbf{H}}_t = \phi(\mathbf{X}_t \mathbf{W}_{xh}^{(f)} + \overrightarrow{\mathbf{H}}_{t-1} \mathbf{W}_{hh}^{(f)} + \mathbf{b}_h^{(f)}), \overleftarrow{\mathbf{H}}_t = \phi(\mathbf{X}_t \mathbf{W}_{xh}^{(b)} + \overleftarrow{\mathbf{H}}_{t+1} \mathbf{W}_{hh}^{(b)} + \mathbf{b}_h^{(b)}).

Concatenate to form the layer output: \mathbf{H}_t = [\overrightarrow{\mathbf{H}}_t, \overleftarrow{\mathbf{H}}_t] \in \mathbb{R}^{n \times 2h}. The output layer reads from 2h features.

From scratch

Two RNNScratch cells, output dim doubled:

class BiRNNScratch(d2l.Module):
    num_inputs: int
    num_hiddens: int
    sigma: float = 0.01

    def setup(self):
        self.f_rnn = d2l.RNNScratch(self.num_inputs, self.num_hiddens,
                                    self.sigma)
        self.b_rnn = d2l.RNNScratch(self.num_inputs, self.num_hiddens,
                                    self.sigma)
        # The output dimension will be doubled (forward and backward
        # outputs are concatenated along the last axis).
        self.output_dim = self.num_hiddens * 2

Run the backward cell on the reversed input, then concatenate forward and backward outputs at each time step:

def forward(self, inputs, Hs=None):
    f_H, b_H = Hs if Hs is not None else (None, None)
    f_outputs, f_H = self.f_rnn(inputs, f_H)
    b_outputs, b_H = self.b_rnn(reversed(inputs), b_H)
    outputs = [d2l.concat((f, b), -1) for f, b in zip(
        f_outputs, reversed(b_outputs))]
    return outputs, (f_H, b_H)

Concise: bidirectional=True

PyTorch / MXNet / TF expose this as a one-flag toggle on the library cell:

Recap

  • Bidirectional RNN = forward + backward RNN, hidden states concatenated → output dim is 2h.
  • Use for encoding tasks where both sides are available (POS tagging, NER, masked-LM pretraining); never for next-token language modeling.
  • Roughly 2× compute and 2× parameters vs. a unidirectional RNN.
  • One flag in modern frameworks: bidirectional=True.