Module 05: Attention

Introduction

Attention made transformers revolutionary. Each token examines every other token and gathers relevant information.

Attention enables each token to ask: “Which tokens in this sequence matter to me?”

Why attention matters for LLMs:

  • Long-range dependencies: Token 100 can attend to token 1—solving the vanishing gradient problem that cripples RNNs
  • Parallelization: All positions compute simultaneously during training, unlike in RNNs
  • Interpretability: Attention weights reveal what the model examines
  • Dynamic context: Each token’s representation is context-dependent, not fixed

Self-attention is the key innovation: tokens attend to other tokens within the same sequence - queries, keys, and values all come from the same input. Cross-attention (used in encoder-decoder models) draws queries from one sequence and keys/values from another.

What You’ll Learn

After this module, you can:

  • Understand Query, Key, Value projections and their roles
  • Implement scaled dot-product attention from scratch
  • Apply causal masking for autoregressive models
  • Build multi-head attention and understand why it’s beneficial
  • Recognize attention patterns and what they reveal

Prerequisites

This module requires familiarity with:

Note: Attention treats tokens as an unordered set. Positional embeddings (Module 04) supply the sense of order.

Attention as Three Questions

Every token in a sequence asks three questions. These questions unlock attention.

import numpy as np

# A simple sentence
tokens = ["The", "cat", "sat"]

# Each token has an embedding (we'll use random ones for illustration)
np.random.seed(42)
embed_dim = 4
embeddings = {tok: np.random.randn(embed_dim).round(2) for tok in tokens}

print("Each token has an embedding vector:")
for tok, emb in embeddings.items():
    print(f"  '{tok}': {emb}")
Each token has an embedding vector:
  'The': [ 0.5  -0.14  0.65  1.52]
  'cat': [-0.23 -0.23  1.58  0.77]
  'sat': [-0.47  0.54 -0.46 -0.47]

The Three Questions:

Question Vector What it asks
Query (Q) “What am I looking for?” Token seeks relevant context
Key (K) “What do I contain?” Token advertises its content
Value (V) “What do I return if matched?” Token’s actual information
# Each token projects its embedding into Q, K, V
# These are learned linear transformations

# For "sat", the query might encode: "I need a subject (who sat?)"
# For "cat", the key might encode: "I'm a noun, a subject candidate"
# For "cat", the value carries: the actual semantic content of "cat"

print("When 'sat' attends to 'cat':")
print("  Q_sat . K_cat = high score (sat is looking for a subject, cat is one)")
print("  The output for 'sat' includes V_cat weighted by this score")
When 'sat' attends to 'cat':
  Q_sat . K_cat = high score (sat is looking for a subject, cat is one)
  The output for 'sat' includes V_cat weighted by this score

Q, K, and V are learned projections: the model learns what to seek (Q), how to advertise content (K), and what information to transmit (V).

Intuition: Query, Key, Value

Think of attention as a “soft lookup” - like a database query, but differentiable:

Query:   "What information do I need?"     (the question)
Keys:    "What information do I have?"     (index/labels for content)
Values:  "Here's my actual information"    (the content itself)

Attention = softmax(Query . Keys) x Values

Analogy: Imagine a library where:

  • Your query is “books about cats”
  • Each book has a key (its topic/keywords)
  • Each book has a value (its actual content)
  • You get a weighted average of book contents based on how well they match your query

For the sentence “The cat sat on the mat”:

  • “sat” might attend strongly to “cat” (who sat?) and weakly to “mat” (where?)
  • “mat” might attend strongly to “the” and “on” (which mat? on what?)

Softmax normalizes each row of attention weights to sum to 1.

The Math: Scaled Dot-Product Attention

The attention formula:

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) x V

Where:

  • Q (Query): What am I looking for? Shape: (seq, d_k)
  • K (Key): What do I have to offer? Shape: (seq, d_k)
  • V (Value): What information do I carry? Shape: (seq, d_v)
  • d_k: Dimension of keys (for scaling)

Step by Step

Building Attention by Hand

Before using PyTorch, let’s build attention with NumPy to see what happens at each step.

import numpy as np

def attention_from_scratch(x, W_q, W_k, W_v):
    """
    Single-head attention implemented with pure NumPy.

    Args:
        x: Input embeddings (seq_len, embed_dim)
        W_q, W_k, W_v: Projection matrices (embed_dim, head_dim)

    Returns:
        output: Attended values (seq_len, head_dim)
        weights: Attention weights (seq_len, seq_len)
    """
    # Step 1: Project input into Q, K, V
    Q = x @ W_q  # (seq, head_dim) - "What am I looking for?"
    K = x @ W_k  # (seq, head_dim) - "What do I contain?"
    V = x @ W_v  # (seq, head_dim) - "What do I return?"

    print(f"Input x shape: {x.shape}")
    print(f"Q = x @ W_q: {Q.shape}")
    print(f"K = x @ W_k: {K.shape}")
    print(f"V = x @ W_v: {V.shape}")

    # Step 2: Compute attention scores
    # Each query attends to all keys: Q @ K.T
    d_k = K.shape[-1]
    scores = Q @ K.T  # (seq, seq) - similarity between every pair
    scores = scores / np.sqrt(d_k)  # Scale to prevent softmax saturation

    print(f"\nScores = Q @ K.T / sqrt({d_k}): {scores.shape}")
    print(f"Score matrix (who attends to whom):")
    print(scores.round(2))

    # Step 3: Softmax to get attention weights
    # Each row sums to 1: how much each position attends to others
    def softmax(x):
        exp_x = np.exp(x - x.max(axis=-1, keepdims=True))  # Numerical stability
        return exp_x / exp_x.sum(axis=-1, keepdims=True)

    weights = softmax(scores)

    print(f"\nAttention weights (each row sums to 1):")
    print(weights.round(3))
    print(f"Row sums: {weights.sum(axis=-1).round(3)}")

    # Step 4: Weighted sum of values
    output = weights @ V  # (seq, head_dim)

    print(f"\nOutput = weights @ V: {output.shape}")

    return output, weights

# Demo with a tiny example
np.random.seed(42)
seq_len, embed_dim, head_dim = 3, 4, 2

x = np.random.randn(seq_len, embed_dim)
W_q = np.random.randn(embed_dim, head_dim) * 0.5
W_k = np.random.randn(embed_dim, head_dim) * 0.5
W_v = np.random.randn(embed_dim, head_dim) * 0.5

print("=" * 50)
print("ATTENTION FROM SCRATCH")
print("=" * 50)
output, weights = attention_from_scratch(x, W_q, W_k, W_v)
==================================================
ATTENTION FROM SCRATCH
==================================================
Input x shape: (3, 4)
Q = x @ W_q: (3, 2)
K = x @ W_k: (3, 2)
V = x @ W_v: (3, 2)

Scores = Q @ K.T / sqrt(2): (3, 3)
Score matrix (who attends to whom):
[[ 0.05  0.2   0.4 ]
 [ 0.48  0.72 -0.05]
 [ 0.18  0.22 -0.18]]

Attention weights (each row sums to 1):
[[0.278 0.324 0.397]
 [0.348 0.445 0.206]
 [0.365 0.381 0.255]]
Row sums: [1. 1. 1.]

Output = weights @ V: (3, 2)

Key Insight: Attention reduces to four matrix multiplications: 1. Q = x @ W_q - Project to queries 2. K = x @ W_k - Project to keys 3. scores = Q @ K.T / sqrt(d) - Compute similarities 4. output = softmax(scores) @ V - Weighted sum of values

Why Scale by sqrt(d_k)?

Scaling prevents large d_k from producing dot products that saturate softmax and shrink gradients to near-zero. Here’s the intuition:

The Problem: When Q and K have elements drawn from a distribution with mean 0 and variance 1, their dot product has variance proportional to d_k. For d_k = 64, dot products can easily reach values like 8 or -10.

Why This Matters: Softmax of large values produces near-one-hot distributions:

  • softmax([10, 0, 0]) = [0.9999, 0.00005, 0.00005]

This causes:

  1. Vanishing gradients: The gradient of softmax approaches 0 at extremes
  2. Loss of information: We want soft attention, not hard selection

The Solution: Dividing by sqrt(d_k) normalizes variance back to ~1.

import torch
import torch.nn.functional as F
import math

# Show the effect of scaling
d_k = 64
q = torch.randn(1, d_k)
k = torch.randn(1, d_k)

dot_product = (q @ k.T).item()
scaled = dot_product / math.sqrt(d_k)

print(f"d_k = {d_k}")
print(f"Raw dot product: {dot_product:.2f}")
print(f"Scaled by sqrt({d_k}) = {math.sqrt(d_k):.1f}: {scaled:.2f}")
print(f"\nScaling keeps values in a reasonable range for softmax")

# Demonstrate the gradient problem
scores_large = torch.tensor([[10.0, 1.0, 1.0]], requires_grad=True)
scores_normal = torch.tensor([[1.0, 0.5, 0.5]], requires_grad=True)

weights_large = F.softmax(scores_large, dim=-1)
weights_normal = F.softmax(scores_normal, dim=-1)

print(f"\nLarge scores [10, 1, 1] -> softmax: {weights_large.detach().numpy().round(4)}")
print(f"Normal scores [1, 0.5, 0.5] -> softmax: {weights_normal.detach().numpy().round(3)}")
d_k = 64
Raw dot product: 4.68
Scaled by sqrt(64) = 8.0: 0.58

Scaling keeps values in a reasonable range for softmax

Large scores [10, 1, 1] -> softmax: [[9.998e-01 1.000e-04 1.000e-04]]
Normal scores [1, 0.5, 0.5] -> softmax: [[0.452 0.274 0.274]]

Numerical Stability in Softmax

The naive softmax implementation hides a danger: overflow.

import numpy as np

# The naive softmax
def naive_softmax(x):
    """This will overflow for large values!"""
    exp_x = np.exp(x)
    return exp_x / exp_x.sum()

# Try it with large values
large_scores = np.array([1000.0, 1001.0, 1002.0])

print("Large scores:", large_scores)
print("exp(1000) =", np.exp(1000))  # This is inf!

try:
    result = naive_softmax(large_scores)
    print("Naive softmax:", result)  # Will be [nan, nan, nan]
except:
    print("Overflow error!")
Large scores: [1000. 1001. 1002.]
exp(1000) = inf
Naive softmax: [nan nan nan]
/tmp/ipykernel_3209/757484777.py:13: RuntimeWarning: overflow encountered in exp
  print("exp(1000) =", np.exp(1000))  # This is inf!
/tmp/ipykernel_3209/757484777.py:6: RuntimeWarning: overflow encountered in exp
  exp_x = np.exp(x)
/tmp/ipykernel_3209/757484777.py:7: RuntimeWarning: invalid value encountered in divide
  return exp_x / exp_x.sum()

The Problem: exp(1000) is astronomically large - it overflows to infinity. Even exp(100) is about 2.7 x 10^43.

The Solution: The max-subtraction trick. Subtract the maximum value before exponentiating.

def stable_softmax(x):
    """
    Numerically stable softmax using the max-subtraction trick.

    Key insight: softmax(x) = softmax(x - c) for any constant c
    We choose c = max(x) to keep values small.
    """
    # Subtract max for numerical stability
    x_shifted = x - x.max()

    print(f"Original: {x}")
    print(f"After subtracting max ({x.max()}): {x_shifted}")
    print(f"Now exp() won't overflow: exp({x_shifted}) = {np.exp(x_shifted)}")

    exp_x = np.exp(x_shifted)
    return exp_x / exp_x.sum()

print("Stable softmax with max-subtraction trick:")
print("=" * 50)
large_scores = np.array([1000.0, 1001.0, 1002.0])
result = stable_softmax(large_scores)
print(f"\nResult: {result}")
print(f"Sum: {result.sum()}")  # Should be 1.0
Stable softmax with max-subtraction trick:
==================================================
Original: [1000. 1001. 1002.]
After subtracting max (1002.0): [-2. -1.  0.]
Now exp() won't overflow: exp([-2. -1.  0.]) = [0.13533528 0.36787944 1.        ]

Result: [0.09003057 0.24472847 0.66524096]
Sum: 0.9999999999999999

Why does this work mathematically?

softmax(x)_i = exp(x_i) / sum(exp(x_j))
            = exp(x_i - c) / sum(exp(x_j - c))    [multiply by exp(-c)/exp(-c)]

For c = max(x), all exponents are <= 0, so exp() stays bounded.
# Verify the math: both give same result
normal_scores = np.array([2.0, 1.0, 0.1])

naive_result = naive_softmax(normal_scores)
stable_result = stable_softmax(normal_scores)

print(f"\nNaive:  {naive_result}")
print(f"Stable: {stable_result}")
print(f"Same? {np.allclose(naive_result, stable_result)}")
Original: [2.  1.  0.1]
After subtracting max (2.0): [ 0.  -1.  -1.9]
Now exp() won't overflow: exp([ 0.  -1.  -1.9]) = [1.         0.36787944 0.14956862]

Naive:  [0.65900114 0.24243297 0.09856589]
Stable: [0.65900114 0.24243297 0.09856589]
Same? True
WarningAlways Use Stable Softmax

PyTorch’s F.softmax applies the max-subtraction trick automatically. Naive softmax fails silently, returning NaN when scores grow large.

Code: Scaled Dot-Product Attention

Let’s implement attention step by step. This follows the exact algorithm from the attention.py module:

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Compute scaled dot-product attention.

    Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) x V

    Args:
        query: (batch, seq, d_k) or (..., seq, d_k)
        key: (batch, seq, d_k)
        value: (batch, seq, d_v)  # d_v can differ from d_k
        mask: Optional mask where 0 = masked, 1 = attend

    Returns:
        output: (batch, seq, d_v)
        attention_weights: (batch, seq, seq)

    Note: The mask convention (0 = masked) matches the module implementation.
    Masked positions get -inf before softmax, becoming 0 after.
    """
    d_k = query.size(-1)

    # Step 1: Compute similarity scores
    # QK^T: (..., seq, d_k) @ (..., d_k, seq) -> (..., seq, seq)
    scores = torch.matmul(query, key.transpose(-2, -1))

    # Step 2: Scale by sqrt(d_k)
    scores = scores / math.sqrt(d_k)

    # Step 3: Apply mask (if provided)
    # Masked positions get -inf, which becomes 0 after softmax
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 4: Softmax (each row sums to 1)
    attention_weights = F.softmax(scores, dim=-1)

    # Step 5: Weighted sum of values
    output = torch.matmul(attention_weights, value)

    return output, attention_weights

# Test it
batch, seq, d_k = 1, 4, 8
Q = torch.randn(batch, seq, d_k)
K = torch.randn(batch, seq, d_k)
V = torch.randn(batch, seq, d_k)

output, weights = scaled_dot_product_attention(Q, K, V)

print(f"Query shape: {Q.shape}")
print(f"Key shape: {K.shape}")
print(f"Value shape: {V.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"\nAttention weights (each row sums to 1):")
print(weights[0].round(decimals=2).numpy())
print(f"\nRow sums: {weights[0].sum(dim=-1).numpy()}")
Query shape: torch.Size([1, 4, 8])
Key shape: torch.Size([1, 4, 8])
Value shape: torch.Size([1, 4, 8])
Output shape: torch.Size([1, 4, 8])
Attention weights shape: torch.Size([1, 4, 4])

Attention weights (each row sums to 1):
[[0.42 0.18 0.15 0.25]
 [0.09 0.41 0.25 0.25]
 [0.06 0.68 0.24 0.03]
 [0.29 0.47 0.07 0.16]]

Row sums: [1.         1.         0.9999999  0.99999994]

Visualizing Attention Patterns

# Extract attention weights for OJS visualization
attention_weights_for_viz = weights[0].detach().numpy().tolist()
ojs_define(attention_weights_viz = attention_weights_for_viz)

Causal Masking for Language Models

In autoregressive models (like GPT), each token attends only to previous tokens, never future ones. A causal mask enforces this constraint:

NoteMask Convention

In this lesson, we use an additive mask where:

  • 0 = attend (score unchanged)
  • -inf = masked (softmax converts to 0)

Some libraries use a boolean mask (True = attend, False = masked) which is converted internally. The key insight: positions with -inf before softmax become 0 attention weight.

The Causal Mask from Scratch

The causal mask is elegantly simple: we add -inf to positions we want to mask, and softmax turns those into zeros.

import numpy as np

def causal_mask_from_scratch(seq_len):
    """
    Create a causal mask using np.triu (upper triangular).

    The mask has -inf above the diagonal (future positions)
    and 0 on and below the diagonal (past/current positions).
    """
    # np.triu with k=1 gives us the strictly upper triangular part
    # (everything above the main diagonal)
    mask = np.triu(np.ones((seq_len, seq_len)), k=1)

    # Convert 1s to -inf (positions to mask)
    mask = mask * (-1e9)  # Use large negative instead of -inf for visualization

    return mask

# Visualize the mask
seq_len = 5
mask = causal_mask_from_scratch(seq_len)

print("Causal Mask (0 = attend, -inf = masked):")
print(mask.round(0))

print("\nHow it works:")
print("  Position 0: sees only position 0")
print("  Position 1: sees positions 0, 1")
print("  Position 4: sees positions 0, 1, 2, 3, 4")
Causal Mask (0 = attend, -inf = masked):
[[-0.e+00 -1.e+09 -1.e+09 -1.e+09 -1.e+09]
 [-0.e+00 -0.e+00 -1.e+09 -1.e+09 -1.e+09]
 [-0.e+00 -0.e+00 -0.e+00 -1.e+09 -1.e+09]
 [-0.e+00 -0.e+00 -0.e+00 -0.e+00 -1.e+09]
 [-0.e+00 -0.e+00 -0.e+00 -0.e+00 -0.e+00]]

How it works:
  Position 0: sees only position 0
  Position 1: sees positions 0, 1
  Position 4: sees positions 0, 1, 2, 3, 4
def attention_with_causal_mask(x, W_q, W_k, W_v):
    """
    Causal attention from scratch - each position only attends to past.
    """
    Q = x @ W_q
    K = x @ W_k
    V = x @ W_v

    seq_len = x.shape[0]
    d_k = K.shape[-1]

    # Compute scores
    scores = Q @ K.T / np.sqrt(d_k)

    print("Scores before masking:")
    print(scores.round(2))

    # Add causal mask: -inf for future positions
    mask = np.triu(np.ones((seq_len, seq_len)), k=1) * (-1e9)
    scores = scores + mask

    print("\nScores after adding causal mask:")
    print(scores.round(2))

    # Softmax: -inf becomes 0
    def softmax(x):
        exp_x = np.exp(x - x.max(axis=-1, keepdims=True))
        return exp_x / exp_x.sum(axis=-1, keepdims=True)

    weights = softmax(scores)

    print("\nAttention weights (upper triangle is 0!):")
    print(weights.round(3))

    output = weights @ V
    return output, weights

# Demo
np.random.seed(42)
seq_len, embed_dim, head_dim = 4, 4, 2
x = np.random.randn(seq_len, embed_dim)
W_q = np.random.randn(embed_dim, head_dim) * 0.5
W_k = np.random.randn(embed_dim, head_dim) * 0.5
W_v = np.random.randn(embed_dim, head_dim) * 0.5

print("=" * 50)
print("CAUSAL ATTENTION FROM SCRATCH")
print("=" * 50)
output, weights = attention_with_causal_mask(x, W_q, W_k, W_v)
==================================================
CAUSAL ATTENTION FROM SCRATCH
==================================================
Scores before masking:
[[-1.08 -0.42  0.22  0.84]
 [-1.26 -0.68  0.22  1.97]
 [ 0.11  0.11 -0.01 -0.41]
 [ 2.12  0.79 -0.44 -1.52]]

Scores after adding causal mask:
[[-1.08000000e+00 -1.00000000e+09 -1.00000000e+09 -9.99999999e+08]
 [-1.26000000e+00 -6.80000000e-01 -1.00000000e+09 -9.99999998e+08]
 [ 1.10000000e-01  1.10000000e-01 -1.00000000e-02 -1.00000000e+09]
 [ 2.12000000e+00  7.90000000e-01 -4.40000000e-01 -1.52000000e+00]]

Attention weights (upper triangle is 0!):
[[1.    0.    0.    0.   ]
 [0.359 0.641 0.    0.   ]
 [0.348 0.345 0.307 0.   ]
 [0.731 0.193 0.057 0.019]]
NoteKey Insight

Causal masking is just adding -inf before softmax. That is the entire trick.

  • softmax([2.0, 1.0, -inf]) = [0.73, 0.27, 0.00]
  • The -inf position gets exactly 0 weight
  • No information flows from future to past
def create_causal_mask(seq_len):
    """Create a lower triangular causal mask."""
    return torch.tril(torch.ones(seq_len, seq_len))

# Show the mask
seq_len = 6
mask = create_causal_mask(seq_len)

print("Causal Mask (1 = can attend, 0 = masked):")
print()
tokens = ["The", "cat", "sat", "on", "the", "mat"]
for i in range(seq_len):
    row = ['#' if mask[i, j] == 1 else '.' for j in range(seq_len)]
    print(f"  {tokens[i]:4s}: {''.join(row)}")

print(f"\nPosition 0 can only see position 0")
print(f"Position 5 can see all previous positions")
Causal Mask (1 = can attend, 0 = masked):

  The : #.....
  cat : ##....
  sat : ###...
  on  : ####..
  the : #####.
  mat : ######

Position 0 can only see position 0
Position 5 can see all previous positions
# Apply causal mask
Q = torch.randn(1, 6, 8)
K = torch.randn(1, 6, 8)
V = torch.randn(1, 6, 8)

# Without mask (bidirectional)
output_bi, weights_bi = scaled_dot_product_attention(Q, K, V)

# With causal mask
output_causal, weights_causal = scaled_dot_product_attention(Q, K, V, mask=mask)

# Pass to OJS for visualization
ojs_define(
    weights_bi_viz = weights_bi[0].detach().numpy().tolist(),
    weights_causal_viz = weights_causal[0].detach().numpy().tolist(),
    mask_tokens = tokens
)

print("Notice: In causal attention, the upper triangle is 0 (can't attend to future)")

Multi-Head Attention

Multiple attention heads, instead of one, learn different patterns:

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) x W_O

where head_i = Attention(Q x W_Q_i, K x W_K_i, V x W_V_i)

Why Multiple Heads?

A single attention head computes one weighted average, which limits what relationships it can capture. Multiple heads provide:

  • Diverse patterns: Different heads focus on different relationships: syntax, semantics, position, coreference
  • Subspace attention: Each head operates in a lower-dimensional subspace (head_dim = embed_dim / num_heads), allowing specialized representations
  • Computational efficiency: Despite having multiple heads, the total computation is similar to single-head attention with full dimensionality (same number of parameters)

Typical configurations:

  • GPT-2: 12 heads, 768 embed_dim, 64 head_dim
  • GPT-3: 96 heads, 12288 embed_dim, 128 head_dim
  • Llama 2 (7B): 32 heads, 4096 embed_dim, 128 head_dim

What Different Heads Learn

Trained models show head specialization:

  • Head 0: “Who did what?” - attends to subject-verb pairs
  • Head 1: “What comes before?” - attends to previous token
  • Head 2: “What’s similar?” - attends to semantically similar words
  • Head 3: “Syntax patterns” - attends to grammatical structure
import torch.nn as nn

# Simplified implementation for learning - see attention.py for production version
class MultiHeadAttention(nn.Module):
    """Multi-head attention with separate Q, K, V projections (simplified for illustration)."""

    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Projections for Q, K, V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_len, _ = x.shape

        # Project Q, K, V
        q = self.q_proj(x)  # (batch, seq, embed)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Reshape for multi-head: (batch, seq, embed) -> (batch, heads, seq, head_dim)
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Apply attention to values
        attn_output = torch.matmul(attention_weights, v)

        # Reshape back: (batch, heads, seq, head_dim) -> (batch, seq, embed)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.embed_dim)

        # Final projection
        output = self.out_proj(attn_output)

        if return_attention:
            return output, attention_weights
        return output

# Test multi-head attention
embed_dim = 64
num_heads = 8
mha = MultiHeadAttention(embed_dim=embed_dim, num_heads=num_heads)

x = torch.randn(2, 10, embed_dim)
output, weights = mha(x, return_attention=True)

print(f"Multi-Head Attention Configuration:")
print(f"  Embedding dimension: {embed_dim}")
print(f"  Number of heads: {num_heads}")
print(f"  Dimension per head: {embed_dim // num_heads}")
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"  (batch, heads, query_pos, key_pos)")
print(f"\nTotal parameters: {sum(p.numel() for p in mha.parameters()):,}")
Multi-Head Attention Configuration:
  Embedding dimension: 64
  Number of heads: 8
  Dimension per head: 8

Input shape: torch.Size([2, 10, 64])
Output shape: torch.Size([2, 10, 64])
Attention weights shape: torch.Size([2, 8, 10, 10])
  (batch, heads, query_pos, key_pos)

Total parameters: 16,640

Visualizing Multi-Head Attention

# Pass multi-head attention weights to OJS
multi_head_weights = [weights[0, h].detach().numpy().tolist() for h in range(num_heads)]
ojs_define(mha_weights = multi_head_weights, mha_num_heads = num_heads)

Using Our Attention Module

The attention.py module provides production-ready implementations:

from attention import (
    CausalMultiHeadAttention,
    demonstrate_attention,
    demonstrate_causal_attention
)

# Run built-in demonstrations
print("=" * 60)
print("MULTI-HEAD ATTENTION DEMONSTRATION")
print("=" * 60)
demonstrate_attention(seq_len=6, embed_dim=32, num_heads=4)
============================================================
MULTI-HEAD ATTENTION DEMONSTRATION
============================================================
============================================================
ATTENTION DEMONSTRATION
============================================================

Multi-Head Attention:
  Embed dim: 32
  Num heads: 4
  Head dim: 8

Input shape: (1, 6, 32)
  (batch=1, seq_len=6, embed_dim=32)

Output shape: (1, 6, 32)
Attention weights shape: (1, 4, 6, 6)
  (batch, heads, seq, seq)

Attention weights for head 0, position 0:
  [0.18868425488471985, 0.17037604749202728, 0.358096182346344, 0.05720369145274162, 0.054688461124897, 0.17095136642456055]
  Sum: 1.0000 (should be 1.0)
MultiHeadAttention(
  (q_proj): Linear(in_features=32, out_features=32, bias=True)
  (k_proj): Linear(in_features=32, out_features=32, bias=True)
  (v_proj): Linear(in_features=32, out_features=32, bias=True)
  (out_proj): Linear(in_features=32, out_features=32, bias=True)
  (attention): ScaledDotProductAttention()
  (dropout): Dropout(p=0.0, inplace=False)
)
print("\n" + "=" * 60)
print("CAUSAL ATTENTION DEMONSTRATION")
print("=" * 60)
demonstrate_causal_attention(seq_len=6)

============================================================
CAUSAL ATTENTION DEMONSTRATION
============================================================
============================================================
CAUSAL ATTENTION DEMONSTRATION
============================================================

Causal mask for seq_len=6:
(1 = can attend, 0 = masked)
  Position 0: █·····
  Position 1: ██····
  Position 2: ███···
  Position 3: ████··
  Position 4: █████·
  Position 5: ██████

Interpretation:
  Position 0: can only see position 0
  Position 1: can see positions 0, 1
  Position 5: can see all positions

Without mask (position 0 attends to all):
  [0.06325218826532364, 0.056782983243465424, 0.11940829455852509, 0.43410682678222656, 0.09652253985404968, 0.2299271523952484]

With causal mask (position 0 only attends to itself):
  [1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
# Causal multi-head attention (what GPT uses)
causal_mha = CausalMultiHeadAttention(
    embed_dim=64,
    num_heads=8,
    max_seq_len=512,
    dropout=0.0
)

x = torch.randn(2, 10, 64)
output, weights = causal_mha(x, return_attention=True)

print(f"\nCausal Multi-Head Attention:")
print(f"  Input shape: {x.shape}")
print(f"  Output shape: {output.shape}")
print(f"  Attention weights shape: {weights.shape}")

Causal Multi-Head Attention:
  Input shape: torch.Size([2, 10, 64])
  Output shape: torch.Size([2, 10, 64])
  Attention weights shape: torch.Size([2, 8, 10, 10])

PyTorch’s Optimized Attention

Now that we understand attention from scratch, PyTorch’s production-optimized implementations offer a faster path.

F.scaled_dot_product_attention

PyTorch 2.0+ provides F.scaled_dot_product_attention - a single function that replaces our manual implementation and automatically uses the best available backend.

import torch
import torch.nn.functional as F

# Our inputs
batch, num_heads, seq_len, head_dim = 2, 8, 64, 32
query = torch.randn(batch, num_heads, seq_len, head_dim)
key = torch.randn(batch, num_heads, seq_len, head_dim)
value = torch.randn(batch, num_heads, seq_len, head_dim)

# The manual way (what we implemented)
def manual_attention(q, k, v, is_causal=False):
    d_k = q.size(-1)
    scores = q @ k.transpose(-2, -1) / (d_k ** 0.5)
    if is_causal:
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        scores.masked_fill_(mask, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    return weights @ v

# PyTorch's optimized version
output_manual = manual_attention(query, key, value, is_causal=True)
output_pytorch = F.scaled_dot_product_attention(query, key, value, is_causal=True)

print(f"Manual output shape: {output_manual.shape}")
print(f"PyTorch SDPA shape: {output_pytorch.shape}")
print(f"Results match: {torch.allclose(output_manual, output_pytorch, atol=1e-5)}")
Manual output shape: torch.Size([2, 8, 64, 32])
PyTorch SDPA shape: torch.Size([2, 8, 64, 32])
Results match: True

Flash Attention: The Speed Revolution

F.scaled_dot_product_attention invokes Flash Attention when the backend supports it - a breakthrough algorithm that:

  1. Avoids materializing the full attention matrix (O(n^2) memory -> O(n) memory)
  2. Uses tiling to keep computation in fast GPU SRAM
  3. Fuses operations to minimize memory bandwidth bottleneck
# Check which backends are available
print("PyTorch Attention Backends:")
print(f"  Flash Attention: {torch.backends.cuda.flash_sdp_enabled() if torch.cuda.is_available() else 'N/A (no CUDA)'}")
print(f"  Memory-efficient: {torch.backends.cuda.mem_efficient_sdp_enabled() if torch.cuda.is_available() else 'N/A (no CUDA)'}")
print(f"  Math (fallback): Always available")

# The beautiful thing: same API, automatic optimization
# PyTorch picks the fastest available backend
PyTorch Attention Backends:
  Flash Attention: N/A (no CUDA)
  Memory-efficient: N/A (no CUDA)
  Math (fallback): Always available
# Benchmark: manual vs PyTorch SDPA
import time

def benchmark(fn, name, warmup=5, runs=20):
    # Warmup
    for _ in range(warmup):
        _ = fn()

    # Timed runs
    start = time.perf_counter()
    for _ in range(runs):
        _ = fn()
    elapsed = (time.perf_counter() - start) / runs * 1000

    print(f"{name}: {elapsed:.2f} ms per call")
    return elapsed

# Benchmark on CPU (GPU would show more dramatic difference)
batch, num_heads, seq_len, head_dim = 1, 8, 256, 64
q = torch.randn(batch, num_heads, seq_len, head_dim)
k = torch.randn(batch, num_heads, seq_len, head_dim)
v = torch.randn(batch, num_heads, seq_len, head_dim)

print(f"\nBenchmark (seq_len={seq_len}, {num_heads} heads, head_dim={head_dim}):")
t_manual = benchmark(lambda: manual_attention(q, k, v, is_causal=True), "Manual attention")
t_sdpa = benchmark(lambda: F.scaled_dot_product_attention(q, k, v, is_causal=True), "PyTorch SDPA")
print(f"\nSpeedup: {t_manual/t_sdpa:.1f}x")

Benchmark (seq_len=256, 8 heads, head_dim=64):
Manual attention: 5.37 ms per call
PyTorch SDPA: 1.99 ms per call

Speedup: 2.7x
TipFrom Scratch to Production
What we learned What PyTorch provides
Q @ K.T / sqrt(d) Fused kernel, no intermediate storage
Causal mask with -inf Built-in is_causal=True flag
Stable softmax Numerically stable implementation
Manual loops Flash Attention tiling

Use F.scaled_dot_product_attention in production. The scratch implementation aids debugging, but the optimized version runs 2-10x faster on GPU.

Exercises

Exercise 1: Verify Attention Row Sums

# Verify that each row of attention weights sums to 1
Q = torch.randn(1, 5, 16)
K = torch.randn(1, 5, 16)
V = torch.randn(1, 5, 16)

output, weights = scaled_dot_product_attention(Q, K, V)

print("Attention weights row sums (should all be 1.0):")
print(weights[0].sum(dim=-1))
Attention weights row sums (should all be 1.0):
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

Exercise 2: Effect of Temperature

# Temperature scaling affects attention sharpness
# Higher temperature = more uniform, Lower = more peaked

def attention_with_temperature(Q, K, V, temperature=1.0):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (math.sqrt(d_k) * temperature)
    weights = F.softmax(scores, dim=-1)
    output = torch.matmul(weights, V)
    return output, weights

# Generate fixed Q, K, V for temperature comparison
torch.manual_seed(42)
Q_temp = torch.randn(1, 4, 8)
K_temp = torch.randn(1, 4, 8)
V_temp = torch.randn(1, 4, 8)

# Compute attention at different temperatures
temp_weights = {}
for temp in [0.5, 1.0, 2.0]:
    _, w = attention_with_temperature(Q_temp, K_temp, V_temp, temperature=temp)
    temp_weights[f"temp_{str(temp).replace('.', '_')}"] = w[0].detach().numpy().tolist()

ojs_define(
    temp_weights_05 = temp_weights["temp_0_5"],
    temp_weights_10 = temp_weights["temp_1_0"],
    temp_weights_20 = temp_weights["temp_2_0"]
)

print("Lower temperature = sharper attention (more peaked)")
print("Higher temperature = softer attention (more uniform)")

Exercise 3: Compare Single-Head vs Multi-Head

# Single head with full dimension vs multiple heads with smaller dimensions
embed_dim = 64
seq_len = 8

# Single head: one attention over 64 dimensions
single_head = MultiHeadAttention(embed_dim=embed_dim, num_heads=1)

# Multi head: 8 attention heads over 8 dimensions each
multi_head = MultiHeadAttention(embed_dim=embed_dim, num_heads=8)

x = torch.randn(1, seq_len, embed_dim)

out_single, w_single = single_head(x, return_attention=True)
out_multi, w_multi = multi_head(x, return_attention=True)

print(f"Single-head attention:")
print(f"  Attention weights shape: {w_single.shape}")
print(f"  One pattern to rule them all")

print(f"\nMulti-head attention:")
print(f"  Attention weights shape: {w_multi.shape}")
print(f"  8 different patterns, each can specialize")

# Compute average of first 4 heads for visualization
multi_combined = torch.zeros(seq_len, seq_len)
for h in range(4):
    multi_combined += w_multi[0, h].detach()
multi_combined /= 4

# Pass to OJS for visualization
ojs_define(
    single_head_weights = w_single[0, 0].detach().numpy().tolist(),
    multi_head_avg_weights = multi_combined.numpy().tolist(),
    single_multi_seq_len = seq_len
)

Complexity and Optimizations

Attention has:

  • Time complexity: O(n^2 * d) where n = sequence length, d = embedding dimension
  • Memory complexity: O(n^2) for storing the attention matrix

For long sequences (n = 10000), O(n²) complexity demands 100 million attention entries!

KV Cache for Efficient Inference

During autoregressive generation, we compute attention one token at a time. Without caching, we’d recompute K and V for all previous tokens at each step.

KV Cache: Store computed K and V values for previous tokens:

  • At step t, only compute K_t and V_t for the new token
  • Concatenate with cached K_{1:t-1} and V_{1:t-1}
  • Query only needs the new token’s Q_t

This reduces per-token generation from O(n²) to O(n).

Modern Optimizations

Flash Attention (Dao et al., 2022):

  • Avoids materializing the full n x n attention matrix
  • Uses tiling and recomputation to be memory-efficient
  • 2-4x faster than standard attention on modern GPUs
  • PyTorch 2.0+ and most frameworks now default to Flash Attention

Sparse Attention patterns:

  • Local attention: Each token only attends to nearby tokens
  • Strided attention: Attend to every k-th token
  • Block-sparse: Combine local and strided patterns

Linear Attention approximations:

  • Replace softmax(QK^T)V with kernel feature maps
  • Achieves O(n) complexity but may sacrifice quality

Interactive Exploration

Experiment with attention in real-time. Adjust the temperature to see how it affects the attention distribution:

  • Low temperature → Sharp, focused attention (nearly one-hot)
  • High temperature → Soft, diffuse attention (more uniform)
TipTry This
  1. Set temperature to 0.1 — notice how attention becomes nearly one-hot (picks one token)
  2. Set temperature to 3.0 — notice how attention becomes almost uniform
  3. Compare how “sat” attends (looks at “cat”) vs how “the” attends (looks at other “the”)

Common Pitfalls

When implementing attention, watch out for these issues:

  1. Forgetting to scale: Without /sqrt(d_k), training becomes unstable with large head dimensions
  2. Wrong mask dimensions: Mask should broadcast correctly over batch and head dimensions
  3. NaN from all-masked rows: If an entire row is masked, softmax produces NaN (log(0)). Handle with nan_to_num or ensure at least one position is unmasked
  4. Memory leaks with attention weights: Storing attention weights for visualization can exhaust memory. Only compute when needed

Summary

Key takeaways:

  1. Attention computes weighted sums: Each position’s output is a weighted combination of all (allowed) positions’ values
  2. Q, K, V: Query asks “what do I need?”, Key says “what do I have?”, Value carries the information
  3. Scaling prevents gradient issues: Dividing by sqrt(d_k) keeps softmax from saturating
  4. Causal masking enables generation: In LLMs, we mask future tokens so the model learns to predict the next token
  5. Multiple heads learn different patterns: Each head can specialize in different linguistic relationships
  6. Complexity is O(n^2): Attention’s quadratic cost limits sequence length, motivating optimizations like Flash Attention and KV caching

What’s Next

Module 06: Transformer combines attention with feed-forward networks, layer normalization, and residual connections to build a complete transformer decoder block.