---
title: "Module 05: Attention"
format:
html:
code-fold: false
toc: true
ipynb: default
jupyter: python3
---
{{< include ../_diagram-lib.qmd >}}
## Introduction
The mechanism that made transformers revolutionary. **Attention** allows each token to "look at" every other token and gather relevant information.
Instead of processing tokens in isolation, attention lets each token ask: "What other tokens in this sequence are relevant to me?"
Why attention matters for LLMs:
- **Long-range dependencies**: Token at position 100 can directly attend to token at position 1 (RNNs struggle with this due to vanishing gradients)
- **Parallelization**: Unlike RNNs, all positions can be computed simultaneously during training
- **Interpretability**: Attention weights show what the model is "looking at"
- **Dynamic context**: Each token's representation is context-dependent, not fixed
The key innovation: **self-attention** - tokens attend to other tokens in the same sequence. This is distinct from cross-attention (used in encoder-decoder models) where queries come from one sequence and keys/values from another.
### What You'll Learn
By the end of this module, you will be able to:
- Understand Query, Key, Value projections and their roles
- Implement scaled dot-product attention from scratch
- Apply causal masking for autoregressive models
- Build multi-head attention and understand why it's beneficial
- Recognize attention patterns and what they reveal
**Note**: Attention itself is position-agnostic — it treats tokens as an unordered set. We rely on positional embeddings (Module 04) to give the model a sense of token order.
## Attention as Three Questions
Every token in a sequence asks three questions. Understanding these questions is the key to understanding attention.
```{python}
import numpy as np
# A simple sentence
tokens = ["The", "cat", "sat"]
# Each token has an embedding (we'll use random ones for illustration)
np.random.seed(42)
embed_dim = 4
embeddings = {tok: np.random.randn(embed_dim).round(2) for tok in tokens}
print("Each token has an embedding vector:")
for tok, emb in embeddings.items():
print(f" '{tok}': {emb}")
```
**The Three Questions:**
| Question | Vector | What it asks |
|----------|--------|--------------|
| **Query (Q)** | "What am I looking for?" | Token seeks relevant context |
| **Key (K)** | "What do I contain?" | Token advertises its content |
| **Value (V)** | "What do I return if matched?" | Token's actual information |
```{python}
# Each token projects its embedding into Q, K, V
# These are learned linear transformations
# For "sat", the query might encode: "I need a subject (who sat?)"
# For "cat", the key might encode: "I'm a noun, a subject candidate"
# For "cat", the value carries: the actual semantic content of "cat"
print("When 'sat' attends to 'cat':")
print(" Q_sat . K_cat = high score (sat is looking for a subject, cat is one)")
print(" The output for 'sat' includes V_cat weighted by this score")
```
The magic: Q, K, V are **learned projections**. The model learns what to look for (Q), how to advertise content (K), and what information to pass along (V).
## Intuition: Query, Key, Value
Think of attention as a "soft lookup" - like a database query, but differentiable:
```
Query: "What information do I need?" (the question)
Keys: "What information do I have?" (index/labels for content)
Values: "Here's my actual information" (the content itself)
Attention = softmax(Query . Keys) x Values
```
**Analogy**: Imagine a library where:
- Your **query** is "books about cats"
- Each book has a **key** (its topic/keywords)
- Each book has a **value** (its actual content)
- You get a weighted average of book contents based on how well they match your query
For the sentence "The cat sat on the mat":
- "sat" might attend strongly to "cat" (who sat?) and weakly to "mat" (where?)
- "mat" might attend strongly to "the" and "on" (which mat? on what?)
```{ojs}
//| echo: false
// Attention weights data - each token's attention distribution over all tokens
// Rows = query (who is looking), Columns = key (what they look at)
attentionData = {
const tokens = ["The", "cat", "sat", "on", "the", "mat"];
// Realistic attention patterns for each token (rows sum to ~1)
const weights = [
[0.45, 0.15, 0.10, 0.05, 0.20, 0.05], // "The" - attends to self and other "the"
[0.15, 0.50, 0.20, 0.05, 0.05, 0.05], // "cat" - attends strongly to self
[0.10, 0.60, 0.15, 0.05, 0.02, 0.08], // "sat" - attends to "cat" (who sat?)
[0.05, 0.10, 0.25, 0.35, 0.05, 0.20], // "on" - attends to self and context
[0.35, 0.10, 0.05, 0.10, 0.30, 0.10], // "the" - attends to other "The"
[0.05, 0.15, 0.30, 0.25, 0.05, 0.20], // "mat" - attends to "sat", "on"
];
return { tokens, weights };
}
// Selected token state
viewof focusToken = Inputs.select(attentionData.tokens, {
label: "Select token to analyze",
value: "sat"
})
```
```{ojs}
//| echo: false
// Interactive attention flow visualization
{
const { tokens, weights } = attentionData;
const focusIdx = tokens.indexOf(focusToken);
const focusWeights = weights[focusIdx];
const width = 650;
const height = 280;
const tokenY = 80;
const tokenSpacing = 95;
const tokenStartX = 60;
const tokenRadius = 32;
const arrowStartY = tokenY + tokenRadius + 15;
const arrowEndY = height - 55;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Defs for gradients and markers
const defs = svg.append("defs");
// Glow filter for focused token
const glowFilter = defs.append("filter")
.attr("id", "attn-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "4")
.attr("result", "coloredBlur");
const glowMerge = glowFilter.append("feMerge");
glowMerge.append("feMergeNode").attr("in", "coloredBlur");
glowMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 10);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 28)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text(`"${focusToken}" attends to...`);
// Subtitle
svg.append("text")
.attr("x", width / 2)
.attr("y", 48)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.6)
.text("Click any token to change focus");
// Create arrow markers with varying opacity based on attention weight
tokens.forEach((_, i) => {
const weight = focusWeights[i];
const opacity = 0.3 + weight * 0.7;
defs.append("marker")
.attr("id", `attn-arrow-${i}`)
.attr("viewBox", "0 -4 8 8")
.attr("refX", 6)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-4L8,0L0,4Z")
.attr("fill", diagramTheme.highlight)
.attr("opacity", opacity);
});
// Draw attention arrows from focus token to all tokens
const focusX = tokenStartX + focusIdx * tokenSpacing;
tokens.forEach((token, i) => {
const weight = focusWeights[i];
const targetX = tokenStartX + i * tokenSpacing;
// Calculate stroke width based on attention weight (1 to 6 pixels)
const strokeWidth = 1 + weight * 5;
const opacity = 0.2 + weight * 0.8;
// Calculate curve control point
const midX = (focusX + targetX) / 2;
const curveOffset = Math.abs(focusX - targetX) * 0.3;
const controlY = arrowStartY + 40 + curveOffset * 0.5;
// Arrow group with animation
const arrowGroup = svg.append("g")
.attr("class", "attention-arrow")
.style("opacity", 0);
// Draw curved arrow
if (Math.abs(focusIdx - i) <= 1) {
// Straight or slight curve for adjacent tokens
arrowGroup.append("path")
.attr("d", `M${focusX},${arrowStartY} Q${midX},${controlY} ${targetX},${arrowEndY}`)
.attr("fill", "none")
.attr("stroke", diagramTheme.highlight)
.attr("stroke-width", strokeWidth)
.attr("stroke-opacity", opacity)
.attr("marker-end", `url(#attn-arrow-${i})`);
} else {
// More pronounced curve for distant tokens
arrowGroup.append("path")
.attr("d", `M${focusX},${arrowStartY} Q${midX},${controlY + 20} ${targetX},${arrowEndY}`)
.attr("fill", "none")
.attr("stroke", diagramTheme.highlight)
.attr("stroke-width", strokeWidth)
.attr("stroke-opacity", opacity)
.attr("marker-end", `url(#attn-arrow-${i})`);
}
// Animate arrow appearance with stagger
arrowGroup
.transition()
.delay(i * 80)
.duration(400)
.style("opacity", 1);
});
// Draw tokens as clickable circles
tokens.forEach((token, i) => {
const x = tokenStartX + i * tokenSpacing;
const isFocus = i === focusIdx;
const weight = focusWeights[i];
const tokenGroup = svg.append("g")
.attr("transform", `translate(${x}, ${tokenY})`)
.style("cursor", "pointer")
.on("click", () => {
// Update the select input programmatically
const select = document.querySelector('select[name="Select token to analyze"]');
if (select) {
select.value = token;
select.dispatchEvent(new Event('input', { bubbles: true }));
}
});
// Token circle - highlight if focused
tokenGroup.append("circle")
.attr("r", tokenRadius)
.attr("fill", isFocus ? diagramTheme.highlight : diagramTheme.nodeFill)
.attr("stroke", isFocus ? diagramTheme.highlight : diagramTheme.nodeStroke)
.attr("stroke-width", isFocus ? 3 : 2)
.attr("filter", isFocus ? "url(#attn-glow)" : null)
.style("transition", "all 0.3s ease");
// Token label
tokenGroup.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isFocus ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "13px")
.attr("font-weight", isFocus ? "700" : "500")
.text(token);
// Attention weight label below each token (in the "target" area)
const weightGroup = svg.append("g")
.attr("transform", `translate(${x}, ${arrowEndY + 25})`)
.style("opacity", 0);
// Weight background pill
const weightText = (weight * 100).toFixed(0) + "%";
const pillWidth = 42;
const pillHeight = 22;
weightGroup.append("rect")
.attr("x", -pillWidth / 2)
.attr("y", -pillHeight / 2)
.attr("width", pillWidth)
.attr("height", pillHeight)
.attr("rx", pillHeight / 2)
.attr("fill", diagramTheme.highlight)
.attr("opacity", 0.15 + weight * 0.6);
weightGroup.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.highlight)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(weightText);
// Animate weight labels
weightGroup
.transition()
.delay(300 + i * 60)
.duration(300)
.style("opacity", 1);
});
return svg.node();
}
```
Each row of attention weights sums to 1 (softmax normalization).
## The Math: Scaled Dot-Product Attention
The attention formula:
```
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) x V
```
Where:
- **Q** (Query): What am I looking for? Shape: (seq, d_k)
- **K** (Key): What do I have to offer? Shape: (seq, d_k)
- **V** (Value): What information do I carry? Shape: (seq, d_v)
- **d_k**: Dimension of keys (for scaling)
### Step by Step
```{ojs}
//| echo: false
// Step slider control
viewof attentionStep = Inputs.range([0, 4], {
value: 0,
step: 1,
label: "Step"
})
```
```{ojs}
//| echo: false
// Step descriptions for the current step
stepDescriptions = [
{ title: "Overview", desc: "The four steps of scaled dot-product attention" },
{ title: "Step 1: Compute Similarity", desc: "Multiply Q by K transposed to get raw attention scores" },
{ title: "Step 2: Scale", desc: "Divide by sqrt(d_k) to prevent softmax saturation" },
{ title: "Step 3: Softmax", desc: "Normalize each row to get attention weights (sum to 1)" },
{ title: "Step 4: Apply to Values", desc: "Multiply weights by V to get weighted output" }
]
// Scaled dot-product attention visualization
scaledDotProductDiagram = {
const width = 720;
const height = 520;
const step = attentionStep;
const theme = diagramTheme;
const svg = d3.create("svg")
.attr("viewBox", `0 0 ${width} ${height}`)
.attr("width", "100%")
.attr("height", height)
.style("max-width", `${width}px`)
.style("font-family", "'IBM Plex Mono', 'JetBrains Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 12);
// Define gradients and filters
const defs = svg.append("defs");
// Glow filter for active elements
const glowFilter = defs.append("filter")
.attr("id", "sdpa-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "3")
.attr("result", "coloredBlur");
const glowMerge = glowFilter.append("feMerge");
glowMerge.append("feMergeNode").attr("in", "coloredBlur");
glowMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Color palette for matrices
const colors = {
Q: "#22d3ee", // cyan
K: "#a78bfa", // purple
V: "#4ade80", // green
QKT: "#f472b6", // pink
scaled: "#fbbf24", // amber
weights: "#fb923c", // orange
output: "#34d399" // emerald
};
// Matrix dimensions for visualization
const matrixH = 48;
const matrixW = 36;
const cellSize = 12;
// Helper to draw a matrix block
function drawMatrix(g, x, y, rows, cols, color, label, sublabel, isActive, opacity = 1) {
const w = cols * cellSize;
const h = rows * cellSize;
const group = g.append("g")
.attr("transform", `translate(${x}, ${y})`)
.style("opacity", opacity);
// Matrix background
group.append("rect")
.attr("x", -w/2)
.attr("y", -h/2)
.attr("width", w)
.attr("height", h)
.attr("fill", color)
.attr("fill-opacity", isActive ? 0.3 : 0.15)
.attr("stroke", color)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.attr("rx", 4)
.attr("filter", isActive ? "url(#sdpa-glow)" : null);
// Grid lines
for (let i = 1; i < rows; i++) {
group.append("line")
.attr("x1", -w/2)
.attr("y1", -h/2 + i * cellSize)
.attr("x2", w/2)
.attr("y2", -h/2 + i * cellSize)
.attr("stroke", color)
.attr("stroke-opacity", 0.3)
.attr("stroke-width", 0.5);
}
for (let j = 1; j < cols; j++) {
group.append("line")
.attr("x1", -w/2 + j * cellSize)
.attr("y1", -h/2)
.attr("x2", -w/2 + j * cellSize)
.attr("y2", h/2)
.attr("stroke", color)
.attr("stroke-opacity", 0.3)
.attr("stroke-width", 0.5);
}
// Label above
group.append("text")
.attr("x", 0)
.attr("y", -h/2 - 10)
.attr("text-anchor", "middle")
.attr("fill", isActive ? color : theme.nodeText)
.attr("font-size", "13px")
.attr("font-weight", isActive ? "700" : "500")
.text(label);
// Shape label below
if (sublabel) {
group.append("text")
.attr("x", 0)
.attr("y", h/2 + 16)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.6)
.text(sublabel);
}
return group;
}
// Helper to draw an arrow
function drawArrow(g, x1, y1, x2, y2, color, isActive, label = null) {
const markerId = `sdpa-arrow-${Math.random().toString(36).substr(2, 6)}`;
defs.append("marker")
.attr("id", markerId)
.attr("viewBox", "0 -4 8 8")
.attr("refX", 6)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-4L8,0L0,4Z")
.attr("fill", isActive ? color : theme.edgeStroke);
const group = g.append("g");
group.append("line")
.attr("x1", x1)
.attr("y1", y1)
.attr("x2", x2)
.attr("y2", y2)
.attr("stroke", isActive ? color : theme.edgeStroke)
.attr("stroke-width", isActive ? 2 : 1.5)
.attr("marker-end", `url(#${markerId})`)
.attr("filter", isActive ? "url(#sdpa-glow)" : null);
if (label) {
const midX = (x1 + x2) / 2;
const midY = (y1 + y2) / 2;
group.append("text")
.attr("x", midX)
.attr("y", midY - 8)
.attr("text-anchor", "middle")
.attr("fill", isActive ? color : theme.nodeText)
.attr("font-size", "10px")
.attr("font-weight", isActive ? "600" : "400")
.text(label);
}
return group;
}
// Helper to draw operation box
function drawOp(g, x, y, text, isActive, color = theme.highlight) {
const group = g.append("g")
.attr("transform", `translate(${x}, ${y})`);
const padding = 8;
const textEl = group.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? color : theme.nodeText)
.attr("font-size", "12px")
.attr("font-weight", isActive ? "700" : "500")
.text(text);
const bbox = textEl.node().getBBox();
group.insert("rect", "text")
.attr("x", -bbox.width/2 - padding)
.attr("y", -bbox.height/2 - padding/2)
.attr("width", bbox.width + padding * 2)
.attr("height", bbox.height + padding)
.attr("rx", 4)
.attr("fill", isActive ? color : theme.nodeFill)
.attr("fill-opacity", isActive ? 0.2 : 1)
.attr("stroke", isActive ? color : theme.nodeStroke)
.attr("stroke-width", isActive ? 2 : 1)
.attr("filter", isActive ? "url(#sdpa-glow)" : null);
return group;
}
// Title and description
const stepInfo = stepDescriptions[step];
svg.append("text")
.attr("x", width / 2)
.attr("y", 28)
.attr("text-anchor", "middle")
.attr("fill", theme.highlight)
.attr("font-size", "15px")
.attr("font-weight", "700")
.text(stepInfo.title);
svg.append("text")
.attr("x", width / 2)
.attr("y", 48)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "12px")
.attr("opacity", 0.7)
.text(stepInfo.desc);
// Main content group
const content = svg.append("g").attr("transform", "translate(0, 60)");
// Layout positions
const row1Y = 50;
const row2Y = 160;
const row3Y = 270;
const row4Y = 380;
// --- STEP 1: Q, K -> QK^T ---
const step1Active = step === 0 || step === 1;
// Q matrix
drawMatrix(content, 140, row1Y, 4, 3, colors.Q, "Q", "seq x d_k", step === 1, step1Active ? 1 : 0.4);
// K matrix
drawMatrix(content, 280, row1Y, 4, 3, colors.K, "K", "seq x d_k", step === 1, step1Active ? 1 : 0.4);
// K^T indicator
if (step1Active) {
content.append("text")
.attr("x", 280)
.attr("y", row1Y + 50)
.attr("text-anchor", "middle")
.attr("fill", step === 1 ? colors.K : theme.nodeText)
.attr("font-size", "10px")
.attr("opacity", step === 1 ? 1 : 0.5)
.text("transpose");
}
// V matrix (shown from start but activated in step 4)
const step4Active = step === 0 || step === 4;
drawMatrix(content, 580, row1Y, 4, 3, colors.V, "V", "seq x d_v", step === 4, step4Active ? 1 : 0.4);
// Matmul operation for Q @ K^T
drawOp(content, 210, row2Y - 20, "@ (matmul)", step === 1, colors.QKT);
// Arrows from Q and K to matmul
if (step1Active) {
drawArrow(content, 140, row1Y + 35, 190, row2Y - 40, colors.Q, step === 1);
drawArrow(content, 280, row1Y + 35, 230, row2Y - 40, colors.K, step === 1);
}
// QK^T result
const step2Active = step === 0 || step === 2;
drawMatrix(content, 210, row2Y + 40, 4, 4, colors.QKT, "QK^T", "seq x seq", step === 1 || step === 2, step1Active || step2Active ? 1 : 0.4);
// --- STEP 2: Scale ---
// Scale operation
drawOp(content, 210, row3Y - 20, "/ sqrt(d_k)", step === 2, colors.scaled);
// Arrow from QK^T to scale
if (step2Active) {
drawArrow(content, 210, row2Y + 75, 210, row3Y - 40, colors.QKT, step === 2);
}
// Scaled scores
const step3Active = step === 0 || step === 3;
drawMatrix(content, 210, row3Y + 40, 4, 4, colors.scaled, "Scaled", "seq x seq", step === 2 || step === 3, step2Active || step3Active ? 1 : 0.4);
// Annotation for why we scale
if (step === 2) {
content.append("text")
.attr("x", 330)
.attr("y", row3Y - 20)
.attr("fill", colors.scaled)
.attr("font-size", "10px")
.text("prevents gradient vanishing");
}
// --- STEP 3: Softmax ---
drawOp(content, 210, row4Y - 20, "softmax(dim=-1)", step === 3, colors.weights);
// Arrow from scaled to softmax
if (step3Active) {
drawArrow(content, 210, row3Y + 75, 210, row4Y - 40, colors.scaled, step === 3);
}
// Attention weights
drawMatrix(content, 210, row4Y + 40, 4, 4, colors.weights, "Weights", "each row sums to 1", step === 3 || step === 4, step3Active || step4Active ? 1 : 0.4);
// --- STEP 4: Apply to Values ---
// Matmul with V
drawOp(content, 400, row4Y + 40, "@ (matmul)", step === 4, colors.output);
// Arrows for final matmul
if (step4Active) {
drawArrow(content, 250, row4Y + 40, 365, row4Y + 40, colors.weights, step === 4);
drawArrow(content, 580, row1Y + 35, 580, row4Y - 10, colors.V, step === 4);
drawArrow(content, 580, row4Y + 10, 435, row4Y + 40, colors.V, step === 4);
}
// Output
drawMatrix(content, 560, row4Y + 40, 4, 3, colors.output, "Output", "seq x d_v", step === 4, step4Active ? 1 : 0.4);
// Arrow from matmul to output
if (step4Active) {
drawArrow(content, 435, row4Y + 40, 520, row4Y + 40, colors.output, step === 4);
}
// Formula at bottom
const formulaY = height - 25;
const formulaText = "Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) x V";
svg.append("text")
.attr("x", width / 2)
.attr("y", formulaY)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "12px")
.attr("opacity", 0.8)
.text(formulaText);
// Highlight the active part of formula
if (step > 0) {
const highlights = [
null,
{ text: "QK^T", color: colors.QKT },
{ text: "/ sqrt(d_k)", color: colors.scaled },
{ text: "softmax(...)", color: colors.weights },
{ text: "x V", color: colors.output }
];
const h = highlights[step];
if (h) {
svg.append("text")
.attr("x", width / 2)
.attr("y", formulaY + 18)
.attr("text-anchor", "middle")
.attr("fill", h.color)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(`Current: ${h.text}`);
}
}
return svg.node();
}
```
### Building Attention by Hand
Before using PyTorch, let's build attention from scratch with NumPy to see exactly what happens at each step.
```{python}
import numpy as np
def attention_from_scratch(x, W_q, W_k, W_v):
"""
Single-head attention implemented with pure NumPy.
Args:
x: Input embeddings (seq_len, embed_dim)
W_q, W_k, W_v: Projection matrices (embed_dim, head_dim)
Returns:
output: Attended values (seq_len, head_dim)
weights: Attention weights (seq_len, seq_len)
"""
# Step 1: Project input into Q, K, V
Q = x @ W_q # (seq, head_dim) - "What am I looking for?"
K = x @ W_k # (seq, head_dim) - "What do I contain?"
V = x @ W_v # (seq, head_dim) - "What do I return?"
print(f"Input x shape: {x.shape}")
print(f"Q = x @ W_q: {Q.shape}")
print(f"K = x @ W_k: {K.shape}")
print(f"V = x @ W_v: {V.shape}")
# Step 2: Compute attention scores
# Each query attends to all keys: Q @ K.T
d_k = K.shape[-1]
scores = Q @ K.T # (seq, seq) - similarity between every pair
scores = scores / np.sqrt(d_k) # Scale to prevent softmax saturation
print(f"\nScores = Q @ K.T / sqrt({d_k}): {scores.shape}")
print(f"Score matrix (who attends to whom):")
print(scores.round(2))
# Step 3: Softmax to get attention weights
# Each row sums to 1: how much each position attends to others
def softmax(x):
exp_x = np.exp(x - x.max(axis=-1, keepdims=True)) # Numerical stability
return exp_x / exp_x.sum(axis=-1, keepdims=True)
weights = softmax(scores)
print(f"\nAttention weights (each row sums to 1):")
print(weights.round(3))
print(f"Row sums: {weights.sum(axis=-1).round(3)}")
# Step 4: Weighted sum of values
output = weights @ V # (seq, head_dim)
print(f"\nOutput = weights @ V: {output.shape}")
return output, weights
# Demo with a tiny example
np.random.seed(42)
seq_len, embed_dim, head_dim = 3, 4, 2
x = np.random.randn(seq_len, embed_dim)
W_q = np.random.randn(embed_dim, head_dim) * 0.5
W_k = np.random.randn(embed_dim, head_dim) * 0.5
W_v = np.random.randn(embed_dim, head_dim) * 0.5
print("=" * 50)
print("ATTENTION FROM SCRATCH")
print("=" * 50)
output, weights = attention_from_scratch(x, W_q, W_k, W_v)
```
**Key Insight**: The entire attention mechanism is just four matrix multiplications:
1. `Q = x @ W_q` - Project to queries
2. `K = x @ W_k` - Project to keys
3. `scores = Q @ K.T / sqrt(d)` - Compute similarities
4. `output = softmax(scores) @ V` - Weighted sum of values
### Why Scale by sqrt(d_k)?
Without scaling, large d_k leads to large dot products, which pushes softmax into regions with tiny gradients (saturation). Here's the mathematical intuition:
**The Problem**: When Q and K have elements drawn from a distribution with mean 0 and variance 1, their dot product has variance proportional to d_k. For d_k = 64, dot products can easily reach values like 8 or -10.
**Why This Matters**: Softmax of large values produces near-one-hot distributions:
- `softmax([10, 0, 0])` = `[0.9999, 0.00005, 0.00005]`
This causes:
1. **Vanishing gradients**: The gradient of softmax approaches 0 at extremes
2. **Loss of information**: We want soft attention, not hard selection
**The Solution**: Dividing by sqrt(d_k) normalizes variance back to ~1.
```{python}
import torch
import torch.nn.functional as F
import math
# Show the effect of scaling
d_k = 64
q = torch.randn(1, d_k)
k = torch.randn(1, d_k)
dot_product = (q @ k.T).item()
scaled = dot_product / math.sqrt(d_k)
print(f"d_k = {d_k}")
print(f"Raw dot product: {dot_product:.2f}")
print(f"Scaled by sqrt({d_k}) = {math.sqrt(d_k):.1f}: {scaled:.2f}")
print(f"\nScaling keeps values in a reasonable range for softmax")
# Demonstrate the gradient problem
scores_large = torch.tensor([[10.0, 1.0, 1.0]], requires_grad=True)
scores_normal = torch.tensor([[1.0, 0.5, 0.5]], requires_grad=True)
weights_large = F.softmax(scores_large, dim=-1)
weights_normal = F.softmax(scores_normal, dim=-1)
print(f"\nLarge scores [10, 1, 1] -> softmax: {weights_large.detach().numpy().round(4)}")
print(f"Normal scores [1, 0.5, 0.5] -> softmax: {weights_normal.detach().numpy().round(3)}")
```
### Numerical Stability in Softmax
There is a hidden danger in the naive softmax implementation: **overflow**.
```{python}
import numpy as np
# The naive softmax
def naive_softmax(x):
"""This will overflow for large values!"""
exp_x = np.exp(x)
return exp_x / exp_x.sum()
# Try it with large values
large_scores = np.array([1000.0, 1001.0, 1002.0])
print("Large scores:", large_scores)
print("exp(1000) =", np.exp(1000)) # This is inf!
try:
result = naive_softmax(large_scores)
print("Naive softmax:", result) # Will be [nan, nan, nan]
except:
print("Overflow error!")
```
**The Problem**: `exp(1000)` is astronomically large - it overflows to infinity. Even `exp(100)` is about 2.7 x 10^43.
**The Solution**: The max-subtraction trick. Subtract the maximum value before exponentiating.
```{python}
def stable_softmax(x):
"""
Numerically stable softmax using the max-subtraction trick.
Key insight: softmax(x) = softmax(x - c) for any constant c
We choose c = max(x) to keep values small.
"""
# Subtract max for numerical stability
x_shifted = x - x.max()
print(f"Original: {x}")
print(f"After subtracting max ({x.max()}): {x_shifted}")
print(f"Now exp() won't overflow: exp({x_shifted}) = {np.exp(x_shifted)}")
exp_x = np.exp(x_shifted)
return exp_x / exp_x.sum()
print("Stable softmax with max-subtraction trick:")
print("=" * 50)
large_scores = np.array([1000.0, 1001.0, 1002.0])
result = stable_softmax(large_scores)
print(f"\nResult: {result}")
print(f"Sum: {result.sum()}") # Should be 1.0
```
**Why does this work mathematically?**
```
softmax(x)_i = exp(x_i) / sum(exp(x_j))
= exp(x_i - c) / sum(exp(x_j - c)) [multiply by exp(-c)/exp(-c)]
For c = max(x), all exponents are <= 0, so exp() stays bounded.
```
```{python}
# Verify the math: both give same result
normal_scores = np.array([2.0, 1.0, 0.1])
naive_result = naive_softmax(normal_scores)
stable_result = stable_softmax(normal_scores)
print(f"\nNaive: {naive_result}")
print(f"Stable: {stable_result}")
print(f"Same? {np.allclose(naive_result, stable_result)}")
```
::: {.callout-warning}
## Always Use Stable Softmax
PyTorch's `F.softmax` automatically uses the max-subtraction trick. Never implement naive softmax in production code - it will fail silently with NaN values when scores get large.
:::
## Code: Scaled Dot-Product Attention
Let's implement attention step by step. This follows the exact algorithm from the `attention.py` module:
```{python}
def scaled_dot_product_attention(query, key, value, mask=None):
"""
Compute scaled dot-product attention.
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) x V
Args:
query: (batch, seq, d_k) or (..., seq, d_k)
key: (batch, seq, d_k)
value: (batch, seq, d_v) # d_v can differ from d_k
mask: Optional mask where 0 = masked, 1 = attend
Returns:
output: (batch, seq, d_v)
attention_weights: (batch, seq, seq)
Note: The mask convention (0 = masked) matches the module implementation.
Masked positions get -inf before softmax, becoming 0 after.
"""
d_k = query.size(-1)
# Step 1: Compute similarity scores
# QK^T: (..., seq, d_k) @ (..., d_k, seq) -> (..., seq, seq)
scores = torch.matmul(query, key.transpose(-2, -1))
# Step 2: Scale by sqrt(d_k)
scores = scores / math.sqrt(d_k)
# Step 3: Apply mask (if provided)
# Masked positions get -inf, which becomes 0 after softmax
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 4: Softmax (each row sums to 1)
attention_weights = F.softmax(scores, dim=-1)
# Step 5: Weighted sum of values
output = torch.matmul(attention_weights, value)
return output, attention_weights
# Test it
batch, seq, d_k = 1, 4, 8
Q = torch.randn(batch, seq, d_k)
K = torch.randn(batch, seq, d_k)
V = torch.randn(batch, seq, d_k)
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Query shape: {Q.shape}")
print(f"Key shape: {K.shape}")
print(f"Value shape: {V.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"\nAttention weights (each row sums to 1):")
print(weights[0].round(decimals=2).numpy())
print(f"\nRow sums: {weights[0].sum(dim=-1).numpy()}")
```
## Visualizing Attention Patterns
```{python}
import matplotlib.pyplot as plt
def visualize_attention(weights, tokens=None, title="Attention Pattern"):
"""Visualize attention weights as a heatmap."""
if weights.dim() == 3:
weights = weights[0] # Remove batch dim
weights = weights.detach().numpy()
seq_len = weights.shape[0]
plt.figure(figsize=(8, 6))
plt.imshow(weights, cmap='Blues', vmin=0, vmax=weights.max())
plt.colorbar(label='Attention Weight')
if tokens:
plt.xticks(range(seq_len), tokens, rotation=45, ha='right')
plt.yticks(range(seq_len), tokens)
else:
plt.xlabel('Key Position (what we look at)')
plt.ylabel('Query Position (who is looking)')
# Add values in cells
for i in range(seq_len):
for j in range(seq_len):
plt.text(j, i, f'{weights[i, j]:.2f}', ha='center', va='center', fontsize=10)
plt.title(title)
plt.tight_layout()
plt.show()
# Visualize the attention pattern from above
visualize_attention(weights, title="Random Attention Pattern")
```
## Causal Masking for Language Models
In autoregressive models (like GPT), tokens can only attend to **previous** tokens, not future ones. We enforce this with a causal mask:
::: {.callout-note}
## Mask Convention
In this lesson, we use an **additive mask** where:
- **0** = attend (score unchanged)
- **-inf** = masked (softmax converts to 0)
Some libraries use a boolean mask (True = attend, False = masked) which is converted internally. The key insight: positions with -inf before softmax become 0 attention weight.
:::
```{ojs}
//| echo: false
// Toggle between attention modes
viewof attentionMode = Inputs.radio(
["Bidirectional (BERT)", "Causal (GPT)"],
{
value: "Causal (GPT)",
label: "Attention Type"
}
)
```
```{ojs}
//| echo: false
// Bidirectional vs Causal attention matrix visualization
{
const tokens = ["The", "cat", "sat", "on", "the", "mat"];
const n = tokens.length;
const isCausal = attentionMode === "Causal (GPT)";
const width = 580;
const height = 480;
const matrixSize = 300;
const cellSize = matrixSize / n;
const matrixX = (width - matrixSize) / 2;
const matrixY = 100;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'IBM Plex Mono', 'JetBrains Mono', monospace");
// Defs for filters and gradients
const defs = svg.append("defs");
// Glow filter for active cells
const glowFilter = defs.append("filter")
.attr("id", "cell-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "2")
.attr("result", "coloredBlur");
const glowMerge = glowFilter.append("feMerge");
glowMerge.append("feMergeNode").attr("in", "coloredBlur");
glowMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Masked pattern for blocked cells
const maskedPattern = defs.append("pattern")
.attr("id", "masked-pattern")
.attr("patternUnits", "userSpaceOnUse")
.attr("width", 8)
.attr("height", 8);
maskedPattern.append("rect")
.attr("width", 8)
.attr("height", 8)
.attr("fill", diagramTheme.bgSecondary);
maskedPattern.append("path")
.attr("d", "M-1,1 l2,-2 M0,8 l8,-8 M7,9 l2,-2")
.attr("stroke", diagramTheme.nodeStroke)
.attr("stroke-width", 1)
.attr("stroke-opacity", 0.3);
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 12);
// Title
const title = isCausal ? "Causal Attention (GPT)" : "Bidirectional Attention (BERT)";
const subtitle = isCausal
? "Each token can only attend to itself and previous tokens"
: "Each token can attend to all tokens in the sequence";
svg.append("text")
.attr("x", width / 2)
.attr("y", 32)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.highlight)
.attr("font-size", "16px")
.attr("font-weight", "700")
.text(title);
svg.append("text")
.attr("x", width / 2)
.attr("y", 54)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "12px")
.attr("opacity", 0.7)
.text(subtitle);
// Matrix group
const matrix = svg.append("g")
.attr("transform", `translate(${matrixX}, ${matrixY})`);
// Row labels (Query tokens) - on the left
svg.append("text")
.attr("x", matrixX - 55)
.attr("y", matrixY - 15)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text("Query");
tokens.forEach((token, i) => {
svg.append("text")
.attr("x", matrixX - 10)
.attr("y", matrixY + i * cellSize + cellSize / 2)
.attr("text-anchor", "end")
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.text(token);
});
// Column labels (Key tokens) - on top
svg.append("text")
.attr("x", matrixX + matrixSize / 2)
.attr("y", matrixY - 25)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text("Key");
tokens.forEach((token, j) => {
svg.append("text")
.attr("x", matrixX + j * cellSize + cellSize / 2)
.attr("y", matrixY - 8)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.text(token);
});
// Draw attention matrix cells
for (let i = 0; i < n; i++) {
for (let j = 0; j < n; j++) {
const canAttend = !isCausal || j <= i;
const delay = (i * n + j) * 15;
const cell = matrix.append("g")
.attr("transform", `translate(${j * cellSize}, ${i * cellSize})`);
// Cell background
const rect = cell.append("rect")
.attr("width", cellSize - 2)
.attr("height", cellSize - 2)
.attr("x", 1)
.attr("y", 1)
.attr("rx", 4)
.attr("stroke", canAttend ? diagramTheme.highlight : diagramTheme.nodeStroke)
.attr("stroke-width", canAttend ? 1.5 : 1)
.attr("stroke-opacity", canAttend ? 0.6 : 0.3);
// Animate fill based on attention mode
if (canAttend) {
rect
.attr("fill", diagramTheme.highlight)
.attr("fill-opacity", 0)
.transition()
.delay(delay)
.duration(300)
.attr("fill-opacity", 0.25 + (i === j ? 0.25 : 0));
} else {
rect
.attr("fill", "url(#masked-pattern)")
.attr("fill-opacity", 0)
.transition()
.delay(delay)
.duration(300)
.attr("fill-opacity", 1);
}
// Checkmark or X indicator
const symbol = cell.append("text")
.attr("x", cellSize / 2)
.attr("y", cellSize / 2)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("font-size", "14px")
.attr("font-weight", "600")
.attr("opacity", 0);
if (canAttend) {
symbol
.attr("fill", diagramTheme.highlight)
.text("\u2713") // checkmark
.transition()
.delay(delay + 150)
.duration(200)
.attr("opacity", 0.8);
} else {
symbol
.attr("fill", diagramTheme.nodeStroke)
.text("\u2717") // X mark
.transition()
.delay(delay + 150)
.duration(200)
.attr("opacity", 0.5);
}
}
}
// Matrix border
matrix.append("rect")
.attr("width", matrixSize)
.attr("height", matrixSize)
.attr("fill", "none")
.attr("stroke", diagramTheme.nodeStroke)
.attr("stroke-width", 2)
.attr("rx", 6);
// Legend
const legendY = matrixY + matrixSize + 30;
// Can attend legend
const legendAttend = svg.append("g")
.attr("transform", `translate(${width / 2 - 100}, ${legendY})`);
legendAttend.append("rect")
.attr("width", 20)
.attr("height", 20)
.attr("rx", 3)
.attr("fill", diagramTheme.highlight)
.attr("fill-opacity", 0.35)
.attr("stroke", diagramTheme.highlight)
.attr("stroke-width", 1.5);
legendAttend.append("text")
.attr("x", 28)
.attr("y", 10)
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.text("Can attend");
// Cannot attend legend (only show in causal mode)
if (isCausal) {
const legendMasked = svg.append("g")
.attr("transform", `translate(${width / 2 + 40}, ${legendY})`);
legendMasked.append("rect")
.attr("width", 20)
.attr("height", 20)
.attr("rx", 3)
.attr("fill", "url(#masked-pattern)")
.attr("stroke", diagramTheme.nodeStroke)
.attr("stroke-width", 1);
legendMasked.append("text")
.attr("x", 28)
.attr("y", 10)
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.text("Masked (future)");
}
// Explanation text at bottom
const explanationY = legendY + 45;
const explanation = isCausal
? 'Row i can only see columns 0...i (lower triangle + diagonal)'
: 'Every position can attend to every other position';
svg.append("text")
.attr("x", width / 2)
.attr("y", explanationY)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.6)
.attr("font-style", "italic")
.text(explanation);
return svg.node();
}
```
### The Causal Mask from Scratch
The causal mask is elegantly simple: we add `-inf` to positions we want to mask, and softmax turns those into zeros.
```{python}
import numpy as np
def causal_mask_from_scratch(seq_len):
"""
Create a causal mask using np.triu (upper triangular).
The mask has -inf above the diagonal (future positions)
and 0 on and below the diagonal (past/current positions).
"""
# np.triu with k=1 gives us the strictly upper triangular part
# (everything above the main diagonal)
mask = np.triu(np.ones((seq_len, seq_len)), k=1)
# Convert 1s to -inf (positions to mask)
mask = mask * (-1e9) # Use large negative instead of -inf for visualization
return mask
# Visualize the mask
seq_len = 5
mask = causal_mask_from_scratch(seq_len)
print("Causal Mask (0 = attend, -inf = masked):")
print(mask.round(0))
print("\nHow it works:")
print(" Position 0: sees only position 0")
print(" Position 1: sees positions 0, 1")
print(" Position 4: sees positions 0, 1, 2, 3, 4")
```
```{python}
def attention_with_causal_mask(x, W_q, W_k, W_v):
"""
Causal attention from scratch - each position only attends to past.
"""
Q = x @ W_q
K = x @ W_k
V = x @ W_v
seq_len = x.shape[0]
d_k = K.shape[-1]
# Compute scores
scores = Q @ K.T / np.sqrt(d_k)
print("Scores before masking:")
print(scores.round(2))
# Add causal mask: -inf for future positions
mask = np.triu(np.ones((seq_len, seq_len)), k=1) * (-1e9)
scores = scores + mask
print("\nScores after adding causal mask:")
print(scores.round(2))
# Softmax: -inf becomes 0
def softmax(x):
exp_x = np.exp(x - x.max(axis=-1, keepdims=True))
return exp_x / exp_x.sum(axis=-1, keepdims=True)
weights = softmax(scores)
print("\nAttention weights (upper triangle is 0!):")
print(weights.round(3))
output = weights @ V
return output, weights
# Demo
np.random.seed(42)
seq_len, embed_dim, head_dim = 4, 4, 2
x = np.random.randn(seq_len, embed_dim)
W_q = np.random.randn(embed_dim, head_dim) * 0.5
W_k = np.random.randn(embed_dim, head_dim) * 0.5
W_v = np.random.randn(embed_dim, head_dim) * 0.5
print("=" * 50)
print("CAUSAL ATTENTION FROM SCRATCH")
print("=" * 50)
output, weights = attention_with_causal_mask(x, W_q, W_k, W_v)
```
::: {.callout-note}
## Key Insight
Causal masking is just adding `-inf` before softmax. That is the entire trick.
- `softmax([2.0, 1.0, -inf])` = `[0.73, 0.27, 0.00]`
- The `-inf` position gets exactly 0 weight
- No information flows from future to past
:::
```{python}
def create_causal_mask(seq_len):
"""Create a lower triangular causal mask."""
return torch.tril(torch.ones(seq_len, seq_len))
# Show the mask
seq_len = 6
mask = create_causal_mask(seq_len)
print("Causal Mask (1 = can attend, 0 = masked):")
print()
tokens = ["The", "cat", "sat", "on", "the", "mat"]
for i in range(seq_len):
row = ['#' if mask[i, j] == 1 else '.' for j in range(seq_len)]
print(f" {tokens[i]:4s}: {''.join(row)}")
print(f"\nPosition 0 can only see position 0")
print(f"Position 5 can see all previous positions")
```
```{python}
# Apply causal mask
Q = torch.randn(1, 6, 8)
K = torch.randn(1, 6, 8)
V = torch.randn(1, 6, 8)
# Without mask (bidirectional)
output_bi, weights_bi = scaled_dot_product_attention(Q, K, V)
# With causal mask
output_causal, weights_causal = scaled_dot_product_attention(Q, K, V, mask=mask)
# Compare
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
ax = axes[0]
w = weights_bi[0].detach().numpy()
im = ax.imshow(w, cmap='Blues', vmin=0, vmax=1)
ax.set_title('Bidirectional Attention', fontsize=14)
ax.set_xticks(range(6))
ax.set_yticks(range(6))
ax.set_xticklabels(tokens, rotation=45, ha='right')
ax.set_yticklabels(tokens)
ax = axes[1]
w = weights_causal[0].detach().numpy()
im = ax.imshow(w, cmap='Blues', vmin=0, vmax=1)
ax.set_title('Causal Attention (Lower Triangular)', fontsize=14)
ax.set_xticks(range(6))
ax.set_yticks(range(6))
ax.set_xticklabels(tokens, rotation=45, ha='right')
ax.set_yticklabels(tokens)
plt.colorbar(im, ax=axes, shrink=0.8, label='Attention Weight')
plt.tight_layout()
plt.show()
print("Notice: In causal attention, the upper triangle is 0 (can't attend to future)")
```
## Multi-Head Attention
Instead of one attention head, we use multiple "heads" that can learn different patterns:
```
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) x W_O
where head_i = Attention(Q x W_Q_i, K x W_K_i, V x W_V_i)
```
**Why Multiple Heads?**
A single attention head computes one weighted average, which limits what relationships it can capture. Multiple heads provide:
- **Diverse patterns**: Different heads can focus on different relationships (syntax, semantics, position, coreference)
- **Subspace attention**: Each head operates in a lower-dimensional subspace (`head_dim = embed_dim / num_heads`), allowing specialized representations
- **Computational efficiency**: Despite having multiple heads, the total computation is similar to single-head attention with full dimensionality (same number of parameters)
**Typical configurations**:
- GPT-2: 12 heads, 768 embed_dim, 64 head_dim
- GPT-3: 96 heads, 12288 embed_dim, 128 head_dim
- Llama 2 (7B): 32 heads, 4096 embed_dim, 128 head_dim
```{ojs}
//| echo: false
// Step slider for multi-head architecture walkthrough
viewof multiHeadStep = Inputs.range([0, 5], {
value: 0,
step: 1,
label: "Step"
})
```
```{ojs}
//| echo: false
// Step descriptions for multi-head attention flow
multiHeadStepDescriptions = [
{ title: "Input", desc: "Token embeddings enter the multi-head attention layer" },
{ title: "Linear Projections", desc: "Three learned weight matrices create Q, K, V projections" },
{ title: "Split into Heads", desc: "Reshape tensors to distribute across multiple attention heads" },
{ title: "Parallel Attention", desc: "Each head computes scaled dot-product attention independently" },
{ title: "Concatenate", desc: "Combine all head outputs back into a single tensor" },
{ title: "Output Projection", desc: "Final linear transformation produces the layer output" }
]
// Multi-head attention architecture visualization
multiHeadArchDiagram = {
const width = 780;
const height = 640;
const step = multiHeadStep;
const theme = diagramTheme;
const svg = d3.create("svg")
.attr("viewBox", `0 0 ${width} ${height}`)
.attr("width", "100%")
.attr("height", height)
.style("max-width", `${width}px`)
.style("font-family", "'IBM Plex Mono', 'JetBrains Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 12);
// Define gradients and filters
const defs = svg.append("defs");
// Glow filter for active elements
const glowFilter = defs.append("filter")
.attr("id", "mha-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "4")
.attr("result", "coloredBlur");
const glowMerge = glowFilter.append("feMerge");
glowMerge.append("feMergeNode").attr("in", "coloredBlur");
glowMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Color palette - consistent Q/K/V colors
const colors = {
Q: "#22d3ee", // cyan for Query
K: "#a78bfa", // purple for Key
V: "#4ade80", // green for Value
head: "#fb923c", // orange for heads
concat: "#f472b6", // pink for concatenation
output: "#34d399", // emerald for output
input: "#94a3b8" // slate for input
};
// Helper to create arrow marker
function createMarker(id, color) {
defs.append("marker")
.attr("id", id)
.attr("viewBox", "0 -4 8 8")
.attr("refX", 6)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-4L8,0L0,4Z")
.attr("fill", color);
}
createMarker("mha-arrow-default", theme.edgeStroke);
createMarker("mha-arrow-q", colors.Q);
createMarker("mha-arrow-k", colors.K);
createMarker("mha-arrow-v", colors.V);
createMarker("mha-arrow-head", colors.head);
createMarker("mha-arrow-concat", colors.concat);
createMarker("mha-arrow-output", colors.output);
// Helper to draw a box with label
function drawBox(g, x, y, w, h, color, label, sublabel, isActive, opacity = 1) {
const group = g.append("g")
.attr("transform", `translate(${x}, ${y})`)
.style("opacity", opacity);
group.append("rect")
.attr("x", -w/2)
.attr("y", -h/2)
.attr("width", w)
.attr("height", h)
.attr("rx", 6)
.attr("fill", isActive ? color : theme.nodeFill)
.attr("fill-opacity", isActive ? 0.25 : 1)
.attr("stroke", isActive ? color : theme.nodeStroke)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.attr("filter", isActive ? "url(#mha-glow)" : null);
group.append("text")
.attr("y", sublabel ? -6 : 0)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? color : theme.nodeText)
.attr("font-size", "12px")
.attr("font-weight", isActive ? "700" : "500")
.text(label);
if (sublabel) {
group.append("text")
.attr("y", 10)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.attr("opacity", 0.6)
.text(sublabel);
}
return group;
}
// Helper to draw an arrow
function drawArrow(g, x1, y1, x2, y2, color, isActive, markerId = "mha-arrow-default") {
const arrow = g.append("line")
.attr("x1", x1)
.attr("y1", y1)
.attr("x2", x2)
.attr("y2", y2)
.attr("stroke", isActive ? color : theme.edgeStroke)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.attr("marker-end", `url(#${isActive ? markerId : "mha-arrow-default"})`)
.attr("filter", isActive ? "url(#mha-glow)" : null);
return arrow;
}
// Helper to draw curved arrow (for parallel paths)
function drawCurvedArrow(g, x1, y1, x2, y2, curve, color, isActive, markerId) {
const midX = (x1 + x2) / 2;
const midY = (y1 + y2) / 2;
const path = g.append("path")
.attr("d", `M${x1},${y1} Q${midX + curve},${midY} ${x2},${y2}`)
.attr("fill", "none")
.attr("stroke", isActive ? color : theme.edgeStroke)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.attr("marker-end", `url(#${isActive ? markerId : "mha-arrow-default"})`)
.attr("filter", isActive ? "url(#mha-glow)" : null);
return path;
}
// Title and description
const stepInfo = multiHeadStepDescriptions[step];
svg.append("text")
.attr("x", width / 2)
.attr("y", 28)
.attr("text-anchor", "middle")
.attr("fill", theme.highlight)
.attr("font-size", "15px")
.attr("font-weight", "700")
.text(stepInfo.title);
svg.append("text")
.attr("x", width / 2)
.attr("y", 48)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "12px")
.attr("opacity", 0.7)
.text(stepInfo.desc);
// Main content group
const content = svg.append("g").attr("transform", "translate(0, 70)");
// Layout constants
const centerX = width / 2;
const row1Y = 30; // Input
const row2Y = 110; // Projections
const row3Y = 190; // Q, K, V
const row4Y = 270; // Reshape
const row5Y = 350; // Heads
const row6Y = 440; // Concat
const row7Y = 520; // Output projection + output
// Step visibility
const s0 = step === 0;
const s1 = step === 1;
const s2 = step === 2;
const s3 = step === 3;
const s4 = step === 4;
const s5 = step === 5;
// ===== ROW 1: Input =====
drawBox(content, centerX, row1Y, 180, 40, colors.input, "x", "(batch, seq, embed_dim)", s0, s0 ? 1 : 0.5);
// ===== ROW 2: Linear Projections =====
const projSpacing = 140;
const projY = row2Y;
// Arrows from input to projections
if (s0 || s1) {
drawArrow(content, centerX - 50, row1Y + 22, centerX - projSpacing, projY - 22, colors.Q, s1, "mha-arrow-q");
drawArrow(content, centerX, row1Y + 22, centerX, projY - 22, colors.K, s1, "mha-arrow-k");
drawArrow(content, centerX + 50, row1Y + 22, centerX + projSpacing, projY - 22, colors.V, s1, "mha-arrow-v");
}
drawBox(content, centerX - projSpacing, projY, 80, 36, colors.Q, "W_Q", null, s1, s1 ? 1 : 0.5);
drawBox(content, centerX, projY, 80, 36, colors.K, "W_K", null, s1, s1 ? 1 : 0.5);
drawBox(content, centerX + projSpacing, projY, 80, 36, colors.V, "W_V", null, s1, s1 ? 1 : 0.5);
// ===== ROW 3: Q, K, V matrices =====
// Arrows from projections to Q, K, V
if (s1 || s2) {
drawArrow(content, centerX - projSpacing, projY + 20, centerX - projSpacing, row3Y - 20, colors.Q, s1, "mha-arrow-q");
drawArrow(content, centerX, projY + 20, centerX, row3Y - 20, colors.K, s1, "mha-arrow-k");
drawArrow(content, centerX + projSpacing, projY + 20, centerX + projSpacing, row3Y - 20, colors.V, s1, "mha-arrow-v");
}
drawBox(content, centerX - projSpacing, row3Y, 90, 40, colors.Q, "Q", "(b, seq, embed)", s1 || s2, (s1 || s2) ? 1 : 0.5);
drawBox(content, centerX, row3Y, 90, 40, colors.K, "K", "(b, seq, embed)", s1 || s2, (s1 || s2) ? 1 : 0.5);
drawBox(content, centerX + projSpacing, row3Y, 90, 40, colors.V, "V", "(b, seq, embed)", s1 || s2, (s1 || s2) ? 1 : 0.5);
// ===== ROW 4: Reshape =====
// Single reshape operation box
drawBox(content, centerX, row4Y, 280, 40, colors.head, "Reshape + Transpose", "(b, heads, seq, head_dim)", s2, s2 ? 1 : 0.5);
// Arrows from Q, K, V to reshape
if (s2) {
drawArrow(content, centerX - projSpacing, row3Y + 22, centerX - 80, row4Y - 22, colors.Q, true, "mha-arrow-q");
drawArrow(content, centerX, row3Y + 22, centerX, row4Y - 22, colors.K, true, "mha-arrow-k");
drawArrow(content, centerX + projSpacing, row3Y + 22, centerX + 80, row4Y - 22, colors.V, true, "mha-arrow-v");
}
// ===== ROW 5: Parallel Heads =====
const headSpacing = 150;
const heads = [
{ label: "Head 0", x: centerX - headSpacing * 1.5 },
{ label: "Head 1", x: centerX - headSpacing * 0.5 },
{ label: "Head 2", x: centerX + headSpacing * 0.5 },
{ label: "Head 3", x: centerX + headSpacing * 1.5 }
];
// Arrows from reshape to heads
if (s2 || s3) {
heads.forEach((head, i) => {
const startX = centerX + (i - 1.5) * 60;
drawArrow(content, startX, row4Y + 22, head.x, row5Y - 32, colors.head, s3, "mha-arrow-head");
});
}
// Draw heads
heads.forEach((head, i) => {
const headGroup = content.append("g")
.attr("transform", `translate(${head.x}, ${row5Y})`)
.style("opacity", (s3) ? 1 : 0.5);
// Head container
headGroup.append("rect")
.attr("x", -55)
.attr("y", -30)
.attr("width", 110)
.attr("height", 60)
.attr("rx", 8)
.attr("fill", s3 ? colors.head : theme.nodeFill)
.attr("fill-opacity", s3 ? 0.15 : 1)
.attr("stroke", s3 ? colors.head : theme.nodeStroke)
.attr("stroke-width", s3 ? 2 : 1.5)
.attr("filter", s3 ? "url(#mha-glow)" : null);
// Head label
headGroup.append("text")
.attr("y", -10)
.attr("text-anchor", "middle")
.attr("fill", s3 ? colors.head : theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", s3 ? "700" : "500")
.text(head.label);
// Mini attention indicator
headGroup.append("text")
.attr("y", 10)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.attr("opacity", 0.6)
.text("softmax(QK^T/sqrt(d))V");
});
// ===== ROW 6: Concatenate =====
// Arrows from heads to concat
if (s3 || s4) {
heads.forEach((head, i) => {
drawArrow(content, head.x, row5Y + 32, centerX + (i - 1.5) * 40, row6Y - 22, colors.concat, s4, "mha-arrow-concat");
});
}
drawBox(content, centerX, row6Y, 180, 40, colors.concat, "Concatenate", "(b, seq, embed_dim)", s4, s4 ? 1 : 0.5);
// ===== ROW 7: Output Projection and Output =====
// Arrow from concat to output projection
if (s4 || s5) {
drawArrow(content, centerX, row6Y + 22, centerX, row7Y - 22, colors.output, s5, "mha-arrow-output");
}
drawBox(content, centerX, row7Y, 140, 36, colors.output, "W_O", "output projection", s5, s5 ? 1 : 0.5);
// Final output
if (s5) {
drawArrow(content, centerX, row7Y + 20, centerX, row7Y + 55, colors.output, true, "mha-arrow-output");
}
drawBox(content, centerX, row7Y + 75, 180, 40, colors.output, "Output", "(batch, seq, embed_dim)", s5, s5 ? 1 : 0.5);
// Shape annotations on the right side
const annotationX = width - 100;
const annotations = [
{ y: row1Y + 70, text: "embed_dim", active: s0 || s1 },
{ y: row3Y + 70, text: "= heads x head_dim", active: s2 },
{ y: row5Y + 70, text: "parallel computation", active: s3 },
{ y: row6Y + 70, text: "recombine", active: s4 }
];
annotations.forEach(ann => {
if (ann.active) {
svg.append("text")
.attr("x", annotationX)
.attr("y", ann.y)
.attr("text-anchor", "middle")
.attr("fill", theme.highlight)
.attr("font-size", "10px")
.attr("font-style", "italic")
.attr("opacity", 0.8)
.text(ann.text);
}
});
return svg.node();
}
```
### What Different Heads Learn
In trained models, different heads specialize:
- **Head 0**: "Who did what?" - attends to subject-verb pairs
- **Head 1**: "What comes before?" - attends to previous token
- **Head 2**: "What's similar?" - attends to semantically similar words
- **Head 3**: "Syntax patterns" - attends to grammatical structure
```{python}
import torch.nn as nn
# Simplified implementation for learning - see attention.py for production version
class MultiHeadAttention(nn.Module):
"""Multi-head attention with separate Q, K, V projections (simplified for illustration)."""
def __init__(self, embed_dim, num_heads, dropout=0.0):
super().__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Projections for Q, K, V
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# Output projection
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None, return_attention=False):
batch_size, seq_len, _ = x.shape
# Project Q, K, V
q = self.q_proj(x) # (batch, seq, embed)
k = self.k_proj(x)
v = self.v_proj(x)
# Reshape for multi-head: (batch, seq, embed) -> (batch, heads, seq, head_dim)
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Compute attention
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Apply attention to values
attn_output = torch.matmul(attention_weights, v)
# Reshape back: (batch, heads, seq, head_dim) -> (batch, seq, embed)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.embed_dim)
# Final projection
output = self.out_proj(attn_output)
if return_attention:
return output, attention_weights
return output
# Test multi-head attention
embed_dim = 64
num_heads = 8
mha = MultiHeadAttention(embed_dim=embed_dim, num_heads=num_heads)
x = torch.randn(2, 10, embed_dim)
output, weights = mha(x, return_attention=True)
print(f"Multi-Head Attention Configuration:")
print(f" Embedding dimension: {embed_dim}")
print(f" Number of heads: {num_heads}")
print(f" Dimension per head: {embed_dim // num_heads}")
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f" (batch, heads, query_pos, key_pos)")
print(f"\nTotal parameters: {sum(p.numel() for p in mha.parameters()):,}")
```
### Visualizing Multi-Head Attention
```{python}
# Visualize attention patterns for different heads
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for head in range(num_heads):
ax = axes[head // 4, head % 4]
w = weights[0, head].detach().numpy()
im = ax.imshow(w, cmap='Blues', vmin=0, vmax=w.max())
ax.set_title(f'Head {head}', fontsize=12)
ax.set_xlabel('Key')
ax.set_ylabel('Query')
plt.suptitle('Attention Patterns Across 8 Heads\n(Each head learns different patterns)', fontsize=14)
plt.tight_layout()
plt.show()
```
## Using Our Attention Module
The `attention.py` module provides production-ready implementations:
```{python}
from attention import (
CausalMultiHeadAttention,
demonstrate_attention,
demonstrate_causal_attention
)
# Run built-in demonstrations
print("=" * 60)
print("MULTI-HEAD ATTENTION DEMONSTRATION")
print("=" * 60)
demonstrate_attention(seq_len=6, embed_dim=32, num_heads=4)
```
```{python}
print("\n" + "=" * 60)
print("CAUSAL ATTENTION DEMONSTRATION")
print("=" * 60)
demonstrate_causal_attention(seq_len=6)
```
```{python}
# Causal multi-head attention (what GPT uses)
causal_mha = CausalMultiHeadAttention(
embed_dim=64,
num_heads=8,
max_seq_len=512,
dropout=0.0
)
x = torch.randn(2, 10, 64)
output, weights = causal_mha(x, return_attention=True)
print(f"\nCausal Multi-Head Attention:")
print(f" Input shape: {x.shape}")
print(f" Output shape: {output.shape}")
print(f" Attention weights shape: {weights.shape}")
```
## PyTorch's Optimized Attention
Now that we understand attention from scratch, let's see how PyTorch provides production-optimized implementations.
### F.scaled_dot_product_attention
PyTorch 2.0+ provides `F.scaled_dot_product_attention` - a single function that replaces our manual implementation and automatically uses the best available backend.
```{python}
import torch
import torch.nn.functional as F
# Our inputs
batch, num_heads, seq_len, head_dim = 2, 8, 64, 32
query = torch.randn(batch, num_heads, seq_len, head_dim)
key = torch.randn(batch, num_heads, seq_len, head_dim)
value = torch.randn(batch, num_heads, seq_len, head_dim)
# The manual way (what we implemented)
def manual_attention(q, k, v, is_causal=False):
d_k = q.size(-1)
scores = q @ k.transpose(-2, -1) / (d_k ** 0.5)
if is_causal:
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores.masked_fill_(mask, float('-inf'))
weights = F.softmax(scores, dim=-1)
return weights @ v
# PyTorch's optimized version
output_manual = manual_attention(query, key, value, is_causal=True)
output_pytorch = F.scaled_dot_product_attention(query, key, value, is_causal=True)
print(f"Manual output shape: {output_manual.shape}")
print(f"PyTorch SDPA shape: {output_pytorch.shape}")
print(f"Results match: {torch.allclose(output_manual, output_pytorch, atol=1e-5)}")
```
### Flash Attention: The Speed Revolution
`F.scaled_dot_product_attention` automatically uses **Flash Attention** when available - a breakthrough algorithm that:
1. **Avoids materializing the full attention matrix** (O(n^2) memory -> O(n) memory)
2. **Uses tiling** to keep computation in fast GPU SRAM
3. **Fuses operations** to minimize memory bandwidth bottleneck
```{python}
# Check which backends are available
print("PyTorch Attention Backends:")
print(f" Flash Attention: {torch.backends.cuda.flash_sdp_enabled() if torch.cuda.is_available() else 'N/A (no CUDA)'}")
print(f" Memory-efficient: {torch.backends.cuda.mem_efficient_sdp_enabled() if torch.cuda.is_available() else 'N/A (no CUDA)'}")
print(f" Math (fallback): Always available")
# The beautiful thing: same API, automatic optimization
# PyTorch picks the fastest available backend
```
```{python}
# Benchmark: manual vs PyTorch SDPA
import time
def benchmark(fn, name, warmup=5, runs=20):
# Warmup
for _ in range(warmup):
_ = fn()
# Timed runs
start = time.perf_counter()
for _ in range(runs):
_ = fn()
elapsed = (time.perf_counter() - start) / runs * 1000
print(f"{name}: {elapsed:.2f} ms per call")
return elapsed
# Benchmark on CPU (GPU would show more dramatic difference)
batch, num_heads, seq_len, head_dim = 1, 8, 256, 64
q = torch.randn(batch, num_heads, seq_len, head_dim)
k = torch.randn(batch, num_heads, seq_len, head_dim)
v = torch.randn(batch, num_heads, seq_len, head_dim)
print(f"\nBenchmark (seq_len={seq_len}, {num_heads} heads, head_dim={head_dim}):")
t_manual = benchmark(lambda: manual_attention(q, k, v, is_causal=True), "Manual attention")
t_sdpa = benchmark(lambda: F.scaled_dot_product_attention(q, k, v, is_causal=True), "PyTorch SDPA")
print(f"\nSpeedup: {t_manual/t_sdpa:.1f}x")
```
::: {.callout-tip}
## From Scratch to Production
| What we learned | What PyTorch provides |
|-----------------|----------------------|
| Q @ K.T / sqrt(d) | Fused kernel, no intermediate storage |
| Causal mask with -inf | Built-in `is_causal=True` flag |
| Stable softmax | Numerically stable implementation |
| Manual loops | Flash Attention tiling |
**Use `F.scaled_dot_product_attention` in production.** Understanding the scratch implementation helps debug and customize, but the optimized version is 2-10x faster on GPU.
:::
## Exercises
### Exercise 1: Verify Attention Row Sums
```{python}
# Verify that each row of attention weights sums to 1
Q = torch.randn(1, 5, 16)
K = torch.randn(1, 5, 16)
V = torch.randn(1, 5, 16)
output, weights = scaled_dot_product_attention(Q, K, V)
print("Attention weights row sums (should all be 1.0):")
print(weights[0].sum(dim=-1))
```
### Exercise 2: Effect of Temperature
```{python}
# Temperature scaling affects attention sharpness
# Higher temperature = more uniform, Lower = more peaked
def attention_with_temperature(Q, K, V, temperature=1.0):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (math.sqrt(d_k) * temperature)
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, V)
return output, weights
Q = torch.randn(1, 4, 8)
K = torch.randn(1, 4, 8)
V = torch.randn(1, 4, 8)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for i, temp in enumerate([0.5, 1.0, 2.0]):
_, weights = attention_with_temperature(Q, K, V, temperature=temp)
ax = axes[i]
w = weights[0].detach().numpy()
ax.imshow(w, cmap='Blues', vmin=0, vmax=1)
ax.set_title(f'Temperature = {temp}\n({"Sharp" if temp < 1 else "Uniform" if temp > 1 else "Normal"})')
for j in range(4):
for k in range(4):
ax.text(k, j, f'{w[j, k]:.2f}', ha='center', va='center', fontsize=9)
plt.tight_layout()
plt.show()
print("Lower temperature = sharper attention (more peaked)")
print("Higher temperature = softer attention (more uniform)")
```
### Exercise 3: Compare Single-Head vs Multi-Head
```{python}
# Single head with full dimension vs multiple heads with smaller dimensions
embed_dim = 64
seq_len = 8
# Single head: one attention over 64 dimensions
single_head = MultiHeadAttention(embed_dim=embed_dim, num_heads=1)
# Multi head: 8 attention heads over 8 dimensions each
multi_head = MultiHeadAttention(embed_dim=embed_dim, num_heads=8)
x = torch.randn(1, seq_len, embed_dim)
out_single, w_single = single_head(x, return_attention=True)
out_multi, w_multi = multi_head(x, return_attention=True)
print(f"Single-head attention:")
print(f" Attention weights shape: {w_single.shape}")
print(f" One pattern to rule them all")
print(f"\nMulti-head attention:")
print(f" Attention weights shape: {w_multi.shape}")
print(f" 8 different patterns, each can specialize")
# Visualize
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
ax = axes[0]
ax.imshow(w_single[0, 0].detach().numpy(), cmap='Blues')
ax.set_title('Single Head (64d)')
ax = axes[1]
# Show first 4 heads of multi-head
multi_combined = torch.zeros(seq_len, seq_len)
for h in range(4):
multi_combined += w_multi[0, h].detach()
multi_combined /= 4
ax.imshow(multi_combined.numpy(), cmap='Blues')
ax.set_title('Multi Head (8 heads x 8d)\nAverage of first 4 heads')
plt.tight_layout()
plt.show()
```
## Complexity and Optimizations
Attention has:
- **Time complexity**: O(n^2 * d) where n = sequence length, d = embedding dimension
- **Memory complexity**: O(n^2) for storing the attention matrix
For long sequences (n = 10000), this means 100 million attention entries!
```{python}
# Visualize how memory grows with sequence length
import matplotlib.pyplot as plt
import numpy as np
seq_lengths = [512, 1024, 2048, 4096, 8192, 16384, 32768]
memory_bytes = [(n * n * 4) / (1024**3) for n in seq_lengths] # float32 = 4 bytes, convert to GB
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar([str(n) for n in seq_lengths], memory_bytes, color='steelblue')
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Attention Matrix Memory (GB)')
ax.set_title('Quadratic Memory Growth: O(n^2) Attention Matrix Size')
for i, (n, mem) in enumerate(zip(seq_lengths, memory_bytes)):
ax.text(i, mem + 0.1, f'{mem:.2f} GB', ha='center', fontsize=9)
plt.tight_layout()
plt.show()
print("This is why 32K context models need special techniques!")
```
### KV Cache for Efficient Inference
During autoregressive generation, we compute attention one token at a time. Without caching, we'd recompute K and V for all previous tokens at each step.
**KV Cache**: Store computed K and V values for previous tokens:
- At step t, only compute K_t and V_t for the new token
- Concatenate with cached K_{1:t-1} and V_{1:t-1}
- Query only needs the new token's Q_t
This reduces generation from O(n^2) to O(n) per token (where n is current sequence length).
### Modern Optimizations
**Flash Attention** (Dao et al., 2022):
- Avoids materializing the full n x n attention matrix
- Uses tiling and recomputation to be memory-efficient
- 2-4x faster than standard attention on modern GPUs
- Now the default in most frameworks (PyTorch 2.0+)
**Sparse Attention** patterns:
- **Local attention**: Each token only attends to nearby tokens
- **Strided attention**: Attend to every k-th token
- **Block-sparse**: Combine local and strided patterns
**Linear Attention** approximations:
- Replace softmax(QK^T)V with kernel feature maps
- Achieves O(n) complexity but may sacrifice quality
## Interactive Exploration
Experiment with attention in real-time. Adjust the temperature to see how it affects the attention distribution:
- **Low temperature** → Sharp, focused attention (nearly one-hot)
- **High temperature** → Soft, diffuse attention (more uniform)
```{ojs}
//| echo: false
// Tokens to visualize
tokens = ["The", "cat", "sat", "on", "the", "mat"]
// Pre-computed similarity scores (simulating Q·K^T)
// Higher values where semantically related
similarityMatrix = [
[1.0, 0.2, 0.1, 0.1, 0.9, 0.1], // "The" - similar to other "the"
[0.2, 1.0, 0.6, 0.1, 0.2, 0.3], // "cat" - relates to "sat"
[0.1, 0.7, 1.0, 0.3, 0.1, 0.4], // "sat" - relates to "cat", "mat"
[0.1, 0.1, 0.2, 1.0, 0.1, 0.5], // "on" - relates to "mat"
[0.9, 0.2, 0.1, 0.1, 1.0, 0.2], // "the" - similar to other "The"
[0.1, 0.4, 0.5, 0.6, 0.2, 1.0], // "mat" - relates to "sat", "on"
]
// Temperature slider
viewof temperature = Inputs.range([0.1, 3.0], {
value: 1.0,
step: 0.1,
label: "Temperature"
})
// Softmax function
function softmax(arr, temp) {
const scaled = arr.map(x => x / temp);
const maxVal = Math.max(...scaled);
const exps = scaled.map(x => Math.exp(x - maxVal));
const sum = exps.reduce((a, b) => a + b, 0);
return exps.map(x => x / sum);
}
// Compute attention weights
attentionWeights = similarityMatrix.map(row => softmax(row, temperature))
// Create heatmap data
heatmapData = {
const data = [];
for (let i = 0; i < tokens.length; i++) {
for (let j = 0; j < tokens.length; j++) {
data.push({
query: tokens[i],
key: tokens[j],
queryIdx: i,
keyIdx: j,
weight: attentionWeights[i][j]
});
}
}
return data;
}
// Dark mode detection
isDark = {
const check = () => document.body.classList.contains('quarto-dark');
return check();
}
// Theme colors for light/dark mode
theme = isDark ? {
textHigh: '#000000',
textLow: '#e8e6e3',
barFill: '#6b8cae'
} : {
textHigh: '#ffffff',
textLow: '#000000',
barFill: '#3b82f6'
}
```
```{ojs}
//| echo: false
Plot = import("https://esm.sh/@observablehq/plot@0.6")
Plot.plot({
title: "Attention Weights",
subtitle: `Temperature: ${temperature.toFixed(1)} — Each row shows where that token "looks"`,
width: 500,
height: 400,
padding: 0,
marginLeft: 60,
marginBottom: 60,
x: {
domain: tokens,
label: "Key (what we look at) →",
tickRotate: -45
},
y: {
domain: tokens,
label: "← Query (who is looking)"
},
color: {
scheme: "blues",
domain: [0, 1],
legend: true,
label: "Attention Weight"
},
marks: [
Plot.cell(heatmapData, {
x: "key",
y: "query",
fill: "weight",
tip: true
}),
Plot.text(heatmapData, {
x: "key",
y: "query",
text: d => d.weight.toFixed(2),
fill: d => d.weight > 0.5 ? theme.textHigh : theme.textLow,
fontSize: 11
})
]
})
```
```{ojs}
//| echo: false
// Show the selected row's attention distribution
viewof selectedToken = Inputs.select(tokens, {
label: "Focus on token",
value: "sat"
})
selectedIdx = tokens.indexOf(selectedToken)
selectedWeights = attentionWeights[selectedIdx]
md`**"${selectedToken}"** attends to:`
Plot.plot({
width: 500,
height: 200,
marginLeft: 60,
x: {
domain: tokens,
label: "Token"
},
y: {
domain: [0, 1],
label: "Attention Weight"
},
marks: [
Plot.barY(tokens.map((t, i) => ({token: t, weight: selectedWeights[i]})), {
x: "token",
y: "weight",
fill: theme.barFill
}),
Plot.text(tokens.map((t, i) => ({token: t, weight: selectedWeights[i]})), {
x: "token",
y: "weight",
text: d => d.weight.toFixed(2),
dy: -8,
fontSize: 11
}),
Plot.ruleY([0])
]
})
```
::: {.callout-tip}
## Try This
1. Set temperature to **0.1** — notice how attention becomes nearly one-hot (picks one token)
2. Set temperature to **3.0** — notice how attention becomes almost uniform
3. Compare how "sat" attends (looks at "cat") vs how "the" attends (looks at other "the")
:::
## Common Pitfalls
When implementing attention, watch out for these issues:
1. **Forgetting to scale**: Without `/sqrt(d_k)`, training becomes unstable with large head dimensions
2. **Wrong mask dimensions**: Mask should broadcast correctly over batch and head dimensions
3. **NaN from all-masked rows**: If an entire row is masked, softmax produces NaN (log(0)). Handle with `nan_to_num` or ensure at least one position is unmasked
4. **Memory leaks with attention weights**: Storing attention weights for visualization can exhaust memory. Only compute when needed
## Summary
Key takeaways:
1. **Attention computes weighted sums**: Each position's output is a weighted combination of all (allowed) positions' values
2. **Q, K, V**: Query asks "what do I need?", Key says "what do I have?", Value carries the information
3. **Scaling prevents gradient issues**: Dividing by sqrt(d_k) keeps softmax from saturating
4. **Causal masking enables generation**: In LLMs, we mask future tokens so the model learns to predict the next token
5. **Multiple heads learn different patterns**: Each head can specialize in different linguistic relationships
6. **Complexity is O(n^2)**: Attention's quadratic cost limits sequence length, motivating optimizations like Flash Attention and KV caching
## What's Next
In [Module 06: Transformer](../m06_transformer/lesson.qmd), we'll combine attention with feed-forward networks, layer normalization, and residual connections to build a complete transformer decoder block.