Module 05: Attention

Introduction

The mechanism that made transformers revolutionary. Attention allows each token to “look at” every other token and gather relevant information.

Instead of processing tokens in isolation, attention lets each token ask: “What other tokens in this sequence are relevant to me?”

Why attention matters for LLMs:

  • Long-range dependencies: Token at position 100 can directly attend to token at position 1 (RNNs struggle with this due to vanishing gradients)
  • Parallelization: Unlike RNNs, all positions can be computed simultaneously during training
  • Interpretability: Attention weights show what the model is “looking at”
  • Dynamic context: Each token’s representation is context-dependent, not fixed

The key innovation: self-attention - tokens attend to other tokens in the same sequence. This is distinct from cross-attention (used in encoder-decoder models) where queries come from one sequence and keys/values from another.

What You’ll Learn

By the end of this module, you will be able to:

  • Understand Query, Key, Value projections and their roles
  • Implement scaled dot-product attention from scratch
  • Apply causal masking for autoregressive models
  • Build multi-head attention and understand why it’s beneficial
  • Recognize attention patterns and what they reveal

Note: Attention itself is position-agnostic — it treats tokens as an unordered set. We rely on positional embeddings (Module 04) to give the model a sense of token order.

Attention as Three Questions

Every token in a sequence asks three questions. Understanding these questions is the key to understanding attention.

import numpy as np

# A simple sentence
tokens = ["The", "cat", "sat"]

# Each token has an embedding (we'll use random ones for illustration)
np.random.seed(42)
embed_dim = 4
embeddings = {tok: np.random.randn(embed_dim).round(2) for tok in tokens}

print("Each token has an embedding vector:")
for tok, emb in embeddings.items():
    print(f"  '{tok}': {emb}")

The Three Questions:

Question Vector What it asks
Query (Q) “What am I looking for?” Token seeks relevant context
Key (K) “What do I contain?” Token advertises its content
Value (V) “What do I return if matched?” Token’s actual information
# Each token projects its embedding into Q, K, V
# These are learned linear transformations

# For "sat", the query might encode: "I need a subject (who sat?)"
# For "cat", the key might encode: "I'm a noun, a subject candidate"
# For "cat", the value carries: the actual semantic content of "cat"

print("When 'sat' attends to 'cat':")
print("  Q_sat . K_cat = high score (sat is looking for a subject, cat is one)")
print("  The output for 'sat' includes V_cat weighted by this score")

The magic: Q, K, V are learned projections. The model learns what to look for (Q), how to advertise content (K), and what information to pass along (V).

Intuition: Query, Key, Value

Think of attention as a “soft lookup” - like a database query, but differentiable:

Query:   "What information do I need?"     (the question)
Keys:    "What information do I have?"     (index/labels for content)
Values:  "Here's my actual information"    (the content itself)

Attention = softmax(Query . Keys) x Values

Analogy: Imagine a library where:

  • Your query is “books about cats”
  • Each book has a key (its topic/keywords)
  • Each book has a value (its actual content)
  • You get a weighted average of book contents based on how well they match your query

For the sentence “The cat sat on the mat”:

  • “sat” might attend strongly to “cat” (who sat?) and weakly to “mat” (where?)
  • “mat” might attend strongly to “the” and “on” (which mat? on what?)

Each row of attention weights sums to 1 (softmax normalization).

The Math: Scaled Dot-Product Attention

The attention formula:

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) x V

Where:

  • Q (Query): What am I looking for? Shape: (seq, d_k)
  • K (Key): What do I have to offer? Shape: (seq, d_k)
  • V (Value): What information do I carry? Shape: (seq, d_v)
  • d_k: Dimension of keys (for scaling)

Step by Step

Building Attention by Hand

Before using PyTorch, let’s build attention from scratch with NumPy to see exactly what happens at each step.

import numpy as np

def attention_from_scratch(x, W_q, W_k, W_v):
    """
    Single-head attention implemented with pure NumPy.

    Args:
        x: Input embeddings (seq_len, embed_dim)
        W_q, W_k, W_v: Projection matrices (embed_dim, head_dim)

    Returns:
        output: Attended values (seq_len, head_dim)
        weights: Attention weights (seq_len, seq_len)
    """
    # Step 1: Project input into Q, K, V
    Q = x @ W_q  # (seq, head_dim) - "What am I looking for?"
    K = x @ W_k  # (seq, head_dim) - "What do I contain?"
    V = x @ W_v  # (seq, head_dim) - "What do I return?"

    print(f"Input x shape: {x.shape}")
    print(f"Q = x @ W_q: {Q.shape}")
    print(f"K = x @ W_k: {K.shape}")
    print(f"V = x @ W_v: {V.shape}")

    # Step 2: Compute attention scores
    # Each query attends to all keys: Q @ K.T
    d_k = K.shape[-1]
    scores = Q @ K.T  # (seq, seq) - similarity between every pair
    scores = scores / np.sqrt(d_k)  # Scale to prevent softmax saturation

    print(f"\nScores = Q @ K.T / sqrt({d_k}): {scores.shape}")
    print(f"Score matrix (who attends to whom):")
    print(scores.round(2))

    # Step 3: Softmax to get attention weights
    # Each row sums to 1: how much each position attends to others
    def softmax(x):
        exp_x = np.exp(x - x.max(axis=-1, keepdims=True))  # Numerical stability
        return exp_x / exp_x.sum(axis=-1, keepdims=True)

    weights = softmax(scores)

    print(f"\nAttention weights (each row sums to 1):")
    print(weights.round(3))
    print(f"Row sums: {weights.sum(axis=-1).round(3)}")

    # Step 4: Weighted sum of values
    output = weights @ V  # (seq, head_dim)

    print(f"\nOutput = weights @ V: {output.shape}")

    return output, weights

# Demo with a tiny example
np.random.seed(42)
seq_len, embed_dim, head_dim = 3, 4, 2

x = np.random.randn(seq_len, embed_dim)
W_q = np.random.randn(embed_dim, head_dim) * 0.5
W_k = np.random.randn(embed_dim, head_dim) * 0.5
W_v = np.random.randn(embed_dim, head_dim) * 0.5

print("=" * 50)
print("ATTENTION FROM SCRATCH")
print("=" * 50)
output, weights = attention_from_scratch(x, W_q, W_k, W_v)

Key Insight: The entire attention mechanism is just four matrix multiplications: 1. Q = x @ W_q - Project to queries 2. K = x @ W_k - Project to keys 3. scores = Q @ K.T / sqrt(d) - Compute similarities 4. output = softmax(scores) @ V - Weighted sum of values

Why Scale by sqrt(d_k)?

Without scaling, large d_k leads to large dot products, which pushes softmax into regions with tiny gradients (saturation). Here’s the mathematical intuition:

The Problem: When Q and K have elements drawn from a distribution with mean 0 and variance 1, their dot product has variance proportional to d_k. For d_k = 64, dot products can easily reach values like 8 or -10.

Why This Matters: Softmax of large values produces near-one-hot distributions:

  • softmax([10, 0, 0]) = [0.9999, 0.00005, 0.00005]

This causes:

  1. Vanishing gradients: The gradient of softmax approaches 0 at extremes
  2. Loss of information: We want soft attention, not hard selection

The Solution: Dividing by sqrt(d_k) normalizes variance back to ~1.

import torch
import torch.nn.functional as F
import math

# Show the effect of scaling
d_k = 64
q = torch.randn(1, d_k)
k = torch.randn(1, d_k)

dot_product = (q @ k.T).item()
scaled = dot_product / math.sqrt(d_k)

print(f"d_k = {d_k}")
print(f"Raw dot product: {dot_product:.2f}")
print(f"Scaled by sqrt({d_k}) = {math.sqrt(d_k):.1f}: {scaled:.2f}")
print(f"\nScaling keeps values in a reasonable range for softmax")

# Demonstrate the gradient problem
scores_large = torch.tensor([[10.0, 1.0, 1.0]], requires_grad=True)
scores_normal = torch.tensor([[1.0, 0.5, 0.5]], requires_grad=True)

weights_large = F.softmax(scores_large, dim=-1)
weights_normal = F.softmax(scores_normal, dim=-1)

print(f"\nLarge scores [10, 1, 1] -> softmax: {weights_large.detach().numpy().round(4)}")
print(f"Normal scores [1, 0.5, 0.5] -> softmax: {weights_normal.detach().numpy().round(3)}")

Numerical Stability in Softmax

There is a hidden danger in the naive softmax implementation: overflow.

import numpy as np

# The naive softmax
def naive_softmax(x):
    """This will overflow for large values!"""
    exp_x = np.exp(x)
    return exp_x / exp_x.sum()

# Try it with large values
large_scores = np.array([1000.0, 1001.0, 1002.0])

print("Large scores:", large_scores)
print("exp(1000) =", np.exp(1000))  # This is inf!

try:
    result = naive_softmax(large_scores)
    print("Naive softmax:", result)  # Will be [nan, nan, nan]
except:
    print("Overflow error!")

The Problem: exp(1000) is astronomically large - it overflows to infinity. Even exp(100) is about 2.7 x 10^43.

The Solution: The max-subtraction trick. Subtract the maximum value before exponentiating.

def stable_softmax(x):
    """
    Numerically stable softmax using the max-subtraction trick.

    Key insight: softmax(x) = softmax(x - c) for any constant c
    We choose c = max(x) to keep values small.
    """
    # Subtract max for numerical stability
    x_shifted = x - x.max()

    print(f"Original: {x}")
    print(f"After subtracting max ({x.max()}): {x_shifted}")
    print(f"Now exp() won't overflow: exp({x_shifted}) = {np.exp(x_shifted)}")

    exp_x = np.exp(x_shifted)
    return exp_x / exp_x.sum()

print("Stable softmax with max-subtraction trick:")
print("=" * 50)
large_scores = np.array([1000.0, 1001.0, 1002.0])
result = stable_softmax(large_scores)
print(f"\nResult: {result}")
print(f"Sum: {result.sum()}")  # Should be 1.0

Why does this work mathematically?

softmax(x)_i = exp(x_i) / sum(exp(x_j))
            = exp(x_i - c) / sum(exp(x_j - c))    [multiply by exp(-c)/exp(-c)]

For c = max(x), all exponents are <= 0, so exp() stays bounded.
# Verify the math: both give same result
normal_scores = np.array([2.0, 1.0, 0.1])

naive_result = naive_softmax(normal_scores)
stable_result = stable_softmax(normal_scores)

print(f"\nNaive:  {naive_result}")
print(f"Stable: {stable_result}")
print(f"Same? {np.allclose(naive_result, stable_result)}")
WarningAlways Use Stable Softmax

PyTorch’s F.softmax automatically uses the max-subtraction trick. Never implement naive softmax in production code - it will fail silently with NaN values when scores get large.

Code: Scaled Dot-Product Attention

Let’s implement attention step by step. This follows the exact algorithm from the attention.py module:

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Compute scaled dot-product attention.

    Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) x V

    Args:
        query: (batch, seq, d_k) or (..., seq, d_k)
        key: (batch, seq, d_k)
        value: (batch, seq, d_v)  # d_v can differ from d_k
        mask: Optional mask where 0 = masked, 1 = attend

    Returns:
        output: (batch, seq, d_v)
        attention_weights: (batch, seq, seq)

    Note: The mask convention (0 = masked) matches the module implementation.
    Masked positions get -inf before softmax, becoming 0 after.
    """
    d_k = query.size(-1)

    # Step 1: Compute similarity scores
    # QK^T: (..., seq, d_k) @ (..., d_k, seq) -> (..., seq, seq)
    scores = torch.matmul(query, key.transpose(-2, -1))

    # Step 2: Scale by sqrt(d_k)
    scores = scores / math.sqrt(d_k)

    # Step 3: Apply mask (if provided)
    # Masked positions get -inf, which becomes 0 after softmax
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 4: Softmax (each row sums to 1)
    attention_weights = F.softmax(scores, dim=-1)

    # Step 5: Weighted sum of values
    output = torch.matmul(attention_weights, value)

    return output, attention_weights

# Test it
batch, seq, d_k = 1, 4, 8
Q = torch.randn(batch, seq, d_k)
K = torch.randn(batch, seq, d_k)
V = torch.randn(batch, seq, d_k)

output, weights = scaled_dot_product_attention(Q, K, V)

print(f"Query shape: {Q.shape}")
print(f"Key shape: {K.shape}")
print(f"Value shape: {V.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"\nAttention weights (each row sums to 1):")
print(weights[0].round(decimals=2).numpy())
print(f"\nRow sums: {weights[0].sum(dim=-1).numpy()}")

Visualizing Attention Patterns

import matplotlib.pyplot as plt

def visualize_attention(weights, tokens=None, title="Attention Pattern"):
    """Visualize attention weights as a heatmap."""
    if weights.dim() == 3:
        weights = weights[0]  # Remove batch dim

    weights = weights.detach().numpy()
    seq_len = weights.shape[0]

    plt.figure(figsize=(8, 6))
    plt.imshow(weights, cmap='Blues', vmin=0, vmax=weights.max())
    plt.colorbar(label='Attention Weight')

    if tokens:
        plt.xticks(range(seq_len), tokens, rotation=45, ha='right')
        plt.yticks(range(seq_len), tokens)
    else:
        plt.xlabel('Key Position (what we look at)')
        plt.ylabel('Query Position (who is looking)')

    # Add values in cells
    for i in range(seq_len):
        for j in range(seq_len):
            plt.text(j, i, f'{weights[i, j]:.2f}', ha='center', va='center', fontsize=10)

    plt.title(title)
    plt.tight_layout()
    plt.show()

# Visualize the attention pattern from above
visualize_attention(weights, title="Random Attention Pattern")

Causal Masking for Language Models

In autoregressive models (like GPT), tokens can only attend to previous tokens, not future ones. We enforce this with a causal mask:

NoteMask Convention

In this lesson, we use an additive mask where:

  • 0 = attend (score unchanged)
  • -inf = masked (softmax converts to 0)

Some libraries use a boolean mask (True = attend, False = masked) which is converted internally. The key insight: positions with -inf before softmax become 0 attention weight.

The Causal Mask from Scratch

The causal mask is elegantly simple: we add -inf to positions we want to mask, and softmax turns those into zeros.

import numpy as np

def causal_mask_from_scratch(seq_len):
    """
    Create a causal mask using np.triu (upper triangular).

    The mask has -inf above the diagonal (future positions)
    and 0 on and below the diagonal (past/current positions).
    """
    # np.triu with k=1 gives us the strictly upper triangular part
    # (everything above the main diagonal)
    mask = np.triu(np.ones((seq_len, seq_len)), k=1)

    # Convert 1s to -inf (positions to mask)
    mask = mask * (-1e9)  # Use large negative instead of -inf for visualization

    return mask

# Visualize the mask
seq_len = 5
mask = causal_mask_from_scratch(seq_len)

print("Causal Mask (0 = attend, -inf = masked):")
print(mask.round(0))

print("\nHow it works:")
print("  Position 0: sees only position 0")
print("  Position 1: sees positions 0, 1")
print("  Position 4: sees positions 0, 1, 2, 3, 4")
def attention_with_causal_mask(x, W_q, W_k, W_v):
    """
    Causal attention from scratch - each position only attends to past.
    """
    Q = x @ W_q
    K = x @ W_k
    V = x @ W_v

    seq_len = x.shape[0]
    d_k = K.shape[-1]

    # Compute scores
    scores = Q @ K.T / np.sqrt(d_k)

    print("Scores before masking:")
    print(scores.round(2))

    # Add causal mask: -inf for future positions
    mask = np.triu(np.ones((seq_len, seq_len)), k=1) * (-1e9)
    scores = scores + mask

    print("\nScores after adding causal mask:")
    print(scores.round(2))

    # Softmax: -inf becomes 0
    def softmax(x):
        exp_x = np.exp(x - x.max(axis=-1, keepdims=True))
        return exp_x / exp_x.sum(axis=-1, keepdims=True)

    weights = softmax(scores)

    print("\nAttention weights (upper triangle is 0!):")
    print(weights.round(3))

    output = weights @ V
    return output, weights

# Demo
np.random.seed(42)
seq_len, embed_dim, head_dim = 4, 4, 2
x = np.random.randn(seq_len, embed_dim)
W_q = np.random.randn(embed_dim, head_dim) * 0.5
W_k = np.random.randn(embed_dim, head_dim) * 0.5
W_v = np.random.randn(embed_dim, head_dim) * 0.5

print("=" * 50)
print("CAUSAL ATTENTION FROM SCRATCH")
print("=" * 50)
output, weights = attention_with_causal_mask(x, W_q, W_k, W_v)
NoteKey Insight

Causal masking is just adding -inf before softmax. That is the entire trick.

  • softmax([2.0, 1.0, -inf]) = [0.73, 0.27, 0.00]
  • The -inf position gets exactly 0 weight
  • No information flows from future to past
def create_causal_mask(seq_len):
    """Create a lower triangular causal mask."""
    return torch.tril(torch.ones(seq_len, seq_len))

# Show the mask
seq_len = 6
mask = create_causal_mask(seq_len)

print("Causal Mask (1 = can attend, 0 = masked):")
print()
tokens = ["The", "cat", "sat", "on", "the", "mat"]
for i in range(seq_len):
    row = ['#' if mask[i, j] == 1 else '.' for j in range(seq_len)]
    print(f"  {tokens[i]:4s}: {''.join(row)}")

print(f"\nPosition 0 can only see position 0")
print(f"Position 5 can see all previous positions")
# Apply causal mask
Q = torch.randn(1, 6, 8)
K = torch.randn(1, 6, 8)
V = torch.randn(1, 6, 8)

# Without mask (bidirectional)
output_bi, weights_bi = scaled_dot_product_attention(Q, K, V)

# With causal mask
output_causal, weights_causal = scaled_dot_product_attention(Q, K, V, mask=mask)

# Compare
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax = axes[0]
w = weights_bi[0].detach().numpy()
im = ax.imshow(w, cmap='Blues', vmin=0, vmax=1)
ax.set_title('Bidirectional Attention', fontsize=14)
ax.set_xticks(range(6))
ax.set_yticks(range(6))
ax.set_xticklabels(tokens, rotation=45, ha='right')
ax.set_yticklabels(tokens)

ax = axes[1]
w = weights_causal[0].detach().numpy()
im = ax.imshow(w, cmap='Blues', vmin=0, vmax=1)
ax.set_title('Causal Attention (Lower Triangular)', fontsize=14)
ax.set_xticks(range(6))
ax.set_yticks(range(6))
ax.set_xticklabels(tokens, rotation=45, ha='right')
ax.set_yticklabels(tokens)

plt.colorbar(im, ax=axes, shrink=0.8, label='Attention Weight')
plt.tight_layout()
plt.show()

print("Notice: In causal attention, the upper triangle is 0 (can't attend to future)")

Multi-Head Attention

Instead of one attention head, we use multiple “heads” that can learn different patterns:

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) x W_O

where head_i = Attention(Q x W_Q_i, K x W_K_i, V x W_V_i)

Why Multiple Heads?

A single attention head computes one weighted average, which limits what relationships it can capture. Multiple heads provide:

  • Diverse patterns: Different heads can focus on different relationships (syntax, semantics, position, coreference)
  • Subspace attention: Each head operates in a lower-dimensional subspace (head_dim = embed_dim / num_heads), allowing specialized representations
  • Computational efficiency: Despite having multiple heads, the total computation is similar to single-head attention with full dimensionality (same number of parameters)

Typical configurations:

  • GPT-2: 12 heads, 768 embed_dim, 64 head_dim
  • GPT-3: 96 heads, 12288 embed_dim, 128 head_dim
  • Llama 2 (7B): 32 heads, 4096 embed_dim, 128 head_dim

What Different Heads Learn

In trained models, different heads specialize:

  • Head 0: “Who did what?” - attends to subject-verb pairs
  • Head 1: “What comes before?” - attends to previous token
  • Head 2: “What’s similar?” - attends to semantically similar words
  • Head 3: “Syntax patterns” - attends to grammatical structure
import torch.nn as nn

# Simplified implementation for learning - see attention.py for production version
class MultiHeadAttention(nn.Module):
    """Multi-head attention with separate Q, K, V projections (simplified for illustration)."""

    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Projections for Q, K, V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_len, _ = x.shape

        # Project Q, K, V
        q = self.q_proj(x)  # (batch, seq, embed)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Reshape for multi-head: (batch, seq, embed) -> (batch, heads, seq, head_dim)
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Apply attention to values
        attn_output = torch.matmul(attention_weights, v)

        # Reshape back: (batch, heads, seq, head_dim) -> (batch, seq, embed)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.embed_dim)

        # Final projection
        output = self.out_proj(attn_output)

        if return_attention:
            return output, attention_weights
        return output

# Test multi-head attention
embed_dim = 64
num_heads = 8
mha = MultiHeadAttention(embed_dim=embed_dim, num_heads=num_heads)

x = torch.randn(2, 10, embed_dim)
output, weights = mha(x, return_attention=True)

print(f"Multi-Head Attention Configuration:")
print(f"  Embedding dimension: {embed_dim}")
print(f"  Number of heads: {num_heads}")
print(f"  Dimension per head: {embed_dim // num_heads}")
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"  (batch, heads, query_pos, key_pos)")
print(f"\nTotal parameters: {sum(p.numel() for p in mha.parameters()):,}")

Visualizing Multi-Head Attention

# Visualize attention patterns for different heads
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for head in range(num_heads):
    ax = axes[head // 4, head % 4]
    w = weights[0, head].detach().numpy()
    im = ax.imshow(w, cmap='Blues', vmin=0, vmax=w.max())
    ax.set_title(f'Head {head}', fontsize=12)
    ax.set_xlabel('Key')
    ax.set_ylabel('Query')

plt.suptitle('Attention Patterns Across 8 Heads\n(Each head learns different patterns)', fontsize=14)
plt.tight_layout()
plt.show()

Using Our Attention Module

The attention.py module provides production-ready implementations:

from attention import (
    CausalMultiHeadAttention,
    demonstrate_attention,
    demonstrate_causal_attention
)

# Run built-in demonstrations
print("=" * 60)
print("MULTI-HEAD ATTENTION DEMONSTRATION")
print("=" * 60)
demonstrate_attention(seq_len=6, embed_dim=32, num_heads=4)
print("\n" + "=" * 60)
print("CAUSAL ATTENTION DEMONSTRATION")
print("=" * 60)
demonstrate_causal_attention(seq_len=6)
# Causal multi-head attention (what GPT uses)
causal_mha = CausalMultiHeadAttention(
    embed_dim=64,
    num_heads=8,
    max_seq_len=512,
    dropout=0.0
)

x = torch.randn(2, 10, 64)
output, weights = causal_mha(x, return_attention=True)

print(f"\nCausal Multi-Head Attention:")
print(f"  Input shape: {x.shape}")
print(f"  Output shape: {output.shape}")
print(f"  Attention weights shape: {weights.shape}")

PyTorch’s Optimized Attention

Now that we understand attention from scratch, let’s see how PyTorch provides production-optimized implementations.

F.scaled_dot_product_attention

PyTorch 2.0+ provides F.scaled_dot_product_attention - a single function that replaces our manual implementation and automatically uses the best available backend.

import torch
import torch.nn.functional as F

# Our inputs
batch, num_heads, seq_len, head_dim = 2, 8, 64, 32
query = torch.randn(batch, num_heads, seq_len, head_dim)
key = torch.randn(batch, num_heads, seq_len, head_dim)
value = torch.randn(batch, num_heads, seq_len, head_dim)

# The manual way (what we implemented)
def manual_attention(q, k, v, is_causal=False):
    d_k = q.size(-1)
    scores = q @ k.transpose(-2, -1) / (d_k ** 0.5)
    if is_causal:
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        scores.masked_fill_(mask, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    return weights @ v

# PyTorch's optimized version
output_manual = manual_attention(query, key, value, is_causal=True)
output_pytorch = F.scaled_dot_product_attention(query, key, value, is_causal=True)

print(f"Manual output shape: {output_manual.shape}")
print(f"PyTorch SDPA shape: {output_pytorch.shape}")
print(f"Results match: {torch.allclose(output_manual, output_pytorch, atol=1e-5)}")

Flash Attention: The Speed Revolution

F.scaled_dot_product_attention automatically uses Flash Attention when available - a breakthrough algorithm that:

  1. Avoids materializing the full attention matrix (O(n^2) memory -> O(n) memory)
  2. Uses tiling to keep computation in fast GPU SRAM
  3. Fuses operations to minimize memory bandwidth bottleneck
# Check which backends are available
print("PyTorch Attention Backends:")
print(f"  Flash Attention: {torch.backends.cuda.flash_sdp_enabled() if torch.cuda.is_available() else 'N/A (no CUDA)'}")
print(f"  Memory-efficient: {torch.backends.cuda.mem_efficient_sdp_enabled() if torch.cuda.is_available() else 'N/A (no CUDA)'}")
print(f"  Math (fallback): Always available")

# The beautiful thing: same API, automatic optimization
# PyTorch picks the fastest available backend
# Benchmark: manual vs PyTorch SDPA
import time

def benchmark(fn, name, warmup=5, runs=20):
    # Warmup
    for _ in range(warmup):
        _ = fn()

    # Timed runs
    start = time.perf_counter()
    for _ in range(runs):
        _ = fn()
    elapsed = (time.perf_counter() - start) / runs * 1000

    print(f"{name}: {elapsed:.2f} ms per call")
    return elapsed

# Benchmark on CPU (GPU would show more dramatic difference)
batch, num_heads, seq_len, head_dim = 1, 8, 256, 64
q = torch.randn(batch, num_heads, seq_len, head_dim)
k = torch.randn(batch, num_heads, seq_len, head_dim)
v = torch.randn(batch, num_heads, seq_len, head_dim)

print(f"\nBenchmark (seq_len={seq_len}, {num_heads} heads, head_dim={head_dim}):")
t_manual = benchmark(lambda: manual_attention(q, k, v, is_causal=True), "Manual attention")
t_sdpa = benchmark(lambda: F.scaled_dot_product_attention(q, k, v, is_causal=True), "PyTorch SDPA")
print(f"\nSpeedup: {t_manual/t_sdpa:.1f}x")
TipFrom Scratch to Production
What we learned What PyTorch provides
Q @ K.T / sqrt(d) Fused kernel, no intermediate storage
Causal mask with -inf Built-in is_causal=True flag
Stable softmax Numerically stable implementation
Manual loops Flash Attention tiling

Use F.scaled_dot_product_attention in production. Understanding the scratch implementation helps debug and customize, but the optimized version is 2-10x faster on GPU.

Exercises

Exercise 1: Verify Attention Row Sums

# Verify that each row of attention weights sums to 1
Q = torch.randn(1, 5, 16)
K = torch.randn(1, 5, 16)
V = torch.randn(1, 5, 16)

output, weights = scaled_dot_product_attention(Q, K, V)

print("Attention weights row sums (should all be 1.0):")
print(weights[0].sum(dim=-1))

Exercise 2: Effect of Temperature

# Temperature scaling affects attention sharpness
# Higher temperature = more uniform, Lower = more peaked

def attention_with_temperature(Q, K, V, temperature=1.0):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (math.sqrt(d_k) * temperature)
    weights = F.softmax(scores, dim=-1)
    output = torch.matmul(weights, V)
    return output, weights

Q = torch.randn(1, 4, 8)
K = torch.randn(1, 4, 8)
V = torch.randn(1, 4, 8)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, temp in enumerate([0.5, 1.0, 2.0]):
    _, weights = attention_with_temperature(Q, K, V, temperature=temp)
    ax = axes[i]
    w = weights[0].detach().numpy()
    ax.imshow(w, cmap='Blues', vmin=0, vmax=1)
    ax.set_title(f'Temperature = {temp}\n({"Sharp" if temp < 1 else "Uniform" if temp > 1 else "Normal"})')
    for j in range(4):
        for k in range(4):
            ax.text(k, j, f'{w[j, k]:.2f}', ha='center', va='center', fontsize=9)

plt.tight_layout()
plt.show()
print("Lower temperature = sharper attention (more peaked)")
print("Higher temperature = softer attention (more uniform)")

Exercise 3: Compare Single-Head vs Multi-Head

# Single head with full dimension vs multiple heads with smaller dimensions
embed_dim = 64
seq_len = 8

# Single head: one attention over 64 dimensions
single_head = MultiHeadAttention(embed_dim=embed_dim, num_heads=1)

# Multi head: 8 attention heads over 8 dimensions each
multi_head = MultiHeadAttention(embed_dim=embed_dim, num_heads=8)

x = torch.randn(1, seq_len, embed_dim)

out_single, w_single = single_head(x, return_attention=True)
out_multi, w_multi = multi_head(x, return_attention=True)

print(f"Single-head attention:")
print(f"  Attention weights shape: {w_single.shape}")
print(f"  One pattern to rule them all")

print(f"\nMulti-head attention:")
print(f"  Attention weights shape: {w_multi.shape}")
print(f"  8 different patterns, each can specialize")

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

ax = axes[0]
ax.imshow(w_single[0, 0].detach().numpy(), cmap='Blues')
ax.set_title('Single Head (64d)')

ax = axes[1]
# Show first 4 heads of multi-head
multi_combined = torch.zeros(seq_len, seq_len)
for h in range(4):
    multi_combined += w_multi[0, h].detach()
multi_combined /= 4
ax.imshow(multi_combined.numpy(), cmap='Blues')
ax.set_title('Multi Head (8 heads x 8d)\nAverage of first 4 heads')

plt.tight_layout()
plt.show()

Complexity and Optimizations

Attention has:

  • Time complexity: O(n^2 * d) where n = sequence length, d = embedding dimension
  • Memory complexity: O(n^2) for storing the attention matrix

For long sequences (n = 10000), this means 100 million attention entries!

# Visualize how memory grows with sequence length
import matplotlib.pyplot as plt
import numpy as np

seq_lengths = [512, 1024, 2048, 4096, 8192, 16384, 32768]
memory_bytes = [(n * n * 4) / (1024**3) for n in seq_lengths]  # float32 = 4 bytes, convert to GB

fig, ax = plt.subplots(figsize=(10, 5))
ax.bar([str(n) for n in seq_lengths], memory_bytes, color='steelblue')
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Attention Matrix Memory (GB)')
ax.set_title('Quadratic Memory Growth: O(n^2) Attention Matrix Size')
for i, (n, mem) in enumerate(zip(seq_lengths, memory_bytes)):
    ax.text(i, mem + 0.1, f'{mem:.2f} GB', ha='center', fontsize=9)
plt.tight_layout()
plt.show()

print("This is why 32K context models need special techniques!")

KV Cache for Efficient Inference

During autoregressive generation, we compute attention one token at a time. Without caching, we’d recompute K and V for all previous tokens at each step.

KV Cache: Store computed K and V values for previous tokens:

  • At step t, only compute K_t and V_t for the new token
  • Concatenate with cached K_{1:t-1} and V_{1:t-1}
  • Query only needs the new token’s Q_t

This reduces generation from O(n^2) to O(n) per token (where n is current sequence length).

Modern Optimizations

Flash Attention (Dao et al., 2022):

  • Avoids materializing the full n x n attention matrix
  • Uses tiling and recomputation to be memory-efficient
  • 2-4x faster than standard attention on modern GPUs
  • Now the default in most frameworks (PyTorch 2.0+)

Sparse Attention patterns:

  • Local attention: Each token only attends to nearby tokens
  • Strided attention: Attend to every k-th token
  • Block-sparse: Combine local and strided patterns

Linear Attention approximations:

  • Replace softmax(QK^T)V with kernel feature maps
  • Achieves O(n) complexity but may sacrifice quality

Interactive Exploration

Experiment with attention in real-time. Adjust the temperature to see how it affects the attention distribution:

  • Low temperature → Sharp, focused attention (nearly one-hot)
  • High temperature → Soft, diffuse attention (more uniform)
TipTry This
  1. Set temperature to 0.1 — notice how attention becomes nearly one-hot (picks one token)
  2. Set temperature to 3.0 — notice how attention becomes almost uniform
  3. Compare how “sat” attends (looks at “cat”) vs how “the” attends (looks at other “the”)

Common Pitfalls

When implementing attention, watch out for these issues:

  1. Forgetting to scale: Without /sqrt(d_k), training becomes unstable with large head dimensions
  2. Wrong mask dimensions: Mask should broadcast correctly over batch and head dimensions
  3. NaN from all-masked rows: If an entire row is masked, softmax produces NaN (log(0)). Handle with nan_to_num or ensure at least one position is unmasked
  4. Memory leaks with attention weights: Storing attention weights for visualization can exhaust memory. Only compute when needed

Summary

Key takeaways:

  1. Attention computes weighted sums: Each position’s output is a weighted combination of all (allowed) positions’ values
  2. Q, K, V: Query asks “what do I need?”, Key says “what do I have?”, Value carries the information
  3. Scaling prevents gradient issues: Dividing by sqrt(d_k) keeps softmax from saturating
  4. Causal masking enables generation: In LLMs, we mask future tokens so the model learns to predict the next token
  5. Multiple heads learn different patterns: Each head can specialize in different linguistic relationships
  6. Complexity is O(n^2): Attention’s quadratic cost limits sequence length, motivating optimizations like Flash Attention and KV caching

What’s Next

In Module 06: Transformer, we’ll combine attention with feed-forward networks, layer normalization, and residual connections to build a complete transformer decoder block.