---
title: "Module 02: Autograd"
format:
html:
code-fold: false
toc: true
ipynb: default
jupyter: python3
---
{{< include ../_diagram-lib.qmd >}}
## Introduction
The magic that makes neural networks trainable. Automatic differentiation computes gradients through any computation - the foundation of backpropagation.
**Autograd** (automatic differentiation) is how we compute gradients automatically. When you do `loss.backward()` in PyTorch, autograd figures out how to adjust every parameter to reduce the loss.
Why is this essential for LLMs?
- **Millions of parameters**: An LLM has millions (or billions) of numbers to adjust. We can't compute gradients by hand.
- **Complex computations**: Attention, embeddings, layer norms - the gradient must flow through all of them.
- **Training loop**: Every training step computes gradients to update weights.
Without autograd, deep learning wouldn't be practical.
### What You'll Learn
By the end of this module, you will be able to:
- Understand how computational graphs enable automatic gradient computation
- Build a working scalar autograd engine from scratch
- Extend these principles to tensor operations
- Use PyTorch's autograd for gradient computation
- Recognize common autograd pitfalls and how to avoid them
## Intuition: The Computational Graph
Every computation builds a **directed acyclic graph (DAG)** dynamically as operations execute. This is called "define-by-run" - the graph structure is determined by the actual code path taken, which can differ each forward pass (useful for control flow like conditionals and loops).
Think of a computation as a graph of operations:
```{ojs}
//| echo: false
// Interactive step-through of forward pass: x -> a = x*3 -> b = a+1 -> c = b^2
viewof forwardStep = Inputs.range([0, 3], {
value: 0,
step: 1,
label: "Step"
})
```
```{ojs}
//| echo: false
forwardPassDiagram = {
const width = 580;
const height = 140;
const nodeW = 120;
const nodeH = 50;
// Computation values
const x = 2;
const a = x * 3; // 6
const b = a + 1; // 7
const c = b * b; // 49
const nodes = [
{ id: 'x', x: 70, y: 70, label: 'x = 2', sublabel: 'input' },
{ id: 'a', x: 210, y: 70, label: 'a = x * 3', sublabel: forwardStep >= 1 ? `= ${a}` : '?' },
{ id: 'b', x: 360, y: 70, label: 'b = a + 1', sublabel: forwardStep >= 2 ? `= ${b}` : '?' },
{ id: 'c', x: 510, y: 70, label: 'c = b^2', sublabel: forwardStep >= 3 ? `= ${c}` : '?' }
];
const edges = [
{ source: 'x', target: 'a', id: 'e1' },
{ source: 'a', target: 'b', id: 'e2' },
{ source: 'b', target: 'c', id: 'e3' }
];
// Active elements based on step
const activeNodes = forwardStep === 0 ? ['x'] :
forwardStep === 1 ? ['x', 'a'] :
forwardStep === 2 ? ['x', 'a', 'b'] :
['x', 'a', 'b', 'c'];
const activeEdges = forwardStep === 0 ? [] :
forwardStep === 1 ? ['e1'] :
forwardStep === 2 ? ['e1', 'e2'] :
['e1', 'e2', 'e3'];
return FlowDiagram({
nodes,
edges,
width,
height,
activeNodes,
activeEdges,
theme: diagramTheme,
nodeWidth: nodeW,
nodeHeight: nodeH
});
}
```
```{ojs}
//| echo: false
// Step descriptions for forward pass
md`**Step ${forwardStep}:** ${
forwardStep === 0 ? "Start with input x = 2" :
forwardStep === 1 ? "Compute a = x * 3 = 2 * 3 = 6" :
forwardStep === 2 ? "Compute b = a + 1 = 6 + 1 = 7" :
"Compute c = b^2 = 7^2 = 49 (output)"
}`
```
**Forward pass**: Compute values flowing down (x -> a -> b -> c)
**Backward pass**: Compute gradients flowing up (c -> b -> a -> x)
- dc/dc = 1 (output gradient is always 1)
- dc/db = 2b = 14 (derivative of b^2)
- dc/da = dc/db x db/da = 14 x 1 = 14
- dc/dx = dc/da x da/dx = 14 x 3 = 42
The **chain rule** connects everything: multiply local gradients as you go back.
## Computational Graph: Forward and Backward
Here's a more detailed example showing both passes:
```{ojs}
//| echo: false
// Toggle between forward and backward pass
viewof passDirection = Inputs.radio(["Forward Pass", "Backward Pass"], {
value: "Forward Pass",
label: "View"
})
```
```{ojs}
//| echo: false
// Step slider for the selected pass
viewof computeStep = Inputs.range([0, 4], {
value: 0,
step: 1,
label: "Step"
})
```
```{ojs}
//| echo: false
computeGraphDiagram = {
const width = 650;
const height = 280;
const nodeW = 90;
const nodeH = 48;
// Fixed values for the computation
const x = 2.0, w = 0.5, b = 0.1, target = 0.8;
const z1 = w * x; // 1.0
const z2 = z1 + b; // 1.1
const diff = z2 - target; // 0.3
const loss = diff * diff; // 0.09
// Gradients (backward pass)
const dLoss = 1.0;
const dDiff = 2 * diff; // 0.6
const dZ2 = dDiff; // 0.6
const dZ1 = dZ2; // 0.6
const dB = dZ2; // 0.6
const dX = dZ1 * w; // 0.3
const dW = dZ1 * x; // 1.2
const isForward = passDirection === "Forward Pass";
const svg = d3.create('svg')
.attr('width', width)
.attr('height', height)
.attr('viewBox', `0 0 ${width} ${height}`);
// Background
svg.append('rect')
.attr('width', width)
.attr('height', height)
.attr('fill', diagramTheme.bg)
.attr('rx', 8);
// Defs for arrows
const defs = svg.append('defs');
// Normal arrow
defs.append('marker')
.attr('id', 'compute-arrow')
.attr('viewBox', '0 -5 10 10')
.attr('refX', 8)
.attr('refY', 0)
.attr('markerWidth', 6)
.attr('markerHeight', 6)
.attr('orient', 'auto')
.append('path')
.attr('d', 'M0,-5L10,0L0,5')
.attr('fill', diagramTheme.edgeStroke);
// Highlighted arrow
defs.append('marker')
.attr('id', 'compute-arrow-active')
.attr('viewBox', '0 -5 10 10')
.attr('refX', 8)
.attr('refY', 0)
.attr('markerWidth', 6)
.attr('markerHeight', 6)
.attr('orient', 'auto')
.append('path')
.attr('d', 'M0,-5L10,0L0,5')
.attr('fill', diagramTheme.highlight);
// Node positions
const positions = {
x: { x: 50, y: 70 },
w: { x: 50, y: 150 },
b: { x: 50, y: 230 },
target: { x: 400, y: 230 },
z1: { x: 180, y: 110 },
z2: { x: 310, y: 140 },
diff: { x: 440, y: 140 },
loss: { x: 570, y: 140 }
};
// Define which nodes/edges are active at each step
const forwardSteps = [
{ nodes: ['x', 'w', 'b', 'target'], edges: [] },
{ nodes: ['x', 'w', 'b', 'target', 'z1'], edges: ['x-z1', 'w-z1'] },
{ nodes: ['x', 'w', 'b', 'target', 'z1', 'z2'], edges: ['x-z1', 'w-z1', 'z1-z2', 'b-z2'] },
{ nodes: ['x', 'w', 'b', 'target', 'z1', 'z2', 'diff'], edges: ['x-z1', 'w-z1', 'z1-z2', 'b-z2', 'z2-diff', 'target-diff'] },
{ nodes: ['x', 'w', 'b', 'target', 'z1', 'z2', 'diff', 'loss'], edges: ['x-z1', 'w-z1', 'z1-z2', 'b-z2', 'z2-diff', 'target-diff', 'diff-loss'] }
];
const backwardSteps = [
{ nodes: ['loss'], edges: [], grads: { loss: dLoss } },
{ nodes: ['loss', 'diff'], edges: ['loss-diff'], grads: { loss: dLoss, diff: dDiff } },
{ nodes: ['loss', 'diff', 'z2', 'target'], edges: ['loss-diff', 'diff-z2'], grads: { loss: dLoss, diff: dDiff, z2: dZ2 } },
{ nodes: ['loss', 'diff', 'z2', 'target', 'z1', 'b'], edges: ['loss-diff', 'diff-z2', 'z2-z1', 'z2-b'], grads: { loss: dLoss, diff: dDiff, z2: dZ2, z1: dZ1, b: dB } },
{ nodes: ['loss', 'diff', 'z2', 'target', 'z1', 'b', 'x', 'w'], edges: ['loss-diff', 'diff-z2', 'z2-z1', 'z2-b', 'z1-x', 'z1-w'], grads: { loss: dLoss, diff: dDiff, z2: dZ2, z1: dZ1, b: dB, x: dX, w: dW } }
];
const currentStep = isForward ? forwardSteps[computeStep] : backwardSteps[computeStep];
const grads = isForward ? {} : (currentStep.grads || {});
// Helper to draw a node
function drawNode(id, label, sublabel, active) {
const pos = positions[id];
const fill = active ? diagramTheme.highlight : diagramTheme.nodeFill;
const stroke = active ? diagramTheme.highlight : diagramTheme.nodeStroke;
const textColor = active ? diagramTheme.textOnHighlight : diagramTheme.nodeText;
const g = svg.append('g').attr('transform', `translate(${pos.x}, ${pos.y})`);
g.append('rect')
.attr('x', -nodeW/2)
.attr('y', -nodeH/2)
.attr('width', nodeW)
.attr('height', nodeH)
.attr('rx', 6)
.attr('fill', fill)
.attr('stroke', stroke)
.attr('stroke-width', active ? 2 : 1.5);
if (active) {
g.select('rect').attr('filter', `drop-shadow(0 0 6px ${diagramTheme.highlightGlow})`);
}
g.append('text')
.attr('y', sublabel ? -6 : 0)
.attr('text-anchor', 'middle')
.attr('dominant-baseline', 'central')
.attr('fill', textColor)
.attr('font-size', '11px')
.attr('font-weight', '500')
.text(label);
if (sublabel) {
g.append('text')
.attr('y', 10)
.attr('text-anchor', 'middle')
.attr('dominant-baseline', 'central')
.attr('fill', textColor)
.attr('font-size', '9px')
.attr('opacity', active ? 0.9 : 0.7)
.text(sublabel);
}
}
// Helper to draw an edge
function drawEdge(fromId, toId, active, curved = false, curveFactor = 0) {
const from = positions[fromId];
const to = positions[toId];
const color = active ? diagramTheme.highlight : diagramTheme.edgeStroke;
const markerId = active ? 'compute-arrow-active' : 'compute-arrow';
const dx = to.x - from.x;
const dy = to.y - from.y;
const len = Math.sqrt(dx*dx + dy*dy);
const startX = from.x + (dx/len) * (nodeW/2 + 5);
const startY = from.y + (dy/len) * (nodeH/2 + 5);
const endX = to.x - (dx/len) * (nodeW/2 + 10);
const endY = to.y - (dy/len) * (nodeH/2 + 10);
let pathD;
if (curved) {
const midX = (startX + endX) / 2;
const midY = (startY + endY) / 2;
const cx = midX - dy * curveFactor;
const cy = midY + dx * curveFactor;
pathD = `M${startX},${startY} Q${cx},${cy} ${endX},${endY}`;
} else {
pathD = `M${startX},${startY} L${endX},${endY}`;
}
const path = svg.append('path')
.attr('d', pathD)
.attr('fill', 'none')
.attr('stroke', color)
.attr('stroke-width', active ? 2.5 : 1.5)
.attr('marker-end', `url(#${markerId})`);
if (active) {
path.attr('filter', `drop-shadow(0 0 4px ${diagramTheme.highlightGlow})`);
}
}
// Draw edges first
const edgeMap = {
'x-z1': ['x', 'z1', false, 0],
'w-z1': ['w', 'z1', false, 0],
'z1-z2': ['z1', 'z2', false, 0],
'b-z2': ['b', 'z2', false, 0],
'z2-diff': ['z2', 'diff', false, 0],
'target-diff': ['target', 'diff', false, 0],
'diff-loss': ['diff', 'loss', false, 0],
// Backward edges (reversed)
'loss-diff': ['loss', 'diff', false, 0],
'diff-z2': ['diff', 'z2', false, 0],
'z2-z1': ['z2', 'z1', false, 0],
'z2-b': ['z2', 'b', false, 0],
'z1-x': ['z1', 'x', false, 0],
'z1-w': ['z1', 'w', false, 0]
};
// Draw all relevant edges
const allEdges = isForward ?
['x-z1', 'w-z1', 'z1-z2', 'b-z2', 'z2-diff', 'target-diff', 'diff-loss'] :
['loss-diff', 'diff-z2', 'z2-z1', 'z2-b', 'z1-x', 'z1-w'];
allEdges.forEach(edgeId => {
const [fromId, toId, curved, factor] = edgeMap[edgeId];
const active = currentStep.edges.includes(edgeId);
drawEdge(fromId, toId, active, curved, factor);
});
// Node definitions with values/gradients
const nodeLabels = isForward ? {
x: { label: 'x', sublabel: '= 2.0' },
w: { label: 'w', sublabel: '= 0.5' },
b: { label: 'b', sublabel: '= 0.1' },
target: { label: 'target', sublabel: '= 0.8' },
z1: { label: 'z1 = w*x', sublabel: computeStep >= 1 ? `= ${z1.toFixed(1)}` : '?' },
z2: { label: 'z2 = z1+b', sublabel: computeStep >= 2 ? `= ${z2.toFixed(1)}` : '?' },
diff: { label: 'diff', sublabel: computeStep >= 3 ? `= ${diff.toFixed(1)}` : '?' },
loss: { label: 'loss = diff^2', sublabel: computeStep >= 4 ? `= ${loss.toFixed(2)}` : '?' }
} : {
x: { label: 'x', sublabel: grads.x !== undefined ? `grad = ${grads.x.toFixed(1)}` : '' },
w: { label: 'w', sublabel: grads.w !== undefined ? `grad = ${grads.w.toFixed(1)}` : '' },
b: { label: 'b', sublabel: grads.b !== undefined ? `grad = ${grads.b.toFixed(1)}` : '' },
target: { label: 'target', sublabel: '(no grad)' },
z1: { label: 'z1', sublabel: grads.z1 !== undefined ? `grad = ${grads.z1.toFixed(1)}` : '' },
z2: { label: 'z2', sublabel: grads.z2 !== undefined ? `grad = ${grads.z2.toFixed(1)}` : '' },
diff: { label: 'diff', sublabel: grads.diff !== undefined ? `grad = ${grads.diff.toFixed(1)}` : '' },
loss: { label: 'loss', sublabel: grads.loss !== undefined ? `grad = ${grads.loss.toFixed(1)}` : '' }
};
// Draw all nodes
const allNodes = ['x', 'w', 'b', 'target', 'z1', 'z2', 'diff', 'loss'];
allNodes.forEach(id => {
const active = currentStep.nodes.includes(id);
const { label, sublabel } = nodeLabels[id];
drawNode(id, label, sublabel, active);
});
// Title
svg.append('text')
.attr('x', width / 2)
.attr('y', 20)
.attr('text-anchor', 'middle')
.attr('fill', diagramTheme.nodeText)
.attr('font-size', '13px')
.attr('font-weight', '600')
.text(isForward ? 'Forward Pass: Compute Values' : 'Backward Pass: Compute Gradients');
return svg.node();
}
```
```{ojs}
//| echo: false
// Step explanation
md`**Step ${computeStep}:** ${
passDirection === "Forward Pass" ? (
computeStep === 0 ? "Initialize inputs: x=2.0, w=0.5, b=0.1, target=0.8" :
computeStep === 1 ? "Compute z1 = w * x = 0.5 * 2.0 = 1.0" :
computeStep === 2 ? "Compute z2 = z1 + b = 1.0 + 0.1 = 1.1" :
computeStep === 3 ? "Compute diff = z2 - target = 1.1 - 0.8 = 0.3" :
"Compute loss = diff^2 = 0.3^2 = 0.09"
) : (
computeStep === 0 ? "Start at loss with gradient = 1.0 (seed)" :
computeStep === 1 ? "grad(diff) = d(diff^2)/d(diff) * 1.0 = 2*diff = 2*0.3 = 0.6" :
computeStep === 2 ? "grad(z2) = grad(diff) * d(diff)/d(z2) = 0.6 * 1 = 0.6" :
computeStep === 3 ? "grad(z1) = 0.6, grad(b) = 0.6 (both from z2)" :
"grad(x) = grad(z1) * w = 0.6 * 0.5 = 0.3, grad(w) = grad(z1) * x = 0.6 * 2.0 = 1.2"
)
}`
```
## The Math
### Chain Rule
For composed functions f(g(x)):
```
d/dx[f(g(x))] = f'(g(x)) x g'(x)
```
This extends to any depth - just multiply local derivatives along the path.
### Common Gradients
| Operation | Forward | Local Gradient |
|-----------|---------|----------------|
| `c = a + b` | sum | dc/da = 1, dc/db = 1 |
| `c = a * b` | product | dc/da = b, dc/db = a |
| `c = a ** n` | power | dc/da = n x a^(n-1) |
| `c = exp(a)` | exp | dc/da = exp(a) |
| `c = log(a)` | log | dc/da = 1/a |
| `c = tanh(a)` | tanh | dc/da = 1 - tanh^2(a) |
| `c = relu(a)` | ReLU | dc/da = 1 if a > 0 else 0 |
### Matrix Multiplication Gradients
For `C = A @ B`:
- dL/dA = dL/dC @ B.T
- dL/dB = A.T @ dL/dC
This is why matrix shapes matter so much!
## Gradient Accumulation
When a value is used multiple times, gradients ADD:
```{ojs}
//| echo: false
// Step slider for gradient accumulation
viewof accumStep = Inputs.range([0, 3], {
value: 0,
step: 1,
label: "Step"
})
```
```{ojs}
//| echo: false
gradAccumDiagram = {
const width = 500;
const height = 200;
const nodeW = 80;
const nodeH = 44;
const a = 3;
const c = a * a; // 9
// Gradient computation
const dC = 1;
const path1Grad = a; // 3 (from left input of *)
const path2Grad = a; // 3 (from right input of *)
const dA = path1Grad + path2Grad; // 6 = 2a
const svg = d3.create('svg')
.attr('width', width)
.attr('height', height)
.attr('viewBox', `0 0 ${width} ${height}`);
svg.append('rect')
.attr('width', width)
.attr('height', height)
.attr('fill', diagramTheme.bg)
.attr('rx', 8);
// Defs
const defs = svg.append('defs');
defs.append('marker')
.attr('id', 'accum-arrow')
.attr('viewBox', '0 -5 10 10')
.attr('refX', 8)
.attr('refY', 0)
.attr('markerWidth', 6)
.attr('markerHeight', 6)
.attr('orient', 'auto')
.append('path')
.attr('d', 'M0,-5L10,0L0,5')
.attr('fill', diagramTheme.edgeStroke);
defs.append('marker')
.attr('id', 'accum-arrow-active')
.attr('viewBox', '0 -5 10 10')
.attr('refX', 8)
.attr('refY', 0)
.attr('markerWidth', 6)
.attr('markerHeight', 6)
.attr('orient', 'auto')
.append('path')
.attr('d', 'M0,-5L10,0L0,5')
.attr('fill', diagramTheme.highlight);
// Accent arrow for second path
defs.append('marker')
.attr('id', 'accum-arrow-accent')
.attr('viewBox', '0 -5 10 10')
.attr('refX', 8)
.attr('refY', 0)
.attr('markerWidth', 6)
.attr('markerHeight', 6)
.attr('orient', 'auto')
.append('path')
.attr('d', 'M0,-5L10,0L0,5')
.attr('fill', diagramTheme.accent);
// Positions
const aPos = { x: 100, y: 100 };
const mulPos = { x: 250, y: 100 };
const cPos = { x: 400, y: 100 };
// Draw edges (two curved paths from a to mul)
const drawCurvedPath = (fromX, fromY, toX, toY, curve, active, pathNum) => {
const dx = toX - fromX;
const dy = toY - fromY;
const len = Math.sqrt(dx*dx + dy*dy);
const startX = fromX + (dx/len) * (nodeW/2 + 5);
const startY = fromY + (dy/len) * (nodeH/2 + 5);
const endX = toX - (dx/len) * (nodeW/2 + 10);
const endY = toY - (dy/len) * (nodeH/2 + 10);
const midX = (startX + endX) / 2;
const midY = (startY + endY) / 2;
const cx = midX;
const cy = midY + curve;
const pathD = `M${startX},${startY} Q${cx},${cy} ${endX},${endY}`;
let color, markerId;
if (active && pathNum === 1) {
color = diagramTheme.highlight;
markerId = 'accum-arrow-active';
} else if (active && pathNum === 2) {
color = diagramTheme.accent;
markerId = 'accum-arrow-accent';
} else {
color = diagramTheme.edgeStroke;
markerId = 'accum-arrow';
}
const path = svg.append('path')
.attr('d', pathD)
.attr('fill', 'none')
.attr('stroke', color)
.attr('stroke-width', active ? 2.5 : 1.5)
.attr('marker-end', `url(#${markerId})`);
if (active) {
path.attr('filter', `drop-shadow(0 0 4px ${pathNum === 1 ? diagramTheme.highlightGlow : diagramTheme.accentGlow})`);
}
// Path label
if (active) {
svg.append('text')
.attr('x', midX)
.attr('y', midY + curve + (curve > 0 ? 15 : -10))
.attr('text-anchor', 'middle')
.attr('fill', color)
.attr('font-size', '10px')
.attr('font-weight', '500')
.text(`path ${pathNum}: grad = ${a}`);
}
};
// Straight edge from mul to c
const drawStraightPath = (fromX, fromY, toX, toY, active) => {
const dx = toX - fromX;
const len = Math.abs(dx);
const startX = fromX + (nodeW/2 + 5);
const endX = toX - (nodeW/2 + 10);
const color = active ? diagramTheme.highlight : diagramTheme.edgeStroke;
const markerId = active ? 'accum-arrow-active' : 'accum-arrow';
const path = svg.append('path')
.attr('d', `M${startX},${fromY} L${endX},${toY}`)
.attr('fill', 'none')
.attr('stroke', color)
.attr('stroke-width', active ? 2.5 : 1.5)
.attr('marker-end', `url(#${markerId})`);
if (active) {
path.attr('filter', `drop-shadow(0 0 4px ${diagramTheme.highlightGlow})`);
}
};
// Draw edges based on step
// Step 0: Show structure (forward view)
// Step 1: Start backward from c
// Step 2: Show path 1 gradient
// Step 3: Show path 2 gradient and accumulation
drawCurvedPath(aPos.x, aPos.y, mulPos.x, mulPos.y, -30, accumStep >= 2, 1);
drawCurvedPath(aPos.x, aPos.y, mulPos.x, mulPos.y, 30, accumStep >= 3, 2);
drawStraightPath(mulPos.x, mulPos.y, cPos.x, cPos.y, accumStep >= 1);
// Draw nodes
const drawNode = (x, y, label, sublabel, active) => {
const fill = active ? diagramTheme.highlight : diagramTheme.nodeFill;
const stroke = active ? diagramTheme.highlight : diagramTheme.nodeStroke;
const textColor = active ? diagramTheme.textOnHighlight : diagramTheme.nodeText;
const g = svg.append('g').attr('transform', `translate(${x}, ${y})`);
g.append('rect')
.attr('x', -nodeW/2)
.attr('y', -nodeH/2)
.attr('width', nodeW)
.attr('height', nodeH)
.attr('rx', 6)
.attr('fill', fill)
.attr('stroke', stroke)
.attr('stroke-width', active ? 2 : 1.5);
if (active) {
g.select('rect').attr('filter', `drop-shadow(0 0 6px ${diagramTheme.highlightGlow})`);
}
g.append('text')
.attr('y', sublabel ? -6 : 0)
.attr('text-anchor', 'middle')
.attr('dominant-baseline', 'central')
.attr('fill', textColor)
.attr('font-size', '12px')
.attr('font-weight', '500')
.text(label);
if (sublabel) {
g.append('text')
.attr('y', 10)
.attr('text-anchor', 'middle')
.attr('dominant-baseline', 'central')
.attr('fill', textColor)
.attr('font-size', '10px')
.attr('opacity', active ? 0.9 : 0.7)
.text(sublabel);
}
};
// Node labels/sublabels based on step
const aLabel = accumStep === 0 ? `= ${a}` :
accumStep === 3 ? `grad = ${dA}` : '';
const mulLabel = '*';
const cLabel = accumStep === 0 ? `= ${c}` :
accumStep >= 1 ? `grad = ${dC}` : '';
drawNode(aPos.x, aPos.y, 'a', aLabel, accumStep === 0 || accumStep === 3);
drawNode(mulPos.x, mulPos.y, mulLabel, 'a * a', false);
drawNode(cPos.x, cPos.y, 'c', cLabel, accumStep >= 1);
// Title
svg.append('text')
.attr('x', width / 2)
.attr('y', 25)
.attr('text-anchor', 'middle')
.attr('fill', diagramTheme.nodeText)
.attr('font-size', '13px')
.attr('font-weight', '600')
.text('Gradient Accumulation: c = a * a');
return svg.node();
}
```
```{ojs}
//| echo: false
md`**Step ${accumStep}:** ${
accumStep === 0 ? "Forward pass: a = 3, c = a * a = 9" :
accumStep === 1 ? "Start backward: grad(c) = 1 (seed gradient)" :
accumStep === 2 ? "Path 1 contributes: d(a*a)/da (left input) = a = 3" :
"Path 2 contributes: d(a*a)/da (right input) = a = 3. **Total: grad(a) = 3 + 3 = 6 = 2a**"
}`
```
## Code Walkthrough
Let's explore autograd interactively:
```{python}
import torch
print(f"PyTorch version: {torch.__version__}")
```
### Building a Simple Autograd Engine
We'll build a simple autograd engine (inspired by Andrej Karpathy's micrograd). This handles scalar values only - PyTorch's autograd extends these same principles to tensors of any shape, which is what makes it powerful for real neural networks.
```{python}
class Value:
"""A scalar value that tracks its gradient."""
def __init__(self, data, children=(), op='', label=''):
self.data = data
self.grad = 0.0
self._backward = lambda: None
self._prev = set(children)
self._op = op
self.label = label
def __repr__(self):
return f"Value(data={self.data}, grad={self.grad})"
def __add__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data + other.data, (self, other), '+')
def _backward():
self.grad += out.grad # d(a+b)/da = 1
other.grad += out.grad # d(a+b)/db = 1
out._backward = _backward
return out
def __mul__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data * other.data, (self, other), '*')
def _backward():
self.grad += other.data * out.grad # d(a*b)/da = b
other.grad += self.data * out.grad # d(a*b)/db = a
out._backward = _backward
return out
def __pow__(self, other):
assert isinstance(other, (int, float)), "only supporting int/float powers"
out = Value(self.data ** other, (self,), f'**{other}')
def _backward():
self.grad += (other * self.data ** (other - 1)) * out.grad
out._backward = _backward
return out
def __neg__(self):
return self * -1
def __sub__(self, other):
return self + (-other)
def __radd__(self, other):
return self + other
def __rmul__(self, other):
return self * other
def tanh(self):
import math
t = math.tanh(self.data)
out = Value(t, (self,), 'tanh')
def _backward():
self.grad += (1 - t ** 2) * out.grad
out._backward = _backward
return out
def relu(self):
out = Value(max(0, self.data), (self,), 'relu')
def _backward():
self.grad += (self.data > 0) * out.grad
out._backward = _backward
return out
def exp(self):
import math
out = Value(math.exp(self.data), (self,), 'exp')
def _backward():
self.grad += out.data * out.grad
out._backward = _backward
return out
def log(self):
import math
out = Value(math.log(self.data), (self,), 'log')
def _backward():
self.grad += (1.0 / self.data) * out.grad
out._backward = _backward
return out
def __truediv__(self, other):
return self * (other ** -1)
def backward(self):
# Topological sort to process in correct order
topo = []
visited = set()
def build_topo(v):
if v not in visited:
visited.add(v)
for child in v._prev:
build_topo(child)
topo.append(v)
build_topo(self)
# Go backwards, applying chain rule
self.grad = 1.0
for v in reversed(topo):
v._backward()
```
### Testing Our Value Class
```{python}
# Create some Values
a = Value(2.0, label='a')
b = Value(3.0, label='b')
print(f"a = {a}")
print(f"b = {b}")
print(f"\nInitially, gradients are 0.0")
```
```{python}
# Perform a computation
c = a * b
c.label = 'c'
print(f"c = a * b = {c.data}")
print(f"\nNow let's compute gradients with backward():")
c.backward()
print(f"\ndc/da = {a.grad} (should be b = 3.0)")
print(f"dc/db = {b.grad} (should be a = 2.0)")
```
### Verifying Gradients Numerically
```{python}
# Let's verify: increase 'a' by a tiny amount
epsilon = 0.0001
a_original = 2.0
b_val = 3.0
c_original = a_original * b_val
c_perturbed = (a_original + epsilon) * b_val
numerical_gradient = (c_perturbed - c_original) / epsilon
print(f"Numerical gradient dc/da = {numerical_gradient:.4f}")
print(f"Our computed gradient: {a.grad}")
print(f"\nThey match! (The tiny difference is numerical precision)")
```
### The Chain Rule in Action
```{python}
# Let's trace through: c = (a + b) * b
a = Value(2.0, label='a')
b = Value(3.0, label='b')
sum_ab = a + b
sum_ab.label = 'sum'
c = sum_ab * b
c.label = 'c'
c.backward()
print(f"Expression: c = (a + b) * b")
print(f"a = {a.data}, b = {b.data}")
print(f"sum = a + b = {sum_ab.data}")
print(f"c = sum * b = {c.data}")
print(f"\nGradients:")
print(f"dc/da = {a.grad} (expected: b = 3)")
print(f"dc/db = {b.grad} (expected: sum + a = 5 + 3 = 8)")
```
### Training a Neuron
Let's actually train a neuron to output a target value!
```{python}
# Training loop
x_val = 2.0
target_val = 0.8
# Learnable parameters (start with small random values)
w = 0.3
b = 0.1
learning_rate = 0.5
losses = []
print("Training a neuron to output 0.8 when input is 2.0")
print("=" * 50)
for step in range(20):
# Create fresh Values (gradients reset)
x = Value(x_val, label='x')
w_v = Value(w, label='w')
b_v = Value(b, label='b')
target = Value(target_val, label='target')
# Forward pass
y = (w_v * x + b_v).tanh()
loss = (y - target) ** 2
# Backward pass
loss.backward()
# Gradient descent update
w = w - learning_rate * w_v.grad
b = b - learning_rate * b_v.grad
losses.append(loss.data)
if step % 4 == 0:
print(f"Step {step:2d}: y={y.data:.4f}, loss={loss.data:.6f}, w={w:.4f}, b={b:.4f}")
print(f"\nFinal output: {y.data:.4f}")
print(f"Target: {target_val}")
print("Pretty close!")
```
```{python}
#| echo: false
# Pass training losses to OJS for visualization
ojs_define(training_losses = losses)
```
```{ojs}
//| echo: false
trainingLossChart = {
const width = 600;
const height = 280;
const margin = {top: 40, right: 30, bottom: 50, left: 60};
const innerWidth = width - margin.left - margin.right;
const innerHeight = height - margin.top - margin.bottom;
const theme = diagramTheme;
// Convert training_losses to array with step numbers
const data = training_losses.map((loss, i) => ({step: i, loss: loss}));
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`);
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 8);
const g = svg.append("g")
.attr("transform", `translate(${margin.left},${margin.top})`);
// Scales
const x = d3.scaleLinear()
.domain([0, data.length - 1])
.range([0, innerWidth]);
const y = d3.scaleLinear()
.domain([0, d3.max(data, d => d.loss) * 1.1])
.range([innerHeight, 0]);
// Grid lines
g.append("g")
.attr("class", "grid")
.call(d3.axisLeft(y).ticks(5).tickSize(-innerWidth).tickFormat(""))
.call(g => g.select(".domain").remove())
.call(g => g.selectAll(".tick line")
.attr("stroke", theme.nodeStroke)
.attr("stroke-opacity", 0.3));
// X axis
g.append("g")
.attr("transform", `translate(0,${innerHeight})`)
.call(d3.axisBottom(x).ticks(10).tickFormat(d3.format("d")))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick line").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick text").attr("fill", theme.nodeText));
// Y axis
g.append("g")
.call(d3.axisLeft(y).ticks(5).tickFormat(d3.format(".4f")))
.call(g => g.select(".domain").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick line").attr("stroke", theme.nodeStroke))
.call(g => g.selectAll(".tick text").attr("fill", theme.nodeText));
// Axis labels
g.append("text")
.attr("x", innerWidth / 2)
.attr("y", innerHeight + 40)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "12px")
.text("Training Step");
g.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -innerHeight / 2)
.attr("y", -45)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "12px")
.text("Loss");
// Line generator
const line = d3.line()
.x(d => x(d.step))
.y(d => y(d.loss))
.curve(d3.curveMonotoneX);
// Draw the line
g.append("path")
.datum(data)
.attr("fill", "none")
.attr("stroke", theme.accent)
.attr("stroke-width", 2.5)
.attr("d", line);
// Draw dots
g.selectAll(".dot")
.data(data)
.join("circle")
.attr("class", "dot")
.attr("cx", d => x(d.step))
.attr("cy", d => y(d.loss))
.attr("r", 4)
.attr("fill", theme.accent)
.attr("stroke", theme.bg)
.attr("stroke-width", 1.5);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 24)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text("Loss Decreasing During Training");
return svg.node();
}
```
The loss decreases because gradients tell us which way to adjust w and b!
## The Evolutionary Leap: Scalar to Tensor
We've now mastered scalar autograd with our `Value` class. The chain rule, computational graphs, and backward passes all work the same way at any scale. But real neural networks operate on tensors with thousands or millions of elements. Let's build a second autograd engine — this time for tensors — to see how the same principles scale up.
### Why Scalars Don't Scale
Our `Value` class works beautifully for understanding autograd. But try to imagine training a real neural network with it:
- A small MLP with 10,000 parameters creates 10,000 `Value` objects
- Each forward pass creates thousands more intermediate `Value` nodes
- Python loops for every dot product (catastrophically slow)
- Memory explodes with millions of tiny objects
The fix isn't "optimize Python." The fix is **stop pretending a neural network is a pile of scalars**.
```{python}
# The problem: dot product with scalar Values
def slow_dot_product(values_a, values_b):
"""This is what we're doing now - Python loop over scalars."""
result = values_a[0] * values_b[0]
for a, b in zip(values_a[1:], values_b[1:]):
result = result + a * b
return result
# With 1000-dimensional vectors, that's 1000 Python operations
# A GPU can do this in ONE operation
```
The solution: upgrade from scalar `Value` to tensor `Tensor`.
### Building a Tensor Autograd
Here's the evolutionary leap: a `Tensor` class backed by NumPy arrays instead of Python floats.
```{python}
import numpy as np
from typing import Optional, Tuple, Callable, Set
class Tensor:
"""
A NumPy-backed tensor with reverse-mode autodiff.
This is what PyTorch's tensors do internally (but in C++/CUDA).
"""
def __init__(self, data, requires_grad: bool = False, _prev: Set["Tensor"] = None, _op: str = ""):
# Convert to numpy array if needed
if isinstance(data, np.ndarray):
self.data = data.astype(np.float32)
else:
self.data = np.array(data, dtype=np.float32)
self.requires_grad = requires_grad
self.grad: Optional[np.ndarray] = None
self._prev = _prev or set()
self._op = _op
self._backward: Callable[[], None] = lambda: None
def __repr__(self) -> str:
return f"Tensor(shape={self.data.shape}, requires_grad={self.requires_grad})"
@property
def shape(self) -> Tuple[int, ...]:
return self.data.shape
def zero_grad(self) -> None:
self.grad = None
```
Already we see the key difference: `self.data` is a NumPy array, not a float. This means operations work on entire arrays at once.
### The Tricky Part: Broadcasting Gradients
When you add a `(batch, dim)` tensor to a `(dim,)` bias, NumPy broadcasts the bias across all batches. But during backprop, we need to **undo** that broadcasting - the gradient for the bias should be summed across the batch dimension.
```{python}
def _unbroadcast(grad: np.ndarray, target_shape: Tuple[int, ...]) -> np.ndarray:
"""
Undo NumPy broadcasting for gradients.
Example:
Forward: y = x + b where x is (B, D) and b is (D,)
Backward: grad wrt b should be sum over batch axis -> (D,)
Rules:
1. While grad has extra leading dims, sum them away
2. For dims where target had size 1, sum over that axis
"""
g = grad
# Remove leading dims added by broadcasting
while len(g.shape) > len(target_shape):
g = g.sum(axis=0)
# Sum over axes where target had size 1
for axis, (gdim, tdim) in enumerate(zip(g.shape, target_shape)):
if tdim == 1 and gdim != 1:
g = g.sum(axis=axis, keepdims=True)
return g.reshape(target_shape)
# Test it
grad = np.ones((3, 4)) # Gradient has batch dimension
bias_shape = (4,) # Bias was (4,) before broadcasting
result = _unbroadcast(grad, bias_shape)
print(f"Input grad shape: {grad.shape}")
print(f"Target shape: {bias_shape}")
print(f"Result shape: {result.shape}") # (4,) - summed over batch
print(f"Result values: {result}") # [3, 3, 3, 3] - sum of 3 ones per position
```
This is the key insight PyTorch handles automatically. When you see `RuntimeError: grad shape doesn't match`, it's usually a broadcasting gradient issue.
### Tensor Arithmetic with Gradients
Now we add operations. Each operation stores a `_backward` function that knows how to push gradients to its inputs.
```{python}
# Add these methods to our Tensor class (shown separately for clarity)
def tensor_add(self, other) -> "Tensor":
"""
Addition: c = a + b
Gradient: dc/da = 1, dc/db = 1 (with unbroadcasting)
"""
other = other if isinstance(other, Tensor) else Tensor(other)
out = Tensor(
self.data + other.data,
requires_grad=self.requires_grad or other.requires_grad,
_prev={self, other},
_op="+"
)
def _backward():
if out.grad is None:
return
if self.requires_grad:
g = _unbroadcast(out.grad, self.data.shape)
self.grad = g if self.grad is None else (self.grad + g)
if other.requires_grad:
g = _unbroadcast(out.grad, other.data.shape)
other.grad = g if other.grad is None else (other.grad + g)
out._backward = _backward
return out
def tensor_mul(self, other) -> "Tensor":
"""
Multiplication: c = a * b
Gradient: dc/da = b, dc/db = a (with unbroadcasting)
"""
other = other if isinstance(other, Tensor) else Tensor(other)
out = Tensor(
self.data * other.data,
requires_grad=self.requires_grad or other.requires_grad,
_prev={self, other},
_op="*"
)
def _backward():
if out.grad is None:
return
if self.requires_grad:
g = _unbroadcast(out.grad * other.data, self.data.shape)
self.grad = g if self.grad is None else (self.grad + g)
if other.requires_grad:
g = _unbroadcast(out.grad * self.data, other.data.shape)
other.grad = g if other.grad is None else (other.grad + g)
out._backward = _backward
return out
# Attach to Tensor class
Tensor.__add__ = tensor_add
Tensor.__radd__ = lambda self, other: tensor_add(self, other)
Tensor.__mul__ = tensor_mul
Tensor.__rmul__ = lambda self, other: tensor_mul(self, other)
Tensor.__neg__ = lambda self: tensor_mul(self, -1.0)
Tensor.__sub__ = lambda self, other: tensor_add(self, -other)
```
Notice how similar this is to our scalar `Value` - just with arrays and `_unbroadcast`.
### Matrix Multiplication: The Core of Neural Networks
Matrix multiplication is where tensors really shine. This single operation replaces thousands of scalar multiplications and additions.
```{python}
def tensor_matmul(self, other) -> "Tensor":
"""
Matrix multiplication: C = A @ B
Forward: C[i,j] = sum_k A[i,k] * B[k,j]
Backward:
dA = dC @ B.T (gradient flows back through B transposed)
dB = A.T @ dC (gradient flows back through A transposed)
"""
other = other if isinstance(other, Tensor) else Tensor(other)
out = Tensor(
np.matmul(self.data, other.data),
requires_grad=self.requires_grad or other.requires_grad,
_prev={self, other},
_op="matmul"
)
def _backward():
if out.grad is None:
return
if self.requires_grad:
# dA = dC @ B.T
dA = np.matmul(out.grad, np.swapaxes(other.data, -1, -2))
dA = _unbroadcast(dA, self.data.shape)
self.grad = dA if self.grad is None else (self.grad + dA)
if other.requires_grad:
# dB = A.T @ dC
dB = np.matmul(np.swapaxes(self.data, -1, -2), out.grad)
dB = _unbroadcast(dB, other.data.shape)
other.grad = dB if other.grad is None else (other.grad + dB)
out._backward = _backward
return out
Tensor.matmul = tensor_matmul
# Test: simple 2x2 matmul
A = Tensor([[1, 2], [3, 4]], requires_grad=True)
B = Tensor([[5, 6], [7, 8]], requires_grad=True)
C = A.matmul(B)
print(f"A @ B =\n{C.data}")
# Backward pass
C.grad = np.ones_like(C.data) # Seed gradient
C._backward()
print(f"\ndA =\n{A.grad}")
print(f"\ndB =\n{B.grad}")
```
This is the fundamental operation of neural networks. Every linear layer is just `x @ W + b`.
### Backward Pass: Topological Sort
Just like our scalar `Value`, we need to traverse the graph in reverse order.
```{python}
def tensor_backward(self) -> None:
"""
Reverse-mode autodiff: topologically sort and backprop.
"""
# Build topological order
topo = []
visited = set()
def build(v: Tensor):
if v not in visited:
visited.add(v)
for p in v._prev:
build(p)
topo.append(v)
build(self)
# Seed gradient and propagate
self.grad = np.ones_like(self.data, dtype=np.float32)
for v in reversed(topo):
v._backward()
Tensor.backward = tensor_backward
```
### Activation Functions
Neural networks need nonlinearities. Here are the common ones:
```{python}
def tensor_relu(self) -> "Tensor":
"""ReLU: max(0, x). Gradient: 1 if x > 0, else 0."""
out = Tensor(
np.maximum(self.data, 0.0),
requires_grad=self.requires_grad,
_prev={self},
_op="relu"
)
def _backward():
if out.grad is None or not self.requires_grad:
return
g = out.grad * (self.data > 0.0)
self.grad = g if self.grad is None else (self.grad + g)
out._backward = _backward
return out
def tensor_tanh(self) -> "Tensor":
"""Tanh: gradient is (1 - tanh^2)."""
t = np.tanh(self.data)
out = Tensor(t, requires_grad=self.requires_grad, _prev={self}, _op="tanh")
def _backward():
if out.grad is None or not self.requires_grad:
return
g = out.grad * (1.0 - t * t)
self.grad = g if self.grad is None else (self.grad + g)
out._backward = _backward
return out
Tensor.relu = tensor_relu
Tensor.tanh = tensor_tanh
```
### Reduction Operations
We need `sum` and `mean` for computing losses.
```{python}
def tensor_sum(self, axis=None, keepdims=False) -> "Tensor":
"""Sum over axis. Gradient broadcasts back to input shape."""
out = Tensor(
self.data.sum(axis=axis, keepdims=keepdims),
requires_grad=self.requires_grad,
_prev={self},
_op="sum"
)
def _backward():
if out.grad is None or not self.requires_grad:
return
g = out.grad
# Expand reduced axes back to input shape
if axis is None:
g = np.ones_like(self.data) * g
else:
axes = axis if isinstance(axis, tuple) else (axis,)
if not keepdims:
for ax in sorted(axes):
g = np.expand_dims(g, axis=ax)
g = np.ones_like(self.data) * g
self.grad = g if self.grad is None else (self.grad + g)
out._backward = _backward
return out
def tensor_mean(self, axis=None, keepdims=False) -> "Tensor":
"""Mean = sum / count."""
if axis is None:
denom = self.data.size
else:
axes = axis if isinstance(axis, tuple) else (axis,)
denom = np.prod([self.data.shape[ax] for ax in axes])
return self.sum(axis=axis, keepdims=keepdims) * (1.0 / float(denom))
Tensor.sum = tensor_sum
Tensor.mean = tensor_mean
# Test
x = Tensor([[1, 2, 3], [4, 5, 6]], requires_grad=True)
loss = x.mean()
print(f"Mean: {loss.data}")
loss.backward()
print(f"Gradient (each element contributes 1/6): \n{x.grad}")
```
### Putting It Together: A Tiny Neural Network
Let's use our tensor autograd to train a small network:
```{python}
# Complete example: train a 2-layer network on XOR
np.random.seed(42)
# XOR dataset
X = Tensor([[0, 0], [0, 1], [1, 0], [1, 1]], requires_grad=False)
y = Tensor([[0], [1], [1], [0]], requires_grad=False)
# Weights (small random init)
W1 = Tensor(np.random.randn(2, 8) * 0.5, requires_grad=True)
b1 = Tensor(np.zeros((1, 8)), requires_grad=True)
W2 = Tensor(np.random.randn(8, 1) * 0.5, requires_grad=True)
b2 = Tensor(np.zeros((1, 1)), requires_grad=True)
params = [W1, b1, W2, b2]
lr = 0.5
for step in range(200):
# Forward
h = X.matmul(W1) + b1 # (4, 8)
h = h.tanh() # activation
out = h.matmul(W2) + b2 # (4, 1)
# MSE loss
diff = out + (y * -1.0) # out - y
loss = (diff * diff).mean()
# Backward
for p in params:
p.grad = None
loss.backward()
# SGD update
for p in params:
p.data -= lr * p.grad
if step % 50 == 0:
print(f"Step {step}: loss = {loss.data:.4f}")
print(f"\nFinal predictions:")
print(f" [0,0] -> {out.data[0,0]:.3f} (target: 0)")
print(f" [0,1] -> {out.data[1,0]:.3f} (target: 1)")
print(f" [1,0] -> {out.data[2,0]:.3f} (target: 1)")
print(f" [1,1] -> {out.data[3,0]:.3f} (target: 0)")
```
We just trained a neural network using only NumPy and our ~100 lines of autograd code!
### PyTorch's Tensor Autograd
Now let's see the same XOR problem in PyTorch:
```{python}
import torch
# Same XOR problem
X_pt = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
y_pt = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
torch.manual_seed(42)
# Note: create tensor, multiply, THEN set requires_grad to keep as leaf tensors
W1_pt = (torch.randn(2, 8) * 0.5).requires_grad_(True)
b1_pt = torch.zeros(1, 8, requires_grad=True)
W2_pt = (torch.randn(8, 1) * 0.5).requires_grad_(True)
b2_pt = torch.zeros(1, 1, requires_grad=True)
params_pt = [W1_pt, b1_pt, W2_pt, b2_pt]
for step in range(200):
# Forward - identical logic!
h = (X_pt @ W1_pt + b1_pt).tanh()
out = h @ W2_pt + b2_pt
loss = ((out - y_pt) ** 2).mean()
# Backward - one line
loss.backward()
# Update - use .data to modify in-place, keeping tensors as leaves
with torch.no_grad():
for p in params_pt:
p.data -= 0.5 * p.grad
p.grad = None
if step % 50 == 0:
print(f"Step {step}: loss = {loss.item():.4f}")
```
### Key Insight
The logic is **identical**. PyTorch just:
- Handles `_unbroadcast` automatically
- Uses optimized C++/CUDA kernels
- Provides convenient `loss.backward()` without manual graph traversal
- Has `torch.no_grad()` context for updates
You now understand what happens inside `requires_grad=True` and `backward()`.
### PyTorch Autograd (Scalar Examples)
Now let's see how PyTorch does the same thing with scalars:
```{python}
# Create tensors that track gradients
x = torch.tensor(2.0, requires_grad=True)
w = torch.tensor(3.0, requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)
# Forward pass
y = w * x + b # Linear function: y = 3*2 + 1 = 7
loss = y ** 2 # Square loss: loss = 49
# Backward pass - PyTorch computes all gradients automatically
loss.backward()
# Derivation using chain rule:
# dloss/dy = 2*y = 2*7 = 14
# dloss/dx = dloss/dy * dy/dx = 14 * w = 14 * 3 = 42
# dloss/dw = dloss/dy * dy/dw = 14 * x = 14 * 2 = 28
# dloss/db = dloss/dy * dy/db = 14 * 1 = 14
print(f"y = {y.item():.1f}, loss = {loss.item():.1f}")
print(f"dloss/dx = {x.grad.item():.1f} (= 2*y*w = 2*7*3)")
print(f"dloss/dw = {w.grad.item():.1f} (= 2*y*x = 2*7*2)")
print(f"dloss/db = {b.grad.item():.1f} (= 2*y*1 = 2*7)")
```
```{python}
# With vectors
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
w = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
# Forward pass
y = (x * w).sum() # Dot product
print(f"y = x . w = {y.item():.4f}")
# Backward pass
y.backward()
print(f"\ndy/dx = {x.grad}") # Should be w
print(f"dy/dw = {w.grad}") # Should be x
```
```{python}
# With matrices (like in attention)
Q = torch.randn(2, 3, requires_grad=True) # Query
K = torch.randn(2, 3, requires_grad=True) # Key
# Attention scores: Q @ K^T
scores = Q @ K.T
loss = scores.sum()
loss.backward()
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"scores shape: {scores.shape}")
print(f"\ndloss/dQ shape: {Q.grad.shape}")
print(f"dloss/dK shape: {K.grad.shape}")
print("\nGradients have the same shape as the original tensors!")
```
## Interactive Exploration
Now that you've seen gradient computation in code, let's explore it interactively. The widget below lets you modify input values and see how gradients change in real-time — demonstrating the chain rule in action.
```{ojs}
//| echo: false
// Dark mode detection
isDark = {
const check = () => document.body.classList.contains('quarto-dark');
return check();
}
// Theme colors for light/dark mode
theme = isDark ? {
// Dark mode colors
bgPrimary: '#181a1b',
textPrimary: '#e8e6e3',
textMuted: '#a8a6a2',
// Node colors (dark mode - muted versions)
inputBg: '#1e3a5f',
inputBorder: '#6b8cae',
inputText: '#93c5fd',
opBg: '#3d3520',
opBorder: '#b89a5a',
opText: '#d4b87a',
outputBg: '#1a3d2e',
outputBorder: '#5a9a7a',
outputText: '#86efac',
gradientText: '#fca5a5',
arrowColor: '#a8a6a2'
} : {
// Light mode colors
bgPrimary: '#ffffff',
textPrimary: '#1e293b',
textMuted: '#6b7280',
// Node colors (light mode - Tailwind palette)
inputBg: '#dbeafe',
inputBorder: '#3b82f6',
inputText: '#1e40af',
opBg: '#fef3c7',
opBorder: '#f59e0b',
opText: '#92400e',
outputBg: '#dcfce7',
outputBorder: '#22c55e',
outputText: '#166534',
gradientText: '#dc2626',
arrowColor: '#6b7280'
}
```
```{ojs}
//| echo: false
// Compute forward and backward pass for: c = (a + b) * b
function computeGraph(a, b) {
// Forward pass
const sum = a + b;
const c = sum * b;
// Backward pass (chain rule)
const dc_dc = 1;
const dc_dsum = b; // d(sum * b)/d(sum) = b
const dc_db_direct = sum; // d(sum * b)/d(b) = sum (from multiplication)
const dsum_db = 1; // d(a + b)/d(b) = 1
const dc_db_via_sum = dc_dsum * dsum_db; // path through sum
const dc_db = dc_db_direct + dc_db_via_sum; // total: sum + b = a + 2b
const dc_da = dc_dsum * 1; // d(a + b)/d(a) = 1, so dc/da = b
return {
// Forward values
a, b, sum, c,
// Backward gradients
dc_dc, dc_dsum, dc_db, dc_da,
// For display
dc_db_direct, dc_db_via_sum
};
}
```
```{ojs}
//| echo: false
viewof aValue = Inputs.range([-5, 5], {
value: 2,
step: 0.5,
label: "a ="
})
viewof bValue = Inputs.range([-5, 5], {
value: 3,
step: 0.5,
label: "b ="
})
```
```{ojs}
//| echo: false
graph = computeGraph(aValue, bValue)
```
```{ojs}
//| echo: false
// Visual computation graph using HTML/SVG
computationGraphViz = html`
<div style="font-family: system-ui; margin: 20px 0; color: ${theme.textPrimary};">
<div style="display: flex; align-items: center; justify-content: center; gap: 20px; flex-wrap: wrap;">
<!-- Input nodes -->
<div style="display: flex; flex-direction: column; gap: 40px;">
<div style="background: ${theme.inputBg}; border: 2px solid ${theme.inputBorder}; border-radius: 8px; padding: 12px 20px; text-align: center; min-width: 80px;">
<div style="font-weight: bold; color: ${theme.inputText};">a</div>
<div style="font-size: 20px; color: ${theme.textPrimary};">${graph.a}</div>
<div style="font-size: 12px; color: ${theme.gradientText}; margin-top: 4px;">∂c/∂a = ${graph.dc_da}</div>
</div>
<div style="background: ${theme.inputBg}; border: 2px solid ${theme.inputBorder}; border-radius: 8px; padding: 12px 20px; text-align: center; min-width: 80px;">
<div style="font-weight: bold; color: ${theme.inputText};">b</div>
<div style="font-size: 20px; color: ${theme.textPrimary};">${graph.b}</div>
<div style="font-size: 12px; color: ${theme.gradientText}; margin-top: 4px;">∂c/∂b = ${graph.dc_db}</div>
</div>
</div>
<!-- Arrow -->
<div style="font-size: 24px; color: ${theme.arrowColor};">→</div>
<!-- Sum node -->
<div style="background: ${theme.opBg}; border: 2px solid ${theme.opBorder}; border-radius: 8px; padding: 12px 20px; text-align: center; min-width: 100px;">
<div style="font-weight: bold; color: ${theme.opText};">a + b</div>
<div style="font-size: 20px; color: ${theme.textPrimary};">${graph.sum}</div>
<div style="font-size: 12px; color: ${theme.gradientText}; margin-top: 4px;">∂c/∂(sum) = ${graph.dc_dsum}</div>
</div>
<!-- Arrow -->
<div style="font-size: 24px; color: ${theme.arrowColor};">→</div>
<!-- Multiply node -->
<div style="background: ${theme.opBg}; border: 2px solid ${theme.opBorder}; border-radius: 8px; padding: 12px 20px; text-align: center; min-width: 100px;">
<div style="font-weight: bold; color: ${theme.opText};">× b</div>
<div style="font-size: 20px; color: ${theme.textPrimary};">${graph.c}</div>
<div style="font-size: 12px; color: ${theme.gradientText}; margin-top: 4px;">∂c/∂c = ${graph.dc_dc}</div>
</div>
<!-- Arrow -->
<div style="font-size: 24px; color: ${theme.arrowColor};">→</div>
<!-- Output node -->
<div style="background: ${theme.outputBg}; border: 2px solid ${theme.outputBorder}; border-radius: 8px; padding: 12px 20px; text-align: center; min-width: 80px;">
<div style="font-weight: bold; color: ${theme.outputText};">c</div>
<div style="font-size: 24px; font-weight: bold; color: ${theme.textPrimary};">${graph.c}</div>
</div>
</div>
<!-- b connection to multiply (shown below) -->
<div style="text-align: center; margin-top: 10px; color: ${theme.textMuted}; font-size: 13px;">
<em>Note: b connects to both (+) and (×)</em>
</div>
</div>
`
```
```{ojs}
//| echo: false
// Gradient explanation
gradientExplanation = md`
### Gradient Computation
**Expression:** \`c = (a + b) * b\`
| Variable | Value | Gradient (∂c/∂?) | How It's Computed |
|----------|-------|------------------|-------------------|
| c | ${graph.c} | ${graph.dc_dc} | Output (seed = 1) |
| sum = a + b | ${graph.sum} | ${graph.dc_dsum} | ∂(sum × b)/∂sum = b = ${graph.b} |
| a | ${graph.a} | ${graph.dc_da} | ∂c/∂sum × ∂sum/∂a = ${graph.b} × 1 = ${graph.dc_da} |
| b | ${graph.b} | ${graph.dc_db} | Two paths: ${graph.dc_db_direct} (direct) + ${graph.dc_db_via_sum} (via sum) = ${graph.dc_db} |
**Key insight:** ∂c/∂b = a + 2b = ${graph.a} + 2×${graph.b} = ${graph.dc_db}. The gradient of b has **two contributions** because b appears twice in the expression!
`
```
::: {.callout-tip}
## Try This
1. **Gradient depends on values**: Change a from 2 to 4. Watch ∂c/∂b change (it equals a + 2b).
2. **b has two paths**: Notice ∂c/∂b is the sum of two contributions - one from the multiplication (= sum) and one via the addition (= b).
3. **Zero gradients**: Set b = 0. Now ∂c/∂a = 0 because the multiplication by b kills the gradient.
4. **Negative gradients**: Try negative values. Gradients can be negative, indicating the output decreases when the input increases.
5. **Verify the formula**: For any a, b values: ∂c/∂a should equal b, and ∂c/∂b should equal a + 2b.
:::
## Exercises
### Exercise 1: Verify the Gradient of x^3
```{python}
# Verify the gradient of x^3 at x=2
# Expected: d/dx[x^3] = 3x^2 = 12
x = Value(2.0, label='x')
y = x ** 3
y.backward()
print(f"x = {x.data}")
print(f"y = x^3 = {y.data}")
print(f"dy/dx = {x.grad} (expected: 12.0)")
```
### Exercise 2: Softmax Gradient
```{python}
# Compute gradient of softmax numerator
# If y = exp(x) / (exp(x) + exp(z)), what is dy/dx?
x = Value(1.0, label='x')
z = Value(2.0, label='z')
exp_x = x.exp()
exp_z = z.exp()
y = exp_x / (exp_x + exp_z)
y.backward()
print(f"x = {x.data}, z = {z.data}")
print(f"y = softmax(x)[0] = {y.data:.4f}")
print(f"dy/dx = {x.grad:.4f}")
print(f"\nThis is the gradient that flows back through softmax!")
```
### Exercise 3: ReLU vs Tanh Gradients
```{python}
# Compare ReLU vs tanh gradients
for x_val in [-2.0, -0.5, 0.5, 2.0]:
x_relu = Value(x_val)
x_tanh = Value(x_val)
y_relu = x_relu.relu()
y_tanh = x_tanh.tanh()
y_relu.backward()
y_tanh.backward()
print(f"x={x_val:5.1f}: ReLU grad={x_relu.grad:.3f}, tanh grad={x_tanh.grad:.3f}")
print("\nNotice: ReLU has 0 or 1, tanh is smooth but saturates for large |x|")
```
## Backpropagation in Neural Networks
Here's how gradients flow backward through a neural network layer:
```{ojs}
//| echo: false
// Toggle between forward and backward for NN diagram
viewof nnPassView = Inputs.radio(["Forward Pass", "Backward Pass"], {
value: "Forward Pass",
label: "View"
})
```
```{ojs}
//| echo: false
// Step slider for NN pass
viewof nnStep = Inputs.range([0, 4], {
value: 0,
step: 1,
label: "Step"
})
```
```{ojs}
//| echo: false
nnBackpropDiagram = {
const width = 700;
const height = 200;
const nodeW = 85;
const nodeH = 46;
const isForward = nnPassView === "Forward Pass";
const svg = d3.create('svg')
.attr('width', width)
.attr('height', height)
.attr('viewBox', `0 0 ${width} ${height}`);
svg.append('rect')
.attr('width', width)
.attr('height', height)
.attr('fill', diagramTheme.bg)
.attr('rx', 8);
// Defs
const defs = svg.append('defs');
defs.append('marker')
.attr('id', 'nn-arrow')
.attr('viewBox', '0 -5 10 10')
.attr('refX', 8)
.attr('refY', 0)
.attr('markerWidth', 6)
.attr('markerHeight', 6)
.attr('orient', 'auto')
.append('path')
.attr('d', 'M0,-5L10,0L0,5')
.attr('fill', diagramTheme.edgeStroke);
defs.append('marker')
.attr('id', 'nn-arrow-active')
.attr('viewBox', '0 -5 10 10')
.attr('refX', 8)
.attr('refY', 0)
.attr('markerWidth', 6)
.attr('markerHeight', 6)
.attr('orient', 'auto')
.append('path')
.attr('d', 'M0,-5L10,0L0,5')
.attr('fill', diagramTheme.highlight);
// Node positions - horizontal layout
const positions = {
x: { x: 50, y: 80 },
W: { x: 50, y: 150 },
z1: { x: 170, y: 115 },
b: { x: 170, y: 180 },
z2: { x: 310, y: 115 },
y: { x: 450, y: 115 },
loss: { x: 590, y: 115 }
};
// Forward steps
const forwardSteps = [
{ nodes: ['x', 'W', 'b'], edges: [], desc: 'Inputs: x (data), W (weights), b (bias)' },
{ nodes: ['x', 'W', 'b', 'z1'], edges: ['x-z1', 'W-z1'], desc: 'z1 = W @ x (matrix multiply)' },
{ nodes: ['x', 'W', 'b', 'z1', 'z2'], edges: ['x-z1', 'W-z1', 'z1-z2', 'b-z2'], desc: 'z2 = z1 + b (add bias)' },
{ nodes: ['x', 'W', 'b', 'z1', 'z2', 'y'], edges: ['x-z1', 'W-z1', 'z1-z2', 'b-z2', 'z2-y'], desc: 'y = tanh(z2) (activation)' },
{ nodes: ['x', 'W', 'b', 'z1', 'z2', 'y', 'loss'], edges: ['x-z1', 'W-z1', 'z1-z2', 'b-z2', 'z2-y', 'y-loss'], desc: 'L = (y - target)^2 (loss)' }
];
// Backward steps with gradient formulas
const backwardSteps = [
{ nodes: ['loss'], edges: [], grads: { loss: '1' }, desc: 'Start: dL/dL = 1' },
{ nodes: ['loss', 'y'], edges: ['loss-y'], grads: { loss: '1', y: '2(y-t)' }, desc: 'dL/dy = 2(y - target)' },
{ nodes: ['loss', 'y', 'z2'], edges: ['loss-y', 'y-z2'], grads: { loss: '1', y: '2(y-t)', z2: 'dL/dy*(1-y^2)' }, desc: 'dL/dz2 = dL/dy * (1 - y^2) [tanh derivative]' },
{ nodes: ['loss', 'y', 'z2', 'z1', 'b'], edges: ['loss-y', 'y-z2', 'z2-z1', 'z2-b'], grads: { loss: '1', y: '2(y-t)', z2: 'dL/dy*(1-y^2)', z1: 'dL/dz2', b: 'dL/dz2' }, desc: 'dL/dz1 = dL/dz2, dL/db = dL/dz2 (UPDATE b!)' },
{ nodes: ['loss', 'y', 'z2', 'z1', 'b', 'x', 'W'], edges: ['loss-y', 'y-z2', 'z2-z1', 'z2-b', 'z1-x', 'z1-W'], grads: { loss: '1', y: '2(y-t)', z2: 'dL/dy*(1-y^2)', z1: 'dL/dz2', b: 'dL/dz2', W: 'dL/dz1 @ x^T' }, desc: 'dL/dW = dL/dz1 @ x^T (UPDATE W!)' }
];
const currentStep = isForward ? forwardSteps[nnStep] : backwardSteps[nnStep];
const grads = isForward ? {} : (currentStep.grads || {});
// Edge definitions
const edgeMap = {
'x-z1': ['x', 'z1'],
'W-z1': ['W', 'z1'],
'z1-z2': ['z1', 'z2'],
'b-z2': ['b', 'z2'],
'z2-y': ['z2', 'y'],
'y-loss': ['y', 'loss'],
// Backward edges (reversed)
'loss-y': ['loss', 'y'],
'y-z2': ['y', 'z2'],
'z2-z1': ['z2', 'z1'],
'z2-b': ['z2', 'b'],
'z1-x': ['z1', 'x'],
'z1-W': ['z1', 'W']
};
// Draw edge helper
const drawEdge = (fromId, toId, active) => {
const from = positions[fromId];
const to = positions[toId];
const color = active ? diagramTheme.highlight : diagramTheme.edgeStroke;
const markerId = active ? 'nn-arrow-active' : 'nn-arrow';
const dx = to.x - from.x;
const dy = to.y - from.y;
const len = Math.sqrt(dx*dx + dy*dy);
const startX = from.x + (dx/len) * (nodeW/2 + 5);
const startY = from.y + (dy/len) * (nodeH/2 + 5);
const endX = to.x - (dx/len) * (nodeW/2 + 10);
const endY = to.y - (dy/len) * (nodeH/2 + 10);
const path = svg.append('path')
.attr('d', `M${startX},${startY} L${endX},${endY}`)
.attr('fill', 'none')
.attr('stroke', color)
.attr('stroke-width', active ? 2.5 : 1.5)
.attr('marker-end', `url(#${markerId})`);
if (active) {
path.attr('filter', `drop-shadow(0 0 4px ${diagramTheme.highlightGlow})`);
}
};
// Draw all relevant edges
const allEdges = isForward ?
['x-z1', 'W-z1', 'z1-z2', 'b-z2', 'z2-y', 'y-loss'] :
['loss-y', 'y-z2', 'z2-z1', 'z2-b', 'z1-x', 'z1-W'];
allEdges.forEach(edgeId => {
const [fromId, toId] = edgeMap[edgeId];
const active = currentStep.edges.includes(edgeId);
drawEdge(fromId, toId, active);
});
// Node definitions
const nodeLabels = isForward ? {
x: { label: 'x', sublabel: '(input)' },
W: { label: 'W', sublabel: '(weights)' },
z1: { label: 'z1 = W@x', sublabel: '' },
b: { label: 'b', sublabel: '(bias)' },
z2: { label: 'z2 = z1+b', sublabel: '' },
y: { label: 'y = tanh(z2)', sublabel: '' },
loss: { label: 'L = (y-t)^2', sublabel: '' }
} : {
x: { label: 'x', sublabel: '' },
W: { label: 'W', sublabel: grads.W ? `dL/dW` : '' },
z1: { label: 'z1', sublabel: grads.z1 ? 'dL/dz1' : '' },
b: { label: 'b', sublabel: grads.b ? 'dL/db' : '' },
z2: { label: 'z2', sublabel: grads.z2 ? 'dL/dz2' : '' },
y: { label: 'y', sublabel: grads.y ? 'dL/dy' : '' },
loss: { label: 'L', sublabel: grads.loss ? 'dL/dL = 1' : '' }
};
// Mark updatable parameters in backward pass
const updatableParams = ['W', 'b'];
// Draw node helper
const drawNode = (id, active) => {
const pos = positions[id];
const { label, sublabel } = nodeLabels[id];
const isUpdatable = !isForward && updatableParams.includes(id) && currentStep.nodes.includes(id) && nnStep >= 3;
const fill = active ? diagramTheme.highlight : diagramTheme.nodeFill;
const stroke = active ? diagramTheme.highlight : (isUpdatable ? diagramTheme.accent : diagramTheme.nodeStroke);
const textColor = active ? diagramTheme.textOnHighlight : diagramTheme.nodeText;
const g = svg.append('g').attr('transform', `translate(${pos.x}, ${pos.y})`);
g.append('rect')
.attr('x', -nodeW/2)
.attr('y', -nodeH/2)
.attr('width', nodeW)
.attr('height', nodeH)
.attr('rx', 6)
.attr('fill', fill)
.attr('stroke', stroke)
.attr('stroke-width', active ? 2 : (isUpdatable ? 2.5 : 1.5));
if (active) {
g.select('rect').attr('filter', `drop-shadow(0 0 6px ${diagramTheme.highlightGlow})`);
} else if (isUpdatable) {
g.select('rect').attr('filter', `drop-shadow(0 0 4px ${diagramTheme.accentGlow})`);
}
g.append('text')
.attr('y', sublabel ? -6 : 0)
.attr('text-anchor', 'middle')
.attr('dominant-baseline', 'central')
.attr('fill', textColor)
.attr('font-size', '11px')
.attr('font-weight', '500')
.text(label);
if (sublabel) {
g.append('text')
.attr('y', 10)
.attr('text-anchor', 'middle')
.attr('dominant-baseline', 'central')
.attr('fill', textColor)
.attr('font-size', '9px')
.attr('opacity', active ? 0.9 : 0.7)
.text(sublabel);
}
// Add "UPDATE!" label for parameters being updated
if (isUpdatable) {
g.append('text')
.attr('y', nodeH/2 + 12)
.attr('text-anchor', 'middle')
.attr('fill', diagramTheme.accent)
.attr('font-size', '9px')
.attr('font-weight', '600')
.text('UPDATE!');
}
};
// Draw all nodes
const allNodes = ['x', 'W', 'z1', 'b', 'z2', 'y', 'loss'];
allNodes.forEach(id => {
const active = currentStep.nodes.includes(id);
drawNode(id, active);
});
// Title
svg.append('text')
.attr('x', width / 2)
.attr('y', 20)
.attr('text-anchor', 'middle')
.attr('fill', diagramTheme.nodeText)
.attr('font-size', '13px')
.attr('font-weight', '600')
.text(isForward ? 'Forward Pass: y = tanh(W @ x + b)' : 'Backward Pass: Compute Gradients');
return svg.node();
}
```
```{ojs}
//| echo: false
// Step explanation for NN diagram
nnStepDesc = {
const forwardDescs = [
"Inputs: x (data), W (weights), b (bias)",
"z1 = W @ x (matrix multiply weights with input)",
"z2 = z1 + b (add bias term)",
"y = tanh(z2) (apply activation function)",
"L = (y - target)^2 (compute loss)"
];
const backwardDescs = [
"Start: dL/dL = 1 (seed gradient)",
"dL/dy = 2(y - target) [derivative of squared loss]",
"dL/dz2 = dL/dy * (1 - y^2) [tanh derivative: 1 - tanh^2]",
"dL/dz1 = dL/dz2, dL/db = dL/dz2. **Update b!**",
"dL/dW = dL/dz1 @ x^T [matmul gradient]. **Update W!**"
];
const descs = nnPassView === "Forward Pass" ? forwardDescs : backwardDescs;
return descs[nnStep];
}
md`**Step ${nnStep}:** ${nnStepDesc}`
```
Key insights:
1. **Gradients flow backward**: Starting from the loss, we trace back through every operation
2. **Chain rule connects layers**: Multiply local gradients along the path
3. **Accumulation**: If a value is used multiple times, gradients add up
4. **Shape matters**: dL/dW must have the same shape as W for the update
## Detaching Tensors and Stopping Gradients
Sometimes you want to stop gradient flow. PyTorch provides several mechanisms:
### Using detach()
```{python}
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x * 2
z = y.detach() # z has no gradient history
print(f"y.requires_grad: {y.requires_grad}")
print(f"z.requires_grad: {z.requires_grad}")
# z is now a "constant" - no gradients flow through it
loss = z.sum()
# loss.backward() would NOT compute gradients for x
```
### Using no_grad()
```{python}
x = torch.tensor([1.0, 2.0], requires_grad=True)
y = x * 2
# Compute without building graph (saves memory)
with torch.no_grad():
z = y * 3
print(f"Inside no_grad, z.requires_grad: {z.requires_grad}")
# Common use: inference
model_output = y # pretend this is model output
with torch.no_grad():
prediction = model_output.argmax() # no gradients needed
```
### When to Stop Gradients
- **Inference**: No training, no gradients needed
- **Frozen layers**: Transfer learning with some layers fixed
- **Metrics**: Computing accuracy, loss for logging (not backprop)
- **Target values**: In losses like MSE, the target is a constant
## Memory Implications
### Why Autograd Uses Memory
During the forward pass, autograd stores intermediate values needed for gradient computation:
```{python}
# Each operation stores data for backward pass
x = torch.randn(1000, 1000, requires_grad=True)
y = x @ x.T # Stores x for gradient computation
z = y.relu() # Stores y (to know which elements > 0)
out = z.mean() # Stores z
# All of x, y, z stay in memory until backward() completes
out.backward()
# Now intermediate tensors can be freed
```
### Memory Grows with Depth
For a model with L layers:
- Forward pass: O(L) memory for activations
- Each activation tensor can be large (batch_size x hidden_dim)
- LLMs with 100+ layers and large hidden dims = huge memory
### Gradient Checkpointing
Trade compute for memory by recomputing activations during backward:
```{python}
from torch.utils.checkpoint import checkpoint
def expensive_layer(x):
"""A layer we want to checkpoint."""
return x.relu().pow(2)
x = torch.randn(100, 100, requires_grad=True)
# Without checkpointing: stores intermediate activations
y_normal = expensive_layer(x)
# With checkpointing: discards intermediates, recomputes during backward
y_checkpoint = checkpoint(expensive_layer, x, use_reentrant=False)
print(f"Both produce same result: {torch.allclose(y_normal, y_checkpoint)}")
```
In practice, checkpoint every few layers to reduce memory by ~sqrt(L).
## Common Pitfalls
### 1. Forgetting to Zero Gradients
```{python}
x = torch.tensor(2.0, requires_grad=True)
# First backward
y = x ** 2
y.backward()
print(f"After first backward: x.grad = {x.grad}")
# Second backward without zeroing - gradients ACCUMULATE!
y = x ** 2
y.backward()
print(f"After second backward: x.grad = {x.grad}") # 8.0, not 4.0!
# Always zero gradients before backward in training loops
x.grad.zero_()
y = x ** 2
y.backward()
print(f"After zeroing: x.grad = {x.grad}") # Back to 4.0
```
### 2. In-place Operations
```{python}
x = torch.tensor([1.0, 2.0], requires_grad=True)
y = x * 2
# This would break the graph (commented to avoid error):
# y.add_(1) # In-place operation on a tensor needed for gradient
# Instead, use out-of-place operations:
y = y + 1 # Creates new tensor, graph preserved
y.sum().backward()
print(f"x.grad: {x.grad}")
```
### 3. Losing requires_grad
```{python}
x = torch.tensor([1.0, 2.0], requires_grad=True)
# Can't call .numpy() directly on a tensor requiring grad:
# y = x.numpy() # RuntimeError: Can't call numpy() on tensor that requires grad
# Must detach first - this explicitly breaks the graph
y = x.detach().numpy()
print(f"y is now numpy array: {type(y)}")
# Converting back loses gradient tracking:
z = torch.from_numpy(y)
print(f"z.requires_grad: {z.requires_grad}") # False
# Be careful when mixing numpy and autograd
```
### 4. Backward Through Non-Scalar
```{python}
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x ** 2 # Vector output
# This fails (commented to avoid error):
# y.backward() # RuntimeError: need gradient argument
# For non-scalar outputs, provide a gradient tensor:
y.backward(torch.ones_like(y)) # Equivalent to y.sum().backward()
print(f"x.grad: {x.grad}")
```
## Summary
Key takeaways:
1. **Gradients flow backward** through the computational graph
2. **Chain rule** connects everything: multiply local gradients
3. **Gradients accumulate** when a value is used multiple times
4. **Training = gradient descent**: use gradients to adjust parameters
5. **PyTorch autograd** does this automatically for any computation
6. **Memory matters**: intermediate activations consume memory; use `detach()`, `no_grad()`, or gradient checkpointing
7. **Zero gradients**: always zero gradients before each backward pass in training loops
## What's Next
In [Module 03: Tokenization](../m03_tokenization/lesson.qmd), we'll learn how text gets converted into numbers that our model can process. Autograd will be working behind the scenes during training, but we need input data first!