Part 6 of 8

The Gradient Descent through Transformers

cd ../blog
TransformersNLPAttentionSparse AttentionFlash AttentionThe Gradient Descent through Transformers

Attention Part 3 — The Sparsity Lineage: Making Attention Sub-Quadratic

May 12, 202622 min read

This is Part 6 of The Gradient Descent through Transformers — a series where I walk through every component of the modern transformer stack, how it evolved from 2017 to 2026, and why each piece matters.

Previously: Attention Part 2 — The Sharing Lineage


The Problem: Quadratic Compute

In Part 2, we solved the memory problem — making the KV cache smaller through sharing and compression. But there's a second cost that grows just as painfully: compute.

Standard attention computes a score between every query and every key:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

That QKTQK^T is an n×nn \times n matrix. For a sequence of 128K tokens, that's 16.4 billion entries — per head, per layer. Even with MLA's compressed cache, the attention computation itself remains quadratic.

This cost doesn't just burn FLOPs. It means:

  • Prefill latency scales quadratically (processing a 128K prompt takes 16× longer than 32K)
  • Training cost for long-context models is dominated by attention
  • Hardware utilization drops as the attention matrix overflows fast GPU memory tiers

The sharing lineage asked: "can we store less?" The sparsity lineage asks a different question: "can we compute less?"


The Core Insight: Most Attention Is Wasted

Before diving into solutions, let's understand why sparsity might work — and more importantly, let's convince ourselves with evidence that it should work.

What Actually Happens Inside Attention

Recall the attention formula: for each query token qtq_t, we compute a score against every key token, apply softmax, and produce a weighted sum of values. Softmax guarantees all weights sum to 1. That means if a token at position 5000 finds 10 tokens genuinely relevant, it still has to spread some weight across the other 4990 tokens. The distribution can be peaky (concentrated on a few keys), but it can never be truly sparse — softmax always assigns some probability mass everywhere.

In practice, trained attention patterns look remarkably concentrated. Studies of GPT-2's attention heads (Clark et al., 2019) found that:

  • Many heads develop highly stereotyped patterns (attending to previous token, sentence start, or specific syntactic positions)
  • In most heads, over 90% of attention weight concentrates on fewer than 10% of available positions
  • The remaining 90% of positions each receive tiny slivers of weight (0.01-0.5% each) that contribute almost nothing to the output

This makes intuitive sense when you think about language structure:

  • Local dominance: the next word depends primarily on its immediate context. "The cat sat on the ___" — the answer depends on "sat on the", not on a sentence from 3 paragraphs ago.
  • Sparse long-range dependencies: occasionally a token genuinely needs distant context (pronoun resolution, callback to a prior topic), but this is the exception, not the rule.
  • Positional biases: some positions (sentence starts, punctuation, function words like "the") attract attention from many positions regardless of content — a handful of "anchor" tokens that everything attends to.

The result: we compute an n×nn \times n attention matrix, but the information content is concentrated in a tiny fraction of those scores. The rest is wasted arithmetic.

Attention Weight Sparsity
A simulated causal attention heatmap for a 16-token sentence. Click any row to see how that token distributes its attention. Bright cells = meaningful weight. Most cells are near-zero.
“Meaningful” threshold:5%
The
cat
sat
on
the
mat
because
it
was
tired
and
cold
The
cat
sat
on
the
mat
because
it
was
tired
and
cold
The Waste at Scale
Total attention scores computed
78
Scores above threshold (5%)
46
Wasted compute
41.0%
Weight captured by meaningful scores
96.7%
At this 12-token scale, 41.0% of attention scores contribute almost nothing. Now imagine a 128K-token sequence — that's billions of near-zero scores computed and discarded every forward pass.

Quantifying the Waste

The visualization above uses a 16-token sentence. At this scale, the waste is already visible. But the problem grows quadratically:

Sequence LengthAttention ScoresMeaningful Scores (~5-10%)Wasted Compute
512 tokens131K~10K~92%
4K tokens8M~500K~94%
32K tokens512M~25M~95%
128K tokens8.2B~400M~95%+

At 128K tokens, we're computing 8.2 billion attention scores per head per layer — and more than 95% of them produce near-zero weights that barely influence the output. Every one of those near-zero scores still costs a multiply-accumulate operation, a memory read for the corresponding key vector, and a contribution to the softmax denominator.

This isn't just a theoretical inefficiency. It translates directly into:

  • Wasted GPU cycles: computing scores no one will use
  • Wasted memory bandwidth: loading K vectors from memory for positions that will get 0.01% weight
  • Wasted value aggregation: multiplying V vectors by near-zero weights and adding them to the output

If we could skip even 90% of these useless computations, attention would go from quadratic to near-linear. That's the prize. The question is: how do you decide what to skip?


Sparse Transformer (Child et al., OpenAI, 2019)

The Motivation: Can We Skip What Doesn't Matter?

We established that most attention weights end up near zero — language is inherently sparse in its dependencies. So here's the tempting thought: what if we just… don't compute the scores we know will be near-zero?

But there's an immediate chicken-and-egg problem. You can't know which scores are near-zero until you compute them. The whole point of attention is that relevance is data-dependent — it's determined by the content of Q and K, which changes every forward pass. So how do you skip computation without first doing the computation?

The Sparse Transformer's answer: don't try to be smart about it. Instead of predicting which specific tokens matter (a hard, content-dependent problem), design a fixed geometric pattern that guarantees coverage of both local and long-range dependencies regardless of content. Accept that you'll miss some relevant tokens, but ensure that information can still flow between any two positions through multi-hop paths.

This is the key insight: you don't need every token to directly see every other token. You just need to guarantee that information from any position can reach any other position within a bounded number of layers — like a relay network where no single runner covers the full distance, but the message still arrives.

With that framing, the design becomes a two-step decision. Let's walk through it.

Step 1: Choose a Factorization Pattern

The paper proposes two ways to factorize the full attention matrix into two smaller, complementary subsets. You choose one pattern based on your data type — these are alternatives, not things you combine together.

Each pattern splits attention into two subsets per head. Instead of every token attending to all previous tokens, each token attends to two carefully chosen subsets: one captures local context, the other captures long-range context. Together, they cover the full sequence with only O(nn)O(n\sqrt{n}) connections instead of O(n2)O(n^2).

Pattern A: Strided (for data with periodic structure)

Best for data with natural periodicity — images (pixel rows), music (beat patterns), or any sequence where positions separated by a fixed stride tend to be related.

For a stride =n\ell = \sqrt{n}:

  • Subset 1 (local): token at position tt attends to the \ell positions directly before it: t,t+1,,t\\{t-\ell, t-\ell+1, \ldots, t\\}
  • Subset 2 (strided): token at position tt attends to every \ell-th position going back: t,t2,t3,\\{t - \ell, t - 2\ell, t - 3\ell, \ldots\\}

Concrete example with n=16n = 16, =4\ell = 4, token at position 14:

  • Subset 1 (local): positions 11,12,13,14\\{11, 12, 13, 14\\} — the 4 nearest neighbors
  • Subset 2 (strided): positions 14,10,6,2\\{14, 10, 6, 2\\} — every 4th position going back

Every token has its own grid of distant positions. Long-range connectivity is distributed — democratic, every token gets sparse but direct access to distant context.

Pattern B: Fixed (for data without clear periodicity)

Best for text, code, and other sequential data where there's no natural stride. Instead of giving every token long-range access, it designates specific "summary" positions that act as information relays.

The sequence is divided into blocks of size =n\ell = \sqrt{n}:

  • Subset 1 (block-local): each token attends to all tokens within its own block of \ell tokens
  • Subset 2 (summary): each token attends to the last cc positions (summary positions) of every preceding block

The last cc positions of each block are special — they serve as "relay" tokens that aggregate and broadcast information from their block.

Concrete example with n=16n = 16, =4\ell = 4, c=1c = 1 (last position of each block is summary):

  • Blocks: [0,1,2,3] [4,5,6,7] [8,9,10,11] [12,13,14,15][0,1,2,\mathbf{3}]\ [4,5,6,\mathbf{7}]\ [8,9,10,\mathbf{11}]\ [12,13,14,\mathbf{15}] (bold = summary positions)
  • Token at position 9:
    • Subset 1 (block-local): 8,9,10,11\\{8, 9, 10, 11\\} — its own block
    • Subset 2 (summary): 3,7\\{3, 7\\} — summary positions from all prior blocks

Regular tokens can't directly see distant regular tokens. But summary tokens see other summary tokens, so information routes through them.

The Key Difference

StridedFixed
Long-range accessEvery token gets it (sparse, evenly-spaced)Only summary tokens get it
Information flowDirect (sparse) from any token to any distant tokenHierarchical — through relay positions
Best forData with periodic structure (images, audio)Sequential data (text, code)
TradeoffDemocratic but noisyStructured but creates bottlenecks at relays

Step 2: Combine the Two Subsets

Once you've chosen your pattern (strided OR fixed), you have two attention subsets per position. The paper proposes three methods to combine these subsets into a functioning attention mechanism:

Method 1: Interleaved Heads (merged attention)

Use a single attention head but merge both subsets into one set of attended positions. Token tt attends to A1(t)A2(t)A_1(t) \cup A_2(t) — the union of both subsets — in a single attention operation.

Simple but limits per-head capacity. Each head must handle both local and long-range patterns simultaneously.

Method 2: Multi-Head Split

Different heads within the same layer handle different subsets. For example, in a layer with 8 heads:

  • Heads 1-4 attend using subset 1 (local context)
  • Heads 5-8 attend using subset 2 (long-range context)

This lets each head specialize — some heads focus purely on local syntax, others on long-range dependencies. The model learns to route information appropriately through head specialization.

Method 3: Interleaved Layers (the paper's primary approach)

Alternate which subset each layer uses:

  • Odd layers: every head attends using subset 1
  • Even layers: every head attends using subset 2

This is what the paper calls "factorized attention" and is their primary design. The idea: information from any position can reach any other position in exactly 2 layers (one hop through each subset).

The 2-Hop Connectivity Guarantee

This is the elegant insight that makes factorized sparse patterns work despite each individual layer seeing only n\sqrt{n} tokens:

Any token can reach any other token in at most 2 hops.

With the strided pattern using interleaved layers:

  • Layer \ell (subset 1, local): token tt attends to its local neighborhood [tn, t][t-\sqrt{n},\ t]
  • Layer +1\ell+1 (subset 2, strided): token tt attends to positions t,tn,t2n,t, t-\sqrt{n}, t-2\sqrt{n}, \ldots

Now consider: can token at position 100 reach token at position 7?

  • Layer \ell: position 7 is attended to by positions in [7,7+n][7, 7+\sqrt{n}] (some intermediate token in the local window)
  • Layer +1\ell+1: that intermediate token can reach position 100 via the strided subset

In the fixed pattern, the 2-hop works through relay positions:

  • Layer \ell (subset 1): a distant token's information reaches its block's summary position
  • Layer +1\ell+1 (subset 2): the query token attends to that summary position

This means: despite each layer only computing O(nn)O(n\sqrt{n}) attention scores, full connectivity is preserved — just routed through a 2-layer path instead of being direct.

Why It Matters Historically

Sparse Transformer proved three foundational ideas:

  1. You can drop most attention connections without destroying model quality
  2. Factorized patterns with O(n)O(\sqrt{n}) connections per layer can maintain full connectivity via multi-hop
  3. Information can route through "relay" positions — not every token needs direct access to every other token

These ideas echo through everything that followed: sliding window's receptive field growth across layers is essentially the same multi-hop insight, and NSA's compressed global tokens are a learned version of fixed pattern's summary positions.

Why It Didn't Dominate

Despite the elegant theory, practical issues held it back:

  • Fixed patterns are content-blind: the strided pattern might skip the one token that actually matters. A pronoun resolution might need exactly the token that falls between stride positions.
  • Irregular memory access: GPUs are optimized for dense, contiguous operations. Sparse indexing creates scattered memory reads that kill throughput — you save FLOPs but not necessarily wall-clock time.
  • The quadratic cost wasn't yet painful: in 2019, typical sequence lengths were 512-2048 tokens. At these lengths, full attention is fast enough that sparse patterns add complexity without meaningful speedup.
  • No standard implementation: each paper rolled its own CUDA kernels, making adoption hard.

The ideas lived on as foundational intuition, but the specific patterns didn't become standard. The field needed simpler, more hardware-friendly approaches.

Want the full details? The paper has additional nuances — how they handle the autoregressive mask within each pattern, training with mixed-precision, and their specific image/music generation results. Read the full paper: Generating Long Sequences with Sparse Transformers


Sliding Window Attention (Mistral, 2023)

The simplest form of sparse attention, and the one that actually shipped at scale in decoder-only LLMs: each token only attends to its ww nearest neighbors.

The idea itself originated in Longformer (2020) for encoder models — BERT-style architectures processing long documents. But encoder sliding window is a different beast: bidirectional attention within windows, no autoregressive constraints, no KV cache concerns. It took three years for someone to realize this could work beautifully for autoregressive generation when combined with a rolling KV buffer. That someone was Mistral.

Want to go deeper on Longformer? We cover the full paper — including its dilated windows and global token design — in: Longformer Paper Explained →

Attention(qt)=softmax(qtK[tw:t]Tdk)V[tw:t]\text{Attention}(q_t) = \text{softmax}\left(\frac{q_t \cdot K_{[t-w:t]}^T}{\sqrt{d_k}}\right) \cdot V_{[t-w:t]}

Instead of attending to all nn tokens, each token sees only a window of ww tokens. Compute becomes O(nw)O(n \cdot w)linear in sequence length when ww is fixed.

Why This Works: The Locality Assumption

Sliding window attention bets that most useful information is local. For a token at position 1000, tokens at positions 997-1003 are usually more relevant than the token at position 3.

This is empirically well-supported. If you analyze trained attention patterns in full-attention models, the vast majority of attention weight concentrates on nearby tokens. Only a handful of heads in a handful of layers consistently attend to distant positions. Sliding window sacrifices those rare long-range connections in exchange for dramatic efficiency gains.

The Effective Context: Larger Than You Think

Here's the crucial insight that makes sliding window viable: information propagates across layers.

A single layer with window ww can only "see" ww tokens. But with LL stacked layers, information flows through intermediate positions. Each layer passes information forward by ww positions, giving an effective receptive field of L×wL \times w tokens.

With Mistral 7B's configuration (32 layers, w=4096w = 4096):

  • Per-layer receptive field: 4096 tokens
  • Effective receptive field: 32×4096=131,07232 \times 4096 = 131{,}072 tokens

A token at position 128K can theoretically access information from position 0, just through a chain of intermediate relays rather than direct attention. It's like a bucket brigade — no single person reaches across the room, but the message still gets there.

Sliding window receptive field growth: a single token at Layer 5 can reach 11 input tokens through 5 layers with window size 3. At Mistral 7B scale (32 layers, w=4096), this covers the full 128K context.

The KV Cache Bonus

Sliding window doesn't just save compute — it caps KV cache growth. Instead of the cache growing linearly with sequence length, it's bounded at ww entries per head per layer.

Sequence LengthFull Attention CacheSliding Window Cache (w=4096)
4K tokens4K entries4K entries
32K tokens32K entries4K entries
128K tokens128K entries4K entries

Mistral implements this as a "rolling buffer" — once the cache reaches size ww, new entries overwrite the oldest ones. Memory usage is constant regardless of how long the conversation gets.

Who Uses Sliding Window

  • Mistral 7B (September 2023) — popularized it for decoder-only LLMs. Window size 4096.
  • Mistral Small / Medium — same approach at larger scale
  • Gemma 3 — in alternating layers (more on this next)
  • Phi-3 — uses sliding window in conjunction with other techniques
  • Mistral Large — later versions moved to full attention, suggesting sliding window has limits at frontier scale

The Limitation

Sliding window has one clear failure mode: tasks requiring precise long-range retrieval.

"What was the 3rd item in the list from 10,000 tokens ago?" — if that information has degraded through the bucket-brigade propagation across layers, the model will fail. Information doesn't pass through relays perfectly; it gets mixed, compressed, and potentially lost at each hop.

For summarization, code generation, and conversational tasks, this rarely matters. For needle-in-a-haystack retrieval over very long contexts, it does. This limitation motivated the next evolution.


Global + Local Hybrid (Gemma 3, Google DeepMind, 2025)

The fix for sliding window's long-range blindness is intuitive: keep most layers local, but let a few layers see everything.

The Idea

Instead of choosing between "full attention everywhere" (expensive) and "sliding window everywhere" (limited), alternate:

  • Local layers: sliding window attention, bounded cache, fast
  • Global layers: full attention, unbounded cache, provides long-range "highways"

Most information flow in language is local — the next word depends mainly on recent context. So local layers handle the common case efficiently. But for the 10% of cases that need long-range connections, the occasional global layer provides direct access.

Gemma 3's Design

Gemma 3 implements this with a clean ratio: 1 global attention layer for every 5 sliding-window layers.

In a 26-layer model:

  • 22 layers use sliding window (local)
  • 4 layers use full attention (global)

The math works out beautifully:

  • Memory: 22/26 layers have bounded cache → ~85% of total cache savings preserved
  • Long-range ability: every ~5 layers, information can jump to any position
  • No special tokens needed: global layers attend over the full sequence normally

Why the Highway Analogy Holds

Think of it like a road network:

  • Sliding window layers are local streets — fast, direct connections between nearby positions. High throughput for short distances.
  • Global attention layers are highways — they let information jump from any position to any other position, but you only build a few of them.

Most traffic (information flow) is local — it moves a few positions forward per layer. But when a token at position 50K needs to attend to something at position 100, the next global layer provides a direct highway for that.

The ratio (1:5 in Gemma 3) is tunable. More global layers improve long-range tasks but increase memory. Fewer global layers save more but risk information degradation over very long ranges.

The Implementation

From an engineering perspective, hybrid attention is simple:

for i, layer in enumerate(layers):
    if i % 6 == 0:  # Every 6th layer is global
        out = full_attention(x, full_kv_cache[i])
    else:
        out = sliding_window_attention(x, rolling_cache[i], window_size=4096)

The main complexity is managing two different cache types — rolling buffers for local layers, growing caches for global layers. But the global layers are few enough that their cache growth is manageable.


What Models Actually Use

ModelMechanismSparsity Pattern / WindowMax Context
Mistral 7B (2023)Pure SWAFixed 4,096-token sliding window32K
Jamba 1.5 (2024)Hybrid SSM/Attention1:7 ratio (Transformer to Mamba layers)256K
Gemma 3 (2025)SWA/Global Hybrid5:1 ratio (local window of 1,024 tokens)128K
Phi-4-mini-flash (2025)Hybrid SambaYMamba + Sliding Window Attention64K
Gemma 4 (2025)Hybrid SWA/GlobalSliding window + Global layers (variant-dependent window: 512 or 1024)256K
GLM-5 / 5.1 (2025)DeepSeek Sparse AttentionAdopted DeepSeek's sparse attention mechanism (learned Top-K selection)128K+
Qwen 3 Next (2025)Hybrid Linear/DenseGated DeltaNet (Linear) + Gated Attention (3:1 ratio)1M
DeepSeek-V4 (2025)Hybrid CSA/HCACompressed Sparse Attention + Heavily Compressed Attention1M

Notice the pattern: early models picked one technique (Mistral → pure sliding window). But from 2025 onward, every major model is a hybrid — mixing multiple sparsity mechanisms, often combining both lineages we've covered in this series.


The Composability Picture: Two Lineages Become One

Throughout this series, we've traced two independent evolutionary paths:

  • Part 2 — The Sharing Lineage: reducing what we store (MHA → MQA → GQA → MLA)
  • Part 3 — The Sparsity Lineage: reducing what we compute (Full → Sparse → Sliding Window → Global/Local)

But here's what makes the current moment exciting: these lineages are not mutually exclusive. They attack orthogonal dimensions of the same bottleneck — one compresses the representations, the other reduces the connections. A model can (and increasingly does) use both simultaneously.

LineageWhat it reducesExample
Sharing/CompressionKV cache memoryMLA compresses KV into low-rank latent
Sparsity/RoutingAttention computeSWA skips distant scores, global layers provide highways

When you combine them, the savings multiply:

  • MLA means each layer's KV cache is small (fewer bytes to store and load from memory)
  • Sliding window means most layers only attend locally (fewer scores to compute)
  • Global layers ensure long-range info still flows (quality preserved)

This is exactly what we see in the latest architectures.


The Latest Trend: Combining Everything

The 2025 frontier models aren't choosing between sharing and sparsity — they're layering every trick together:

DeepSeek-V4 introduces a hybrid attention architecture with two new mechanisms: Compressed Sparse Attention (CSA) for local compression and Heavily Compressed Attention (HCA) for aggressive global compression. Together they reduce attention FLOPs to ~27% and KV cache to ~10% of V3's cost — enabling 1M context with manageable compute and memory.

Gemma 4 builds on Gemma 3's hybrid pattern — sliding window in most layers, full attention in a few global layers — now supporting 256K context. Different model sizes use different window sizes (512 for smaller variants, 1024 for larger), tuning the efficiency/quality tradeoff per scale.

Qwen 3 Next goes further by replacing most attention layers entirely with linear attention (Gated DeltaNet) at a 3:1 ratio to standard gated attention layers — a hybrid that's sub-quadratic by construction while retaining dense attention where it matters most. This achieves ~1M context with YaRN scaling.

GLM-5 adopts DeepSeek's sparse attention mechanism (DSA) — content-aware sparsity where a learned router selects which tokens are worth attending to at runtime, rather than relying on fixed patterns.

The pattern is clear: the best architectures of 2025 are those that combine compressed KV representations (sharing lineage) with intelligent sparse patterns (sparsity lineage), achieving long contexts (200K–1M tokens) that would be unthinkable with either approach alone.

Want to see how all these pieces fit together in one model? We take Gemma 4 apart layer by layer — its hybrid attention, MoE routing, and every architectural decision — in: Dissecting Gemma 4: Architecture from the Ground Up →

The Natural Question: How Do You Design Learned Sparse Attention?

All of these hybrid approaches share a common thread: they don't just use fixed patterns (like Sparse Transformer's strided/fixed) or simple heuristics (like sliding window's locality assumption). They use learned routing — letting the model itself decide which tokens matter for each query.

This idea was formalized most cleanly in Native Sparse Attention (NSA) by DeepSeek (2025). NSA combines three paths — compressed global context, learned token selection, and sliding window — into a single hardware-efficient mechanism. It's arguably the blueprint that the models above are all iterating on.

We'll dissect NSA in full detail — the compression module, the routing mechanism, the hardware-aligned design, and why it works so well — in a dedicated paper explanation.

Deep dive available: Native Sparse Attention: Hardware-Aligned Learned Sparsity → — the full paper dissection with code implementations.


The Remaining Bottleneck: Systems-Level Optimization

We've now covered two lineages of architectural innovation:

  • Part 2: Reducing what we store (sharing → compression)
  • Part 3: Reducing what we compute (sparsity → learned routing)

But there's a third axis we haven't touched. Even with full, dense, quadratic attention — can we make it faster by understanding how GPUs actually work?

Flash Attention doesn't change the attention algorithm at all. It computes the exact same QKTQK^T matrix and same softmax. But by reordering operations to exploit the GPU's memory hierarchy — keeping data in fast SRAM instead of slow HBM — it achieves 2-4× speedup with zero approximation.

This is a fundamentally different kind of optimization: not changing what we compute, but how the hardware executes it. It's the reason most models can still afford full attention at 32-128K context lengths — and it's what we'll cover next.


References


Next in the series: Attention Part 4 — Flash Attention: Making GPUs Actually Work — how to make attention 2-4× faster without changing a single weight, by understanding SRAM, HBM, and the art of tiling.

This post is part of The Gradient Descent through Transformers — a series dissecting every component of the modern transformer stack.