Forward pass

Long Short-Term Memory (LSTM)

Long Short-Term Memory

Vanilla RNNs hit a ceiling: gradients vanish across long sequences. LSTMs (Hochreiter & Schmidhuber, 1997) fix this by giving each unit a memory cell with a self-loop of weight 1 and three learned gates.

  • Forget gate \mathbf{F}_t — keep or wipe memory.
  • Input gate \mathbf{I}_t — let new content in.
  • Output gate \mathbf{O}_t — expose or hide memory.

For two decades, the sequence model — speech, translation, language modeling — until Transformers took over (2017).

Gates at a glance

The input \mathbf{X}_t and previous hidden state \mathbf{H}_{t-1} feed three sigmoid gates.

Computing the input, forget, and output gates.

The three gates

The gates are learned, elementwise switches:

\mathbf{I}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i), \mathbf{F}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f), \mathbf{O}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o).

Each value lies in (0, 1):

  • \mathbf{F}_t decides what to retain.
  • \mathbf{I}_t decides what new content to write.
  • \mathbf{O}_t decides what memory to expose.

Plus an input node

A fourth head — the input node — uses \tanh and proposes content to write into memory:

Computing the input node \tilde{\mathbf{C}}_t.

Input-node equation

\tilde{\mathbf{C}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c).

Same algebra four times — only the activation and what each output controls differ.

Memory cell update

The cell state is the LSTM’s long-lived memory.

Computing the cell internal state \mathbf{C}_t.

Cell-update equation

Mix the previous cell with the new proposal, gated elementwise:

\mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t.

If \mathbf{F}_t \approx 1 and \mathbf{I}_t \approx 0, the cell holds its value unchanged across arbitrary horizons. That’s the constant error carousel that fixes vanishing gradients.

Hidden state

Gated, squashed cell:

\mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t).

Computing \mathbf{H}_t from \mathbf{C}_t and the output gate.

From scratch: parameters

Twelve weight matrices and four biases — the same triple() factory four times, one per gate/node:

class LSTMScratch(d2l.Module):
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()

        init_weight = lambda *shape: nn.Parameter(d2l.randn(*shape) * sigma)
        triple = lambda: (init_weight(num_inputs, num_hiddens),
                          init_weight(num_hiddens, num_hiddens),
                          nn.Parameter(d2l.zeros(num_hiddens)))
        self.W_xi, self.W_hi, self.b_i = triple()  # Input gate
        self.W_xf, self.W_hf, self.b_f = triple()  # Forget gate
        self.W_xo, self.W_ho, self.b_o = triple()  # Output gate
        self.W_xc, self.W_hc, self.b_c = triple()  # Input node

Walk the sequence; at each step compute the four gate/node heads, update \mathbf{C}, then \mathbf{H}. Carry both states forward.

def forward(self, inputs, H_C=None):
    if H_C is None:
        # Initial state with shape: (batch_size, num_hiddens)
        H = d2l.zeros((inputs.shape[1], self.num_hiddens),
                      device=inputs.device)
        C = d2l.zeros((inputs.shape[1], self.num_hiddens),
                      device=inputs.device)
    else:
        H, C = H_C
    outputs = []
    for X in inputs:
        I = d2l.sigmoid(d2l.matmul(X, self.W_xi) +
                        d2l.matmul(H, self.W_hi) + self.b_i)
        F = d2l.sigmoid(d2l.matmul(X, self.W_xf) +
                        d2l.matmul(H, self.W_hf) + self.b_f)
        O = d2l.sigmoid(d2l.matmul(X, self.W_xo) +
                        d2l.matmul(H, self.W_ho) + self.b_o)
        C_tilde = d2l.tanh(d2l.matmul(X, self.W_xc) +
                           d2l.matmul(H, self.W_hc) + self.b_c)
        C = F * C + I * C_tilde
        H = O * d2l.tanh(C)
        outputs.append(H)
    return outputs, (H, C)

Training the from-scratch LSTM

Same RNNLMScratch head, same Trainer, same gradient clipping — only the cell changed. Higher learning rate (lr=4) is fine because gates keep activations bounded.

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
lstm = LSTMScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(lstm, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)

Concise: nn.LSTM

Library cell + cuDNN kernels — usually 5–10× faster than the loop in Python:

class LSTM(d2l.RNN):
    def __init__(self, num_inputs, num_hiddens):
        d2l.Module.__init__(self)
        self.save_hyperparameters()
        self.rnn = nn.LSTM(num_inputs, num_hiddens)

    def forward(self, inputs, H_C=None):
        return self.rnn(inputs, H_C)

Drop into the same LM scaffold and train:

lstm = LSTM(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLM(lstm, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)

Decoding

Predict from a prefix:

model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has of the time the tim'

Recap

  • LSTM = vanilla RNN cell replaced by a memory cell with three multiplicative gates.
  • The cell update \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t is what fixes vanishing gradients.
  • Only \mathbf{H}_t leaves the cell; \mathbf{C}_t is internal.
  • Reuse the from-scratch LM scaffold; nn.LSTM is the cuDNN drop-in for production.
  • Dominant sequence model 2011–2017; many ideas (gating, residual paths) carried into Transformers.