← Concept library

Deep Learning

Backpropagation and Automatic Differentiation

How reverse-mode autodiff turns the chain rule into an efficient gradient algorithm, and the design choices PyTorch and JAX make to implement it at scale.

intermediate · 8 min read

Backpropagation is the chain rule applied to a computation graph in a specific order: reverse topological, with intermediate values cached from the forward pass. Every gradient-based learning system - PyTorch, JAX, TensorFlow, Flax - is fundamentally a reverse-mode autodiff engine. Understanding the algorithm tells you why activation memory dominates GPU usage and why gradient checkpointing trades compute for memory.

The chain rule, mechanically

Given y = f(g(h(x))), the gradient is the product of local Jacobians evaluated at intermediate values:

dy/dx = (df/dg) * (dg/dh) * (dh/dx)

For a scalar-valued loss L and matrix-valued intermediates, this becomes a sequence of matrix multiplications. The question is which order you multiply them in.

Forward vs reverse mode

Two orderings of the same chain rule:

  • Forward mode propagates derivatives along with values, left to right. Cost per input: one extra evaluation. For a function R^n -> R^m, computing the full Jacobian takes n forward passes.
  • Reverse mode evaluates the function first (forward pass), then propagates derivatives right to left from the output. Cost per output: one extra backward pass. For R^n -> R^m, the full Jacobian takes m backward passes.

ML loss functions are R^n -> R where n is the parameter count (billions) and the output dimension is 1. Reverse mode gets all n gradients in a single backward pass; forward mode would need n passes. That is why nobody uses forward mode for neural network training.

Forward mode is still useful in a few places:

  • Hessian-vector products via jvp(grad(f)) - one forward and one reverse pass instead of computing the full Hessian.
  • Differentiating through ODE solvers and physics simulations where the output dimension exceeds the input dimension.
  • Sensitivity analysis when you have a few inputs and many outputs.

The backprop algorithm

For a feed-forward network with layers f_1, f_2, ..., f_L:

# Forward pass - cache activations
a_0 = x
for l in 1..L:
    a_l = f_l(a_{l-1}, theta_l)
loss = L(a_L, y)

# Backward pass - apply chain rule in reverse
g = dL/da_L
for l in L..1:
    dtheta_l = g * df_l/dtheta_l(a_{l-1}, theta_l)
    g = g * df_l/da_{l-1}(a_{l-1}, theta_l)

Two things to notice:

  1. The activations a_l from the forward pass must be retained until the backward pass uses them. Activation memory scales with depth and sequence length and dominates GPU memory for large models.
  2. Each backward step is roughly the same cost as the forward step (two matrix multiplies per layer instead of one). The standard rule of thumb: backward pass is 2x the forward pass FLOPs.

How PyTorch and JAX implement this

The two dominant frameworks make different design choices.

PyTorch: define-by-run. Every tensor with requires_grad=True is a node in a graph that is built dynamically as operations execute. Each operation records a Function object that knows how to compute its local gradient. loss.backward() walks this graph in reverse topological order, calling each Function's backward method.

x = torch.tensor([1.0, 2.0], requires_grad=True)
y = (x ** 2).sum()
y.backward()
# x.grad is now [2.0, 4.0]

The graph is rebuilt every iteration. This is what makes PyTorch's debugging story good - you can put print statements and pdb breakpoints between operations.

JAX: function transformations. Code is traced into a typed intermediate representation (jaxpr), then transformed. jax.grad(f) returns a new function whose forward pass produces the gradient. There is no implicit graph state; everything is functional.

import jax
import jax.numpy as jnp
f = lambda x: jnp.sum(x ** 2)
grad_f = jax.grad(f)
# grad_f(jnp.array([1.0, 2.0])) returns [2.0, 4.0]

JAX composes transformations: jit(grad(vmap(f))) produces a fully JIT-compiled, batched, differentiated version of f. The functional API has a steeper learning curve but the composition story is significantly cleaner than PyTorch's, which is why JAX is dominant in research codebases at DeepMind and Anthropic.

The activation memory problem

For a transformer with L layers, hidden dim d, sequence length n, batch size b, activation memory is roughly:

mem = L * b * n * d * (sizeof(dtype) * factor)

For Llama-2 70B with sequence 4096 and batch 4 in bf16, this is over 100 GB - more than weights and optimiser state combined. You cannot just buy more memory; H100 GPUs have 80 GB each.

Gradient checkpointing

Chen et al (2016). Trade compute for memory: in the forward pass, discard intermediate activations. In the backward pass, recompute them on demand by re-running the forward.

The simplest variant ("uniform checkpointing") keeps activations at every sqrt(L)-th layer and recomputes the rest. Memory becomes O(sqrt(L)) instead of O(L); total compute becomes ~1.33x the original (one extra forward pass per backward).

import torch.utils.checkpoint as cp
def block(x): return ffn(attn(x))
y = cp.checkpoint(block, x)   # activations inside `block` are dropped, recomputed in backward

In practice production LLM training checkpoints at the transformer-block boundary. Tuning which blocks to checkpoint is a recurring optimisation in pretraining engineering work.

Trade-offs

  • More compute, less memory. Standard rule: checkpointing buys you 2-3x larger batch sizes at a ~30% wall-clock penalty.
  • Selective checkpointing. Activation memory of attention is much larger than MLP at long sequences (because of the QK^T matrix). Checkpointing only attention often gets most of the memory win at half the compute cost.
  • Activation offloading. An alternative: copy activations to CPU memory during forward, copy back during backward. Cheaper compute, more PCIe pressure. Useful when GPU memory is the binding constraint but bandwidth is not.

Further reading