Training Infrastructure
Data Parallelism and DDP
How replicating the model and sharding the batch across GPUs scales training, and why AllReduce is the primitive every framework eventually depends on.
intermediate · 8 min read
The single-GPU training loop hits a wall the moment your batch outgrows one device. Data parallelism (DP) is the cheapest way through that wall: replicate the model on every GPU, hand each replica a different shard of the batch, and average the gradients before the optimiser step. It is bandwidth-bound rather than compute-bound, and the entire stack lives or dies on how fast you can sum a few hundred megabytes of gradient buffers across the cluster.
The basic algorithm
- Replicate the full model on each of
NGPUs. - Split the global batch into
Nshards, one per GPU. - Each replica runs forward and backward on its shard independently.
- Sum gradients across all replicas (AllReduce).
- Each replica runs the optimiser step on the identical, summed gradient.
Step 5 is what keeps replicas in sync without ever exchanging parameters. As long as every rank starts from the same weights and applies the same averaged gradient, they stay bit-identical (modulo non-deterministic kernels).
AllReduce is the only collective that matters
AllReduce takes one tensor per rank and leaves the sum (or average) on every rank. The naive implementation - send everything to rank 0, sum, broadcast back - costs 2(N-1)/N of the data through a single link and saturates that link as you add GPUs.
The two implementations you will actually meet:
- Ring-AllReduce. Each rank sends
1/Nof its buffer to its right neighbour while receiving1/Nfrom its left, repeated2(N-1)times. Bandwidth-optimal: each link carries roughly2(N-1)/Nof the data regardless ofN. Latency grows linearly withN, so it dominates for small messages. - Tree / double-tree AllReduce. Builds a reduction tree with
log(N)depth. Latency-optimal for small payloads; NCCL's double-tree variant retains close to ring's bandwidth for large payloads. NCCL picks ring vs tree per message size automatically.
NCCL (NVIDIA Collective Communications Library) is the implementation that PyTorch, JAX, TensorFlow and DeepSpeed all call into on NVIDIA hardware. It exploits NVLink within a node and InfiniBand / RoCE between nodes.
PyTorch DDP in practice
torch.nn.parallel.DistributedDataParallel wraps your model and registers backward hooks on every parameter. As gradients are produced during the backward pass, DDP buckets them (default 25 MB) and kicks off an AllReduce on each bucket. The AllReduce overlaps with the rest of the backward, so for a well-tuned setup you mostly pay the slowest of (compute time, comms time) rather than their sum.
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group(backend="nccl")
model = DDP(model.to(rank), device_ids=[rank], bucket_cap_mb=25)
for batch in loader: # DistributedSampler shards the data
loss = model(batch).loss
loss.backward() # AllReduce kicks off here, per bucket
optimizer.step()
optimizer.zero_grad(set_to_none=True)
A few things that bite people:
- Unused parameters. If a parameter does not get a gradient on every rank, DDP hangs waiting for the missing AllReduce. Set
find_unused_parameters=True(slower) or restructure the model. - Forgetting
DistributedSampler. Without it every rank sees the same data, which silently wastesN-1GPUs. broadcast_buffers=Trueis the default. BatchNorm running stats and similar buffers get broadcast from rank 0 every forward. For LayerNorm-only models this is wasted bandwidth.
Gradient accumulation for effective-batch scaling
When you want a larger effective batch than your GPUs can hold, accumulate gradients across K micro-batches before stepping the optimiser. With DDP, wrap the inner K-1 iterations in model.no_sync() to skip the AllReduce on those steps and only sync on the final one - otherwise you pay the comms cost K times for no benefit.
for i, batch in enumerate(loader):
sync_ctx = model.no_sync() if (i + 1) % K != 0 else contextlib.nullcontext()
with sync_ctx:
(model(batch).loss / K).backward()
if (i + 1) % K == 0:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
Effective batch = per_gpu_batch * num_gpus * accumulation_steps. Pick that number first, then divide responsibilities.
When DP alone runs out
The model and its training state must fit on one GPU. For a model with P parameters in FP16 plus an Adam optimiser, you need roughly:
| Component | Bytes per parameter |
|---|---|
| FP16 weights | 2 |
| FP16 gradients | 2 |
| FP32 master weights (mixed precision) | 4 |
| Adam moment 1 (FP32) | 4 |
| Adam moment 2 (FP32) | 4 |
| Total optimiser state | 16 |
A 7B-parameter model needs roughly 112 GB just for state, before activations. An 80 GB H100 cannot hold it. That is where ZeRO, FSDP, and tensor parallelism start.
Further reading
- PyTorch DistributedDataParallel notes - the canonical reference on bucketing, hook ordering, and gotchas.
- NCCL overview - what NCCL guarantees and which primitives it exposes.
- Scaling Deep Learning Training with NCCL - Sylvain Jeaugey on ring vs tree AllReduce.
- HuggingFace multi-GPU training guide - decision table mapping model size to parallelism strategy.