Vanilla RNNs hit a ceiling: gradients vanish across long sequences. LSTMs (Hochreiter & Schmidhuber, 1997) fix this by giving each unit a memory cell with a self-loop of weight 1 and three learned gates.
Forget gate\mathbf{F}_t — keep or wipe memory.
Input gate\mathbf{I}_t — let new content in.
Output gate\mathbf{O}_t — expose or hide memory.
For two decades, the sequence model — speech, translation, language modeling — until Transformers took over (2017).
Gates at a glance
The input \mathbf{X}_t and previous hidden state \mathbf{H}_{t-1} feed three sigmoid gates.
If \mathbf{F}_t \approx 1 and \mathbf{I}_t \approx 0, the cell holds its value unchanged across arbitrary horizons. That’s the constant error carousel that fixes vanishing gradients.
Walk the sequence; at each step compute the four gate/node heads, update \mathbf{C}, then \mathbf{H}. Carry both states forward.
def forward(self, inputs, H_C=None):# Use lax.scan primitive instead of looping over the# inputs, since scan saves time in jit compilation.def scan_fn(carry, X): H, C = carry I = d2l.sigmoid(d2l.matmul(X, self.W_xi) + ( d2l.matmul(H, self.W_hi)) +self.b_i) F = d2l.sigmoid(d2l.matmul(X, self.W_xf) + d2l.matmul(H, self.W_hf) +self.b_f) O = d2l.sigmoid(d2l.matmul(X, self.W_xo) + d2l.matmul(H, self.W_ho) +self.b_o) C_tilde = d2l.tanh(d2l.matmul(X, self.W_xc) + d2l.matmul(H, self.W_hc) +self.b_c) C = F * C + I * C_tilde H = O * d2l.tanh(C)return (H, C), H # return carry, yif H_C isNone: batch_size = inputs.shape[1] carry = jnp.zeros((batch_size, self.num_hiddens)), \ jnp.zeros((batch_size, self.num_hiddens))else: carry = H_C# scan takes the scan_fn, initial carry state, xs with leading axis to be scanned carry, outputs = jax.lax.scan(scan_fn, carry, inputs)return outputs, carry
Training the from-scratch LSTM
Same RNNLMScratch head, same Trainer, same gradient clipping — only the cell changed. Higher learning rate (lr=4) is fine because gates keep activations bounded.