← Concept library

Training Infrastructure

Mixed-Precision Training (FP16, BF16, FP8)

How lower-precision formats halve memory and double throughput on tensor cores, why BF16 displaced FP16 for training, and what FP8 changes on H100 and Blackwell.

intermediate · 8 min read

Tensor cores on every Volta-and-newer NVIDIA GPU run at roughly 2x the FLOPs in 16-bit than in 32-bit, and 4x in 8-bit on Hopper and Blackwell. Memory bandwidth scales the same way: a 16-bit tensor is half the bytes to move. Mixed-precision training is therefore not optional for any modern run; the only question is which 16-bit format to use and whether you can risk pushing to 8-bit.

FP16 vs BF16: range vs precision

Both formats are 16 bits. They split the budget differently.

Format Sign Exponent Mantissa Dynamic range Smallest normal
FP32 1 8 23 ~1e-38 to ~3e38 ~1.2e-38
FP16 1 5 10 ~6e-5 to ~65504 ~6.1e-5
BF16 1 8 7 ~1e-38 to ~3e38 ~1.2e-38

FP16 has more mantissa bits (better precision) but a much narrower range. BF16 keeps FP32's exponent range, sacrificing mantissa.

For training, range matters far more than precision. Gradients during early training can swing across many orders of magnitude. FP16 cannot represent gradients smaller than roughly 6e-5; everything below that flushes to zero. BF16 just keeps going.

Loss scaling: the FP16 workaround

To use FP16 without losing tiny gradients, you scale the loss up by a large constant S (typically 2^15 or higher) before backward, then unscale gradients by S before the optimiser step. The shifted gradients sit higher in the FP16 range and survive.

scaler = torch.cuda.amp.GradScaler()
with torch.autocast(device_type="cuda", dtype=torch.float16):
    loss = model(batch).loss
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()

Dynamic loss scaling automatically halves S on overflow and gradually probes higher when steady. It works, but it is a thing that can go wrong: bad S schedules cause silent training stalls or NaN spirals.

Why BF16 won for training

BF16's exponent range means no loss scaling is required - you autocast and forget. That alone removed an entire class of bugs from frontier training. Specifically:

  • No GradScaler. No tuning of initial scale, no overflow recovery cycles.
  • Same dynamic range as FP32. Gradients, activations, and optimiser states convert without thinking.
  • Same bit width as FP16. All the memory and bandwidth wins remain.
  • Slightly worse precision. In practice this is invisible for transformer training; the optimisation is noisy enough that 7 mantissa bits suffice.

Every frontier LLM since roughly 2021 (PaLM, GPT-4, Llama 2 onward, Mistral, Gemini) trains in BF16. FP16 lingers only on Volta / Turing where BF16 is not natively supported.

Master weights stay in FP32

The standard "mixed precision" recipe keeps two copies of the model:

  • Compute copy in BF16 or FP16. Forward, backward, and gradients run here.
  • Master copy in FP32. The optimiser step reads master weights, applies the gradient, and writes the updated FP32 master back. The BF16/FP16 compute copy is re-derived from the master.

Why? Optimiser updates can be much smaller than the weights themselves. weight -= lr * grad where lr * grad is six orders of magnitude smaller than weight underflows in BF16. Keeping FP32 master weights costs an extra 4 bytes per parameter and removes the problem entirely. Frameworks (DeepSpeed, FSDP with MixedPrecision) handle this automatically.

FP8 on Hopper and Blackwell

H100 introduced two FP8 formats targeting different parts of the pipeline:

Format Sign Exponent Mantissa Use
E4M3 1 4 3 forward activations and weights
E5M2 1 5 2 gradients (needs the range)

You cannot just cast to FP8 the way you can with BF16. FP8 has a representable range so small (E4M3 max ~448, E5M2 max ~57344) that every tensor needs per-tensor scaling: pick a scaling factor that maps the tensor's actual range into FP8's representable window, store the scale alongside the tensor, and unscale on the way back to higher precision for accumulation.

NVIDIA's Transformer Engine library handles the scaling factors automatically - tracking per-tensor histories and updating scales each step. Hand-rolling FP8 without it is a fast path to NaNs.

The convergence-stability trade-off

Lower precision means more noise in every matmul. For training:

  • BF16 matches FP32 final loss almost exactly across well-tested architectures.
  • FP8 matches BF16 final loss for models up to at least 175B parameters when done with per-tensor scaling and selective FP32 keepers (LayerNorm, softmax, residual accumulation). The original NVIDIA FP8 paper reports this on GPT-3-class runs.
  • INT8 and below: training does not work reliably yet; these formats are for inference quantisation, not training.

The keepers matter. LayerNorm reductions, softmax, and the residual accumulator are noise-sensitive enough that running them in FP8 destabilises training. The "FP8 mixed precision" recipe is really "FP8 matmuls, BF16 everything else, FP32 master weights and norm statistics."

When to keep master weights in FP32

Always, unless you have proven empirically that you can drop them. The 4 bytes per parameter is cheap compared to losing a multi-week training run because the optimiser drift accumulated in BF16. FSDP's MixedPrecision config exposes param_dtype, reduce_dtype, and buffer_dtype separately; set param_dtype=torch.bfloat16 for compute and let the optimiser keep its own FP32 shard.

Further reading