Forward pass

Long Short-Term Memory (LSTM)

Long Short-Term Memory

Vanilla RNNs hit a ceiling: gradients vanish across long sequences. LSTMs (Hochreiter & Schmidhuber, 1997) fix this by giving each unit a memory cell with a self-loop of weight 1 and three learned gates.

  • Forget gate \mathbf{F}_t — keep or wipe memory.
  • Input gate \mathbf{I}_t — let new content in.
  • Output gate \mathbf{O}_t — expose or hide memory.

For two decades, the sequence model — speech, translation, language modeling — until Transformers took over (2017).

Gates at a glance

The input \mathbf{X}_t and previous hidden state \mathbf{H}_{t-1} feed three sigmoid gates.

Computing the input, forget, and output gates.

The three gates

The gates are learned, elementwise switches:

\mathbf{I}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i), \mathbf{F}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f), \mathbf{O}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o).

Each value lies in (0, 1):

  • \mathbf{F}_t decides what to retain.
  • \mathbf{I}_t decides what new content to write.
  • \mathbf{O}_t decides what memory to expose.

Plus an input node

A fourth head — the input node — uses \tanh and proposes content to write into memory:

Computing the input node \tilde{\mathbf{C}}_t.

Input-node equation

\tilde{\mathbf{C}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c).

Same algebra four times — only the activation and what each output controls differ.

Memory cell update

The cell state is the LSTM’s long-lived memory.

Computing the cell internal state \mathbf{C}_t.

Cell-update equation

Mix the previous cell with the new proposal, gated elementwise:

\mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t.

If \mathbf{F}_t \approx 1 and \mathbf{I}_t \approx 0, the cell holds its value unchanged across arbitrary horizons. That’s the constant error carousel that fixes vanishing gradients.

Hidden state

Gated, squashed cell:

\mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t).

Computing \mathbf{H}_t from \mathbf{C}_t and the output gate.

From scratch: parameters

Twelve weight matrices and four biases — the same triple() factory four times, one per gate/node:

class LSTMScratch(d2l.Module):
    num_inputs: int
    num_hiddens: int
    sigma: float = 0.01

    def setup(self):
        init_weight = lambda name, shape: self.param(name,
                                                     nn.initializers.normal(self.sigma),
                                                     shape)
        triple = lambda name : (
            init_weight(f'W_x{name}', (self.num_inputs, self.num_hiddens)),
            init_weight(f'W_h{name}', (self.num_hiddens, self.num_hiddens)),
            self.param(f'b_{name}', nn.initializers.zeros, (self.num_hiddens)))

        self.W_xi, self.W_hi, self.b_i = triple('i')  # Input gate
        self.W_xf, self.W_hf, self.b_f = triple('f')  # Forget gate
        self.W_xo, self.W_ho, self.b_o = triple('o')  # Output gate
        self.W_xc, self.W_hc, self.b_c = triple('c')  # Input node

Walk the sequence; at each step compute the four gate/node heads, update \mathbf{C}, then \mathbf{H}. Carry both states forward.

def forward(self, inputs, H_C=None):
    # Use lax.scan primitive instead of looping over the
    # inputs, since scan saves time in jit compilation.
    def scan_fn(carry, X):
        H, C = carry
        I = d2l.sigmoid(d2l.matmul(X, self.W_xi) + (
            d2l.matmul(H, self.W_hi)) + self.b_i)
        F = d2l.sigmoid(d2l.matmul(X, self.W_xf) +
                        d2l.matmul(H, self.W_hf) + self.b_f)
        O = d2l.sigmoid(d2l.matmul(X, self.W_xo) +
                        d2l.matmul(H, self.W_ho) + self.b_o)
        C_tilde = d2l.tanh(d2l.matmul(X, self.W_xc) +
                           d2l.matmul(H, self.W_hc) + self.b_c)
        C = F * C + I * C_tilde
        H = O * d2l.tanh(C)
        return (H, C), H  # return carry, y

    if H_C is None:
        batch_size = inputs.shape[1]
        carry = jnp.zeros((batch_size, self.num_hiddens)), \
                jnp.zeros((batch_size, self.num_hiddens))
    else:
        carry = H_C

    # scan takes the scan_fn, initial carry state, xs with leading axis to be scanned
    carry, outputs = jax.lax.scan(scan_fn, carry, inputs)
    return outputs, carry

Training the from-scratch LSTM

Same RNNLMScratch head, same Trainer, same gradient clipping — only the cell changed. Higher learning rate (lr=4) is fine because gates keep activations bounded.

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
lstm = LSTMScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(lstm, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)

Concise: nn.LSTM

Library cell + cuDNN kernels — usually 5–10× faster than the loop in Python:

class LSTM(d2l.RNN):
    num_hiddens: int

    @nn.compact
    def __call__(self, inputs, H_C=None, training=False):
        # Flax ≥0.8 deprecates OptimizedLSTMCell; use LSTMCell instead.
        if H_C is None:
            batch_size = inputs.shape[1]
            H_C = nn.LSTMCell(features=self.num_hiddens).initialize_carry(
                jax.random.PRNGKey(0), (batch_size, self.num_hiddens))

        LSTM = nn.scan(nn.LSTMCell, variable_broadcast="params",
                       in_axes=0, out_axes=0, split_rngs={"params": False})

        H_C, outputs = LSTM(features=self.num_hiddens)(H_C, inputs)
        return outputs, H_C

Drop into the same LM scaffold and train:

lstm = LSTM(num_hiddens=32)
model = d2l.RNNLM(lstm, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)

Decoding

Predict from a prefix:

model.predict('it has', 20, data.vocab, trainer.state.params)
'it has the ped in a mather'

Recap

  • LSTM = vanilla RNN cell replaced by a memory cell with three multiplicative gates.
  • The cell update \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t is what fixes vanishing gradients.
  • Only \mathbf{H}_t leaves the cell; \mathbf{C}_t is internal.
  • Reuse the from-scratch LM scaffold; nn.LSTM is the cuDNN drop-in for production.
  • Dominant sequence model 2011–2017; many ideas (gating, residual paths) carried into Transformers.