---
title: "Module 07: Training"
format:
html:
code-fold: false
toc: true
ipynb: default
jupyter: python3
---
{{< include ../_diagram-lib.qmd >}}
## Introduction
Training a language model means teaching it to predict the next token. We do this through an iterative process:
1. **Computing loss**: How wrong are our predictions?
2. **Computing gradients**: Which direction should we adjust weights?
3. **Updating weights**: Take a small step in that direction
4. **Repeat**: Until the model gets good at prediction
In this module, we'll explore cross-entropy loss, the AdamW optimizer, learning rate scheduling, gradient accumulation, and checkpointing.
### What You'll Learn
By the end of this module, you will be able to:
- Understand cross-entropy loss and perplexity for language models
- Implement learning rate schedules (warmup + cosine decay)
- Use gradient accumulation for effective larger batch sizes
- Apply gradient clipping for training stability
- Save and load model checkpoints
**Note**: This lesson demonstrates concepts interactively. The `training.py` file provides production-ready implementations of the same algorithms.
## The Training Objective
Language models are trained with **next-token prediction**:
```
Input: [The, cat, sat, on, the]
Target: [cat, sat, on, the, mat]
For each position, predict the next token.
```
The loss function measures how well the model predicts: **Cross-entropy** between predicted probabilities and actual next tokens.
$$\text{loss} = -\sum \log(P(\text{correct\_token}))$$
Lower loss means the model assigns higher probability to correct tokens, which means better predictions.
## The Training Loop
The training loop is the core of how neural networks learn:
```{ojs}
//| echo: false
// Training loop steps data
trainingSteps = [
{
id: 0,
name: "Zero Gradients",
code: "optimizer.zero_grad()",
description: "Clear accumulated gradients from the previous iteration to start fresh.",
detail: "Gradients accumulate by default in PyTorch. Without zeroing, they add up across iterations."
},
{
id: 1,
name: "Forward Pass",
code: "logits = model(input_ids)",
description: "Pass input tokens through the model to get predicted logits.",
detail: "The model computes attention, embeddings, and projections to produce next-token predictions."
},
{
id: 2,
name: "Compute Loss",
code: "loss = F.cross_entropy(logits, targets)",
description: "Measure how wrong the predictions are compared to actual next tokens.",
detail: "Cross-entropy loss: lower means higher probability assigned to correct tokens."
},
{
id: 3,
name: "Backward Pass",
code: "loss.backward()",
description: "Compute gradients for all parameters via backpropagation.",
detail: "Automatic differentiation traces computation graph backward, computing dLoss/dParam."
},
{
id: 4,
name: "Gradient Clipping",
code: "clip_grad_norm_(params, 1.0)",
description: "Scale gradients if their norm exceeds threshold to prevent instability.",
detail: "Prevents exploding gradients that can cause NaN loss or divergent training."
},
{
id: 5,
name: "Optimizer Step",
code: "optimizer.step()",
description: "Update model weights using the computed (and clipped) gradients.",
detail: "AdamW applies momentum, adaptive learning rates, and weight decay to the update."
},
{
id: 6,
name: "Update LR",
code: "scheduler.step()",
description: "Adjust learning rate according to schedule (warmup + cosine decay).",
detail: "High LR early for exploration, lower LR later for fine-tuning convergence."
}
]
```
```{ojs}
//| echo: false
// Step slider control
viewof trainingStep = Inputs.range([0, 6], {
value: 0,
step: 1,
label: "Training Step"
})
```
```{ojs}
//| echo: false
// Current step info
currentTrainingStep = trainingSteps[trainingStep]
```
```{ojs}
//| echo: false
// Draw the cyclic training loop diagram
{
const width = 650;
const height = 480;
const centerX = width / 2;
const centerY = height / 2 - 20;
const radius = 160;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`);
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 8);
// Defs for arrows
const defs = svg.append("defs");
// Standard arrow
defs.append("marker")
.attr("id", "training-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.edgeStroke);
// Highlighted arrow
defs.append("marker")
.attr("id", "training-arrow-active")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.highlight);
// Calculate node positions in a circle
const nodeCount = 7;
const startAngle = -Math.PI / 2; // Start at top
const nodePositions = trainingSteps.map((step, i) => {
const angle = startAngle + (i * 2 * Math.PI / nodeCount);
return {
...step,
x: centerX + radius * Math.cos(angle),
y: centerY + radius * Math.sin(angle),
angle: angle
};
});
// Draw connecting arrows between nodes
const edgesGroup = svg.append("g").attr("class", "edges");
for (let i = 0; i < nodeCount; i++) {
const from = nodePositions[i];
const to = nodePositions[(i + 1) % nodeCount];
// Calculate edge start/end to not overlap nodes
const nodeRadius = 42;
const dx = to.x - from.x;
const dy = to.y - from.y;
const dist = Math.sqrt(dx * dx + dy * dy);
const startX = from.x + (dx / dist) * (nodeRadius + 2);
const startY = from.y + (dy / dist) * (nodeRadius + 2);
const endX = to.x - (dx / dist) * (nodeRadius + 8);
const endY = to.y - (dy / dist) * (nodeRadius + 8);
// This edge is highlighted when we're on the "from" step
const isActive = trainingStep === i;
edgesGroup.append("path")
.attr("d", `M${startX},${startY} L${endX},${endY}`)
.attr("fill", "none")
.attr("stroke", isActive ? diagramTheme.highlight : diagramTheme.edgeStroke)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.attr("marker-end", isActive ? "url(#training-arrow-active)" : "url(#training-arrow)")
.attr("opacity", isActive ? 1 : 0.6)
.style("filter", isActive ? `drop-shadow(0 0 4px ${diagramTheme.highlightGlow})` : "none");
}
// Draw nodes
const nodesGroup = svg.append("g").attr("class", "nodes");
nodePositions.forEach((node, i) => {
const isActive = trainingStep === i;
const nodeSize = 42;
const g = nodesGroup.append("g")
.attr("transform", `translate(${node.x}, ${node.y})`);
// Circle node
g.append("circle")
.attr("r", nodeSize)
.attr("fill", isActive ? diagramTheme.highlight : diagramTheme.nodeFill)
.attr("stroke", isActive ? diagramTheme.highlight : diagramTheme.nodeStroke)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.style("filter", isActive ? `drop-shadow(0 0 8px ${diagramTheme.highlightGlow})` : "none");
// Step number
g.append("text")
.attr("y", -10)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "600")
.attr("opacity", 0.7)
.text(`Step ${i + 1}`);
// Node label (split long names)
const words = node.name.split(" ");
if (words.length > 1) {
g.append("text")
.attr("y", 5)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.text(words[0]);
g.append("text")
.attr("y", 18)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.text(words.slice(1).join(" "));
} else {
g.append("text")
.attr("y", 10)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.text(node.name);
}
});
// Center label
svg.append("text")
.attr("x", centerX)
.attr("y", centerY - 5)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.attr("opacity", 0.8)
.text("Training");
svg.append("text")
.attr("x", centerX)
.attr("y", centerY + 12)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.attr("opacity", 0.8)
.text("Loop");
// Info panel at bottom
const infoY = height - 100;
const infoGroup = svg.append("g")
.attr("transform", `translate(${width / 2}, ${infoY})`);
// Info box background
infoGroup.append("rect")
.attr("x", -280)
.attr("y", -10)
.attr("width", 560)
.attr("height", 85)
.attr("rx", 6)
.attr("fill", diagramTheme.bgSecondary)
.attr("stroke", diagramTheme.nodeStroke)
.attr("stroke-width", 1);
// Step name
infoGroup.append("text")
.attr("y", 8)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.highlight)
.attr("font-size", "13px")
.attr("font-weight", "600")
.text(`${currentTrainingStep.id + 1}. ${currentTrainingStep.name}`);
// Code
infoGroup.append("text")
.attr("y", 28)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.accent)
.attr("font-size", "12px")
.attr("font-family", "monospace")
.text(currentTrainingStep.code);
// Description (wrap if needed)
const desc = currentTrainingStep.description;
if (desc.length > 70) {
const mid = desc.lastIndexOf(" ", 70);
infoGroup.append("text")
.attr("y", 50)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.text(desc.substring(0, mid));
infoGroup.append("text")
.attr("y", 64)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.text(desc.substring(mid + 1));
} else {
infoGroup.append("text")
.attr("y", 55)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.text(desc);
}
return svg.node();
}
```
```{ojs}
//| echo: false
// Additional detail below the diagram
md`**Why this step matters:** ${currentTrainingStep.detail}`
```
Note: `zero_grad()` can be called either at the start or end of each iteration. Calling it at the start (shown above) is common because it ensures gradients are fresh before the backward pass.
## Setup
```{python}
import sys
import math
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
# For reproducibility
torch.manual_seed(42)
# Display device info
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")
```
## Cross-Entropy Loss
The loss function measures how wrong our predictions are. Cross-entropy loss penalizes wrong predictions more heavily when the model is confident but incorrect.
**Why cross-entropy?**
1. **Probabilistic interpretation**: It measures the "surprise" when the true token appears
2. **Gradient properties**: Gradients are proportional to the error (predicted - actual)
3. **Information theory**: Minimizing cross-entropy = maximizing likelihood of data
**Mathematical formulation:**
$$\text{CrossEntropy}(p, q) = -\sum_{i} p_i \log(q_i)$$
For language modeling with one-hot targets (only one correct token), this simplifies to:
$$\text{Loss} = -\log(q_{\text{correct}})$$
where $q_{\text{correct}}$ is the probability the model assigns to the correct token.
```{python}
# Example: Model predicting next token
vocab_size = 10
# Model outputs logits (raw scores)
logits = torch.tensor([
[-1.0, 0.5, 2.0, -0.5, 1.0, 0.0, -1.5, 0.3, -0.8, 0.2] # scores for each token
])
# True next token is index 2
target = torch.tensor([2])
# Convert to probabilities
probs = F.softmax(logits, dim=-1)
print("Logits (raw model output):")
print(f" {logits[0].tolist()}")
print(f"\nProbabilities (after softmax):")
print(f" {[f'{p:.3f}' for p in probs[0].tolist()]}")
print(f"\nTarget token: {target.item()}")
print(f"Probability assigned to target: {probs[0, target.item()]:.4f}")
```
```{python}
# Cross-entropy loss
loss = F.cross_entropy(logits, target)
manual_loss = -torch.log(probs[0, target.item()])
print(f"Cross-entropy loss: {loss.item():.4f}")
print(f"Manual calculation: -log({probs[0, target.item()]:.4f}) = {manual_loss.item():.4f}")
# Perplexity
perplexity = math.exp(loss.item())
print(f"\nPerplexity: {perplexity:.2f}")
```
Let's visualize how loss changes with probability:
```{python}
# Loss for different predictions
probs_range = np.linspace(0.01, 0.99, 100)
losses = -np.log(probs_range)
plt.figure(figsize=(10, 4))
plt.plot(probs_range, losses)
plt.xlabel('Probability assigned to correct token')
plt.ylabel('Cross-entropy loss')
plt.title('Loss vs Probability')
plt.grid(True, alpha=0.3)
plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
# Mark some points
for p in [0.1, 0.5, 0.9]:
plt.plot(p, -np.log(p), 'ro', markersize=10)
plt.annotate(f'P={p}\nLoss={-np.log(p):.2f}', (p, -np.log(p)+0.5))
plt.show()
print("Higher probability -> Lower loss -> Better predictions!")
```
### Cross-Entropy from Scratch
Before using `F.cross_entropy`, let's understand what it does internally.
**The Numerical Stability Problem**
Softmax involves `exp(x)`, which explodes for large x:
```{python}
# The problem: exp() overflows easily
logits_big = np.array([1000.0, 1001.0, 1002.0])
print(f"exp(logits) = {np.exp(logits_big)}") # [inf, inf, inf] - overflow!
```
**The Fix: Log-Sum-Exp Trick**
The key insight is that we can compute log-softmax stably by subtracting the maximum:
$$\log \text{softmax}(x_i) = x_i - \log\sum_j e^{x_j} = x_i - \underbrace{(m + \log\sum_j e^{x_j - m})}_{\text{logsumexp}}$$
where $m = \max(x)$. By subtracting the max, all exponents become $\leq 0$, avoiding overflow.
```{python}
def logsumexp(x: np.ndarray, axis: int = -1, keepdims: bool = True) -> np.ndarray:
"""
Stable log(sum(exp(x))).
Trick: log(sum(exp(x))) = m + log(sum(exp(x - m)))
where m = max(x). This keeps exp() arguments <= 0.
"""
m = x.max(axis=axis, keepdims=True)
return m + np.log(np.exp(x - m).sum(axis=axis, keepdims=keepdims))
# Now it works!
print(f"logsumexp(logits) = {logsumexp(logits_big, keepdims=False)}")
```
**Cross-Entropy Implementation**
```{python}
def cross_entropy_scratch(logits: np.ndarray, targets: np.ndarray) -> float:
"""
Cross-entropy loss from logits.
logits: (B, C) - raw scores for each class
targets: (B,) - integer class labels
Formula: loss = logsumexp(logits) - logits[correct_class]
This is equivalent to: -log(softmax(logits)[correct_class])
but numerically stable.
"""
B, C = logits.shape
# log(sum(exp(logits))) for normalization
lse = logsumexp(logits, axis=-1, keepdims=False).squeeze() # (B,)
# Gather correct class logits
correct_logits = logits[np.arange(B), targets] # (B,)
# Loss per sample, then mean
losses = lse - correct_logits
return float(losses.mean())
# Test
test_logits = np.array([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]])
test_targets = np.array([0, 1]) # First sample: class 0, second: class 1
print(f"Cross-entropy loss (scratch): {cross_entropy_scratch(test_logits, test_targets):.4f}")
```
**PyTorch Equivalent**
```{python}
# Compare with PyTorch
logits_pt = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]])
targets_pt = torch.tensor([0, 1])
loss_pt = F.cross_entropy(logits_pt, targets_pt)
print(f"Cross-entropy loss (PyTorch): {loss_pt.item():.4f}")
```
Same result! PyTorch's `F.cross_entropy` does exactly this internally, plus handles gradients automatically.
::: {.callout-note}
## Key Insight
Cross-entropy is just `logsumexp(logits) - logits[correct_class]`. The logsumexp trick prevents numerical overflow by subtracting the max before exponentiating.
:::
## Perplexity
Perplexity is a more intuitive measure than raw loss:
$$\text{Perplexity} = e^{\text{cross\_entropy\_loss}}$$
**Interpretation**: "The model is as confused as if it were choosing uniformly among N options."
| Loss | Perplexity | Interpretation |
|------|------------|----------------|
| 0.0 | 1.0 | Perfect predictions |
| 2.3 | 10 | ~10 equally likely options |
| 4.6 | 100 | ~100 equally likely options |
| 6.9 | 1000 | Random guessing (vocab=1000) |
For reference:
- GPT-2 on WebText: ~20 perplexity
- Human baseline: ~10-20 perplexity (depends on domain)
## Learning Rate Schedule
We don't use a constant learning rate. Instead, we use warmup followed by cosine decay:
```{ojs}
//| echo: false
// Learning Rate Schedule Parameters
viewof lrTotalSteps = Inputs.range([100, 2000], {
value: 1000,
step: 50,
label: "Total Steps"
})
viewof lrWarmupSteps = Inputs.range([0, 500], {
value: 100,
step: 10,
label: "Warmup Steps"
})
viewof lrCurrentStep = Inputs.range([0, lrTotalSteps], {
value: 0,
step: 1,
label: "Current Step"
})
```
```{ojs}
//| echo: false
// LR Schedule calculation function
lrScheduleData = {
const maxLR = 1.0;
const minLR = 0.1;
const data = [];
for (let step = 0; step <= lrTotalSteps; step++) {
let lr;
let phase;
if (step < lrWarmupSteps) {
// Linear warmup
lr = maxLR * step / Math.max(1, lrWarmupSteps);
phase = "warmup";
} else if (step >= lrTotalSteps) {
lr = minLR;
phase = "decay";
} else {
// Cosine decay
const progress = (step - lrWarmupSteps) / Math.max(1, lrTotalSteps - lrWarmupSteps);
const cosine = 0.5 * (1 + Math.cos(Math.PI * progress));
lr = minLR + (maxLR - minLR) * cosine;
phase = "decay";
}
data.push({ step, lr, phase });
}
return data;
}
// Current LR value
currentLR = {
const maxLR = 1.0;
const minLR = 0.1;
const step = lrCurrentStep;
if (step < lrWarmupSteps) {
return maxLR * step / Math.max(1, lrWarmupSteps);
} else if (step >= lrTotalSteps) {
return minLR;
} else {
const progress = (step - lrWarmupSteps) / Math.max(1, lrTotalSteps - lrWarmupSteps);
const cosine = 0.5 * (1 + Math.cos(Math.PI * progress));
return minLR + (maxLR - minLR) * cosine;
}
}
// Current phase
currentPhase = {
if (lrCurrentStep < lrWarmupSteps) return "warmup";
if (lrCurrentStep === lrWarmupSteps) return "peak";
return "decay";
}
```
```{ojs}
//| echo: false
// Learning Rate Schedule Visualization
lrScheduleChart = {
const width = 700;
const height = 380;
const margin = { top: 40, right: 30, bottom: 50, left: 60 };
const innerWidth = width - margin.left - margin.right;
const innerHeight = height - margin.top - margin.bottom;
const theme = diagramTheme;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', monospace");
// Background with gradient
const defs = svg.append("defs");
const bgGradient = defs.append("linearGradient")
.attr("id", "lr-bg-gradient")
.attr("x1", "0%")
.attr("y1", "0%")
.attr("x2", "0%")
.attr("y2", "100%");
bgGradient.append("stop")
.attr("offset", "0%")
.attr("stop-color", theme.isDark ? "#1a1a2e" : "#f8fafc");
bgGradient.append("stop")
.attr("offset", "100%")
.attr("stop-color", theme.isDark ? "#16162a" : "#f1f5f9");
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", "url(#lr-bg-gradient)")
.attr("rx", 12);
// Chart area
const chart = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
// Scales
const xScale = d3.scaleLinear()
.domain([0, lrTotalSteps])
.range([0, innerWidth]);
const yScale = d3.scaleLinear()
.domain([0, 1.1])
.range([innerHeight, 0]);
// Phase background regions
// Warmup region
if (lrWarmupSteps > 0) {
chart.append("rect")
.attr("x", 0)
.attr("y", 0)
.attr("width", xScale(lrWarmupSteps))
.attr("height", innerHeight)
.attr("fill", theme.accent)
.attr("opacity", currentPhase === "warmup" ? 0.15 : 0.05);
}
// Decay region
chart.append("rect")
.attr("x", xScale(lrWarmupSteps))
.attr("y", 0)
.attr("width", innerWidth - xScale(lrWarmupSteps))
.attr("height", innerHeight)
.attr("fill", theme.highlight)
.attr("opacity", currentPhase === "decay" || currentPhase === "peak" ? 0.1 : 0.03);
// Grid lines
const yTicks = [0, 0.25, 0.5, 0.75, 1.0];
yTicks.forEach(tick => {
chart.append("line")
.attr("x1", 0)
.attr("x2", innerWidth)
.attr("y1", yScale(tick))
.attr("y2", yScale(tick))
.attr("stroke", theme.nodeStroke)
.attr("stroke-opacity", 0.3)
.attr("stroke-dasharray", "3,3");
});
// Phase labels
if (lrWarmupSteps > 0) {
chart.append("text")
.attr("x", xScale(lrWarmupSteps / 2))
.attr("y", 15)
.attr("text-anchor", "middle")
.attr("font-size", "11px")
.attr("font-weight", currentPhase === "warmup" ? "600" : "400")
.attr("fill", currentPhase === "warmup" ? theme.accent : theme.nodeText)
.attr("opacity", currentPhase === "warmup" ? 1 : 0.5)
.text("WARMUP");
}
chart.append("text")
.attr("x", xScale(lrWarmupSteps + (lrTotalSteps - lrWarmupSteps) / 2))
.attr("y", 15)
.attr("text-anchor", "middle")
.attr("font-size", "11px")
.attr("font-weight", currentPhase === "decay" ? "600" : "400")
.attr("fill", currentPhase === "decay" || currentPhase === "peak" ? theme.highlight : theme.nodeText)
.attr("opacity", currentPhase === "decay" || currentPhase === "peak" ? 1 : 0.5)
.text("COSINE DECAY");
// Line generator
const lineGen = d3.line()
.x(d => xScale(d.step))
.y(d => yScale(d.lr))
.curve(d3.curveMonotoneX);
// Gradient for the line
const lineGradient = defs.append("linearGradient")
.attr("id", "lr-line-gradient")
.attr("gradientUnits", "userSpaceOnUse")
.attr("x1", 0)
.attr("x2", innerWidth)
.attr("y1", 0)
.attr("y2", 0);
lineGradient.append("stop")
.attr("offset", "0%")
.attr("stop-color", theme.accent);
const warmupPct = (lrWarmupSteps / lrTotalSteps * 100).toFixed(1);
lineGradient.append("stop")
.attr("offset", `${warmupPct}%`)
.attr("stop-color", theme.accent);
lineGradient.append("stop")
.attr("offset", `${warmupPct}%`)
.attr("stop-color", theme.highlight);
lineGradient.append("stop")
.attr("offset", "100%")
.attr("stop-color", theme.highlight);
// Area under curve
const areaGen = d3.area()
.x(d => xScale(d.step))
.y0(innerHeight)
.y1(d => yScale(d.lr))
.curve(d3.curveMonotoneX);
// Area gradient
const areaGradient = defs.append("linearGradient")
.attr("id", "lr-area-gradient")
.attr("x1", "0%")
.attr("y1", "0%")
.attr("x2", "0%")
.attr("y2", "100%");
areaGradient.append("stop")
.attr("offset", "0%")
.attr("stop-color", theme.highlight)
.attr("stop-opacity", 0.3);
areaGradient.append("stop")
.attr("offset", "100%")
.attr("stop-color", theme.highlight)
.attr("stop-opacity", 0.02);
chart.append("path")
.datum(lrScheduleData)
.attr("d", areaGen)
.attr("fill", "url(#lr-area-gradient)");
// Main line
chart.append("path")
.datum(lrScheduleData)
.attr("d", lineGen)
.attr("fill", "none")
.attr("stroke", "url(#lr-line-gradient)")
.attr("stroke-width", 3)
.attr("stroke-linecap", "round");
// Current step marker
const currentX = xScale(lrCurrentStep);
const currentY = yScale(currentLR);
// Vertical line at current step
chart.append("line")
.attr("x1", currentX)
.attr("x2", currentX)
.attr("y1", 0)
.attr("y2", innerHeight)
.attr("stroke", theme.nodeText)
.attr("stroke-opacity", 0.4)
.attr("stroke-dasharray", "4,4");
// Horizontal line to y-axis
chart.append("line")
.attr("x1", 0)
.attr("x2", currentX)
.attr("y1", currentY)
.attr("y2", currentY)
.attr("stroke", theme.nodeText)
.attr("stroke-opacity", 0.4)
.attr("stroke-dasharray", "4,4");
// Glow effect for marker
const glowFilter = defs.append("filter")
.attr("id", "lr-marker-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "4")
.attr("result", "blur");
glowFilter.append("feMerge")
.selectAll("feMergeNode")
.data(["blur", "SourceGraphic"])
.join("feMergeNode")
.attr("in", d => d);
// Current step dot with glow
chart.append("circle")
.attr("cx", currentX)
.attr("cy", currentY)
.attr("r", 12)
.attr("fill", currentPhase === "warmup" ? theme.accent : theme.highlight)
.attr("opacity", 0.3)
.attr("filter", "url(#lr-marker-glow)");
chart.append("circle")
.attr("cx", currentX)
.attr("cy", currentY)
.attr("r", 6)
.attr("fill", currentPhase === "warmup" ? theme.accent : theme.highlight)
.attr("stroke", theme.bg === "transparent" ? "#1c1917" : theme.bg)
.attr("stroke-width", 2);
// X-axis
chart.append("g")
.attr("transform", `translate(0, ${innerHeight})`)
.call(d3.axisBottom(xScale).ticks(8).tickFormat(d => d))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick line").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick text")
.attr("fill", theme.nodeText)
.attr("font-size", "11px"));
// X-axis label
chart.append("text")
.attr("x", innerWidth / 2)
.attr("y", innerHeight + 40)
.attr("text-anchor", "middle")
.attr("font-size", "12px")
.attr("fill", theme.nodeText)
.text("Training Steps");
// Y-axis
chart.append("g")
.call(d3.axisLeft(yScale).ticks(5).tickFormat(d => d.toFixed(2)))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick line").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick text")
.attr("fill", theme.nodeText)
.attr("font-size", "11px"));
// Y-axis label
chart.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -innerHeight / 2)
.attr("y", -45)
.attr("text-anchor", "middle")
.attr("font-size", "12px")
.attr("fill", theme.nodeText)
.text("Learning Rate");
// Info box
const infoBox = svg.append("g")
.attr("transform", `translate(${width - 170}, 50)`);
infoBox.append("rect")
.attr("x", 0)
.attr("y", 0)
.attr("width", 150)
.attr("height", 80)
.attr("rx", 8)
.attr("fill", theme.nodeFill)
.attr("stroke", theme.nodeStroke)
.attr("stroke-width", 1.5);
infoBox.append("text")
.attr("x", 75)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("font-size", "10px")
.attr("font-weight", "600")
.attr("fill", theme.nodeText)
.attr("opacity", 0.6)
.text("CURRENT");
infoBox.append("text")
.attr("x", 75)
.attr("y", 45)
.attr("text-anchor", "middle")
.attr("font-size", "18px")
.attr("font-weight", "700")
.attr("fill", currentPhase === "warmup" ? theme.accent : theme.highlight)
.text(`LR: ${currentLR.toFixed(4)}`);
infoBox.append("text")
.attr("x", 75)
.attr("y", 65)
.attr("text-anchor", "middle")
.attr("font-size", "11px")
.attr("fill", theme.nodeText)
.attr("opacity", 0.7)
.text(`Step ${lrCurrentStep} / ${lrTotalSteps}`);
// Phase indicator badge
const phaseBadge = svg.append("g")
.attr("transform", `translate(${margin.left + 10}, 55)`);
const phaseColor = currentPhase === "warmup" ? theme.accent : theme.highlight;
const phaseLabel = currentPhase.toUpperCase();
phaseBadge.append("rect")
.attr("x", 0)
.attr("y", 0)
.attr("width", 75)
.attr("height", 24)
.attr("rx", 12)
.attr("fill", phaseColor)
.attr("opacity", 0.9);
phaseBadge.append("text")
.attr("x", 37.5)
.attr("y", 16)
.attr("text-anchor", "middle")
.attr("font-size", "10px")
.attr("font-weight", "700")
.attr("fill", theme.textOnHighlight)
.text(phaseLabel);
return svg.node();
}
```
**Why warmup?**
- Early training is unstable with large LR
- Gradients are noisy before weights settle
- Small LR lets model "get its bearings"
**Why decay?**
- Large LR is good for exploration early
- Small LR is good for fine-tuning later
- Cosine is smooth (no sudden changes)
```{python}
class CosineScheduler:
"""Learning rate scheduler with linear warmup and cosine decay."""
def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0.0):
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr = min_lr
self.base_lr = optimizer.param_groups[0]['lr']
self.current_step = 0
def get_lr(self):
"""Calculate learning rate for current step."""
if self.current_step < self.warmup_steps:
# Linear warmup
return self.base_lr * self.current_step / max(1, self.warmup_steps)
elif self.current_step >= self.total_steps:
return self.min_lr
else:
# Cosine decay
progress = (self.current_step - self.warmup_steps) / max(
1, self.total_steps - self.warmup_steps
)
cosine = 0.5 * (1 + math.cos(math.pi * progress))
return self.min_lr + (self.base_lr - self.min_lr) * cosine
def step(self):
"""Update learning rate."""
lr = self.get_lr()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.current_step += 1
return lr
# Create scheduler
model = nn.Linear(10, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = CosineScheduler(
optimizer,
warmup_steps=100,
total_steps=1000,
min_lr=1e-5
)
# Collect LRs over training
lrs = []
for _ in range(1000):
lrs.append(scheduler.get_lr())
scheduler.step()
plt.figure(figsize=(12, 4))
plt.plot(lrs)
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule: Warmup + Cosine Decay')
plt.axvline(x=100, color='r', linestyle='--', label='Warmup ends')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
print(f"Initial LR: {lrs[0]:.6f}")
print(f"After warmup (step 100): {lrs[100]:.6f}")
print(f"Final LR: {lrs[-1]:.6f}")
```
Let's compare different warmup lengths:
```{python}
fig, ax = plt.subplots(figsize=(12, 4))
for warmup in [10, 50, 100, 200]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = CosineScheduler(optimizer, warmup_steps=warmup, total_steps=500)
lrs = []
for _ in range(500):
lrs.append(scheduler.get_lr())
scheduler.step()
ax.plot(lrs, label=f'Warmup={warmup}')
ax.set_xlabel('Step')
ax.set_ylabel('Learning Rate')
ax.set_title('Effect of Warmup Length')
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()
```
## AdamW Optimizer
AdamW is Adam with decoupled weight decay (proper L2 regularization). It's the standard optimizer for training language models.
**Why AdamW over SGD or Adam?**
- **SGD**: Requires careful learning rate tuning per layer, slow convergence
- **Adam**: Weight decay is applied to gradients (incorrect for L2 regularization)
- **AdamW**: Decouples weight decay from gradient updates (mathematically correct)
```{ojs}
//| echo: false
// AdamW step-through visualization
viewof adamwStep = Inputs.range([0, 4], {
value: 0,
step: 1,
label: "Step"
})
```
```{ojs}
//| echo: false
// Step descriptions for AdamW
adamwStepInfo = {
const steps = [
{
title: "Input Gradient",
description: "Receive gradient g from backpropagation",
formula: "g = dL/dθ",
highlight: ["gradient"]
},
{
title: "Momentum Update",
description: "Update first moment (exponential moving average of gradients)",
formula: "m = β₁·m + (1-β₁)·g",
highlight: ["gradient", "momentum"]
},
{
title: "Adaptive Learning Rate",
description: "Update second moment (exponential moving average of squared gradients)",
formula: "v = β₂·v + (1-β₂)·g²",
highlight: ["gradient", "velocity"]
},
{
title: "Bias Correction",
description: "Correct for initialization bias in early timesteps",
formula: "m̂ = m/(1-β₁ᵗ), v̂ = v/(1-β₂ᵗ)",
highlight: ["momentum", "velocity", "bias"]
},
{
title: "Weight Update",
description: "Apply adaptive update with decoupled weight decay",
formula: "θ = θ - lr·(m̂/√v̂ + λ·θ)",
highlight: ["bias", "update"]
}
];
return steps[adamwStep];
}
// Numeric computation for AdamW example
adamwComputation = {
// Initial values and hyperparameters
const g = 0.5; // gradient
const beta1 = 0.9;
const beta2 = 0.999;
const lr = 0.001;
const lambda = 0.01; // weight decay
const t = 5; // timestep
const m_prev = 0.1; // previous momentum
const v_prev = 0.01; // previous velocity
const theta_prev = 0.75; // previous weight
// Step 0: Just the gradient
const step0 = { g };
// Step 1: Momentum update
const m = beta1 * m_prev + (1 - beta1) * g;
const step1 = { ...step0, m, m_prev };
// Step 2: Velocity update
const v = beta2 * v_prev + (1 - beta2) * (g * g);
const step2 = { ...step1, v, v_prev };
// Step 3: Bias correction
const m_hat = m / (1 - Math.pow(beta1, t));
const v_hat = v / (1 - Math.pow(beta2, t));
const step3 = { ...step2, m_hat, v_hat };
// Step 4: Weight update
const adam_update = m_hat / Math.sqrt(v_hat + 1e-8);
const weight_decay = lambda * theta_prev;
const theta = theta_prev - lr * (adam_update + weight_decay);
const step4 = { ...step3, adam_update, weight_decay, theta, theta_prev };
const steps = [step0, step1, step2, step3, step4];
return {
...steps[adamwStep],
beta1, beta2, lr, lambda, t,
step: adamwStep
};
}
```
```{ojs}
//| echo: false
// AdamW flowchart visualization
adamwDiagram = {
const width = 680;
const height = 420;
const theme = diagramTheme;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'IBM Plex Mono', 'Fira Code', monospace");
// Background with subtle gradient
const bgGrad = svg.append("defs").append("linearGradient")
.attr("id", "adamw-bg-grad")
.attr("x1", "0%").attr("y1", "0%")
.attr("x2", "100%").attr("y2", "100%");
bgGrad.append("stop").attr("offset", "0%").attr("stop-color", theme.bg);
bgGrad.append("stop").attr("offset", "100%").attr("stop-color", theme.bgSecondary);
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", "url(#adamw-bg-grad)")
.attr("rx", 12);
// Glow filter for highlights
const defs = svg.select("defs");
const glowFilter = defs.append("filter")
.attr("id", "adamw-glow")
.attr("x", "-50%").attr("y", "-50%")
.attr("width", "200%").attr("height", "200%");
glowFilter.append("feGaussianBlur")
.attr("stdDeviation", "4")
.attr("result", "coloredBlur");
const feMerge = glowFilter.append("feMerge");
feMerge.append("feMergeNode").attr("in", "coloredBlur");
feMerge.append("feMergeNode").attr("in", "SourceGraphic");
// Arrow marker
defs.append("marker")
.attr("id", "adamw-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", theme.edgeStroke);
defs.append("marker")
.attr("id", "adamw-arrow-highlight")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", theme.highlight);
// Node definitions
const nodes = [
{ id: "gradient", label: "Gradient", sublabel: "g = dL/dθ", x: 340, y: 60 },
{ id: "momentum", label: "Momentum", sublabel: "m = β₁m + (1-β₁)g", x: 180, y: 160 },
{ id: "velocity", label: "Adaptive LR", sublabel: "v = β₂v + (1-β₂)g²", x: 500, y: 160 },
{ id: "bias", label: "Bias Correction", sublabel: "m̂, v̂", x: 340, y: 260 },
{ id: "update", label: "Weight Update", sublabel: "θ = θ - lr·(...)", x: 340, y: 360 }
];
// Edge definitions
const edges = [
{ from: "gradient", to: "momentum" },
{ from: "gradient", to: "velocity" },
{ from: "momentum", to: "bias" },
{ from: "velocity", to: "bias" },
{ from: "bias", to: "update" }
];
// Determine which nodes/edges are active based on step
const activeNodes = adamwStepInfo.highlight;
const isNodeActive = (id) => activeNodes.includes(id);
const isEdgeActive = (from, to) => {
return activeNodes.includes(from) && activeNodes.includes(to);
};
// Draw edges
const edgesLayer = svg.append("g").attr("class", "edges");
edges.forEach(edge => {
const fromNode = nodes.find(n => n.id === edge.from);
const toNode = nodes.find(n => n.id === edge.to);
const active = isEdgeActive(edge.from, edge.to);
// Calculate shortened path
const dx = toNode.x - fromNode.x;
const dy = toNode.y - fromNode.y;
const len = Math.sqrt(dx*dx + dy*dy);
const startOffset = 30;
const endOffset = 35;
const x1 = fromNode.x + (dx/len) * startOffset;
const y1 = fromNode.y + (dy/len) * startOffset;
const x2 = toNode.x - (dx/len) * endOffset;
const y2 = toNode.y - (dy/len) * endOffset;
edgesLayer.append("path")
.attr("d", `M${x1},${y1} L${x2},${y2}`)
.attr("fill", "none")
.attr("stroke", active ? theme.highlight : theme.edgeStroke)
.attr("stroke-width", active ? 2.5 : 1.5)
.attr("marker-end", active ? "url(#adamw-arrow-highlight)" : "url(#adamw-arrow)")
.attr("opacity", active ? 1 : 0.5)
.style("filter", active ? "url(#adamw-glow)" : "none")
.style("transition", "all 0.3s ease");
});
// Draw nodes
const nodesLayer = svg.append("g").attr("class", "nodes");
nodes.forEach(node => {
const active = isNodeActive(node.id);
const nodeWidth = 140;
const nodeHeight = 54;
const g = nodesLayer.append("g")
.attr("transform", `translate(${node.x}, ${node.y})`);
// Node background
g.append("rect")
.attr("x", -nodeWidth/2)
.attr("y", -nodeHeight/2)
.attr("width", nodeWidth)
.attr("height", nodeHeight)
.attr("rx", 8)
.attr("ry", 8)
.attr("fill", active ? theme.highlight : theme.nodeFill)
.attr("stroke", active ? theme.highlight : theme.nodeStroke)
.attr("stroke-width", active ? 2 : 1.5)
.style("filter", active ? "url(#adamw-glow)" : "none")
.style("transition", "all 0.3s ease");
// Node label
g.append("text")
.attr("y", -8)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", active ? theme.textOnHighlight : theme.nodeText)
.attr("font-size", "13px")
.attr("font-weight", "600")
.style("transition", "fill 0.3s ease")
.text(node.label);
// Node sublabel
g.append("text")
.attr("y", 12)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", active ? theme.textOnHighlight : theme.nodeText)
.attr("font-size", "10px")
.attr("opacity", active ? 0.9 : 0.7)
.style("transition", "all 0.3s ease")
.text(node.sublabel);
});
return svg.node();
}
```
```{ojs}
//| echo: false
// Info panel showing current step details and numeric values
adamwInfoPanel = {
const theme = diagramTheme;
const comp = adamwComputation;
const info = adamwStepInfo;
const container = htl.html`<div style="
background: ${theme.bgSecondary};
border: 1px solid ${theme.nodeStroke};
border-radius: 8px;
padding: 16px 20px;
margin-top: 12px;
font-family: 'IBM Plex Mono', 'Fira Code', monospace;
">
<div style="
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 12px;
">
<span style="
font-size: 14px;
font-weight: 600;
color: ${theme.highlight};
">Step ${adamwStep}: ${info.title}</span>
<span style="
font-size: 12px;
color: ${theme.nodeText};
opacity: 0.7;
">t = ${comp.t}</span>
</div>
<p style="
font-size: 12px;
color: ${theme.nodeText};
margin: 0 0 12px 0;
line-height: 1.5;
">${info.description}</p>
<div style="
background: ${theme.nodeFill};
border-radius: 6px;
padding: 12px 16px;
font-family: 'IBM Plex Mono', 'Fira Code', monospace;
">
<div style="
font-size: 15px;
color: ${theme.accent};
font-weight: 500;
margin-bottom: 10px;
">${info.formula}</div>
${adamwStep === 0 ? htl.html`
<div style="font-size: 11px; color: ${theme.nodeText}; line-height: 1.8;">
<div><span style="opacity: 0.6;">gradient:</span> g = <span style="color: ${theme.highlight};">${comp.g.toFixed(3)}</span></div>
<div><span style="opacity: 0.6;">hyperparams:</span> β₁=${comp.beta1}, β₂=${comp.beta2}, lr=${comp.lr}, λ=${comp.lambda}</div>
</div>
` : ''}
${adamwStep === 1 ? htl.html`
<div style="font-size: 11px; color: ${theme.nodeText}; line-height: 1.8;">
<div>m = ${comp.beta1} × ${comp.m_prev.toFixed(3)} + ${(1-comp.beta1).toFixed(1)} × ${comp.g.toFixed(3)}</div>
<div>m = <span style="color: ${theme.highlight};">${comp.m.toFixed(4)}</span></div>
</div>
` : ''}
${adamwStep === 2 ? htl.html`
<div style="font-size: 11px; color: ${theme.nodeText}; line-height: 1.8;">
<div>v = ${comp.beta2} × ${comp.v_prev.toFixed(4)} + ${(1-comp.beta2).toFixed(3)} × ${comp.g.toFixed(3)}²</div>
<div>v = <span style="color: ${theme.highlight};">${comp.v.toFixed(6)}</span></div>
</div>
` : ''}
${adamwStep === 3 ? htl.html`
<div style="font-size: 11px; color: ${theme.nodeText}; line-height: 1.8;">
<div>m̂ = ${comp.m.toFixed(4)} / (1 - ${comp.beta1}^${comp.t}) = <span style="color: ${theme.highlight};">${comp.m_hat.toFixed(4)}</span></div>
<div>v̂ = ${comp.v.toFixed(6)} / (1 - ${comp.beta2}^${comp.t}) = <span style="color: ${theme.highlight};">${comp.v_hat.toFixed(6)}</span></div>
</div>
` : ''}
${adamwStep === 4 ? htl.html`
<div style="font-size: 11px; color: ${theme.nodeText}; line-height: 1.8;">
<div>adam = m̂/√v̂ = ${comp.m_hat.toFixed(4)} / √${comp.v_hat.toFixed(6)} = ${comp.adam_update.toFixed(4)}</div>
<div>decay = λ·θ = ${comp.lambda} × ${comp.theta_prev.toFixed(2)} = ${comp.weight_decay.toFixed(5)}</div>
<div>θ = ${comp.theta_prev.toFixed(4)} - ${comp.lr} × (${comp.adam_update.toFixed(4)} + ${comp.weight_decay.toFixed(5)})</div>
<div>θ = <span style="color: ${theme.highlight}; font-weight: 600;">${comp.theta.toFixed(6)}</span></div>
</div>
` : ''}
</div>
</div>`;
return container;
}
```
**Hyperparameters explained:**
| Parameter | Default | Purpose |
|-----------|---------|---------|
| beta1 | 0.9 | Momentum coefficient - smooths gradient direction |
| beta2 | 0.999 | Adaptive LR coefficient - smooths gradient magnitude |
| epsilon | 1e-8 | Numerical stability (prevents division by zero) |
| weight_decay | 0.01 | L2 regularization strength |
**Practical tip:** The LLM community has converged on beta1=0.9, beta2=0.95 for large models (used by LLaMA, GPT-3). The lower beta2 adapts faster to changing gradient magnitudes.
```{python}
# Creating an AdamW optimizer
model = nn.Linear(100, 10)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4, # Learning rate
betas=(0.9, 0.999), # Momentum and adaptive LR
weight_decay=0.01 # Regularization
)
print("AdamW optimizer created")
print(f" Learning rate: {optimizer.param_groups[0]['lr']}")
print(f" Weight decay: {optimizer.param_groups[0]['weight_decay']}")
```
### Optimizers from Scratch
Let's build optimizers from first principles to understand what PyTorch does internally.
#### Plain SGD
The simplest optimizer: move parameters in the opposite direction of the gradient.
```{python}
class SGD_Scratch:
"""
Stochastic Gradient Descent.
Update rule: theta = theta - lr * gradient
"""
def __init__(self, params, lr=0.01):
self.params = list(params)
self.lr = lr
def step(self):
with torch.no_grad():
for p in self.params:
if p.grad is not None:
p -= self.lr * p.grad
def zero_grad(self):
for p in self.params:
if p.grad is not None:
p.grad = None
# Test: compare with PyTorch SGD
torch.manual_seed(42)
model_scratch = nn.Linear(10, 2)
model_pytorch = nn.Linear(10, 2)
model_pytorch.load_state_dict(model_scratch.state_dict())
opt_scratch = SGD_Scratch(model_scratch.parameters(), lr=0.1)
opt_pytorch = torch.optim.SGD(model_pytorch.parameters(), lr=0.1)
# Forward + backward
x = torch.randn(4, 10)
loss_scratch = model_scratch(x).sum()
loss_pytorch = model_pytorch(x).sum()
loss_scratch.backward()
loss_pytorch.backward()
# Update
opt_scratch.step()
opt_pytorch.step()
# Compare weights
print("After one SGD step:")
print(f" Scratch weight[0,0]: {model_scratch.weight[0,0].item():.6f}")
print(f" PyTorch weight[0,0]: {model_pytorch.weight[0,0].item():.6f}")
print(f" Match: {torch.allclose(model_scratch.weight, model_pytorch.weight)}")
```
#### SGD with Momentum
Momentum adds "velocity" to gradient descent. Instead of using the gradient directly, we accumulate a moving average of gradients:
$$v_t = \mu \cdot v_{t-1} + g_t$$
$$\theta_t = \theta_{t-1} - \alpha \cdot v_t$$
This helps:
- Smooth out noisy gradients
- Accelerate through flat regions
- Dampen oscillations in steep valleys
```{python}
class SGD_Momentum_Scratch:
"""
SGD with momentum.
Update rule:
v = momentum * v + gradient
theta = theta - lr * v
"""
def __init__(self, params, lr=0.01, momentum=0.9):
self.params = list(params)
self.lr = lr
self.momentum = momentum
# Velocity buffer for each parameter
self.v = [torch.zeros_like(p) for p in self.params]
def step(self):
with torch.no_grad():
for i, p in enumerate(self.params):
if p.grad is None:
continue
# Update velocity: v = momentum * v + grad
self.v[i] = self.momentum * self.v[i] + p.grad
# Update parameter
p -= self.lr * self.v[i]
def zero_grad(self):
for p in self.params:
if p.grad is not None:
p.grad = None
# Test: compare with PyTorch SGD momentum
torch.manual_seed(42)
model_scratch = nn.Linear(10, 2)
model_pytorch = nn.Linear(10, 2)
model_pytorch.load_state_dict(model_scratch.state_dict())
opt_scratch = SGD_Momentum_Scratch(model_scratch.parameters(), lr=0.1, momentum=0.9)
opt_pytorch = torch.optim.SGD(model_pytorch.parameters(), lr=0.1, momentum=0.9)
# Multiple steps to see momentum accumulate
for step in range(3):
x = torch.randn(4, 10)
loss_scratch = model_scratch(x).sum()
loss_pytorch = model_pytorch(x).sum()
opt_scratch.zero_grad()
opt_pytorch.zero_grad()
loss_scratch.backward()
loss_pytorch.backward()
opt_scratch.step()
opt_pytorch.step()
print("After 3 momentum SGD steps:")
print(f" Scratch weight[0,0]: {model_scratch.weight[0,0].item():.6f}")
print(f" PyTorch weight[0,0]: {model_pytorch.weight[0,0].item():.6f}")
print(f" Match: {torch.allclose(model_scratch.weight, model_pytorch.weight)}")
```
::: {.callout-note}
## Key Insight: Momentum
Momentum is like pushing a ball down a hill - it builds up speed in consistent directions and resists sudden direction changes. This makes optimization faster and more stable.
:::
#### Adam from Scratch
Adam combines momentum with adaptive learning rates. It tracks two quantities:
1. **First moment** $m$ (mean of gradients) - like momentum
2. **Second moment** $v$ (mean of squared gradients) - adapts learning rate per-parameter
$$m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t$$
$$v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2$$
We also need **bias correction** because $m$ and $v$ are initialized to zero:
$$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$
Finally, the update:
$$\theta_t = \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$$
```{python}
class Adam_Scratch:
"""
Adam optimizer with optional weight decay (AdamW style).
Tracks first moment (mean) and second moment (variance) of gradients.
Uses bias correction to fix initialization bias.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):
self.params = list(params)
self.lr = lr
self.b1, self.b2 = betas
self.eps = eps
self.weight_decay = weight_decay
# First moment (mean of gradients)
self.m = [torch.zeros_like(p) for p in self.params]
# Second moment (mean of squared gradients)
self.v = [torch.zeros_like(p) for p in self.params]
# Timestep
self.t = 0
def step(self):
self.t += 1
with torch.no_grad():
for i, p in enumerate(self.params):
if p.grad is None:
continue
g = p.grad
# AdamW: Weight decay applied directly to weights (decoupled)
if self.weight_decay != 0.0:
p -= self.lr * self.weight_decay * p
# Update first moment: m = beta1 * m + (1 - beta1) * g
self.m[i] = self.b1 * self.m[i] + (1 - self.b1) * g
# Update second moment: v = beta2 * v + (1 - beta2) * g^2
self.v[i] = self.b2 * self.v[i] + (1 - self.b2) * (g * g)
# Bias correction (crucial early in training!)
mhat = self.m[i] / (1 - self.b1 ** self.t)
vhat = self.v[i] / (1 - self.b2 ** self.t)
# Update parameters
p -= self.lr * mhat / (torch.sqrt(vhat) + self.eps)
def zero_grad(self):
for p in self.params:
if p.grad is not None:
p.grad = None
# Test: compare with PyTorch AdamW
torch.manual_seed(42)
model_scratch = nn.Linear(10, 2)
model_pytorch = nn.Linear(10, 2)
model_pytorch.load_state_dict(model_scratch.state_dict())
opt_scratch = Adam_Scratch(model_scratch.parameters(), lr=1e-3, weight_decay=0.01)
opt_pytorch = torch.optim.AdamW(model_pytorch.parameters(), lr=1e-3, weight_decay=0.01)
# Multiple steps
for step in range(5):
x = torch.randn(4, 10)
loss_scratch = model_scratch(x).sum()
loss_pytorch = model_pytorch(x).sum()
opt_scratch.zero_grad()
opt_pytorch.zero_grad()
loss_scratch.backward()
loss_pytorch.backward()
opt_scratch.step()
opt_pytorch.step()
print("After 5 AdamW steps:")
print(f" Scratch weight[0,0]: {model_scratch.weight[0,0].item():.6f}")
print(f" PyTorch weight[0,0]: {model_pytorch.weight[0,0].item():.6f}")
print(f" Close match: {torch.allclose(model_scratch.weight, model_pytorch.weight, atol=1e-6)}")
```
::: {.callout-note}
## Key Insight: Adam
Adam is "momentum + per-parameter learning rates." The second moment $v$ tracks how much each parameter's gradient varies. Parameters with consistently large gradients get smaller effective learning rates (stabilizing training), while those with small gradients get larger rates (speeding up learning).
:::
**Why bias correction matters:**
Without bias correction, the first few steps are biased toward zero because $m$ and $v$ are initialized to zero. Let's see this:
```{python}
# Demonstrate bias correction importance
m, v = 0.0, 0.0
b1, b2 = 0.9, 0.999
true_grad = 1.0 # Pretend gradient is always 1
print("Step | m (biased) | m_hat (corrected)")
print("-" * 45)
for t in range(1, 6):
m = b1 * m + (1 - b1) * true_grad
m_hat = m / (1 - b1 ** t)
print(f" {t} | {m:.4f} | {m_hat:.4f}")
print(f"\nWithout correction, m starts near 0.1 instead of 1.0!")
print(f"Bias correction fixes this, making m_hat ≈ 1.0 from the start.")
```
## Gradient Accumulation
Want a larger effective batch size without more memory? Use gradient accumulation!
**Problem**: Want batch_size=32 but only 8 fits in memory
**Solution**: Accumulate gradients over 4 mini-batches
```{ojs}
//| echo: false
// Step slider for gradient accumulation (0 = initial, 1-4 = mini-batches, 5 = optimizer step)
viewof accumStep = Inputs.range([0, 5], {
value: 0,
step: 1,
label: "Accumulation Step"
})
```
```{ojs}
//| echo: false
// Gradient accumulation diagram data
accumStepInfo = {
const steps = [
{ name: "Ready", description: "Gradients zeroed, ready to accumulate", gradientLevel: 0 },
{ name: "Mini-batch 1", description: "loss.backward() - gradients start accumulating", gradientLevel: 0.25 },
{ name: "Mini-batch 2", description: "loss.backward() - gradients continue accumulating", gradientLevel: 0.5 },
{ name: "Mini-batch 3", description: "loss.backward() - gradients continue accumulating", gradientLevel: 0.75 },
{ name: "Mini-batch 4", description: "loss.backward() - gradients fully accumulated", gradientLevel: 1.0 },
{ name: "Optimizer Step", description: "optimizer.step() - one weight update with effective batch_size=32", gradientLevel: 0 }
];
return steps[accumStep];
}
```
```{ojs}
//| echo: false
// Interactive gradient accumulation visualization
{
const width = 700;
const height = 340;
const batchSize = 8;
const accumSteps = 4;
const effectiveBatch = batchSize * accumSteps;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`);
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 8);
// Defs for arrows and gradients
const defs = svg.append("defs");
// Arrow markers
defs.append("marker")
.attr("id", "accum-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.edgeStroke);
defs.append("marker")
.attr("id", "accum-arrow-active")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.highlight);
// Gradient fill for the accumulator bar
const gradientFill = defs.append("linearGradient")
.attr("id", "gradient-fill")
.attr("x1", "0%")
.attr("y1", "100%")
.attr("x2", "0%")
.attr("y2", "0%");
gradientFill.append("stop")
.attr("offset", "0%")
.attr("stop-color", diagramTheme.accent);
gradientFill.append("stop")
.attr("offset", "100%")
.attr("stop-color", diagramTheme.highlight);
// Layout constants
const batchBoxWidth = 100;
const batchBoxHeight = 55;
const batchStartX = 60;
const batchSpacing = 20;
const batchY = 80;
const accumX = 480;
const accumY = 80;
const accumWidth = 70;
const accumHeight = 140;
const optimizerX = 620;
const optimizerY = 150;
// Draw mini-batch boxes
const batches = [1, 2, 3, 4];
batches.forEach((batch, i) => {
const x = batchStartX + i * (batchBoxWidth + batchSpacing);
const isActive = accumStep === batch;
const isProcessed = accumStep > batch;
const g = svg.append("g")
.attr("transform", `translate(${x}, ${batchY})`);
// Box
g.append("rect")
.attr("width", batchBoxWidth)
.attr("height", batchBoxHeight)
.attr("rx", 6)
.attr("fill", isActive ? diagramTheme.highlight : (isProcessed ? diagramTheme.accent : diagramTheme.nodeFill))
.attr("stroke", isActive ? diagramTheme.highlight : (isProcessed ? diagramTheme.accent : diagramTheme.nodeStroke))
.attr("stroke-width", isActive ? 2.5 : 1.5)
.attr("opacity", isProcessed && !isActive ? 0.7 : 1)
.style("filter", isActive ? `drop-shadow(0 0 8px ${diagramTheme.highlightGlow})` : "none");
// Batch label
g.append("text")
.attr("x", batchBoxWidth / 2)
.attr("y", 18)
.attr("text-anchor", "middle")
.attr("fill", isActive || isProcessed ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(`Mini-batch ${batch}`);
// Size info
g.append("text")
.attr("x", batchBoxWidth / 2)
.attr("y", 34)
.attr("text-anchor", "middle")
.attr("fill", isActive || isProcessed ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.8)
.text(`size=${batchSize}`);
// backward() call
g.append("text")
.attr("x", batchBoxWidth / 2)
.attr("y", 48)
.attr("text-anchor", "middle")
.attr("fill", isActive || isProcessed ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "9px")
.attr("font-family", "monospace")
.attr("opacity", 0.7)
.text("loss.backward()");
// Arrow from batch to accumulator
if (accumStep >= batch && accumStep <= 4) {
const arrowActive = isActive;
const startX = x + batchBoxWidth;
const startY = batchY + batchBoxHeight / 2;
const endX = accumX - 5;
const endY = accumY + 30 + i * 25;
// Curved path
const midX = (startX + endX) / 2 + 20;
svg.append("path")
.attr("d", `M${startX + 5},${startY} Q${midX},${startY} ${endX},${endY}`)
.attr("fill", "none")
.attr("stroke", arrowActive ? diagramTheme.highlight : diagramTheme.edgeStroke)
.attr("stroke-width", arrowActive ? 2 : 1.5)
.attr("marker-end", arrowActive ? "url(#accum-arrow-active)" : "url(#accum-arrow)")
.attr("opacity", isProcessed && !arrowActive ? 0.5 : (arrowActive ? 1 : 0.7))
.style("filter", arrowActive ? `drop-shadow(0 0 3px ${diagramTheme.highlightGlow})` : "none");
}
});
// Draw accumulator container
const accumG = svg.append("g")
.attr("transform", `translate(${accumX}, ${accumY})`);
// Accumulator background
accumG.append("rect")
.attr("width", accumWidth)
.attr("height", accumHeight)
.attr("rx", 8)
.attr("fill", diagramTheme.bgSecondary)
.attr("stroke", accumStep >= 1 && accumStep <= 4 ? diagramTheme.accent : diagramTheme.nodeStroke)
.attr("stroke-width", 2);
// Gradient level bar (fills from bottom)
const gradientLevel = accumStepInfo.gradientLevel;
const barPadding = 8;
const barWidth = accumWidth - barPadding * 2;
const barMaxHeight = accumHeight - barPadding * 2 - 20;
const barHeight = barMaxHeight * gradientLevel;
if (barHeight > 0) {
accumG.append("rect")
.attr("x", barPadding)
.attr("y", accumHeight - barPadding - barHeight)
.attr("width", barWidth)
.attr("height", barHeight)
.attr("rx", 4)
.attr("fill", "url(#gradient-fill)")
.attr("opacity", 0.9);
}
// Accumulator label
accumG.append("text")
.attr("x", accumWidth / 2)
.attr("y", 14)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("font-weight", "600")
.text("Gradients");
// Percentage label
accumG.append("text")
.attr("x", accumWidth / 2)
.attr("y", accumHeight / 2 + 5)
.attr("text-anchor", "middle")
.attr("fill", gradientLevel > 0.3 ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "700")
.text(`${Math.round(gradientLevel * 100)}%`);
// Arrow from accumulator to optimizer
const optimizerActive = accumStep === 5;
svg.append("path")
.attr("d", `M${accumX + accumWidth + 5},${accumY + accumHeight / 2} L${optimizerX - 50},${optimizerY}`)
.attr("fill", "none")
.attr("stroke", optimizerActive ? diagramTheme.highlight : diagramTheme.edgeStroke)
.attr("stroke-width", optimizerActive ? 2.5 : 1.5)
.attr("marker-end", optimizerActive ? "url(#accum-arrow-active)" : "url(#accum-arrow)")
.attr("opacity", accumStep < 5 ? 0.4 : 1)
.attr("stroke-dasharray", accumStep < 5 ? "5,3" : "none")
.style("filter", optimizerActive ? `drop-shadow(0 0 4px ${diagramTheme.highlightGlow})` : "none");
// Optimizer box
const optG = svg.append("g")
.attr("transform", `translate(${optimizerX - 45}, ${optimizerY - 30})`);
optG.append("rect")
.attr("width", 90)
.attr("height", 60)
.attr("rx", 6)
.attr("fill", optimizerActive ? diagramTheme.highlight : diagramTheme.nodeFill)
.attr("stroke", optimizerActive ? diagramTheme.highlight : diagramTheme.nodeStroke)
.attr("stroke-width", optimizerActive ? 2.5 : 1.5)
.style("filter", optimizerActive ? `drop-shadow(0 0 8px ${diagramTheme.highlightGlow})` : "none");
optG.append("text")
.attr("x", 45)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("fill", optimizerActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text("Optimizer");
optG.append("text")
.attr("x", 45)
.attr("y", 38)
.attr("text-anchor", "middle")
.attr("fill", optimizerActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "9px")
.attr("font-family", "monospace")
.attr("opacity", 0.8)
.text("step()");
optG.append("text")
.attr("x", 45)
.attr("y", 52)
.attr("text-anchor", "middle")
.attr("fill", optimizerActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "8px")
.attr("opacity", 0.7)
.text("1 update");
// Status info panel at bottom
const infoY = 250;
svg.append("rect")
.attr("x", 30)
.attr("y", infoY)
.attr("width", width - 60)
.attr("height", 70)
.attr("rx", 6)
.attr("fill", diagramTheme.bgSecondary)
.attr("stroke", diagramTheme.nodeStroke)
.attr("stroke-width", 1);
// Step name
svg.append("text")
.attr("x", 50)
.attr("y", infoY + 22)
.attr("fill", diagramTheme.highlight)
.attr("font-size", "13px")
.attr("font-weight", "700")
.text(`Step ${accumStep}: ${accumStepInfo.name}`);
// Description
svg.append("text")
.attr("x", 50)
.attr("y", infoY + 42)
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.text(accumStepInfo.description);
// Effective batch size calculation
svg.append("text")
.attr("x", 50)
.attr("y", infoY + 58)
.attr("fill", diagramTheme.accent)
.attr("font-size", "10px")
.attr("font-family", "monospace")
.text(`Effective batch size: ${batchSize} x ${accumSteps} = ${effectiveBatch}`);
return svg.node();
}
```
```{python}
# Demonstrate gradient accumulation
model = nn.Linear(10, 1)
accumulation_steps = 4
# Simulate accumulated gradients
total_loss = 0
for i in range(accumulation_steps):
x = torch.randn(8, 10) # Mini-batch
y = model(x)
loss = y.mean() / accumulation_steps # Scale loss!
loss.backward() # Gradients accumulate
total_loss += loss.item()
print(f"Accumulated loss (4 mini-batches): {total_loss:.4f}")
print(f"Gradient norm before step: {model.weight.grad.norm().item():.4f}")
# Now do one optimizer step
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.step()
optimizer.zero_grad()
print("After optimizer.step() and zero_grad()")
```
## Gradient Clipping
Gradient clipping prevents exploding gradients by scaling down gradients when their norm exceeds a threshold.
```{python}
# Demonstrate gradient clipping
model = nn.Linear(10, 10)
# Create artificial large gradients
for p in model.parameters():
p.grad = torch.randn_like(p) * 100 # Very large!
# Compute gradient norm before clipping
total_norm_before = 0
for p in model.parameters():
total_norm_before += p.grad.norm().item() ** 2
total_norm_before = total_norm_before ** 0.5
print(f"Gradient norm before clipping: {total_norm_before:.2f}")
# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Compute gradient norm after
total_norm_after = 0
for p in model.parameters():
total_norm_after += p.grad.norm().item() ** 2
total_norm_after = total_norm_after ** 0.5
print(f"Gradient norm after clipping: {total_norm_after:.2f}")
print(f"\nGradients scaled down by {total_norm_before / total_norm_after:.1f}x")
```
### Gradient Clipping from Scratch
Let's implement gradient clipping ourselves to understand the algorithm:
```{python}
def clip_grad_norm_scratch(params, max_norm: float) -> float:
"""
Clip gradients by global norm.
Algorithm:
1. Compute total norm: sqrt(sum of all grad^2)
2. If total_norm > max_norm, scale all grads by (max_norm / total_norm)
Returns the original norm (before clipping).
"""
params = list(params)
# Step 1: Compute total gradient norm
total_sq = 0.0
for p in params:
if p.grad is not None:
total_sq += (p.grad ** 2).sum().item()
total_norm = total_sq ** 0.5
# Step 2: Clip if needed
if total_norm > max_norm:
scale = max_norm / (total_norm + 1e-12) # Small epsilon for numerical stability
for p in params:
if p.grad is not None:
p.grad *= scale
return total_norm
# Test: compare with PyTorch
model_scratch = nn.Linear(10, 10)
model_pytorch = nn.Linear(10, 10)
# Set same large gradients
torch.manual_seed(42)
for p in model_scratch.parameters():
p.grad = torch.randn_like(p) * 100
for ps, pp in zip(model_scratch.parameters(), model_pytorch.parameters()):
pp.grad = ps.grad.clone()
# Clip with both
norm_scratch = clip_grad_norm_scratch(model_scratch.parameters(), max_norm=1.0)
norm_pytorch = torch.nn.utils.clip_grad_norm_(model_pytorch.parameters(), max_norm=1.0)
print(f"Original norm (scratch): {norm_scratch:.4f}")
print(f"Original norm (PyTorch): {norm_pytorch.item():.4f}")
# Check gradients match after clipping
grads_match = all(
torch.allclose(ps.grad, pp.grad)
for ps, pp in zip(model_scratch.parameters(), model_pytorch.parameters())
)
print(f"Gradients match after clipping: {grads_match}")
```
::: {.callout-note}
## Key Insight: Gradient Clipping
Gradient clipping scales ALL gradients by the same factor to preserve their relative magnitudes. This is different from clipping each gradient independently - we want to maintain the direction of the overall update while limiting its magnitude.
:::
**When to use gradient clipping:**
- Always for transformer training (standard practice)
- max_norm=1.0 is a good default
- Monitor gradient norms during training - consistently high norms suggest instability
## Batch Size Considerations
Batch size affects both training dynamics and memory usage:
**Tradeoffs:**
| Aspect | Small Batch | Large Batch |
|--------|-------------|-------------|
| Memory | Less | More |
| Gradient noise | More (regularization effect) | Less (stable gradients) |
| Convergence | May generalize better | Faster convergence |
| LR needed | Lower | Higher (linear scaling rule) |
**The Linear Scaling Rule:** When you double the batch size, you can double the learning rate. This maintains similar training dynamics.
**Effective batch size** = batch_size x gradient_accumulation_steps
```{python}
# Batch size vs memory example (conceptual)
print("Memory usage scales linearly with batch size:")
print()
for batch_size in [8, 16, 32, 64]:
# Simulated memory calculation
tokens_per_batch = batch_size * 512 # sequence length
memory_mb = batch_size * 50 # ~50MB per sample for a small model
print(f" Batch size {batch_size:2d}: ~{tokens_per_batch:,} tokens/batch, ~{memory_mb}MB")
```
## Mixed Precision Training
Modern GPUs/TPUs can perform faster computation with lower precision numbers (fp16/bf16) while maintaining training quality.
::: {.callout-note}
## Conceptual Example
The code below shows the API but doesn't execute training — mixed precision requires specific hardware (CUDA GPUs) to demonstrate speedups.
:::
**Precision types:**
| Type | Bits | Range | Use Case |
|------|------|-------|----------|
| fp32 | 32 | Large | Default, master weights |
| fp16 | 16 | Limited | Faster compute, risk of overflow |
| bf16 | 16 | Large (like fp32) | Best of both worlds |
**How mixed precision works:**
1. Keep master weights in fp32 (full precision)
2. Cast to fp16/bf16 for forward/backward pass (fast)
3. Compute gradients in fp16/bf16
4. Update master weights in fp32 (accurate)
```{python}
# Mixed precision example (conceptual - requires GPU)
print("Mixed Precision Training:")
print()
# Simulated speedup
print("Speedups on modern hardware:")
print(" - A100 GPU with bf16: ~2x faster than fp32")
print(" - H100 GPU with fp8: ~3x faster than fp32")
print()
# PyTorch autocast usage
print("PyTorch usage:")
print("""
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in mixed precision
with autocast():
logits = model(input_ids)
loss = F.cross_entropy(logits, targets)
# Backward pass with gradient scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
""")
```
**Practical advice:**
- Use bf16 if available (A100, H100) - it has the same dynamic range as fp32
- Use fp16 with gradient scaling on older GPUs (V100)
- Apple Silicon (MPS) does not yet fully support mixed precision
## Distributed Training Basics
Training large models requires multiple GPUs. Here's a brief overview:
**Data Parallel (DP/DDP):**
- Same model copied to all GPUs
- Each GPU processes different data
- Gradients are averaged across GPUs
- Memory per GPU = full model size
```{ojs}
//| echo: false
// Data Parallel step descriptions
dpSteps = [
{
id: 0,
name: "Input Data",
description: "Large training batch ready to be distributed across GPUs"
},
{
id: 1,
name: "Split Data",
description: "Batch is divided evenly among available GPUs"
},
{
id: 2,
name: "Forward Pass",
description: "Each GPU computes forward pass on its data shard with full model copy"
},
{
id: 3,
name: "Compute Gradients",
description: "Each GPU computes gradients via backpropagation"
},
{
id: 4,
name: "AllReduce",
description: "Gradients are averaged across all GPUs via collective communication"
}
]
```
```{ojs}
//| echo: false
// Step slider for Data Parallel diagram
viewof dpStep = Inputs.range([0, 4], {
value: 0,
step: 1,
label: "Step"
})
```
```{ojs}
//| echo: false
// Current Data Parallel step info
currentDpStep = dpSteps[dpStep]
```
```{ojs}
//| echo: false
// Data Parallel interactive diagram
{
const width = 700;
const height = 380;
const numGpus = 3;
const batchPerGpu = 8;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`);
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 8);
// Defs for arrows and gradients
const defs = svg.append("defs");
// Arrow markers
defs.append("marker")
.attr("id", "dp-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.edgeStroke);
defs.append("marker")
.attr("id", "dp-arrow-active")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 5)
.attr("markerHeight", 5)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.highlight);
// Data flow gradient for animation effect
const flowGradient = defs.append("linearGradient")
.attr("id", "dp-flow-gradient")
.attr("x1", "0%")
.attr("y1", "0%")
.attr("x2", "100%")
.attr("y2", "0%");
flowGradient.append("stop")
.attr("offset", "0%")
.attr("stop-color", diagramTheme.highlight)
.attr("stop-opacity", 0.2);
flowGradient.append("stop")
.attr("offset", "50%")
.attr("stop-color", diagramTheme.highlight)
.attr("stop-opacity", 1);
flowGradient.append("stop")
.attr("offset", "100%")
.attr("stop-color", diagramTheme.highlight)
.attr("stop-opacity", 0.2);
// Layout constants
const batchX = 70;
const splitX = 200;
const gpuX = 400;
const reduceX = 580;
const centerY = height / 2;
const gpuSpacing = 90;
// GPU Y positions
const gpuYs = [centerY - gpuSpacing, centerY, centerY + gpuSpacing];
// Helper: draw data block
const drawDataBlock = (g, x, y, w, h, label, isActive, isSmall = false) => {
const block = g.append("g").attr("transform", `translate(${x}, ${y})`);
block.append("rect")
.attr("x", -w/2)
.attr("y", -h/2)
.attr("width", w)
.attr("height", h)
.attr("rx", 4)
.attr("fill", isActive ? diagramTheme.highlight : diagramTheme.nodeFill)
.attr("stroke", isActive ? diagramTheme.highlight : diagramTheme.nodeStroke)
.attr("stroke-width", isActive ? 2 : 1.5)
.style("filter", isActive ? `drop-shadow(0 0 6px ${diagramTheme.highlightGlow})` : "none");
if (label) {
block.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", isSmall ? "10px" : "11px")
.attr("font-weight", "500")
.text(label);
}
return block;
};
// Helper: draw GPU box
const drawGpu = (g, x, y, gpuNum, isActive, showGradients = false) => {
const gpu = g.append("g").attr("transform", `translate(${x}, ${y})`);
const boxW = 100;
const boxH = 60;
// GPU container
gpu.append("rect")
.attr("x", -boxW/2)
.attr("y", -boxH/2)
.attr("width", boxW)
.attr("height", boxH)
.attr("rx", 6)
.attr("fill", isActive ? diagramTheme.highlight : diagramTheme.nodeFill)
.attr("stroke", isActive ? diagramTheme.highlight : diagramTheme.nodeStroke)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.style("filter", isActive ? `drop-shadow(0 0 8px ${diagramTheme.highlightGlow})` : "none");
// GPU label
gpu.append("text")
.attr("y", -12)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(`GPU ${gpuNum}`);
// Full model indicator
gpu.append("text")
.attr("y", 6)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "9px")
.attr("opacity", isActive ? 0.9 : 0.7)
.text("Full Model");
// Gradient indicator (when computing gradients)
if (showGradients) {
gpu.append("text")
.attr("y", 20)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.accent)
.attr("font-size", "9px")
.attr("font-weight", "500")
.text("∇ gradients");
}
return gpu;
};
// Draw based on current step
const mainGroup = svg.append("g");
// Step 0: Show full batch
if (dpStep >= 0) {
const isActive = dpStep === 0;
drawDataBlock(mainGroup, batchX, centerY, 60, 100, null, isActive);
// Data visualization inside batch
const batchGroup = mainGroup.append("g").attr("transform", `translate(${batchX}, ${centerY})`);
for (let i = 0; i < 6; i++) {
const row = Math.floor(i / 2);
const col = i % 2;
batchGroup.append("rect")
.attr("x", -20 + col * 22)
.attr("y", -35 + row * 25)
.attr("width", 18)
.attr("height", 20)
.attr("rx", 2)
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.accent)
.attr("opacity", isActive ? 0.9 : 0.6);
}
// Batch label
mainGroup.append("text")
.attr("x", batchX)
.attr("y", centerY + 65)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.text("Data Batch");
mainGroup.append("text")
.attr("x", batchX)
.attr("y", centerY + 80)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.7)
.text(`(${batchPerGpu * numGpus} samples)`);
}
// Step 1+: Show split batches
if (dpStep >= 1) {
const isActive = dpStep === 1;
// Draw split indicator arrows
for (let i = 0; i < numGpus; i++) {
const startX = batchX + 35;
const startY = centerY;
const endX = splitX - 25;
const endY = gpuYs[i];
mainGroup.append("path")
.attr("d", `M${startX},${startY} C${startX + 40},${startY} ${endX - 40},${endY} ${endX},${endY}`)
.attr("fill", "none")
.attr("stroke", isActive ? diagramTheme.highlight : diagramTheme.edgeStroke)
.attr("stroke-width", isActive ? 2 : 1.5)
.attr("marker-end", isActive ? "url(#dp-arrow-active)" : "url(#dp-arrow)")
.attr("opacity", isActive ? 1 : 0.6)
.style("filter", isActive ? `drop-shadow(0 0 4px ${diagramTheme.highlightGlow})` : "none");
}
// Draw split batches
for (let i = 0; i < numGpus; i++) {
drawDataBlock(mainGroup, splitX, gpuYs[i], 45, 40, null, isActive, true);
// Mini data visualization
const splitGroup = mainGroup.append("g").attr("transform", `translate(${splitX}, ${gpuYs[i]})`);
for (let j = 0; j < 2; j++) {
splitGroup.append("rect")
.attr("x", -15 + j * 16)
.attr("y", -8)
.attr("width", 12)
.attr("height", 16)
.attr("rx", 2)
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.accent)
.attr("opacity", isActive ? 0.9 : 0.6);
}
// Batch shard label
mainGroup.append("text")
.attr("x", splitX)
.attr("y", gpuYs[i] + 30)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "9px")
.attr("opacity", 0.7)
.text(`Batch ${i}`);
}
}
// Step 2+: Show GPUs with forward pass
if (dpStep >= 2) {
const isForward = dpStep === 2;
const isGradient = dpStep === 3;
const isActive = isForward || isGradient;
// Arrows from split to GPU
for (let i = 0; i < numGpus; i++) {
const startX = splitX + 28;
const endX = gpuX - 55;
const y = gpuYs[i];
mainGroup.append("path")
.attr("d", `M${startX},${y} L${endX},${y}`)
.attr("fill", "none")
.attr("stroke", isActive ? diagramTheme.highlight : diagramTheme.edgeStroke)
.attr("stroke-width", isActive ? 2 : 1.5)
.attr("marker-end", isActive ? "url(#dp-arrow-active)" : "url(#dp-arrow)")
.attr("opacity", isActive ? 1 : 0.6)
.style("filter", isActive ? `drop-shadow(0 0 4px ${diagramTheme.highlightGlow})` : "none");
}
// Draw GPUs
for (let i = 0; i < numGpus; i++) {
drawGpu(mainGroup, gpuX, gpuYs[i], i, isActive, isGradient);
}
}
// Step 4: AllReduce
if (dpStep >= 4) {
const isActive = dpStep === 4;
// Arrows from GPU to AllReduce
for (let i = 0; i < numGpus; i++) {
const startX = gpuX + 55;
const startY = gpuYs[i];
const endX = reduceX - 45;
const endY = centerY;
mainGroup.append("path")
.attr("d", `M${startX},${startY} C${startX + 30},${startY} ${endX - 30},${endY} ${endX},${endY}`)
.attr("fill", "none")
.attr("stroke", isActive ? diagramTheme.highlight : diagramTheme.edgeStroke)
.attr("stroke-width", isActive ? 2 : 1.5)
.attr("marker-end", isActive ? "url(#dp-arrow-active)" : "url(#dp-arrow)")
.attr("opacity", isActive ? 1 : 0.6)
.style("filter", isActive ? `drop-shadow(0 0 4px ${diagramTheme.highlightGlow})` : "none");
}
// AllReduce node
const reduceGroup = mainGroup.append("g").attr("transform", `translate(${reduceX}, ${centerY})`);
reduceGroup.append("rect")
.attr("x", -42)
.attr("y", -35)
.attr("width", 84)
.attr("height", 70)
.attr("rx", 6)
.attr("fill", isActive ? diagramTheme.highlight : diagramTheme.nodeFill)
.attr("stroke", isActive ? diagramTheme.highlight : diagramTheme.nodeStroke)
.attr("stroke-width", isActive ? 2.5 : 1.5)
.style("filter", isActive ? `drop-shadow(0 0 8px ${diagramTheme.highlightGlow})` : "none");
reduceGroup.append("text")
.attr("y", -10)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text("AllReduce");
reduceGroup.append("text")
.attr("y", 8)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "9px")
.attr("opacity", 0.8)
.text("Average");
reduceGroup.append("text")
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isActive ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "9px")
.attr("opacity", 0.8)
.text("Gradients");
}
// Step indicator bar at top
const stepBarY = 25;
const stepBarWidth = 500;
const stepBarX = (width - stepBarWidth) / 2;
const stepWidth = stepBarWidth / 5;
const stepLabels = ["Input", "Split", "Forward", "Gradients", "AllReduce"];
for (let i = 0; i < 5; i++) {
const isCurrentStep = dpStep === i;
const isPastStep = dpStep > i;
const stepX = stepBarX + i * stepWidth + stepWidth / 2;
// Step circle
mainGroup.append("circle")
.attr("cx", stepX)
.attr("cy", stepBarY)
.attr("r", 12)
.attr("fill", isCurrentStep ? diagramTheme.highlight : (isPastStep ? diagramTheme.accent : diagramTheme.nodeFill))
.attr("stroke", isCurrentStep ? diagramTheme.highlight : (isPastStep ? diagramTheme.accent : diagramTheme.nodeStroke))
.attr("stroke-width", isCurrentStep ? 2 : 1.5)
.style("filter", isCurrentStep ? `drop-shadow(0 0 6px ${diagramTheme.highlightGlow})` : "none");
// Step number
mainGroup.append("text")
.attr("x", stepX)
.attr("y", stepBarY)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isCurrentStep || isPastStep ? diagramTheme.textOnHighlight : diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("font-weight", "600")
.text(i);
// Step label
mainGroup.append("text")
.attr("x", stepX)
.attr("y", stepBarY + 22)
.attr("text-anchor", "middle")
.attr("fill", isCurrentStep ? diagramTheme.highlight : diagramTheme.nodeText)
.attr("font-size", "9px")
.attr("font-weight", isCurrentStep ? "600" : "400")
.attr("opacity", isCurrentStep ? 1 : 0.7)
.text(stepLabels[i]);
// Connecting line (except last)
if (i < 4) {
const lineStartX = stepX + 15;
const lineEndX = stepBarX + (i + 1) * stepWidth + stepWidth / 2 - 15;
mainGroup.append("line")
.attr("x1", lineStartX)
.attr("y1", stepBarY)
.attr("x2", lineEndX)
.attr("y2", stepBarY)
.attr("stroke", isPastStep ? diagramTheme.accent : diagramTheme.edgeStroke)
.attr("stroke-width", 1.5)
.attr("opacity", 0.5);
}
}
// Effective batch size display
const statsY = height - 30;
mainGroup.append("text")
.attr("x", width / 2)
.attr("y", statsY)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.8)
.text(`Effective batch size: ${batchPerGpu} samples/GPU × ${numGpus} GPUs = ${batchPerGpu * numGpus} samples`);
return svg.node();
}
```
```{ojs}
//| echo: false
// Step description panel for Data Parallel
html`<div style="
background: ${diagramTheme.bgSecondary};
border-radius: 6px;
padding: 12px 16px;
margin-top: 8px;
border-left: 3px solid ${diagramTheme.highlight};
">
<div style="font-weight: 600; color: ${diagramTheme.nodeText}; margin-bottom: 4px;">
Step ${currentDpStep.id}: ${currentDpStep.name}
</div>
<div style="color: ${diagramTheme.nodeText}; opacity: 0.8; font-size: 13px;">
${currentDpStep.description}
</div>
</div>`
```
**Fully Sharded Data Parallel (FSDP):**
- Model is sharded across GPUs
- Each GPU holds a fraction of parameters
- Memory per GPU = model_size / num_gpus
- Enables training models larger than single GPU memory
```{python}
# Distributed training concepts
print("Distributed Training Strategies:")
print()
print("1. Data Parallel (DDP):")
print(" - Best for: Models that fit in one GPU")
print(" - Scales: Batch size (effective_batch = batch * num_gpus)")
print()
print("2. Fully Sharded Data Parallel (FSDP):")
print(" - Best for: Large models (>10B parameters)")
print(" - Scales: Model size and batch size")
print()
print("3. Pipeline Parallel:")
print(" - Best for: Very deep models")
print(" - Splits model layers across GPUs")
print()
print("4. Tensor Parallel:")
print(" - Best for: Models with large layers")
print(" - Splits individual layers across GPUs")
```
## Training Stability and Failure Modes
Understanding common failure modes helps you debug training issues:
**Loss = NaN or Inf**
Causes:
- Learning rate too high
- Gradient explosion
- Numerical overflow in fp16
Solutions:
- Reduce learning rate (try 10x smaller)
- Add gradient clipping
- Use bf16 instead of fp16 or add gradient scaling
**Loss stuck at high value**
Causes:
- Learning rate too low
- Poor weight initialization
- Data loading bug (same batch every time)
Solutions:
- Increase learning rate
- Check data loader with small sample
- Verify model architecture
**Loss oscillates or increases**
Causes:
- Learning rate too high
- Batch size too small
- Bug in loss computation
Solutions:
- Add warmup period
- Reduce learning rate
- Use gradient accumulation
```{python}
# Visualize training pathologies
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
steps = np.arange(100)
# Good training
ax = axes[0]
good_loss = 5.0 * np.exp(-0.03 * steps) + 0.5 + 0.1 * np.random.randn(100)
ax.plot(steps, good_loss)
ax.set_title('Good Training')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_ylim(0, 6)
ax.grid(True, alpha=0.3)
# LR too high - diverges
ax = axes[1]
unstable_loss = 4.0 + 0.5 * np.sin(steps * 0.3) + 0.02 * steps
ax.plot(steps, unstable_loss, 'r')
ax.set_title('LR Too High (Unstable)')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_ylim(0, 8)
ax.grid(True, alpha=0.3)
# LR too low - slow convergence
ax = axes[2]
slow_loss = 5.0 * np.exp(-0.005 * steps) + 0.5
ax.plot(steps, slow_loss, 'orange')
ax.set_title('LR Too Low (Slow)')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_ylim(0, 6)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
```
**Debugging checklist:**
1. Check initial loss - should be ~log(vocab_size) for untrained model
2. Verify data is being loaded correctly (print a few samples)
3. Monitor gradient norms - should be stable, not growing
4. Check learning rate schedule is working (print LR each step)
5. Test with a tiny dataset first to verify overfitting capability
## Text Dataset
Let's create a simple dataset for language modeling:
```{python}
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
"""Simple text dataset for language modeling."""
def __init__(self, tokens, seq_len):
self.tokens = tokens
self.seq_len = seq_len
def __len__(self):
return max(0, len(self.tokens) - self.seq_len)
def __getitem__(self, idx):
input_ids = self.tokens[idx:idx + self.seq_len]
targets = self.tokens[idx + 1:idx + self.seq_len + 1]
return input_ids, targets
# Create a simple dataset
tokens = torch.arange(100) # Token IDs 0-99
seq_len = 8
dataset = TextDataset(tokens, seq_len=seq_len)
print(f"Token IDs: {tokens[:20].tolist()}...")
print(f"Sequence length: {seq_len}")
print(f"Number of samples: {len(dataset)}")
```
```{python}
# Look at a sample
input_ids, targets = dataset[0]
print("Sample 0:")
print(f" Input: {input_ids.tolist()}")
print(f" Target: {targets.tolist()}")
print(f"\n Target is input shifted by 1 position!")
# Another sample
input_ids, targets = dataset[50]
print(f"\nSample 50:")
print(f" Input: {input_ids.tolist()}")
print(f" Target: {targets.tolist()}")
```
## Training a Model
Now let's put it all together and train a tiny model:
```{python}
import sys
sys.path.insert(0, '..')
from m06_transformer.transformer import create_gpt_tiny
# Create model and data
torch.manual_seed(42)
vocab_size = 100
model = create_gpt_tiny(vocab_size=vocab_size)
# Random "training data"
tokens = torch.randint(0, vocab_size, (5000,))
print(f"Model: {model.num_params:,} parameters")
print(f"Training data: {len(tokens):,} tokens")
```
```{python}
# Check initial loss (should be ~log(vocab_size) for random predictions)
dataset = TextDataset(tokens, seq_len=32)
input_ids, targets = dataset[0]
input_ids = input_ids.unsqueeze(0) # Add batch dimension
targets = targets.unsqueeze(0)
model.eval()
with torch.no_grad():
logits = model(input_ids)
# Reshape for loss computation
B, T, V = logits.shape
initial_loss = F.cross_entropy(logits.view(B*T, V), targets.view(B*T))
print(f"Initial loss: {initial_loss.item():.4f}")
print(f"Initial perplexity: {math.exp(initial_loss.item()):.2f}")
print(f"\nExpected for random guessing: loss ~ {np.log(vocab_size):.2f}, ppl ~ {vocab_size}")
```
```{python}
def train_model(model, tokens, num_steps=100, batch_size=16, seq_len=32, learning_rate=3e-4):
"""Simple training loop."""
dataset = TextDataset(tokens, seq_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
scheduler = CosineScheduler(optimizer, warmup_steps=10, total_steps=num_steps, min_lr=1e-5)
model.train()
losses = []
step = 0
while step < num_steps:
for input_ids, targets in dataloader:
if step >= num_steps:
break
# Forward pass
logits = model(input_ids)
B, T, V = logits.shape
loss = F.cross_entropy(logits.view(B*T, V), targets.view(B*T))
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Update
optimizer.step()
scheduler.step()
optimizer.zero_grad()
losses.append(loss.item())
if step % 10 == 0:
lr = optimizer.param_groups[0]['lr']
ppl = math.exp(loss.item())
print(f"Step {step:3d} | Loss: {loss.item():.4f} | PPL: {ppl:.2f} | LR: {lr:.2e}")
step += 1
return losses
# Train!
print("Starting training...\n")
losses = train_model(model, tokens, num_steps=100)
print(f"\nFinal loss: {losses[-1]:.4f}")
print(f"Final perplexity: {math.exp(losses[-1]):.2f}")
```
```{python}
# Plot training curve
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
# Loss
ax = axes[0]
ax.plot(losses)
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
ax.grid(True, alpha=0.3)
ax.axhline(y=np.log(vocab_size), color='r', linestyle='--', label='Random baseline')
ax.legend()
# Perplexity
ax = axes[1]
ppls = [math.exp(l) for l in losses]
ax.plot(ppls)
ax.set_xlabel('Step')
ax.set_ylabel('Perplexity')
ax.set_title('Training Perplexity')
ax.grid(True, alpha=0.3)
ax.axhline(y=vocab_size, color='r', linestyle='--', label='Random baseline')
ax.legend()
plt.tight_layout()
plt.show()
```
## Effect of Learning Rate
Learning rate is crucial - too high causes instability, too low is slow:
```{python}
# Train with different learning rates
learning_rates = [1e-5, 1e-4, 3e-4, 1e-3, 3e-3]
all_losses = {}
for lr in learning_rates:
torch.manual_seed(42)
model = create_gpt_tiny(vocab_size=100)
tokens = torch.randint(0, 100, (3000,))
# Train silently
dataset = TextDataset(tokens, seq_len=32)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, drop_last=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
model.train()
losses = []
step = 0
while step < 50:
for input_ids, targets in dataloader:
if step >= 50:
break
logits = model(input_ids)
B, T, V = logits.shape
loss = F.cross_entropy(logits.view(B*T, V), targets.view(B*T))
loss.backward()
optimizer.step()
optimizer.zero_grad()
losses.append(loss.item())
step += 1
all_losses[lr] = losses
print(f"LR={lr:.0e}: final_loss={losses[-1]:.3f}, final_ppl={math.exp(losses[-1]):.1f}")
```
```{python}
# Plot comparison
plt.figure(figsize=(12, 5))
for lr, losses in all_losses.items():
plt.plot(losses, label=f'LR={lr:.0e}')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training Loss for Different Learning Rates')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
print("\nObservations:")
print("- Too low (1e-5): Training is very slow")
print("- Just right (3e-4): Smooth, fast convergence")
print("- Too high (3e-3): Unstable, loss may spike or diverge")
```
## Checkpointing
Save regularly! Training can crash. Here's what to save:
```{python}
# Demonstrate checkpointing
import json
from pathlib import Path
def save_checkpoint(model, optimizer, step, loss, path):
"""Save a training checkpoint."""
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'step': step,
'loss': loss,
}
torch.save(checkpoint, path)
print(f"Checkpoint saved to {path}")
def load_checkpoint(model, optimizer, path):
"""Load a training checkpoint."""
checkpoint = torch.load(path, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f"Checkpoint loaded from {path}")
print(f" Step: {checkpoint['step']}, Loss: {checkpoint['loss']:.4f}")
return checkpoint['step'], checkpoint['loss']
# Save example
model = create_gpt_tiny(vocab_size=100)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
save_checkpoint(model, optimizer, step=50, loss=2.5, path="demo_checkpoint.pt")
# Load example
model2 = create_gpt_tiny(vocab_size=100)
optimizer2 = torch.optim.AdamW(model2.parameters(), lr=3e-4)
step, loss = load_checkpoint(model2, optimizer2, "demo_checkpoint.pt")
# Clean up
Path("demo_checkpoint.pt").unlink()
```
## Validation and Early Stopping
Monitor validation loss to detect overfitting:
```{ojs}
//| echo: false
// Early Stopping Controls
viewof esCurrentEpoch = Inputs.range([1, 50], {
value: 1,
step: 1,
label: "Current Epoch"
})
```
```{ojs}
//| echo: false
// Generate training and validation loss curves
earlyStoppingData = {
const epochs = 50;
const data = [];
// Training loss: exponential decay with noise
// Validation loss: decreases then increases (U-shape)
const bestEpoch = 25; // Where validation loss is lowest
for (let epoch = 1; epoch <= epochs; epoch++) {
// Training loss: smooth exponential decay
const trainLoss = 2.5 * Math.exp(-0.08 * epoch) + 0.3 + 0.05 * Math.sin(epoch * 0.5);
// Validation loss: U-shaped curve
// Decreases initially, then increases (overfitting)
const valBase = 2.5 * Math.exp(-0.06 * epoch) + 0.4;
const overfitComponent = epoch > bestEpoch ? 0.02 * Math.pow(epoch - bestEpoch, 1.3) : 0;
const valLoss = valBase + overfitComponent + 0.03 * Math.sin(epoch * 0.7 + 1);
data.push({
epoch,
trainLoss,
valLoss,
gap: valLoss - trainLoss
});
}
return data;
}
// Find the best epoch (minimum validation loss)
bestModelEpoch = {
let minVal = Infinity;
let bestEpoch = 1;
for (const d of earlyStoppingData) {
if (d.valLoss < minVal) {
minVal = d.valLoss;
bestEpoch = d.epoch;
}
}
return bestEpoch;
}
// Current epoch data
currentEpochData = {
const current = earlyStoppingData.find(d => d.epoch === esCurrentEpoch);
return current || earlyStoppingData[0];
}
// Training phase detection
trainingPhase = {
if (esCurrentEpoch < bestModelEpoch - 5) return "learning";
if (esCurrentEpoch <= bestModelEpoch + 2) return "optimal";
return "overfitting";
}
```
```{ojs}
//| echo: false
// Early Stopping Visualization
{
const theme = diagramTheme;
const width = 700;
const height = 400;
const margin = { top: 40, right: 150, bottom: 60, left: 70 };
const innerWidth = width - margin.left - margin.right;
const innerHeight = height - margin.top - margin.bottom;
const svg = d3.create("svg")
.attr("viewBox", `0 0 ${width} ${height}`)
.attr("width", "100%")
.attr("height", height)
.style("max-width", `${width}px`)
.style("font-family", "'JetBrains Mono', 'Fira Code', monospace");
const defs = svg.append("defs");
// Background gradient
const bgGradient = defs.append("linearGradient")
.attr("id", "es-bg-gradient")
.attr("x1", "0%")
.attr("y1", "0%")
.attr("x2", "0%")
.attr("y2", "100%");
bgGradient.append("stop")
.attr("offset", "0%")
.attr("stop-color", theme.isDark ? "#1a1a2e" : "#f8fafc");
bgGradient.append("stop")
.attr("offset", "100%")
.attr("stop-color", theme.isDark ? "#16162a" : "#f1f5f9");
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", "url(#es-bg-gradient)")
.attr("rx", 12);
const chart = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
// Scales
const xScale = d3.scaleLinear()
.domain([1, 50])
.range([0, innerWidth]);
const yScale = d3.scaleLinear()
.domain([0, 3])
.range([innerHeight, 0]);
// Overfitting region highlight
chart.append("rect")
.attr("x", xScale(bestModelEpoch))
.attr("y", 0)
.attr("width", innerWidth - xScale(bestModelEpoch))
.attr("height", innerHeight)
.attr("fill", theme.isDark ? "#ff6b6b" : "#ef4444")
.attr("opacity", trainingPhase === "overfitting" ? 0.15 : 0.05);
// Optimal zone highlight
chart.append("rect")
.attr("x", xScale(Math.max(1, bestModelEpoch - 5)))
.attr("y", 0)
.attr("width", xScale(bestModelEpoch + 2) - xScale(Math.max(1, bestModelEpoch - 5)))
.attr("height", innerHeight)
.attr("fill", theme.isDark ? "#4ade80" : "#22c55e")
.attr("opacity", trainingPhase === "optimal" ? 0.15 : 0.05);
// Region labels at top
chart.append("text")
.attr("x", xScale(10))
.attr("y", 15)
.attr("text-anchor", "middle")
.attr("font-size", "10px")
.attr("font-weight", trainingPhase === "learning" ? "600" : "400")
.attr("fill", theme.accent)
.attr("opacity", trainingPhase === "learning" ? 1 : 0.5)
.text("LEARNING");
chart.append("text")
.attr("x", xScale(bestModelEpoch))
.attr("y", 15)
.attr("text-anchor", "middle")
.attr("font-size", "10px")
.attr("font-weight", trainingPhase === "optimal" ? "600" : "400")
.attr("fill", theme.isDark ? "#4ade80" : "#22c55e")
.attr("opacity", trainingPhase === "optimal" ? 1 : 0.5)
.text("OPTIMAL");
chart.append("text")
.attr("x", xScale(40))
.attr("y", 15)
.attr("text-anchor", "middle")
.attr("font-size", "10px")
.attr("font-weight", trainingPhase === "overfitting" ? "600" : "400")
.attr("fill", theme.isDark ? "#ff6b6b" : "#ef4444")
.attr("opacity", trainingPhase === "overfitting" ? 1 : 0.5)
.text("OVERFITTING");
// Grid lines
const yTicks = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
yTicks.forEach(tick => {
chart.append("line")
.attr("x1", 0)
.attr("x2", innerWidth)
.attr("y1", yScale(tick))
.attr("y2", yScale(tick))
.attr("stroke", theme.nodeStroke)
.attr("stroke-opacity", 0.2)
.attr("stroke-dasharray", "3,3");
});
// X-axis
chart.append("g")
.attr("transform", `translate(0, ${innerHeight})`)
.call(d3.axisBottom(xScale).ticks(10))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick line").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick text").attr("fill", theme.nodeText).attr("font-size", "11px"));
// Y-axis
chart.append("g")
.call(d3.axisLeft(yScale).ticks(6))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick line").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick text").attr("fill", theme.nodeText).attr("font-size", "11px"));
// Axis labels
chart.append("text")
.attr("x", innerWidth / 2)
.attr("y", innerHeight + 45)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "12px")
.attr("font-weight", "500")
.text("Epoch");
chart.append("text")
.attr("x", -innerHeight / 2)
.attr("y", -50)
.attr("transform", "rotate(-90)")
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "12px")
.attr("font-weight", "500")
.text("Loss");
// Filter data up to current epoch
const visibleData = earlyStoppingData.filter(d => d.epoch <= esCurrentEpoch);
// Gap fill between curves (overfitting visualization)
const gapArea = d3.area()
.x(d => xScale(d.epoch))
.y0(d => yScale(d.trainLoss))
.y1(d => yScale(d.valLoss))
.curve(d3.curveMonotoneX);
chart.append("path")
.datum(visibleData)
.attr("d", gapArea)
.attr("fill", theme.isDark ? "#ff6b6b" : "#ef4444")
.attr("opacity", 0.1);
// Training loss line
const trainLine = d3.line()
.x(d => xScale(d.epoch))
.y(d => yScale(d.trainLoss))
.curve(d3.curveMonotoneX);
chart.append("path")
.datum(visibleData)
.attr("d", trainLine)
.attr("fill", "none")
.attr("stroke", theme.accent)
.attr("stroke-width", 3)
.attr("stroke-linecap", "round");
// Validation loss line
const valLine = d3.line()
.x(d => xScale(d.epoch))
.y(d => yScale(d.valLoss))
.curve(d3.curveMonotoneX);
chart.append("path")
.datum(visibleData)
.attr("d", valLine)
.attr("fill", "none")
.attr("stroke", theme.highlight)
.attr("stroke-width", 3)
.attr("stroke-linecap", "round");
// Best model marker (vertical line at best epoch)
if (esCurrentEpoch >= bestModelEpoch) {
const bestData = earlyStoppingData.find(d => d.epoch === bestModelEpoch);
chart.append("line")
.attr("x1", xScale(bestModelEpoch))
.attr("x2", xScale(bestModelEpoch))
.attr("y1", 0)
.attr("y2", innerHeight)
.attr("stroke", theme.isDark ? "#4ade80" : "#22c55e")
.attr("stroke-width", 2)
.attr("stroke-dasharray", "6,4");
// Best model point marker
chart.append("circle")
.attr("cx", xScale(bestModelEpoch))
.attr("cy", yScale(bestData.valLoss))
.attr("r", 8)
.attr("fill", theme.isDark ? "#4ade80" : "#22c55e")
.attr("stroke", "#fff")
.attr("stroke-width", 2);
// Star/checkpoint icon
chart.append("text")
.attr("x", xScale(bestModelEpoch))
.attr("y", yScale(bestData.valLoss) + 1)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", "#fff")
.attr("font-size", "10px")
.attr("font-weight", "bold")
.text("★");
// Label for best model
chart.append("text")
.attr("x", xScale(bestModelEpoch))
.attr("y", yScale(bestData.valLoss) - 18)
.attr("text-anchor", "middle")
.attr("fill", theme.isDark ? "#4ade80" : "#22c55e")
.attr("font-size", "11px")
.attr("font-weight", "600")
.text("SAVE CHECKPOINT");
}
// Current epoch marker
const currentX = xScale(esCurrentEpoch);
const currentTrainY = yScale(currentEpochData.trainLoss);
const currentValY = yScale(currentEpochData.valLoss);
// Vertical line at current epoch
chart.append("line")
.attr("x1", currentX)
.attr("x2", currentX)
.attr("y1", 0)
.attr("y2", innerHeight)
.attr("stroke", theme.nodeText)
.attr("stroke-width", 1)
.attr("stroke-opacity", 0.4)
.attr("stroke-dasharray", "4,4");
// Gap indicator arrow
if (currentEpochData.gap > 0.1) {
const midY = (currentTrainY + currentValY) / 2;
// Gap line
chart.append("line")
.attr("x1", currentX + 8)
.attr("x2", currentX + 8)
.attr("y1", currentTrainY)
.attr("y2", currentValY)
.attr("stroke", theme.isDark ? "#ff6b6b" : "#ef4444")
.attr("stroke-width", 2);
// Gap label
chart.append("text")
.attr("x", currentX + 18)
.attr("y", midY)
.attr("dominant-baseline", "central")
.attr("fill", theme.isDark ? "#ff6b6b" : "#ef4444")
.attr("font-size", "10px")
.attr("font-weight", "500")
.text(`Gap: ${currentEpochData.gap.toFixed(2)}`);
}
// Current points
chart.append("circle")
.attr("cx", currentX)
.attr("cy", currentTrainY)
.attr("r", 6)
.attr("fill", theme.accent)
.attr("stroke", "#fff")
.attr("stroke-width", 2);
chart.append("circle")
.attr("cx", currentX)
.attr("cy", currentValY)
.attr("r", 6)
.attr("fill", theme.highlight)
.attr("stroke", "#fff")
.attr("stroke-width", 2);
// Legend
const legendX = innerWidth + 20;
const legendY = 40;
// Training loss legend
chart.append("line")
.attr("x1", legendX)
.attr("x2", legendX + 25)
.attr("y1", legendY)
.attr("y2", legendY)
.attr("stroke", theme.accent)
.attr("stroke-width", 3);
chart.append("text")
.attr("x", legendX + 32)
.attr("y", legendY)
.attr("dominant-baseline", "central")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Train Loss");
// Validation loss legend
chart.append("line")
.attr("x1", legendX)
.attr("x2", legendX + 25)
.attr("y1", legendY + 25)
.attr("y2", legendY + 25)
.attr("stroke", theme.highlight)
.attr("stroke-width", 3);
chart.append("text")
.attr("x", legendX + 32)
.attr("y", legendY + 25)
.attr("dominant-baseline", "central")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Val Loss");
// Best model legend
chart.append("circle")
.attr("cx", legendX + 12)
.attr("cy", legendY + 55)
.attr("r", 6)
.attr("fill", theme.isDark ? "#4ade80" : "#22c55e");
chart.append("text")
.attr("x", legendX + 32)
.attr("y", legendY + 55)
.attr("dominant-baseline", "central")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Best Model");
// Status panel
const statusY = legendY + 90;
chart.append("rect")
.attr("x", legendX - 5)
.attr("y", statusY - 5)
.attr("width", 115)
.attr("height", 75)
.attr("rx", 6)
.attr("fill", theme.bgSecondary)
.attr("stroke", theme.nodeStroke)
.attr("stroke-width", 1);
chart.append("text")
.attr("x", legendX + 5)
.attr("y", statusY + 12)
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.7)
.text(`Epoch: ${esCurrentEpoch}`);
chart.append("text")
.attr("x", legendX + 5)
.attr("y", statusY + 28)
.attr("fill", theme.accent)
.attr("font-size", "10px")
.text(`Train: ${currentEpochData.trainLoss.toFixed(3)}`);
chart.append("text")
.attr("x", legendX + 5)
.attr("y", statusY + 44)
.attr("fill", theme.highlight)
.attr("font-size", "10px")
.text(`Val: ${currentEpochData.valLoss.toFixed(3)}`);
const phaseColor = trainingPhase === "learning" ? theme.accent :
trainingPhase === "optimal" ? (theme.isDark ? "#4ade80" : "#22c55e") :
(theme.isDark ? "#ff6b6b" : "#ef4444");
chart.append("text")
.attr("x", legendX + 5)
.attr("y", statusY + 60)
.attr("fill", phaseColor)
.attr("font-size", "10px")
.attr("font-weight", "600")
.text(trainingPhase.toUpperCase());
return svg.node();
}
```
Tips:
- Monitor validation loss, not just training loss
- Save the model with the best validation loss
- Consider early stopping if validation loss increases consistently
## Training Tips
### Quick Reference Table
| Symptom | Likely Cause | Solution |
|---------|--------------|----------|
| Loss = NaN | LR too high | Reduce LR by 10x |
| Loss stuck | LR too low | Increase LR by 2-5x |
| Loss oscillates | Batch too small | Use gradient accumulation |
| Overfitting | Not enough data | More data, more dropout |
| Underfitting | Model too small | More layers/heads/dims |
| Slow training | No GPU/MPS | Use hardware acceleration |
| OOM errors | Batch too large | Reduce batch size, use accumulation |
| Training crash | No checkpoints | Save every N steps |
### Hyperparameter Recommendations
Based on published research and common practices:
| Hyperparameter | Small Models (<1B) | Large Models (>1B) |
|----------------|-------------------|-------------------|
| Learning rate | 1e-4 to 6e-4 | 1e-4 to 3e-4 |
| Warmup | 1-2% of steps | 0.1-1% of steps |
| Weight decay | 0.01 - 0.1 | 0.01 - 0.1 |
| Beta1 | 0.9 | 0.9 |
| Beta2 | 0.999 | 0.95 |
| Batch size | 256 - 1024 tokens | 1M - 4M tokens |
| Gradient clip | 1.0 | 1.0 |
### Memory Optimization Strategies
1. **Gradient accumulation**: Simulate larger batches
2. **Mixed precision (fp16/bf16)**: ~50% memory reduction
3. **Gradient checkpointing**: Trade compute for memory
4. **FSDP/DeepSpeed**: Shard model across GPUs
## Interactive Exploration
Experiment with learning rate schedules in real-time. Adjust the hyperparameters to see how warmup and cosine decay shape the learning rate curve.
```{ojs}
//| echo: false
// Cosine schedule with linear warmup
function computeSchedule(maxLr, minLr, warmupSteps, totalSteps) {
const lrs = [];
const numPoints = Math.min(totalSteps, 500); // Limit points for performance
const stepSize = totalSteps / numPoints;
for (let i = 0; i <= numPoints; i++) {
const step = Math.floor(i * stepSize);
let lr;
if (step < warmupSteps) {
// Linear warmup
lr = maxLr * step / Math.max(1, warmupSteps);
} else if (step >= totalSteps) {
lr = minLr;
} else {
// Cosine decay
const progress = (step - warmupSteps) / Math.max(1, totalSteps - warmupSteps);
const cosine = 0.5 * (1 + Math.cos(Math.PI * progress));
lr = minLr + (maxLr - minLr) * cosine;
}
lrs.push({ step, lr, phase: step < warmupSteps ? "warmup" : "decay" });
}
return lrs;
}
// Get LR at a specific step
function getLrAtStep(step, maxLr, minLr, warmupSteps, totalSteps) {
if (step < warmupSteps) {
return maxLr * step / Math.max(1, warmupSteps);
} else if (step >= totalSteps) {
return minLr;
} else {
const progress = (step - warmupSteps) / Math.max(1, totalSteps - warmupSteps);
const cosine = 0.5 * (1 + Math.cos(Math.PI * progress));
return minLr + (maxLr - minLr) * cosine;
}
}
```
```{ojs}
//| echo: false
// Input controls
viewof maxLr = Inputs.range([1e-5, 1e-2], {
value: 1e-3,
step: 1e-5,
label: "Max Learning Rate",
format: x => x.toExponential(1)
})
viewof minLr = Inputs.range([0, 1e-4], {
value: 1e-5,
step: 1e-6,
label: "Min Learning Rate",
format: x => x.toExponential(1)
})
viewof warmupSteps = Inputs.range([0, 500], {
value: 100,
step: 10,
label: "Warmup Steps"
})
viewof totalSteps = Inputs.range([100, 2000], {
value: 1000,
step: 50,
label: "Total Steps"
})
viewof currentStep = Inputs.range([0, totalSteps], {
value: Math.floor(totalSteps / 2),
step: 1,
label: "Current Step"
})
```
```{ojs}
//| echo: false
// Dark mode detection
isDark = {
const check = () => document.body.classList.contains('quarto-dark');
return check();
}
// Theme colors for light/dark mode
theme = isDark ? {
warmupBg: '#3d3520',
curveStroke: '#6b8cae',
warmupMarker: '#b89a5a',
currentMarker: '#a86b6b',
annotationText: '#b89a5a'
} : {
warmupBg: '#fef3c7',
curveStroke: '#3b82f6',
warmupMarker: '#f59e0b',
currentMarker: '#ef4444',
annotationText: '#f59e0b'
}
```
```{ojs}
//| echo: false
// Compute schedule data
scheduleData = computeSchedule(maxLr, minLr, warmupSteps, totalSteps)
// Current LR
currentLr = getLrAtStep(currentStep, maxLr, minLr, warmupSteps, totalSteps)
// Warmup percentage
warmupPct = ((warmupSteps / totalSteps) * 100).toFixed(1)
```
```{ojs}
//| echo: false
Plot = import("https://esm.sh/@observablehq/plot@0.6")
Plot.plot({
title: "Learning Rate Schedule: Warmup + Cosine Decay",
subtitle: `Warmup: ${warmupSteps} steps (${warmupPct}%) | Peak LR: ${maxLr.toExponential(1)} | Min LR: ${minLr.toExponential(1)}`,
width: 700,
height: 350,
marginLeft: 70,
marginBottom: 50,
x: {
label: "Training Step →",
domain: [0, totalSteps]
},
y: {
label: "↑ Learning Rate",
domain: [0, maxLr * 1.1],
tickFormat: ".1e"
},
marks: [
// Warmup region background
Plot.rectY([{x1: 0, x2: warmupSteps, y: maxLr * 1.1}], {
x1: "x1",
x2: "x2",
y2: "y",
y1: 0,
fill: theme.warmupBg,
fillOpacity: 0.5
}),
// Main LR curve
Plot.line(scheduleData, {
x: "step",
y: "lr",
stroke: theme.curveStroke,
strokeWidth: 2.5
}),
// Warmup end marker
Plot.ruleX([warmupSteps], {
stroke: theme.warmupMarker,
strokeWidth: 2,
strokeDasharray: "5,5"
}),
// Current step indicator
Plot.ruleX([currentStep], {
stroke: theme.currentMarker,
strokeWidth: 2
}),
// Current LR point
Plot.dot([{step: currentStep, lr: currentLr}], {
x: "step",
y: "lr",
fill: theme.currentMarker,
r: 6
}),
// Annotations
Plot.text([{step: warmupSteps, lr: maxLr * 1.05}], {
x: "step",
y: "lr",
text: ["← Warmup ends"],
fill: theme.annotationText,
fontSize: 11,
textAnchor: "start"
}),
Plot.ruleY([0])
]
})
```
```{ojs}
//| echo: false
// Display current step info
md`**Step ${currentStep}:** LR = **${currentLr.toExponential(3)}** ${currentStep < warmupSteps ? "(warming up)" : currentStep >= totalSteps ? "(finished)" : "(decaying)"}`
```
```{ojs}
//| echo: false
// Legend
md`<span style="background: ${theme.warmupBg}; padding: 2px 8px; color: ${isDark ? '#e8e6e3' : '#1e293b'}">Warmup phase</span> <span style="color: ${theme.warmupMarker}">┆</span> Warmup ends <span style="color: ${theme.currentMarker}">│</span> Current step`
```
::: {.callout-tip}
## Try This
1. **Effect of warmup**: Set warmup to 0, then gradually increase to 200. Notice how the curve changes from immediate peak to gradual ramp-up.
2. **Long vs short training**: Compare total_steps=500 vs total_steps=2000 with the same warmup. See how the decay rate changes.
3. **Min LR matters**: Set min_lr to 0, then to 1e-5. The floor prevents the model from completely stopping learning.
4. **Warmup ratio**: Try warmup_steps = 1-2% of total_steps (common in practice). For 1000 steps, that's 10-20 warmup steps.
5. **Drag the current step slider** to see the exact LR at any point in training.
:::
## Exercises
### Exercise 1: Learning Rate Finder
Implement a learning rate finder that trains for a few iterations at exponentially increasing learning rates and plots loss vs learning rate.
```{python}
# Your implementation here
def lr_finder(model, tokens, start_lr=1e-7, end_lr=1e-1, num_steps=100):
"""Find optimal learning rate by training with exponentially increasing LR."""
# TODO: Implement this
pass
```
### Exercise 2: Custom Scheduler
Implement a linear warmup + linear decay scheduler (instead of cosine decay).
```{python}
# Your implementation here
class LinearScheduler:
def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0.0):
# TODO: Implement this
pass
def step(self):
pass
```
### Exercise 3: Training with Validation
Modify the training loop to:
1. Compute validation loss every N steps
2. Save the best model (lowest validation loss)
3. Implement early stopping if validation loss doesn't improve for M steps
## Summary
In this module, we learned:
1. **Cross-entropy loss** measures prediction quality (lower = better), with mathematical foundations in information theory
2. **Perplexity** provides an intuitive metric: exp(loss) - "choosing among N equally likely options"
3. **Learning rate scheduling** with warmup + cosine decay prevents early instability and enables fine-tuning
4. **AdamW optimizer** combines momentum, adaptive learning rates, and proper weight decay decoupling
5. **Gradient accumulation** enables larger effective batch sizes without more memory
6. **Gradient clipping** (max_norm=1.0) prevents exploding gradients, essential for transformers
7. **Batch size tradeoffs** affect memory, training dynamics, and generalization
8. **Mixed precision training** (fp16/bf16) provides 2x speedup and 50% memory reduction
9. **Distributed training** (DDP, FSDP) scales training to multiple GPUs
10. **Common failure modes** (NaN loss, stuck training, oscillation) and their solutions
11. **Checkpointing strategies** ensure you never lose training progress
### Key Takeaways
- Always use warmup (at least 1% of steps) to stabilize early training
- Monitor gradient norms alongside loss - they tell you about training stability
- Start with standard hyperparameters (lr=3e-4, wd=0.01, clip=1.0) and adjust from there
- Test your training loop on a tiny dataset first - verify it can overfit
## What's Next
In [Module 08: Generation](../m08_generation/), we'll use our trained model to generate text with various decoding strategies like greedy, sampling, and top-k/top-p.