---
title: "Module 06: Transformer"
format:
html:
code-fold: false
toc: true
ipynb: default
jupyter: python3
---
{{< include ../_diagram-lib.qmd >}}
## Introduction
The **transformer decoder block** is the building block of GPT-style language models. Stack 12-96 of these blocks, and you get models like GPT-2, GPT-3, or LLaMA.
In this module, we'll combine everything we've built so far:
- **Multi-head attention** from Module 05
- **Feed-forward networks** (mini neural networks for each token)
- **Layer normalization** (stabilizes training)
- **Residual connections** (enables deep networks)
Each block performs two main operations:
1. **Multi-head attention**: Tokens communicate with each other
2. **Feed-forward network**: Each token is processed independently
### Decoder-Only vs Encoder-Decoder
This module implements **decoder-only** transformers (GPT-style). There are two main transformer architectures:
| Architecture | Examples | Use Case | Attention |
|-------------|----------|----------|-----------|
| Decoder-only | GPT, LLaMA, Claude | Text generation | Causal (can't see future) |
| Encoder-Decoder | T5, BART, original Transformer | Translation, summarization | Bidirectional encoder + causal decoder |
We focus on decoder-only because it's simpler and powers most modern LLMs.
### What You'll Learn
By the end of this module, you will be able to:
- Understand the complete GPT-style transformer architecture
- Implement LayerNorm, GELU, and feed-forward networks from scratch
- Build a full transformer block with residual connections
- Assemble a complete language model from components
- Calculate parameter counts for different model sizes
## Complete Model Architecture
```{ojs}
//| echo: false
// Step definitions for the forward pass walkthrough
architectureSteps = [
{
id: 0,
name: "Input",
description: "Token IDs enter the model as integer indices into the vocabulary.",
shape: "[batch, seq_len]",
example: "[1, 4] → 4 token IDs"
},
{
id: 1,
name: "Token Embedding",
description: "Each token ID is mapped to a dense vector via learned embedding lookup.",
shape: "[batch, seq_len, embed_dim]",
example: "[1, 4, 128] → 4 vectors of dim 128"
},
{
id: 2,
name: "Position Embedding",
description: "Position information is added so the model knows token order.",
shape: "[batch, seq_len, embed_dim]",
example: "pos[0..3] added to each token"
},
{
id: 3,
name: "Transformer Blocks",
description: "N stacked blocks refine representations through attention and FFN.",
shape: "[batch, seq_len, embed_dim]",
example: "6 blocks × (attention + FFN)"
},
{
id: 4,
name: "Final LayerNorm",
description: "Normalize activations before the output projection.",
shape: "[batch, seq_len, embed_dim]",
example: "Stabilize for prediction"
},
{
id: 5,
name: "LM Head → Logits",
description: "Project to vocabulary size. Each position predicts the next token.",
shape: "[batch, seq_len, vocab_size]",
example: "[1, 4, 10000] → scores for each word"
}
]
```
```{ojs}
//| echo: false
// Step slider control
viewof archStep = Inputs.range([0, 5], {
value: 0,
step: 1,
label: "Forward Pass Step"
})
```
```{ojs}
//| echo: false
// Current step info
currentArchStep = architectureSteps[archStep]
```
```{ojs}
//| echo: false
// Draw the GPT architecture diagram
{
const width = 680;
const height = 620;
const marginLeft = 40;
const marginRight = 180;
const marginTop = 50;
const marginBottom = 30;
// Component dimensions
const nodeWidth = 140;
const nodeHeight = 44;
const blockWidth = 120;
const blockHeight = 36;
const numBlocks = 6;
const blockGap = 6;
// Vertical layout
const inputY = marginTop + 30;
const tokenEmbY = inputY + 70;
const posEmbY = tokenEmbY;
const addY = tokenEmbY + 65;
const blocksStartY = addY + 75;
const blocksEndY = blocksStartY + numBlocks * (blockHeight + blockGap) - blockGap;
const lnFinalY = blocksEndY + 65;
const lmHeadY = lnFinalY + 60;
const logitsY = lmHeadY + 65;
// Horizontal positions
const centerX = marginLeft + (width - marginLeft - marginRight) / 2;
const tokenEmbX = centerX - 80;
const posEmbX = centerX + 80;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 10);
// Defs for arrows and filters
const defs = svg.append("defs");
// Glow filter
const glowFilter = defs.append("filter")
.attr("id", "arch-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "4")
.attr("result", "coloredBlur");
const glowMerge = glowFilter.append("feMerge");
glowMerge.append("feMergeNode").attr("in", "coloredBlur");
glowMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Arrow markers
defs.append("marker")
.attr("id", "arch-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.edgeStroke);
defs.append("marker")
.attr("id", "arch-arrow-active")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.highlight);
// Color scheme for component types
const colors = {
input: diagramTheme.accent,
embedding: "#6366f1", // Indigo
blocks: "#059669", // Emerald
output: "#dc2626" // Red
};
// Helper to draw a node
function drawNode(x, y, w, h, label, sublabel, stepId, colorType) {
const isActive = archStep === stepId;
const baseColor = colors[colorType] || diagramTheme.nodeFill;
const g = svg.append("g")
.attr("transform", `translate(${x}, ${y})`);
g.append("rect")
.attr("x", -w/2)
.attr("y", -h/2)
.attr("width", w)
.attr("height", h)
.attr("rx", 6)
.attr("fill", isActive ? baseColor : diagramTheme.nodeFill)
.attr("stroke", isActive ? baseColor : diagramTheme.nodeStroke)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.attr("opacity", isActive ? 1 : 0.85)
.style("filter", isActive ? "url(#arch-glow)" : "none");
const textColor = isActive ? "#ffffff" : diagramTheme.nodeText;
if (sublabel) {
g.append("text")
.attr("y", -6)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(label);
g.append("text")
.attr("y", 10)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "9px")
.attr("opacity", 0.8)
.text(sublabel);
} else {
g.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(label);
}
return g;
}
// Helper to draw an arrow
function drawArrow(x1, y1, x2, y2, isActive) {
svg.append("path")
.attr("d", `M${x1},${y1} L${x2},${y2}`)
.attr("fill", "none")
.attr("stroke", isActive ? diagramTheme.highlight : diagramTheme.edgeStroke)
.attr("stroke-width", isActive ? 2 : 1.5)
.attr("marker-end", isActive ? "url(#arch-arrow-active)" : "url(#arch-arrow)")
.attr("opacity", isActive ? 1 : 0.6);
}
// Title
svg.append("text")
.attr("x", centerX)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "700")
.text("GPT Architecture: Forward Pass");
// === Draw components ===
// Input tokens
drawNode(centerX, inputY, nodeWidth, nodeHeight, "Token IDs", "[23, 156, 42, 789]", 0, "input");
// Arrows from input to embeddings
const splitY = inputY + nodeHeight/2 + 10;
drawArrow(centerX - 20, inputY + nodeHeight/2, tokenEmbX, tokenEmbY - nodeHeight/2 - 8, archStep === 0 || archStep === 1);
drawArrow(centerX + 20, inputY + nodeHeight/2, posEmbX, posEmbY - nodeHeight/2 - 8, archStep === 0 || archStep === 1);
// Token Embedding
drawNode(tokenEmbX, tokenEmbY, nodeWidth, nodeHeight, "Token Embed", "(vocab, embed_dim)", 1, "embedding");
// Position Embedding
drawNode(posEmbX, posEmbY, nodeWidth, nodeHeight, "Pos Embed", "(max_seq, embed_dim)", 2, "embedding");
// Arrows to Add
drawArrow(tokenEmbX, tokenEmbY + nodeHeight/2, centerX - 10, addY - 18 - 8, archStep === 1 || archStep === 2);
drawArrow(posEmbX, posEmbY + nodeHeight/2, centerX + 10, addY - 18 - 8, archStep === 2);
// Add circle
const addActive = archStep === 1 || archStep === 2;
svg.append("circle")
.attr("cx", centerX)
.attr("cy", addY)
.attr("r", 18)
.attr("fill", addActive ? colors.embedding : diagramTheme.nodeFill)
.attr("stroke", addActive ? colors.embedding : diagramTheme.nodeStroke)
.attr("stroke-width", addActive ? 2 : 1.5)
.style("filter", addActive ? "url(#arch-glow)" : "none");
svg.append("text")
.attr("x", centerX)
.attr("y", addY)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", addActive ? "#ffffff" : diagramTheme.nodeText)
.attr("font-size", "16px")
.attr("font-weight", "bold")
.text("+");
// Arrow to blocks
drawArrow(centerX, addY + 18, centerX, blocksStartY - blockHeight/2 - 8, archStep === 2 || archStep === 3);
// Transformer Blocks (stacked)
const blocksActive = archStep === 3;
// Background container for blocks
const blocksContainerPad = 12;
const blocksContainerHeight = numBlocks * (blockHeight + blockGap) - blockGap + blocksContainerPad * 2;
svg.append("rect")
.attr("x", centerX - blockWidth/2 - blocksContainerPad)
.attr("y", blocksStartY - blockHeight/2 - blocksContainerPad)
.attr("width", blockWidth + blocksContainerPad * 2)
.attr("height", blocksContainerHeight)
.attr("rx", 8)
.attr("fill", "none")
.attr("stroke", blocksActive ? colors.blocks : diagramTheme.nodeStroke)
.attr("stroke-width", blocksActive ? 2 : 1)
.attr("stroke-dasharray", blocksActive ? "none" : "4,3")
.attr("opacity", blocksActive ? 1 : 0.5);
// Label for blocks section
svg.append("text")
.attr("x", centerX + blockWidth/2 + blocksContainerPad + 10)
.attr("y", blocksStartY + blocksContainerHeight/2 - blockHeight/2 - blocksContainerPad)
.attr("text-anchor", "start")
.attr("dominant-baseline", "central")
.attr("fill", blocksActive ? colors.blocks : diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", blocksActive ? "600" : "400")
.attr("opacity", blocksActive ? 1 : 0.6)
.text("× N Blocks");
// Individual blocks
for (let i = 0; i < numBlocks; i++) {
const blockY = blocksStartY + i * (blockHeight + blockGap);
const g = svg.append("g")
.attr("transform", `translate(${centerX}, ${blockY})`);
g.append("rect")
.attr("x", -blockWidth/2)
.attr("y", -blockHeight/2)
.attr("width", blockWidth)
.attr("height", blockHeight)
.attr("rx", 5)
.attr("fill", blocksActive ? colors.blocks : diagramTheme.nodeFill)
.attr("stroke", blocksActive ? colors.blocks : diagramTheme.nodeStroke)
.attr("stroke-width", blocksActive ? 2 : 1)
.attr("opacity", blocksActive ? 1 : 0.7);
g.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", blocksActive ? "#ffffff" : diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("font-weight", "500")
.text(`Block ${i}`);
// Arrow between blocks (except last)
if (i < numBlocks - 1) {
const nextY = blockY + blockHeight/2 + blockGap/2;
svg.append("line")
.attr("x1", centerX)
.attr("y1", blockY + blockHeight/2)
.attr("x2", centerX)
.attr("y2", blockY + blockHeight + blockGap - blockHeight/2)
.attr("stroke", blocksActive ? colors.blocks : diagramTheme.edgeStroke)
.attr("stroke-width", 1)
.attr("opacity", blocksActive ? 0.8 : 0.4);
}
}
// Arrow to Final LayerNorm
drawArrow(centerX, blocksEndY + blockHeight/2, centerX, lnFinalY - nodeHeight/2 - 8, archStep === 3 || archStep === 4);
// Final LayerNorm
drawNode(centerX, lnFinalY, nodeWidth, nodeHeight, "Final LayerNorm", "Normalize", 4, "output");
// Arrow to LM Head
drawArrow(centerX, lnFinalY + nodeHeight/2, centerX, lmHeadY - nodeHeight/2 - 8, archStep === 4 || archStep === 5);
// LM Head
drawNode(centerX, lmHeadY, nodeWidth, nodeHeight, "LM Head", "Linear → vocab_size", 5, "output");
// Arrow to Logits
drawArrow(centerX, lmHeadY + nodeHeight/2, centerX, logitsY - nodeHeight/2 - 8, archStep === 5);
// Logits
drawNode(centerX, logitsY, nodeWidth, nodeHeight, "Logits", "[batch, seq, vocab]", 5, "output");
// === Info panel on the right ===
const infoX = width - marginRight + 20;
const infoY = marginTop + 50;
const infoWidth = marginRight - 30;
// Info panel background
svg.append("rect")
.attr("x", infoX - 10)
.attr("y", infoY - 10)
.attr("width", infoWidth)
.attr("height", 150)
.attr("rx", 8)
.attr("fill", diagramTheme.bgSecondary)
.attr("stroke", diagramTheme.nodeStroke)
.attr("stroke-width", 1)
.attr("opacity", 0.5);
// Step indicator
svg.append("text")
.attr("x", infoX)
.attr("y", infoY + 5)
.attr("fill", diagramTheme.highlight)
.attr("font-size", "10px")
.attr("font-weight", "600")
.attr("text-transform", "uppercase")
.attr("letter-spacing", "0.5px")
.text(`Step ${archStep + 1} of 6`);
// Step name
svg.append("text")
.attr("x", infoX)
.attr("y", infoY + 28)
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "13px")
.attr("font-weight", "700")
.text(currentArchStep.name);
// Description (wrap text)
const desc = currentArchStep.description;
const descWords = desc.split(" ");
let descLine = "";
let descY = infoY + 50;
const maxDescWidth = infoWidth - 20;
descWords.forEach((word, i) => {
const testLine = descLine + (descLine ? " " : "") + word;
if (testLine.length > 22) {
svg.append("text")
.attr("x", infoX)
.attr("y", descY)
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.8)
.text(descLine);
descY += 14;
descLine = word;
} else {
descLine = testLine;
}
});
if (descLine) {
svg.append("text")
.attr("x", infoX)
.attr("y", descY)
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.8)
.text(descLine);
}
// Shape info
svg.append("text")
.attr("x", infoX)
.attr("y", infoY + 110)
.attr("fill", diagramTheme.accent)
.attr("font-size", "9px")
.attr("font-weight", "500")
.text("Shape:");
svg.append("text")
.attr("x", infoX)
.attr("y", infoY + 124)
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("font-family", "'JetBrains Mono', monospace")
.text(currentArchStep.shape);
return svg.node();
}
```
::: {.callout-tip}
## Interactive Architecture Walkthrough
Use the slider above to step through the forward pass. Each stage shows how tensor shapes transform as data flows through the model:
- **Input**: Raw token IDs (integers)
- **Embeddings**: Dense vectors capturing meaning and position
- **Blocks**: Iterative refinement through attention and FFN
- **Output**: Probability distribution over vocabulary
:::
## Single Transformer Block (Pre-Norm)
```{ojs}
//| echo: false
// Step definitions for transformer block walkthrough
blockSteps = [
{
id: 0,
name: "Input",
description: "Embeddings enter the transformer block unchanged initially.",
highlight: ["input"],
activeEdges: []
},
{
id: 1,
name: "Attention Path",
description: "Pre-Norm: LayerNorm first, then Multi-Head Attention with causal mask, then Dropout.",
highlight: ["ln1", "attn", "drop1"],
activeEdges: ["e_in_ln1", "e_ln1_attn", "e_attn_drop1"]
},
{
id: 2,
name: "First Residual",
description: "Add original input to attention output: x + Attention(LayerNorm(x)). Skip connection preserves gradients.",
highlight: ["add1", "residual1"],
activeEdges: ["e_in_add1", "e_drop1_add1"]
},
{
id: 3,
name: "FFN Path",
description: "LayerNorm, then Feed-Forward Network (expand 4x, GELU, project back). Each token processed independently.",
highlight: ["ln2", "ffn"],
activeEdges: ["e_add1_ln2", "e_ln2_ffn"]
},
{
id: 4,
name: "Second Residual",
description: "Add intermediate to FFN output: x + FFN(LayerNorm(x)). Two residuals per block enable deep stacking.",
highlight: ["add2", "residual2", "output"],
activeEdges: ["e_add1_add2", "e_ffn_add2", "e_add2_out"]
}
]
```
```{ojs}
//| echo: false
// Step slider control for transformer block
viewof blockStep = Inputs.range([0, 4], {
value: 0,
step: 1,
label: "Step"
})
```
```{ojs}
//| echo: false
// Current step info
currentBlockStep = blockSteps[blockStep]
```
```{ojs}
//| echo: false
// Transformer block diagram
{
const width = 720;
const height = 580;
const theme = diagramTheme;
const step = blockStep;
const stepInfo = currentBlockStep;
const svg = d3.create("svg")
.attr("viewBox", `0 0 ${width} ${height}`)
.attr("width", "100%")
.attr("height", height)
.style("max-width", `${width}px`)
.style("font-family", "'IBM Plex Mono', 'JetBrains Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 12);
// Defs for filters and markers
const defs = svg.append("defs");
// Glow filter
const glowFilter = defs.append("filter")
.attr("id", "block-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "4")
.attr("result", "coloredBlur");
const glowMerge = glowFilter.append("feMerge");
glowMerge.append("feMergeNode").attr("in", "coloredBlur");
glowMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Arrow markers
defs.append("marker")
.attr("id", "block-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", theme.edgeStroke);
defs.append("marker")
.attr("id", "block-arrow-attn")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", "#22d3ee");
defs.append("marker")
.attr("id", "block-arrow-ffn")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", "#a78bfa");
defs.append("marker")
.attr("id", "block-arrow-res")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", "#fb923c");
// Color scheme
const colors = {
attention: "#22d3ee", // Cyan for attention path
ffn: "#a78bfa", // Purple for FFN path
residual: "#fb923c", // Orange for residual connections
norm: "#6b7280", // Gray for LayerNorm
io: theme.accent // Theme accent for input/output
};
// Helper: check if element is active
const isActive = (id) => stepInfo.highlight.includes(id);
const isEdgeActive = (id) => stepInfo.activeEdges.includes(id);
// Layout
const centerX = 280;
const rightX = 500;
const nodeWidth = 140;
const nodeHeight = 44;
// Vertical positions
const inputY = 80;
const ln1Y = 160;
const attnY = 230;
const drop1Y = 300;
const add1Y = 365;
const ln2Y = 430;
const ffnY = 490;
const add2Y = 490;
const outputY = 550;
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 28)
.attr("text-anchor", "middle")
.attr("fill", theme.highlight)
.attr("font-size", "15px")
.attr("font-weight", "700")
.text(`Transformer Block: ${stepInfo.name}`);
svg.append("text")
.attr("x", width / 2)
.attr("y", 50)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.75)
.text(stepInfo.description);
// Helper: draw a node box
function drawNode(x, y, label, sublabel, id, color) {
const active = isActive(id);
const g = svg.append("g").attr("transform", `translate(${x}, ${y})`);
g.append("rect")
.attr("x", -nodeWidth/2)
.attr("y", -nodeHeight/2)
.attr("width", nodeWidth)
.attr("height", nodeHeight)
.attr("rx", 6)
.attr("fill", active ? color : theme.nodeFill)
.attr("fill-opacity", active ? 0.25 : 1)
.attr("stroke", active ? color : theme.nodeStroke)
.attr("stroke-width", active ? 2.5 : 1.5)
.attr("filter", active ? "url(#block-glow)" : null);
g.append("text")
.attr("y", sublabel ? -7 : 0)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", active ? color : theme.nodeText)
.attr("font-size", "12px")
.attr("font-weight", active ? "600" : "500")
.text(label);
if (sublabel) {
g.append("text")
.attr("y", 10)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", active ? color : theme.nodeText)
.attr("font-size", "9px")
.attr("opacity", active ? 0.9 : 0.6)
.text(sublabel);
}
}
// Helper: draw a circle node (for + operations)
function drawCircle(x, y, label, id, color) {
const active = isActive(id);
const g = svg.append("g").attr("transform", `translate(${x}, ${y})`);
g.append("circle")
.attr("r", 20)
.attr("fill", active ? color : theme.nodeFill)
.attr("fill-opacity", active ? 0.3 : 1)
.attr("stroke", active ? color : theme.nodeStroke)
.attr("stroke-width", active ? 2.5 : 1.5)
.attr("filter", active ? "url(#block-glow)" : null);
g.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", active ? color : theme.nodeText)
.attr("font-size", "18px")
.attr("font-weight", "bold")
.text(label);
}
// Helper: draw an arrow
function drawArrow(x1, y1, x2, y2, edgeId, color, curved = false, curveDir = 1) {
const active = isEdgeActive(edgeId);
const markerColor = color === colors.attention ? "block-arrow-attn" :
color === colors.ffn ? "block-arrow-ffn" :
color === colors.residual ? "block-arrow-res" : "block-arrow";
let pathD;
if (curved) {
const midX = (x1 + x2) / 2 + curveDir * 40;
pathD = `M${x1},${y1} Q${midX},${(y1+y2)/2} ${x2},${y2}`;
} else {
pathD = `M${x1},${y1} L${x2},${y2}`;
}
svg.append("path")
.attr("d", pathD)
.attr("fill", "none")
.attr("stroke", active ? color : theme.edgeStroke)
.attr("stroke-width", active ? 2.5 : 1.5)
.attr("stroke-opacity", active ? 1 : 0.5)
.attr("marker-end", active ? `url(#${markerColor})` : "url(#block-arrow)")
.attr("filter", active ? "url(#block-glow)" : null);
}
// === Draw the architecture ===
// Input node
drawNode(centerX, inputY, "Input x", "(batch, seq, embed)", "input", colors.io);
// --- Attention Path (left side) ---
// LayerNorm 1
drawNode(centerX, ln1Y, "LayerNorm", null, "ln1", colors.attention);
// Multi-Head Attention
drawNode(centerX, attnY, "Multi-Head Attention", "(causal mask)", "attn", colors.attention);
// Dropout
drawNode(centerX, drop1Y, "Dropout", null, "drop1", colors.attention);
// Arrows in attention path
drawArrow(centerX, inputY + nodeHeight/2, centerX, ln1Y - nodeHeight/2 - 8, "e_in_ln1", colors.attention);
drawArrow(centerX, ln1Y + nodeHeight/2, centerX, attnY - nodeHeight/2 - 8, "e_ln1_attn", colors.attention);
drawArrow(centerX, attnY + nodeHeight/2, centerX, drop1Y - nodeHeight/2 - 8, "e_attn_drop1", colors.attention);
// First residual add
drawCircle(centerX, add1Y, "+", "add1", colors.residual);
// Arrow from dropout to add1
drawArrow(centerX, drop1Y + nodeHeight/2, centerX, add1Y - 20 - 8, "e_drop1_add1", colors.attention);
// Residual connection 1 (skip from input to add1)
const res1Active = isActive("residual1");
// Draw the skip connection path on the right
svg.append("path")
.attr("d", `M${centerX + nodeWidth/2},${inputY} L${centerX + 90},${inputY} L${centerX + 90},${add1Y} L${centerX + 20},${add1Y}`)
.attr("fill", "none")
.attr("stroke", res1Active ? colors.residual : theme.edgeStroke)
.attr("stroke-width", res1Active ? 2.5 : 1.5)
.attr("stroke-dasharray", "6,3")
.attr("stroke-opacity", res1Active ? 1 : 0.4)
.attr("marker-end", res1Active ? "url(#block-arrow-res)" : "url(#block-arrow)")
.attr("filter", res1Active ? "url(#block-glow)" : null);
// Residual label
if (res1Active) {
svg.append("text")
.attr("x", centerX + 100)
.attr("y", (inputY + add1Y) / 2)
.attr("text-anchor", "start")
.attr("fill", colors.residual)
.attr("font-size", "10px")
.attr("font-weight", "600")
.text("residual");
}
// --- FFN Path ---
// LayerNorm 2
drawNode(centerX, ln2Y, "LayerNorm", null, "ln2", colors.ffn);
// Feed-Forward Network
drawNode(rightX, ffnY, "Feed-Forward", "(4x expand, GELU)", "ffn", colors.ffn);
// Arrow from add1 to ln2
drawArrow(centerX, add1Y + 20, centerX, ln2Y - nodeHeight/2 - 8, "e_add1_ln2", colors.ffn);
// Arrow from ln2 to ffn (horizontal then down)
const ln2ToFfnActive = isEdgeActive("e_ln2_ffn");
svg.append("path")
.attr("d", `M${centerX + nodeWidth/2},${ln2Y} L${rightX},${ln2Y} L${rightX},${ffnY - nodeHeight/2 - 8}`)
.attr("fill", "none")
.attr("stroke", ln2ToFfnActive ? colors.ffn : theme.edgeStroke)
.attr("stroke-width", ln2ToFfnActive ? 2.5 : 1.5)
.attr("stroke-opacity", ln2ToFfnActive ? 1 : 0.5)
.attr("marker-end", ln2ToFfnActive ? "url(#block-arrow-ffn)" : "url(#block-arrow)")
.attr("filter", ln2ToFfnActive ? "url(#block-glow)" : null);
// Second residual add
drawCircle(centerX, add2Y, "+", "add2", colors.residual);
// Arrow from ffn to add2
const ffnToAddActive = isEdgeActive("e_ffn_add2");
svg.append("path")
.attr("d", `M${rightX - nodeWidth/2},${ffnY} L${centerX + 20},${add2Y}`)
.attr("fill", "none")
.attr("stroke", ffnToAddActive ? colors.ffn : theme.edgeStroke)
.attr("stroke-width", ffnToAddActive ? 2.5 : 1.5)
.attr("stroke-opacity", ffnToAddActive ? 1 : 0.5)
.attr("marker-end", ffnToAddActive ? "url(#block-arrow-ffn)" : "url(#block-arrow)")
.attr("filter", ffnToAddActive ? "url(#block-glow)" : null);
// Residual connection 2 (from add1 to add2 - the skip)
const res2Active = isActive("residual2");
svg.append("path")
.attr("d", `M${centerX - 20},${add1Y} L${centerX - 60},${add1Y} L${centerX - 60},${add2Y} L${centerX - 20},${add2Y}`)
.attr("fill", "none")
.attr("stroke", res2Active ? colors.residual : theme.edgeStroke)
.attr("stroke-width", res2Active ? 2.5 : 1.5)
.attr("stroke-dasharray", "6,3")
.attr("stroke-opacity", res2Active ? 1 : 0.4)
.attr("marker-end", res2Active ? "url(#block-arrow-res)" : "url(#block-arrow)")
.attr("filter", res2Active ? "url(#block-glow)" : null);
// Residual 2 label
if (res2Active) {
svg.append("text")
.attr("x", centerX - 70)
.attr("y", (add1Y + add2Y) / 2)
.attr("text-anchor", "end")
.attr("fill", colors.residual)
.attr("font-size", "10px")
.attr("font-weight", "600")
.text("residual");
}
// Output node
drawNode(centerX, outputY, "Output", "(batch, seq, embed)", "output", colors.io);
// Arrow from add2 to output
drawArrow(centerX, add2Y + 20, centerX, outputY - nodeHeight/2 - 8, "e_add2_out", colors.io);
// === Legend ===
const legendX = 600;
const legendY = 120;
svg.append("text")
.attr("x", legendX)
.attr("y", legendY)
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text("Legend");
// Attention path
svg.append("rect")
.attr("x", legendX)
.attr("y", legendY + 15)
.attr("width", 14)
.attr("height", 14)
.attr("rx", 3)
.attr("fill", colors.attention)
.attr("fill-opacity", 0.3)
.attr("stroke", colors.attention);
svg.append("text")
.attr("x", legendX + 20)
.attr("y", legendY + 25)
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text("Attention path");
// FFN path
svg.append("rect")
.attr("x", legendX)
.attr("y", legendY + 38)
.attr("width", 14)
.attr("height", 14)
.attr("rx", 3)
.attr("fill", colors.ffn)
.attr("fill-opacity", 0.3)
.attr("stroke", colors.ffn);
svg.append("text")
.attr("x", legendX + 20)
.attr("y", legendY + 48)
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text("FFN path");
// Residual
svg.append("line")
.attr("x1", legendX)
.attr("y1", legendY + 68)
.attr("x2", legendX + 14)
.attr("y2", legendY + 68)
.attr("stroke", colors.residual)
.attr("stroke-width", 2)
.attr("stroke-dasharray", "4,2");
svg.append("text")
.attr("x", legendX + 20)
.attr("y", legendY + 71)
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.text("Residual (skip)");
// Pre-Norm annotation
svg.append("rect")
.attr("x", legendX - 10)
.attr("y", legendY + 95)
.attr("width", 110)
.attr("height", 50)
.attr("rx", 6)
.attr("fill", theme.bgSecondary)
.attr("stroke", theme.nodeStroke)
.attr("stroke-width", 1);
svg.append("text")
.attr("x", legendX)
.attr("y", legendY + 113)
.attr("fill", theme.highlight)
.attr("font-size", "10px")
.attr("font-weight", "600")
.text("Pre-Norm Pattern:");
svg.append("text")
.attr("x", legendX)
.attr("y", legendY + 130)
.attr("fill", theme.nodeText)
.attr("font-size", "9px")
.attr("font-family", "'IBM Plex Mono', monospace")
.text("x + f(LayerNorm(x))");
return svg.node();
}
```
The key innovation is the **residual connections** (the + nodes). Instead of `y = f(x)`, we compute `y = x + f(x)`. This:
- Helps gradients flow through deep networks
- Makes it easy to learn identity (just set `f(x) = 0`)
- Enables training of 100+ layer networks
## The Components
In this section, we build each component from scratch before showing the PyTorch equivalents. The pattern is: understand the math, implement it simply, then see how PyTorch optimizes it.
### LayerNorm from Scratch
**The Idea**: Neural network activations can drift to extreme values during training, causing gradients to explode or vanish. Layer normalization fixes this by normalizing each token's embedding to have mean 0 and variance 1, then applying learnable scale and shift parameters.
**The Formula**:
$$\text{LayerNorm}(x) = \gamma \times \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$
where:
- $\mu$ = mean across the embedding dimension
- $\sigma^2$ = variance across the embedding dimension
- $\gamma$ (gamma) = learnable scale parameter (initialized to 1)
- $\beta$ (beta) = learnable shift parameter (initialized to 0)
- $\epsilon$ = small constant for numerical stability (typically 1e-5)
**From Scratch Implementation**:
```{python}
import numpy as np
import torch
import torch.nn as nn
class LayerNormScratch:
"""Layer normalization from scratch using NumPy-style operations."""
def __init__(self, dim, eps=1e-5):
# Learnable parameters
self.gamma = np.ones((dim,)) # scale (initialized to 1)
self.beta = np.zeros((dim,)) # shift (initialized to 0)
self.eps = eps
def __call__(self, x):
"""
Args:
x: input array of shape (..., dim)
Returns:
normalized array of same shape
"""
# Step 1: Compute mean across last dimension
mean = x.mean(axis=-1, keepdims=True)
# Step 2: Compute variance across last dimension
var = ((x - mean) ** 2).mean(axis=-1, keepdims=True)
# Step 3: Normalize (the "norm" in LayerNorm)
x_norm = (x - mean) / np.sqrt(var + self.eps)
# Step 4: Scale and shift with learnable parameters
return self.gamma * x_norm + self.beta
# Test our from-scratch implementation
x = np.array([[2.0, 4.0, 6.0, 8.0],
[1.0, 2.0, 3.0, 4.0]])
ln_scratch = LayerNormScratch(dim=4)
out_scratch = ln_scratch(x)
print("LayerNorm from Scratch:")
print(f" Input:\n{x}")
print(f" Output:\n{np.round(out_scratch, 4)}")
print(f" Output mean per row: {out_scratch.mean(axis=-1).round(6)}")
print(f" Output std per row: {out_scratch.std(axis=-1).round(4)}")
```
**PyTorch's nn.LayerNorm**:
```{python}
# PyTorch's optimized implementation
ln_pytorch = nn.LayerNorm(4, elementwise_affine=True)
# Initialize to match our scratch version (gamma=1, beta=0)
nn.init.ones_(ln_pytorch.weight)
nn.init.zeros_(ln_pytorch.bias)
x_torch = torch.tensor(x, dtype=torch.float32)
out_pytorch = ln_pytorch(x_torch)
print("PyTorch LayerNorm:")
print(f" Output:\n{out_pytorch.detach().numpy().round(4)}")
print(f" Matches scratch: {np.allclose(out_scratch, out_pytorch.detach().numpy(), atol=1e-5)}")
```
**Key Insight**: LayerNorm is just normalize-scale-shift. The learnable $\gamma$ and $\beta$ let the network undo the normalization if needed, but start from a stable baseline. Unlike BatchNorm, LayerNorm normalizes across features (embedding dimension) rather than across batch, making it suitable for variable-length sequences.
### Dropout: Regularization by Noise
**The Idea**: During training, randomly "drop" (zero out) some activations. This prevents the network from relying too heavily on any single feature and encourages redundancy. The key trick: scale remaining values by $\frac{1}{1-p}$ so the expected value stays the same.
**Why it works**:
- Forces the network to learn redundant representations
- Acts like training an ensemble of sub-networks
- At inference time, use all neurons (no dropout)
**From Scratch Implementation**:
```{python}
class DropoutScratch:
"""Dropout from scratch."""
def __init__(self, p=0.1):
"""
Args:
p: probability of dropping each element (not keeping!)
"""
self.p = p
def __call__(self, x, training=True):
"""
Args:
x: input array
training: if False, return x unchanged
Returns:
x with dropout applied (if training)
"""
if not training or self.p == 0:
return x
# Create random mask: True where we KEEP the value
keep_prob = 1 - self.p
mask = np.random.random(x.shape) < keep_prob
# Apply mask and scale by 1/(1-p)
# This keeps the expected value the same:
# E[x * mask / keep_prob] = x * keep_prob / keep_prob = x
return x * mask / keep_prob
# Demonstrate dropout
np.random.seed(42)
x = np.ones((2, 8))
dropout = DropoutScratch(p=0.5)
print("Dropout from Scratch (p=0.5):")
print(f" Input (all 1s): {x[0]}")
# Apply dropout multiple times to see the randomness
for i in range(3):
np.random.seed(i)
out = dropout(x, training=True)
print(f" Trial {i+1}: {out[0].round(2)}")
print(f" Mean: {out[0].mean():.2f} (should be ~1.0 on average)")
```
**The Scaling Trick Explained**:
```{python}
# Why divide by (1-p)?
# Without scaling, dropout reduces expected output
# With scaling, expected output stays the same
p = 0.5
np.random.seed(0)
x = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
# Without scaling
mask = np.random.random(x.shape) < (1-p)
out_no_scale = x * mask
print(f"Without scaling: {out_no_scale} -> mean = {out_no_scale.mean():.2f}")
# With scaling (divide by keep probability)
np.random.seed(0)
mask = np.random.random(x.shape) < (1-p)
out_scaled = x * mask / (1-p)
print(f"With scaling: {out_scaled} -> mean = {out_scaled.mean():.2f}")
print(f"\nThe scaling keeps expected value at 1.0 despite dropping 50% of values")
```
**PyTorch's nn.Dropout**:
```{python}
# PyTorch handles training/mode automatically
dropout_pytorch = nn.Dropout(p=0.5)
x_torch = torch.ones(2, 8)
# Training mode (dropout active)
dropout_pytorch.train()
torch.manual_seed(42)
out_train = dropout_pytorch(x_torch)
print(f"PyTorch Dropout (training): {out_train[0].numpy()}")
# Inference mode (dropout disabled)
dropout_pytorch.eval()
out_inference = dropout_pytorch(x_torch)
print(f"PyTorch Dropout (inference): {out_inference[0].numpy()}")
```
**Key Insight**: Dropout is just random masking with scaling. The $\frac{1}{1-p}$ factor during training means we do not need to modify anything at inference time - the expected value is already correct.
### Residual Connections: The Highway for Gradients
**The Idea**: Instead of computing $y = f(x)$, compute $y = x + f(x)$. This "skip connection" lets gradients flow directly through the network, solving the vanishing gradient problem in deep networks.
**Why it helps**:
```{python}
import matplotlib.pyplot as plt
# Visualize gradient flow with and without residuals
def gradient_flow_demo():
"""Show how residuals help gradients in deep networks."""
# Simulate a function that shrinks gradients (common in deep nets)
def layer_gradient(g, shrink=0.8):
return g * shrink
# Without residual: gradients multiply
# d(f(f(f(x))))/dx = f'(x) * f'(f(x)) * f'(f(f(x)))
gradients_no_residual = [1.0]
for _ in range(20):
gradients_no_residual.append(layer_gradient(gradients_no_residual[-1]))
# With residual: gradients add
# d(x + f(x))/dx = 1 + f'(x) (the 1 always flows through!)
gradients_with_residual = [1.0]
for _ in range(20):
# Gradient through residual = 1 (skip) + shrink (through f)
g = 1.0 + layer_gradient(gradients_with_residual[-1]) * 0.1
gradients_with_residual.append(min(g, gradients_with_residual[-1] * 1.05))
plt.figure(figsize=(10, 5))
plt.semilogy(gradients_no_residual, 'b-o', label='Without residual', markersize=4)
plt.semilogy(gradients_with_residual, 'r-o', label='With residual', markersize=4)
plt.xlabel('Layer depth')
plt.ylabel('Gradient magnitude (log scale)')
plt.title('Residual Connections Prevent Vanishing Gradients')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
print("Without residuals: gradients vanish exponentially")
print(f" After 20 layers: {gradients_no_residual[-1]:.6f}")
print("With residuals: gradients stay healthy")
print(f" After 20 layers: {gradients_with_residual[-1]:.2f}")
gradient_flow_demo()
```
**Pre-Norm vs Post-Norm**:
The original Transformer used "post-norm": normalize after the residual addition.
```python
# Post-Norm (original Transformer)
x = LayerNorm(x + Attention(x))
x = LayerNorm(x + FFN(x))
```
Modern LLMs use "pre-norm": normalize before each sublayer.
```python
# Pre-Norm (GPT-2, LLaMA, modern LLMs)
x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))
```
**Why Pre-Norm is Better**:
```{python}
# Demonstrate the stability difference
def simulate_forward_pass(num_layers, prenorm=True):
"""Simulate activation magnitudes through layers."""
x = 1.0 # Starting activation magnitude
for _ in range(num_layers):
if prenorm:
# Pre-norm: normalize first, then residual keeps things bounded
normed = 1.0 # After LayerNorm, magnitude is ~1
sublayer_out = normed * 0.5 # Sublayer output
x = x + sublayer_out # Residual addition
else:
# Post-norm: residual can grow, then we normalize
sublayer_out = x * 0.5
x = x + sublayer_out # Can grow unboundedly before norm
x = 1.0 # LayerNorm resets to ~1
return x
print("Activation stability comparison:")
print(f" Pre-norm after 24 layers: ~{simulate_forward_pass(24, prenorm=True):.1f}")
print(f" Post-norm after 24 layers: ~{simulate_forward_pass(24, prenorm=False):.1f}")
print("\nPre-norm has a cleaner gradient path because the skip connection")
print("bypasses normalization - gradients flow directly from output to input.")
```
**Key Insight**: Residual connections transform `y = f(x)` into `y = x + f(x)`. The gradient of this is `dy/dx = 1 + df/dx`. That `+ 1` is crucial - it means gradients always have a direct path through the network, even if `df/dx` is tiny.
### The Full Transformer Block from Scratch
**The Idea**: Now we assemble all the pieces into a complete transformer block:
1. LayerNorm + Multi-Head Attention + Residual
2. LayerNorm + Feed-Forward Network + Residual
```{python}
class FeedForwardScratch:
"""Simple feed-forward network from scratch."""
def __init__(self, embed_dim, ff_dim):
# Initialize weights with small random values
scale = 0.02
self.w1 = np.random.randn(embed_dim, ff_dim) * scale
self.b1 = np.zeros(ff_dim)
self.w2 = np.random.randn(ff_dim, embed_dim) * scale
self.b2 = np.zeros(embed_dim)
def gelu(self, x):
"""GELU activation: x * Phi(x) where Phi is standard normal CDF."""
return 0.5 * x * (1 + np.tanh(np.sqrt(2/np.pi) * (x + 0.044715 * x**3)))
def __call__(self, x):
# Up projection: embed_dim -> ff_dim
h = x @ self.w1 + self.b1
# Activation
h = self.gelu(h)
# Down projection: ff_dim -> embed_dim
return h @ self.w2 + self.b2
# NOTE: This uses a SIMPLIFIED attention (just linear projection) to focus on
# the overall block structure. Real attention with Q, K, V is in attention.py
class TransformerBlockScratch:
"""
A complete transformer block from scratch (with simplified attention).
Architecture (Pre-Norm):
x = x + Attention(LayerNorm(x))
x = x + FeedForward(LayerNorm(x))
WARNING: The attention here is simplified to a linear projection for
demonstration purposes. See m05_attention for full attention implementation.
"""
def __init__(self, embed_dim, num_heads, ff_dim, dropout_p=0.1):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Layer norms
self.ln1 = LayerNormScratch(embed_dim)
self.ln2 = LayerNormScratch(embed_dim)
# Attention projections (simplified: no actual attention computation)
# In a full implementation, this would include Q, K, V projections
scale = 0.02
self.attn_proj = np.random.randn(embed_dim, embed_dim) * scale
# Feed-forward network
self.ff = FeedForwardScratch(embed_dim, ff_dim)
# Dropout
self.dropout = DropoutScratch(dropout_p)
def __call__(self, x, training=True):
"""
Args:
x: input of shape (batch, seq, embed_dim)
training: whether to apply dropout
Returns:
output of shape (batch, seq, embed_dim)
"""
# === Attention sub-block ===
# 1. Layer norm (pre-norm)
normed = self.ln1(x)
# 2. Attention (simplified: just a linear projection for demo)
# Real implementation would compute Q, K, V and attention weights
attn_out = normed @ self.attn_proj
# 3. Dropout
attn_out = self.dropout(attn_out, training=training)
# 4. Residual connection
x = x + attn_out
# === Feed-forward sub-block ===
# 1. Layer norm (pre-norm)
normed = self.ln2(x)
# 2. Feed-forward network
ff_out = self.ff(normed)
# 3. Dropout
ff_out = self.dropout(ff_out, training=training)
# 4. Residual connection
x = x + ff_out
return x
# Test the from-scratch transformer block
np.random.seed(42)
block_scratch = TransformerBlockScratch(
embed_dim=64,
num_heads=4,
ff_dim=256,
dropout_p=0.0 # Disable dropout for reproducibility
)
x = np.random.randn(2, 8, 64) # batch=2, seq=8, embed=64
out = block_scratch(x, training=False)
print("Transformer Block from Scratch:")
print(f" Input shape: {x.shape}")
print(f" Output shape: {out.shape}")
print(f" Input mean: {x.mean():.4f}")
print(f" Output mean: {out.mean():.4f}")
print(f"\nThe block transforms each token while preserving shape.")
print("Residual connections keep the output close to input initially.")
```
**The Complete Picture**:
```{python}
# Visualize the transformer block structure
print("""
Transformer Block (Pre-Norm Architecture):
==========================================
Input x
|
+------------------+
| |
v |
LayerNorm |
| |
v |
Multi-Head Attention |
| |
v |
Dropout |
| |
+--------(+)-------+ <- Residual connection
|
+------------------+
| |
v |
LayerNorm |
| |
v |
Feed-Forward |
| |
v |
Dropout |
| |
+--------(+)-------+ <- Residual connection
|
v
Output
""")
```
**Key Insight**: A Transformer block is just attention + MLP + residuals + norms. That is it. The magic comes from stacking many of these simple blocks and training on lots of data.
### PyTorch Transformer Modules
PyTorch provides optimized versions of everything we built from scratch.
**Comparison Table**:
| Component | From Scratch | PyTorch |
|-----------|--------------|---------|
| LayerNorm | Manual mean/var | `nn.LayerNorm` |
| Dropout | Random mask + scale | `nn.Dropout` |
| FFN | Two linear layers + GELU | Custom or `nn.Sequential` |
| Full Block | Manual assembly | `nn.TransformerDecoderLayer` |
```{python}
# PyTorch's TransformerDecoderLayer
# Note: This is for encoder-decoder models; for decoder-only like GPT,
# we typically build our own (as in transformer.py)
from torch.nn import TransformerDecoderLayer
# Create a decoder layer similar to our scratch implementation
pytorch_block = TransformerDecoderLayer(
d_model=64,
nhead=4,
dim_feedforward=256,
dropout=0.1,
activation='gelu',
batch_first=True,
norm_first=True # Pre-norm architecture
)
x_torch = torch.randn(2, 8, 64)
# For decoder-only, we use self-attention (memory = x)
pytorch_block.eval()
out_pytorch = pytorch_block(x_torch, x_torch)
print("PyTorch TransformerDecoderLayer:")
print(f" Input shape: {tuple(x_torch.shape)}")
print(f" Output shape: {tuple(out_pytorch.shape)}")
print(f" Parameters: {sum(p.numel() for p in pytorch_block.parameters()):,}")
```
**When to Use What**:
- **Learning**: Build from scratch to understand every step
- **Production**: Use PyTorch's optimized modules
- **Custom architectures**: Mix both - understand the components, then optimize
```{python}
# Our module's TransformerBlock (production quality)
from transformer import TransformerBlock
our_block = TransformerBlock(
embed_dim=64,
num_heads=4,
ff_dim=256,
dropout=0.1
)
x_torch = torch.randn(2, 8, 64)
our_block.eval()
out_ours = our_block(x_torch)
print("Our TransformerBlock (from transformer.py):")
print(f" Input shape: {tuple(x_torch.shape)}")
print(f" Output shape: {tuple(out_ours.shape)}")
print(f" Parameters: {sum(p.numel() for p in our_block.parameters()):,}")
print("\nThis is what we use for training - it includes proper")
print("causal attention, not the simplified version in scratch code.")
```
---
## More Component Details
### Layer Normalization (PyTorch Details)
Normalizes activations across the embedding dimension:
$$\text{LayerNorm}(x) = \gamma \times \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$
where $\mu$ and $\sigma^2$ are the mean and variance across the embedding dimension, and $\gamma$, $\beta$ are learnable parameters.
Why it helps:
- **Stabilizes activations**: Prevents values from exploding or vanishing
- **Faster training**: More stable gradients
- **Independent per token**: Each token normalized separately
### Feed-Forward Network (FFN)
```{ojs}
//| echo: false
// Step definitions for FFN walkthrough
ffnSteps = [
{
id: 0,
name: "Input",
description: "Input tensor arrives with embed_dim features per token",
activeNode: "input",
activeEdge: null
},
{
id: 1,
name: "First Linear (Expand)",
description: "Project from embed_dim to 4x embed_dim - the expansion gives capacity for complex transformations",
activeNode: "linear1",
activeEdge: "e_input_linear1"
},
{
id: 2,
name: "GELU Activation",
description: "Apply GELU nonlinearity - smoother than ReLU, allows gradients to flow",
activeNode: "gelu",
activeEdge: "e_linear1_gelu"
},
{
id: 3,
name: "Second Linear (Project)",
description: "Project back from 4x embed_dim to embed_dim - compress the expanded representation",
activeNode: "linear2",
activeEdge: "e_gelu_linear2"
},
{
id: 4,
name: "Output",
description: "Output tensor has the same shape as input - ready for residual connection",
activeNode: "output",
activeEdge: "e_linear2_output"
}
]
```
```{ojs}
//| echo: false
// Step slider for FFN
viewof ffnStep = Inputs.range([0, 4], {
value: 0,
step: 1,
label: "FFN Step"
})
```
```{ojs}
//| echo: false
currentFFNStep = ffnSteps[ffnStep]
```
```{ojs}
//| echo: false
// Draw the FFN diagram
{
const width = 700;
const height = 180;
const marginX = 50;
const marginY = 40;
// Node dimensions - varying heights to show dimension expansion
const baseHeight = 50;
const expandedHeight = 100; // 4x expansion shown visually as 2x height
const nodeWidth = 100;
// Horizontal positions (evenly spaced)
const spacing = (width - 2 * marginX) / 4;
const inputX = marginX;
const linear1X = marginX + spacing;
const geluX = marginX + spacing * 2;
const linear2X = marginX + spacing * 3;
const outputX = marginX + spacing * 4;
const centerY = height / 2;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', 'SF Mono', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 8);
// Defs for markers and filters
const defs = svg.append("defs");
// Glow filter
const glowFilter = defs.append("filter")
.attr("id", "ffn-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "4")
.attr("result", "coloredBlur");
const glowMerge = glowFilter.append("feMerge");
glowMerge.append("feMergeNode").attr("in", "coloredBlur");
glowMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Arrow markers
defs.append("marker")
.attr("id", "ffn-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.edgeStroke);
defs.append("marker")
.attr("id", "ffn-arrow-active")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.highlight);
// Color for different node types
const colors = {
io: diagramTheme.accent, // Input/Output
linear: "#8b5cf6", // Linear layers (purple)
activation: "#10b981" // Activation (emerald)
};
// Helper to draw a node with variable height
function drawNode(x, y, w, h, label, sublabel, nodeId, colorType) {
const isActive = currentFFNStep.activeNode === nodeId;
const baseColor = colors[colorType] || diagramTheme.nodeFill;
const g = svg.append("g")
.attr("transform", `translate(${x}, ${y})`);
g.append("rect")
.attr("x", -w/2)
.attr("y", -h/2)
.attr("width", w)
.attr("height", h)
.attr("rx", 6)
.attr("fill", isActive ? baseColor : diagramTheme.nodeFill)
.attr("stroke", isActive ? baseColor : diagramTheme.nodeStroke)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.attr("opacity", isActive ? 1 : 0.85)
.style("filter", isActive ? "url(#ffn-glow)" : "none");
const textColor = isActive ? "#ffffff" : diagramTheme.nodeText;
if (sublabel) {
g.append("text")
.attr("y", -8)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(label);
g.append("text")
.attr("y", 8)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "9px")
.attr("opacity", isActive ? 0.95 : 0.7)
.text(sublabel);
} else {
g.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(label);
}
return g;
}
// Helper to draw an edge
function drawEdge(x1, y1, x2, y2, edgeId) {
const isActive = currentFFNStep.activeEdge === edgeId;
const strokeColor = isActive ? diagramTheme.highlight : diagramTheme.edgeStroke;
const markerId = isActive ? "ffn-arrow-active" : "ffn-arrow";
const path = svg.append("path")
.attr("d", `M${x1},${y1} L${x2},${y2}`)
.attr("fill", "none")
.attr("stroke", strokeColor)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.attr("marker-end", `url(#${markerId})`);
if (isActive) {
path.style("filter", `drop-shadow(0 0 4px ${diagramTheme.highlightGlow})`);
}
return path;
}
// Draw edges first (so nodes appear on top)
// Edge offsets to connect to node edges
const edgeGap = 8;
// Input -> Linear1
drawEdge(inputX + nodeWidth/2 + edgeGap, centerY, linear1X - nodeWidth/2 - edgeGap, centerY, "e_input_linear1");
// Linear1 -> GELU
drawEdge(linear1X + nodeWidth/2 + edgeGap, centerY, geluX - nodeWidth/2 - edgeGap, centerY, "e_linear1_gelu");
// GELU -> Linear2
drawEdge(geluX + nodeWidth/2 + edgeGap, centerY, linear2X - nodeWidth/2 - edgeGap, centerY, "e_gelu_linear2");
// Linear2 -> Output
drawEdge(linear2X + nodeWidth/2 + edgeGap, centerY, outputX - nodeWidth/2 - edgeGap, centerY, "e_linear2_output");
// Draw nodes
// Input (embed_dim)
drawNode(inputX, centerY, nodeWidth, baseHeight, "Input", "(embed_dim)", "input", "io");
// Linear1 - expanded height to show 4x
drawNode(linear1X, centerY, nodeWidth, expandedHeight, "Linear", "→ 4x embed", "linear1", "linear");
// GELU - also expanded (operates on 4x)
drawNode(geluX, centerY, nodeWidth, expandedHeight, "GELU", "(4x embed)", "gelu", "activation");
// Linear2 - expanded input, compressed output (show as expanded)
drawNode(linear2X, centerY, nodeWidth, expandedHeight, "Linear", "→ embed_dim", "linear2", "linear");
// Output (embed_dim)
drawNode(outputX, centerY, nodeWidth, baseHeight, "Output", "(embed_dim)", "output", "io");
// Dimension labels below the flow
const labelY = centerY + 60;
const dimLabelStyle = {
"font-size": "9px",
"fill": diagramTheme.nodeText,
"opacity": 0.6
};
// Add "4x expansion" annotation in the middle section
svg.append("text")
.attr("x", (linear1X + linear2X) / 2)
.attr("y", centerY - expandedHeight/2 - 12)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.highlight)
.attr("font-size", "10px")
.attr("font-weight", "500")
.text("4x capacity expansion");
// Small arrows to indicate expansion and compression
const arrowY = centerY - expandedHeight/2 - 6;
svg.append("path")
.attr("d", `M${linear1X - 20},${arrowY} L${linear1X + 20},${arrowY}`)
.attr("stroke", diagramTheme.highlight)
.attr("stroke-width", 1)
.attr("opacity", 0.5);
svg.append("path")
.attr("d", `M${linear2X - 20},${arrowY} L${linear2X + 20},${arrowY}`)
.attr("stroke", diagramTheme.highlight)
.attr("stroke-width", 1)
.attr("opacity", 0.5);
return svg.node();
}
```
```{ojs}
//| echo: false
// Step description display
html`<div style="
background: ${diagramTheme.bgSecondary};
border: 1px solid ${diagramTheme.nodeStroke};
border-radius: 6px;
padding: 12px 16px;
margin-top: 8px;
font-family: 'JetBrains Mono', 'Fira Code', monospace;
font-size: 13px;
">
<strong style="color: ${diagramTheme.highlight};">${currentFFNStep.name}</strong>
<span style="color: ${diagramTheme.nodeText}; opacity: 0.9;"> — ${currentFFNStep.description}</span>
</div>`
```
The FFN is a mini neural network applied to each token independently:
- **4x expansion**: More capacity to learn complex transformations
- **GELU activation**: Smoother than ReLU, better gradients
- **Same for all tokens**: Unlike attention, no mixing between positions
### Pre-Norm vs Post-Norm
We use **Pre-Norm** (LayerNorm before attention/FFN) rather than Post-Norm:
```python
# Pre-Norm (GPT-2, LLaMA, modern LLMs)
x = x + Attention(LayerNorm(x))
# Post-Norm (original Transformer paper)
x = LayerNorm(x + Attention(x))
```
**Why Pre-Norm is preferred:**
1. **Cleaner gradient path**: The residual connection bypasses normalization, so gradients flow directly
2. **More stable training**: Especially important for deep networks (24+ layers)
3. **Requires final LayerNorm**: Since the last block's output isn't normalized, we add a final LayerNorm before the output projection
Post-Norm can achieve slightly better final performance with careful hyperparameter tuning, but Pre-Norm is more robust and easier to train.
## Code Walkthrough
Let's build and explore transformer blocks:
```{python}
import sys
import importlib.util
from pathlib import Path
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
print(f"PyTorch version: {torch.__version__}")
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Device: {device}")
```
### GELU Activation
GPT uses GELU instead of ReLU. Let's see why:
```{python}
import torch.nn.functional as F
x = torch.linspace(-3, 3, 100)
gelu_out = F.gelu(x)
relu_out = torch.relu(x)
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(x.numpy(), relu_out.numpy(), 'b-', label='ReLU', linewidth=2)
plt.plot(x.numpy(), gelu_out.numpy(), 'r-', label='GELU', linewidth=2)
plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
plt.axvline(x=0, color='k', linestyle='--', alpha=0.3)
plt.xlabel('x')
plt.ylabel('Activation')
plt.legend()
plt.title('GELU vs ReLU')
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
x_zoom = torch.linspace(-1, 1, 100)
plt.plot(x_zoom.numpy(), torch.relu(x_zoom).numpy(), 'b-', label='ReLU', linewidth=2)
plt.plot(x_zoom.numpy(), F.gelu(x_zoom).numpy(), 'r-', label='GELU', linewidth=2)
plt.xlabel('x')
plt.title('Zoomed: GELU is smooth at 0')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("Key difference: GELU is smooth everywhere!")
print("ReLU has a sharp corner at x=0, which can cause gradient issues.")
```
**GELU formula**: $\text{GELU}(x) = x \cdot \Phi(x)$ where $\Phi$ is the standard normal CDF.
Approximation used in practice: $0.5x(1 + \tanh(\sqrt{2/\pi}(x + 0.044715x^3)))$
**Activation function choices in modern LLMs:**
| Model | Activation | Notes |
|-------|-----------|-------|
| GPT-2, BERT | GELU | Smooth, good gradients |
| LLaMA, Mistral | SwiGLU | Gated variant, better performance |
| GPT-3 | GELU | Same as GPT-2 |
**SwiGLU** (used in LLaMA) is a gated linear unit: $\text{SwiGLU}(x) = \text{Swish}(xW_1) \otimes xW_2$. It requires an extra linear layer but often improves performance. Our implementation uses standard GELU to match GPT-2.
### Layer Normalization Demo
```{python}
# Manual LayerNorm demonstration
x = torch.tensor([[2.0, 4.0, 6.0, 8.0]])
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, unbiased=False, keepdim=True)
normalized = (x - mean) / torch.sqrt(var + 1e-5)
print("Manual LayerNorm:")
print(f" Input: {x.numpy().tolist()}")
print(f" Mean: {mean.item():.2f}")
print(f" Variance: {var.item():.2f}")
print(f" Normalized: {normalized.numpy().round(2).tolist()}")
print(f" New mean: {normalized.mean().item():.4f}")
print(f" New std: {normalized.std().item():.4f}")
```
```{python}
# PyTorch LayerNorm (with learnable gamma and beta)
ln = nn.LayerNorm(4)
pytorch_normalized = ln(x)
print(f"PyTorch LayerNorm output: {pytorch_normalized.detach().numpy().round(2).tolist()}")
print("(gamma and beta are learnable parameters)")
```
### Residual Connections Demo
```{python}
# Demonstrate gradient flow with and without residuals
def simple_layer(x):
"""A simple transformation that shrinks values."""
return x * 0.5 + 0.1
# Stack 10 layers WITHOUT residual
x = torch.tensor([1.0])
outputs_no_residual = [x.item()]
for _ in range(10):
x = simple_layer(x)
outputs_no_residual.append(x.item())
# Stack 10 layers WITH residual
x = torch.tensor([1.0])
outputs_with_residual = [x.item()]
for _ in range(10):
x = x + simple_layer(x) * 0.1 # x + f(x)
outputs_with_residual.append(x.item())
plt.figure(figsize=(10, 4))
plt.plot(outputs_no_residual, 'b-o', label='Without residual')
plt.plot(outputs_with_residual, 'r-o', label='With residual')
plt.xlabel('Layer')
plt.ylabel('Value')
plt.title('Effect of Residual Connections')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
print(f"Without residual: value shrinks to {outputs_no_residual[-1]:.4f}")
print(f"With residual: value stays near {outputs_with_residual[-1]:.4f}")
```
### Weight Initialization
Proper initialization is critical for training deep networks. Our implementation uses:
- **Embeddings**: Normal distribution with std=0.02
- **Linear layers in FFN**: Normal distribution with std=0.02, biases initialized to 0
- **Attention projections**: Xavier uniform initialization
```{python}
# Demonstrate the importance of initialization
import torch.nn as nn
# Bad initialization - too large
bad_linear = nn.Linear(768, 768)
nn.init.normal_(bad_linear.weight, std=1.0) # Too large!
# Good initialization - small weights
good_linear = nn.Linear(768, 768)
nn.init.normal_(good_linear.weight, std=0.02) # GPT-2 style
x = torch.randn(1, 10, 768)
bad_out = bad_linear(x)
good_out = good_linear(x)
print("Effect of initialization on output magnitude:")
print(f" Bad init (std=1.0): output std = {bad_out.std().item():.2f}")
print(f" Good init (std=0.02): output std = {good_out.std().item():.2f}")
print("\nLarge outputs can cause exploding gradients and NaN losses!")
```
**GPT-2's initialization trick**: Scale the final projection in each residual block by $1/\sqrt{2N}$ where N is the number of layers. This keeps the variance stable as depth increases.
### Building a Transformer Block
```{python}
from transformer import (
FeedForward,
TransformerBlock,
GPTModel,
create_gpt_tiny,
create_gpt_small,
)
# Create a transformer block
embed_dim = 64
num_heads = 4
ff_dim = 256
block = TransformerBlock(
embed_dim=embed_dim,
num_heads=num_heads,
ff_dim=ff_dim,
dropout=0.0
)
print(f"Transformer Block:")
print(f" Embed dim: {embed_dim}")
print(f" Num heads: {num_heads}")
print(f" Head dim: {embed_dim // num_heads}")
print(f" FF dim: {ff_dim}")
print(f"\nTotal parameters: {sum(p.numel() for p in block.parameters()):,}")
```
```{python}
# Forward pass
x = torch.randn(1, 8, embed_dim) # batch=1, seq=8
output, attention = block(x, return_attention=True)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention shape: {attention.shape}")
```
```{python}
# Visualize attention patterns from the block
fig, axes = plt.subplots(1, 4, figsize=(14, 3))
for head in range(4):
ax = axes[head]
w = attention[0, head].detach().numpy()
ax.imshow(w, cmap='Blues', vmin=0, vmax=w.max())
ax.set_title(f'Head {head}')
ax.set_xlabel('Key')
ax.set_ylabel('Query')
plt.suptitle('Attention Patterns in Transformer Block (Causal Masked)', fontsize=12)
plt.tight_layout()
plt.show()
```
### Complete GPT Model
```{python}
# Create a tiny GPT model
model = create_gpt_tiny(vocab_size=1000)
print("GPT Tiny Model:")
print(f" Vocab size: {model.vocab_size}")
print(f" Embed dim: {model.embed_dim}")
print(f" Num layers: {len(model.blocks)}")
print(f" Max seq len: {model.max_seq_len}")
print(f"\nTotal parameters: {model.num_params:,}")
```
```{python}
# Parameter breakdown
counts = model.count_parameters()
print("Parameter breakdown:")
for name, count in counts.items():
if count > 0:
pct = 100 * count / counts['total']
print(f" {name}: {count:,} ({pct:.1f}%)")
# Visualize
labels = [k for k, v in counts.items() if v > 0 and k != 'total']
sizes = [counts[k] for k in labels]
plt.figure(figsize=(8, 6))
plt.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=90)
plt.title(f'Parameter Distribution ({model.num_params:,} total)')
plt.show()
```
```{python}
# Forward pass
token_ids = torch.randint(0, 1000, (2, 32)) # batch=2, seq=32
logits = model(token_ids)
print(f"Input token IDs: {token_ids.shape}")
print(f"Output logits: {logits.shape}")
print(f" (batch=2, seq=32, vocab=1000)")
# Get predictions
probs = torch.softmax(logits[0, -1], dim=-1)
top_5 = torch.topk(probs, 5)
print("\nTop 5 predicted next tokens (untrained, so random):")
for i, (idx, prob) in enumerate(zip(top_5.indices, top_5.values)):
print(f" {i+1}. Token {idx.item()}: {prob.item()*100:.2f}%")
```
### Hidden States Through Layers
```{python}
# Get hidden states from all layers
logits, hidden_states = model(token_ids, return_hidden_states=True)
print(f"Number of hidden states: {len(hidden_states)}")
print(f" (1 after embedding + {len(model.blocks)} after each block)")
# Show how representations change through layers
norms = [h.norm(dim=-1).mean().item() for h in hidden_states]
plt.figure(figsize=(10, 4))
plt.plot(range(len(norms)), norms, 'b-o')
plt.xlabel('Layer')
plt.ylabel('Average Embedding Norm')
plt.title('Embedding Norms Through the Network')
plt.xticks(range(len(norms)), ['Embed'] + [f'Block {i}' for i in range(len(model.blocks))])
plt.grid(True, alpha=0.3)
plt.show()
```
### Weight Tying
GPT shares weights between token embedding and output projection:
```{python}
# Check weight tying
print("Weight Tying:")
print(f" Token embedding weight id: {id(model.token_embedding.weight)}")
print(f" LM head weight id: {id(model.lm_head.weight)}")
print(f" Are they the same object? {model.token_embedding.weight is model.lm_head.weight}")
# This saves parameters!
vocab_size = 1000
embed_dim = 128
saved_params = vocab_size * embed_dim
print(f"\nParameters saved by weight tying: {saved_params:,}")
```
## Model Sizes Comparison
| Model | Layers | Heads | Embed Dim | Params |
|-------|--------|-------|-----------|--------|
| Tiny (ours) | 4 | 4 | 128 | ~1M |
| Small (ours) | 6 | 6 | 384 | ~10M |
| GPT-2 Small | 12 | 12 | 768 | 117M |
| GPT-2 Medium | 24 | 16 | 1024 | 345M |
| GPT-2 Large | 36 | 20 | 1280 | 774M |
| GPT-2 XL | 48 | 25 | 1600 | 1.5B |
### Parameter Counting Formulas
Understanding where parameters come from helps with model sizing:
**Per Transformer Block:**
- Attention Q, K, V projections: $3 \times d \times d$ (where $d$ = embed_dim)
- Attention output projection: $d \times d$
- FFN first linear: $d \times 4d$
- FFN second linear: $4d \times d$
- LayerNorm (x2): $2 \times 2d$ (gamma and beta for each)
Total per block: $\approx 12d^2$ parameters
**Full Model:**
- Token embedding: $V \times d$ (V = vocab size)
- Position embedding: $L \times d$ (L = max sequence length)
- N transformer blocks: $N \times 12d^2$
- Final LayerNorm: $2d$
- LM head: 0 (weight-tied with token embedding)
**Approximate formula**: $\text{Params} \approx V \times d + 12Nd^2$
For GPT-2 Small (V=50257, d=768, N=12): $50257 \times 768 + 12 \times 12 \times 768^2 \approx 117M$
Scaling laws: more layers, heads, and dimensions lead to better performance (but diminishing returns and higher compute cost).
## Architectural Variations
Modern LLMs have evolved beyond the original GPT-2 architecture. Here are key variations:
### Normalization
| Variant | Used By | Description |
|---------|---------|-------------|
| LayerNorm | GPT-2, GPT-3 | Normalize across embedding dimension |
| RMSNorm | LLaMA, Mistral | Simpler: just divide by RMS, no mean subtraction |
| Pre-Norm | Most modern LLMs | Normalize before sublayer (more stable) |
### Position Embeddings
| Variant | Used By | Description |
|---------|---------|-------------|
| Learned absolute | GPT-2 | Separate embedding for each position |
| Rotary (RoPE) | LLaMA, Mistral | Encode position in attention via rotation |
| ALiBi | BLOOM | Add position bias to attention scores |
Our implementation uses learned absolute position embeddings (GPT-2 style), which are simple but limit the model to the maximum trained sequence length.
### Feed-Forward Networks
| Variant | Used By | Expansion | Activation |
|---------|---------|-----------|------------|
| Standard | GPT-2 | 4x | GELU |
| SwiGLU | LLaMA | 8/3x (after gating) | SiLU (Swish) |
## Common Pitfalls
When implementing or training transformers, watch out for:
1. **Forgetting the causal mask**: Without it, the model can "cheat" by looking at future tokens during training, leading to poor generation at inference time.
2. **Wrong normalization axis**: LayerNorm should normalize across the embedding dimension (last axis), not the sequence or batch dimensions.
3. **Residual connection placement**: Make sure to add the residual *after* dropout but *before* the next LayerNorm in Pre-Norm architecture.
4. **Large learning rates**: Transformers are sensitive to learning rate. Start with 1e-4 to 3e-4 for Adam, use warmup.
5. **Numerical instability**: Use float32 for training initially. Half precision (fp16/bf16) requires careful scaling.
6. **Forgetting final LayerNorm**: In Pre-Norm, the output of the last block isn't normalized. The final LayerNorm before the LM head is essential.
## Interactive Exploration
Experiment with transformer architecture choices to understand where parameters come from and how they scale.
```{ojs}
//| echo: false
// Parameter counting functions
function countParameters(vocabSize, embedDim, numLayers, numHeads, ffMult, maxSeqLen) {
// Token embeddings (weight-tied with LM head)
const tokenEmbed = vocabSize * embedDim;
// Position embeddings
const posEmbed = maxSeqLen * embedDim;
// Per transformer block
const attnParams = 4 * embedDim * embedDim; // Q, K, V, O projections
const ffnParams = 2 * embedDim * (ffMult * embedDim); // up + down projection
const normParams = 4 * embedDim; // 2 LayerNorms with gamma + beta each
const perBlock = attnParams + ffnParams + normParams;
// Total blocks
const totalBlocks = numLayers * perBlock;
// Final LayerNorm
const finalNorm = 2 * embedDim;
return {
tokenEmbed,
posEmbed,
attention: numLayers * attnParams,
ffn: numLayers * ffnParams,
layerNorm: numLayers * normParams + finalNorm,
total: tokenEmbed + posEmbed + totalBlocks + finalNorm
};
}
// Format large numbers
function formatParams(n) {
if (n >= 1e9) return (n / 1e9).toFixed(2) + "B";
if (n >= 1e6) return (n / 1e6).toFixed(2) + "M";
if (n >= 1e3) return (n / 1e3).toFixed(2) + "K";
return n.toString();
}
// Format memory
function formatMemory(bytes) {
if (bytes >= 1e9) return (bytes / 1e9).toFixed(2) + " GB";
if (bytes >= 1e6) return (bytes / 1e6).toFixed(1) + " MB";
return (bytes / 1e3).toFixed(1) + " KB";
}
```
```{ojs}
//| echo: false
// Architecture presets
presets = ({
"Tiny (Ours)": { vocab: 10000, embed: 128, layers: 4, heads: 4, ff: 4, seq: 512 },
"Small (Ours)": { vocab: 10000, embed: 384, layers: 6, heads: 6, ff: 4, seq: 512 },
"GPT-2 Small": { vocab: 50257, embed: 768, layers: 12, heads: 12, ff: 4, seq: 1024 },
"GPT-2 Medium": { vocab: 50257, embed: 1024, layers: 24, heads: 16, ff: 4, seq: 1024 },
"GPT-2 Large": { vocab: 50257, embed: 1280, layers: 36, heads: 20, ff: 4, seq: 1024 }
})
viewof preset = Inputs.select(Object.keys(presets), {
label: "Load Preset",
value: "Small (Ours)"
})
```
```{ojs}
//| echo: false
// Input controls with preset values
selectedPreset = presets[preset]
viewof vocabSize = Inputs.select([1000, 5000, 10000, 32000, 50257, 100000], {
value: selectedPreset.vocab,
label: "Vocabulary Size"
})
viewof embedDim = Inputs.select([64, 128, 256, 384, 512, 768, 1024, 1280, 2048, 4096], {
value: selectedPreset.embed,
label: "Embedding Dimension"
})
viewof numLayers = Inputs.range([1, 48], {
value: selectedPreset.layers,
step: 1,
label: "Number of Layers"
})
viewof numHeads = Inputs.range([1, 32], {
value: selectedPreset.heads,
step: 1,
label: "Number of Heads"
})
viewof ffMult = Inputs.range([1, 8], {
value: selectedPreset.ff,
step: 1,
label: "FFN Multiplier"
})
viewof maxSeqLen = Inputs.select([256, 512, 1024, 2048, 4096, 8192], {
value: selectedPreset.seq,
label: "Max Sequence Length"
})
```
```{ojs}
//| echo: false
// Dark mode detection
isDark = {
const check = () => document.body.classList.contains('quarto-dark');
return check();
}
// Theme colors for light/dark mode
theme = isDark ? {
tokenEmbed: '#6b8cae',
posEmbed: '#9a8bc0',
attention: '#5a9a7a',
ffn: '#b89a5a',
layerNorm: '#7a7c7d',
warning: '#fca5a5'
} : {
tokenEmbed: '#3b82f6',
posEmbed: '#8b5cf6',
attention: '#10b981',
ffn: '#f59e0b',
layerNorm: '#6b7280',
warning: '#ef4444'
}
```
```{ojs}
//| echo: false
// Compute parameters
params = countParameters(vocabSize, embedDim, numLayers, numHeads, ffMult, maxSeqLen)
// Data for pie chart
pieData = [
{ category: "Token Embeddings", value: params.tokenEmbed, color: theme.tokenEmbed },
{ category: "Position Embeddings", value: params.posEmbed, color: theme.posEmbed },
{ category: "Attention", value: params.attention, color: theme.attention },
{ category: "Feed-Forward", value: params.ffn, color: theme.ffn },
{ category: "LayerNorm", value: params.layerNorm, color: theme.layerNorm }
].filter(d => d.value > 0)
// Memory calculations
memoryFp32 = params.total * 4
memoryFp16 = params.total * 2
memoryInt8 = params.total * 1
// Head dimension check
headDim = embedDim / numHeads
headDimValid = embedDim % numHeads === 0
```
```{ojs}
//| echo: false
Plot = import("https://esm.sh/@observablehq/plot@0.6")
Plot.plot({
title: `Parameter Distribution: ${formatParams(params.total)} total`,
width: 600,
height: 300,
marginLeft: 120,
x: {
label: "Parameters →",
tickFormat: d => formatParams(d)
},
y: {
label: null
},
color: {
domain: pieData.map(d => d.category),
range: pieData.map(d => d.color)
},
marks: [
Plot.barX(pieData, {
y: "category",
x: "value",
fill: "category",
tip: true,
title: d => `${d.category}: ${formatParams(d.value)} (${(d.value / params.total * 100).toFixed(1)}%)`
}),
Plot.ruleX([0])
]
})
```
```{ojs}
//| echo: false
// Summary statistics
md`
| Component | Parameters | Percentage |
|-----------|------------|------------|
| Token Embeddings | ${formatParams(params.tokenEmbed)} | ${(params.tokenEmbed / params.total * 100).toFixed(1)}% |
| Position Embeddings | ${formatParams(params.posEmbed)} | ${(params.posEmbed / params.total * 100).toFixed(1)}% |
| Attention (${numLayers} layers) | ${formatParams(params.attention)} | ${(params.attention / params.total * 100).toFixed(1)}% |
| Feed-Forward (${numLayers} layers) | ${formatParams(params.ffn)} | ${(params.ffn / params.total * 100).toFixed(1)}% |
| LayerNorm | ${formatParams(params.layerNorm)} | ${(params.layerNorm / params.total * 100).toFixed(1)}% |
| **Total** | **${formatParams(params.total)}** | **100%** |
`
```
```{ojs}
//| echo: false
md`**Memory Requirements:** ${formatMemory(memoryFp32)} (fp32) | ${formatMemory(memoryFp16)} (fp16) | ${formatMemory(memoryInt8)} (int8)`
```
```{ojs}
//| echo: false
// Validation warning
headDimValid ? md`` : md`<span style="color: ${theme.warning}">⚠️ Warning: embed_dim (${embedDim}) is not divisible by num_heads (${numHeads}). Head dimension would be ${headDim.toFixed(2)}.</span>`
```
::: {.callout-tip}
## Try This
1. **FFN dominates**: Set embed_dim=768, layers=12. Notice the Feed-Forward bars are ~2x the Attention bars (because FFN has 8d² params vs Attention's 4d²).
2. **Embedding cost at small scale**: With vocab=50257 and embed_dim=768, token embeddings are ~38M params - a large fraction for small models.
3. **Scaling law**: Double embed_dim from 512 to 1024. Total params roughly quadruple (because most params scale with d²).
4. **Load GPT-2 presets** and see how the 117M, 345M, 774M models break down.
5. **Head dimension check**: Try numHeads that doesn't divide embedDim evenly - you'll see a warning.
:::
## Exercises
### Exercise 1: Build a Custom Block
```{python}
# Create a transformer block with different configurations
custom_block = TransformerBlock(
embed_dim=128,
num_heads=8,
ff_dim=512, # 4x expansion
dropout=0.1
)
# Test it
x = torch.randn(4, 16, 128) # batch=4, seq=16
output = custom_block(x)
print(f"Custom block: {x.shape} -> {output.shape}")
print(f"Parameters: {sum(p.numel() for p in custom_block.parameters()):,}")
```
### Exercise 2: Compare Model Scales
```{python}
# Compare tiny vs small model
tiny = create_gpt_tiny(vocab_size=10000)
small = create_gpt_small(vocab_size=10000)
print(f"{'Model':<10} {'Embed':<8} {'Layers':<8} {'Heads':<8} {'Params':<15}")
print("-" * 50)
print(f"{'Tiny':<10} {tiny.embed_dim:<8} {len(tiny.blocks):<8} {tiny.blocks[0].attention.mha.num_heads:<8} {tiny.num_params:,}")
print(f"{'Small':<10} {small.embed_dim:<8} {len(small.blocks):<8} {small.blocks[0].attention.mha.num_heads:<8} {small.num_params:,}")
```
### Exercise 3: Information Flow
```{python}
# See how a single token's representation changes through layers
model = create_gpt_tiny(vocab_size=100)
token_ids = torch.randint(0, 100, (1, 8))
_, hidden = model(token_ids, return_hidden_states=True)
# Track first token through layers
first_token_norms = [h[0, 0].norm().item() for h in hidden]
plt.figure(figsize=(8, 4))
plt.bar(range(len(first_token_norms)), first_token_norms)
plt.xlabel('Layer')
plt.ylabel('Embedding Norm (first token)')
plt.title('First Token Representation Through Layers')
plt.xticks(range(len(first_token_norms)), ['Embed'] + [f'Block {i}' for i in range(len(model.blocks))])
plt.show()
```
## Summary
Key takeaways:
1. **Transformer architecture**: Input embeddings -> N transformer blocks -> Final LayerNorm -> Output projection
2. **Each block has two sublayers**:
- Multi-head attention (tokens communicate)
- Feed-forward network (tokens processed independently)
3. **Pre-Norm architecture**: LayerNorm before each sublayer, with a "clean" residual path for stable gradients
4. **Layer normalization**: Normalizes across the embedding dimension, keeping activations in a stable range
5. **Residual connections**: `x + f(x)` enables gradient flow through very deep networks (100+ layers)
6. **Feed-forward networks**: 4x expansion with GELU activation provides computational capacity
7. **Weight tying**: Sharing token embedding and output projection reduces parameters and improves performance
8. **Initialization matters**: Small initial weights (std=0.02) prevent exploding activations
9. **Parameter scaling**: Total params $\approx V \times d + 12Nd^2$ (dominated by FFN for large models)
10. **Architectural variations**: Modern LLMs (LLaMA, Mistral) use RMSNorm, RoPE, and SwiGLU for better efficiency
## What's Next
In [Module 07: Training](../m07_training/lesson.qmd), we'll train our transformer on actual data using cross-entropy loss, learning rate scheduling, and gradient accumulation.