---
title: "Module 06: Transformer"
format:
html:
code-fold: false
toc: true
ipynb: default
jupyter: python3
---
{{< include ../_diagram-lib.qmd >}}
{{< include ../_components/step-control.qmd >}}
## Introduction
The **transformer decoder block** is the building block of GPT-style language models. GPT-2, GPT-3, and LLaMA stack 12 to 96 of these blocks.
This module combines everything built so far:
- **Multi-head attention** from Module 05
- **Feed-forward networks** (mini neural networks for each token)
- **Layer normalization** (stabilizes training)
- **Residual connections** (enables deep networks)
Each block performs two operations:
1. **Multi-head attention**: Tokens communicate with each other
2. **Feed-forward network**: Each token is processed independently
### Decoder-Only vs Encoder-Decoder
This module implements **decoder-only** transformers (GPT-style). The two main transformer architectures:
| Architecture | Examples | Use Case | Attention |
|-------------|----------|----------|-----------|
| Decoder-only | GPT, LLaMA, Claude | Text generation | Causal (can't see future) |
| Encoder-Decoder | T5, BART, original Transformer | Translation, summarization | Bidirectional encoder + causal decoder |
Most modern LLMs use decoder-only architecture because it is simpler; we follow that approach.
### What You'll Learn
After this module, you can:
- Understand the complete GPT-style transformer architecture
- Implement LayerNorm, GELU, and feed-forward networks from scratch
- Build a full transformer block with residual connections
- Assemble a complete language model from components
- Calculate parameter counts for different model sizes
### Prerequisites
This module requires familiarity with:
- [Module 04: Embeddings](../m04_embeddings/lesson.qmd) — Token and positional embeddings
- [Module 05: Attention](../m05_attention/lesson.qmd) — Multi-head attention mechanism
## Complete Model Architecture
```{ojs}
//| echo: false
// Step definitions for the forward pass walkthrough
architectureSteps = [
{
id: 0,
name: "Input",
description: "Token IDs enter the model as integer indices into the vocabulary.",
shape: "[batch, seq_len]",
example: "[1, 4] → 4 token IDs"
},
{
id: 1,
name: "Token Embedding",
description: "Each token ID is mapped to a dense vector via learned embedding lookup.",
shape: "[batch, seq_len, embed_dim]",
example: "[1, 4, 128] → 4 vectors of dim 128"
},
{
id: 2,
name: "Position Embedding",
description: "Position information is added so the model knows token order.",
shape: "[batch, seq_len, embed_dim]",
example: "pos[0..3] added to each token"
},
{
id: 3,
name: "Transformer Blocks",
description: "N stacked blocks refine representations through attention and FFN.",
shape: "[batch, seq_len, embed_dim]",
example: "6 blocks × (attention + FFN)"
},
{
id: 4,
name: "Final LayerNorm",
description: "Normalize activations before the output projection.",
shape: "[batch, seq_len, embed_dim]",
example: "Stabilize for prediction"
},
{
id: 5,
name: "LM Head → Logits",
description: "Project to vocabulary size. Each position predicts the next token.",
shape: "[batch, seq_len, vocab_size]",
example: "[1, 4, 10000] → scores for each word"
}
]
```
```{ojs}
//| echo: false
// Step control for architecture walkthrough
viewof archStep = stepControl({min: 0, max: 5, value: 0, label: "Architecture Step"})
```
```{ojs}
//| echo: false
// Current step info
currentArchStep = architectureSteps[archStep]
```
```{ojs}
//| echo: false
// Draw the GPT architecture diagram
transformerArchitectureDiagram = {
const width = 680;
const height = 620;
const marginLeft = 40;
const marginRight = 180;
const marginTop = 50;
const marginBottom = 30;
// Component dimensions
const nodeWidth = 140;
const nodeHeight = 44;
const blockWidth = 120;
const blockHeight = 36;
const numBlocks = 6;
const blockGap = 6;
// Vertical layout
const inputY = marginTop + 30;
const tokenEmbY = inputY + 70;
const posEmbY = tokenEmbY;
const addY = tokenEmbY + 65;
const blocksStartY = addY + 75;
const blocksEndY = blocksStartY + numBlocks * (blockHeight + blockGap) - blockGap;
const lnFinalY = blocksEndY + 65;
const lmHeadY = lnFinalY + 60;
const logitsY = lmHeadY + 65;
// Horizontal positions
const centerX = marginLeft + (width - marginLeft - marginRight) / 2;
const tokenEmbX = centerX - 80;
const posEmbX = centerX + 80;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 10);
// Defs for arrows and filters
const defs = svg.append("defs");
// Glow filter
const glowFilter = defs.append("filter")
.attr("id", "arch-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "4")
.attr("result", "coloredBlur");
const glowMerge = glowFilter.append("feMerge");
glowMerge.append("feMergeNode").attr("in", "coloredBlur");
glowMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Arrow markers
defs.append("marker")
.attr("id", "arch-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.edgeStroke);
defs.append("marker")
.attr("id", "arch-arrow-active")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.highlight);
// Color scheme for component types
const colors = {
input: diagramTheme.accent,
embedding: "#6366f1", // Indigo
blocks: "#059669", // Emerald
output: "#dc2626" // Red
};
// Helper to draw a node
function drawNode(x, y, w, h, label, sublabel, stepId, colorType) {
const isActive = archStep === stepId;
const baseColor = colors[colorType] || diagramTheme.nodeFill;
const g = svg.append("g")
.attr("transform", `translate(${x}, ${y})`);
g.append("rect")
.attr("x", -w/2)
.attr("y", -h/2)
.attr("width", w)
.attr("height", h)
.attr("rx", 6)
.attr("fill", isActive ? baseColor : diagramTheme.nodeFill)
.attr("stroke", isActive ? baseColor : diagramTheme.nodeStroke)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.attr("opacity", isActive ? 1 : 0.85)
.style("filter", isActive ? "url(#arch-glow)" : "none");
const textColor = isActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText;
if (sublabel) {
g.append("text")
.attr("y", -6)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(label);
g.append("text")
.attr("y", 10)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "9px")
.attr("opacity", 0.8)
.text(sublabel);
} else {
g.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(label);
}
return g;
}
// Helper to draw an arrow
function drawArrow(x1, y1, x2, y2, isActive) {
svg.append("path")
.attr("d", `M${x1},${y1} L${x2},${y2}`)
.attr("fill", "none")
.attr("stroke", isActive ? diagramTheme.highlight : diagramTheme.edgeStroke)
.attr("stroke-width", isActive ? 2 : 1.5)
.attr("marker-end", isActive ? "url(#arch-arrow-active)" : "url(#arch-arrow)")
.attr("opacity", isActive ? 1 : 0.6);
}
// Title
svg.append("text")
.attr("x", centerX)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "700")
.text("GPT Architecture: Forward Pass");
// === Draw components ===
// Input tokens
drawNode(centerX, inputY, nodeWidth, nodeHeight, "Token IDs", "[23, 156, 42, 789]", 0, "input");
// Arrows from input to embeddings
const splitY = inputY + nodeHeight/2 + 10;
drawArrow(centerX - 20, inputY + nodeHeight/2, tokenEmbX, tokenEmbY - nodeHeight/2 - 8, archStep === 0 || archStep === 1);
drawArrow(centerX + 20, inputY + nodeHeight/2, posEmbX, posEmbY - nodeHeight/2 - 8, archStep === 0 || archStep === 1);
// Token Embedding
drawNode(tokenEmbX, tokenEmbY, nodeWidth, nodeHeight, "Token Embed", "(vocab, embed_dim)", 1, "embedding");
// Position Embedding
drawNode(posEmbX, posEmbY, nodeWidth, nodeHeight, "Pos Embed", "(max_seq, embed_dim)", 2, "embedding");
// Arrows to Add
drawArrow(tokenEmbX, tokenEmbY + nodeHeight/2, centerX - 10, addY - 18 - 8, archStep === 1 || archStep === 2);
drawArrow(posEmbX, posEmbY + nodeHeight/2, centerX + 10, addY - 18 - 8, archStep === 2);
// Add circle
const addActive = archStep === 1 || archStep === 2;
svg.append("circle")
.attr("cx", centerX)
.attr("cy", addY)
.attr("r", 18)
.attr("fill", addActive ? colors.embedding : diagramTheme.nodeFill)
.attr("stroke", addActive ? colors.embedding : diagramTheme.nodeStroke)
.attr("stroke-width", addActive ? 2 : 1.5)
.style("filter", addActive ? "url(#arch-glow)" : "none");
svg.append("text")
.attr("x", centerX)
.attr("y", addY)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", addActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "16px")
.attr("font-weight", "bold")
.text("+");
// Arrow to blocks
drawArrow(centerX, addY + 18, centerX, blocksStartY - blockHeight/2 - 8, archStep === 2 || archStep === 3);
// Transformer Blocks (stacked)
const blocksActive = archStep === 3;
// Background container for blocks
const blocksContainerPad = 12;
const blocksContainerHeight = numBlocks * (blockHeight + blockGap) - blockGap + blocksContainerPad * 2;
svg.append("rect")
.attr("x", centerX - blockWidth/2 - blocksContainerPad)
.attr("y", blocksStartY - blockHeight/2 - blocksContainerPad)
.attr("width", blockWidth + blocksContainerPad * 2)
.attr("height", blocksContainerHeight)
.attr("rx", 8)
.attr("fill", "none")
.attr("stroke", blocksActive ? colors.blocks : diagramTheme.nodeStroke)
.attr("stroke-width", blocksActive ? 2 : 1)
.attr("stroke-dasharray", blocksActive ? "none" : "4,3")
.attr("opacity", blocksActive ? 1 : 0.5);
// Label for blocks section
svg.append("text")
.attr("x", centerX + blockWidth/2 + blocksContainerPad + 10)
.attr("y", blocksStartY + blocksContainerHeight/2 - blockHeight/2 - blocksContainerPad)
.attr("text-anchor", "start")
.attr("dominant-baseline", "central")
.attr("fill", blocksActive ? colors.blocks : diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", blocksActive ? "600" : "400")
.attr("opacity", blocksActive ? 1 : 0.6)
.text("× N Blocks");
// Individual blocks
for (let i = 0; i < numBlocks; i++) {
const blockY = blocksStartY + i * (blockHeight + blockGap);
const g = svg.append("g")
.attr("transform", `translate(${centerX}, ${blockY})`);
g.append("rect")
.attr("x", -blockWidth/2)
.attr("y", -blockHeight/2)
.attr("width", blockWidth)
.attr("height", blockHeight)
.attr("rx", 5)
.attr("fill", blocksActive ? colors.blocks : diagramTheme.nodeFill)
.attr("stroke", blocksActive ? colors.blocks : diagramTheme.nodeStroke)
.attr("stroke-width", blocksActive ? 2 : 1)
.attr("opacity", blocksActive ? 1 : 0.7);
g.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", blocksActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("font-weight", "500")
.text(`Block ${i}`);
// Arrow between blocks (except last)
if (i < numBlocks - 1) {
const nextY = blockY + blockHeight/2 + blockGap/2;
svg.append("line")
.attr("x1", centerX)
.attr("y1", blockY + blockHeight/2)
.attr("x2", centerX)
.attr("y2", blockY + blockHeight + blockGap - blockHeight/2)
.attr("stroke", blocksActive ? colors.blocks : diagramTheme.edgeStroke)
.attr("stroke-width", 1)
.attr("opacity", blocksActive ? 0.8 : 0.4);
}
}
// Arrow to Final LayerNorm
drawArrow(centerX, blocksEndY + blockHeight/2, centerX, lnFinalY - nodeHeight/2 - 8, archStep === 3 || archStep === 4);
// Final LayerNorm
drawNode(centerX, lnFinalY, nodeWidth, nodeHeight, "Final LayerNorm", "Normalize", 4, "output");
// Arrow to LM Head
drawArrow(centerX, lnFinalY + nodeHeight/2, centerX, lmHeadY - nodeHeight/2 - 8, archStep === 4 || archStep === 5);
// LM Head
drawNode(centerX, lmHeadY, nodeWidth, nodeHeight, "LM Head", "Linear → vocab_size", 5, "output");
// Arrow to Logits
drawArrow(centerX, lmHeadY + nodeHeight/2, centerX, logitsY - nodeHeight/2 - 8, archStep === 5);
// Logits
drawNode(centerX, logitsY, nodeWidth, nodeHeight, "Logits", "[batch, seq, vocab]", 5, "output");
// === Info panel on the right ===
const infoX = width - marginRight + 20;
const infoY = marginTop + 50;
const infoWidth = marginRight - 30;
// Info panel background
svg.append("rect")
.attr("x", infoX - 10)
.attr("y", infoY - 10)
.attr("width", infoWidth)
.attr("height", 150)
.attr("rx", 8)
.attr("fill", diagramTheme.bgSecondary)
.attr("stroke", diagramTheme.nodeStroke)
.attr("stroke-width", 1)
.attr("opacity", 0.5);
// Step indicator
svg.append("text")
.attr("x", infoX)
.attr("y", infoY + 5)
.attr("fill", diagramTheme.highlight)
.attr("font-size", "10px")
.attr("font-weight", "600")
.attr("text-transform", "uppercase")
.attr("letter-spacing", "0.5px")
.text(`Step ${archStep + 1} of 6`);
// Step name
svg.append("text")
.attr("x", infoX)
.attr("y", infoY + 28)
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "13px")
.attr("font-weight", "700")
.text(currentArchStep.name);
// Description (wrap text)
const desc = currentArchStep.description;
const descWords = desc.split(" ");
let descLine = "";
let descY = infoY + 50;
const maxDescWidth = infoWidth - 20;
descWords.forEach((word, i) => {
const testLine = descLine + (descLine ? " " : "") + word;
if (testLine.length > 22) {
svg.append("text")
.attr("x", infoX)
.attr("y", descY)
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.8)
.text(descLine);
descY += 14;
descLine = word;
} else {
descLine = testLine;
}
});
if (descLine) {
svg.append("text")
.attr("x", infoX)
.attr("y", descY)
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.8)
.text(descLine);
}
// Shape info
svg.append("text")
.attr("x", infoX)
.attr("y", infoY + 110)
.attr("fill", diagramTheme.accent)
.attr("font-size", "9px")
.attr("font-weight", "500")
.text("Shape:");
svg.append("text")
.attr("x", infoX)
.attr("y", infoY + 124)
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("font-family", "'JetBrains Mono', monospace")
.text(currentArchStep.shape);
return svg.node();
}
```
::: {.callout-tip}
## Interactive Architecture Walkthrough
Use the slider above to step through the forward pass. Each stage shows how tensor shapes transform as data flows through the model:
- **Input**: Raw token IDs (integers)
- **Embeddings**: Dense vectors capturing meaning and position
- **Blocks**: Iterative refinement through attention and FFN
- **Output**: Probability distribution over vocabulary
:::
## Single Transformer Block (Pre-Norm)
```{ojs}
//| echo: false
// Step definitions for transformer block walkthrough
blockSteps = [
{
id: 0,
name: "Input",
description: "Embeddings enter the transformer block unchanged initially.",
highlight: ["input"],
activeEdges: []
},
{
id: 1,
name: "Attention Path",
description: "Pre-Norm: LayerNorm first, then Multi-Head Attention with causal mask, then Dropout.",
highlight: ["ln1", "attn", "drop1"],
activeEdges: ["e_in_ln1", "e_ln1_attn", "e_attn_drop1"]
},
{
id: 2,
name: "First Residual",
description: "Add original input to attention output: x + Attention(LayerNorm(x)). Skip connection preserves gradients.",
highlight: ["add1", "residual1"],
activeEdges: ["e_in_add1", "e_drop1_add1"]
},
{
id: 3,
name: "FFN Path",
description: "LayerNorm, then Feed-Forward Network (expand 4x, GELU, project back). Each token processed independently.",
highlight: ["ln2", "ffn"],
activeEdges: ["e_add1_ln2", "e_ln2_ffn"]
},
{
id: 4,
name: "Second Residual",
description: "Add intermediate to FFN output: x + FFN(LayerNorm(x)). Two residuals per block enable deep stacking.",
highlight: ["add2", "residual2", "output"],
activeEdges: ["e_add1_add2", "e_ffn_add2", "e_add2_out"]
}
]
```
```{ojs}
//| echo: false
// Step control for transformer block walkthrough
viewof blockStep = stepControl({min: 0, max: 4, value: 0, label: "Block Step"})
```
```{ojs}
//| echo: false
// Current step info
currentBlockStep = blockSteps[blockStep]
```
```{ojs}
//| echo: false
// Transformer block diagram
transformerBlockDiagram = {
const width = 720;
const height = 580;
const theme = diagramTheme;
const step = blockStep;
const stepInfo = currentBlockStep;
const svg = d3.create("svg")
.attr("viewBox", `0 0 ${width} ${height}`)
.attr("width", "100%")
.attr("height", height)
.style("max-width", `${width}px`)
.style("font-family", "'IBM Plex Mono', 'JetBrains Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 12);
// Defs for filters and markers
const defs = svg.append("defs");
// Glow filter
const glowFilter = defs.append("filter")
.attr("id", "block-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "4")
.attr("result", "coloredBlur");
const glowMerge = glowFilter.append("feMerge");
glowMerge.append("feMergeNode").attr("in", "coloredBlur");
glowMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Arrow markers
defs.append("marker")
.attr("id", "block-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", theme.edgeStroke);
defs.append("marker")
.attr("id", "block-arrow-attn")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", "#22d3ee");
defs.append("marker")
.attr("id", "block-arrow-ffn")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", "#a78bfa");
defs.append("marker")
.attr("id", "block-arrow-res")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", "#fb923c");
// Color scheme
const colors = {
attention: "#22d3ee", // Cyan for attention path
ffn: "#a78bfa", // Purple for FFN path
residual: "#fb923c", // Orange for residual connections
norm: "#6b7280", // Gray for LayerNorm
io: theme.accent // Theme accent for input/output
};
// Helper: check if element is active
const isActive = (id) => stepInfo.highlight.includes(id);
const isEdgeActive = (id) => stepInfo.activeEdges.includes(id);
// Layout
const centerX = 280;
const rightX = 500;
const nodeWidth = 140;
const nodeHeight = 44;
// Vertical positions
const inputY = 80;
const ln1Y = 160;
const attnY = 230;
const drop1Y = 300;
const add1Y = 365;
const ln2Y = 430;
const ffnY = 490;
const add2Y = 490;
const outputY = 550;
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 28)
.attr("text-anchor", "middle")
.attr("fill", theme.highlight)
.attr("font-size", "15px")
.attr("font-weight", "700")
.text(`Transformer Block: ${stepInfo.name}`);
svg.append("text")
.attr("x", width / 2)
.attr("y", 50)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.75)
.text(stepInfo.description);
// Helper: draw a node box
function drawNode(x, y, label, sublabel, id, color) {
const active = isActive(id);
const g = svg.append("g").attr("transform", `translate(${x}, ${y})`);
g.append("rect")
.attr("x", -nodeWidth/2)
.attr("y", -nodeHeight/2)
.attr("width", nodeWidth)
.attr("height", nodeHeight)
.attr("rx", 6)
.attr("fill", active ? color : theme.nodeFill)
.attr("fill-opacity", active ? 0.25 : 1)
.attr("stroke", active ? color : theme.nodeStroke)
.attr("stroke-width", active ? 2.5 : 1.5)
.attr("filter", active ? "url(#block-glow)" : null);
g.append("text")
.attr("y", sublabel ? -7 : 0)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", active ? color : theme.nodeText)
.attr("font-size", "12px")
.attr("font-weight", active ? "600" : "500")
.text(label);
if (sublabel) {
g.append("text")
.attr("y", 10)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", active ? color : theme.nodeText)
.attr("font-size", "9px")
.attr("opacity", active ? 0.9 : 0.6)
.text(sublabel);
}
}
// Helper: draw a circle node (for + operations)
function drawCircle(x, y, label, id, color) {
const active = isActive(id);
const g = svg.append("g").attr("transform", `translate(${x}, ${y})`);
g.append("circle")
.attr("r", 20)
.attr("fill", active ? color : theme.nodeFill)
.attr("fill-opacity", active ? 0.3 : 1)
.attr("stroke", active ? color : theme.nodeStroke)
.attr("stroke-width", active ? 2.5 : 1.5)
.attr("filter", active ? "url(#block-glow)" : null);
g.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", active ? color : theme.nodeText)
.attr("font-size", "18px")
.attr("font-weight", "bold")
.text(label);
}
// Helper: draw an arrow
function drawArrow(x1, y1, x2, y2, edgeId, color, curved = false, curveDir = 1) {
const active = isEdgeActive(edgeId);
const markerColor = color === colors.attention ? "block-arrow-attn" :
color === colors.ffn ? "block-arrow-ffn" :
color === colors.residual ? "block-arrow-res" : "block-arrow";
let pathD;
if (curved) {
const midX = (x1 + x2) / 2 + curveDir * 40;
pathD = `M${x1},${y1} Q${midX},${(y1+y2)/2} ${x2},${y2}`;
} else {
pathD = `M${x1},${y1} L${x2},${y2}`;
}
svg.append("path")
.attr("d", pathD)
.attr("fill", "none")
.attr("stroke", active ? color : theme.edgeStroke)
.attr("stroke-width", active ? 2.5 : 1.5)
.attr("stroke-opacity", active ? 1 : 0.5)
.attr("marker-end", active ? `url(#${markerColor})` : "url(#block-arrow)")
.attr("filter", active ? "url(#block-glow)" : null);
}
// === Draw the architecture ===
// Input node
drawNode(centerX, inputY, "Input x", "(batch, seq, embed)", "input", colors.io);
// --- Attention Path (left side) ---
// LayerNorm 1
drawNode(centerX, ln1Y, "LayerNorm", null, "ln1", colors.attention);
// Multi-Head Attention
drawNode(centerX, attnY, "Multi-Head Attention", "(causal mask)", "attn", colors.attention);
// Dropout
drawNode(centerX, drop1Y, "Dropout", null, "drop1", colors.attention);
// Arrows in attention path
drawArrow(centerX, inputY + nodeHeight/2, centerX, ln1Y - nodeHeight/2 - 8, "e_in_ln1", colors.attention);
drawArrow(centerX, ln1Y + nodeHeight/2, centerX, attnY - nodeHeight/2 - 8, "e_ln1_attn", colors.attention);
drawArrow(centerX, attnY + nodeHeight/2, centerX, drop1Y - nodeHeight/2 - 8, "e_attn_drop1", colors.attention);
// First residual add
drawCircle(centerX, add1Y, "+", "add1", colors.residual);
// Arrow from dropout to add1
drawArrow(centerX, drop1Y + nodeHeight/2, centerX, add1Y - 20 - 8, "e_drop1_add1", colors.attention);
// Residual connection 1 (skip from input to add1)
const res1Active = isActive("residual1");
// Draw the skip connection path on the right
svg.append("path")
.attr("d", `M${centerX + nodeWidth/2},${inputY} L${centerX + 90},${inputY} L${centerX + 90},${add1Y} L${centerX + 20},${add1Y}`)
.attr("fill", "none")
.attr("stroke", res1Active ? colors.residual : theme.edgeStroke)
.attr("stroke-width", res1Active ? 2.5 : 1.5)
.attr("stroke-dasharray", "6,3")
.attr("stroke-opacity", res1Active ? 1 : 0.4)
.attr("marker-end", res1Active ? "url(#block-arrow-res)" : "url(#block-arrow)")
.attr("filter", res1Active ? "url(#block-glow)" : null);
// Residual label
if (res1Active) {
svg.append("text")
.attr("x", centerX + 100)
.attr("y", (inputY + add1Y) / 2)
.attr("text-anchor", "start")
.attr("fill", colors.residual)
.attr("font-size", "10px")
.attr("font-weight", "600")
.text("residual");
}
// --- FFN Path ---
// LayerNorm 2
drawNode(centerX, ln2Y, "LayerNorm", null, "ln2", colors.ffn);
// Feed-Forward Network
drawNode(rightX, ffnY, "Feed-Forward", "(4x expand, GELU)", "ffn", colors.ffn);
// Arrow from add1 to ln2
drawArrow(centerX, add1Y + 20, centerX, ln2Y - nodeHeight/2 - 8, "e_add1_ln2", colors.ffn);
// Arrow from ln2 to ffn (horizontal then down)
const ln2ToFfnActive = isEdgeActive("e_ln2_ffn");
svg.append("path")
.attr("d", `M${centerX + nodeWidth/2},${ln2Y} L${rightX},${ln2Y} L${rightX},${ffnY - nodeHeight/2 - 8}`)
.attr("fill", "none")
.attr("stroke", ln2ToFfnActive ? colors.ffn : theme.edgeStroke)
.attr("stroke-width", ln2ToFfnActive ? 2.5 : 1.5)
.attr("stroke-opacity", ln2ToFfnActive ? 1 : 0.5)
.attr("marker-end", ln2ToFfnActive ? "url(#block-arrow-ffn)" : "url(#block-arrow)")
.attr("filter", ln2ToFfnActive ? "url(#block-glow)" : null);
// Second residual add
drawCircle(centerX, add2Y, "+", "add2", colors.residual);
// Arrow from ffn to add2
const ffnToAddActive = isEdgeActive("e_ffn_add2");
svg.append("path")
.attr("d", `M${rightX - nodeWidth/2},${ffnY} L${centerX + 20},${add2Y}`)
.attr("fill", "none")
.attr("stroke", ffnToAddActive ? colors.ffn : theme.edgeStroke)
.attr("stroke-width", ffnToAddActive ? 2.5 : 1.5)
.attr("stroke-opacity", ffnToAddActive ? 1 : 0.5)
.attr("marker-end", ffnToAddActive ? "url(#block-arrow-ffn)" : "url(#block-arrow)")
.attr("filter", ffnToAddActive ? "url(#block-glow)" : null);
// Residual connection 2 (from add1 to add2 - the skip)
const res2Active = isActive("residual2");
svg.append("path")
.attr("d", `M${centerX - 20},${add1Y} L${centerX - 60},${add1Y} L${centerX - 60},${add2Y} L${centerX - 20},${add2Y}`)
.attr("fill", "none")
.attr("stroke", res2Active ? colors.residual : theme.edgeStroke)
.attr("stroke-width", res2Active ? 2.5 : 1.5)
.attr("stroke-dasharray", "6,3")
.attr("stroke-opacity", res2Active ? 1 : 0.4)
.attr("marker-end", res2Active ? "url(#block-arrow-res)" : "url(#block-arrow)")
.attr("filter", res2Active ? "url(#block-glow)" : null);
// Residual 2 label
if (res2Active) {
svg.append("text")
.attr("x", centerX - 70)
.attr("y", (add1Y + add2Y) / 2)
.attr("text-anchor", "end")
.attr("fill", colors.residual)
.attr("font-size", "10px")
.attr("font-weight", "600")
.text("residual");
}
// Output node
drawNode(centerX, outputY, "Output", "(batch, seq, embed)", "output", colors.io);
// Arrow from add2 to output
drawArrow(centerX, add2Y + 20, centerX, outputY - nodeHeight/2 - 8, "e_add2_out", colors.io);
// === Legend ===
const legendX = 600;
const legendY = 120;
svg.append("text")
.attr("x", legendX)
.attr("y", legendY)
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text("Legend");
// Attention path
svg.append("rect")
.attr("x", legendX)
.attr("y", legendY + 15)
.attr("width", 14)
.attr("height", 14)
.attr("rx", 3)
.attr("fill", colors.attention)
.attr("fill-opacity", 0.3)
.attr("stroke", colors.attention);
svg.append("text")
.attr("x", legendX + 20)
.attr("y", legendY + 25)
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text("Attention path");
// FFN path
svg.append("rect")
.attr("x", legendX)
.attr("y", legendY + 38)
.attr("width", 14)
.attr("height", 14)
.attr("rx", 3)
.attr("fill", colors.ffn)
.attr("fill-opacity", 0.3)
.attr("stroke", colors.ffn);
svg.append("text")
.attr("x", legendX + 20)
.attr("y", legendY + 48)
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text("FFN path");
// Residual
svg.append("line")
.attr("x1", legendX)
.attr("y1", legendY + 68)
.attr("x2", legendX + 14)
.attr("y2", legendY + 68)
.attr("stroke", colors.residual)
.attr("stroke-width", 2)
.attr("stroke-dasharray", "4,2");
svg.append("text")
.attr("x", legendX + 20)
.attr("y", legendY + 71)
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text("Residual (skip)");
// Pre-Norm annotation
svg.append("rect")
.attr("x", legendX - 10)
.attr("y", legendY + 95)
.attr("width", 110)
.attr("height", 50)
.attr("rx", 6)
.attr("fill", theme.bgSecondary)
.attr("stroke", theme.nodeStroke)
.attr("stroke-width", 1);
svg.append("text")
.attr("x", legendX)
.attr("y", legendY + 113)
.attr("fill", theme.highlight)
.attr("font-size", "10px")
.attr("font-weight", "600")
.text("Pre-Norm Pattern:");
svg.append("text")
.attr("x", legendX)
.attr("y", legendY + 130)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.attr("font-family", "'IBM Plex Mono', monospace")
.text("x + f(LayerNorm(x))");
return svg.node();
}
```
The key innovation is the **residual connections** (the + nodes). Instead of `y = f(x)`, we compute `y = x + f(x)`. This:
- Helps gradients flow through deep networks
- Makes it easy to learn identity (just set `f(x) = 0`)
- Enables training of 100+ layer networks
## The Components
We build each component from scratch, then show the PyTorch equivalents. The pattern: understand the math, implement it simply, then see how PyTorch optimizes it.
### LayerNorm from Scratch
**The Idea**: Activations drift to extreme values during training, causing gradients to explode or vanish. Layer normalization fixes this by normalizing each token's embedding to zero mean and unit variance, then applying learnable scale and shift.
**The Formula**:
$$\text{LayerNorm}(x) = \gamma \times \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$
where:
- $\mu$ = mean across the embedding dimension
- $\sigma^2$ = variance across the embedding dimension
- $\gamma$ (gamma) = learnable scale parameter (initialized to 1)
- $\beta$ (beta) = learnable shift parameter (initialized to 0)
- $\epsilon$ = small constant for numerical stability (typically 1e-5)
**From Scratch Implementation**:
```{python}
import numpy as np
import torch
import torch.nn as nn
class LayerNormScratch:
"""Layer normalization from scratch using NumPy-style operations."""
def __init__(self, dim, eps=1e-5):
# Learnable parameters
self.gamma = np.ones((dim,)) # scale (initialized to 1)
self.beta = np.zeros((dim,)) # shift (initialized to 0)
self.eps = eps
def __call__(self, x):
"""
Args:
x: input array of shape (..., dim)
Returns:
normalized array of same shape
"""
# Step 1: Compute mean across last dimension
mean = x.mean(axis=-1, keepdims=True)
# Step 2: Compute variance across last dimension
var = ((x - mean) ** 2).mean(axis=-1, keepdims=True)
# Step 3: Normalize (the "norm" in LayerNorm)
x_norm = (x - mean) / np.sqrt(var + self.eps)
# Step 4: Scale and shift with learnable parameters
return self.gamma * x_norm + self.beta
# Test our from-scratch implementation
x = np.array([[2.0, 4.0, 6.0, 8.0],
[1.0, 2.0, 3.0, 4.0]])
ln_scratch = LayerNormScratch(dim=4)
out_scratch = ln_scratch(x)
print("LayerNorm from Scratch:")
print(f" Input:\n{x}")
print(f" Output:\n{np.round(out_scratch, 4)}")
print(f" Output mean per row: {out_scratch.mean(axis=-1).round(6)}")
print(f" Output std per row: {out_scratch.std(axis=-1).round(4)}")
```
**PyTorch's nn.LayerNorm**:
```{python}
# PyTorch's optimized implementation
ln_pytorch = nn.LayerNorm(4, elementwise_affine=True)
# Initialize to match our scratch version (gamma=1, beta=0)
nn.init.ones_(ln_pytorch.weight)
nn.init.zeros_(ln_pytorch.bias)
x_torch = torch.tensor(x, dtype=torch.float32)
out_pytorch = ln_pytorch(x_torch)
print("PyTorch LayerNorm:")
print(f" Output:\n{out_pytorch.detach().numpy().round(4)}")
print(f" Matches scratch: {np.allclose(out_scratch, out_pytorch.detach().numpy(), atol=1e-5)}")
```
**Key Insight**: LayerNorm is just normalize-scale-shift. The learnable $\gamma$ and $\beta$ let the network undo the normalization if needed, but start from a stable baseline. Unlike BatchNorm, LayerNorm normalizes across features (embedding dimension) rather than across batch, making it suitable for variable-length sequences.
### Dropout: Regularization by Noise
**The Idea**: During training, dropout zeros a random subset of activations. The network learns redundant representations rather than depending on any single feature. The key trick: scale remaining values by $\frac{1}{1-p}$ so the expected value stays the same.
**Why it works**:
- Forces the network to learn redundant representations
- Acts like training an ensemble of sub-networks
- At inference time, use all neurons (no dropout)
**From Scratch Implementation**:
```{python}
class DropoutScratch:
"""Dropout from scratch."""
def __init__(self, p=0.1):
"""
Args:
p: probability of dropping each element (not keeping!)
"""
self.p = p
def __call__(self, x, training=True):
"""
Args:
x: input array
training: if False, return x unchanged
Returns:
x with dropout applied (if training)
"""
if not training or self.p == 0:
return x
# Create random mask: True where we KEEP the value
keep_prob = 1 - self.p
mask = np.random.random(x.shape) < keep_prob
# Apply mask and scale by 1/(1-p)
# This keeps the expected value the same:
# E[x * mask / keep_prob] = x * keep_prob / keep_prob = x
return x * mask / keep_prob
# Demonstrate dropout
np.random.seed(42)
x = np.ones((2, 8))
dropout = DropoutScratch(p=0.5)
print("Dropout from Scratch (p=0.5):")
print(f" Input (all 1s): {x[0]}")
# Apply dropout multiple times to see the randomness
for i in range(3):
np.random.seed(i)
out = dropout(x, training=True)
print(f" Trial {i+1}: {out[0].round(2)}")
print(f" Mean: {out[0].mean():.2f} (should be ~1.0 on average)")
```
**The Scaling Trick Explained**:
```{python}
# Why divide by (1-p)?
# Without scaling, dropout reduces expected output
# With scaling, expected output stays the same
p = 0.5
np.random.seed(0)
x = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
# Without scaling
mask = np.random.random(x.shape) < (1-p)
out_no_scale = x * mask
print(f"Without scaling: {out_no_scale} -> mean = {out_no_scale.mean():.2f}")
# With scaling (divide by keep probability)
np.random.seed(0)
mask = np.random.random(x.shape) < (1-p)
out_scaled = x * mask / (1-p)
print(f"With scaling: {out_scaled} -> mean = {out_scaled.mean():.2f}")
print(f"\nThe scaling keeps expected value at 1.0 despite dropping 50% of values")
```
**PyTorch's nn.Dropout**:
```{python}
# PyTorch handles training/mode automatically
dropout_pytorch = nn.Dropout(p=0.5)
x_torch = torch.ones(2, 8)
# Training mode (dropout active)
dropout_pytorch.train()
torch.manual_seed(42)
out_train = dropout_pytorch(x_torch)
print(f"PyTorch Dropout (training): {out_train[0].numpy()}")
# Inference mode (dropout disabled)
dropout_pytorch.eval()
out_inference = dropout_pytorch(x_torch)
print(f"PyTorch Dropout (inference): {out_inference[0].numpy()}")
```
**Key Insight**: Dropout is random masking with scaling. Scaling by $\frac{1}{1-p}$ during training preserves the expected value, so inference requires no adjustment.
### Residual Connections: The Highway for Gradients
**The Idea**: Replace $y = f(x)$ with $y = x + f(x)$. This skip connection eliminates the vanishing gradient problem by giving gradients a direct path.
**Why it helps**:
{{< include _gradient-flow-viz.qmd >}}
**Pre-Norm vs Post-Norm**:
The original Transformer used "post-norm": normalize after the residual addition.
```python
# Post-Norm (original Transformer)
x = LayerNorm(x + Attention(x))
x = LayerNorm(x + FFN(x))
```
Modern LLMs use "pre-norm": normalize before each sublayer.
```python
# Pre-Norm (GPT-2, LLaMA, modern LLMs)
x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))
```
**Why Pre-Norm is Better**:
```{python}
# Demonstrate the stability difference
def simulate_forward_pass(num_layers, prenorm=True):
"""Simulate activation magnitudes through layers."""
x = 1.0 # Starting activation magnitude
for _ in range(num_layers):
if prenorm:
# Pre-norm: normalize first, then residual keeps things bounded
normed = 1.0 # After LayerNorm, magnitude is ~1
sublayer_out = normed * 0.5 # Sublayer output
x = x + sublayer_out # Residual addition
else:
# Post-norm: residual can grow, then we normalize
sublayer_out = x * 0.5
x = x + sublayer_out # Can grow unboundedly before norm
x = 1.0 # LayerNorm resets to ~1
return x
print("Activation stability comparison:")
print(f" Pre-norm after 24 layers: ~{simulate_forward_pass(24, prenorm=True):.1f}")
print(f" Post-norm after 24 layers: ~{simulate_forward_pass(24, prenorm=False):.1f}")
print("\nPre-norm has a cleaner gradient path because the skip connection")
print("bypasses normalization - gradients flow directly from output to input.")
```
**Key Insight**: Residual connections transform `y = f(x)` into `y = x + f(x)`. The gradient of this is `dy/dx = 1 + df/dx`. That `+ 1` is crucial - it means gradients always have a direct path through the network, even if `df/dx` is tiny.
### The Full Transformer Block from Scratch
**The Idea**: Now we assemble all the pieces into a complete transformer block:
1. LayerNorm + Multi-Head Attention + Residual
2. LayerNorm + Feed-Forward Network + Residual
```{python}
class FeedForwardScratch:
"""Simple feed-forward network from scratch."""
def __init__(self, embed_dim, ff_dim):
# Initialize weights with small random values
scale = 0.02
self.w1 = np.random.randn(embed_dim, ff_dim) * scale
self.b1 = np.zeros(ff_dim)
self.w2 = np.random.randn(ff_dim, embed_dim) * scale
self.b2 = np.zeros(embed_dim)
def gelu(self, x):
"""GELU activation: x * Phi(x) where Phi is standard normal CDF."""
return 0.5 * x * (1 + np.tanh(np.sqrt(2/np.pi) * (x + 0.044715 * x**3)))
def __call__(self, x):
# Up projection: embed_dim -> ff_dim
h = x @ self.w1 + self.b1
# Activation
h = self.gelu(h)
# Down projection: ff_dim -> embed_dim
return h @ self.w2 + self.b2
# NOTE: This uses a SIMPLIFIED attention (just linear projection) to focus on
# the overall block structure. Real attention with Q, K, V is in attention.py
class TransformerBlockScratch:
"""
A complete transformer block from scratch (with simplified attention).
Architecture (Pre-Norm):
x = x + Attention(LayerNorm(x))
x = x + FeedForward(LayerNorm(x))
WARNING: The attention here is simplified to a linear projection for
demonstration purposes. See m05_attention for full attention implementation.
"""
def __init__(self, embed_dim, num_heads, ff_dim, dropout_p=0.1):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Layer norms
self.ln1 = LayerNormScratch(embed_dim)
self.ln2 = LayerNormScratch(embed_dim)
# Attention projections (simplified: no actual attention computation)
# In a full implementation, this would include Q, K, V projections
scale = 0.02
self.attn_proj = np.random.randn(embed_dim, embed_dim) * scale
# Feed-forward network
self.ff = FeedForwardScratch(embed_dim, ff_dim)
# Dropout
self.dropout = DropoutScratch(dropout_p)
def __call__(self, x, training=True):
"""
Args:
x: input of shape (batch, seq, embed_dim)
training: whether to apply dropout
Returns:
output of shape (batch, seq, embed_dim)
"""
# === Attention sub-block ===
# 1. Layer norm (pre-norm)
normed = self.ln1(x)
# 2. Attention (simplified: just a linear projection for demo)
# Real implementation would compute Q, K, V and attention weights
attn_out = normed @ self.attn_proj
# 3. Dropout
attn_out = self.dropout(attn_out, training=training)
# 4. Residual connection
x = x + attn_out
# === Feed-forward sub-block ===
# 1. Layer norm (pre-norm)
normed = self.ln2(x)
# 2. Feed-forward network
ff_out = self.ff(normed)
# 3. Dropout
ff_out = self.dropout(ff_out, training=training)
# 4. Residual connection
x = x + ff_out
return x
# Test the from-scratch transformer block
np.random.seed(42)
block_scratch = TransformerBlockScratch(
embed_dim=64,
num_heads=4,
ff_dim=256,
dropout_p=0.0 # Disable dropout for reproducibility
)
x = np.random.randn(2, 8, 64) # batch=2, seq=8, embed=64
out = block_scratch(x, training=False)
print("Transformer Block from Scratch:")
print(f" Input shape: {x.shape}")
print(f" Output shape: {out.shape}")
print(f" Input mean: {x.mean():.4f}")
print(f" Output mean: {out.mean():.4f}")
print(f"\nThe block transforms each token while preserving shape.")
print("Residual connections keep the output close to input initially.")
```
**The Complete Picture**:
```{python}
# Visualize the transformer block structure
print("""
Transformer Block (Pre-Norm Architecture):
==========================================
Input x
|
+------------------+
| |
v |
LayerNorm |
| |
v |
Multi-Head Attention |
| |
v |
Dropout |
| |
+--------(+)-------+ <- Residual connection
|
+------------------+
| |
v |
LayerNorm |
| |
v |
Feed-Forward |
| |
v |
Dropout |
| |
+--------(+)-------+ <- Residual connection
|
v
Output
""")
```
**Key Insight**: A Transformer block contains only attention, MLP, residuals, and norms. Stacking dozens of blocks and training on billions of tokens produces the performance.
### PyTorch Transformer Modules
PyTorch provides optimized versions of everything we built from scratch.
**Comparison Table**:
| Component | From Scratch | PyTorch |
|-----------|--------------|---------|
| LayerNorm | Manual mean/var | `nn.LayerNorm` |
| Dropout | Random mask + scale | `nn.Dropout` |
| FFN | Two linear layers + GELU | Custom or `nn.Sequential` |
| Full Block | Manual assembly | `nn.TransformerDecoderLayer` |
```{python}
# PyTorch's TransformerDecoderLayer
# Note: This is for encoder-decoder models; for decoder-only like GPT,
# we typically build our own (as in transformer.py)
from torch.nn import TransformerDecoderLayer
# Create a decoder layer similar to our scratch implementation
pytorch_block = TransformerDecoderLayer(
d_model=64,
nhead=4,
dim_feedforward=256,
dropout=0.1,
activation='gelu',
batch_first=True,
norm_first=True # Pre-norm architecture
)
x_torch = torch.randn(2, 8, 64)
# For decoder-only, we use self-attention (memory = x)
pytorch_block.eval()
out_pytorch = pytorch_block(x_torch, x_torch)
print("PyTorch TransformerDecoderLayer:")
print(f" Input shape: {tuple(x_torch.shape)}")
print(f" Output shape: {tuple(out_pytorch.shape)}")
print(f" Parameters: {sum(p.numel() for p in pytorch_block.parameters()):,}")
```
**When to Use What**:
- **Learning**: Build from scratch to understand every step
- **Production**: Use PyTorch's optimized modules
- **Custom architectures**: Mix both - understand the components, then optimize
```{python}
# Our module's TransformerBlock (production quality)
from transformer import TransformerBlock
our_block = TransformerBlock(
embed_dim=64,
num_heads=4,
ff_dim=256,
dropout=0.1
)
x_torch = torch.randn(2, 8, 64)
our_block.eval()
out_ours = our_block(x_torch)
print("Our TransformerBlock (from transformer.py):")
print(f" Input shape: {tuple(x_torch.shape)}")
print(f" Output shape: {tuple(out_ours.shape)}")
print(f" Parameters: {sum(p.numel() for p in our_block.parameters()):,}")
print("\nThis is what we use for training - it includes proper")
print("causal attention, not the simplified version in scratch code.")
```
---
## More Component Details
### Layer Normalization (PyTorch Details)
Normalizes activations across the embedding dimension:
$$\text{LayerNorm}(x) = \gamma \times \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$
where $\mu$ and $\sigma^2$ are the mean and variance across the embedding dimension, and $\gamma$, $\beta$ are learnable parameters.
Why it helps:
- **Stabilizes activations**: Prevents values from exploding or vanishing
- **Faster training**: More stable gradients
- **Independent per token**: Each token normalized separately
### Feed-Forward Network (FFN)
```{ojs}
//| echo: false
// Step definitions for FFN walkthrough
ffnSteps = [
{
id: 0,
name: "Input",
description: "Input tensor arrives with embed_dim features per token",
activeNode: "input",
activeEdge: null
},
{
id: 1,
name: "First Linear (Expand)",
description: "Project from embed_dim to 4x embed_dim - the expansion gives capacity for complex transformations",
activeNode: "linear1",
activeEdge: "e_input_linear1"
},
{
id: 2,
name: "GELU Activation",
description: "Apply GELU nonlinearity - smoother than ReLU, allows gradients to flow",
activeNode: "gelu",
activeEdge: "e_linear1_gelu"
},
{
id: 3,
name: "Second Linear (Project)",
description: "Project back from 4x embed_dim to embed_dim - compress the expanded representation",
activeNode: "linear2",
activeEdge: "e_gelu_linear2"
},
{
id: 4,
name: "Output",
description: "Output tensor has the same shape as input - ready for residual connection",
activeNode: "output",
activeEdge: "e_linear2_output"
}
]
```
```{ojs}
//| echo: false
// Step control for FFN walkthrough
viewof ffnStep = stepControl({min: 0, max: 4, value: 0, label: "FFN Step"})
```
```{ojs}
//| echo: false
currentFFNStep = ffnSteps[ffnStep]
```
```{ojs}
//| echo: false
// Draw the FFN diagram
ffnDiagram = {
const width = 700;
const height = 180;
const marginX = 50;
const marginY = 40;
// Node dimensions - varying heights to show dimension expansion
const baseHeight = 50;
const expandedHeight = 100; // 4x expansion shown visually as 2x height
const nodeWidth = 100;
// Horizontal positions (evenly spaced)
const spacing = (width - 2 * marginX) / 4;
const inputX = marginX;
const linear1X = marginX + spacing;
const geluX = marginX + spacing * 2;
const linear2X = marginX + spacing * 3;
const outputX = marginX + spacing * 4;
const centerY = height / 2;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 8);
// Defs for markers and filters
const defs = svg.append("defs");
// Glow filter
const glowFilter = defs.append("filter")
.attr("id", "ffn-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "4")
.attr("result", "coloredBlur");
const glowMerge = glowFilter.append("feMerge");
glowMerge.append("feMergeNode").attr("in", "coloredBlur");
glowMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Arrow markers
defs.append("marker")
.attr("id", "ffn-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.edgeStroke);
defs.append("marker")
.attr("id", "ffn-arrow-active")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.highlight);
// Color for different node types
const colors = {
io: diagramTheme.accent, // Input/Output
linear: "#8b5cf6", // Linear layers (purple)
activation: "#10b981" // Activation (emerald)
};
// Helper to draw a node with variable height
function drawNode(x, y, w, h, label, sublabel, nodeId, colorType) {
const isActive = currentFFNStep.activeNode === nodeId;
const baseColor = colors[colorType] || diagramTheme.nodeFill;
const g = svg.append("g")
.attr("transform", `translate(${x}, ${y})`);
g.append("rect")
.attr("x", -w/2)
.attr("y", -h/2)
.attr("width", w)
.attr("height", h)
.attr("rx", 6)
.attr("fill", isActive ? baseColor : diagramTheme.nodeFill)
.attr("stroke", isActive ? baseColor : diagramTheme.nodeStroke)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.attr("opacity", isActive ? 1 : 0.85)
.style("filter", isActive ? "url(#ffn-glow)" : "none");
const textColor = isActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText;
if (sublabel) {
g.append("text")
.attr("y", -8)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(label);
g.append("text")
.attr("y", 8)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "9px")
.attr("opacity", isActive ? 0.95 : 0.7)
.text(sublabel);
} else {
g.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(label);
}
return g;
}
// Helper to draw an edge
function drawEdge(x1, y1, x2, y2, edgeId) {
const isActive = currentFFNStep.activeEdge === edgeId;
const strokeColor = isActive ? diagramTheme.highlight : diagramTheme.edgeStroke;
const markerId = isActive ? "ffn-arrow-active" : "ffn-arrow";
const path = svg.append("path")
.attr("d", `M${x1},${y1} L${x2},${y2}`)
.attr("fill", "none")
.attr("stroke", strokeColor)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.attr("marker-end", `url(#${markerId})`);
if (isActive) {
path.style("filter", `drop-shadow(0 0 4px ${diagramTheme.highlightGlow})`);
}
return path;
}
// Draw edges first (so nodes appear on top)
// Edge offsets to connect to node edges
const edgeGap = 8;
// Input -> Linear1
drawEdge(inputX + nodeWidth/2 + edgeGap, centerY, linear1X - nodeWidth/2 - edgeGap, centerY, "e_input_linear1");
// Linear1 -> GELU
drawEdge(linear1X + nodeWidth/2 + edgeGap, centerY, geluX - nodeWidth/2 - edgeGap, centerY, "e_linear1_gelu");
// GELU -> Linear2
drawEdge(geluX + nodeWidth/2 + edgeGap, centerY, linear2X - nodeWidth/2 - edgeGap, centerY, "e_gelu_linear2");
// Linear2 -> Output
drawEdge(linear2X + nodeWidth/2 + edgeGap, centerY, outputX - nodeWidth/2 - edgeGap, centerY, "e_linear2_output");
// Draw nodes
// Input (embed_dim)
drawNode(inputX, centerY, nodeWidth, baseHeight, "Input", "(embed_dim)", "input", "io");
// Linear1 - expanded height to show 4x
drawNode(linear1X, centerY, nodeWidth, expandedHeight, "Linear", "→ 4x embed", "linear1", "linear");
// GELU - also expanded (operates on 4x)
drawNode(geluX, centerY, nodeWidth, expandedHeight, "GELU", "(4x embed)", "gelu", "activation");
// Linear2 - expanded input, compressed output (show as expanded)
drawNode(linear2X, centerY, nodeWidth, expandedHeight, "Linear", "→ embed_dim", "linear2", "linear");
// Output (embed_dim)
drawNode(outputX, centerY, nodeWidth, baseHeight, "Output", "(embed_dim)", "output", "io");
// Dimension labels below the flow
const labelY = centerY + 60;
const dimLabelStyle = {
"font-size": "9px",
"fill": diagramTheme.nodeText,
"opacity": 0.6
};
// Add "4x expansion" annotation in the middle section
svg.append("text")
.attr("x", (linear1X + linear2X) / 2)
.attr("y", centerY - expandedHeight/2 - 12)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.highlight)
.attr("font-size", "10px")
.attr("font-weight", "500")
.text("4x capacity expansion");
// Small arrows to indicate expansion and compression
const arrowY = centerY - expandedHeight/2 - 6;
svg.append("path")
.attr("d", `M${linear1X - 20},${arrowY} L${linear1X + 20},${arrowY}`)
.attr("stroke", diagramTheme.highlight)
.attr("stroke-width", 1)
.attr("opacity", 0.5);
svg.append("path")
.attr("d", `M${linear2X - 20},${arrowY} L${linear2X + 20},${arrowY}`)
.attr("stroke", diagramTheme.highlight)
.attr("stroke-width", 1)
.attr("opacity", 0.5);
return svg.node();
}
```
```{ojs}
//| echo: false
// Step description display
html`<div style="
background: ${diagramTheme.bgSecondary};
border: 1px solid ${diagramTheme.nodeStroke};
border-radius: 6px;
padding: 12px 16px;
margin-top: 8px;
font-family: 'JetBrains Mono', 'Fira Code', monospace;
font-size: 13px;
">
<strong style="color: ${diagramTheme.highlight};">${currentFFNStep.name}</strong>
<span style="color: ${diagramTheme.nodeText}; opacity: 0.9;"> — ${currentFFNStep.description}</span>
</div>`
```
The FFN is a mini neural network applied to each token independently:
- **4x expansion**: More capacity to learn complex transformations
- **GELU activation**: Smoother than ReLU, better gradients
- **Same for all tokens**: Unlike attention, no mixing between positions
### Pre-Norm vs Post-Norm
We use **Pre-Norm** (LayerNorm before attention/FFN) rather than Post-Norm:
```python
# Pre-Norm (GPT-2, LLaMA, modern LLMs)
x = x + Attention(LayerNorm(x))
# Post-Norm (original Transformer paper)
x = LayerNorm(x + Attention(x))
```
**Why Pre-Norm is preferred:**
1. **Cleaner gradient path**: The residual connection bypasses normalization, so gradients flow directly
2. **More stable training**: Especially important for deep networks (24+ layers)
3. **Requires final LayerNorm**: Since the last block's output isn't normalized, we add a final LayerNorm before the output projection
Post-Norm achieves marginally better final performance with careful tuning, but Pre-Norm trains more robustly.
## Code Walkthrough
Let's build and explore transformer blocks:
```{python}
import sys
import importlib.util
from pathlib import Path
import torch
import torch.nn as nn
print(f"PyTorch version: {torch.__version__}")
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Device: {device}")
```
### GELU Activation
GPT uses GELU instead of ReLU. Let's see why:
{{< include _gelu-viz.qmd >}}
**GELU formula**: $\text{GELU}(x) = x \cdot \Phi(x)$ where $\Phi$ is the standard normal CDF.
Approximation used in practice: $0.5x(1 + \tanh(\sqrt{2/\pi}(x + 0.044715x^3)))$
**Activation function choices in modern LLMs:**
| Model | Activation | Notes |
|-------|-----------|-------|
| GPT-2, BERT | GELU | Smooth, good gradients |
| LLaMA, Mistral | SwiGLU | Gated variant, better performance |
| GPT-3 | GELU | Same as GPT-2 |
**SwiGLU** (used in LLaMA) is a gated linear unit: $\text{SwiGLU}(x) = \text{Swish}(xW_1) \otimes xW_2$. It requires an extra linear layer but often improves performance. Our implementation uses standard GELU to match GPT-2.
### Layer Normalization Demo
```{python}
# Manual LayerNorm demonstration
x = torch.tensor([[2.0, 4.0, 6.0, 8.0]])
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, unbiased=False, keepdim=True)
normalized = (x - mean) / torch.sqrt(var + 1e-5)
print("Manual LayerNorm:")
print(f" Input: {x.numpy().tolist()}")
print(f" Mean: {mean.item():.2f}")
print(f" Variance: {var.item():.2f}")
print(f" Normalized: {normalized.numpy().round(2).tolist()}")
print(f" New mean: {normalized.mean().item():.4f}")
print(f" New std: {normalized.std().item():.4f}")
```
```{python}
# PyTorch LayerNorm (with learnable gamma and beta)
ln = nn.LayerNorm(4)
pytorch_normalized = ln(x)
print(f"PyTorch LayerNorm output: {pytorch_normalized.detach().numpy().round(2).tolist()}")
print("(gamma and beta are learnable parameters)")
```
### Residual Connections Demo
{{< include _residual-effect-viz.qmd >}}
### Weight Initialization
Deep networks require proper initialization. Our implementation uses:
- **Embeddings**: Normal distribution with std=0.02
- **Linear layers in FFN**: Normal distribution with std=0.02, biases initialized to 0
- **Attention projections**: Xavier uniform initialization
```{python}
# Demonstrate the importance of initialization
import torch.nn as nn
# Bad initialization - too large
bad_linear = nn.Linear(768, 768)
nn.init.normal_(bad_linear.weight, std=1.0) # Too large!
# Good initialization - small weights
good_linear = nn.Linear(768, 768)
nn.init.normal_(good_linear.weight, std=0.02) # GPT-2 style
x = torch.randn(1, 10, 768)
bad_out = bad_linear(x)
good_out = good_linear(x)
print("Effect of initialization on output magnitude:")
print(f" Bad init (std=1.0): output std = {bad_out.std().item():.2f}")
print(f" Good init (std=0.02): output std = {good_out.std().item():.2f}")
print("\nLarge outputs can cause exploding gradients and NaN losses!")
```
**GPT-2's initialization trick**: Scale the final projection in each residual block by $1/\sqrt{2N}$ where N is the number of layers. This keeps the variance stable as depth increases.
### Building a Transformer Block
```{python}
from transformer import (
FeedForward,
TransformerBlock,
GPTModel,
create_gpt_tiny,
create_gpt_small,
)
# Create a transformer block
embed_dim = 64
num_heads = 4
ff_dim = 256
block = TransformerBlock(
embed_dim=embed_dim,
num_heads=num_heads,
ff_dim=ff_dim,
dropout=0.0
)
print(f"Transformer Block:")
print(f" Embed dim: {embed_dim}")
print(f" Num heads: {num_heads}")
print(f" Head dim: {embed_dim // num_heads}")
print(f" FF dim: {ff_dim}")
print(f"\nTotal parameters: {sum(p.numel() for p in block.parameters()):,}")
```
```{python}
# Forward pass
x = torch.randn(1, 8, embed_dim) # batch=1, seq=8
output, attention = block(x, return_attention=True)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention shape: {attention.shape}")
```
```{python}
#| output: false
# Pass attention weights to OJS for visualization
attention_list = [attention[0, h].detach().tolist() for h in range(4)]
ojs_define(attentionWeightsData=attention_list)
```
{{< include _attention-patterns-viz.qmd >}}
### Complete GPT Model
```{python}
# Create a tiny GPT model
model = create_gpt_tiny(vocab_size=1000)
print("GPT Tiny Model:")
print(f" Vocab size: {model.vocab_size}")
print(f" Embed dim: {model.embed_dim}")
print(f" Num layers: {len(model.blocks)}")
print(f" Max seq len: {model.max_seq_len}")
print(f"\nTotal parameters: {model.num_params:,}")
```
```{python}
# Parameter breakdown
counts = model.count_parameters()
print("Parameter breakdown:")
for name, count in counts.items():
if count > 0:
pct = 100 * count / counts['total']
print(f" {name}: {count:,} ({pct:.1f}%)")
```
```{python}
#| output: false
# Pass parameter counts to OJS for visualization
param_data = [
{"category": name, "value": count}
for name, count in counts.items()
if count > 0 and name != 'total'
]
ojs_define(parameterData=param_data, totalParams=counts['total'])
```
```{ojs}
//| echo: false
// Parameter distribution bar chart
paramDistChart = {
const data = typeof parameterData !== 'undefined' ? parameterData : [];
const total = typeof totalParams !== 'undefined' ? totalParams : 0;
if (data.length === 0) return html`<div>Loading...</div>`;
const width = 600;
const height = 280;
const margin = { top: 45, right: 30, bottom: 60, left: 130 };
const innerWidth = width - margin.left - margin.right;
const innerHeight = height - margin.top - margin.bottom;
const theme = diagramTheme;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'IBM Plex Sans', system-ui, sans-serif");
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 10);
const g = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
// Scales
const yScale = d3.scaleBand()
.domain(data.map(d => d.category))
.range([0, innerHeight])
.padding(0.25);
const xScale = d3.scaleLinear()
.domain([0, d3.max(data, d => d.value)])
.range([0, innerWidth]);
// Colors - use theme colors for consistency
const colors = [theme.primary, theme.accent, theme.success, theme.info, theme.error];
// Bars
g.selectAll("rect.bar")
.data(data)
.join("rect")
.attr("class", "bar")
.attr("x", 0)
.attr("y", d => yScale(d.category))
.attr("width", d => xScale(d.value))
.attr("height", yScale.bandwidth())
.attr("rx", 4)
.attr("fill", (d, i) => colors[i % colors.length])
.attr("opacity", 0.85);
// Value labels
g.selectAll("text.value")
.data(data)
.join("text")
.attr("class", "value")
.attr("x", d => xScale(d.value) + 8)
.attr("y", d => yScale(d.category) + yScale.bandwidth() / 2)
.attr("dominant-baseline", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text(d => {
const pct = (d.value / total * 100).toFixed(1);
return `${(d.value / 1000).toFixed(1)}K (${pct}%)`;
});
// Y axis
g.append("g")
.call(d3.axisLeft(yScale))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick line").remove())
.call(g => g.selectAll(".tick text")
.attr("fill", theme.nodeText)
.attr("font-size", "11px"));
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text(`Parameter Distribution (${(total / 1000).toFixed(0)}K total)`);
return svg.node();
}
```
```{python}
# Forward pass
token_ids = torch.randint(0, 1000, (2, 32)) # batch=2, seq=32
logits = model(token_ids)
print(f"Input token IDs: {token_ids.shape}")
print(f"Output logits: {logits.shape}")
print(f" (batch=2, seq=32, vocab=1000)")
# Get predictions
probs = torch.softmax(logits[0, -1], dim=-1)
top_5 = torch.topk(probs, 5)
print("\nTop 5 predicted next tokens (untrained, so random):")
for i, (idx, prob) in enumerate(zip(top_5.indices, top_5.values)):
print(f" {i+1}. Token {idx.item()}: {prob.item()*100:.2f}%")
```
### Hidden States Through Layers
```{python}
# Get hidden states from all layers
logits, hidden_states = model(token_ids, return_hidden_states=True)
print(f"Number of hidden states: {len(hidden_states)}")
print(f" (1 after embedding + {len(model.blocks)} after each block)")
# Show how representations change through layers
norms = [h.norm(dim=-1).mean().item() for h in hidden_states]
layer_names = ['Embed'] + [f'Block {i}' for i in range(len(model.blocks))]
```
```{python}
#| output: false
# Pass data to OJS for visualization
layer_norms_data = [{"layer": name, "norm": norm} for name, norm in zip(layer_names, norms)]
ojs_define(layerNormsData=layer_norms_data)
```
```{ojs}
//| echo: false
// Embedding norms through network visualization
embeddingNormsChart = {
const data = typeof layerNormsData !== 'undefined' ? layerNormsData : [];
if (data.length === 0) return html`<div>Loading...</div>`;
const width = 650;
const height = 280;
const margin = { top: 45, right: 30, bottom: 70, left: 60 };
const innerWidth = width - margin.left - margin.right;
const innerHeight = height - margin.top - margin.bottom;
const theme = diagramTheme;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'IBM Plex Sans', system-ui, sans-serif");
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 10);
const g = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
// Scales
const xScale = d3.scaleBand()
.domain(data.map(d => d.layer))
.range([0, innerWidth])
.padding(0.3);
const yScale = d3.scaleLinear()
.domain([0, d3.max(data, d => d.norm) * 1.15])
.range([innerHeight, 0]);
// Grid
const gridColor = theme.isDark ? "rgba(255,255,255,0.08)" : "rgba(0,0,0,0.08)";
g.selectAll("line.grid")
.data(yScale.ticks(5))
.join("line")
.attr("x1", 0)
.attr("x2", innerWidth)
.attr("y1", d => yScale(d))
.attr("y2", d => yScale(d))
.attr("stroke", gridColor);
// Gradient for bars
const defs = svg.append("defs");
const gradient = defs.append("linearGradient")
.attr("id", "embed-norm-gradient")
.attr("x1", "0%")
.attr("y1", "100%")
.attr("x2", "0%")
.attr("y2", "0%");
gradient.append("stop")
.attr("offset", "0%")
.attr("stop-color", theme.accent);
gradient.append("stop")
.attr("offset", "100%")
.attr("stop-color", theme.highlight);
// Bars
g.selectAll("rect.bar")
.data(data)
.join("rect")
.attr("class", "bar")
.attr("x", d => xScale(d.layer))
.attr("y", d => yScale(d.norm))
.attr("width", xScale.bandwidth())
.attr("height", d => innerHeight - yScale(d.norm))
.attr("rx", 4)
.attr("fill", "url(#embed-norm-gradient)")
.attr("opacity", 0.85);
// Value labels
g.selectAll("text.value")
.data(data)
.join("text")
.attr("class", "value")
.attr("x", d => xScale(d.layer) + xScale.bandwidth() / 2)
.attr("y", d => yScale(d.norm) - 8)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.attr("font-weight", "500")
.text(d => d.norm.toFixed(2));
// Line connecting bars
const lineGen = d3.line()
.x(d => xScale(d.layer) + xScale.bandwidth() / 2)
.y(d => yScale(d.norm))
.curve(d3.curveMonotoneX);
g.append("path")
.datum(data)
.attr("d", lineGen)
.attr("fill", "none")
.attr("stroke", theme.highlight)
.attr("stroke-width", 2)
.attr("stroke-dasharray", "5,3")
.attr("opacity", 0.5);
// X axis
g.append("g")
.attr("transform", `translate(0, ${innerHeight})`)
.call(d3.axisBottom(xScale))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick line").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick text")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.attr("transform", "rotate(-30)")
.attr("text-anchor", "end")
.attr("dx", "-0.5em")
.attr("dy", "0.5em"));
// Y axis
g.append("g")
.call(d3.axisLeft(yScale).ticks(5))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick line").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick text")
.attr("fill", theme.nodeText)
.attr("font-size", "11px"));
g.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -innerHeight / 2)
.attr("y", -45)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "12px")
.text("Average Embedding Norm");
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text("Embedding Norms Through the Network");
return svg.node();
}
```
### Weight Tying
GPT shares weights between token embedding and output projection:
```{python}
# Check weight tying
print("Weight Tying:")
print(f" Token embedding weight id: {id(model.token_embedding.weight)}")
print(f" LM head weight id: {id(model.lm_head.weight)}")
print(f" Are they the same object? {model.token_embedding.weight is model.lm_head.weight}")
# This saves parameters!
vocab_size = 1000
embed_dim = 128
saved_params = vocab_size * embed_dim
print(f"\nParameters saved by weight tying: {saved_params:,}")
```
## Model Sizes Comparison
| Model | Layers | Heads | Embed Dim | Params |
|-------|--------|-------|-----------|--------|
| Tiny (ours) | 4 | 4 | 128 | ~1M |
| Small (ours) | 6 | 6 | 384 | ~10M |
| GPT-2 Small | 12 | 12 | 768 | 117M |
| GPT-2 Medium | 24 | 16 | 1024 | 345M |
| GPT-2 Large | 36 | 20 | 1280 | 774M |
| GPT-2 XL | 48 | 25 | 1600 | 1.5B |
### Parameter Counting Formulas
Knowing where parameters come from aids model sizing:
**Per Transformer Block:**
- Attention Q, K, V projections: $3 \times d \times d$ (where $d$ = embed_dim)
- Attention output projection: $d \times d$
- FFN first linear: $d \times 4d$
- FFN second linear: $4d \times d$
- LayerNorm (x2): $2 \times 2d$ (gamma and beta for each)
Total per block: $\approx 12d^2$ parameters
**Full Model:**
- Token embedding: $V \times d$ (V = vocab size)
- Position embedding: $L \times d$ (L = max sequence length)
- N transformer blocks: $N \times 12d^2$
- Final LayerNorm: $2d$
- LM head: 0 (weight-tied with token embedding)
**Approximate formula**: $\text{Params} \approx V \times d + 12Nd^2$
For GPT-2 Small (V=50257, d=768, N=12): $50257 \times 768 + 12 \times 12 \times 768^2 \approx 117M$
Scaling laws: more layers, heads, and dimensions lead to better performance (but diminishing returns and higher compute cost).
## Architectural Variations
Modern LLMs have evolved beyond the original GPT-2 architecture. Here are key variations:
### Normalization
| Variant | Used By | Description |
|---------|---------|-------------|
| LayerNorm | GPT-2, GPT-3 | Normalize across embedding dimension |
| RMSNorm | LLaMA, Mistral | Simpler: just divide by RMS, no mean subtraction |
| Pre-Norm | Most modern LLMs | Normalize before sublayer (more stable) |
### Position Embeddings
| Variant | Used By | Description |
|---------|---------|-------------|
| Learned absolute | GPT-2 | Separate embedding for each position |
| Rotary (RoPE) | LLaMA, Mistral | Encode position in attention via rotation |
| ALiBi | BLOOM | Add position bias to attention scores |
Our implementation uses learned absolute position embeddings (GPT-2 style), which are simple but limit the model to the maximum trained sequence length.
### Feed-Forward Networks
| Variant | Used By | Expansion | Activation |
|---------|---------|-----------|------------|
| Standard | GPT-2 | 4x | GELU |
| SwiGLU | LLaMA | 8/3x (after gating) | SiLU (Swish) |
## Common Pitfalls
When implementing or training transformers, watch out for:
1. **Forgetting the causal mask**: Without it, the model "cheats" by seeing future tokens during training, which cripples generation at inference.
2. **Wrong normalization axis**: LayerNorm should normalize across the embedding dimension (last axis), not the sequence or batch dimensions.
3. **Residual connection placement**: Make sure to add the residual *after* dropout but *before* the next LayerNorm in Pre-Norm architecture.
4. **Large learning rates**: Transformers are sensitive to learning rate. Start with 1e-4 to 3e-4 for Adam, use warmup.
5. **Numerical instability**: Use float32 for training initially. Half precision (fp16/bf16) requires careful scaling.
6. **Forgetting final LayerNorm**: In Pre-Norm, the output of the last block isn't normalized. The final LayerNorm before the LM head is essential.
## Interactive Exploration
Experiment with transformer architecture choices to understand where parameters come from and how they scale.
```{ojs}
//| echo: false
// Parameter counting functions
function countParameters(vocabSize, embedDim, numLayers, numHeads, ffMult, maxSeqLen) {
// Token embeddings (weight-tied with LM head)
const tokenEmbed = vocabSize * embedDim;
// Position embeddings
const posEmbed = maxSeqLen * embedDim;
// Per transformer block
const attnParams = 4 * embedDim * embedDim; // Q, K, V, O projections
const ffnParams = 2 * embedDim * (ffMult * embedDim); // up + down projection
const normParams = 4 * embedDim; // 2 LayerNorms with gamma + beta each
const perBlock = attnParams + ffnParams + normParams;
// Total blocks
const totalBlocks = numLayers * perBlock;
// Final LayerNorm
const finalNorm = 2 * embedDim;
return {
tokenEmbed,
posEmbed,
attention: numLayers * attnParams,
ffn: numLayers * ffnParams,
layerNorm: numLayers * normParams + finalNorm,
total: tokenEmbed + posEmbed + totalBlocks + finalNorm
};
}
// Format large numbers
function formatParams(n) {
if (n >= 1e9) return (n / 1e9).toFixed(2) + "B";
if (n >= 1e6) return (n / 1e6).toFixed(2) + "M";
if (n >= 1e3) return (n / 1e3).toFixed(2) + "K";
return n.toString();
}
// Format memory
function formatMemory(bytes) {
if (bytes >= 1e9) return (bytes / 1e9).toFixed(2) + " GB";
if (bytes >= 1e6) return (bytes / 1e6).toFixed(1) + " MB";
return (bytes / 1e3).toFixed(1) + " KB";
}
```
```{ojs}
//| echo: false
// Architecture presets
presets = ({
"Tiny (Ours)": { vocab: 10000, embed: 128, layers: 4, heads: 4, ff: 4, seq: 512 },
"Small (Ours)": { vocab: 10000, embed: 384, layers: 6, heads: 6, ff: 4, seq: 512 },
"GPT-2 Small": { vocab: 50257, embed: 768, layers: 12, heads: 12, ff: 4, seq: 1024 },
"GPT-2 Medium": { vocab: 50257, embed: 1024, layers: 24, heads: 16, ff: 4, seq: 1024 },
"GPT-2 Large": { vocab: 50257, embed: 1280, layers: 36, heads: 20, ff: 4, seq: 1024 }
})
viewof preset = Inputs.select(Object.keys(presets), {
label: "Load Preset",
value: "Small (Ours)"
})
```
```{ojs}
//| echo: false
// Input controls with preset values
selectedPreset = presets[preset]
viewof vocabSize = Inputs.select([1000, 5000, 10000, 32000, 50257, 100000], {
value: selectedPreset.vocab,
label: "Vocabulary Size"
})
viewof embedDim = Inputs.select([64, 128, 256, 384, 512, 768, 1024, 1280, 2048, 4096], {
value: selectedPreset.embed,
label: "Embedding Dimension"
})
viewof numLayers = Inputs.range([1, 48], {
value: selectedPreset.layers,
step: 1,
label: "Number of Layers"
})
viewof numHeads = Inputs.range([1, 32], {
value: selectedPreset.heads,
step: 1,
label: "Number of Heads"
})
viewof ffMult = Inputs.range([1, 8], {
value: selectedPreset.ff,
step: 1,
label: "FFN Multiplier"
})
viewof maxSeqLen = Inputs.select([256, 512, 1024, 2048, 4096, 8192], {
value: selectedPreset.seq,
label: "Max Sequence Length"
})
```
```{ojs}
//| echo: false
// Theme colors derived from diagramTheme for consistency
theme = {
const t = diagramTheme;
return {
tokenEmbed: t.primary,
posEmbed: t.accent,
attention: t.success,
ffn: t.info,
layerNorm: t.nodeStroke,
warning: t.error
};
}
```
```{ojs}
//| echo: false
// Compute parameters
params = countParameters(vocabSize, embedDim, numLayers, numHeads, ffMult, maxSeqLen)
// Data for pie chart
pieData = [
{ category: "Token Embeddings", value: params.tokenEmbed, color: theme.tokenEmbed },
{ category: "Position Embeddings", value: params.posEmbed, color: theme.posEmbed },
{ category: "Attention", value: params.attention, color: theme.attention },
{ category: "Feed-Forward", value: params.ffn, color: theme.ffn },
{ category: "LayerNorm", value: params.layerNorm, color: theme.layerNorm }
].filter(d => d.value > 0)
// Memory calculations
memoryFp32 = params.total * 4
memoryFp16 = params.total * 2
memoryInt8 = params.total * 1
// Head dimension check
headDim = embedDim / numHeads
headDimValid = embedDim % numHeads === 0
```
```{ojs}
//| echo: false
Plot = import("https://esm.sh/@observablehq/plot@0.6")
parameterDistributionChart = Plot.plot({
title: `Parameter Distribution: ${formatParams(params.total)} total`,
width: 600,
height: 300,
marginLeft: 120,
x: {
label: "Parameters →",
tickFormat: d => formatParams(d)
},
y: {
label: null
},
color: {
domain: pieData.map(d => d.category),
range: pieData.map(d => d.color)
},
marks: [
Plot.barX(pieData, {
y: "category",
x: "value",
fill: "category",
tip: true,
title: d => `${d.category}: ${formatParams(d.value)} (${(d.value / params.total * 100).toFixed(1)}%)`
}),
Plot.ruleX([0])
]
})
```
```{ojs}
//| echo: false
// Summary statistics
md`
| Component | Parameters | Percentage |
|-----------|------------|------------|
| Token Embeddings | ${formatParams(params.tokenEmbed)} | ${(params.tokenEmbed / params.total * 100).toFixed(1)}% |
| Position Embeddings | ${formatParams(params.posEmbed)} | ${(params.posEmbed / params.total * 100).toFixed(1)}% |
| Attention (${numLayers} layers) | ${formatParams(params.attention)} | ${(params.attention / params.total * 100).toFixed(1)}% |
| Feed-Forward (${numLayers} layers) | ${formatParams(params.ffn)} | ${(params.ffn / params.total * 100).toFixed(1)}% |
| LayerNorm | ${formatParams(params.layerNorm)} | ${(params.layerNorm / params.total * 100).toFixed(1)}% |
| **Total** | **${formatParams(params.total)}** | **100%** |
`
```
```{ojs}
//| echo: false
md`**Memory Requirements:** ${formatMemory(memoryFp32)} (fp32) | ${formatMemory(memoryFp16)} (fp16) | ${formatMemory(memoryInt8)} (int8)`
```
```{ojs}
//| echo: false
// Validation warning
headDimValid ? md`` : md`<span style="color: ${theme.warning}">⚠️ Warning: embed_dim (${embedDim}) is not divisible by num_heads (${numHeads}). Head dimension would be ${headDim.toFixed(2)}.</span>`
```
::: {.callout-tip}
## Try This
1. **FFN dominates**: Set embed_dim=768, layers=12. Notice the Feed-Forward bars are ~2x the Attention bars (because FFN has 8d² params vs Attention's 4d²).
2. **Embedding cost at small scale**: With vocab=50257 and embed_dim=768, token embeddings are ~38M params - a large fraction for small models.
3. **Scaling law**: Double embed_dim from 512 to 1024. Total params roughly quadruple (because most params scale with d²).
4. **Load GPT-2 presets** and see how the 117M, 345M, 774M models break down.
5. **Head dimension check**: Try numHeads that doesn't divide embedDim evenly - you'll see a warning.
:::
## Exercises
### Exercise 1: Build a Custom Block
```{python}
# Create a transformer block with different configurations
custom_block = TransformerBlock(
embed_dim=128,
num_heads=8,
ff_dim=512, # 4x expansion
dropout=0.1
)
# Test it
x = torch.randn(4, 16, 128) # batch=4, seq=16
output = custom_block(x)
print(f"Custom block: {x.shape} -> {output.shape}")
print(f"Parameters: {sum(p.numel() for p in custom_block.parameters()):,}")
```
### Exercise 2: Compare Model Scales
```{python}
# Compare tiny vs small model
tiny = create_gpt_tiny(vocab_size=10000)
small = create_gpt_small(vocab_size=10000)
print(f"{'Model':<10} {'Embed':<8} {'Layers':<8} {'Heads':<8} {'Params':<15}")
print("-" * 50)
print(f"{'Tiny':<10} {tiny.embed_dim:<8} {len(tiny.blocks):<8} {tiny.blocks[0].attention.mha.num_heads:<8} {tiny.num_params:,}")
print(f"{'Small':<10} {small.embed_dim:<8} {len(small.blocks):<8} {small.blocks[0].attention.mha.num_heads:<8} {small.num_params:,}")
```
### Exercise 3: Information Flow
```{python}
# See how a single token's representation changes through layers
model = create_gpt_tiny(vocab_size=100)
token_ids = torch.randint(0, 100, (1, 8))
_, hidden = model(token_ids, return_hidden_states=True)
# Track first token through layers
first_token_norms = [h[0, 0].norm().item() for h in hidden]
ex3_layer_names = ['Embed'] + [f'Block {i}' for i in range(len(model.blocks))]
print(f"First token embedding norm through {len(first_token_norms)} layers:")
for name, norm in zip(ex3_layer_names, first_token_norms):
print(f" {name}: {norm:.3f}")
```
```{python}
#| output: false
# Pass data to OJS for visualization
first_token_data = [{"layer": name, "norm": norm} for name, norm in zip(ex3_layer_names, first_token_norms)]
ojs_define(firstTokenNormsData=first_token_data)
```
```{ojs}
//| echo: false
// First token representation through layers
firstTokenChart = {
const data = typeof firstTokenNormsData !== 'undefined' ? firstTokenNormsData : [];
if (data.length === 0) return html`<div>Loading...</div>`;
const width = 600;
const height = 260;
const margin = { top: 45, right: 30, bottom: 70, left: 60 };
const innerWidth = width - margin.left - margin.right;
const innerHeight = height - margin.top - margin.bottom;
const theme = diagramTheme;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'IBM Plex Sans', system-ui, sans-serif");
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 10);
const g = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
// Scales
const xScale = d3.scaleBand()
.domain(data.map(d => d.layer))
.range([0, innerWidth])
.padding(0.3);
const yScale = d3.scaleLinear()
.domain([0, d3.max(data, d => d.norm) * 1.15])
.range([innerHeight, 0]);
// Grid
const gridColor = theme.isDark ? "rgba(255,255,255,0.08)" : "rgba(0,0,0,0.08)";
g.selectAll("line.grid")
.data(yScale.ticks(5))
.join("line")
.attr("x1", 0)
.attr("x2", innerWidth)
.attr("y1", d => yScale(d))
.attr("y2", d => yScale(d))
.attr("stroke", gridColor);
// Bars
g.selectAll("rect.bar")
.data(data)
.join("rect")
.attr("class", "bar")
.attr("x", d => xScale(d.layer))
.attr("y", d => yScale(d.norm))
.attr("width", xScale.bandwidth())
.attr("height", d => innerHeight - yScale(d.norm))
.attr("rx", 4)
.attr("fill", theme.accent)
.attr("opacity", 0.8);
// Value labels
g.selectAll("text.value")
.data(data)
.join("text")
.attr("class", "value")
.attr("x", d => xScale(d.layer) + xScale.bandwidth() / 2)
.attr("y", d => yScale(d.norm) - 6)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text(d => d.norm.toFixed(2));
// X axis
g.append("g")
.attr("transform", `translate(0, ${innerHeight})`)
.call(d3.axisBottom(xScale))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick line").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick text")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.attr("transform", "rotate(-30)")
.attr("text-anchor", "end")
.attr("dx", "-0.5em")
.attr("dy", "0.5em"));
// Y axis
g.append("g")
.call(d3.axisLeft(yScale).ticks(5))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick line").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick text")
.attr("fill", theme.nodeText)
.attr("font-size", "11px"));
g.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -innerHeight / 2)
.attr("y", -45)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "12px")
.text("Embedding Norm (First Token)");
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text("First Token Representation Through Layers");
return svg.node();
}
```
## Summary
Key takeaways:
1. **Transformer architecture**: Input embeddings -> N transformer blocks -> Final LayerNorm -> Output projection
2. **Each block has two sublayers**:
- Multi-head attention (tokens communicate)
- Feed-forward network (tokens processed independently)
3. **Pre-Norm architecture**: LayerNorm before each sublayer, with a "clean" residual path for stable gradients
4. **Layer normalization**: Normalizes across the embedding dimension, keeping activations in a stable range
5. **Residual connections**: `x + f(x)` enables gradient flow through very deep networks (100+ layers)
6. **Feed-forward networks**: 4x expansion with GELU activation provides computational capacity
7. **Weight tying**: Sharing token embedding and output projection reduces parameters and improves performance
8. **Initialization matters**: Small initial weights (std=0.02) prevent exploding activations
9. **Parameter scaling**: Total params $\approx V \times d + 12Nd^2$ (dominated by FFN for large models)
10. **Architectural variations**: Modern LLMs (LLaMA, Mistral) use RMSNorm, RoPE, and SwiGLU for better efficiency
## What's Next
[Module 07: Training](../m07_training/lesson.qmd) trains our transformer on actual data using cross-entropy loss, learning rate scheduling, and gradient accumulation.