← Concept library

Mathematical Foundations

Numerical Computation Gotchas

Catastrophic cancellation, the log-sum-exp trick, mixed-precision training, the determinism tax, and how to actually debug a NaN in a 70B model.

intermediate · 8 min read

Floating-point math is not math. It does not associate, does not distribute, and occasionally returns inf or NaN from operations that look benign. Every modern LLM is trained in BF16 with FP32 master weights, uses the log-sum-exp trick in softmax, and has a separate code path for "make this run deterministically (slowly)." If you do not know why, you cannot debug the day a checkpoint NaNs at step 47000 and the only clue is one suspicious gradient.

Floating-point basics

IEEE 754 floats represent numbers as sign * mantissa * 2^exponent. Common formats:

Format Bits Mantissa Exponent Max Min normal Used for
FP64 64 52 11 ~1.8e308 ~2.2e-308 Scientific computing
FP32 32 23 8 ~3.4e38 ~1.2e-38 Master weights, optimiser state
FP16 16 10 5 65504 ~6e-5 Legacy mixed precision
BF16 16 7 8 ~3.4e38 ~1.2e-38 Modern LLM training
FP8 (E4M3 / E5M2) 8 3 or 2 4 or 5 varies varies Inference, some training

BF16 has the same dynamic range as FP32 (because the exponent is unchanged) but only 7 mantissa bits. Precision is awful, range is fine. FP16 has more precision but tiny range, which causes silent underflow during training - the original motivation for loss scaling.

Catastrophic cancellation

Subtracting two nearly equal numbers obliterates precision. Classic example:

a = 1.0
b = 1.0 + 1e-8
b - a  # in fp32: 0.0, not 1e-8

You started with two numbers good to 7 digits and ended with zero meaningful digits. In ML this shows up in:

  • Computing variance as E[X^2] - E[X]^2 when both terms are large and close. Use Welford's algorithm or the centred formula.
  • Computing 1 - softmax_max where one logit dominates. Use 1 - p_max only after computing p stably.
  • Computing small differences in losses across runs. Trust the loss curve direction, not the absolute fourth decimal.

The log-sum-exp trick

Computing softmax naively:

softmax(x_i) = exp(x_i) / sum_j exp(x_j)

If any x_i exceeds ~88 (in fp32) or ~16 (in fp16), exp(x_i) overflows to inf and the whole vector becomes nan. If all x_j are very negative, the denominator underflows to 0 and you divide by zero.

The fix: subtract the max before exponentiating.

m = max(x)
softmax(x_i) = exp(x_i - m) / sum_j exp(x_j - m)

The result is mathematically identical (the exp(-m) cancels in numerator and denominator) but the largest exponent is now 0, so no overflow. At least one term is exp(0) = 1, so no underflow.

The same idea extended:

logsumexp(x) = m + log(sum_j exp(x_j - m))
log_softmax(x_i) = x_i - logsumexp(x)

Every serious framework has a fused log_softmax kernel. Use it. Never write log(softmax(x)) by hand.

Underflow in softmax and attention

Attention multiplies QK^T / sqrt(d_k) then softmaxes. The pre-softmax values can grow large for long sequences or high d_k. Several stability tricks:

  • Scaling by sqrt(d_k) keeps the variance of dot products from growing with dimension.
  • Computing softmax in FP32 even when the rest of the layer is BF16. The intermediate sum-of-exps is the most overflow-prone op in the model.
  • FlashAttention computes attention in tiles and uses an online softmax that maintains running max and sum without ever materialising the full attention matrix. The numerical care matches the memory care.

For long-context inference (100k+ tokens), even FP32 softmax can lose precision in the tail. This is one of the reasons very long contexts sometimes degrade quality silently.

Why FP32 master weights + BF16 compute

Modern mixed-precision training keeps two copies of every weight:

  • A BF16 copy used for forward and backward passes (cheap compute, half the memory bandwidth).
  • An FP32 master copy that receives the gradient updates.

Why two copies? Because a single optimiser step might change a weight by 1e-7 and the previous weight is 0.1. In BF16 with 7 mantissa bits, 0.1 + 1e-7 = 0.1 (the update is rounded away). The model would stop learning. FP32 weights accumulate the small updates faithfully; the BF16 copy gets refreshed each step. Loss scaling for FP16 training is the same trick - multiply the loss by a constant so gradients are large enough not to underflow.

The determinism / reproducibility tax

GPU matmul is not deterministic by default. Reduction order in cuBLAS depends on tensor shapes and tile selection at kernel launch time. Two identical runs can produce loss values that differ by 1e-6 per step and significantly different checkpoints after many steps.

Making training fully deterministic requires:

  • torch.use_deterministic_algorithms(True).
  • torch.backends.cudnn.deterministic = True, torch.backends.cudnn.benchmark = False.
  • Setting CUBLAS_WORKSPACE_CONFIG.
  • Pinning all RNG seeds (Python, NumPy, PyTorch, CUDA).
  • Avoiding atomic operations in custom kernels.

The cost is 10-30% throughput on common workloads. Most production training runs accept non-determinism because the speed difference compounds across months of training. Reproducibility comes from saving checkpoints and configs, not from bit-exact replay.

NaN debugging

A NaN in your loss is a forensic puzzle. The actual NaN-producing op is usually many layers before the symptom. A working procedure:

  1. Capture the step. Save the input batch, full model state, and optimiser state immediately when NaN is detected. NaNs are sticky - they propagate forward, so the corrupted state is itself a clue.
  2. Walk forward. Re-run forward pass with torch.autograd.detect_anomaly() or equivalent. It identifies the first op that produced a NaN.
  3. Check the inputs to that op. Look for inf, 0 in denominators, negative numbers under sqrt or log, exp of large positive values.
  4. Check gradients. A NaN in forward often follows an inf in backward in the previous step. Gradient logging during training catches this earlier.

Common culprits ranked by frequency:

Symptom Likely cause
NaN in loss log(0) from softmax tail; division by attention denominator
NaN in weights Exploding gradient at previous step; grad clip threshold too high
inf in attention Pre-softmax score overflowed BF16; missing log-sum-exp
Loss spike then NaN LR too high; LR warmup too short
Slow drift to NaN Numerical accumulation; switch to FP32 in specific layer

Loss scaling, gradient clipping, FP32 softmax, and FP32 master weights remove most NaN sources. The remaining ones are usually data (a corrupt batch, a numeric outlier) or genuine architectural bugs.

Common pitfalls

  • Comparing floating-point numbers for equality. Almost always wrong. Use abs(a - b) < tol.
  • Summing a long list of small floats by accumulation. Loses precision linearly in list length. Use Kahan summation or pairwise summation.
  • Assuming BF16 inference matches BF16 training. Training uses FP32 accumulation inside matmul; many inference engines use BF16 accumulation. Outputs can drift.
  • Mixing float types in attention. Promote everything to FP32 inside softmax, then cast back. Forgetting this is a top-five source of "training was fine, why is inference broken" bugs.

Further reading