← Concept library

Applied LLMs

Triton: Python-Level GPU Kernels

Triton lets you write GPU kernels in Python by operating on tiles of data rather than individual threads, and its compiler handles shared-memory management, coalescing, and vectorisation automatically.

intermediate · 8 min read

Writing a fast CUDA kernel for a fused softmax or a custom attention variant takes hundreds of lines of C++, explicit shared-memory declarations, warp-level synchronisation, and careful vectorisation. A single dimension change in your tensor can silently halve throughput. Triton offers a different contract: write the algorithm at the tile level in Python, and let the compiler figure out the low-level scheduling.

Why CUDA is Hard to Optimise by Hand

A GPU executes thousands of threads simultaneously, but raw thread count is not what makes kernels fast. Throughput depends on three things happening simultaneously:

  1. Memory coalescing. Adjacent threads must access adjacent addresses so the memory controller can serve a full 128-byte cache line in one transaction.
  2. Shared-memory tiling. Data used repeatedly must be staged in on-chip SRAM (shared memory) to avoid re-fetching from HBM.
  3. Occupancy and latency hiding. Enough warps must be in flight to cover the hundreds of cycles that a global-memory read costs.

Getting all three right simultaneously, for every kernel you write, is the core difficulty. A fused layernorm + linear kernel that is fast on an A100 may need a full rewrite for an H100 because warp size, tensor-core shape, and memory bandwidth ratios all shift.

The Triton Programming Model: Blocked Programs

Triton's key insight, explained clearly in its programming guide, is to raise the abstraction level from individual threads to blocks of data. Rather than asking "what does thread 47 do?", you ask "what does this tile of 128 elements do?".

The model has two layers:

  • Python + @triton.jit: You write a regular Python function. Triton's JIT compiles it to PTX (NVIDIA's IR) or AMDGPU ISA at the first call.
  • Tile-level primitives (triton.language): Operations like tl.load, tl.store, tl.dot, and tl.sum operate on pointers + offsets that implicitly represent contiguous blocks.

A minimal vector-addition kernel illustrates the model:

import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
    pid   = tl.program_id(0)               # which tile am I?
    offs  = pid * BLOCK + tl.arange(0, BLOCK)
    mask  = offs < n                       # guard OOB
    x     = tl.load(x_ptr + offs, mask=mask)
    y     = tl.load(y_ptr + offs, mask=mask)
    tl.store(out_ptr + offs, x + y, mask=mask)

tl.program_id(0) corresponds roughly to a CUDA block index; tl.arange(0, BLOCK) is the intra-tile offset vector. The mask pattern handles tail elements without a separate epilogue kernel. At compile time, BLOCK is a constexpr, so the compiler can fully unroll and vectorise the inner loop.

Contrast this with CUDA, where you would manually declare __shared__ buffers, insert __syncthreads(), and compute byte-level pointer arithmetic. Triton's compiler inserts all of that from the blocked data-flow graph.

What the Compiler Automates

Triton's backend performs several passes that would otherwise be manual in CUDA:

Compiler pass What it does
Coalescing analysis Reorders memory accesses so warps read contiguous lines
Shared-memory allocation Inserts __shared__ staging buffers for reused tiles
Thread swizzling Permutes thread-to-data assignment to avoid bank conflicts
Vectorisation Emits 128-bit LDG instructions where alignment allows
Tensor-core selection Rewrites tl.dot to use mma instructions on ampere/hopper

This is the practical pay-off: a Triton fused-softmax kernel written in roughly 40 lines matches hand-tuned CUDA performance on A100 for large batch sizes, according to benchmarks in the official tutorial suite. The compiler is not magic - it makes the same decisions an expert CUDA programmer would - but it makes them consistently and portably.

How torch.compile Uses Triton

PyTorch's torch.compile pipeline (introduced in PyTorch 2.0) uses Triton as its default kernel-generation backend on NVIDIA hardware. The flow is:

Python model
  -> TorchDynamo  (captures computation graph via bytecode tracing)
  -> TorchInductor (schedules + fuses ops)
  -> Triton JIT  (generates and caches PTX)
  -> CUDA driver  (loads PTX to GPU)

TorchInductor decides which ops to fuse and emits Triton kernel source. This means you can get custom fused kernels for your model without writing a single line of Triton yourself. The generated kernels are cached in ~/.cache/torch/inductor/, so the compilation cost is paid once per unique graph shape.

The implication for practitioners: hand-writing Triton makes sense when you need a kernel the inductor does not generate - custom sparse attention, non-standard normalisation, or a kernel that fuses across a dataflow pattern the graph compiler cannot see.

When it Falls Down

Small tensors and high-dispatch overhead. Triton's JIT compilation is cheap after the first call, but launching a kernel for a 256-element vector still costs a few microseconds of driver overhead. For small operations called in a tight Python loop, CUDA graphs or torch.compile with graph capture eliminate this overhead better than hand-written Triton.

Reductions across large axis counts. Triton handles 1-D and 2-D reductions well, but reductions over three or more axes require manual tiling strategies that quickly become as complex as CUDA. Libraries like cuDNN or FlashAttention-3 use specialised code generation that goes beyond what Triton's current IR expresses cleanly.

Irregular or sparse access patterns. The compiler's coalescing analysis assumes regular stride patterns. Genuinely irregular sparse formats (CSR, COO with random indices) yield poor coalescing, and Triton gives you no better tools than CUDA for this case.

Hopper-specific features. The H100's TMA (Tensor Memory Accelerator) and warp-specialised pipelines require explicit async copy instructions. Triton's Hopper backend (as of 2024-2025) supports some of these, but complex ping-pong pipelines still require dropping to PTX or using CUTLASS.

Debugging is harder than it looks. Triton kernels run on the GPU; you cannot use print in a kernel (unlike pdb). The triton-viz project and tl.device_print (an experimental feature) help, but diagnosing a race condition or a silent NaN in a fused kernel is substantially harder than in PyTorch eager mode.

Portability across vendors. Triton targets NVIDIA and AMD (via ROCm), but the performance portability story is incomplete. A kernel tuned for A100 BLOCK sizes may underperform on AMD MI300X without re-autotuning, and support for Intel GPUs is experimental.

Further Reading

Sign in to save and react.
Share Copied