Deep Learning
Recurrent Networks: RNN, LSTM, GRU
How gating fixed the vanishing-gradient problem in RNNs, and why transformers displaced them everywhere except streaming and on-device workloads.
intermediate · 8 min read
RNNs process sequences one step at a time, carrying a hidden state forward. Until 2017 they were the default for language, speech, and time series. Transformers replaced them for most tasks - but RNNs still own a few niches where their constant-memory step is exactly what you need.
The vanilla RNN
At each step t, the network mixes the previous hidden state with the current input:
h_t = tanh(W_h @ h_{t-1} + W_x @ x_t + b)
y_t = W_y @ h_t
Train it with backprop-through-time: unroll the recurrence, treat it as a deep feed-forward net sharing weights, then backpropagate. The trouble is what happens to gradients across many steps.
Vanishing and exploding gradients
The gradient at step t with respect to step t-k involves the product of k Jacobians of the recurrence. With tanh activations and weights initialised in the usual way:
- If the spectral radius of
W_his< 1, gradients shrink geometrically. After 30 steps they are numerically zero. The network cannot learn long-range dependencies. - If the spectral radius is
> 1, gradients explode. Training diverges in a few steps.
Gradient clipping fixes the explosion. Vanishing is structural - you cannot clip your way out of a zero.
The LSTM gate
Hochreiter and Schmidhuber's 1997 LSTM introduced a separate cell state c_t that flows through time mostly unchanged, modulated by three gates:
f_t = sigmoid(W_f @ [h_{t-1}, x_t]) # forget gate
i_t = sigmoid(W_i @ [h_{t-1}, x_t]) # input gate
o_t = sigmoid(W_o @ [h_{t-1}, x_t]) # output gate
c_tilde = tanh(W_c @ [h_{t-1}, x_t])
c_t = f_t * c_{t-1} + i_t * c_tilde
h_t = o_t * tanh(c_t)
The crucial bit is c_t = f_t * c_{t-1} + .... If the forget gate stays near 1 the cell state is approximately the identity map across time, so gradients propagate unimpeded. The gates learn when to forget, when to write, when to read.
GRU
Cho et al's GRU (2014) collapses the input and forget gates into a single update gate and merges cell state into hidden state. Two gates instead of three, roughly 25% fewer parameters:
z_t = sigmoid(W_z @ [h_{t-1}, x_t]) # update gate
r_t = sigmoid(W_r @ [h_{t-1}, x_t]) # reset gate
h_tilde = tanh(W @ [r_t * h_{t-1}, x_t])
h_t = (1 - z_t) * h_{t-1} + z_t * h_tilde
On most benchmarks GRU and LSTM are within noise of each other. GRU trains slightly faster. LSTM has slightly more expressive cell-state dynamics. The choice rarely matters in practice.
Why transformers won
Three reasons, in order of importance:
- Parallel training. Self-attention computes all positions in parallel. An RNN must wait for step
t-1to finish before starting stept. On modern GPUs this is the difference between training on 10B tokens and 10T tokens. - Effective context. LSTMs were quoted as handling "hundreds of tokens" in 2016 papers. In practice signal decays well before that. Attention reaches every prior token directly with O(1) path length.
- Scaling laws. Transformers scale predictably with parameters, data, compute. RNNs plateau.
Where RNNs still win
- Streaming inference. A transformer needs the full prefix in KV cache - memory grows linearly with sequence length. An RNN holds a fixed-size hidden state. For an always-on speech recogniser running for hours, this is decisive.
- Tiny models on device. A 1M-parameter GRU on a microcontroller doing keyword spotting beats anything attention-based at that size budget.
- State-space hybrids. Mamba and RWKV reintroduce recurrent computation with modern training tricks (parallel scan, selective gating) and match transformers on some long-context benchmarks. The recurrent skeleton is being rehabilitated.
Further reading
- Understanding LSTM Networks - Christopher Olah's diagrams are still the clearest explanation.
- The Unreasonable Effectiveness of Recurrent Neural Networks - Karpathy 2015, char-RNN intuition.
- Learning Phrase Representations using RNN Encoder-Decoder - Cho et al, the GRU paper.