FlashAttention and the Memory Wall: Why Attention Was Never Compute-Bound
June 07, 2026 · 23 min read
In 2022, Tri Dao and his collaborators published a paper with a claim that sounded almost too modest to matter: attention could be made 3x faster without changing a single output value (Dao et al., 2022, FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, arXiv:2205.14135). No approximation, no sparsity pattern, no low-rank trick. The algorithm computed exactly the same softmax attention as the 2017 Transformer, bit-for-bit up to floating-point reassociation. The entire speedup came from changing where intermediate values lived during the computation. Five years of research into approximate attention had attacked the wrong constraint. Attention was never short on arithmetic. It was drowning in memory traffic.
Why this matters: Every model you run today, through PyTorch's
scaled_dot_product_attention, vLLM, or Hugging Face Transformers, almost certainly executes a FlashAttention-family kernel. Understanding why it wins tells you how to reason about any GPU workload: the question is rarely "how many FLOPs?" and almost always "how many bytes?"
TL;DR
- GPU compute has outgrown GPU memory bandwidth for two decades: peak server FLOPs scaled roughly 3.0x every two years while DRAM bandwidth scaled about 1.6x (Gholami et al., 2024, AI and Memory Wall, arXiv:2403.14123). Attention sits on the wrong side of that gap.
- Standard attention materializes an \(N \times N\) score matrix in GPU main memory (HBM), reads it back for softmax, writes the result, and reads it again. For a 4K sequence that is over 100 MB of traffic per head for about 1 MB of actual inputs.
- FlashAttention fuses the whole computation into one kernel that streams tiles through on-chip SRAM, which on an A100 is roughly 10x faster than HBM but ten thousand times smaller. The \(N \times N\) matrix never exists in HBM.
- The enabling trick is the online softmax of Milakov and Gimelshein, 2018, arXiv:1805.02867: softmax can be computed incrementally over blocks by carrying a running maximum and a running normalizer.
- The backward pass recomputes attention scores from Q, K, V rather than storing them. Spending extra FLOPs to avoid bytes is a winning trade on modern hardware, and that inversion is the deepest lesson of the paper.
- Each generation chased the hardware: FlashAttention-2 reorganized work partitioning to hit up to 72% model FLOPs utilization on A100 (Dao, 2023, arXiv:2307.08691); FlashAttention-3 used Hopper's asynchronous tensor cores and FP8 to reach about 1.2 PFLOPs/s on H100 (Shah et al., 2024, arXiv:2407.08608).
- Memory complexity drops from \(O(N^2)\) to \(O(N)\), which is what made 32K-128K context windows trainable at all. The attention time is still quadratic in sequence length.
At a Glance
flowchart LR
subgraph STD["Standard attention"]
A1["Q, K, V in HBM"] --> A2["Write S = QKᵀ to HBM"]
A2 --> A3["Read S, softmax, write P"]
A3 --> A4["Read P, multiply by V"]
A4 --> A5["Output"]
end
subgraph FA["FlashAttention"]
B1["Q, K, V in HBM"] --> B2["Stream tiles into SRAM"]
B2 --> B3["Fused matmul + online softmax"]
B3 --> B4["Output written once"]
end
classDef blue fill:#1e40af,stroke:#3b82f6,stroke-width:1px,color:#fff
classDef rose fill:#be123c,stroke:#fb7185,stroke-width:1px,color:#fff
classDef purple fill:#6d28d9,stroke:#a78bfa,stroke-width:1px,color:#fff
classDef teal fill:#0e7490,stroke:#22d3ee,stroke-width:1px,color:#fff
class A1,B1 blue
class A2,A3,A4 rose
class B2,B3 purple
class A5,B4 teal
The red boxes are the problem: every one of them is a full pass over an \(N \times N\) matrix through memory that moves bytes 100x slower than the tensor cores consume them. FlashAttention's purple path keeps those intermediates in on-chip SRAM and touches HBM only for the inputs and the final output.
[IMAGE: Roofline plot for an A100, x-axis arithmetic intensity in FLOPs/byte (log), y-axis attainable TFLOPs/s (log); mark the ridge point near 195 FLOPs/byte, place standard attention (~65 FLOPs/byte) left of the ridge and FlashAttention to the right, with arrows showing the move]
Before FlashAttention
When the Transformer arrived (Vaswani et al., 2017, Attention Is All You Need, arXiv:1706.03762), nobody worried much about the quadratic score matrix; sequences were a few hundred tokens and the matrices were small. As context lengths pushed past 1K, the field's diagnosis was that attention had too many FLOPs and too much memory, both \(O(N^2)\), and that the cure was approximation. The years 2019-2021 produced a zoo of "efficient Transformers": Reformer's locality-sensitive hashing (Kitaev et al., 2020, arXiv:2001.04451), Longformer's sliding windows (Beltagy et al., 2020, arXiv:2004.05150), Linformer's low-rank projections (Wang et al., 2020, arXiv:2006.04768), and Performer's kernel feature maps (Choromanski et al., 2020, arXiv:2009.14794).
Almost none of them displaced dense attention in production models. The reason was awkward: many "linear" attention variants were not actually faster at the sequence lengths people used, because their FLOP savings did not translate into wall-clock savings. The bottleneck was somewhere else.
Two papers saw it early. Ivanov et al. profiled Transformer training and found it memory-bound, cutting data movement by about 23% for a 1.3x speedup on a BERT encoder layer without touching the math (Ivanov et al., 2020, Data Movement Is All You Need, arXiv:2007.00072). And Rabe and Staats showed that self-attention needs only \(O(\log n)\) memory if you process it in chunks, separating the memory question from the compute question (Rabe and Staats, 2021, arXiv:2112.05682). FlashAttention combined the chunked computation with a hardware-conscious kernel design and added the piece nobody had done end to end: an IO-optimal backward pass.
timeline title From approximation back to exactness 2017 : Transformer, attention on short sequences 2018 : Online softmax shows incremental normalization is possible 2019-2021 : Approximate era, Reformer, Longformer, Linformer, Performer 2020 : Data Movement Is All You Need profiles the real bottleneck 2021 : Rabe and Staats, chunked attention in O(log n) memory 2022 : FlashAttention, exact and IO-aware, 3x GPT-2 speedup 2023 : FlashAttention-2, work partitioning, 72% MFU on A100 2024 : FlashAttention-3, Hopper asynchrony and FP8, ~1.2 PFLOPs/s
How FlashAttention Actually Works
The memory hierarchy is the algorithm's real input
A GPU is not one memory and one processor. An A100 exposes 40-80 GB of HBM2e at 1.6-2.0 TB/s (NVIDIA A100 datasheet), but each of its 108 streaming multiprocessors also has 192 KB of on-chip SRAM with an aggregate bandwidth around 19 TB/s, roughly an order of magnitude faster than HBM (figures from the FlashAttention paper's hardware characterization). The compute units are faster still: 312 TFLOPs/s of dense FP16/BF16 tensor-core throughput.
graph TD HOST["Host DRAM, ~25 GB/s over PCIe"] --> HBM["HBM2e: 40-80 GB at ~2 TB/s"] HBM --> L2["L2 cache: 40 MB"] L2 --> SRAM["SM SRAM: 192 KB x 108 SMs, ~19 TB/s"] SRAM --> REG["Registers, feeding 312 TFLOPs/s of tensor cores"] classDef slate fill:#334155,stroke:#64748b,stroke-width:1px,color:#e2e8f0 classDef blue fill:#1e40af,stroke:#3b82f6,stroke-width:1px,color:#fff classDef purple fill:#6d28d9,stroke:#a78bfa,stroke-width:1px,color:#fff classDef emerald fill:#047857,stroke:#34d399,stroke-width:1px,color:#fff class HOST slate class HBM blue class L2 purple class SRAM,REG emerald
Divide the numbers and you get the constraint that governs everything: 312 TFLOPs/s against 1.6 TB/s means a kernel must perform roughly 195 floating-point operations per byte of HBM traffic to keep the tensor cores busy. Below that arithmetic intensity, the chip waits on memory no matter how clever the math is.
Standard attention is far below it. The computation
\[\text{Attention}(Q,K,V)=\text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V\]is implemented as three separate kernels: a matmul that writes \(S = QK^\top\) to HBM, a softmax that reads \(S\) and writes \(P\), and a second matmul that reads \(P\). The matmuls individually are compute-friendly, but the softmax and the round trips between kernels are pure memory traffic over \(N^2\) values. The aggregate arithmetic intensity lands around 60-70 FLOPs per byte for typical head dimensions (a back-of-envelope figure; the worked example below derives it), a third of what the hardware needs.
[IMAGE: Annotated GPU die schematic showing one SM enlarged, with the 192 KB SRAM block highlighted and arrows to HBM stacks at the die edge, labeled with the 19 TB/s vs 2 TB/s bandwidths]
Tiling: never materialize the matrix
The fix is conceptually old (it is how every fast matrix multiply has worked since the 1990s): split the inputs into blocks small enough to live in SRAM, and do all the work on a block while it is resident. FlashAttention splits \(Q\) into row blocks and \(K, V\) into column blocks. For each pair of blocks it computes a small tile of scores, applies softmax logic, and accumulates the contribution to the output, all inside one kernel, all in SRAM.
The obstacle that kept attention from being tiled for five years is the softmax denominator. Softmax over a row needs the row's maximum and the sum of exponentials across all \(N\) columns, but a tile only sees a slice of the row. You appear to need the whole row before you can normalize anything.
Online softmax: the trick that unlocks fusion
Milakov and Gimelshein showed in 2018 that softmax can be computed in a single streaming pass by maintaining two running statistics per row: the maximum \(m\) seen so far and the normalizer \(\ell\) rescaled to that maximum (arXiv:1805.02867). When a new block arrives with block-local statistics \(\tilde{m}\) and \(\tilde{\ell}\), the running values update as:
\[m^{(\text{new})} = \max\big(m, \tilde{m}\big), \qquad \ell^{(\text{new})} = e^{\,m - m^{(\text{new})}}\,\ell + e^{\,\tilde{m} - m^{(\text{new})}}\,\tilde{\ell}\]FlashAttention extends the same rescaling to the output accumulator: the partial output computed under an old maximum is multiplied by \(e^{\,m - m^{(\text{new})}}\) before the new block's contribution is added. Every intermediate stays exact; the final result equals the monolithic softmax up to floating-point rounding. This is why FlashAttention is exact attention, not an approximation.
The payoff in IO terms is dramatic. Standard attention performs \(\Theta(Nd + N^2)\) HBM accesses. FlashAttention performs \(\Theta(N^2 d^2 M^{-1})\), where \(M\) is the SRAM size; for typical \(d\) of 64-128 and \(M\) of about 100K elements, that is many times fewer accesses, and the paper proves no exact-attention algorithm can do asymptotically better over a range of SRAM sizes (Dao et al., 2022).
The backward pass: pay FLOPs, save bytes
Training needs gradients, and the gradient of attention wants the \(N \times N\) probability matrix \(P\) that the forward pass deliberately never stored. The standard solution would be to store it, which reinstates the \(O(N^2)\) memory and the HBM traffic. FlashAttention instead stores only the per-row statistics \((m, \ell)\), which are \(O(N)\), and recomputes the score tiles from \(Q\) and \(K\) during the backward pass.
That is more arithmetic, roughly a third more FLOPs for the attention layer. It is still faster, because the recomputation happens in SRAM at tensor-core speed while the alternative is reading tens of megabytes per head back through HBM. On a machine where compute outruns bandwidth by 100x, FLOPs are the cheap currency and bytes are the expensive one.
flowchart TD
A["Backward pass needs the score matrix"] --> B{"Was it stored in HBM?"}
B -- "Standard: yes" --> C["Read two N x N matrices back from HBM"]
C --> D["Memory-bound, slow despite fewer FLOPs"]
B -- "FlashAttention: no" --> E["Recompute tiles in SRAM from Q and K"]
E --> F["~33% more FLOPs, far fewer bytes, net faster"]
classDef slate fill:#334155,stroke:#64748b,stroke-width:1px,color:#e2e8f0
classDef rose fill:#be123c,stroke:#fb7185,stroke-width:1px,color:#fff
classDef emerald fill:#047857,stroke:#34d399,stroke-width:1px,color:#fff
class A,B slate
class C,D rose
class E,F emerald
FlashAttention-2 and 3: chasing the hardware
The original kernel left throughput on the table: it reached 30-50% of peak on A100 because too much time went into non-matmul operations and because work was parallelized only over batch and heads, starving the GPU at small batch sizes and long sequences. FlashAttention-2 restructured the loops (each thread block now owns a block of queries and iterates over keys), parallelized additionally over the sequence dimension, and minimized shared-memory shuffles between warps. The result was about 2x over the original, 50-73% of theoretical peak in the attention kernel, and 225 TFLOPs/s (72% model FLOPs utilization) in end-to-end GPT-style training (Dao, 2023, arXiv:2307.08691).
Hopper moved the goalposts again. An H100 SXM offers roughly 989 TFLOPs/s of dense BF16 and 3.35 TB/s of HBM3 (NVIDIA datasheet figures), but extracting it requires using the chip's asynchronous machinery: the Tensor Memory Accelerator for bulk copies and warpgroup-level matmuls that run while other warps compute. FlashAttention-2, written for Ampere, hit only about 35% utilization on H100. FlashAttention-3 was redesigned around asynchrony, overlapping softmax with matmuls through warp specialization and pipelining, and added block-quantized FP8 with incoherent processing, reaching about 740 TFLOPs/s in FP16 (75% utilization) and close to 1.2 PFLOPs/s in FP8, with 2.6x lower numerical error than a baseline FP8 attention (Shah et al., 2024, arXiv:2407.08608).
The pattern across the three versions is worth noticing: the algorithm barely changed. Online softmax, tiling, and recomputation are identical in all three. What changed is the mapping of that algorithm onto each generation's execution model. Kernel engineering is now co-evolving with silicon, version by version.
[IMAGE: Three-panel diagram of FlashAttention v1/v2/v3 loop structures: v1 outer loop over K/V blocks, v2 outer loop over Q blocks with sequence parallelism, v3 warp-specialized producer/consumer pipeline with TMA copies overlapping matmuls]
Seeing It in Motion
The forward pass for one block of queries, as a conversation between memory levels:
sequenceDiagram
participant HBM as HBM (global memory)
participant SRAM as SM SRAM (192 KB)
participant TC as Tensor cores
HBM->>SRAM: Load Q block (once)
loop For each K, V block j
HBM->>SRAM: Load K_j and V_j tiles
SRAM->>TC: Compute tile S = Q K_jᵀ
TC->>SRAM: Update running max m, normalizer l
TC->>SRAM: Rescale accumulator, add P_j V_j
end
SRAM->>HBM: Write output block (once)
Note over HBM,SRAM: The N x N score matrix never touches HBM
Each query block reads every key and value block exactly once, updates its running statistics, and writes its output exactly once. The quadratic object exists only as a sequence of small tiles, each living for microseconds in SRAM.
[IMAGE: Animated-style grid showing the N x N score matrix divided into tiles, with one row-band of tiles lighting up in sequence as the inner loop progresses, and the running (m, l) statistics updating beside it]
By the Numbers
The bandwidth hierarchy that motivates everything (A100 SXM, 80 GB):
| Memory level | Size | Bandwidth | Relative speed |
|---|---|---|---|
| SM SRAM (aggregate) | ~20 MB | ~19 TB/s | ~10x HBM |
| HBM2e | 80 GB | ~2.0 TB/s | baseline |
| Host DRAM over PCIe 4.0 | TBs | ~25 GB/s | ~1/80x HBM |
Sources: NVIDIA A100 datasheet; SRAM figures from Dao et al., 2022. The structural trend behind the gap: peak hardware FLOPs scaled ~3.0x per two years over two decades while DRAM bandwidth scaled ~1.6x (Gholami et al., 2024).
Measured gains across the FlashAttention generations:
| Version | Hardware | Headline result | Utilization |
|---|---|---|---|
| FlashAttention (2022) | A100 | 15% end-to-end over the BERT-large MLPerf 1.1 record; 3x GPT-2 (seq 1K); 2.4x Long Range Arena | 30-50% of peak (attention kernel) |
| FlashAttention-2 (2023) | A100 | ~2x over v1; 225 TFLOPs/s end-to-end GPT training | up to 72% MFU |
| FlashAttention-3 (2024) | H100 | 1.5-2x over v2; ~740 TFLOPs/s FP16; ~1.2 PFLOPs/s FP8 | ~75% (FP16 kernel) |
All figures from the respective papers (arXiv:2205.14135, arXiv:2307.08691, arXiv:2407.08608).
Complexity, the part that survives hardware churn:
| Property | Standard attention | FlashAttention |
|---|---|---|
| Time (FLOPs) | \(O(N^2 d)\) | \(O(N^2 d)\), ~1.3x in training due to recomputation |
| Extra memory | \(O(N^2)\) | \(O(N)\) |
| HBM accesses | \(\Theta(Nd + N^2)\) | \(\Theta(N^2 d^2 M^{-1})\) |
The memory column is what changed practice. Quadratic activation memory made a 64K context physically impossible to train on a single accelerator; linear memory made context length a budget item instead of a hard wall.
[IMAGE: Log-scale chart of peak FLOPs vs DRAM bandwidth for server hardware 2003-2024, two diverging trend lines (3.0x vs 1.6x per two years), with the widening gap shaded and labeled "the memory wall"]
A Concrete Example
Take one attention head in a 4K-context model: \(N = 4096\), head dimension \(d = 128\), FP16 everywhere. All figures below are back-of-envelope, which is exactly how a kernel engineer would first size this problem.
Inputs. \(Q\), \(K\), \(V\) are each $4096 \times 128$ FP16 matrices: 1 MB apiece, 3 MB total, plus 1 MB for the output.
Standard attention traffic. The score matrix \(S\) is $4096 \times 4096$: 32 MB in FP16. Kernel one writes it (32 MB). Kernel two reads it and writes the probability matrix \(P\) (64 MB). Kernel three reads \(P\) (32 MB). With inputs and output, total HBM traffic is roughly 132 MB to process 4 MB of real data; 97% of the bytes are intermediate scaffolding.
Arithmetic. Two matmuls of $2N^2d$ FLOPs each: about 8.6 GFLOPs. At the A100's 312 TFLOPs/s, the math takes ~28 microseconds. Moving 132 MB at 1.6 TB/s takes ~82 microseconds. The tensor cores idle two-thirds of the time; arithmetic intensity is 8.6 GFLOPs / 132 MB, about 65 FLOPs per byte, against the ~195 the chip needs.
FlashAttention traffic. With 192 KB of SRAM per SM, the kernel can hold a $128 \times 128$ query tile (32 KB), matching key and value tiles (32 KB each), and its accumulators comfortably. Each query block streams all key/value blocks once; total HBM traffic is the inputs, the output, and the \(O(N)\) statistics, on the order of 5-10 MB. That is a >13x reduction in bytes for the same 8.6 GFLOPs, pushing arithmetic intensity past the ridge point. Memory time drops to a few microseconds and the kernel becomes compute-bound, which is the goal state. The paper's measured numbers agree in spirit: up to 9x fewer HBM accesses on GPT-2, yielding 7.6x speedup on the attention computation itself (Dao et al., 2022).
The memory cliff. At 4K context the standard approach also needs that 32 MB score matrix per head, per layer held for the backward pass. A 32-head, 32-layer model would need ~32 GB just for attention scores at batch size 1. FlashAttention stores 4096 row statistics per head instead: about 0.5 MB across the whole model. That single change is why long-context training stopped being exotic.
Where It Breaks
FlashAttention is not a free lunch, and knowing its failure modes is more useful than knowing its benchmarks.
It does not fix quadratic time. The FLOPs are still \(O(N^2 d)\). FlashAttention moved the crossover point where attention dominates a model's runtime, but a 1M-token context still pays a million-squared price in arithmetic. Long-context methods like ring attention, sliding windows, and sparse attention exist precisely because IO-awareness alone runs out.
Decode inference is a different problem. During autoregressive generation, each new token attends to the whole KV cache with a single query. There is no \(N \times N\) matrix to avoid; the workload is reading the cache itself, and it is memory-bound in a way no tiling can fix. This is why decode-specific kernels (Flash-Decoding, PagedAttention-style cache management in vLLM) exist as separate engineering efforts.
Portability is a real cost. Each version is hand-tuned to one vendor's execution model: v2 to Ampere, v3 to Hopper's TMA and warpgroup matmuls. AMD and Intel GPUs need reimplementations (vLLM ships a Triton backend partly for this reason), and every new NVIDIA generation resets some of the work. The algorithm is portable; the performance is not.
Numerics are equal, not identical. Online softmax reorders floating-point operations, so outputs differ from the naive kernel at the level of rounding. For training this is noise. For bit-exact reproducibility across kernel versions, it is an audit headache that has bitten more than one team's regression suite.
Small sequences gain little. Below a few hundred tokens the score matrix fits in cache anyway, kernel-launch overheads dominate, and the fused kernel's advantage shrinks toward zero. The win grows with \(N\); it was never about making short attention fast.
Alternative Designs
| Approach | Strengths | Weaknesses | Best when |
|---|---|---|---|
| FlashAttention (exact, IO-aware) | Exact outputs; linear memory; production-proven | Still \(O(N^2)\) time; per-architecture tuning | Dense attention up to ~128K tokens |
| Sparse / windowed (Longformer, BigBird) | Sub-quadratic FLOPs; good for structured locality | Quality depends on the sparsity pattern matching the task | Documents with strong local structure |
| Low-rank / kernel (Linformer, Performer) | Linear time in theory | Approximation error; often slower in wall-clock at moderate N | Very long sequences where exactness is negotiable |
| Chunked exact (Rabe-Staats) | Simple; framework-level; \(O(\log n)\) memory | No IO-optimal backward; slower than fused kernels | Memory relief without custom kernels |
| Architectural exits (Mamba/SSMs, linear attention) | \(O(N)\) time and memory by construction | Different model class; quality tradeoffs still being mapped | When you can choose the architecture, not just the kernel |
The honest comparison is that FlashAttention won the kernel argument so decisively that the approximate-attention literature largely pivoted: the interesting alternatives today are architectural (state-space models, hybrid attention layouts) rather than approximations of softmax attention, because approximating a now-fast exact kernel buys little.
How It Is Used in Practice
Adoption is the quiet kind: invisible and near-total. PyTorch integrated FlashAttention-2 into scaled_dot_product_attention as of version 2.2, where it is selected automatically when inputs qualify. vLLM has used FlashAttention-family kernels since v0.1.4 and exposes attention backends (FlashAttention, FlashInfer, Triton) per hardware target. Hugging Face Transformers enables it with attn_implementation="flash_attention_2", and Meta lists FlashAttention-3 among its research deployments. The original repository (Dao-AILab/flash-attention) remains the reference implementation.
The operational consequences for an inference platform: prefill (processing the prompt) is where FlashAttention earns its keep, since that phase has real \(N \times N\) work to eliminate; decode throughput is governed by KV-cache bandwidth and benefits more from paging and batching strategies; and kernel choice is now a deployment decision (a config flag per GPU generation) rather than a modeling decision. Teams running mixed fleets of A100s and H100s routinely run different kernel versions for the same model weights.
Insights Worth Remembering
- The FLOP count of an algorithm tells you almost nothing about its speed on a modern accelerator; the byte count usually tells you everything. Count memory accesses first.
- "Exact but IO-aware" beat three years of "approximate but FLOP-light" research. When a field converges on attacking one constraint, check whether the binding constraint is actually a different one.
- Recomputation is a trade, not a sin. On hardware where compute outruns bandwidth 100x, spending 33% more FLOPs to avoid touching HBM is buying speed with a depreciating currency.
- The online softmax sat in the literature for four years before anyone used it to fuse attention. The bottleneck was not mathematics; it was someone reading a numerics paper with a kernel engineer's eyes.
- FlashAttention's memory result (\(O(N^2) \to O(N)\)) mattered more than its speed result. Speed made training cheaper; linear memory made long context possible.
- An algorithm's asymptotic story can be unchanged while its systems story changes completely. v1, v2, and v3 are the same algorithm with three different relationships to silicon.
- The memory wall is not an attention problem. KV-cache reads in decoding, optimizer state movement, expert routing in MoE: the same FLOPs-vs-bytes analysis decides performance in each. FlashAttention is the worked example for a general method.
Open Questions
Does IO-awareness become a compiler's job? Hand-written kernels chase each hardware generation at considerable engineering cost. Triton, PyTorch's FlexAttention, and ongoing compiler research aim to generate fused, tiled attention variants automatically. It is measured fact that compiled kernels now approach hand-tuned performance for many attention variants; whether they can absorb each new generation's asynchronous machinery (TMA-style engines, warp specialization) without human kernel work remains open.
How far does FP8 (and below) attention go? FlashAttention-3 showed block quantization and incoherent processing hold FP8 error to 2.6x below a baseline FP8 attention (Shah et al., 2024). Whether 4-bit attention computation (not just 4-bit weights) can be made training-safe is an active question, and the answer likely differs for prefill and decode.
Does the memory wall bend? The FLOPs-vs-bandwidth gap that created this entire line of work is structural (Gholami et al., 2024). HBM4 and processing-in-memory designs attack it from the hardware side; if bandwidth ever catches up, IO-aware algorithm design loses some of its premium. Betting on that reversal has been wrong for twenty years, which is evidence, though not proof, that it stays wrong.
Attention versus its successors. State-space models and hybrid architectures promise \(O(N)\) time, not just \(O(N)\) memory. Whether they displace dense attention or settle into hybrid layouts alongside it is unresolved; what FlashAttention guarantees is that they must beat a far faster baseline than the one they were originally pitched against.
Sources and Further Reading
Foundational Papers
- Dao, Fu, Ermon, Rudra, Ré, 2022, FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, arXiv:2205.14135
- Milakov, Gimelshein, 2018, Online Normalizer Calculation for Softmax, arXiv:1805.02867
- Vaswani et al., 2017, Attention Is All You Need, arXiv:1706.03762
- Rabe, Staats, 2021, Self-attention Does Not Need O(n²) Memory, arXiv:2112.05682
Important Follow-up Work
- Dao, 2023, FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning, arXiv:2307.08691
- Shah, Bikshandi, Zhang, Thakkar, Ramani, Dao, 2024, FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision, arXiv:2407.08608
- Ivanov et al., 2020, Data Movement Is All You Need: A Case Study on Optimizing Transformers, arXiv:2007.00072
- Gholami, Yao, Kim, Hooper, Mahoney, Keutzer, 2024, AI and Memory Wall, arXiv:2403.14123
The Approximate-Attention Era
- Kitaev, Kaiser, Levskaya, 2020, Reformer: The Efficient Transformer, arXiv:2001.04451
- Beltagy, Peters, Cohan, 2020, Longformer: The Long-Document Transformer, arXiv:2004.05150
- Wang et al., 2020, Linformer: Self-Attention with Linear Complexity, arXiv:2006.04768
- Choromanski et al., 2020, Rethinking Attention with Performers, arXiv:2009.14794
Technical Resources
- Dao-AILab/flash-attention, reference implementation
- NVIDIA A100 Tensor Core GPU datasheet
- Colfax Research, FlashAttention-3 deep dive
[IMAGE: Side-by-side memory footprint bar chart at N = 4K/16K/64K/128K: standard attention's quadratic activation memory per layer vs FlashAttention's linear statistics, log y-axis, with the single-GPU 80 GB line marked]
[IMAGE: Stacked-area timing breakdown of one transformer layer forward pass before and after FlashAttention: matmul time vs memory-stall time, showing the stall region collapsing]
[IMAGE: Diagram of the online softmax rescaling step: two tiles with local maxima m1 and m2, showing the accumulator being multiplied by exp(m1 - m_new) when the larger maximum arrives]