Mathematical Foundations
Numerical Computation Gotchas
Catastrophic cancellation, the log-sum-exp trick, mixed-precision training, the determinism tax, and how to actually debug a NaN in a 70B model.
intermediate · 8 min read
Floating-point math is not math. It does not associate, does not distribute, and occasionally returns inf or NaN from operations that look benign. Every modern LLM is trained in BF16 with FP32 master weights, uses the log-sum-exp trick in softmax, and has a separate code path for "make this run deterministically (slowly)." If you do not know why, you cannot debug the day a checkpoint NaNs at step 47000 and the only clue is one suspicious gradient.
Floating-point basics
IEEE 754 floats represent numbers as sign * mantissa * 2^exponent. Common formats:
| Format | Bits | Mantissa | Exponent | Max | Min normal | Used for |
|---|---|---|---|---|---|---|
| FP64 | 64 | 52 | 11 | ~1.8e308 | ~2.2e-308 | Scientific computing |
| FP32 | 32 | 23 | 8 | ~3.4e38 | ~1.2e-38 | Master weights, optimiser state |
| FP16 | 16 | 10 | 5 | 65504 | ~6e-5 | Legacy mixed precision |
| BF16 | 16 | 7 | 8 | ~3.4e38 | ~1.2e-38 | Modern LLM training |
| FP8 (E4M3 / E5M2) | 8 | 3 or 2 | 4 or 5 | varies | varies | Inference, some training |
BF16 has the same dynamic range as FP32 (because the exponent is unchanged) but only 7 mantissa bits. Precision is awful, range is fine. FP16 has more precision but tiny range, which causes silent underflow during training - the original motivation for loss scaling.
Catastrophic cancellation
Subtracting two nearly equal numbers obliterates precision. Classic example:
a = 1.0
b = 1.0 + 1e-8
b - a # in fp32: 0.0, not 1e-8
You started with two numbers good to 7 digits and ended with zero meaningful digits. In ML this shows up in:
- Computing variance as
E[X^2] - E[X]^2when both terms are large and close. Use Welford's algorithm or the centred formula. - Computing
1 - softmax_maxwhere one logit dominates. Use1 - p_maxonly after computingpstably. - Computing small differences in losses across runs. Trust the loss curve direction, not the absolute fourth decimal.
The log-sum-exp trick
Computing softmax naively:
softmax(x_i) = exp(x_i) / sum_j exp(x_j)
If any x_i exceeds ~88 (in fp32) or ~16 (in fp16), exp(x_i) overflows to inf and the whole vector becomes nan. If all x_j are very negative, the denominator underflows to 0 and you divide by zero.
The fix: subtract the max before exponentiating.
m = max(x)
softmax(x_i) = exp(x_i - m) / sum_j exp(x_j - m)
The result is mathematically identical (the exp(-m) cancels in numerator and denominator) but the largest exponent is now 0, so no overflow. At least one term is exp(0) = 1, so no underflow.
The same idea extended:
logsumexp(x) = m + log(sum_j exp(x_j - m))
log_softmax(x_i) = x_i - logsumexp(x)
Every serious framework has a fused log_softmax kernel. Use it. Never write log(softmax(x)) by hand.
Underflow in softmax and attention
Attention multiplies QK^T / sqrt(d_k) then softmaxes. The pre-softmax values can grow large for long sequences or high d_k. Several stability tricks:
- Scaling by
sqrt(d_k)keeps the variance of dot products from growing with dimension. - Computing softmax in FP32 even when the rest of the layer is BF16. The intermediate sum-of-exps is the most overflow-prone op in the model.
- FlashAttention computes attention in tiles and uses an online softmax that maintains running max and sum without ever materialising the full attention matrix. The numerical care matches the memory care.
For long-context inference (100k+ tokens), even FP32 softmax can lose precision in the tail. This is one of the reasons very long contexts sometimes degrade quality silently.
Why FP32 master weights + BF16 compute
Modern mixed-precision training keeps two copies of every weight:
- A BF16 copy used for forward and backward passes (cheap compute, half the memory bandwidth).
- An FP32 master copy that receives the gradient updates.
Why two copies? Because a single optimiser step might change a weight by 1e-7 and the previous weight is 0.1. In BF16 with 7 mantissa bits, 0.1 + 1e-7 = 0.1 (the update is rounded away). The model would stop learning. FP32 weights accumulate the small updates faithfully; the BF16 copy gets refreshed each step. Loss scaling for FP16 training is the same trick - multiply the loss by a constant so gradients are large enough not to underflow.
The determinism / reproducibility tax
GPU matmul is not deterministic by default. Reduction order in cuBLAS depends on tensor shapes and tile selection at kernel launch time. Two identical runs can produce loss values that differ by 1e-6 per step and significantly different checkpoints after many steps.
Making training fully deterministic requires:
torch.use_deterministic_algorithms(True).torch.backends.cudnn.deterministic = True,torch.backends.cudnn.benchmark = False.- Setting
CUBLAS_WORKSPACE_CONFIG. - Pinning all RNG seeds (Python, NumPy, PyTorch, CUDA).
- Avoiding atomic operations in custom kernels.
The cost is 10-30% throughput on common workloads. Most production training runs accept non-determinism because the speed difference compounds across months of training. Reproducibility comes from saving checkpoints and configs, not from bit-exact replay.
NaN debugging
A NaN in your loss is a forensic puzzle. The actual NaN-producing op is usually many layers before the symptom. A working procedure:
- Capture the step. Save the input batch, full model state, and optimiser state immediately when NaN is detected. NaNs are sticky - they propagate forward, so the corrupted state is itself a clue.
- Walk forward. Re-run forward pass with
torch.autograd.detect_anomaly()or equivalent. It identifies the first op that produced a NaN. - Check the inputs to that op. Look for
inf,0in denominators, negative numbers undersqrtorlog, exp of large positive values. - Check gradients. A NaN in forward often follows an
infin backward in the previous step. Gradient logging during training catches this earlier.
Common culprits ranked by frequency:
| Symptom | Likely cause |
|---|---|
| NaN in loss | log(0) from softmax tail; division by attention denominator |
| NaN in weights | Exploding gradient at previous step; grad clip threshold too high |
| inf in attention | Pre-softmax score overflowed BF16; missing log-sum-exp |
| Loss spike then NaN | LR too high; LR warmup too short |
| Slow drift to NaN | Numerical accumulation; switch to FP32 in specific layer |
Loss scaling, gradient clipping, FP32 softmax, and FP32 master weights remove most NaN sources. The remaining ones are usually data (a corrupt batch, a numeric outlier) or genuine architectural bugs.
Common pitfalls
- Comparing floating-point numbers for equality. Almost always wrong. Use
abs(a - b) < tol. - Summing a long list of small floats by accumulation. Loses precision linearly in list length. Use Kahan summation or pairwise summation.
- Assuming BF16 inference matches BF16 training. Training uses FP32 accumulation inside matmul; many inference engines use BF16 accumulation. Outputs can drift.
- Mixing float types in attention. Promote everything to FP32 inside softmax, then cast back. Forgetting this is a top-five source of "training was fine, why is inference broken" bugs.
Further reading
- The Log-Sum-Exp Trick - Gregory Gundersen's derivation and worked examples.
- Deep Learning Book - Chapter 4: Numerical Computation - Goodfellow et al, the conditioning and stability primer.
- Train With Mixed Precision - NVIDIA's documentation on FP16 / BF16 training mechanics.