Applied LLMs
When Not to Write a Custom Kernel
Writing a CUDA kernel is expensive to maintain and easy to get wrong; this concept maps the decision boundary between writing one and leaning on existing compilers and libraries.
intermediate · 7 min read
Writing a CUDA kernel from scratch typically takes a senior GPU engineer two to five days for the first correct, performant version, and then ongoing effort every time the NVIDIA architecture changes. The question is not whether custom kernels are powerful - they clearly are - but whether the performance gap over readily available alternatives actually justifies that cost. Most of the time, it does not.
The default stack already covers the common cases
Before reaching for a .cu file, consider what is already optimised for you.
cuBLAS / cuDNN / cuSPARSE. NVIDIA ships hand-tuned kernels for general matrix multiply (GEMM), convolution, attention, and a dozen other primitives. These are compiled against every Ampere, Hopper, and Ada generation variant and exploit architecture-specific instructions (Tensor Cores, wgmma, etc.) that are tedious to drive manually. A custom GEMM written by most teams will be slower than cuBLAS on standard shapes.
torch.compile + TorchInductor. From PyTorch 2.0 onward, torch.compile traces a computation graph, fuses point-wise operations, and emits optimised Triton or CUDA kernels without you writing a line of low-level code. The official tutorial reports a 2.26x speedup on a 4096x4096 workload with zero kernel code. For element-wise chains (layer norm, GELU, dropout in sequence), torch.compile will often emit a fused kernel that is close to hand-written quality.
Triton. When torch.compile does not produce a good kernel for your specific pattern, Triton lets you write tile-level GPU code in Python, and the compiler handles register allocation, shared-memory management, and async copies. The official fused-softmax tutorial shows Triton sustaining ~1,400 GB/s versus ~400 GB/s for a naive PyTorch loop. If your operation fits in a single tile pass, Triton is usually the right tool before CUDA.
XLA / JAX. In JAX, jax.jit lowers through XLA, which applies operator fusion, layout optimisation, and rematerialisation automatically. The resulting HLO bytecode is compiled per-device and cached. Whole-model compilation under XLA often outperforms hand-tuned per-operator code because XLA can reason globally across the whole graph.
The practical hierarchy looks like this:
| Option | Effort | Fits when |
|---|---|---|
| cuBLAS / cuDNN | Zero - use via PyTorch | Standard GEMM, conv, attention |
| torch.compile | One line: model = torch.compile(model) |
Element-wise chains, standard architectures |
| Triton | ~50-200 lines Python | Non-standard tile pattern, bandwidth-bound op |
| Custom CUDA | 200-2000 lines C++ | Architecture-specific intrinsics, irregular memory layout, no compiler support |
When a custom kernel is genuinely justified
The case for writing CUDA is strong in exactly three scenarios.
1. The operation has no good primitive and is a training bottleneck. FlashAttention is the canonical example: the standard attention implementation is memory-bandwidth limited because it materialises the full N x N attention matrix to HBM. FlashAttention restructures into tiles that stay in SRAM, which requires precise control over reads and writes that torch.compile cannot discover automatically because it lacks the algebraic rewrite that makes tiling mathematically valid.
2. You need architecture-specific intrinsics. Sparse weight kernels using structured 2:4 sparsity, sub-byte quantisation (INT4/FP8 with hand-tuned dequant), or warp-level matrix instructions (e.g. wmma or wgmma) often require code that no compiler generates today. If you are squeezing the last 20% out of inference on H100s at high traffic, this argument is real.
3. Existing frameworks impose a shape or layout assumption you cannot satisfy. cuDNN expects NHWC or NCHW. cuBLAS expects column-major or row-major with standard strides. If your data layout is genuinely irregular (sparse block structure, ragged sequences, interleaved formats), the library reshaping overhead can negate the kernel speed.
Notice what is NOT on the list: "I think this op is slow." Profile first. torch.profiler and Nsight Systems will tell you whether the operation is actually a bottleneck and whether you are compute-bound or memory-bandwidth-bound. Writing a custom kernel for a compute-bound bottleneck that is 2% of total runtime is a maintenance burden with zero measurable impact.
The hidden costs that make the decision asymmetric
Performance engineering intuition tends to undercount the long-term costs of a custom kernel.
Maintenance against new architectures. A kernel that achieves peak throughput on Ampere (A100) may be 30-40% slower on Hopper (H100) because the memory hierarchy, warp scheduling, and async-copy semantics changed. torch.compile and Triton regenerate code per-target; a hand-written .cu file needs manual porting.
Debugging complexity. A race condition in shared memory, an off-by-one in a tile index, or a missing __syncthreads() produces silent wrong answers, not crashes. Debugging CUDA race conditions with cuda-memcheck and compute-sanitizer is significantly slower than debugging Python.
Graph break cost. If you wrap a custom CUDA kernel as a torch.autograd.Function and call it inside a torch.compile region, you introduce a graph break: the compiler cannot trace into your kernel and must split the graph around it. Two graph breaks in a training loop can eliminate most of the speedup torch.compile would otherwise provide. The breakage cost is often larger than the gain from the custom op.
Portability. A .cu file does not run on AMD ROCm, Apple MPS, or Google TPUs without a port. Triton targets multiple backends; torch.compile is backend-agnostic. If your workload will ever run on non-NVIDIA hardware, the custom kernel locks you in.
When it falls down
The "do not write a custom kernel" heuristic has genuine limits.
- torch.compile on highly dynamic shapes. If your batch sizes or sequence lengths vary widely across calls, torch.compile will recompile frequently or fall back to eager mode, and the compilation overhead dominates. In this regime, a hand-tuned kernel with runtime shape logic can outperform a compiled graph.
- Very long compilation times. Large models compiled with torch.compile can take minutes on first run. For latency-sensitive cold-start scenarios (serverless inference, hot-swap model serving), the compilation overhead is unacceptable and pre-compiled CUDA kernels loaded at startup are preferable.
- Triton coverage gaps. Triton's tile model assumes your operation decomposes cleanly into fixed-size 2D blocks. Operations with irregular data dependencies (e.g., sparse attention with arbitrary masks, graph neural network scatter operations over variable neighbourhoods) do not map cleanly to Triton tiles and may require CUDA.
- Sub-byte precision. INT4, MXFP4, and similar formats require bit-packing and dequantisation logic that current compilers do not generate optimally. At production inference scale (millions of requests per day), 15-20% throughput differences from quantised GEMM quality matter financially.
- Kernel fusion across framework boundaries. torch.compile fuses within a single graph capture. If your pipeline mixes PyTorch, custom C++ extensions, and third-party CUDA calls, the compiler cannot see across boundaries, and a hand-written fused kernel may be the only way to eliminate intermediate buffers.
The practical takeaway is not "never write CUDA" but "exhaust the compiler/library options first, profile to confirm a real bottleneck, and account for the full maintenance lifecycle before committing."