Part 4 of 8

The Gradient Descent through Transformers

cd ../blog
TransformersNLPAttentionThe Gradient Descent through Transformers

Attention Part 1 — The Mechanism That Changed Everything

May 6, 202630 min read

This is Part 4 of The Gradient Descent through Transformers — a series where I walk through every component of the modern transformer stack, how it evolved from 2017 to 2026, and why each piece matters.

Previously: Positional Encoding Part 2 — RoPE, ALiBi, and the Quest for Length Generalization


We've covered how text becomes tokens (Part 1) and how transformers know where those tokens are (Part 2, Part 3). Now we get to the heart of the transformer: self-attention — the mechanism that lets every token look at every other token and decide what's relevant.

This is the single most important idea in the transformer. Everything else — the feedforward layers, the normalization, the residual connections — is supporting infrastructure. Self-attention is the engine.

What Problem Does Attention Solve?

Before transformers, the dominant models for sequence processing were RNNs (and their variants LSTM, GRU). They had a fundamental limitation: information had to flow sequentially.

If you're processing the sentence "The animal didn't cross the street because it was too tired", and you want to figure out what "it" refers to — you need to connect "it" back to "animal", which is 7 tokens earlier. In an RNN, the information about "animal" has to survive through 7 sequential processing steps, each one risking information loss. For longer documents, this problem becomes catastrophic.

"Why Not Just Add Attention to RNNs?"

This is exactly what happened first. Bahdanau et al. (2014) added an attention mechanism on top of an RNN encoder-decoder for machine translation. The decoder could directly look back at all encoder hidden states instead of relying on a single compressed vector. It worked — translation quality improved significantly.

But adding attention to RNNs doesn't fix the core problem. The RNN encoder still processes tokens sequentially, one by one. Token 100 can't be computed until tokens 1-99 are done. This means:

  1. No parallelism during training. You can't use GPUs efficiently because each step depends on the previous. A 1000-token sequence requires 1000 sequential steps — you're using your expensive GPU like a calculator.

  2. Encoder representations are still built through a bottleneck. By the time the RNN reaches token 100, its hidden state has been updated 100 times. The representation of token 1 has been "compressed" through all those updates. Attention helps the decoder look back at all encoder states, but those states themselves were built through the sequential bottleneck.

  3. Gradients still flow through the sequential path. Backpropagation through 1000 RNN steps still risks vanishing or exploding gradients, even with LSTM/GRU. Attention provides a gradient shortcut for the decoder, but doesn't fix the encoder's gradient problem.

The Transformer Insight: Attention Can Replace the RNN Entirely

The breakthrough of "Attention Is All You Need" (Vaswani et al., 2017) wasn't inventing attention — it was showing that attention is powerful enough to be the entire model. No RNN needed. Self-attention replaces both the sequential processing AND the cross-sequence attention in one mechanism.

Every token's representation is computed in parallel, looking at all other tokens simultaneously. No sequential bottleneck:

  • Training parallelizes perfectly: all tokens in a sequence can be processed at once. GPUs love this.
  • No information decay: token 1's representation directly participates in token 1000's computation — no 999 steps of compression.
  • Gradient shortcuts everywhere: attention provides direct gradient paths between any two tokens, regardless of distance.

The name of the paper says it all: you don't need recurrence, you don't need convolution. Attention is all you need.

Building Self-Attention from Scratch

Let's build the mechanism step by step. We have a sequence of nn tokens, each represented as a dd-dimensional embedding vector. The goal: produce a new representation for each token that incorporates information from all other tokens, weighted by relevance.

The Intuition: Questions, Answers, and Information

The naming convention — Query, Key, Value — comes from a database analogy:

  • Query (Q): "What am I looking for?" — what this token wants to know
  • Key (K): "What do I contain?" — what this token advertises about itself
  • Value (V): "What information do I provide?" — the actual content to pass along if selected

When token ii wants to gather information from the sequence, it broadcasts its query qiq_i to all other tokens. Each other token jj responds with its key kjk_j. The dot product qikjq_i \cdot k_j measures how well the query matches the key — how relevant token jj is to token ii's question. The higher the match, the more of token jj's value vjv_j gets passed to token ii.

This is like searching a library: your query is the topic you're researching, the keys are book titles, and the values are the book contents. You look at all titles (keys), find which match your topic (query), and read those books (values).

Step 1: Project into Q, K, V

Each token embedding xix_i gets transformed into three different vectors by three learned weight matrices:

qi=xiWQ,ki=xiWK,vi=xiWVq_i = x_i W_Q, \quad k_i = x_i W_K, \quad v_i = x_i W_V

Where WQ,WKRd×dkW_Q, W_K \in \mathbb{R}^{d \times d_k} and WVRd×dvW_V \in \mathbb{R}^{d \times d_v}.

Why project at all? Why not use the raw embeddings?

You could compute attention directly on the token embeddings: score(i,j)=xixj\text{score}(i, j) = x_i \cdot x_j. This would just measure how similar two tokens are in embedding space. But this is too rigid — it would mean "cat" always attends to "cat" the most (identity has the highest dot product). There's no way for the model to learn that "sat" should attend to "cat" (its subject) rather than to another "sat".

The projection matrices WQW_Q, WKW_K, WVW_V are learned transformations that let the model decide: "when I'm a query, I look like this; when I'm a key, I advertise like this." They give the model the flexibility to learn arbitrary attention patterns — not just token similarity.

Why THREE separate projections? Why not one or two?

Why not one projection? If we used the same matrix for Q and K (qi=xiWq_i = x_i W, kj=xjWk_j = x_j W), then score(i,j)=(xiW)(xjW)\text{score}(i,j) = (x_i W) \cdot (x_j W). This is symmetric — token ii attending to token jj would give the same score as jj attending to ii. But attention shouldn't be symmetric! In "the cat sat", "sat" needs to strongly attend to its subject "cat" (to know who sat), but "cat" doesn't necessarily need to attend back to "sat" with the same strength.

Separate WQW_Q and WKW_K break this symmetry. Each token can independently control what it's looking for (Q) and what it advertises (K).

Why a separate V? The Key says "attend to me because I'm relevant." But the information you actually want to copy might be different from what made the match. Think of the library analogy: you search by title (Key matches Query), but what you read is the content (Value). The word "Paris" might be attended to because its Key signals "I am a location" — but the Value it passes along might encode "capital, France, Europe, Eiffel Tower." The Key is the index; the Value is the data.

If we stacked all tokens as a matrix XRn×dX \in \mathbb{R}^{n \times d}:

Q=XWQ,K=XWK,V=XWVQ = XW_Q, \quad K = XW_K, \quad V = XW_V

Step 2: Compute Attention Scores

Why dot product?

We need a way to measure "how relevant is token jj to token ii?" The dot product is the simplest operation that does this — it measures alignment between two vectors. If qiq_i and kjk_j point in the same direction (large positive dot product), they match well. If they're orthogonal (dot product ≈ 0), they're unrelated. If they point in opposite directions (large negative), they actively repel each other.

Other options exist (additive attention: vTtanh(W1q+W2k)v^T \tanh(W_1 q + W_2 k), used by Bahdanau), but dot product is faster (just a matrix multiply) and works just as well in practice.

For each pair of tokens (i,j)(i, j), compute how much token ii should attend to token jj:

score(i,j)=qikj=l=1dkqi,lkj,l\text{score}(i, j) = q_i \cdot k_j = \sum_{l=1}^{d_k} q_{i,l} \cdot k_{j,l}

In matrix form, we compute ALL pairwise scores at once:

S=QKTRn×nS = QK^T \in \mathbb{R}^{n \times n}

This is an n×nn \times n matrix where entry SijS_{ij} is the attention score from token ii to token jj. This is also where the quadratic cost comes from — but we'll get to that later.

Step 3: Scale

Here's a subtle but critical step that's easy to overlook. We divide the scores by dk\sqrt{d_k}:

Sscaled=QKTdkS_{scaled} = \frac{QK^T}{\sqrt{d_k}}

Why? Without scaling, the dot products grow in magnitude with the dimension dkd_k. If qq and kk are vectors with entries drawn from a standard normal distribution, their dot product has variance dkd_k. For dk=64d_k = 64, the dot products can be as large as ±16.

When these large values go through softmax (next step), the softmax becomes extremely peaked — almost all the weight goes to one token, and the gradients become vanishingly small. The model can't learn.

Dividing by dk\sqrt{d_k} normalizes the variance back to 1, keeping the softmax in a regime where it produces smooth distributions and useful gradients.

Let's see this concretely:

Without scaling (dk=64d_k = 64): scores might be [12.3,8.1,15.7,1.2][12.3, -8.1, 15.7, 1.2] → softmax → [0.03,0.00,0.97,0.00][0.03, 0.00, 0.97, 0.00] — all attention on one token, barely any gradient for the others.

With scaling (÷ 8): scores become [1.54,1.01,1.96,0.15][1.54, -1.01, 1.96, 0.15] → softmax → [0.25,0.02,0.38,0.06][0.25, 0.02, 0.38, 0.06] — much smoother distribution, gradients flow to all positions.

Step 4: Softmax (Normalize)

Why softmax? Why not just use the raw scores?

The raw scores SijS_{ij} can be any real number — positive, negative, large, small. We need to turn them into weights that:

  • Are all non-negative (you can't attend "negatively" to a token)
  • Sum to 1 (so the output is a weighted average, not an unbounded sum)
  • Preserve the ranking (higher score = more attention)

Softmax does exactly this. It maps any vector of real numbers to a probability distribution:

A=softmax(Sscaled)whereAij=eSij/dkl=1neSil/dkA = \text{softmax}(S_{scaled}) \quad \text{where} \quad A_{ij} = \frac{e^{S_{ij}/\sqrt{d_k}}}{\sum_{l=1}^{n} e^{S_{il}/\sqrt{d_k}}}

Each row AiA_i is now a probability distribution over all tokens — representing how much token ii attends to each other token.

Why not just normalize by dividing by the sum? Because raw scores can be negative, and dividing by a sum of mixed-sign numbers gives meaningless results. The exponential in softmax ensures everything is positive first, then normalizes. It also sharpens differences: a score of 5 vs 3 becomes e5/e37.4×e^5 / e^3 \approx 7.4\times difference, making the model's choices more decisive.

Step 5: Weighted Sum of Values

Why a weighted sum? Why not just pick the top-scoring token?

You could do "hard" attention — attend 100% to the highest-scoring token and 0% to everything else. But this has two problems:

  1. It's not differentiable. The argmax operation has zero gradient almost everywhere. The model can't learn through backpropagation which tokens to attend to.

  2. It throws away information. Language is ambiguous. "it" might refer to "animal" (70% likely) or "street" (30% likely). A soft weighted sum preserves this uncertainty — the output is a blend reflecting both possibilities. Hard attention would force a premature commitment.

The soft weighted sum gives us the best of both worlds: it's differentiable (gradients flow smoothly), and it allows the model to hedge when uncertain.

The final output for each token:

outputi=j=1nAijvj\text{output}_i = \sum_{j=1}^{n} A_{ij} \cdot v_j

In matrix form:

Output=AV=softmax(QKTdk)V\text{Output} = A \cdot V = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

This is the complete self-attention formula. Token ii's output is a blend of all value vectors, where the blending weights come from how well token ii's query matches each other token's key.

Putting It All Together

For the sentence "The cat sat" with d=4d = 4 (embedding dimension) and single-head attention (dk=d=4d_k = d = 4):

Input embeddings (each token is a 4-dimensional vector): X=[xThexcatxsat]R3×4X = \begin{bmatrix} x_{\text{The}} \\ x_{\text{cat}} \\ x_{\text{sat}} \end{bmatrix} \in \mathbb{R}^{3 \times 4}

Project to Q, K, V (no dimension reduction in single-head — WQ,WK,WVW_Q, W_K, W_V are all 4×44 \times 4): Q=XWQR3×4,K=XWKR3×4,V=XWVR3×4Q = XW_Q \in \mathbb{R}^{3 \times 4}, \quad K = XW_K \in \mathbb{R}^{3 \times 4}, \quad V = XW_V \in \mathbb{R}^{3 \times 4}

Attention scores (QKTQK^T) — a 3×33 \times 3 matrix (one score for every token pair): S=[qThekTheqThekcatqTheksatqcatkTheqcatkcatqcatksatqsatkTheqsatkcatqsatksat]S = \begin{bmatrix} q_{\text{The}} \cdot k_{\text{The}} & q_{\text{The}} \cdot k_{\text{cat}} & q_{\text{The}} \cdot k_{\text{sat}} \\ q_{\text{cat}} \cdot k_{\text{The}} & q_{\text{cat}} \cdot k_{\text{cat}} & q_{\text{cat}} \cdot k_{\text{sat}} \\ q_{\text{sat}} \cdot k_{\text{The}} & q_{\text{sat}} \cdot k_{\text{cat}} & q_{\text{sat}} \cdot k_{\text{sat}} \end{bmatrix}

Scale by dk=4=2\sqrt{d_k} = \sqrt{4} = 2, softmax each row, multiply by V → each token gets a new 4-dimensional representation that's a weighted blend of all value vectors.

Note: in multi-head attention (covered later), dkd_k becomes smaller than dd because the dimension is split across heads. But for single-head attention, dk=dd_k = d.

Bidirectional vs Causal: Two Worlds of Self-Attention

Everything we've described so far is bidirectional self-attention — every token attends to every other token, in both directions. Token 3 can look at token 7, and token 7 can look at token 3. The full n×nn \times n attention matrix is computed without restriction.

This is what encoder models like BERT use. An encoder's job is to understand text that's already complete — you feed it an entire sentence, and it builds a deep representation of every token using context from both the left and the right. When BERT processes "The cat sat on the mat", the word "sat" can look at both "cat" (to its left) and "mat" (to its right). It sees the whole picture.

Why Decoders Can't See Everything

But decoder models (GPT, Llama, Claude) have a fundamentally different job: they generate text, one token at a time. When the model is predicting the 5th word, the 6th word doesn't exist yet — it hasn't been generated. Allowing the model to attend to future tokens during training would be cheating: it would learn to "peek" at the answer instead of learning to predict it.

This is the core difference:

  • Encoder (BERT): understands complete text → sees everything → bidirectional
  • Decoder (GPT): generates text left-to-right → can only see the past → needs restriction

Causal Self-Attention

The solution is simple but fundamental: mask out all future positions in the attention computation. Token ii can only attend to tokens at positions i\leq i. This is called a causal mask (because information flows only in the causal direction — past to present, never future to past).

Mechanically, we add -\infty to the upper-triangular part of the score matrix before softmax:

Smasked=[s11s21s22s31s32s33]S_{masked} = \begin{bmatrix} s_{11} & -\infty & -\infty \\ s_{21} & s_{22} & -\infty \\ s_{31} & s_{32} & s_{33} \end{bmatrix}

After softmax, e=0e^{-\infty} = 0, so those positions get zero attention weight:

  • Token 1 can only see itself
  • Token 2 sees tokens 1 and 2
  • Token 3 sees all three

The key insight: during training, we can still process the entire sequence in parallel (all positions computed at once), but each position's attention is restricted to only look backward. This gives us the parallelism of transformers while maintaining the autoregressive property needed for generation.

Why Almost All Modern LLMs Use Causal Attention

In 2024-2026, virtually every major LLM — GPT-4, Claude, Llama, Mistral, Gemini — is a decoder-only model with causal self-attention. Encoder-only models (BERT) and encoder-decoder models (T5) still exist for specific tasks, but the "GPT architecture" (decoder-only, causal attention, autoregressive generation) won for general-purpose language models.

Why? Because generation is the hardest and most general task. A model that can generate coherent text can also be prompted to classify, summarize, translate, and reason — without architectural changes. The causal mask is a small constraint that enables this universal capability.

Padding Mask

One other mask worth mentioning: when batching sequences of different lengths, shorter sequences get padded with [PAD] tokens. The padding mask ensures no real token attends to padding tokens — they contain no useful information and shouldn't influence any representation.

See Both in Action

Now that you understand both modes, use the visualizer below. Start in bidirectional mode — notice how every token has arcs going to all other tokens in both directions. Then switch to causal — watch the forward-looking arcs disappear, and the attention matrix becomes triangular:

interactive

Self-Attention — Visualized

click any token below to see what it attends to

mode:

click a token to see its attention flow

Thecatsatonthematbecau…itwastired

full attention matrix (row = query token, column = key token):

The
cat
sat
on
the
mat
beca
it
was
tire
The
5
12
4
19
14
8
7
9
11
11
cat
19
5
12
7
7
8
11
13
9
8
sat
8
7
8
10
13
11
7
13
6
17
on
12
4
5
11
11
9
6
20
8
13
the
7
22
11
4
11
16
8
6
8
7
mat
8
20
9
8
9
10
10
7
13
6
because
9
11
23
3
8
15
9
6
4
11
it
11
8
5
16
9
7
9
13
14
8
was
5
13
3
17
15
10
6
11
11
10
tired
14
8
5
14
7
6
9
13
18
5

darker = higher attention weight · full matrix (bidirectional)

Note: this uses random Q, K projections for demonstration — the attention patterns here aren't linguistically meaningful. In a trained model, the learned WQ and WK matrices produce patterns where tokens attend to grammatically and semantically relevant tokens (subjects attend to verbs, pronouns attend to their referents, etc).

The Full Transformer: Encoder, Decoder, and Cross-Attention

Now that you understand both bidirectional and causal self-attention, let's zoom out and see where they fit in the original transformer architecture. This context is important — it explains why cross-attention exists and how the field evolved from the 2017 design to what we use today.

The Original 2017 Architecture

The transformer from "Attention Is All You Need" was designed for machine translation (e.g., French → English). It had two halves:

Encoder (left side): Takes the full source sentence (French) and builds a deep representation using bidirectional self-attention. Every French word can attend to every other French word. The encoder's job is to understand the input completely.

Decoder (right side): Generates the target sentence (English) one token at a time using causal self-attention. Each English word can only attend to previously generated English words. The decoder's job is to produce output.

Cross-Attention (the bridge): After the decoder's causal self-attention, there's a second attention layer where the decoder attends to the encoder's output. This is where the two sequences talk to each other.

How Cross-Attention Works

Cross-attention is mechanically identical to self-attention — same Q, K, V, dot product, softmax, weighted sum. The only difference: Q comes from one sequence, while K and V come from a different sequence.

Q=XdecoderWQ(decoder token asks a question)Q = X_{\text{decoder}} W_Q \quad \text{(decoder token asks a question)} K=XencoderWK(encoder tokens provide keys)K = X_{\text{encoder}} W_K \quad \text{(encoder tokens provide keys)} V=XencoderWV(encoder tokens provide values)V = X_{\text{encoder}} W_V \quad \text{(encoder tokens provide values)}

The decoder token broadcasts its query: "I'm trying to generate the next English word — which French words should I look at?" The encoder tokens respond with their keys, and the attention mechanism selects the relevant French tokens whose values get passed to the decoder.

Concrete example: Translating "Le chat est assis""The cat is sitting"

When the decoder is generating "sitting", its query looks for the French verb. Cross-attention produces high weights on "assis" (the French word for sitting) — pulling its meaning into the decoder's representation. The decoder doesn't need to "remember" the French sentence through a bottleneck; it can directly look at any French word at any time.

Three Types of Attention in One Architecture

The original transformer uses all three:

Attention TypeWhereQ sourceK, V sourceMode
Encoder self-attentionEncoder layersEncoderEncoderBidirectional
Decoder self-attentionDecoder layersDecoderDecoderCausal (masked)
Cross-attentionDecoder layersDecoderEncoderBidirectional over encoder

Each decoder layer has TWO attention sublayers: first causal self-attention (decoder attends to itself), then cross-attention (decoder attends to encoder). This is why encoder-decoder models are more complex — they have 50% more attention computation per decoder layer.

The Shift to Decoder-Only

The original transformer was encoder-decoder because it was designed for translation — a task with distinct input and output sequences. But starting with GPT (2018), researchers discovered something surprising:

A decoder alone, prompted with the right text, can do everything.

Instead of an encoder processing the French sentence and a decoder generating English, you just feed the decoder: "Translate French to English: Le chat est assis → " and let it continue generating. The "encoder" functionality is implicit — the causal attention over the prompt serves the same purpose.

This worked for translation, summarization, question answering, classification — every NLP task. One architecture to rule them all. The simplicity was irresistible:

  • Fewer parameters (no separate encoder)
  • One attention type (just causal self-attention)
  • One training objective (predict next token)
  • Scales better (all parameters contribute to one task)

By 2023, virtually every major LLM — GPT-4, Claude, Llama, Mistral — was decoder-only. The encoder-decoder architecture didn't die (T5, Flan-T5 still exist), but it became niche.

Where Cross-Attention Still Lives

Cross-attention didn't disappear — it moved to multimodal and specialized architectures:

  • Vision-language models (LLaVA, Flamingo): An image encoder processes the image, then the text decoder cross-attends to image features. "What's in this image?" — the decoder's query looks at the encoder's visual tokens.
  • Speech models (Whisper): An audio encoder processes the waveform, then a text decoder cross-attends to generate the transcript.
  • Retrieval-augmented generation (RAG): Some architectures use cross-attention to let the generator attend to retrieved document embeddings.
  • Diffusion models (Stable Diffusion): The image generator cross-attends to the text encoder's output to guide image generation from a text prompt.

Anywhere you have two distinct modalities or sequences that need to interact, cross-attention is the mechanism that bridges them.

Multi-Head Attention: Why One Perspective Isn't Enough

A single attention head computes one set of Q, K, V projections and produces one attention pattern. But language has many types of relationships happening simultaneously in the same sentence:

  • Syntactic: subject → verb agreement ("The cats are")
  • Semantic: pronoun → referent ("it" → "animal")
  • Positional: adjacent tokens attending to each other
  • Long-range: a conclusion referencing a premise from paragraphs earlier

One attention pattern can't capture all of these at once. If the single head learns to focus on syntactic relationships, it loses the ability to track semantic ones. Multi-head attention solves this by running multiple attention heads in parallel, each with its own Q, K, V projections, each free to learn different types of relationships.

The Concept: Multiple Parallel Perspectives

Instead of computing one attention pattern over the full dd-dimensional space, we split the computation into hh independent heads, each operating on a smaller dk=d/hd_k = d/h dimensional subspace.

For each head j{1,...,h}j \in \{1, ..., h\}:

headj=Attention(Qj,Kj,Vj)\text{head}_j = \text{Attention}(Q_j, K_j, V_j)

Each head produces its own attention pattern and its own output. All head outputs are concatenated and projected back to the model dimension:

MultiHead(X)=Concat(head1,...,headh)WO\text{MultiHead}(X) = \text{Concat}(\text{head}_1, ..., \text{head}_h) \cdot W_O

Where WORd×dW_O \in \mathbb{R}^{d \times d} is a learned output projection that combines the perspectives from all heads.

How It Actually Works: Two Equivalent Views

There are two ways to understand the implementation — they're mathematically identical but one is conceptual and the other is what GPUs actually compute:

View A (conceptual — per-head matrices):

Each head jj has its own weight matrices: Qj=XWQj,Kj=XWKj,Vj=XWVjQ_j = X W_Q^j, \quad K_j = X W_K^j, \quad V_j = X W_V^j

Where WQj,WKj,WVjRd×dkW_Q^j, W_K^j, W_V^j \in \mathbb{R}^{d \times d_k} and dk=d/hd_k = d/h.

You do hh separate matrix multiplications, each producing smaller Q, K, V matrices. Each head independently computes attention on its subspace.

View B (actual implementation — one big matrix, then split):

In practice, you use ONE big weight matrix WQRd×dW_Q \in \mathbb{R}^{d \times d} and compute the full projection in a single matrix multiply:

Q=XWQRn×dQ = X W_Q \in \mathbb{R}^{n \times d}

Then you reshape this into hh heads by splitting the last dimension:

QRn×dQRn×h×dkQ \in \mathbb{R}^{n \times d} \longrightarrow Q \in \mathbb{R}^{n \times h \times d_k}

Same thing for K and V. The math is identical — the big WQW_Q matrix is effectively hh smaller matrices stacked side by side. But doing one big matrix multiply is much faster on GPUs than hh small ones.

Why this distinction matters: When we get to Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) in Part 2, the "big matrix, then split" view makes it clear what's happening. In GQA, the Q matrix is still split into 32 heads, but the K and V matrices are split into only 8 groups — each group shared by 4 query heads. This only makes sense if you understand that "separate head parameters" is really "one big matrix, split differently."

Why Split Dimensions Instead of Running Full Attention Multiple Times?

You might ask: why not just run hh full-dimensional attention computations? Because that would cost h×h \times more compute. By splitting the dimensions, multi-head attention uses the same total computation as single-head attention:

For a model with d=512d = 512 and h=8h = 8 heads:

  • Single-head: one attention computation on 512 dimensions
  • Multi-head: 8 parallel computations on 64 dimensions each
  • Total dimensions processed: 8×64=5128 \times 64 = 512 — identical cost

You get hh different attention patterns for free (in terms of compute). The only overhead is the output projection WOW_O that combines the heads.

Do All Heads Need the Same Dimension?

In the standard implementation, yes — all heads have dk=d/hd_k = d/h. This is primarily a practical choice: uniform tensor shapes are easy to parallelize on GPUs. Mathematically, nothing prevents heads of different sizes (e.g., some heads with 128 dimensions and others with 32). Some research has explored this, but the gains are marginal and the implementation complexity isn't worth it. Every major LLM uses equal head sizes.

What Different Heads Learn

Research has shown that different heads naturally specialize without being told to:

  • Positional heads: attend primarily to the previous or next token (syntactic structure)
  • Rare word heads: attend strongly to infrequent tokens (they carry more information)
  • Separator heads: attend to punctuation and special tokens (sentence boundaries)
  • Semantic heads: attend to semantically related tokens regardless of distance
  • Duplicate heads: some heads learn nearly identical patterns — this redundancy is one motivation for GQA (reducing K, V heads without losing quality)

This specialization emerges purely from training — it's not programmed. The multi-head architecture provides the capacity for diverse attention patterns, and gradient descent discovers which specializations are useful.

Putting It All Together: The Matrix Flow

The animation below walks through the full multi-head computation step by step. Watch how X gets projected into Q, K, V, how each matrix splits across heads, how each head computes attention with different patterns, and how the outputs get concatenated and projected through WOW_O:

Multi-Head Attention: step-by-step matrix flow from embeddings through parallel heads to final output projection.

Two Fundamental Flaws

Multi-head attention is powerful — but it comes with two costs that become devastating at scale. Understanding both is essential before we can appreciate the solutions in Part 2.

Flaw 1: The Quadratic Wall

Self-attention computes the n×nn \times n score matrix QKTQK^T. Both the computation and the storage of this matrix are O(n2)O(n^2):

Sequence LengthScore Matrix SizeMemory (fp16)Compute (GFLOPs)
512262K entries0.5 MB0.03
2,0484.2M entries8 MB0.5
8,19267M entries128 MB8
32,7681.1B entries2 GB134
131,072 (128K)17.2B entries32 GB2,147

At 128K tokens (the context length of GPT-4 and Llama 3), the attention matrix alone takes 32 GB per layer per head in memory. With 32 layers and 32 heads, you'd need petabytes of memory — obviously impossible.

This is the cost you pay during training and prefill — when the full sequence is processed at once. The quadratic cost was recognized early (it's inherent in the original 2017 design), and it's what motivated sparse attention patterns starting in 2019.

Flaw 2: The KV Cache During Inference

The quadratic wall is a training/prefill problem. But during inference — when the model actually generates text — a different cost dominates.

The Autoregressive Problem

Language models generate text one token at a time. To produce token t+1t+1, the model:

  1. Takes all tokens so far: [x1,x2,...,xt][x_1, x_2, ..., x_t]
  2. Runs the full forward pass (attention + feedforward layers)
  3. Produces a probability distribution over the next token
  4. Samples token t+1t+1
  5. Repeats with [x1,x2,...,xt,xt+1][x_1, x_2, ..., x_t, x_{t+1}]

Here's the problem: at step tt, the attention layer computes:

qtKT=qt[k1,k2,...,kt]Tq_t \cdot K^T = q_t \cdot [k_1, k_2, ..., k_t]^T

The new token's query must attend to all previous tokens' keys. And the output is a weighted sum of all previous tokens' values. So at every step, you need K and V for the entire history.

The Naive Approach (Wasteful)

The naive implementation recomputes K and V for all tt tokens at every generation step. At step 100, you recompute k1k_1 through k100k_{100}. At step 101, you recompute k1k_1 through k101k_{101} — recalculating all 100 previous keys that haven't changed.

This means step tt costs O(t)O(t) compute for the projection alone, and the total generation cost for a sequence of length nn is O(n2)O(n^2) — even ignoring the attention computation itself.

The Solution: Cache K and V

The fix is obvious once you see the waste: cache the K and V vectors from previous steps. Token k1k_1's key never changes once computed — it depends only on x1x_1 and the learned weight matrix WKW_K. So compute it once, store it, and reuse it forever.

This is the KV cache:

  • At step 1: compute k1,v1k_1, v_1, cache them
  • At step 2: compute k2,v2k_2, v_2, cache them. Attend using [k1,k2][k_1, k_2] from cache
  • At step tt: compute kt,vtk_t, v_t, append to cache. Attend using [k1,...,kt][k_1, ..., k_t] from cache

Now each generation step only computes K, V for the single new token — O(1)O(1) projection cost per step instead of O(t)O(t).

The New Problem: Memory

The KV cache eliminates redundant computation but introduces a memory problem. You're storing K and V for every token, in every layer, for every head:

KV cache size=2×nlayers×nheads×t×dk\text{KV cache size} = 2 \times n_{\text{layers}} \times n_{\text{heads}} \times t \times d_k

Let's compute this for a realistic model (Llama 3 70B):

  • 80 layers, 64 KV heads, dk=128d_k = 128, fp16 (2 bytes per value)
  • At sequence length 8K: 2×80×64×8192×128×2=2 \times 80 \times 64 \times 8192 \times 128 \times 2 = 20.9 GB
  • At sequence length 32K: 83.9 GB
  • At sequence length 128K: 335 GB

That's the cache for a single sequence. Serve 8 users in parallel and you need 2.7 TB of memory just for KV caches at 128K context — more than the model weights themselves.

Why This Matters

The KV cache is the dominant memory cost during LLM serving. It determines:

  • Maximum batch size: more memory for caches = fewer concurrent users
  • Maximum context length: longer sequences = larger caches
  • Hardware requirements: KV cache often dictates how many GPUs you need, not model weights

This is why reducing the KV cache — through fewer heads (MQA/GQA), compressed representations (MLA), or bounded attention (sliding window) — became the central challenge for efficient LLM deployment. The KV cache problem wasn't felt until models were large enough and sequences long enough for serving to become the bottleneck — which is why solutions like MQA (2019) and GQA (2023) came years after the original transformer.

Why Not Just Remove Attention?

You might wonder: if attention is so expensive, why not replace it with something cheaper? People have tried — linear attention, state space models (Mamba), and other alternatives. But nothing has matched the quality of softmax attention for language modeling. The quadratic cost is the price of a mechanism that genuinely allows any token to interact with any other token. The field's response hasn't been to remove attention, but to make it cheaper.

And "cheaper" comes in two fundamentally different flavors:


What's Next: Two Paths to Efficient Attention

We've now seen how attention works — and why it's expensive. The O(n2)O(n^2) score matrix and the ever-growing KV cache are real walls that every production model must solve. The solutions split into two distinct categories:

Path 1: Architectural Innovations — Change What Gets Computed

These approaches redesign the attention mechanism itself to reduce memory and compute:

  • Multi-Query & Grouped-Query Attention (MQA/GQA) — share K, V heads across query heads to shrink the KV cache by 8-32x
  • Sliding Window & Sparse Attention — limit which token pairs can interact, breaking the quadratic cost for long sequences
  • Multi-Latent Attention (MLA) — compress KV into low-rank latent vectors (DeepSeek's radical approach)
  • Differential Attention & Native Sparse Attention — the newest (2024-2025) innovations that learn what to attend to and what to ignore

These change the math. A model using GQA computes a fundamentally different operation than vanilla multi-head attention — it just happens to approximate the same result with far less memory.

Path 2: Systems-Level Optimizations — Change How It Runs on Hardware

These keep the attention math identical but exploit GPU memory hierarchy and parallelism:

  • Flash Attention — reorders computation to minimize HBM reads/writes (same result, 2-4x faster)
  • Paged Attention (vLLM) — virtual memory for KV cache, eliminating fragmentation during serving
  • KV Cache Quantization — store cached keys/values in int8/int4 instead of fp16
  • Operator Fusion & Kernel Optimization — fuse softmax, masking, and dropout into a single GPU kernel

These don't change what's computed — they change where data lives and when it moves between SRAM and HBM.

The Road Ahead

The next two posts in this series tackle each path:

Part 2 — Architectural Attention Variants will trace the chronological evolution from MQA (2019) through GQA and sliding window attention, to the cutting-edge MLA and differential attention mechanisms used in today's frontier models. We'll see what Llama, Mistral, DeepSeek, and Gemma actually chose — and why.

Part 3 — GPU-Level Attention Optimization will open the hood on Flash Attention, paged KV caches, and the memory hierarchy tricks that made million-token contexts practical without changing a single weight matrix.

Together, these two posts complete the picture: architecture decides what to compute, and systems engineering decides how fast it runs.


References & Further Reading


This post is part of The Gradient Descent through Transformers — a series dissecting every component of the modern transformer stack.