Applied LLMs
Kernel Fusion
Kernel fusion eliminates redundant memory round-trips by merging multiple GPU operations into a single kernel launch, turning memory-bandwidth bottlenecks into throughput wins.
intermediate · 8 min read
A na naive softmax over a (4096, 4096) matrix launches three separate CUDA kernels: one to compute the row maximum, one to compute the exponentials and sum, one to normalise. Each kernel reads and writes the full matrix to global memory. The matrix is 64 MB in fp32. Three passes means roughly 384 MB of DRAM traffic for an operation that, in principle, only needs to read the matrix once and write it once. The compute is trivial; the bottleneck is entirely the memory bus. Kernel fusion is how you fix this.
What "fusion" actually means
A CUDA kernel is a function that runs on the GPU. Launching one requires synchronising the device, loading instructions, and - most expensively - materialising any intermediate tensors in global DRAM so the next kernel can read them. Fusion collapses a sequence of element-wise or reduction operations into a single kernel that keeps intermediate values in registers or shared memory, never touching DRAM for them.
The key insight is the memory hierarchy cost ladder:
| Memory level | Approximate latency | Bandwidth (A100) |
|---|---|---|
| Register | 0 cycles | - |
| Shared memory (SRAM) | ~30 cycles | ~19 TB/s |
| L2 cache | ~200 cycles | ~4 TB/s |
| Global (HBM) | ~600 cycles | 2 TB/s |
Keeping an intermediate value in a register across fused sub-operations costs nothing. Spilling it to global memory and reloading it in the next kernel costs ~600 cycles per element. At scale, that cost dominates.
Horizontal vs. vertical fusion
There are two structural patterns:
Vertical (producer-consumer) fusion chains operations where the output of one feeds immediately into the next: LayerNorm -> GeLU -> Linear can in principle be fused so that normalised, activated values are passed directly to the GEMM without a round-trip to DRAM.
Horizontal fusion merges independent operations that execute over the same data - for example, the query, key, and value projections in an attention layer all read the same input tensor. A single fused kernel reads it once and writes three outputs, compared to three separate kernels each reading it independently.
Flash Attention is the canonical vertical fusion example in modern ML. It fuses the attention score computation (Q K^T / sqrt(d_k)), the softmax, and the weighted sum of values into one kernel pass, which is the fundamental reason it achieves near-HBM-optimal throughput rather than the 4-5x slower unfused baseline. The algorithm is non-trivial precisely because a softmax is a global reduction (you need row maximums before you can normalise), yet FlashAttention keeps values in SRAM through a tiled, numerically stable rewrite.
How compilers decide what to fuse
Manually writing fused CUDA kernels is expensive. Two tools handle this at different levels of abstraction.
Triton is a Python-embedded DSL that compiles to PTX. A Triton kernel is explicitly written as a single fused operation; the programmer defines the tiling and memory access patterns, and Triton's compiler handles register allocation, vectorisation, and shared-memory management. The fused softmax tutorial in the official Triton docs shows a kernel that is 4x faster than a PyTorch JIT baseline at common matrix widths because it avoids the DRAM round-trips entirely.
# Triton fused softmax - simplified shape of the kernel
@triton.jit
def fused_softmax_kernel(X, Y, stride, n_cols, BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
x = tl.load(X + row * stride + cols, mask=cols < n_cols, other=-float('inf'))
# max + exp + sum + divide all happen in registers - no intermediate DRAM writes
x = x - tl.max(x, axis=0)
x = tl.exp(x)
x = x / tl.sum(x, axis=0)
tl.store(Y + row * stride + cols, x, mask=cols < n_cols)
torch.compile (introduced in PyTorch 2.0) automates fusion through a compiler stack: TorchDynamo captures Python bytecode into a computation graph, TorchInductor lowers it to Triton or C++ kernels, and the optimiser applies pointwise fusion, reduction fusion, and memory layout transformations without the programmer doing anything explicit. For a typical training forward pass, torch.compile commonly reduces kernel launch overhead by collapsing dozens of small element-wise kernels into a handful of fused ones.
XLA (the compiler underlying JAX and TensorFlow) uses a different fusion algorithm: it analyses the HLO (high-level operations) graph and classifies ops as "fusible producer-consumer pairs" based on element-count and access pattern. Snider and Liang (2023) report up to 10.56x speedup from custom fusion strategies on top of XLA's defaults, illustrating that the compiler's heuristics leave room on the table.
When the hardware runs a fused kernel
Inside a single fused kernel, the GPU's streaming multiprocessors (SMs) each handle a tile of the input. The key resource constraints are:
- Registers per thread: A100 has 65,536 registers per SM. A kernel that uses more than its share per thread will spill to local memory (which lives in global DRAM), defeating fusion.
- Shared memory per SM: 164 KB on A100. Kernels that need large tiles for their reductions (like FlashAttention with large
d_k) must be careful here; block size must be tuned so the tile fits. - Occupancy: More concurrent warps per SM hide latency. Fused kernels that use many registers per thread lower occupancy, potentially hurting throughput on operations that are not purely compute-bound.
These constraints create the fusion-or-spill tradeoff: fusing too many operations into one kernel can push register and shared-memory usage beyond SM capacity, at which point the hardware degrades performance.
When it falls down
Ops with incompatible tile shapes. Fusing a row-wise reduction (softmax) with a column-wise operation (e.g., layer norm over the other axis) requires holding the full row in SRAM; on very wide models this may exceed SRAM capacity.
Reductions that need global synchronisation. Batch normalisation must compute a mean over a full batch dimension. That requires a cross-SM reduction, which cannot happen within a single kernel without atomics or a two-pass scheme. Naive fusion across such ops either is incorrect or needs complex multi-stage kernels.
Dynamic shapes. torch.compile and Triton kernels are compiled for specific tile sizes and strides. Frequent input shape changes force recompilation. The dynamic=True flag in torch.compile helps but produces less-optimised code than the static case.
Short operator chains on compute-bound ops. If the bottleneck is FLOP throughput (e.g., a large GEMM), fusion with adjacent cheap ops (bias add, ReLU) saves little: the kernel is already compute-limited, so removing DRAM traffic does not move the needle. Profile before fusing.
Debugging difficulty. A fused kernel conflates multiple logical operations. Numerical errors inside a fused reduction (e.g., overflow before subtraction of max) are harder to isolate than bugs in separate kernels with inspectable intermediate tensors.
Further reading
- Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations - the fused softmax tutorial, with benchmarks and annotated kernel code.
- Operator Fusion in XLA: Analysis and Evaluation (Snider & Liang, 2023) - empirical analysis of XLA's fusion decisions and a 10x+ speedup from custom strategies.
- NVIDIA CUDA C++ Best Practices Guide - primary reference for memory coalescing, shared memory bank conflicts, and occupancy tuning; the foundation for understanding why fusion helps.
- Triton language documentation - full reference for writing custom fused kernels in Python-embedded Triton.