Part 7 of 8

The Gradient Descent through Transformers

cd ../blog
TransformersNLPAttentionFlash AttentionGPUSystemsThe Gradient Descent through Transformers

Attention Part 4 — Flash Attention: Making GPUs Actually Work

May 13, 202633 min read

This is Part 7 of The Gradient Descent through Transformers — a series where I walk through every component of the modern transformer stack, how it evolved from 2017 to 2026, and why each piece matters.

Previously: Attention Part 3 — The Sparsity Lineage


A Different Kind of Optimization

In Parts 2 and 3, we changed what attention computes — sharing KV heads, compressing into latents, skipping sparse connections. Every technique modified the attention algorithm itself.

Flash Attention does something fundamentally different: it computes the exact same result as standard attention — same QKTQK^T, same softmax, same weighted sum of values — but rearranges how the GPU executes it. No approximation, no sparsity, no architectural change. Just a deeper understanding of hardware.

The result: 2-4× faster attention, 5-20× less memory, and the ability to scale to much longer sequences. This single optimization is why most models today can use full dense attention at 32-128K context lengths without needing the sparse techniques from Part 3.

To understand why this works, we need to understand why naive attention is slow in the first place. And that requires understanding how GPUs actually move data.


The GPU Memory Hierarchy

Every explanation of Flash Attention starts here, because this is the root cause of everything.

A modern GPU (like an A100 or H100) has two main memory levels:

HBM (High Bandwidth Memory) — Large but Slow

  • Capacity: 40-80 GB
  • Bandwidth: ~2 TB/s (A100), ~3.35 TB/s (H100)
  • Role: stores all your tensors — weights, activations, KV cache, everything

This is what people usually mean by "GPU memory." It's large enough to hold your model and activations, and its bandwidth sounds fast — 2 terabytes per second!

SRAM (On-Chip Memory) — Small but Fast

  • Capacity: ~20 MB total (distributed across streaming multiprocessors)
  • Bandwidth: ~19 TB/s
  • Role: the actual workspace where computation happens

SRAM is where the tensor cores do their matrix multiplies. It's roughly 10× faster than HBM, but 4000× smaller.

The Implication

Every arithmetic operation requires its operands to be in SRAM. So the lifecycle of a computation looks like:

  1. Load data from HBM into SRAM
  2. Compute on data in SRAM (matrix multiply, add, etc.)
  3. Store result back from SRAM to HBM

If the computation is simple (like an element-wise add) but the data is large, you spend most of your time on steps 1 and 3 — just moving bytes between HBM and SRAM. The tensor cores sit idle, waiting for data.

This is the difference between:

  • Compute-bound operations: the arithmetic takes longer than the data movement (e.g., large matrix multiplies with high arithmetic intensity)
  • Memory-bound operations: the data movement takes longer than the arithmetic (e.g., element-wise operations, reductions, softmax)

Why Naive Attention Is Memory-Bound

Here's how standard attention is typically implemented (in PyTorch, cuDNN, etc.):

# Step 1: Compute QK^T scores
S = Q @ K.T / sqrt(d_k)          # Shape: (N, N) — WRITTEN to HBM
 
# Step 2: Apply softmax
P = softmax(S, dim=-1)            # Shape: (N, N) — READ from HBM, WRITTEN to HBM
 
# Step 3: Multiply by values
O = P @ V                         # Shape: (N, d) — READ from HBM

Count the HBM round-trips for the N×NN \times N attention matrix:

StepHBM ReadsHBM Writes
S=QKTS = QK^TRead Q, KWrite S (N×NN \times N)
P=softmax(S)P = \text{softmax}(S)Read S (N×NN \times N)Write P (N×NN \times N)
O=PVO = PVRead P (N×NN \times N)Write O

The N×NN \times N matrix gets written to HBM, then read back, then written again, then read back again. Four passes over an N2N^2-sized tensor through the memory bus.

For N=8192N = 8192 (a modest context length) with FP16:

  • The attention matrix is 8192×8192×28192 \times 8192 \times 2 bytes = 128 MB per head
  • With 32 heads: 4 GB of intermediate storage
  • Each byte moves through the HBM bus multiple times

The actual matrix multiplies (QKTQK^T and PVPV) are fast — they have high arithmetic intensity. But the softmax in between forces us to materialize the entire N×NN \times N matrix in HBM, read it back, normalize it, write it back, and then read it again.

Softmax is the bottleneck. Not because it's expensive to compute (it's just exp, sum, divide), but because it requires access to the entire row before any output can be produced. You can't compute softmax(xi)=exi/jexj\text{softmax}(x_i) = e^{x_i} / \sum_j e^{x_j} without knowing all xjx_j values in the row first.

This seems to force us to materialize the full N×NN \times N matrix. You have to compute all scores in a row, store them somewhere (HBM — the only place big enough), then read them all back for softmax.

Or do you?


Tiling: A General GPU Optimization

Before jumping to Flash Attention's solution, let's understand tiling — a fundamental GPU optimization technique that's been used in high-performance computing for decades.

The Problem: Data Doesn't Fit in Fast Memory

Imagine you need to multiply two large matrices AA and BB, each 1024×10241024 \times 1024. The full matrices live in slow memory (HBM), but your fast memory (SRAM) can only hold, say, three 64×6464 \times 64 blocks at a time.

Naive approach: for each output element C[i,j]C[i,j], load the entire row ii of AA and column jj of BB from slow memory, compute the dot product, store the result. You end up loading the same rows and columns over and over — terrible memory reuse.

Tiled approach: divide all three matrices into small blocks (tiles). Instead of computing one element of CC at a time, compute an entire block of CC at a time — and accumulate it from multiple partial contributions.

Let's make this concrete. Say we have two 4×4 matrices divided into 2×2 tiles:

Matrix A:                          Matrix B:
[ 1   2 |  3   4 ]                [ 2   0 |  1   3 ]
[ 5   6 |  7   8 ]                [ 1   4 |  0   2 ]
[-------+--------]                [-------+--------]
[ 9  10 | 11  12 ]                [ 3   1 |  2   0 ]
[13  14 | 15  16 ]                [ 0   2 |  1   1 ]

A00 = [1 2]  A01 = [3 4]          B00 = [2 0]  B01 = [1 3]
      [5 6]        [7 8]                [1 4]        [0 2]

A10 = [9  10]  A11 = [11 12]      B10 = [3 1]  B11 = [2 0]
      [13 14]        [15 16]            [0 2]        [1 1]

Block view:
A = [A00 | A01]    B = [B00 | B01]    C = [C00 | C01]
    [----+----]        [----+----]        [----+----]
    [A10 | A11]        [B10 | B11]        [C10 | C11]

Now, recall that matrix multiply works element-wise as C[i,j]=kA[i,k]B[k,j]C[i,j] = \sum_k A[i,k] \cdot B[k,j]. The exact same formula works at the block level:

C00=A00×B00+A01×B10C_{00} = A_{00} \times B_{00} + A_{01} \times B_{10}

Here kk is the index along the shared dimension — it walks along the columns of A and the rows of B simultaneously. It's the dimension that gets summed over and "disappears" in the output. In our 2-tile example, kk goes from 0 to 1.

To compute one output block C00C_{00}, we walk along kk, loading one tile from each matrix at each step — while C00C_{00} accumulates entirely in SRAM:

  1. k=0: Load A00A_{00} and B00B_{00} into SRAM → compute A00×B00A_{00} \times B_{00}C00C_{00} starts accumulating in SRAM
  2. k=1: Load A01A_{01} and B10B_{10} into SRAM → compute A01×B10A_{01} \times B_{10}add to C00C_{00}, still in SRAM
  3. All k-steps done → C00C_{00} is fully computed → write final result to HBM once

The partial C00C_{00} never leaves SRAM during accumulation. Only the completed block gets written back to slow memory.

In pseudocode for the general case:

# Tiled matrix multiply
For each block row i of C:
    For each block column j of C:
        C_ij = 0  (in SRAM)
        For each block k along the shared dimension:
            Load A_ik, B_kj into SRAM       # one read from HBM
            C_ij += A_ik @ B_kj              # small matmul in SRAM, accumulate
        Write C_ij to HBM                    # one write — only the final result

The arithmetic is identical — same number of multiplies and adds. But data movement drops dramatically because each block loaded into SRAM gets reused across all the output elements it contributes to, and partial results accumulate in fast memory instead of bouncing through HBM.

Matrix Multiply: Naive vs Tiled
Step through both approaches to see how tiling reduces data movement between HBM (slow) and SRAM (fast). Same arithmetic, far fewer memory loads.
A
4×4
×
B
4×4
=
C
4×4
HBM
Large, slow (~2 TB/s)
A (full)B (full)C (full)
SRAM
Small, fast (~19 TB/s)
Row 0 of ACol 0 of B
Load full row 0 of A and full column 0 of B to compute C[0,0]
HBM loads so far: 8 elements (row + column each time)
Step 1 / 16
Naive Approach
Steps: 16 (one per output element)
Total HBM loads: 128 elements
Each row/column reloaded 4 times
Tiled Approach
Steps: 8 (one per tile pair)
Total HBM loads: 64 elements
Each tile loaded 2 times (once per k-step)

Why Tiling Works

The insight is about arithmetic intensity — the ratio of compute operations to memory operations:

  • Without tiling: you load data, use it once, throw it away. Low arithmetic intensity. Memory-bound.
  • With tiling: you load data, use it many times while it's in fast memory, then throw it away. High arithmetic intensity. Compute-bound (or at least closer to it).

Tiling doesn't reduce total FLOPs. It reduces total bytes moved between slow and fast memory by maximizing data reuse within the fast memory.

The Dream for Attention

So the natural thought is: can we tile attention the same way? Process the N×NN \times N attention computation in small blocks that fit in SRAM, accumulate partial results, and avoid ever materializing the full matrix in HBM?

For the QKTQK^T matmul and the PVPV matmul — yes, standard tiling works fine. But there's an important difference from the matrix multiply tiling we just learned.

Why Attention Tiling Is Simpler Than Matmul Tiling

In our matmul example, C=A×BC = A \times B where both AA and BB were (N×N)(N \times N). Since the shared dimension was also large, we needed to tile the matrix along both dimensions so that we get a small tile which fits in GPU SRAM. And since the tiling occurred along both dimensions, we needed more than one iteration (k-steps) to arrive at the final output for each tile:

C00=A00×B00+A01×B102 k-steps to complete one output tileC_{00} = A_{00} \times B_{00} + A_{01} \times B_{10} \quad \leftarrow \text{2 k-steps to complete one output tile}

Visually, our (4×4)(4 \times 4) matrices with (2×2)(2 \times 2) tiles looked like this:

A (4×4)                          B (4×4)
┌───────┬───────┐                ┌───────┬───────┐
│ A[0,0]│ A[0,1]│                │ B[0,0]│ B[0,1]│
│  2×2  │  2×2  │                │  2×2  │  2×2  │
├───────┼───────┤                ├───────┼───────┤
│ A[1,0]│ A[1,1]│                │ B[1,0]│ B[1,1]│
│  2×2  │  2×2  │                │  2×2  │  2×2  │
└───────┴───────┘                └───────┴───────┘
   ↑ tiled along rows               ↑ tiled along rows
     AND columns (shared dim)          AND columns (shared dim)

Both dimensions of both matrices were tiled. A single tile from AA and a single tile from BB were not enough to produce a final output tile — we needed to walk along kk and accumulate.

Attention is different. In self-attention, S=Q×KTS = Q \times K^T where QQ is (N×d)(N \times d) and KTK^T is (d×N)(d \times N). The shared dimension here is dd — the head dimension, typically just 64 or 128. The sequence length NN can be enormous (128K tokens), but dd is small and fixed.

This means we only need to tile along NN. The dd dimension is small enough to keep intact within each tile:

Q (N×d)                           K^T (d×N)
┌───────┐                        ┌───────┬───────┬───────┐
│ Q[0]  │ ← B×d                  │ K[0]ᵀ │ K[1]ᵀ │ K[2]ᵀ │
│       │                        │  d×B  │  d×B  │  d×B  │
├───────┤                        └───────┴───────┴───────┘
│ Q[1]  │ ← B×d                     ↑ tiled along N only
│       │                           d dimension kept whole
├───────┤
│ Q[2]  │ ← B×d
│       │
└───────┘
   ↑ tiled along N only
     d dimension kept whole

Notice the difference: each QQ block is (B×d)(B \times d) — it contains the full dd-dimensional representation for BB tokens. Each KTK^T block is (d×B)(d \times B) — again, the full dd dimension. Because dd is not tiled, a single pair of blocks is enough to produce an output tile with no accumulation:

S[i,j]=Qi×KjTdone in one step, no k-iterationS[i,j] = Q_i \times K_j^T \quad \leftarrow \text{done in one step, no k-iteration}

With block size B=128B = 128 and head dimension d=128d = 128, each block is 128×128=16,384128 \times 128 = 16{,}384 values = 32 KB in FP16. Two blocks plus the output tile fit comfortably in SRAM's ~20 MB. No partial sums, no k-iteration — each score tile is a self-contained matmul.

Matmul:    (N × N) × (N × N)  → shared dim = N (huge, must tile → k-iteration)
Attention: (N × d) × (d × N)  → shared dim = d (tiny, fits whole → single pass)

This is what makes Flash Attention's tiling practical: we only tile the big dimension (NN), and the small dimension (dd) stays intact. Each score tile S[i,j]S[i,j] can be fully computed from just one QQ block and one KK block — no need to revisit them.

The Softmax Problem

And that's exactly the problem: softmax needs the full row.

If you're computing attention in tiles (say, processing K/V block 0 first, then block 1), you can't compute the final softmax denominator until you've seen all K/V blocks. The normalization constant jexj\sum_j e^{x_j} depends on every score in the row.

This is the obstacle that made everyone assume tiling was impossible for attention. Each score tile is easy to compute, but softmax forces you to materialize the full row of scores in HBM before normalizing. And it's what the online softmax trick solves.


Online Softmax: The Trick That Enables Tiling

Standard softmax for a row [x1,x2,,xN][x_1, x_2, \ldots, x_N]:

softmax(xi)=eximj=1Nexjm,m=max(x1,,xN)\text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_{j=1}^{N} e^{x_j - m}}, \quad m = \max(x_1, \ldots, x_N)

(The m=maxm = \max subtraction is for numerical stability — standard practice.)

The problem: you need mm (the max over the full row) and the full denominator before producing any output. Both require seeing every element in the row first. That's what forces us to materialize the full row in HBM — we can't tile softmax the way we tiled matrix multiply.

Or can we?

The Intuition: Fix Your Mistakes As You Go

Online softmax's insight is simple: process one chunk at a time, and retroactively correct your past work whenever new data changes the picture.

Let's walk through it with concrete numbers. Say the full row of attention scores is [2, 8, 1, 9, 3, 7] and we can only hold 3 values in SRAM at a time.

Chunk 1: [2, 8, 1]

We haven't seen the full row yet, but we do our best with what we have:

  • Best max so far: m=8m = 8
  • Compute exponentials: e28, e88, e18=[0.0025, 1.0, 0.0009]e^{2-8},\ e^{8-8},\ e^{1-8} = [0.0025,\ 1.0,\ 0.0009]
  • Running sum: =1.0034\ell = 1.0034

For the data we've seen so far, this is perfectly correct. If the row ended here, dividing by \ell would give us the right softmax.

Chunk 2: [9, 3, 7]

New chunk arrives. Its max is 9 — bigger than our old max of 8. This means all our old exponentials were computed as ex8e^{x - 8} when they should have been ex9e^{x - 9}. They're all too big.

By how much? Each old term should be:

exj9e^{x_j - 9}

But we computed:

exj8e^{x_j - 8}

Using the exponent rule ea+b=eaebe^{a+b} = e^a \cdot e^b:

exj9=e(xj8)+(89)=exj8e89e^{x_j - 9} = e^{(x_j - 8) + (8 - 9)} = e^{x_j - 8} \cdot e^{8-9}

=exj8what we already have×e10.368= \underbrace{e^{x_j - 8}}_{\text{what we already have}} \times \underbrace{e^{-1}}_{\approx\, 0.368}

The correction factor emoldmnewe^{m_{old} - m_{new}} falls out naturally — it doesn't depend on jj, so every term gets the same fix. One multiplication corrects the entire running sum:

corrected=old×e89=1.0034×0.368=0.369\ell_{\text{corrected}} = \ell_{\text{old}} \times e^{8-9} = 1.0034 \times 0.368 = 0.369

Now process the new chunk normally with the updated max:

  • New exponentials: e99, e39, e79=[1.0, 0.0025, 0.135]e^{9-9},\ e^{3-9},\ e^{7-9} = [1.0,\ 0.0025,\ 0.135]
  • Add to corrected sum: =0.369+1.137=1.506\ell = 0.369 + 1.137 = 1.506
  • Update max: m=9m = 9

After both chunks, we have m=9m = 9 and =1.506\ell = 1.506exactly what batch softmax over the full row [2, 8, 1, 9, 3, 7] would produce. We just never needed all 6 values in memory at once.

The Same Correction Fixes the Output

In Flash Attention, we're not just computing softmax weights — we're computing the weighted sum O=softmax(S)VO = \text{softmax}(S) \cdot V. The trick is to track the unnormalized weighted sum and only normalize at the very end.

After chunk 1, we compute the unnormalized output:

O~=e28V1+e88V2+e18V3\tilde{O} = e^{2-8} \cdot V_1 + e^{8-8} \cdot V_2 + e^{1-8} \cdot V_3

These exponentials were computed with m=8m = 8. When chunk 2 reveals the true max is 9, every exponential is too large by e89e^{8-9}. Same single-multiplication fix:

O~corrected=O~old×e89+e99V4+e39V5+e79V6new chunk’s unnormalized contribution\tilde{O}_{\text{corrected}} = \tilde{O}_{\text{old}} \times e^{8-9} + \underbrace{e^{9-9} \cdot V_4 + e^{3-9} \cdot V_5 + e^{7-9} \cdot V_6}_{\text{new chunk's unnormalized contribution}}

After all chunks are processed, we normalize once: O=O~/O = \tilde{O} / \ell.

Why unnormalized? Because if we normalized after each chunk (dividing by the local \ell), the correction would need both the old \ell and the new \ell — much messier. With unnormalized O~\tilde{O}, the correction is the exact same emoldmnewe^{m_{\text{old}} - m_{\text{new}}} factor as the sum. One multiplication fixes everything.

The General Update Rules

For each new block bb of scores, three running values get updated:

  1. Max: mnew=max(mold, mb)m_{\text{new}} = \max(m_{\text{old}},\ m_b)

  2. Sum: new=oldemoldmnew+jbexjmnew\ell_{\text{new}} = \ell_{\text{old}} \cdot e^{m_{\text{old}} - m_{\text{new}}} + \sum_{j \in b} e^{x_j - m_{\text{new}}}

  3. Unnormalized output: O~new=O~oldemoldmnew+jbexjmnewVj\tilde{O}_{\text{new}} = \tilde{O}_{\text{old}} \cdot e^{m_{\text{old}} - m_{\text{new}}} + \sum_{j \in b} e^{x_j - m_{\text{new}}} \cdot V_j

  4. Final normalization (after all blocks): O=O~/O = \tilde{O} / \ell

The correction factor emoldmnewe^{m_{\text{old}} - m_{\text{new}}} is always 1\leq 1 (the new max is at least as large as the old). It retroactively scales down all previous results as if they'd been computed with the true max from the start. The division by \ell happens only once, at the very end — not after each block.

This means we can process the attention matrix one tile at a time, maintaining a small running state (m,,O~)(m, \ell, \tilde{O}) in SRAM, without ever writing the full attention matrix to HBM.

Note: online softmax wasn't invented by Flash Attention — it was proposed by Milakov & Gimelshein (2018). Flash Attention's contribution is recognizing that this technique enables something much bigger.


The Real Contribution: Kernel Fusion

What Is a GPU Kernel?

A kernel is a function that runs on the GPU. Every time PyTorch executes an operation — a matrix multiply, a softmax, an element-wise add — it launches a kernel: a self-contained job that the GPU picks up, runs, and finishes.

Here's the important part: each kernel starts with a blank SRAM. It reads its inputs from HBM into SRAM, does its computation, writes outputs back to HBM, and then SRAM is cleared for the next kernel. Even if the previous kernel had useful data sitting in SRAM, the next kernel can't see it — it has to re-read everything from HBM.

Think of it like a kitchen with a tiny countertop (SRAM) and a large pantry (HBM). Each kernel is a different cook who comes in, grabs ingredients from the pantry, works on the countertop, puts the result back in the pantry, and then wipes the countertop clean before leaving. The next cook can't reuse anything on the countertop — they have to go back to the pantry for everything.

Why This Matters for Attention

Online softmax is the enabler, but the core contribution of Flash Attention is fusing the entire attention operation into a single kernel — one cook who does everything on the countertop without putting intermediate results back in the pantry.

In standard PyTorch, attention is three separate kernel launches:

S = Q @ K.T / sqrt(d_k)    # Kernel 1: matmul — reads Q,K from HBM, writes S to HBM
P = softmax(S)               # Kernel 2: softmax — reads S from HBM, writes P to HBM
O = P @ V                    # Kernel 3: matmul — reads P,V from HBM, writes O to HBM

Each kernel launch is a pantry round-trip: read inputs from HBM, compute on the countertop, put results back in HBM, wipe the countertop. The N×N attention matrix (SS, then PP) gets put back in the pantry and retrieved four times across these three kernels — because each cook can't see what the previous one left on the countertop.

Flash Attention replaces all three cooks with one cook who does everything in a single kernel:

O = flash_attention(Q, K, V)  # Single kernel: reads Q,K,V from HBM, writes O to HBM
                               # S and P never exist in HBM — only in SRAM, tile by tile

Inside this single kernel:

  • A tile of scores is computed (QiKjTQ_i K_j^T) and stays in SRAM
  • Softmax is applied to that tile (via online softmax), still in SRAM
  • The tile is multiplied by VjV_j, accumulated into the running output, still in SRAM
  • The tile is discarded — it never leaves SRAM

The N×N matrix is never materialized anywhere. Each element is born in SRAM, used once, and forgotten. Only the final output OO (size N×dN \times d, not N×NN \times N) gets written to HBM.

This is the fundamental insight: it's not about a smarter algorithm for softmax — it's about eliminating memory round-trips by fusing operations that are normally separate. Online softmax is what makes the fusion possible (without it, softmax forces you to materialize the full row), but the fusion itself is the contribution.


The Full Algorithm

With kernel fusion and online softmax, the complete algorithm becomes:

The Outer Loop (Over K/V Blocks) — Flash Attention 1

# Initialize: for each Q row, set Õ = 0, m = -∞, l = 0 (stored in HBM)

For each block j of K and V:
    Load K_j, V_j from HBM to SRAM

    For each block i of Q:
        Load Q_i, Õ_i, m_i, l_i from HBM to SRAM
        #  ↑ m_i and l_i are the running stats from previous j iterations

        # Compute attention scores for this tile — stays in SRAM
        S_ij = Q_i @ K_j.T / sqrt(d)

        # Online softmax update — all in SRAM
        m_i = max(m_i, rowmax(S_ij))
        correction = exp(m_i_old - m_i)
        l_i = l_i * correction + rowsum(exp(S_ij - m_i))
        Õ_i = Õ_i * correction + exp(S_ij - m_i) @ V_j

        # S_ij is discarded here — never written to HBM

        # Updated running state goes back to HBM (small: block_size × d)
        Write Õ_i, m_i, l_i back to HBM
        #  ↑ next time this Q block is visited (j+1), these become the inputs

# After all j blocks: O_i = Õ_i / l_i  for each Q block (normalize once)
Flash Attention 1: Step by Step
Watch how attention scores are computed tile-by-tile in SRAM without ever materializing the full N×N matrix in HBM. 6 tokens, block size 2 = 3 blocks.
Attention Score Matrix (N×N)
K[0]
K[1]
K[2]
Q[0]
Q[1]
Q[2]
S[0,0]loading...
S[0,1]
S[0,2]
S[1,0]
S[1,1]
S[1,2]
S[2,0]
S[2,1]
S[2,2]
Each tile is 2×2 scores. Only the active tile exists in SRAM — the rest are never stored.
SRAM
Fast, tiny (~20 MB) — the workspace
Q[0]K[0]V[0]Õ[0]m[0]l[0]
HBM
Slow, large (80 GB) — the storage
Q (full)K (full)V (full)Õ[0], m[0], l[0]Õ[1], m[1], l[1]Õ[2], m[2], l[2]
Step 1/27:Load K[0], V[0] and Q[0] with its running state (Õ[0], m[0], l[0]) from HBM into SRAM
HBM reads: 1HBM writes: 0Score tiles in HBM: 0 (always)

Notice: the running state (Õ_i, m_i, l_i) gets read from and written to HBM in the inner loop — once per (Q block, K/V block) pair. For N/BN/B Q-blocks and N/BN/B KV-blocks, that's (N/B)2(N/B)^2 HBM round-trips for the partial outputs. Each round-trip is small (B×dB \times d instead of N×NN \times N), but it adds up. Flash Attention 2 fixes this.

What's Different

Naive AttentionFlash Attention
HBM reads/writesO(N2)O(N^2) for the attention matrixO(N2d/M)O(N^2 d / M) where MM = SRAM size
Intermediate storageFull N×NN \times N matrix in HBMOne tile in SRAM at a time
Memory usageO(N2)O(N^2)O(N)O(N) — only stores O, m, l per row
ComputeSame O(N2d)O(N^2 d) FLOPsSame O(N2d)O(N^2 d) FLOPs

The total FLOPs don't change — we're still computing every entry of QKTQK^T. But the number of HBM accesses drops dramatically because we never materialize the full attention matrix. Each element of the attention matrix is computed in SRAM, used immediately, and discarded.

The Memory Savings

Standard attention stores the N×NN \times N attention matrix — that's O(N2)O(N^2) memory. For a 128K sequence: 128K×128K×2128K \times 128K \times 2 bytes = 32 GB per head. Clearly impossible.

Flash Attention stores only the running statistics per row: O (N×dN \times d), m (NN), l (NN). That's O(Nd)O(N \cdot d)linear in sequence length. For 128K with d=128d = 128: ~33 MB per head. A 1000× reduction.

This is why Flash Attention unlocked long-context models. Before it, attention at 32K+ tokens was memory-constrained regardless of compute. After it, the memory constraint is lifted entirely.


The Backward Pass: Even Trickier

In training, we need gradients. Standard attention stores the full N×NN \times N matrix PP for the backward pass (it's needed to compute LS\frac{\partial L}{\partial S}). That's the same O(N2)O(N^2) memory we just eliminated.

Flash Attention's solution: recompute PP from QQ, KK during the backward pass rather than storing it.

This sounds wasteful — we're doing extra FLOPs. But the tradeoff is:

  • Without recomputation: store N2N^2 values in HBM (memory-bound, kills max sequence length)
  • With recomputation: extra O(N2d)O(N^2 d) FLOPs (compute, which is cheap relative to memory bandwidth)

Since naive attention is memory-bound (not compute-bound), adding more compute doesn't actually slow things down much — the GPU was already waiting on memory. Trading memory for recomputation is essentially free.


Flash Attention 2: The Loop Swap

Look at Flash Attention 1's pseudocode again. The running state (O~i\tilde{O}_i, mim_i, i\ell_i) gets read from and written to HBM in the inner loop — once per (Q block, K/V block) pair. For N/BN/B Q-blocks and N/BN/B KV-blocks, that's (N/B)2(N/B)^2 HBM round-trips for partial outputs. Each round-trip is small (B×dB \times d instead of N×NN \times N), but it adds up.

Why does this happen? Because K/V is in the outer loop. When we fix a K/V block and cycle through all Q blocks in the inner loop, each Q block's partial output O~i\tilde{O}_i gets loaded, updated, and written back — then we move on to the next Q block. By the time the next K/V block comes around in the outer loop and we revisit Q block ii, its partial output has been evicted from SRAM long ago. So we must write it to HBM after each visit and reload it next time.

An analogy. Imagine you're a student taking notes from 3 textbooks. You have a tiny desk (SRAM) and a big bookshelf across the room (HBM).

Flash Attention 1 (textbook-first — K/V outer loop):

  • Pick up Textbook 1 (K/V block), bring it to desk
  • Get Notebook A (Õ for Q block 0) from shelf, write notes, put Notebook A back on shelf
  • Get Notebook B (Õ for Q block 1) from shelf, write notes, put Notebook B back on shelf
  • Get Notebook C (Õ for Q block 2) from shelf, write notes, put Notebook C back on shelf
  • Put Textbook 1 back. Pick up Textbook 2.
  • Walk to shelf, get Notebook A again, write more notes, put it back...

Every notebook gets picked up and put back 3 times (once per textbook). That's 9 shelf trips for notebooks.

Flash Attention 2 (notebook-first — Q outer loop):

  • Pick up Notebook A, keep it on your desk the whole time
  • Bring Textbook 1 → write notes into A → done with Textbook 1, discard it
  • Bring Textbook 2 → write notes into A → done, discard it
  • Bring Textbook 3 → write notes into A → done, discard it
  • Notebook A is complete → put it on shelf. One trip.
  • Pick up Notebook B, repeat...

Each notebook is picked up once and put back once. 3 shelf trips total instead of 9.

Mapping back: Notebooks = O~i\tilde{O}_i (partial output being accumulated). Textbooks = KjK_j, VjV_j blocks (read once, used, discarded). Shelf = HBM. Desk = SRAM.

The rule: whatever is being accumulated should stay on the desk. Whatever is consumed once should cycle through. FA1 had it backwards — the accumulating output was in the inner loop, so it kept getting evicted. FA2 makes the output the "resident" and cycles K/V through instead.

In pseudocode:

For each block i of Q:                          # Q in OUTER loop
    Load Q_i from HBM to SRAM
    Initialize Õ_i, m_i, l_i in SRAM            # stays here the whole time

    For each block j of K and V:                 # K/V in INNER loop
        Load K_j, V_j from HBM to SRAM
        # Same online softmax update as before
        # Õ_i stays in SRAM — no HBM write between K/V blocks

    O_i = Õ_i / l_i                              # normalize once
    Write final O_i to HBM                       # only ONCE per Q block

With Q in the outer loop, each thread block "owns" its Q block's partial output for the entire inner loop. O~i\tilde{O}_i stays in SRAM while all K/V blocks cycle through — it's only written to HBM once, when the inner loop finishes. This cuts the partial output HBM round-trips from (N/B)2(N/B)^2 down to just N/BN/B.

Additional improvements:

  • Better work partitioning across warps: within a thread block, different warps can process different K/V blocks in parallel for the same Q block, then reduce.
  • ~2× speedup over Flash Attention 1, reaching 50-73% of theoretical peak FLOP/s on A100.

Additional Optimizations

  • Causal masking without waste: for causal attention, roughly half the tiles are entirely masked out (all scores = -\infty). Flash Attention 2 skips these tiles entirely — no compute, no memory access. This gives an automatic ~1.7× speedup for autoregressive models.
  • Head dimension parallelism: for GQA/MQA where multiple Q heads share one KV head, Flash Attention 2 parallelizes across the Q heads for the same KV, improving utilization.

Flash Attention 3: Hopper-Specific

Flash Attention 3 (Dao et al., 2024) targets the H100's Hopper architecture with hardware-specific optimizations:

  • Warp specialization: separate warps for data movement (producer) and computation (consumer), overlapping memory loads with tensor core operations
  • FP8 support: using the H100's FP8 tensor cores for the QKTQK^T matmul with FP32 accumulation, then FP16/BF16 for the softmax and output
  • Asynchronous operations: using the Tensor Memory Accelerator (TMA) for async data loads while compute proceeds
  • Block quantization: per-block scaling factors for FP8 to maintain accuracy

Result: 1.5-2× faster than Flash Attention 2 on H100, approaching 75% of peak FP8 throughput.


What Flash Attention Actually Changed

Before Flash Attention (pre-2022):

  • Maximum practical sequence length: ~2-4K tokens (memory-limited)
  • Longer contexts required sparse attention patterns, LSH approximations, or linear attention approximations
  • Even "efficient" transformers couldn't match quality of full attention

After Flash Attention:

  • Full dense attention at 32-128K tokens — no approximation needed
  • Most production LLMs use Flash Attention by default (GPT-4, Claude, Gemma, Llama, Mistral — all of them)
  • Sparse attention reserved for >128K contexts where even Flash Attention's compute cost becomes prohibitive
  • Training throughput improved 2-4× across the board

It's rare for a single systems paper to so completely change what's architecturally feasible. Flash Attention didn't invent a new attention mechanism — it made the existing one fast enough that most of the "efficient attention" literature from 2019-2022 became unnecessary.


Paged Attention: The Serving Side (Kwon et al., 2023)

Flash Attention solves the training and compute problem. But when you're serving an LLM — handling hundreds of concurrent requests of different lengths — there's a different memory bottleneck: KV cache fragmentation.

The Problem: Wasted GPU Memory During Serving

During inference, each new token needs to attend to all previous tokens. To avoid recomputing QKTQK^T for all previous positions every time, we cache the K and V vectors for every token we've already processed. This is the KV cache — it grows by one K vector and one V vector per layer per token generated.

Now, how do you allocate memory for this cache? You don't know upfront how long each request will be. In a naive serving system, you pre-allocate a contiguous chunk of GPU memory for each request's KV cache based on the maximum possible sequence length. If your model supports 32K tokens, you allocate 32K slots — even if the request ends at 500 tokens.

Imagine you're a hotel manager. A guest checks in and says "I might stay 1 night or 30 nights." You don't know. So you reserve 30 consecutive rooms (contiguous memory). Even if they leave after 2 nights, those 30 rooms stay blocked off. Now multiply this by hundreds of guests (concurrent requests). Most of your hotel is "reserved but empty" — you're turning away new guests because you have no rooms, even though 80% are physically unoccupied.

With hundreds of concurrent requests:

  • Most requests use a fraction of their allocated memory
  • The unused memory can't be given to other requests (it's reserved contiguously)
  • GPU memory utilization drops to 20-40% — most of it is "reserved but empty"

This is the classic problem of internal fragmentation — the same problem that plagued early operating systems before virtual memory.

The Solution: Virtual Memory for KV Cache

Paged Attention (the core of vLLM) borrows directly from operating systems: don't allocate memory contiguously. Use pages.

Instead of reserving consecutive rooms, you give each guest a room key card that works on any room. When they need night 1, you assign whichever room happens to be free. Night 2 — a different room (not next to the first, doesn't matter). You keep a small table mapping "Guest A, night 1 → room 7, night 2 → room 23, ..."

Concretely, the KV cache is divided into fixed-size blocks (pages), and each request gets a page table that maps logical positions to physical memory blocks. Blocks are allocated on demand — a new block is assigned only when the sequence actually grows into it. When a request finishes, its blocks are immediately freed for other requests.

A Concrete Example: 4 Concurrent Requests

Let's say 4 requests arrive on a GPU with 20 free memory blocks. Each block holds 4 tokens worth of KV vectors.

Naive system — must pre-allocate contiguously for max length (say 16 tokens = 4 blocks each):

GPU Memory: [A0][A1][A2][A3] [B0][B1][B2][B3] [C0][C1][C2][C3] [D0][D1][D2][D3] [--][--][--][--]
             ← Request A →   ← Request B →     ← Request C →    ← Request D →     4 free

Request A: "Tell me a joke"     → generates 5 tokens, uses 2 blocks, wastes 2
Request B: "Translate this..."  → generates 14 tokens, uses 4 blocks, wastes 0
Request C: "Hi"                 → generates 3 tokens, uses 1 block, wastes 3
Request D: "Summarize..."       → generates 8 tokens, uses 2 blocks, wastes 2

16 blocks reserved. Only 9 actually used. 7 blocks sitting empty but locked. If a 5th request arrives — rejected, even though there's plenty of physical space.

Paged Attention — allocates one block at a time, on demand:

Step 1: All 4 requests start. Each needs 1 block for their first few tokens.

  GPU Memory: [A0][B0][C0][D0][--][--][--][--][--]...
  Page tables:
    A: [0→block0]    B: [0→block1]    C: [0→block2]    D: [0→block3]
  Free: 16 blocks

Step 2: A and B need a 2nd block. C is done (3 tokens fit in 1 block).

  GPU Memory: [A0][B0][C0][D0][A1][B1][--][--][--]...
  Page tables:
    A: [0→block0, 1→block4]
    B: [0→block1, 1→block5]
    C: [0→block2]                ← done, will be freed
    D: [0→block3]
  Free: 14 blocks

Step 3: C finishes — its block is immediately freed. D needs another block.
         A 5th request E arrives — no problem, plenty of space.

  GPU Memory: [A0][B0][E0][D0][A1][B1][D1][--][--]...
                        ↑ block2 recycled to E
  Page tables:
    A: [0→block0, 1→block4]
    B: [0→block1, 1→block5]
    D: [0→block3, 1→block6]
    E: [0→block2]                ← reusing C's freed block
  Free: 13 blocks

Step 4: B grows long (14 tokens = 4 blocks). A finishes.

  GPU Memory: [--][B0][E0][D0][--][B1][D1][B2][B3][--]...
  Page tables:
    B: [0→block1, 1→block5, 2→block7, 3→block8]
    D: [0→block3, 1→block6]
    E: [0→block2]
  Free: A's blocks 0,4 now freed → 13 blocks available

Notice: no request reserves more than it currently needs. Freed blocks are immediately reusable — when C finished, its block was recycled to E. B's KV cache lives in blocks 1, 5, 7, 8 (scattered, not consecutive), and the page table handles the mapping. The 5th request was accepted where the naive system would have rejected it.

Why This Matters

Contiguous AllocationPaged Attention
Memory allocated per requestmax_seq_len × hidden_dimactual_seq_len × hidden_dim
FragmentationInternal (wasted reserved space)Near-zero
Memory utilization20-40% typical90%+
Max concurrent requestsLimited by over-allocation2-4× more requests

In practice, Paged Attention enables 2-4× higher serving throughput simply by fitting more requests in memory simultaneously. The actual attention computation is unchanged — it just reads KV vectors from non-contiguous blocks via the page table.

How It Composes with Flash Attention

Flash Attention and Paged Attention solve orthogonal problems:

  • Flash Attention: makes each attention computation fast (training + inference)
  • Paged Attention: makes memory management efficient when serving many requests

vLLM uses both: Flash Attention kernels for the actual attention math, Paged Attention for managing where KV blocks live in memory. The Flash Attention kernel is modified slightly to read K/V from non-contiguous block locations (a lookup through the page table), but the tiling and online softmax logic is the same.

Bonus: Copy-on-Write for Shared Prefixes

Paged Attention enables another OS-inspired optimization: copy-on-write. When multiple requests share the same prefix (e.g., the same system prompt), their page tables can point to the same physical blocks for the shared portion. The blocks are only copied when a request diverges.

For serving scenarios where many users share the same system prompt (common in API serving), this can save 30-50% of KV cache memory — without any computation overhead.


The Takeaway for Practitioners

  1. Attention is memory-bound, not compute-bound. The expensive part isn't the arithmetic — it's moving data between HBM and SRAM. Any optimization should focus on reducing data movement.

  2. Never materialize large intermediate tensors. If a computation produces an O(N2)O(N^2) intermediate that's immediately consumed, fuse the operations so it only ever exists in SRAM.

  3. Recomputation can be free. When you're memory-bound, trading extra FLOPs for less memory is a net win. The GPU was idle anyway.

  4. Understand your hardware. Flash Attention isn't a clever algorithm trick — it's a consequence of understanding the GPU's memory hierarchy. The same principle applies to any memory-bound operation.

  5. Use Flash Attention. Don't implement attention yourself unless you're doing research on new attention patterns. Use torch.nn.functional.scaled_dot_product_attention (which uses Flash Attention under the hood), or the flash-attn package directly.


References


Next up: Beyond Attention: Anatomy of a Modern Transformer — every non-attention component and how it evolved.

This post is part of The Gradient Descent through Transformers — a series dissecting every component of the modern transformer stack.