Applied LLMs
Why GEMMs Dominate
Almost every compute-heavy operation in a neural network reduces to a matrix multiply, which is why hardware and compilers optimise almost exclusively for GEMM throughput.
intermediate · 7 min read
Every millisecond a GPU spends running a transformer, roughly 80-90% of its floating-point operations are inside a single primitive: the General Matrix Multiply, or GEMM. That number is not incidental. It is the result of a decade of careful co-design between neural network architectures and silicon.
What a GEMM actually is
The canonical form is:
C = α·A·B + β·C
where A is [M × K], B is [K × N], C is [M × N]
α and β are scalars (usually 1 and 0). The total floating-point work is 2·M·N·K operations (one multiply and one add per element of the inner product, summed over K).
A fully-connected layer with input dimension d_in and output dimension d_out, batched over B samples, maps exactly to this:
| Layer quantity | GEMM dimension |
|---|---|
| Batch size B | M |
| Input features d_in | K |
| Output features d_out | N |
Attention is no different. The query-key dot product QK^T is an [B·heads × T × d_k] times [B·heads × d_k × T] multiply. The value projection softmax(QK^T)·V is another. The four weight projections (Q, K, V, O) are four more batched GEMMs. Convolutions, once unrolled via im2col, collapse to a GEMM too. The kernel slides become rows of a matrix.
Why hardware loves GEMMs: arithmetic intensity
A GPU is not one kind of machine. It has two performance ceilings: peak compute (FLOP/s) and peak memory bandwidth (bytes/s). The ratio between them defines the ops:byte ratio - how many arithmetic operations the chip can sustain per byte it fetches from DRAM. On an NVIDIA H100 SXM5, peak FP16 tensor-core throughput is roughly 989 TFLOP/s, and HBM3 bandwidth is roughly 3.35 TB/s. That gives an ops:byte ratio near 295 FLOP/B.
Any operation with arithmetic intensity below that threshold is memory-bound: the compute units sit idle waiting for data. Any operation above it is compute-bound: the memory subsystem has time to prefetch while arithmetic churns.
A large GEMM with M=N=K=8192 has arithmetic intensity:
2·M·N·K / (bytes read + bytes written)
= 2·8192³ / ((8192²+8192²+8192²)·2 bytes for FP16)
≈ 2730 FLOP/B
That is nearly ten times above the V100's roofline knee. The chip is fully compute-bound; every tensor core is busy. Contrast that with a pointwise activation function (e.g. GELU over a 4096-element vector): it reads 4096 values, does a handful of FLOPs each, and writes 4096 values back. Arithmetic intensity is effectively 1-2 FLOP/B - entirely memory-bound, and the GPU delivers a fraction of its rated throughput.
This is the core reason GEMMs dominate the profiling trace: they are the only common operation large enough to saturate the compute units rather than the memory bus.
How tensor cores exploit the structure
A tensor core (introduced in Volta, refined through Hopper) executes a 16×16×16 FP16 multiply-accumulate in a single instruction, delivering 256 FP16 multiplies and 256 FP32 accumulates per clock. The instruction looks like:
D[16×16] = A[16×8] × B[8×16] + C[16×16] (one warp instruction)
These tiles are stacked in the K dimension: the outer loop advances K by steps of 8, accumulating partial products into the 16×16 FP32 accumulator. The compiler (via cuBLAS or CUTLASS) tiles the M and N dimensions across thread blocks, keeping shared memory full and the tensor cores fed.
The entire efficiency argument depends on the matrices being large enough to hide the overhead of tiling and to amortise the cost of loading weights from DRAM. That is why NVIDIA's performance guide insists on dimensions that are multiples of 8 in FP16 (or 16 in INT8): partial tiles waste tensor-core lanes.
A concrete rule: if M, N, or K is not a multiple of 8, you may lose 20-50% of tensor-core throughput to padding overhead. This is called tile quantisation.
The three GEMMs in every linear layer
A forward pass through y = xW^T + b creates one GEMM. Backpropagation creates two more:
- Forward:
Y = X · W^T, shape[M×N] - Input gradient:
dX = dY · W, shape[M×K] - Weight gradient:
dW = dY^T · X, shape[N×K]
All three are GEMMs of the same problem size. Training cost is therefore 3× the inference cost (in FLOP terms, ignoring memory traffic). This factor of three also explains why the weight gradient GEMM is often the bottleneck during training: dW has the same shape as W itself, so its accumulation must finish before the optimiser step.
When it falls down
Small batch / single-token inference. When batch size is 1 and sequence length is short, M = 1. The GEMM degenerates to a matrix-vector product (GEMV). Arithmetic intensity collapses to roughly 2·N·K / (N·K·2) = 1 FLOP/B - squarely memory-bound. A 70B parameter model performing single-token decoding saturates HBM bandwidth, not tensor cores. This is why speculative decoding, continuous batching, and quantisation (to reduce the bytes moved) matter so much at inference time.
Tiny hidden dimensions. Models with d_model = 128 or sub-256 feed-forward dimensions produce GEMMs where M, N, K are all small. Neither M nor N tile cleanly across hundreds of streaming multiprocessors; most SMs sit idle. This limits how small a model can be before a different hardware target (e.g. an NPU or a CPU SIMD unit) becomes more efficient.
Sparse or structured-sparse weights. GEMMs assume dense storage. A weight matrix with 90% zeros still costs the same FLOP count in a standard GEMM. NVIDIA's 2:4 structured sparsity (Ampere onwards) cuts the GEMM cost in half by storing only the 2 non-zero values in every group of 4, but it requires specific sparsity patterns during training. Unstructured pruning offers no hardware speedup without specialised sparse kernels, and those rarely outperform dense GEMMs until sparsity exceeds 90-95%.
Non-linear bottlenecks in long-context attention. The softmax in attention is not a GEMM. At sequence length T = 32 k, the attention matrix is T × T = 1G elements, each requiring a non-GEMM normalisation step. FlashAttention fuses the softmax into the GEMM kernel to keep the intermediate matrix in SRAM, but the fundamental issue is that the quadratic memory footprint forces algorithm-level workarounds that no amount of tensor-core optimisation can fully paper over.
Further reading
- NVIDIA Deep Learning Performance Guide: Matrix Multiplication - authoritative treatment of tile quantisation, arithmetic intensity, and the roofline model as applied to GEMMs.
- NVIDIA Deep Learning Performance Guide: Fully Connected Layers - maps layer hyperparameters to GEMM dimensions M, N, K and gives concrete alignment rules.
- LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale (Dettmers et al., NeurIPS 2022) - shows what happens when GEMM assumptions break at scale: emergent outlier features force mixed-precision strategies.