← Concept library

Training Infrastructure

Data Parallelism and DDP

How replicating the model and sharding the batch across GPUs scales training, and why AllReduce is the primitive every framework eventually depends on.

intermediate · 8 min read

The single-GPU training loop hits a wall the moment your batch outgrows one device. Data parallelism (DP) is the cheapest way through that wall: replicate the model on every GPU, hand each replica a different shard of the batch, and average the gradients before the optimiser step. It is bandwidth-bound rather than compute-bound, and the entire stack lives or dies on how fast you can sum a few hundred megabytes of gradient buffers across the cluster.

The basic algorithm

  1. Replicate the full model on each of N GPUs.
  2. Split the global batch into N shards, one per GPU.
  3. Each replica runs forward and backward on its shard independently.
  4. Sum gradients across all replicas (AllReduce).
  5. Each replica runs the optimiser step on the identical, summed gradient.

Step 5 is what keeps replicas in sync without ever exchanging parameters. As long as every rank starts from the same weights and applies the same averaged gradient, they stay bit-identical (modulo non-deterministic kernels).

AllReduce is the only collective that matters

AllReduce takes one tensor per rank and leaves the sum (or average) on every rank. The naive implementation - send everything to rank 0, sum, broadcast back - costs 2(N-1)/N of the data through a single link and saturates that link as you add GPUs.

The two implementations you will actually meet:

  • Ring-AllReduce. Each rank sends 1/N of its buffer to its right neighbour while receiving 1/N from its left, repeated 2(N-1) times. Bandwidth-optimal: each link carries roughly 2(N-1)/N of the data regardless of N. Latency grows linearly with N, so it dominates for small messages.
  • Tree / double-tree AllReduce. Builds a reduction tree with log(N) depth. Latency-optimal for small payloads; NCCL's double-tree variant retains close to ring's bandwidth for large payloads. NCCL picks ring vs tree per message size automatically.

NCCL (NVIDIA Collective Communications Library) is the implementation that PyTorch, JAX, TensorFlow and DeepSpeed all call into on NVIDIA hardware. It exploits NVLink within a node and InfiniBand / RoCE between nodes.

PyTorch DDP in practice

torch.nn.parallel.DistributedDataParallel wraps your model and registers backward hooks on every parameter. As gradients are produced during the backward pass, DDP buckets them (default 25 MB) and kicks off an AllReduce on each bucket. The AllReduce overlaps with the rest of the backward, so for a well-tuned setup you mostly pay the slowest of (compute time, comms time) rather than their sum.

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

dist.init_process_group(backend="nccl")
model = DDP(model.to(rank), device_ids=[rank], bucket_cap_mb=25)

for batch in loader:                      # DistributedSampler shards the data
    loss = model(batch).loss
    loss.backward()                       # AllReduce kicks off here, per bucket
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

A few things that bite people:

  • Unused parameters. If a parameter does not get a gradient on every rank, DDP hangs waiting for the missing AllReduce. Set find_unused_parameters=True (slower) or restructure the model.
  • Forgetting DistributedSampler. Without it every rank sees the same data, which silently wastes N-1 GPUs.
  • broadcast_buffers=True is the default. BatchNorm running stats and similar buffers get broadcast from rank 0 every forward. For LayerNorm-only models this is wasted bandwidth.

Gradient accumulation for effective-batch scaling

When you want a larger effective batch than your GPUs can hold, accumulate gradients across K micro-batches before stepping the optimiser. With DDP, wrap the inner K-1 iterations in model.no_sync() to skip the AllReduce on those steps and only sync on the final one - otherwise you pay the comms cost K times for no benefit.

for i, batch in enumerate(loader):
    sync_ctx = model.no_sync() if (i + 1) % K != 0 else contextlib.nullcontext()
    with sync_ctx:
        (model(batch).loss / K).backward()
    if (i + 1) % K == 0:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

Effective batch = per_gpu_batch * num_gpus * accumulation_steps. Pick that number first, then divide responsibilities.

When DP alone runs out

The model and its training state must fit on one GPU. For a model with P parameters in FP16 plus an Adam optimiser, you need roughly:

Component Bytes per parameter
FP16 weights 2
FP16 gradients 2
FP32 master weights (mixed precision) 4
Adam moment 1 (FP32) 4
Adam moment 2 (FP32) 4
Total optimiser state 16

A 7B-parameter model needs roughly 112 GB just for state, before activations. An 80 GB H100 cannot hold it. That is where ZeRO, FSDP, and tensor parallelism start.

Further reading