Deep Learning
Backpropagation and Automatic Differentiation
How reverse-mode autodiff turns the chain rule into an efficient gradient algorithm, and the design choices PyTorch and JAX make to implement it at scale.
intermediate · 8 min read
Backpropagation is the chain rule applied to a computation graph in a specific order: reverse topological, with intermediate values cached from the forward pass. Every gradient-based learning system - PyTorch, JAX, TensorFlow, Flax - is fundamentally a reverse-mode autodiff engine. Understanding the algorithm tells you why activation memory dominates GPU usage and why gradient checkpointing trades compute for memory.
The chain rule, mechanically
Given y = f(g(h(x))), the gradient is the product of local Jacobians evaluated at intermediate values:
dy/dx = (df/dg) * (dg/dh) * (dh/dx)
For a scalar-valued loss L and matrix-valued intermediates, this becomes a sequence of matrix multiplications. The question is which order you multiply them in.
Forward vs reverse mode
Two orderings of the same chain rule:
- Forward mode propagates derivatives along with values, left to right. Cost per input: one extra evaluation. For a function
R^n -> R^m, computing the full Jacobian takesnforward passes. - Reverse mode evaluates the function first (forward pass), then propagates derivatives right to left from the output. Cost per output: one extra backward pass. For
R^n -> R^m, the full Jacobian takesmbackward passes.
ML loss functions are R^n -> R where n is the parameter count (billions) and the output dimension is 1. Reverse mode gets all n gradients in a single backward pass; forward mode would need n passes. That is why nobody uses forward mode for neural network training.
Forward mode is still useful in a few places:
- Hessian-vector products via
jvp(grad(f))- one forward and one reverse pass instead of computing the full Hessian. - Differentiating through ODE solvers and physics simulations where the output dimension exceeds the input dimension.
- Sensitivity analysis when you have a few inputs and many outputs.
The backprop algorithm
For a feed-forward network with layers f_1, f_2, ..., f_L:
# Forward pass - cache activations
a_0 = x
for l in 1..L:
a_l = f_l(a_{l-1}, theta_l)
loss = L(a_L, y)
# Backward pass - apply chain rule in reverse
g = dL/da_L
for l in L..1:
dtheta_l = g * df_l/dtheta_l(a_{l-1}, theta_l)
g = g * df_l/da_{l-1}(a_{l-1}, theta_l)
Two things to notice:
- The activations
a_lfrom the forward pass must be retained until the backward pass uses them. Activation memory scales with depth and sequence length and dominates GPU memory for large models. - Each backward step is roughly the same cost as the forward step (two matrix multiplies per layer instead of one). The standard rule of thumb: backward pass is 2x the forward pass FLOPs.
How PyTorch and JAX implement this
The two dominant frameworks make different design choices.
PyTorch: define-by-run. Every tensor with requires_grad=True is a node in a graph that is built dynamically as operations execute. Each operation records a Function object that knows how to compute its local gradient. loss.backward() walks this graph in reverse topological order, calling each Function's backward method.
x = torch.tensor([1.0, 2.0], requires_grad=True)
y = (x ** 2).sum()
y.backward()
# x.grad is now [2.0, 4.0]
The graph is rebuilt every iteration. This is what makes PyTorch's debugging story good - you can put print statements and pdb breakpoints between operations.
JAX: function transformations. Code is traced into a typed intermediate representation (jaxpr), then transformed. jax.grad(f) returns a new function whose forward pass produces the gradient. There is no implicit graph state; everything is functional.
import jax
import jax.numpy as jnp
f = lambda x: jnp.sum(x ** 2)
grad_f = jax.grad(f)
# grad_f(jnp.array([1.0, 2.0])) returns [2.0, 4.0]
JAX composes transformations: jit(grad(vmap(f))) produces a fully JIT-compiled, batched, differentiated version of f. The functional API has a steeper learning curve but the composition story is significantly cleaner than PyTorch's, which is why JAX is dominant in research codebases at DeepMind and Anthropic.
The activation memory problem
For a transformer with L layers, hidden dim d, sequence length n, batch size b, activation memory is roughly:
mem = L * b * n * d * (sizeof(dtype) * factor)
For Llama-2 70B with sequence 4096 and batch 4 in bf16, this is over 100 GB - more than weights and optimiser state combined. You cannot just buy more memory; H100 GPUs have 80 GB each.
Gradient checkpointing
Chen et al (2016). Trade compute for memory: in the forward pass, discard intermediate activations. In the backward pass, recompute them on demand by re-running the forward.
The simplest variant ("uniform checkpointing") keeps activations at every sqrt(L)-th layer and recomputes the rest. Memory becomes O(sqrt(L)) instead of O(L); total compute becomes ~1.33x the original (one extra forward pass per backward).
import torch.utils.checkpoint as cp
def block(x): return ffn(attn(x))
y = cp.checkpoint(block, x) # activations inside `block` are dropped, recomputed in backward
In practice production LLM training checkpoints at the transformer-block boundary. Tuning which blocks to checkpoint is a recurring optimisation in pretraining engineering work.
Trade-offs
- More compute, less memory. Standard rule: checkpointing buys you 2-3x larger batch sizes at a ~30% wall-clock penalty.
- Selective checkpointing. Activation memory of attention is much larger than MLP at long sequences (because of the QK^T matrix). Checkpointing only attention often gets most of the memory win at half the compute cost.
- Activation offloading. An alternative: copy activations to CPU memory during forward, copy back during backward. Cheaper compute, more PCIe pressure. Useful when GPU memory is the binding constraint but bandwidth is not.
Further reading
- PyTorch Autograd Mechanics - the official explanation of PyTorch's reverse-mode engine.
- JAX Autodiff Cookbook - composable transformations, JVPs and VJPs.
- Training Deep Nets with Sublinear Memory Cost - Chen et al, the gradient checkpointing paper.