---
title: "Module 04: Embeddings"
format:
html:
code-fold: false
toc: true
ipynb: default
jupyter: python3
---
{{< include ../_diagram-lib.qmd >}}
{{< include ../_components/step-control.qmd >}}
## Introduction
Embeddings convert token IDs into dense vectors. In vector space, similar tokens cluster together.
The model learns **embedding** vectors during training. The model does not treat token ID 42 as a bare number; it assigns a dense vector like `[0.2, -0.5, 0.8, ...]` that encodes semantic content.
Why embeddings matter for LLMs:
- **Similarity**: Similar words have similar vectors ("cat" and "dog" are close)
- **Composition**: Vectors can be combined meaningfully
- **Learning**: The model learns these representations during training
- **Position**: We also embed WHERE tokens are in the sequence
Two types of embeddings in transformers:
1. **Token embeddings**: What the token means
2. **Positional embeddings**: Where the token is in the sequence
### What You'll Learn
After this module, you can:
- Explain how embedding lookups work as matrix multiplication
- Implement token and positional embeddings from scratch
- Understand how gradients flow through embedding layers
- Choose between learned and sinusoidal positional embeddings
- Recognize the role of embeddings in the transformer architecture
### Prerequisites
This module requires familiarity with:
- [Module 01: Tensors](../m01_tensors/lesson.qmd) — Tensor shapes, indexing, and matrix multiplication
- [Module 03: Tokenization](../m03_tokenization/lesson.qmd) — How text becomes token IDs
### Memory and Scale Considerations
In most language models, embeddings constitute the largest single component. The parameter count equals `vocab_size × embed_dim`:
| Model | Vocab Size | Embed Dim | Embedding Params | Memory (fp32) |
|-------|------------|-----------|------------------|---------------|
| GPT-2 Small | 50,257 | 768 | 38.6M | 147 MB |
| LLaMA 2 7B | 32,000 | 4,096 | 131M | 500 MB |
| LLaMA 3 8B | 128,256 | 4,096 | 525M | 2 GB |
| GPT-4 (est.) | ~100,000 | ~12,288 | ~1.2B | ~4.7 GB |
| LLaMA 3 70B | 128,256 | 8,192 | 1.05B | 4 GB |
Vocabulary size drives embedding memory and demands careful consideration. A larger vocabulary means each token carries more information (fewer tokens per text), but the embedding table grows proportionally.
## Intuition: Coordinates in Meaning Space
Embeddings are coordinates in "meaning space":
```
Token: "cat" -> [0.8, 0.1, 0.9, ...] <- captures "animal", "pet", etc.
Token: "dog" -> [0.7, 0.2, 0.8, ...] <- similar to cat
Token: "code" -> [0.1, 0.9, 0.2, ...] <- very different
Distance("cat", "dog") < Distance("cat", "code")
```
Positional embeddings add "where in the sequence" information:
```
Position 0: [1.0, 0.0, 0.5, ...] <- "I'm first"
Position 1: [0.9, 0.1, 0.4, ...] <- "I'm second"
Position 2: [0.8, 0.2, 0.3, ...] <- "I'm third"
```
The model combines them:
```
Final embedding = Token embedding + Positional embedding
```
## Embedding Architecture
The diagram below shows how embeddings work in a transformer. Step through the pipeline to trace token IDs as they transform into dense vector representations:
```{ojs}
//| echo: false
// Embedding pipeline step control
viewof embeddingStep = stepControl({min: 0, max: 4, value: 0, label: "Pipeline Step"})
viewof showTensorShapes = Inputs.toggle({
value: true,
label: "Show tensor shapes"
})
```
```{ojs}
//| echo: false
// Step descriptions for the embedding pipeline
embeddingStepInfo = {
const steps = [
{
title: "Input: Token IDs",
desc: "Raw token indices from the tokenizer. Each integer maps to a vocabulary entry.",
highlight: ["input"]
},
{
title: "Token Embedding Lookup",
desc: "Each token ID selects a row from the embedding table. This is just matrix indexing.",
highlight: ["input", "token-table", "token-vectors"]
},
{
title: "Position Embedding Lookup",
desc: "Each position (0, 1, 2, ...) selects a row from the position table.",
highlight: ["pos-table", "pos-vectors"]
},
{
title: "Combine: Token + Position",
desc: "Add token and position embeddings element-wise. Position broadcasts over batch.",
highlight: ["token-vectors", "pos-vectors", "add-op", "combined"]
},
{
title: "Final Embeddings",
desc: "After dropout, these vectors enter the transformer layers. Each token now has both semantic and positional information.",
highlight: ["combined", "dropout", "output"]
}
];
return steps[embeddingStep];
}
```
```{ojs}
//| echo: false
// Interactive embedding pipeline visualization
embeddingPipelineViz = {
const width = 700;
const height = 400;
const step = embeddingStep;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Defs for effects
const defs = svg.append("defs");
// Glow filter
const glowFilter = defs.append("filter")
.attr("id", "emb-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "3")
.attr("result", "coloredBlur");
const glowMerge = glowFilter.append("feMerge");
glowMerge.append("feMergeNode").attr("in", "coloredBlur");
glowMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Arrow marker
defs.append("marker")
.attr("id", "emb-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.edgeStroke);
// Highlighted arrow
defs.append("marker")
.attr("id", "emb-arrow-hl")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.highlight);
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 10);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 28)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text(embeddingStepInfo.title);
// Description
svg.append("text")
.attr("x", width / 2)
.attr("y", 48)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.7)
.text(embeddingStepInfo.desc);
const highlighted = embeddingStepInfo.highlight;
// Helper function to check if node is highlighted
const isHighlighted = (id) => highlighted.includes(id);
// Helper to draw a node box
const drawNode = (x, y, w, h, label, sublabel, id) => {
const hl = isHighlighted(id);
const g = svg.append("g").attr("transform", `translate(${x}, ${y})`);
g.append("rect")
.attr("x", -w/2)
.attr("y", -h/2)
.attr("width", w)
.attr("height", h)
.attr("rx", 6)
.attr("fill", hl ? diagramTheme.highlight : diagramTheme.nodeFill)
.attr("stroke", hl ? diagramTheme.highlight : diagramTheme.nodeStroke)
.attr("stroke-width", hl ? 2 : 1.5)
.attr("filter", hl ? "url(#emb-glow)" : null);
const textColor = hl ? diagramTheme.textOnHighlight : diagramTheme.nodeText;
g.append("text")
.attr("y", sublabel ? -8 : 0)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "11px")
.attr("font-weight", "500")
.text(label);
if (sublabel && showTensorShapes) {
g.append("text")
.attr("y", 10)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "9px")
.attr("opacity", hl ? 0.9 : 0.6)
.text(sublabel);
}
return g;
};
// Helper to draw arrow
const drawArrow = (x1, y1, x2, y2, fromId, toId) => {
const hl = isHighlighted(fromId) && isHighlighted(toId);
svg.append("line")
.attr("x1", x1)
.attr("y1", y1)
.attr("x2", x2)
.attr("y2", y2)
.attr("stroke", hl ? diagramTheme.highlight : diagramTheme.edgeStroke)
.attr("stroke-width", hl ? 2 : 1.5)
.attr("marker-end", hl ? "url(#emb-arrow-hl)" : "url(#emb-arrow)");
};
// Layout positions
const col1 = 100; // Input column
const col2 = 280; // Tables column
const col3 = 460; // Vectors column
const col4 = 600; // Output column
const row1 = 130; // Token path
const row2 = 250; // Position path
const row3 = 340; // Combined/output
// Draw the pipeline
// Input
drawNode(col1, row1, 120, 55, "[42, 156, 7, 203, 15]", "(batch, seq_len)", "input");
// Token embedding table
drawNode(col2, row1, 140, 55, "Token Embedding Table", "(vocab_size × embed_dim)", "token-table");
drawArrow(col1 + 65, row1, col2 - 75, row1, "input", "token-table");
// Token vectors
drawNode(col3, row1, 130, 55, "Token Vectors", "(batch, seq_len, embed_dim)", "token-vectors");
drawArrow(col2 + 75, row1, col3 - 70, row1, "token-table", "token-vectors");
// Position table
drawNode(col2, row2, 140, 55, "Position Table", "(max_seq_len × embed_dim)", "pos-table");
// Position vectors
drawNode(col3, row2, 130, 55, "Position Vectors", "(seq_len, embed_dim)", "pos-vectors");
drawArrow(col2 + 75, row2, col3 - 70, row2, "pos-table", "pos-vectors");
// Add operation
const addX = col3;
const addY = row3 - 40;
const addHl = isHighlighted("add-op");
svg.append("circle")
.attr("cx", addX)
.attr("cy", addY)
.attr("r", 18)
.attr("fill", addHl ? diagramTheme.highlight : diagramTheme.nodeFill)
.attr("stroke", addHl ? diagramTheme.highlight : diagramTheme.nodeStroke)
.attr("stroke-width", addHl ? 2 : 1.5)
.attr("filter", addHl ? "url(#emb-glow)" : null);
svg.append("text")
.attr("x", addX)
.attr("y", addY)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", addHl ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "16px")
.attr("font-weight", "bold")
.text("+");
// Arrows to add
drawArrow(col3, row1 + 30, addX, addY - 22, "token-vectors", "add-op");
drawArrow(col3, row2 - 30, addX, addY + 22, "pos-vectors", "add-op");
// Combined output
drawNode(col3, row3 + 15, 100, 45, "Combined", "(batch, seq_len, embed_dim)", "combined");
// Arrow from add to combined
svg.append("line")
.attr("x1", addX)
.attr("y1", addY + 20)
.attr("x2", addX)
.attr("y2", row3 - 10)
.attr("stroke", isHighlighted("add-op") && isHighlighted("combined") ? diagramTheme.highlight : diagramTheme.edgeStroke)
.attr("stroke-width", 1.5);
// Dropout
drawNode(col4, row3 - 25, 80, 40, "Dropout", "", "dropout");
drawArrow(col3 + 55, row3 + 15, col4 - 45, row3 - 25, "combined", "dropout");
// Final output
drawNode(col4, row3 + 35, 100, 50, "Final Embeddings", "(batch, seq_len, embed_dim)", "output");
drawArrow(col4, row3 - 3, col4, row3 + 7, "dropout", "output");
// Draw embedding vector visualization (mini heatmap style)
if (step >= 1) {
const vecX = col3 + 75;
const vecY = row1 - 20;
const cellW = 6;
const cellH = 25;
const numCells = 8;
// Generate pseudo-random colors for embedding visualization
const colors = d3.scaleSequential(d3.interpolateRdBu).domain([-1, 1]);
const tokenVecHl = isHighlighted("token-vectors");
for (let i = 0; i < numCells; i++) {
const val = Math.sin(i * 0.8 + 1) * 0.7; // Deterministic pattern
svg.append("rect")
.attr("x", vecX + i * cellW)
.attr("y", vecY)
.attr("width", cellW - 1)
.attr("height", cellH)
.attr("fill", colors(val))
.attr("opacity", tokenVecHl ? 1 : 0.4)
.attr("rx", 1);
}
}
if (step >= 2) {
const vecX = col3 + 75;
const vecY = row2 - 12;
const cellW = 6;
const cellH = 25;
const numCells = 8;
const colors = d3.scaleSequential(d3.interpolateRdBu).domain([-1, 1]);
const posVecHl = isHighlighted("pos-vectors");
for (let i = 0; i < numCells; i++) {
const val = Math.cos(i * 0.6) * 0.8; // Different pattern
svg.append("rect")
.attr("x", vecX + i * cellW)
.attr("y", vecY)
.attr("width", cellW - 1)
.attr("height", cellH)
.attr("fill", colors(val))
.attr("opacity", posVecHl ? 1 : 0.4)
.attr("rx", 1);
}
}
return svg.node();
}
```
::: {.callout-tip}
## Try This
Use the slider to step through the embedding pipeline. Notice how token IDs become vectors through table lookup, then get combined with position information.
:::
## Embeddings Are Just Lookup Tables
The key insight: **embedding lookup is sparse matrix multiplication**. When we say "look up embedding for token 3", we're actually doing:
1. Create a one-hot vector: token 3 in vocab of 10 becomes `[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]`
2. Multiply by the weight matrix: `one_hot @ W`
3. Only one row of W participates (the row for token 3)
This is exactly what happens mathematically. PyTorch skips one-hot creation as an optimization, but the matrix multiplication view reveals how gradients flow.
### From Scratch: One-Hot Embedding Lookup
Let's build an embedding layer using explicit one-hot vectors and matrix multiplication:
```{python}
import numpy as np
import torch
import torch.nn as nn
class ScratchEmbedding:
"""Embedding layer using explicit one-hot multiplication.
This shows what's really happening: embedding lookup is
just sparse matrix multiplication.
"""
def __init__(self, vocab_size: int, embed_dim: int):
self.vocab_size = vocab_size
self.embed_dim = embed_dim
# Weight matrix: each row is the embedding for a token
self.W = np.random.randn(vocab_size, embed_dim) * 0.02
def __call__(self, token_ids: np.ndarray) -> np.ndarray:
"""
token_ids: shape (batch, seq_len) - integer token IDs
returns: shape (batch, seq_len, embed_dim)
"""
batch_size, seq_len = token_ids.shape
# Create one-hot encodings: (batch, seq_len, vocab_size)
one_hot = np.zeros((batch_size, seq_len, self.vocab_size))
# Set the appropriate positions to 1
for b in range(batch_size):
for t in range(seq_len):
one_hot[b, t, token_ids[b, t]] = 1.0
# Matrix multiply: (batch, seq_len, vocab_size) @ (vocab_size, embed_dim)
# = (batch, seq_len, embed_dim)
embeddings = one_hot @ self.W
return embeddings
# Test it
vocab_size, embed_dim = 10, 4
scratch_emb = ScratchEmbedding(vocab_size, embed_dim)
# Sample tokens: batch of 2, sequence length 3
token_ids = np.array([[3, 7, 1],
[5, 3, 9]])
result = scratch_emb(token_ids)
print(f"Token IDs shape: {token_ids.shape}")
print(f"Embeddings shape: {result.shape}")
print(f"\nToken 3's embedding (row 3 of W):")
print(f" From lookup: {result[0, 0]}")
print(f" Direct W[3]: {scratch_emb.W[3]}")
print(f" Match: {np.allclose(result[0, 0], scratch_emb.W[3])}")
```
The one-hot multiplication selects exactly one row from W. Watch the math:
```{python}
# Visualize the one-hot multiplication
token_id = 3
one_hot = np.zeros(vocab_size)
one_hot[token_id] = 1.0
print("One-hot vector for token 3:")
print(f" {one_hot}")
print(f"\nWeight matrix W (10 x 4):")
print(f" Row 0: {scratch_emb.W[0]}")
print(f" Row 1: {scratch_emb.W[1]}")
print(f" Row 2: {scratch_emb.W[2]}")
print(f" Row 3: {scratch_emb.W[3]} <-- selected")
print(f" ...")
print(f"\none_hot @ W = {one_hot @ scratch_emb.W}")
print(f"W[3] directly = {scratch_emb.W[3]}")
```
### PyTorch's nn.Embedding
PyTorch provides the same functionality but optimized - it skips creating the one-hot vector entirely:
```{python}
# PyTorch equivalent
torch_emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
# Copy weights from scratch version for comparison
with torch.no_grad():
torch_emb.weight.copy_(torch.from_numpy(scratch_emb.W).float())
# Same token IDs
token_ids_torch = torch.tensor(token_ids)
result_torch = torch_emb(token_ids_torch)
print(f"Scratch result (token 3): {result[0, 0]}")
print(f"PyTorch result (token 3): {result_torch[0, 0].detach().numpy()}")
print(f"Match: {np.allclose(result, result_torch.detach().numpy())}")
```
::: {.callout-tip}
## Key Insight
`nn.Embedding` is just an optimized lookup - no one-hot materialization. But mathematically, it's identical to one-hot times weight matrix. Understanding this helps when debugging gradient flow.
:::
## Making Lookups Differentiable
How do gradients flow through an embedding lookup? The answer comes directly from the matrix multiplication view.
### From Scratch: Gradient Flow
When we compute `output = one_hot @ W`, the gradient with respect to W follows standard matrix calculus:
```
dL/dW = one_hot.T @ dL/doutput
```
**Only the selected rows receive gradients.** If we looked up tokens [3, 7, 1], only rows 3, 7, and 1 of W get updated during training.
```{python}
class ScratchEmbeddingWithGrad:
"""Embedding with gradient computation."""
def __init__(self, vocab_size: int, embed_dim: int):
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.W = np.random.randn(vocab_size, embed_dim) * 0.02
self.grad_W = None
self._last_one_hot = None # Store for backward
def forward(self, token_ids: np.ndarray) -> np.ndarray:
batch_size, seq_len = token_ids.shape
# Create one-hot: (batch, seq_len, vocab_size)
one_hot = np.zeros((batch_size, seq_len, self.vocab_size))
for b in range(batch_size):
for t in range(seq_len):
one_hot[b, t, token_ids[b, t]] = 1.0
self._last_one_hot = one_hot
return one_hot @ self.W
def backward(self, grad_output: np.ndarray):
"""
grad_output: shape (batch, seq_len, embed_dim) - gradient from next layer
"""
# dL/dW = one_hot.T @ grad_output
# Reshape for batch matmul: (batch, vocab_size, seq_len) @ (batch, seq_len, embed_dim)
one_hot_T = self._last_one_hot.transpose(0, 2, 1) # (batch, vocab, seq_len)
# Accumulate gradients across batch
self.grad_W = np.zeros_like(self.W)
for b in range(grad_output.shape[0]):
self.grad_W += one_hot_T[b] @ grad_output[b]
return self.grad_W
# Demonstrate gradient flow
emb = ScratchEmbeddingWithGrad(vocab_size=10, embed_dim=4)
token_ids = np.array([[3, 7, 1]]) # batch=1, seq_len=3
# Forward
output = emb.forward(token_ids)
# Simulate gradient from loss (all ones for simplicity)
grad_from_loss = np.ones_like(output)
# Backward
grad_W = emb.backward(grad_from_loss)
print("Gradient magnitude per row of W:")
for i in range(10):
magnitude = np.abs(grad_W[i]).sum()
marker = " <-- used" if i in [3, 7, 1] else ""
print(f" Row {i}: {magnitude:.4f}{marker}")
print("\nOnly rows 1, 3, 7 received gradients!")
```
### PyTorch: Automatic Gradient Tracking
PyTorch handles this automatically when `requires_grad=True`:
```{python}
# PyTorch does this automatically
torch_emb = nn.Embedding(10, 4)
token_ids = torch.tensor([[3, 7, 1]])
output = torch_emb(token_ids)
# Fake loss: sum of embeddings
loss = output.sum()
loss.backward()
print("PyTorch gradient magnitude per row:")
for i in range(10):
magnitude = torch_emb.weight.grad[i].abs().sum().item()
marker = " <-- used" if i in [3, 7, 1] else ""
print(f" Row {i}: {magnitude:.4f}{marker}")
```
::: {.callout-note}
## Sparse Updates
This "sparse gradient" property is why embedding layers can have millions of parameters but train efficiently - each batch only updates a small subset of rows.
:::
## Positional Information
Attention (without positional embeddings) is **permutation-equivariant**: if you reorder the input tokens, the attention scores simply reorder to match. The relationship between "cat" and "sat" is the same regardless of whether they're at positions [0,1] or [1,0]. This means the model can't distinguish "the cat sat" from "sat the cat" — a critical limitation since word order carries meaning.
Position embeddings solve this by giving each position a learnable vector that gets added to the token embedding.
### From Scratch: Learnable Position Embeddings
Position embeddings are just another lookup table, indexed by position instead of token ID:
```{python}
class ScratchPositionEmbedding:
"""Learnable position embeddings - same as token embeddings but indexed by position."""
def __init__(self, max_seq_len: int, embed_dim: int):
self.max_seq_len = max_seq_len
self.embed_dim = embed_dim
# Each position gets its own learnable vector
self.W = np.random.randn(max_seq_len, embed_dim) * 0.02
def __call__(self, seq_len: int) -> np.ndarray:
"""
seq_len: how many positions to return
returns: shape (seq_len, embed_dim)
"""
# Just slice the first seq_len positions
return self.W[:seq_len]
class ScratchCombinedEmbedding:
"""Token embeddings + position embeddings."""
def __init__(self, vocab_size: int, embed_dim: int, max_seq_len: int):
self.token_emb = ScratchEmbedding(vocab_size, embed_dim)
self.pos_emb = ScratchPositionEmbedding(max_seq_len, embed_dim)
def __call__(self, token_ids: np.ndarray) -> np.ndarray:
"""
token_ids: shape (batch, seq_len)
returns: shape (batch, seq_len, embed_dim)
"""
batch_size, seq_len = token_ids.shape
# Get token embeddings: (batch, seq_len, embed_dim)
tok_emb = self.token_emb(token_ids)
# Get position embeddings: (seq_len, embed_dim)
pos_emb = self.pos_emb(seq_len)
# Add position embeddings (broadcasts over batch dimension)
return tok_emb + pos_emb
# Test combined embedding
combined = ScratchCombinedEmbedding(vocab_size=100, embed_dim=8, max_seq_len=32)
# Same token (ID=42) at different positions
tokens = np.array([[42, 42, 42, 42]]) # Same token, 4 positions
embeddings = combined(tokens)
print("Same token (42) at different positions:")
for pos in range(4):
print(f" Position {pos}: {embeddings[0, pos, :4]}...")
print("\nAll different due to position embeddings!")
```
### Why Position Matters: Attention is Permutation-Invariant
Without position embeddings, attention treats tokens as an unordered set:
```{python}
#| output: false
# Demonstrate permutation invariance
def simple_attention_scores(embeddings):
"""Compute raw attention scores (Q @ K.T) without position."""
# In real attention, Q = emb @ W_q, K = emb @ W_k
# For simplicity, use embeddings directly
return embeddings @ embeddings.T
# Create two orderings of the same tokens
token_emb = ScratchEmbedding(vocab_size=10, embed_dim=8)
# "the cat sat" = tokens [1, 5, 7]
order1 = np.array([[1, 5, 7]])
# "sat the cat" = tokens [7, 1, 5]
order2 = np.array([[7, 1, 5]])
emb1 = token_emb(order1)[0] # (3, 8)
emb2 = token_emb(order2)[0] # (3, 8)
scores1 = simple_attention_scores(emb1)
scores2 = simple_attention_scores(emb2)
# Pass data to OJS
ojs_define(
attention_scores1=scores1.tolist(),
attention_scores2=scores2.tolist()
)
```
```{ojs}
//| echo: false
// Attention scores permutation visualization
attentionPermutationViz = {
const width = 700;
const height = 280;
const margin = { top: 50, right: 80, bottom: 50, left: 70 };
const theme = diagramTheme;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 10);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "13px")
.attr("font-weight", "600")
.text("Attention scores are just permuted (without position embeddings, order is lost)");
// Get attention matrices
const scores1 = attention_scores1;
const scores2 = attention_scores2;
// Find global min/max for consistent color scale
const allVals = [...scores1.flat(), ...scores2.flat()];
const minVal = Math.min(...allVals);
const maxVal = Math.max(...allVals);
// Color scale - Blues
const colorScale = d3.scaleSequential(d3.interpolateBlues)
.domain([minVal, maxVal]);
// Draw a single heatmap
const drawHeatmap = (data, x, y, title, subtitle) => {
const cellSize = 50;
const labels = ["0", "1", "2"];
const g = svg.append("g")
.attr("transform", `translate(${x}, ${y})`);
// Title
g.append("text")
.attr("x", cellSize * 1.5)
.attr("y", -25)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.text(title);
g.append("text")
.attr("x", cellSize * 1.5)
.attr("y", -10)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.7)
.text(subtitle);
// Cells
data.forEach((row, i) => {
row.forEach((val, j) => {
g.append("rect")
.attr("x", j * cellSize)
.attr("y", i * cellSize)
.attr("width", cellSize - 2)
.attr("height", cellSize - 2)
.attr("fill", colorScale(val))
.attr("rx", 4);
g.append("text")
.attr("x", j * cellSize + cellSize / 2 - 1)
.attr("y", i * cellSize + cellSize / 2)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", val > (minVal + maxVal) / 2 ? diagramTheme.textOnHighlight : theme.nodeText)
.attr("font-size", "9px")
.text(val.toFixed(2));
});
});
// Y axis label
g.append("text")
.attr("transform", `rotate(-90)`)
.attr("x", -cellSize * 1.5)
.attr("y", -20)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text("Query pos");
// X axis label
g.append("text")
.attr("x", cellSize * 1.5)
.attr("y", cellSize * 3 + 20)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text("Key pos");
};
// Draw both heatmaps
drawHeatmap(scores1, 120, 75, 'Order: [1, 5, 7]', '"the cat sat"');
drawHeatmap(scores2, 420, 75, 'Order: [7, 1, 5]', '"sat the cat"');
// Color legend
const legendWidth = 15;
const legendHeight = 120;
const legendX = width - 45;
const legendY = 75;
const legendGrad = svg.append("defs")
.append("linearGradient")
.attr("id", "attn-legend-grad")
.attr("x1", "0%")
.attr("y1", "100%")
.attr("x2", "0%")
.attr("y2", "0%");
legendGrad.append("stop").attr("offset", "0%").attr("stop-color", colorScale(minVal));
legendGrad.append("stop").attr("offset", "100%").attr("stop-color", colorScale(maxVal));
svg.append("rect")
.attr("x", legendX)
.attr("y", legendY)
.attr("width", legendWidth)
.attr("height", legendHeight)
.attr("fill", "url(#attn-legend-grad)")
.attr("rx", 3);
svg.append("text")
.attr("x", legendX + legendWidth + 5)
.attr("y", legendY + 6)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text(maxVal.toFixed(1));
svg.append("text")
.attr("x", legendX + legendWidth + 5)
.attr("y", legendY + legendHeight)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text(minVal.toFixed(1));
return svg.node();
}
```
### PyTorch: Combined Token + Position Embedding
```{python}
class PyTorchCombinedEmbedding(nn.Module):
"""Standard transformer embedding: token + position."""
def __init__(self, vocab_size: int, embed_dim: int, max_seq_len: int):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, embed_dim)
self.pos_emb = nn.Embedding(max_seq_len, embed_dim)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
batch_size, seq_len = token_ids.shape
# Token embeddings
tok_emb = self.token_emb(token_ids)
# Position embeddings (create position indices)
positions = torch.arange(seq_len, device=token_ids.device)
pos_emb = self.pos_emb(positions)
return tok_emb + pos_emb
# Compare scratch vs PyTorch
pytorch_combined = PyTorchCombinedEmbedding(vocab_size=100, embed_dim=8, max_seq_len=32)
tokens_torch = torch.tensor([[42, 42, 42, 42]])
embeddings_torch = pytorch_combined(tokens_torch)
print("PyTorch: Same token (42) at different positions:")
for pos in range(4):
print(f" Position {pos}: {embeddings_torch[0, pos, :4].tolist()}")
```
::: {.callout-tip}
## Key Insight
Position embeddings are just another embedding table - they work identically to token embeddings but are indexed by position. The "magic" is simply: `final = token_emb[token_id] + pos_emb[position]`.
:::
## The Math
### Token Embeddings
Simple lookup table: `E[token_id] = embedding_vector`
Mathematically equivalent to one-hot multiplication:
```
one_hot = [0, 0, 0, 1, 0, ...] # 1 at position token_id
embedding = one_hot @ E # selects row token_id from E
```
### Positional Embeddings
Several approaches encode position:
**1. Learned positional embeddings (GPT-2, BERT):**
```python
# Position table: (max_seq_len, embed_dim)
P = torch.randn(max_seq_len, embed_dim)
positions = P[:seq_len] # Get positions for current sequence
```
Each position gets a trainable vector. Simple and effective, but cannot generalize to positions beyond `max_seq_len`.
**2. Sinusoidal positional embeddings (original Transformer):**
```
PE(pos, 2i) = sin(pos / 10000^(2i/d))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d))
```
This creates a unique pattern for each position using waves of different frequencies. The key insight: `PE(pos+k)` equals a linear transformation of `PE(pos)`, so the model can learn relative positions directly.
**3. Rotary Position Embedding - RoPE (LLaMA, Mistral):**
Rather than adding position embeddings to token embeddings, RoPE rotates the query and key vectors based on position. The rotation angle depends on both position and dimension, encoding relative positions naturally in the attention computation. This approach extrapolates well beyond training sequence lengths.
**4. ALiBi - Attention with Linear Biases (BLOOM):**
Instead of adding position information to embeddings, ALiBi adds a position-dependent bias directly to the attention scores: closer tokens get higher scores. Attention applies this bias directly, bypassing the embedding layer.
### Combined Embeddings
```python
# Input: token_ids of shape (batch, seq_len)
token_emb = token_embedding[token_ids] # (batch, seq_len, embed_dim)
pos_emb = position_embedding[:seq_len] # (seq_len, embed_dim)
x = token_emb + pos_emb # (batch, seq_len, embed_dim)
```
## Same Token, Different Positions
The same token ("the") appears at multiple positions in a sentence. Even though it has the same token embedding, the final embedding differs because of position.
```{ojs}
//| echo: false
// Controls for position comparison
viewof pos1 = Inputs.range([0, 5], {
value: 0,
step: 1,
label: "First position"
})
viewof pos2 = Inputs.range([0, 5], {
value: 4,
step: 1,
label: "Second position"
})
viewof embedDimVis = Inputs.range([8, 32], {
value: 16,
step: 4,
label: "Embedding dimensions shown"
})
```
```{ojs}
//| echo: false
// Generate deterministic pseudo-embeddings for visualization
posCompareData = {
const tokenEmb = [];
const posEmb1 = [];
const posEmb2 = [];
// Token embedding (same for both positions - "the")
for (let i = 0; i < embedDimVis; i++) {
tokenEmb.push(Math.sin(i * 0.5 + 2.1) * 0.8);
}
// Position embeddings (different for each position - sinusoidal-like)
for (let i = 0; i < embedDimVis; i++) {
const freq1 = 1 / Math.pow(10, i / embedDimVis * 2);
const freq2 = 1 / Math.pow(10, i / embedDimVis * 2);
posEmb1.push(i % 2 === 0 ? Math.sin(pos1 * freq1) : Math.cos(pos1 * freq1));
posEmb2.push(i % 2 === 0 ? Math.sin(pos2 * freq2) : Math.cos(pos2 * freq2));
}
// Combined embeddings
const final1 = tokenEmb.map((v, i) => v + posEmb1[i]);
const final2 = tokenEmb.map((v, i) => v + posEmb2[i]);
// Compute difference and similarity
const diff = final1.map((v, i) => Math.abs(v - final2[i]));
const dotProduct = final1.reduce((sum, v, i) => sum + v * final2[i], 0);
const norm1 = Math.sqrt(final1.reduce((sum, v) => sum + v * v, 0));
const norm2 = Math.sqrt(final2.reduce((sum, v) => sum + v * v, 0));
const cosineSim = dotProduct / (norm1 * norm2);
return { tokenEmb, posEmb1, posEmb2, final1, final2, diff, cosineSim };
}
```
```{ojs}
//| echo: false
// Interactive position comparison visualization
positionComparisonViz = {
const width = 700;
const height = 480;
const margin = { top: 60, right: 30, bottom: 40, left: 110 };
const barH = 18;
const rowGap = 65;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 10);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 28)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text(`"the" at position ${pos1} vs position ${pos2}`);
// Similarity indicator - use highlight/accent colors for theme support
const simColor = posCompareData.cosineSim > 0.9 ? diagramTheme.highlight :
posCompareData.cosineSim > 0.7 ? diagramTheme.accent : diagramTheme.edgeStroke;
svg.append("text")
.attr("x", width / 2)
.attr("y", 48)
.attr("text-anchor", "middle")
.attr("fill", simColor)
.attr("font-size", "12px")
.text(`Cosine similarity: ${posCompareData.cosineSim.toFixed(3)}`);
// Color scale
const colorScale = d3.scaleSequential(d3.interpolateRdBu).domain([-1.5, 1.5]);
// Bar chart width
const chartWidth = width - margin.left - margin.right;
const cellW = Math.min(20, chartWidth / embedDimVis);
// Draw embedding bars helper
const drawEmbeddingBar = (data, y, label, sublabel) => {
// Label
svg.append("text")
.attr("x", margin.left - 10)
.attr("y", y + barH / 2)
.attr("text-anchor", "end")
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.text(label);
if (sublabel) {
svg.append("text")
.attr("x", margin.left - 10)
.attr("y", y + barH / 2 + 14)
.attr("text-anchor", "end")
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "9px")
.attr("opacity", 0.6)
.text(sublabel);
}
// Cells
data.forEach((val, i) => {
svg.append("rect")
.attr("x", margin.left + i * cellW)
.attr("y", y)
.attr("width", cellW - 1)
.attr("height", barH)
.attr("fill", colorScale(val))
.attr("rx", 2);
});
};
// Row positions
let currentY = margin.top + 20;
// Token embedding (same for both)
drawEmbeddingBar(posCompareData.tokenEmb, currentY, "Token emb", "(same)");
currentY += rowGap;
// Position embedding 1
drawEmbeddingBar(posCompareData.posEmb1, currentY, `Pos ${pos1} emb`, "");
// Plus sign
svg.append("text")
.attr("x", margin.left - 60)
.attr("y", currentY + barH + 20)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "16px")
.attr("font-weight", "bold")
.text("+");
currentY += rowGap;
// Final embedding 1
svg.append("rect")
.attr("x", margin.left - 5)
.attr("y", currentY - 5)
.attr("width", embedDimVis * cellW + 10)
.attr("height", barH + 10)
.attr("fill", "none")
.attr("stroke", diagramTheme.highlight)
.attr("stroke-width", 2)
.attr("rx", 4);
drawEmbeddingBar(posCompareData.final1, currentY, `Final @ ${pos1}`, "");
currentY += rowGap + 10;
// Divider
svg.append("line")
.attr("x1", margin.left)
.attr("y1", currentY - 15)
.attr("x2", width - margin.right)
.attr("y2", currentY - 15)
.attr("stroke", diagramTheme.nodeStroke)
.attr("stroke-dasharray", "4,4")
.attr("opacity", 0.5);
// Position embedding 2
drawEmbeddingBar(posCompareData.posEmb2, currentY, `Pos ${pos2} emb`, "");
// Plus sign
svg.append("text")
.attr("x", margin.left - 60)
.attr("y", currentY + barH + 20)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "16px")
.attr("font-weight", "bold")
.text("+");
// Token embedding again (visual reminder it's the same)
svg.append("text")
.attr("x", margin.left - 10)
.attr("y", currentY + barH + 20)
.attr("text-anchor", "end")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.5)
.text("(+ same token emb)");
currentY += rowGap;
// Final embedding 2
svg.append("rect")
.attr("x", margin.left - 5)
.attr("y", currentY - 5)
.attr("width", embedDimVis * cellW + 10)
.attr("height", barH + 10)
.attr("fill", "none")
.attr("stroke", diagramTheme.accent)
.attr("stroke-width", 2)
.attr("rx", 4);
drawEmbeddingBar(posCompareData.final2, currentY, `Final @ ${pos2}`, "");
// Difference indicator
currentY += rowGap - 5;
svg.append("text")
.attr("x", margin.left - 10)
.attr("y", currentY + barH / 2)
.attr("text-anchor", "end")
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.7)
.text("Difference");
// Draw difference magnitude
const maxDiff = Math.max(...posCompareData.diff);
posCompareData.diff.forEach((val, i) => {
const h = Math.min(barH, (val / maxDiff) * barH);
svg.append("rect")
.attr("x", margin.left + i * cellW)
.attr("y", currentY + barH - h)
.attr("width", cellW - 1)
.attr("height", h)
.attr("fill", diagramTheme.highlight)
.attr("opacity", 0.6)
.attr("rx", 1);
});
// Legend for color scale
const legendX = width - 100;
const legendY = margin.top + 30;
svg.append("text")
.attr("x", legendX)
.attr("y", legendY - 10)
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "9px")
.attr("opacity", 0.7)
.text("Value scale");
const legendW = 60;
const legendH = 10;
const legendGrad = svg.append("defs")
.append("linearGradient")
.attr("id", "legend-grad");
legendGrad.append("stop").attr("offset", "0%").attr("stop-color", colorScale(-1.5));
legendGrad.append("stop").attr("offset", "50%").attr("stop-color", colorScale(0));
legendGrad.append("stop").attr("offset", "100%").attr("stop-color", colorScale(1.5));
svg.append("rect")
.attr("x", legendX)
.attr("y", legendY)
.attr("width", legendW)
.attr("height", legendH)
.attr("fill", "url(#legend-grad)")
.attr("rx", 2);
svg.append("text")
.attr("x", legendX)
.attr("y", legendY + legendH + 12)
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "8px")
.text("-1.5");
svg.append("text")
.attr("x", legendX + legendW)
.attr("y", legendY + legendH + 12)
.attr("text-anchor", "end")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "8px")
.text("+1.5");
return svg.node();
}
```
::: {.callout-tip}
## Try This
1. **Same position**: Set both positions to the same value (e.g., 0 and 0). The similarity becomes 1.0 (identical).
2. **Adjacent positions**: Compare positions 0 and 1. They are very similar (>0.95) because position embeddings change gradually.
3. **Distant positions**: Compare positions 0 and 5. The similarity drops because position embeddings diverge.
The key insight: **same token ID + different position = different final embedding**. This is how the model knows word order matters.
:::
## Code Walkthrough
Let's explore embeddings interactively:
```{python}
import torch
import torch.nn as nn
import numpy as np
print(f"PyTorch version: {torch.__version__}")
```
### Token Embeddings Basics
A token embedding is just a lookup table: token ID -> vector
```{python}
# Create a simple token embedding
vocab_size = 100
embed_dim = 32
# nn.Embedding is PyTorch's lookup table
token_emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
print(f"Vocabulary size: {vocab_size}")
print(f"Embedding dimension: {embed_dim}")
print(f"Total parameters: {vocab_size * embed_dim:,}")
print(f"Embedding table shape: {token_emb.weight.shape}")
```
```{python}
# Look up embeddings for some tokens
token_ids = torch.tensor([[5, 10, 15, 20]])
embeddings = token_emb(token_ids)
print(f"Input token IDs: {token_ids[0].tolist()}")
print(f"Output shape: {tuple(embeddings.shape)}")
print(f"\nToken 5's embedding (first 8 dims):")
print(f" {embeddings[0, 0, :8].tolist()}")
```
```{python}
# Same token always gets the same embedding
e1 = token_emb(torch.tensor([[42]]))
e2 = token_emb(torch.tensor([[42]]))
print(f"Token 42 embedding (call 1): {e1[0, 0, :4].tolist()}")
print(f"Token 42 embedding (call 2): {e2[0, 0, :4].tolist()}")
print(f"Equal: {torch.allclose(e1, e2)}")
```
### Sinusoidal Positional Encoding
The original Transformer uses sin/cos functions to encode position. The key idea is to create a unique "fingerprint" for each position using waves of different frequencies:
- **Low-frequency components** (high dimensions): Change slowly across positions, capturing coarse position
- **High-frequency components** (low dimensions): Change rapidly, capturing fine-grained position
This is analogous to how Fourier series can represent any periodic function as a sum of sines and cosines:
```{python}
import math
def create_sinusoidal_encoding(max_seq_len: int, embed_dim: int) -> torch.Tensor:
"""Create sinusoidal positional encoding matrix."""
pe = torch.zeros(max_seq_len, embed_dim)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
# Compute div_term: 10000^(2i/d) = exp(2i * -log(10000) / d)
div_term = torch.exp(
torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim)
)
# Apply sin to even indices, cos to odd indices
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
# Create positional encoding
pe = create_sinusoidal_encoding(max_seq_len=128, embed_dim=64)
print(f"Positional encoding shape: {pe.shape}")
```
```{python}
#| output: false
# Prepare data for OJS visualization
pe_heatmap_data = pe[:50].numpy().tolist()
# Extract specific dimensions for line plot
pe_dim_lines = {
'dim0': pe[:50, 0].numpy().tolist(),
'dim1': pe[:50, 1].numpy().tolist(),
'dim10': pe[:50, 10].numpy().tolist(),
'dim11': pe[:50, 11].numpy().tolist(),
'dim30': pe[:50, 30].numpy().tolist(),
'dim31': pe[:50, 31].numpy().tolist()
}
# Compute cosine similarity between all position pairs
pe_subset = pe[:20]
pe_norm = pe_subset / pe_subset.norm(dim=1, keepdim=True)
similarity = pe_norm @ pe_norm.T
pe_similarity_data = similarity.numpy().tolist()
ojs_define(
pe_heatmap=pe_heatmap_data,
pe_lines=pe_dim_lines,
pe_similarity=pe_similarity_data
)
```
```{ojs}
//| echo: false
// Sinusoidal Positional Encoding Heatmap
peHeatmapViz = {
const width = 700;
const height = 320;
const margin = { top: 45, right: 80, bottom: 50, left: 70 };
const theme = diagramTheme;
const data = pe_heatmap;
const numPositions = data.length;
const numDims = data[0].length;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 10);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 24)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text("Sinusoidal Positional Encoding");
const innerWidth = width - margin.left - margin.right;
const innerHeight = height - margin.top - margin.bottom;
const g = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
// Cell sizes
const cellW = innerWidth / numDims;
const cellH = innerHeight / numPositions;
// Color scale
const colorScale = d3.scaleSequential(d3.interpolateRdBu).domain([1, -1]);
// Draw cells
data.forEach((row, posIdx) => {
row.forEach((val, dimIdx) => {
g.append("rect")
.attr("x", dimIdx * cellW)
.attr("y", posIdx * cellH)
.attr("width", cellW)
.attr("height", cellH)
.attr("fill", colorScale(val));
});
});
// X axis
g.append("text")
.attr("x", innerWidth / 2)
.attr("y", innerHeight + 35)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Embedding Dimension");
// X ticks
[0, 16, 32, 48, 63].forEach(d => {
g.append("text")
.attr("x", d * cellW + cellW / 2)
.attr("y", innerHeight + 15)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text(d);
});
// Y axis
g.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -innerHeight / 2)
.attr("y", -45)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Position");
// Y ticks
[0, 10, 20, 30, 40, 49].forEach(d => {
g.append("text")
.attr("x", -10)
.attr("y", d * cellH + cellH / 2 + 3)
.attr("text-anchor", "end")
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text(d);
});
// Legend
const legendX = width - 55;
const legendY = margin.top;
const legendH = innerHeight;
const legendW = 15;
const legendGrad = svg.append("defs")
.append("linearGradient")
.attr("id", "pe-heatmap-legend")
.attr("x1", "0%")
.attr("y1", "0%")
.attr("x2", "0%")
.attr("y2", "100%");
legendGrad.append("stop").attr("offset", "0%").attr("stop-color", colorScale(-1));
legendGrad.append("stop").attr("offset", "50%").attr("stop-color", colorScale(0));
legendGrad.append("stop").attr("offset", "100%").attr("stop-color", colorScale(1));
svg.append("rect")
.attr("x", legendX)
.attr("y", legendY)
.attr("width", legendW)
.attr("height", legendH)
.attr("fill", "url(#pe-heatmap-legend)")
.attr("rx", 3);
svg.append("text")
.attr("x", legendX + legendW + 5)
.attr("y", legendY + 8)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text("-1");
svg.append("text")
.attr("x", legendX + legendW + 5)
.attr("y", legendY + legendH / 2 + 3)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text("0");
svg.append("text")
.attr("x", legendX + legendW + 5)
.attr("y", legendY + legendH)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text("+1");
return svg.node();
}
```
```{ojs}
//| echo: false
// Positional Encoding Line Plot by Dimension
peDimLinesViz = {
const width = 700;
const height = 260;
const margin = { top: 50, right: 150, bottom: 45, left: 55 };
const theme = diagramTheme;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 10);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text("Positional Encoding by Dimension");
svg.append("text")
.attr("x", width / 2)
.attr("y", 38)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.7)
.text("(Lower dimensions = higher frequency)");
const innerWidth = width - margin.left - margin.right;
const innerHeight = height - margin.top - margin.bottom;
const g = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
// Data preparation - use theme-aware colors from diagramTheme
// Colors for line chart: accent, highlight, plus additional visually-distinct hues
// These colors work in both light and dark modes due to their saturation levels
const dims = [
{ key: 'dim0', label: 'Dim 0 (sin)', color: diagramTheme.accent },
{ key: 'dim1', label: 'Dim 1 (cos)', color: diagramTheme.highlight },
{ key: 'dim10', label: 'Dim 10 (sin)', color: diagramTheme.primary },
{ key: 'dim11', label: 'Dim 11 (cos)', color: diagramTheme.success },
{ key: 'dim30', label: 'Dim 30 (sin)', color: diagramTheme.info },
{ key: 'dim31', label: 'Dim 31 (cos)', color: diagramTheme.edgeStroke }
];
const numPositions = pe_lines.dim0.length;
// Scales
const xScale = d3.scaleLinear()
.domain([0, numPositions - 1])
.range([0, innerWidth]);
const yScale = d3.scaleLinear()
.domain([-1.1, 1.1])
.range([innerHeight, 0]);
// Grid lines
g.append("line")
.attr("x1", 0)
.attr("x2", innerWidth)
.attr("y1", yScale(0))
.attr("y2", yScale(0))
.attr("stroke", theme.nodeStroke)
.attr("stroke-dasharray", "4,4")
.attr("opacity", 0.5);
// Axes
g.append("g")
.attr("transform", `translate(0, ${innerHeight})`)
.call(d3.axisBottom(xScale).ticks(5))
.call(g => g.selectAll("text").attr("fill", theme.nodeText).attr("font-size", "9px"))
.call(g => g.selectAll("line").attr("stroke", theme.nodeStroke))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke));
g.append("g")
.call(d3.axisLeft(yScale).ticks(5))
.call(g => g.selectAll("text").attr("fill", theme.nodeText).attr("font-size", "9px"))
.call(g => g.selectAll("line").attr("stroke", theme.nodeStroke))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke));
// Axis labels
g.append("text")
.attr("x", innerWidth / 2)
.attr("y", innerHeight + 35)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Position");
g.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -innerHeight / 2)
.attr("y", -40)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Value");
// Draw lines
const line = d3.line()
.x((d, i) => xScale(i))
.y(d => yScale(d));
dims.forEach(dim => {
g.append("path")
.datum(pe_lines[dim.key])
.attr("fill", "none")
.attr("stroke", dim.color)
.attr("stroke-width", 2)
.attr("d", line);
});
// Legend
const legendX = innerWidth + 15;
dims.forEach((dim, i) => {
const ly = 10 + i * 22;
g.append("line")
.attr("x1", legendX)
.attr("x2", legendX + 20)
.attr("y1", ly)
.attr("y2", ly)
.attr("stroke", dim.color)
.attr("stroke-width", 2);
g.append("text")
.attr("x", legendX + 28)
.attr("y", ly + 4)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text(dim.label);
});
return svg.node();
}
```
```{ojs}
//| echo: false
// Position Similarity Matrix
peSimilarityViz = {
const width = 500;
const height = 420;
const margin = { top: 55, right: 80, bottom: 55, left: 65 };
const theme = diagramTheme;
const data = pe_similarity;
const n = data.length;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 10);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text("Positional Encoding Similarity Matrix");
svg.append("text")
.attr("x", width / 2)
.attr("y", 40)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.7)
.text("(Nearby positions are more similar)");
const innerWidth = width - margin.left - margin.right;
const innerHeight = height - margin.top - margin.bottom;
const g = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
const cellSize = Math.min(innerWidth, innerHeight) / n;
// Color scale - Blues
const colorScale = d3.scaleSequential(d3.interpolateBlues)
.domain([0, 1]);
// Draw cells
data.forEach((row, i) => {
row.forEach((val, j) => {
g.append("rect")
.attr("x", j * cellSize)
.attr("y", i * cellSize)
.attr("width", cellSize - 1)
.attr("height", cellSize - 1)
.attr("fill", colorScale(val))
.attr("rx", 1);
});
});
// Axes labels
g.append("text")
.attr("x", n * cellSize / 2)
.attr("y", n * cellSize + 35)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Position");
g.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -n * cellSize / 2)
.attr("y", -45)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Position");
// Tick labels
[0, 5, 10, 15, 19].forEach(t => {
g.append("text")
.attr("x", t * cellSize + cellSize / 2)
.attr("y", n * cellSize + 15)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text(t);
g.append("text")
.attr("x", -10)
.attr("y", t * cellSize + cellSize / 2 + 3)
.attr("text-anchor", "end")
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text(t);
});
// Legend
const legendX = width - 55;
const legendY = margin.top;
const legendH = n * cellSize;
const legendW = 15;
const legendGrad = svg.append("defs")
.append("linearGradient")
.attr("id", "sim-legend")
.attr("x1", "0%")
.attr("y1", "100%")
.attr("x2", "0%")
.attr("y2", "0%");
legendGrad.append("stop").attr("offset", "0%").attr("stop-color", colorScale(0));
legendGrad.append("stop").attr("offset", "100%").attr("stop-color", colorScale(1));
svg.append("rect")
.attr("x", legendX)
.attr("y", legendY)
.attr("width", legendW)
.attr("height", legendH)
.attr("fill", "url(#sim-legend)")
.attr("rx", 3);
svg.append("text")
.attr("x", legendX + legendW + 5)
.attr("y", legendY + 8)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text("1.0");
svg.append("text")
.attr("x", legendX + legendW + 5)
.attr("y", legendY + legendH)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text("0.0");
return svg.node();
}
```
Notice that:
1. **Nearby positions are similar**: Positions 5 and 6 are more similar than positions 5 and 15
2. **The pattern is symmetric**: sim(i, j) = sim(j, i)
3. **Each position is unique**: No two positions have identical encodings
The sinusoidal encoding also has a key mathematical property: for any fixed offset k, the encoding `PE(pos+k)` can be expressed as a linear transformation of `PE(pos)`. This helps the model learn relative positions (e.g., "this token is 3 positions before that token").
### Combined Transformer Embedding
We add token embeddings and positional embeddings together. The implementation in `embeddings.py` handles:
1. **Scaling by sqrt(embed_dim)**: Token embeddings are multiplied by `sqrt(embed_dim)` before adding positional embeddings. This prevents the positional signal from dominating when embed_dim is large (since embeddings are typically initialized with small values like `std=0.02`).
2. **Initialization**: We initialize embeddings from a normal distribution with small standard deviation (0.02). Small initialization prevents exploding gradients.
3. **Padding token handling**: The embedding for the padding token (usually ID 0) is set to zeros and excluded from gradient updates.
4. **Dropout**: Dropout follows the combination step for regularization.
```{python}
class TransformerEmbedding(nn.Module):
"""Combined token + positional embedding."""
def __init__(self, vocab_size: int, embed_dim: int, max_seq_len: int, dropout: float = 0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(embed_dim)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
seq_len = token_ids.shape[1]
# Get token embeddings and scale
token_emb = self.token_embedding(token_ids) * self.scale
# Get positional embeddings
positions = torch.arange(seq_len, device=token_ids.device)
pos_emb = self.position_embedding(positions)
# Combine and apply dropout
return self.dropout(token_emb + pos_emb)
# Create embedding layer
emb = TransformerEmbedding(
vocab_size=1000,
embed_dim=64,
max_seq_len=128,
dropout=0.0 # Disable for visualization
)
# Process some tokens
tokens = torch.randint(0, 1000, (1, 10))
output = emb(tokens)
print(f"Input tokens: {tokens[0].tolist()}")
print(f"Output shape: {tuple(output.shape)}")
```
```{python}
# Show that same token at different positions has different embeddings
# Put token 42 at positions 0, 5, and 9
tokens = torch.tensor([[42, 1, 2, 3, 4, 42, 6, 7, 8, 42]])
output = emb(tokens)
# Get the embeddings for token 42 at each position
pos_0 = output[0, 0].detach()
pos_5 = output[0, 5].detach()
pos_9 = output[0, 9].detach()
print("Token 42 at different positions:")
print(f" Position 0: {pos_0[:4].tolist()}")
print(f" Position 5: {pos_5[:4].tolist()}")
print(f" Position 9: {pos_9[:4].tolist()}")
print(f"\nAll different due to positional encoding!")
```
```{python}
#| output: false
# Prepare data for OJS visualization
token_only = emb.token_embedding(tokens)[0].detach().numpy()
positions = torch.arange(10)
pos_only = emb.position_embedding(positions).detach().numpy()
combined = output[0].detach().numpy()
ojs_define(
token_emb_only=token_only.tolist(),
pos_emb_only=pos_only.tolist(),
combined_emb=combined.tolist()
)
```
```{ojs}
//| echo: false
// Token + Position + Combined embedding visualization
embeddingCombinationViz = {
const width = 700;
const height = 260;
const margin = { top: 45, right: 20, bottom: 45, left: 50 };
const theme = diagramTheme;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 10);
// Data
const datasets = [
{ data: token_emb_only, title: "Token Embeddings Only" },
{ data: pos_emb_only, title: "Positional Embeddings Only" },
{ data: combined_emb, title: "Token + Position (Combined)" }
];
// Find global min/max for consistent color scale
const allVals = [...token_emb_only.flat(), ...pos_emb_only.flat(), ...combined_emb.flat()];
const absMax = Math.max(Math.abs(Math.min(...allVals)), Math.abs(Math.max(...allVals)));
const colorScale = d3.scaleSequential(d3.interpolateRdBu).domain([absMax, -absMax]);
const panelWidth = (width - margin.left - margin.right - 40) / 3;
const panelHeight = height - margin.top - margin.bottom;
datasets.forEach((ds, panelIdx) => {
const data = ds.data;
const numRows = data.length;
const numCols = data[0].length;
const cellW = panelWidth / numCols;
const cellH = panelHeight / numRows;
const panelX = margin.left + panelIdx * (panelWidth + 20);
const panelY = margin.top;
const g = svg.append("g")
.attr("transform", `translate(${panelX}, ${panelY})`);
// Title
svg.append("text")
.attr("x", panelX + panelWidth / 2)
.attr("y", 25)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.text(ds.title);
// Draw cells
data.forEach((row, i) => {
row.forEach((val, j) => {
g.append("rect")
.attr("x", j * cellW)
.attr("y", i * cellH)
.attr("width", cellW)
.attr("height", cellH)
.attr("fill", colorScale(val));
});
});
// X label
g.append("text")
.attr("x", panelWidth / 2)
.attr("y", panelHeight + 30)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text("Dimension");
// Y label (only for first panel)
if (panelIdx === 0) {
g.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -panelHeight / 2)
.attr("y", -35)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text("Position");
}
});
return svg.node();
}
```
### Embedding Similarity
Embeddings capture meaning - similar tokens should have similar embeddings:
```{python}
#| output: false
# Let's simulate "training" by manually setting some embeddings to be similar
# In practice, these patterns emerge from training on real text
vocab_size = 20
embed_dim = 16
token_emb = nn.Embedding(vocab_size, embed_dim)
# Manually set some tokens to have similar embeddings
# (simulating what would happen after training on related words)
with torch.no_grad():
# Tokens 0-4: "numbers" (similar to each other)
base_number = torch.randn(embed_dim)
for i in range(5):
token_emb.weight[i] = base_number + torch.randn(embed_dim) * 0.1
# Tokens 5-9: "letters" (similar to each other, different from numbers)
base_letter = torch.randn(embed_dim)
for i in range(5, 10):
token_emb.weight[i] = base_letter + torch.randn(embed_dim) * 0.1
# Compute all pairwise similarities
all_embeds = token_emb.weight[:10]
all_embeds_norm = all_embeds / all_embeds.norm(dim=1, keepdim=True)
similarity_matrix = (all_embeds_norm @ all_embeds_norm.T).detach().numpy()
# PCA via SVD for 2D visualization (no sklearn needed)
embeddings_np = all_embeds.detach().numpy()
centered = embeddings_np - embeddings_np.mean(axis=0)
U, S, Vt = np.linalg.svd(centered, full_matrices=False)
emb_2d = U[:, :2] * S[:2] # Project to 2D
variance_explained = (S**2) / (S**2).sum()
ojs_define(
token_similarity=similarity_matrix.tolist(),
pca_coords=emb_2d.tolist(),
pca_variance=[float(variance_explained[0]), float(variance_explained[1])]
)
print("Notice: Numbers (N) are similar to each other, letters (L) are similar to each other,")
print("but numbers and letters are different from each other.")
```
```{ojs}
//| echo: false
// Token Embedding Similarity Matrix
tokenSimilarityViz = {
const width = 500;
const height = 420;
const margin = { top: 55, right: 80, bottom: 55, left: 65 };
const theme = diagramTheme;
const data = token_similarity;
const n = data.length;
const labels = ['N0', 'N1', 'N2', 'N3', 'N4', 'L0', 'L1', 'L2', 'L3', 'L4'];
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 10);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text("Token Embedding Similarity");
svg.append("text")
.attr("x", width / 2)
.attr("y", 40)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.7)
.text('(0-4: "numbers", 5-9: "letters")');
const innerWidth = width - margin.left - margin.right;
const innerHeight = height - margin.top - margin.bottom;
const g = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
const cellSize = Math.min(innerWidth, innerHeight) / n;
// Color scale - RdBu
const colorScale = d3.scaleSequential(d3.interpolateRdBu).domain([1, -1]);
// Draw cells
data.forEach((row, i) => {
row.forEach((val, j) => {
g.append("rect")
.attr("x", j * cellSize)
.attr("y", i * cellSize)
.attr("width", cellSize - 1)
.attr("height", cellSize - 1)
.attr("fill", colorScale(val))
.attr("rx", 2);
});
});
// Tick labels - use theme colors for categories
labels.forEach((label, idx) => {
g.append("text")
.attr("x", idx * cellSize + cellSize / 2)
.attr("y", n * cellSize + 18)
.attr("text-anchor", "middle")
.attr("fill", idx < 5 ? theme.accent : theme.highlight)
.attr("font-size", "10px")
.attr("font-weight", "500")
.text(label);
g.append("text")
.attr("x", -8)
.attr("y", idx * cellSize + cellSize / 2 + 4)
.attr("text-anchor", "end")
.attr("fill", idx < 5 ? theme.accent : theme.highlight)
.attr("font-size", "10px")
.attr("font-weight", "500")
.text(label);
});
// Axis labels
g.append("text")
.attr("x", n * cellSize / 2)
.attr("y", n * cellSize + 38)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Token ID");
g.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -n * cellSize / 2)
.attr("y", -45)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Token ID");
// Legend
const legendX = width - 55;
const legendY = margin.top;
const legendH = n * cellSize;
const legendW = 15;
const legendGrad = svg.append("defs")
.append("linearGradient")
.attr("id", "tok-sim-legend")
.attr("x1", "0%")
.attr("y1", "0%")
.attr("x2", "0%")
.attr("y2", "100%");
legendGrad.append("stop").attr("offset", "0%").attr("stop-color", colorScale(-1));
legendGrad.append("stop").attr("offset", "50%").attr("stop-color", colorScale(0));
legendGrad.append("stop").attr("offset", "100%").attr("stop-color", colorScale(1));
svg.append("rect")
.attr("x", legendX)
.attr("y", legendY)
.attr("width", legendW)
.attr("height", legendH)
.attr("fill", "url(#tok-sim-legend)")
.attr("rx", 3);
svg.append("text")
.attr("x", legendX + legendW + 5)
.attr("y", legendY + 8)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text("-1");
svg.append("text")
.attr("x", legendX + legendW + 5)
.attr("y", legendY + legendH / 2 + 3)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text("0");
svg.append("text")
.attr("x", legendX + legendW + 5)
.attr("y", legendY + legendH)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.text("+1");
return svg.node();
}
```
### 2D Visualization with PCA
```{ojs}
//| echo: false
// PCA 2D Visualization
pcaViz = {
const width = 550;
const height = 450;
const margin = { top: 55, right: 100, bottom: 55, left: 65 };
const theme = diagramTheme;
const coords = pca_coords;
const variance = pca_variance;
const labels = ['N0', 'N1', 'N2', 'N3', 'N4', 'L0', 'L1', 'L2', 'L3', 'L4'];
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 10);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text("Token Embeddings in 2D");
svg.append("text")
.attr("x", width / 2)
.attr("y", 40)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.7)
.text("(Similar tokens cluster together)");
const innerWidth = width - margin.left - margin.right;
const innerHeight = height - margin.top - margin.bottom;
const g = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
// Compute scales
const xVals = coords.map(c => c[0]);
const yVals = coords.map(c => c[1]);
const xPad = (Math.max(...xVals) - Math.min(...xVals)) * 0.15;
const yPad = (Math.max(...yVals) - Math.min(...yVals)) * 0.15;
const xScale = d3.scaleLinear()
.domain([Math.min(...xVals) - xPad, Math.max(...xVals) + xPad])
.range([0, innerWidth]);
const yScale = d3.scaleLinear()
.domain([Math.min(...yVals) - yPad, Math.max(...yVals) + yPad])
.range([innerHeight, 0]);
// Grid
g.append("line")
.attr("x1", 0)
.attr("x2", innerWidth)
.attr("y1", yScale(0))
.attr("y2", yScale(0))
.attr("stroke", theme.nodeStroke)
.attr("stroke-dasharray", "4,4")
.attr("opacity", 0.4);
g.append("line")
.attr("x1", xScale(0))
.attr("x2", xScale(0))
.attr("y1", 0)
.attr("y2", innerHeight)
.attr("stroke", theme.nodeStroke)
.attr("stroke-dasharray", "4,4")
.attr("opacity", 0.4);
// Axes
g.append("g")
.attr("transform", `translate(0, ${innerHeight})`)
.call(d3.axisBottom(xScale).ticks(5))
.call(g => g.selectAll("text").attr("fill", theme.nodeText).attr("font-size", "9px"))
.call(g => g.selectAll("line").attr("stroke", theme.nodeStroke))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke));
g.append("g")
.call(d3.axisLeft(yScale).ticks(5))
.call(g => g.selectAll("text").attr("fill", theme.nodeText).attr("font-size", "9px"))
.call(g => g.selectAll("line").attr("stroke", theme.nodeStroke))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke));
// Axis labels
g.append("text")
.attr("x", innerWidth / 2)
.attr("y", innerHeight + 40)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text(`PC1 (${(variance[0] * 100).toFixed(1)}% variance)`);
g.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -innerHeight / 2)
.attr("y", -45)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text(`PC2 (${(variance[1] * 100).toFixed(1)}% variance)`);
// Draw points - use theme colors for categories
coords.forEach((coord, i) => {
const isNumber = i < 5;
const color = isNumber ? theme.accent : theme.highlight;
g.append("circle")
.attr("cx", xScale(coord[0]))
.attr("cy", yScale(coord[1]))
.attr("r", 8)
.attr("fill", color)
.attr("opacity", 0.8);
g.append("text")
.attr("x", xScale(coord[0]) + 12)
.attr("y", yScale(coord[1]) + 4)
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.text(labels[i]);
});
// Legend - use theme colors for categories
const legendX = innerWidth + 20;
const legendY = 30;
g.append("circle")
.attr("cx", legendX + 8)
.attr("cy", legendY)
.attr("r", 6)
.attr("fill", theme.accent);
g.append("text")
.attr("x", legendX + 20)
.attr("y", legendY + 4)
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Numbers");
g.append("circle")
.attr("cx", legendX + 8)
.attr("cy", legendY + 24)
.attr("r", 6)
.attr("fill", theme.highlight);
g.append("text")
.attr("x", legendX + 20)
.attr("y", legendY + 28)
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Letters");
return svg.node();
}
```
## Interactive Exploration
Explore how sinusoidal position encodings create unique patterns for each position. The key insight: low dimensions change rapidly (high frequency), while high dimensions change slowly (low frequency).
```{ojs}
//| echo: false
// Sinusoidal position encoding formula
function positionEncoding(pos, dim, embedDim) {
const i = Math.floor(dim / 2);
const freq = 1 / Math.pow(10000, (2 * i) / embedDim);
if (dim % 2 === 0) {
return Math.sin(pos * freq);
} else {
return Math.cos(pos * freq);
}
}
// Generate full PE matrix
function generatePEMatrix(maxPos, embedDim) {
const matrix = [];
for (let pos = 0; pos < maxPos; pos++) {
for (let dim = 0; dim < embedDim; dim++) {
matrix.push({
pos,
dim,
value: positionEncoding(pos, dim, embedDim)
});
}
}
return matrix;
}
// Generate encoding for a single position
function getPositionVector(pos, embedDim) {
const vec = [];
for (let dim = 0; dim < embedDim; dim++) {
vec.push({
dim,
value: positionEncoding(pos, dim, embedDim),
type: dim % 2 === 0 ? "sin" : "cos"
});
}
return vec;
}
// Cosine similarity between two positions
function cosineSimilarity(pos1, pos2, embedDim) {
let dotProduct = 0;
let norm1 = 0;
let norm2 = 0;
for (let dim = 0; dim < embedDim; dim++) {
const v1 = positionEncoding(pos1, dim, embedDim);
const v2 = positionEncoding(pos2, dim, embedDim);
dotProduct += v1 * v2;
norm1 += v1 * v1;
norm2 += v2 * v2;
}
return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
}
```
```{ojs}
//| echo: false
// Theme colors for light/dark mode - use diagramTheme from _diagram-lib.qmd
// This provides consistent theming with CSS variables
theme = {
textPrimary: diagramTheme.nodeText,
textMuted: diagramTheme.edgeStroke,
ruleStroke: diagramTheme.nodeStroke,
highlightStroke: diagramTheme.nodeText,
lineBlue: diagramTheme.accent,
dotSin: diagramTheme.accent,
dotCos: diagramTheme.highlight,
compareSecondary: diagramTheme.highlight,
statusGreen: diagramTheme.highlight,
statusAmber: diagramTheme.accent,
statusGray: diagramTheme.edgeStroke
}
```
```{ojs}
//| echo: false
// Configuration
maxPos = 64
embedDim = 64
viewof selectedPos = Inputs.range([0, maxPos - 1], {
value: 0,
step: 1,
label: "Highlight Position"
})
viewof comparePos = Inputs.range([0, maxPos - 1], {
value: 5,
step: 1,
label: "Compare With Position"
})
```
```{ojs}
//| echo: false
// Generate data
peMatrix = generatePEMatrix(maxPos, embedDim)
selectedVector = getPositionVector(selectedPos, embedDim)
compareVector = getPositionVector(comparePos, embedDim)
similarity = cosineSimilarity(selectedPos, comparePos, embedDim)
```
```{ojs}
//| echo: false
Plot = import("https://esm.sh/@observablehq/plot@0.6")
// Heatmap of full PE matrix
Plot.plot({
title: "Position Encoding Matrix",
subtitle: `${maxPos} positions × ${embedDim} dimensions | Position ${selectedPos} highlighted`,
width: 650,
height: 300,
marginLeft: 60,
marginBottom: 50,
x: {
label: "Dimension →",
tickSpacing: 8
},
y: {
label: "↑ Position",
tickSpacing: 8
},
color: {
scheme: "RdBu",
domain: [-1, 1],
legend: true,
label: "Value"
},
marks: [
Plot.cell(peMatrix, {
x: "dim",
y: "pos",
fill: "value",
tip: true,
title: d => `Position ${d.pos}, Dim ${d.dim}\nValue: ${d.value.toFixed(3)}`
}),
// Highlight selected position row
Plot.ruleY([selectedPos], {
stroke: theme.highlightStroke,
strokeWidth: 2
})
]
})
```
```{ojs}
//| echo: false
// Line plot for selected position
Plot.plot({
title: `Position ${selectedPos} Encoding Vector`,
subtitle: "Each position has a unique pattern of sin/cos values",
width: 650,
height: 200,
marginLeft: 50,
x: {
label: "Dimension →"
},
y: {
label: "↑ Value",
domain: [-1.2, 1.2]
},
marks: [
Plot.ruleY([0], { stroke: theme.ruleStroke }),
Plot.line(selectedVector, {
x: "dim",
y: "value",
stroke: theme.lineBlue,
strokeWidth: 2
}),
Plot.dot(selectedVector, {
x: "dim",
y: "value",
fill: d => d.type === "sin" ? theme.dotSin : theme.dotCos,
r: 3,
tip: true,
title: d => `Dim ${d.dim} (${d.type}): ${d.value.toFixed(3)}`
})
]
})
```
```{ojs}
//| echo: false
// Similarity comparison
md`### Position Similarity
**Position ${selectedPos}** vs **Position ${comparePos}**: Cosine Similarity = **${similarity.toFixed(4)}**
${Math.abs(selectedPos - comparePos) <= 3
? `<span style="color: ${theme.statusGreen}">✓ Nearby positions (distance ${Math.abs(selectedPos - comparePos)}) have high similarity</span>`
: Math.abs(selectedPos - comparePos) >= 20
? `<span style="color: ${theme.statusAmber}">→ Distant positions (distance ${Math.abs(selectedPos - comparePos)}) have lower similarity</span>`
: `<span style="color: ${theme.statusGray}">→ Moderate distance (${Math.abs(selectedPos - comparePos)} apart)</span>`
}`
```
```{ojs}
//| echo: false
// Show both position vectors overlaid for comparison
Plot.plot({
title: `Comparing Position ${selectedPos} vs Position ${comparePos}`,
width: 650,
height: 180,
marginLeft: 50,
x: {
label: "Dimension →"
},
y: {
label: "↑ Value",
domain: [-1.2, 1.2]
},
color: {
domain: [`Position ${selectedPos}`, `Position ${comparePos}`],
range: [theme.lineBlue, theme.compareSecondary]
},
marks: [
Plot.ruleY([0], { stroke: theme.ruleStroke }),
Plot.line(selectedVector.map(d => ({...d, position: `Position ${selectedPos}`})), {
x: "dim",
y: "value",
stroke: "position",
strokeWidth: 2
}),
Plot.line(compareVector.map(d => ({...d, position: `Position ${comparePos}`})), {
x: "dim",
y: "value",
stroke: "position",
strokeWidth: 2,
strokeDasharray: "4,2"
})
]
})
```
::: {.callout-tip}
## Try This
1. **Frequency gradient**: Look at the heatmap from left to right. Low dimensions (left) have rapid oscillation, high dimensions (right) change slowly.
2. **Adjacent positions**: Set positions to 0 and 1. Notice high similarity (≈0.99+). The encodings are almost identical, differing only slightly.
3. **Distant positions**: Compare positions 0 and 32. Similarity drops significantly because more dimension-waves have cycled.
4. **Unique fingerprints**: Slide through different positions in the line plot. Each position has a unique "fingerprint" pattern.
5. **Sin/Cos pairs**: In the line plot, blue dots are sin (even dims), orange dots are cos (odd dims). They're 90° out of phase.
:::
## Exercises
### Exercise 1: Compare Learned vs Sinusoidal Positional Embeddings
Learned and sinusoidal embeddings have different tradeoffs:
| Aspect | Learned | Sinusoidal |
|--------|---------|------------|
| Training | Updated via backprop | Fixed (no parameters) |
| Extrapolation | Cannot generalize beyond max_seq_len | Can theoretically extrapolate |
| Memory | Adds parameters | Zero parameter overhead |
| Performance | Outperforms sinusoidal in most benchmarks | Good baseline |
```{python}
#| output: false
# Compare learned vs sinusoidal positional embeddings
learned = nn.Embedding(50, 32) # Learned (random initialization)
sinusoidal = create_sinusoidal_encoding(50, 32) # Fixed pattern
ojs_define(
learned_pe=learned.weight.detach().numpy().tolist(),
sinusoidal_pe=sinusoidal.numpy().tolist()
)
print("Learned embeddings start random but are trained to capture position.")
print("Sinusoidal embeddings have a fixed pattern that encodes relative positions.")
```
```{ojs}
//| echo: false
// Learned vs Sinusoidal Positional Embeddings Comparison
learnedVsSinusoidalViz = {
const width = 700;
const height = 280;
const margin = { top: 50, right: 20, bottom: 45, left: 55 };
const theme = diagramTheme;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 10);
// Data
const datasets = [
{ data: learned_pe, title: "Learned Positional Embeddings", subtitle: "(Random initialization)" },
{ data: sinusoidal_pe, title: "Sinusoidal Positional Embeddings", subtitle: "(Fixed pattern)" }
];
// Find global min/max for consistent color scale
const allVals = [...learned_pe.flat(), ...sinusoidal_pe.flat()];
const absMax = Math.max(Math.abs(Math.min(...allVals)), Math.abs(Math.max(...allVals)));
const colorScale = d3.scaleSequential(d3.interpolateRdBu).domain([absMax, -absMax]);
const panelWidth = (width - margin.left - margin.right - 30) / 2;
const panelHeight = height - margin.top - margin.bottom;
datasets.forEach((ds, panelIdx) => {
const data = ds.data;
const numRows = data.length;
const numCols = data[0].length;
const cellW = panelWidth / numCols;
const cellH = panelHeight / numRows;
const panelX = margin.left + panelIdx * (panelWidth + 30);
const panelY = margin.top;
const g = svg.append("g")
.attr("transform", `translate(${panelX}, ${panelY})`);
// Title
svg.append("text")
.attr("x", panelX + panelWidth / 2)
.attr("y", 20)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(ds.title);
svg.append("text")
.attr("x", panelX + panelWidth / 2)
.attr("y", 35)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.7)
.text(ds.subtitle);
// Draw cells
data.forEach((row, i) => {
row.forEach((val, j) => {
g.append("rect")
.attr("x", j * cellW)
.attr("y", i * cellH)
.attr("width", cellW)
.attr("height", cellH)
.attr("fill", colorScale(val));
});
});
// X label
g.append("text")
.attr("x", panelWidth / 2)
.attr("y", panelHeight + 30)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text("Dimension");
// Y label
g.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -panelHeight / 2)
.attr("y", -40)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text("Position");
});
return svg.node();
}
```
### Exercise 2: Effect of Embedding Dimension
The embedding dimension affects both model capacity and computational cost. Larger dimensions can represent more nuanced semantic distinctions but require more memory and computation in every layer of the model.
The `sqrt(embed_dim)` scaling factor is crucial: without it, the magnitude of embeddings would vary significantly with dimension, since randomly initialized vectors of higher dimension have larger expected norms.
```{python}
# What happens with different embedding dimensions?
for dim in [8, 32, 128, 512]:
emb = TransformerEmbedding(
vocab_size=1000,
embed_dim=dim,
max_seq_len=128,
dropout=0.0
)
tokens = torch.randint(0, 1000, (1, 32))
output = emb(tokens)
# Compute variance of output
variance = output.var().item()
print(f"Embed dim {dim:3d}: output variance = {variance:.4f}")
print("\nThe scale factor (sqrt(embed_dim)) helps keep variance stable!")
```
### Exercise 3: Memory Usage of Embeddings
Embedding memory determines model sizing. With weight tying (sharing embeddings between input and output layers), you pay this cost once. Without weight tying, you pay twice.
```{python}
# Memory usage of embeddings
configs = [
{"vocab": 1000, "dim": 64, "name": "Tiny"},
{"vocab": 8000, "dim": 256, "name": "Small"},
{"vocab": 32000, "dim": 512, "name": "Medium"},
{"vocab": 50000, "dim": 768, "name": "Large (GPT-2)"},
{"vocab": 100000, "dim": 4096, "name": "Large (LLaMA)"},
]
print("Embedding Table Memory Usage:")
print("=" * 60)
for cfg in configs:
params = cfg["vocab"] * cfg["dim"]
memory_mb = params * 4 / (1024 * 1024) # 4 bytes per float32
memory_fp16 = memory_mb / 2 # fp16/bf16 halves memory
print(f"{cfg['name']:15s}: {cfg['vocab']:6d} vocab x {cfg['dim']:4d} dim = "
f"{params:>12,} params ({memory_mb:>7.1f} MB fp32, {memory_fp16:>6.1f} MB fp16)")
```
Note: Modern models typically use fp16 or bf16, which halves the memory requirement. Quantization (int8, int4) can reduce it further.
**Weight tying** shares the embedding matrix between the input layer and output projection, halving the embedding parameter count:
```{python}
# Weight tying: share embedding weights with output projection
# This is what GPT-2, LLaMA, and most modern LLMs do
import torch.nn as nn
class SimpleLMWithWeightTying(nn.Module):
def __init__(self, vocab_size, embed_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
# Output projection shares weights with embedding (transposed)
self.output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
# Tie weights: output projection uses the same weights as embedding
self.output_proj.weight = self.embedding.weight
def forward(self, x):
emb = self.embedding(x) # (batch, seq, embed_dim)
logits = self.output_proj(emb) # (batch, seq, vocab_size)
return logits
model = SimpleLMWithWeightTying(vocab_size=1000, embed_dim=256)
print(f"Embedding params: {model.embedding.weight.numel():,}")
print(f"Output proj params: {model.output_proj.weight.numel():,}")
print(f"Are weights shared? {model.embedding.weight is model.output_proj.weight}")
```
## Using the Module's Embeddings
The `embeddings.py` file contains production-ready embedding classes:
```{python}
from embeddings import (
TokenEmbedding,
LearnedPositionalEmbedding,
SinusoidalPositionalEmbedding,
TransformerEmbedding as ModuleTransformerEmbedding,
demonstrate_embeddings
)
# Run the demonstration
demo_emb = demonstrate_embeddings(
vocab_size=100,
embed_dim=32,
seq_len=8,
verbose=True
)
```
## Summary
Key takeaways:
1. **Token embeddings** are lookup tables that convert token IDs to vectors
2. **Positional embeddings** add information about where tokens are in the sequence
3. **Sinusoidal** positional embeddings use fixed sin/cos patterns - no parameters, can theoretically extrapolate
4. **Learned** positional embeddings are trained like any other parameter - outperform sinusoidal in most benchmarks
5. **Modern approaches** (RoPE, ALiBi) handle position differently and extrapolate better to long sequences
6. **Similar tokens** end up with similar embeddings after training - capturing semantic relationships
7. **Scaling by sqrt(embed_dim)** helps maintain stable gradients when dimensions vary
8. **Weight tying** between input embeddings and output layer is common and reduces parameters
### Common Pitfalls
- **Forgetting the sqrt scale**: Without this scaling, positional embeddings can dominate or be ignored depending on embed_dim
- **Exceeding max_seq_len**: Learned positional embeddings fail hard on longer sequences than training
- **Ignoring padding**: Padding tokens should be zero vectors and excluded from gradients
- **Poor initialization**: Large initial values cause training instability; use small std (0.01-0.02)
## Going Deeper
- [Word2Vec](https://arxiv.org/abs/1301.3781) - Original word embeddings paper (Mikolov et al., 2013)
- [Attention Is All You Need](https://arxiv.org/abs/1706.03762) - Section 3.5 on positional encoding
- [RoFormer: Rotary Position Embedding (RoPE)](https://arxiv.org/abs/2104.09864) - Used in LLaMA, Mistral
- [ALiBi: Train Short, Test Long](https://arxiv.org/abs/2108.12409) - Position as attention bias
- [Using the Output Embedding to Improve Language Models](https://arxiv.org/abs/1608.05859) - Weight tying paper
## What's Next
[Module 05: Attention](../m05_attention/lesson.qmd) covers the core mechanism that allows tokens to "look at" each other. Attention operates on embedding vectors. Position information encoded here becomes essential for computing relationships between tokens.