---
title: "Module 08: Generation"
format:
html:
code-fold: false
toc: true
ipynb: default
jupyter: python3
---
## Introduction
After training a language model, we want to generate text. The model outputs probabilities for the next token - but how do we choose which token to use? This module explores various decoding strategies that produce different results.
**Text generation** is the process of producing new text from a trained model. The model predicts probability distributions over the vocabulary, and we must decide how to select the next token from these probabilities.
Why does the decoding strategy matter?
- **Different strategies, different outputs**: Greedy decoding gives deterministic results, sampling gives variety
- **Control creativity vs coherence**: Temperature and filtering parameters let us tune this tradeoff
- **Application-specific needs**: Code generation wants precision, creative writing wants diversity
- **Avoid degenerate outputs**: Poor settings lead to repetition, incoherence, or nonsense
Understanding generation is essential for building LLM applications.
### What You'll Learn
By the end of this module, you will be able to:
- Implement the autoregressive generation loop from scratch
- Apply and combine decoding strategies (greedy, temperature, top-k, top-p)
- Use repetition penalties to prevent degenerate outputs
- Understand KV-caching for efficient generation
- Choose appropriate generation parameters for different use cases
## The Generation Loop
Text generation is autoregressive - we generate one token at a time, feeding previous tokens back to the model:
{{< include ../_diagram-lib.qmd >}}
```{ojs}
//| echo: false
// Generation step data - each step shows the state of generation
generationSteps = [
{
step: 0,
tokens: ["BOS", "The", "cat"],
newToken: null,
description: "Start with prompt tokens",
phase: "input",
probabilities: null
},
{
step: 1,
tokens: ["BOS", "The", "cat"],
newToken: null,
description: "Feed tokens to transformer model",
phase: "forward",
probabilities: null
},
{
step: 2,
tokens: ["BOS", "The", "cat"],
newToken: null,
description: "Model outputs probability distribution",
phase: "output",
probabilities: [
{token: "sat", prob: 0.42},
{token: "is", prob: 0.18},
{token: "ran", prob: 0.15},
{token: "...", prob: 0.25}
]
},
{
step: 3,
tokens: ["BOS", "The", "cat"],
newToken: "sat",
description: "Apply decoding strategy, select 'sat'",
phase: "select",
probabilities: [
{token: "sat", prob: 0.42, selected: true},
{token: "is", prob: 0.18},
{token: "ran", prob: 0.15},
{token: "...", prob: 0.25}
]
},
{
step: 4,
tokens: ["BOS", "The", "cat", "sat"],
newToken: "sat",
description: "Append selected token to sequence",
phase: "append",
probabilities: null
},
{
step: 5,
tokens: ["BOS", "The", "cat", "sat"],
newToken: null,
description: "Feed extended sequence to model",
phase: "forward",
probabilities: null
},
{
step: 6,
tokens: ["BOS", "The", "cat", "sat"],
newToken: null,
description: "Model outputs new probabilities",
phase: "output",
probabilities: [
{token: "on", prob: 0.51},
{token: "down", prob: 0.22},
{token: "still", prob: 0.12},
{token: "...", prob: 0.15}
]
},
{
step: 7,
tokens: ["BOS", "The", "cat", "sat"],
newToken: "on",
description: "Select 'on' from distribution",
phase: "select",
probabilities: [
{token: "on", prob: 0.51, selected: true},
{token: "down", prob: 0.22},
{token: "still", prob: 0.12},
{token: "...", prob: 0.15}
]
},
{
step: 8,
tokens: ["BOS", "The", "cat", "sat", "on"],
newToken: "on",
description: "Sequence grows with each iteration",
phase: "append",
probabilities: null
},
{
step: 9,
tokens: ["BOS", "The", "cat", "sat", "on", "..."],
newToken: null,
description: "Continue until max_len or EOS token",
phase: "done",
probabilities: null
}
]
// Step slider control
viewof genStep = Inputs.range([0, generationSteps.length - 1], {
value: 0,
step: 1,
label: "Generation Step"
})
// Current state
currentGenState = generationSteps[genStep]
```
```{ojs}
//| echo: false
// Autoregressive generation loop visualization
{
const width = 700;
const height = 420;
const state = currentGenState;
const theme = diagramTheme;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'IBM Plex Sans', system-ui, sans-serif");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 12);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 28)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "15px")
.attr("font-weight", "600")
.attr("letter-spacing", "0.5px")
.text("The Autoregressive Generation Loop");
// ===== TOKEN SEQUENCE (TOP) =====
const tokenY = 70;
const tokenWidth = 58;
const tokenHeight = 32;
const tokenGap = 6;
const tokens = state.tokens;
const totalTokenWidth = tokens.length * (tokenWidth + tokenGap) - tokenGap;
const tokenStartX = (width - totalTokenWidth) / 2;
// Token sequence label
svg.append("text")
.attr("x", tokenStartX - 10)
.attr("y", tokenY + tokenHeight / 2)
.attr("text-anchor", "end")
.attr("dominant-baseline", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.7)
.text("Sequence:");
// Draw tokens
tokens.forEach((tok, i) => {
const x = tokenStartX + i * (tokenWidth + tokenGap);
const isNew = state.newToken && tok === state.newToken && i === tokens.length - 1;
const isPrompt = i < 3 && state.step < 4;
let fill = theme.nodeFill;
let stroke = theme.nodeStroke;
let textColor = theme.nodeText;
if (isNew) {
fill = theme.accent;
stroke = theme.accent;
textColor = theme.textOnAccent;
} else if (isPrompt && state.phase === "input") {
fill = theme.highlight;
stroke = theme.highlight;
textColor = theme.textOnHighlight;
}
const g = svg.append("g")
.attr("transform", `translate(${x}, ${tokenY})`);
g.append("rect")
.attr("width", tokenWidth)
.attr("height", tokenHeight)
.attr("rx", 6)
.attr("fill", fill)
.attr("stroke", stroke)
.attr("stroke-width", isNew ? 2 : 1.5);
if (isNew) {
g.select("rect")
.attr("filter", `drop-shadow(0 0 8px ${theme.accentGlow})`);
}
g.append("text")
.attr("x", tokenWidth / 2)
.attr("y", tokenHeight / 2)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "middle")
.attr("fill", textColor)
.attr("font-size", "12px")
.attr("font-weight", "500")
.text(tok);
});
// ===== TRANSFORMER MODEL (MIDDLE) =====
const modelY = 160;
const modelWidth = 180;
const modelHeight = 70;
const modelX = width / 2 - modelWidth / 2;
const modelActive = state.phase === "forward";
const modelFill = modelActive ? theme.highlight : theme.nodeFill;
const modelStroke = modelActive ? theme.highlight : theme.nodeStroke;
const modelText = modelActive ? theme.textOnHighlight : theme.nodeText;
const modelG = svg.append("g")
.attr("transform", `translate(${modelX}, ${modelY})`);
modelG.append("rect")
.attr("width", modelWidth)
.attr("height", modelHeight)
.attr("rx", 10)
.attr("fill", modelFill)
.attr("stroke", modelStroke)
.attr("stroke-width", modelActive ? 2.5 : 1.5);
if (modelActive) {
modelG.select("rect")
.attr("filter", `drop-shadow(0 0 10px ${theme.highlightGlow})`);
}
modelG.append("text")
.attr("x", modelWidth / 2)
.attr("y", modelHeight / 2 - 8)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "middle")
.attr("fill", modelText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text("Transformer");
modelG.append("text")
.attr("x", modelWidth / 2)
.attr("y", modelHeight / 2 + 12)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "middle")
.attr("fill", modelText)
.attr("font-size", "11px")
.attr("opacity", modelActive ? 0.9 : 0.6)
.text("Language Model");
// Arrow from tokens to model
const arrowInY1 = tokenY + tokenHeight + 8;
const arrowInY2 = modelY - 8;
svg.append("path")
.attr("d", `M${width/2},${arrowInY1} L${width/2},${arrowInY2 - 6}`)
.attr("stroke", state.phase === "forward" || state.phase === "input" ? theme.highlight : theme.edgeStroke)
.attr("stroke-width", 2)
.attr("fill", "none")
.attr("marker-end", `url(#arrowhead-${state.phase === "forward" ? "highlight" : "normal"})`);
// ===== PROBABILITY OUTPUT (BOTTOM) =====
const probY = 270;
const probBoxWidth = 280;
const probBoxHeight = 90;
const probX = width / 2 - probBoxWidth / 2;
const probActive = state.phase === "output" || state.phase === "select";
const probFill = probActive ? theme.bgSecondary : theme.nodeFill;
const probG = svg.append("g")
.attr("transform", `translate(${probX}, ${probY})`);
probG.append("rect")
.attr("width", probBoxWidth)
.attr("height", probBoxHeight)
.attr("rx", 8)
.attr("fill", probFill)
.attr("stroke", probActive ? theme.highlight : theme.nodeStroke)
.attr("stroke-width", probActive ? 2 : 1.5)
.attr("stroke-dasharray", probActive ? "none" : "4,2");
probG.append("text")
.attr("x", probBoxWidth / 2)
.attr("y", 18)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text("P(next token)");
// Show probability bars if available
if (state.probabilities) {
const barWidth = 50;
const barMaxHeight = 45;
const barGap = 12;
const probs = state.probabilities;
const totalBarWidth = probs.length * (barWidth + barGap) - barGap;
const barStartX = (probBoxWidth - totalBarWidth) / 2;
probs.forEach((p, i) => {
const bx = barStartX + i * (barWidth + barGap);
const barHeight = p.prob * barMaxHeight;
const by = 70 - barHeight;
let barFill = theme.edgeStroke;
if (p.selected) {
barFill = theme.accent;
}
probG.append("rect")
.attr("x", bx)
.attr("y", by)
.attr("width", barWidth)
.attr("height", barHeight)
.attr("rx", 3)
.attr("fill", barFill)
.attr("opacity", p.selected ? 1 : 0.6);
if (p.selected) {
probG.append("rect")
.attr("x", bx)
.attr("y", by)
.attr("width", barWidth)
.attr("height", barHeight)
.attr("rx", 3)
.attr("fill", "none")
.attr("stroke", theme.accent)
.attr("stroke-width", 2)
.attr("filter", `drop-shadow(0 0 6px ${theme.accentGlow})`);
}
probG.append("text")
.attr("x", bx + barWidth / 2)
.attr("y", 78)
.attr("text-anchor", "middle")
.attr("fill", p.selected ? theme.accent : theme.nodeText)
.attr("font-size", "10px")
.attr("font-weight", p.selected ? "600" : "400")
.text(p.token);
});
} else {
probG.append("text")
.attr("x", probBoxWidth / 2)
.attr("y", 50)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.5)
.text("Waiting for model output...");
}
// Arrow from model to probabilities
const arrowOutY1 = modelY + modelHeight + 8;
const arrowOutY2 = probY - 8;
svg.append("path")
.attr("d", `M${width/2},${arrowOutY1} L${width/2},${arrowOutY2 - 6}`)
.attr("stroke", state.phase === "output" ? theme.highlight : theme.edgeStroke)
.attr("stroke-width", 2)
.attr("fill", "none")
.attr("marker-end", `url(#arrowhead-${state.phase === "output" ? "highlight" : "normal"})`);
// ===== LOOP ARROW (right side) =====
if (state.phase !== "done") {
const loopX = width - 80;
const loopStartY = probY + probBoxHeight / 2;
const loopEndY = tokenY + tokenHeight / 2;
const loopActive = state.phase === "append";
const loopPath = `M${probX + probBoxWidth + 10},${loopStartY}
L${loopX},${loopStartY}
L${loopX},${loopEndY}
L${tokenStartX + totalTokenWidth + 15},${loopEndY}`;
svg.append("path")
.attr("d", loopPath)
.attr("stroke", loopActive ? theme.accent : theme.edgeStroke)
.attr("stroke-width", loopActive ? 2.5 : 1.5)
.attr("fill", "none")
.attr("stroke-dasharray", loopActive ? "none" : "6,4")
.attr("marker-end", `url(#arrowhead-${loopActive ? "accent" : "normal"})`);
svg.append("text")
.attr("x", loopX + 8)
.attr("y", (loopStartY + loopEndY) / 2)
.attr("fill", loopActive ? theme.accent : theme.nodeText)
.attr("font-size", "10px")
.attr("opacity", loopActive ? 1 : 0.6)
.text("append");
}
// ===== STEP DESCRIPTION =====
const descY = 395;
svg.append("text")
.attr("x", width / 2)
.attr("y", descY)
.attr("text-anchor", "middle")
.attr("fill", theme.highlight)
.attr("font-size", "13px")
.attr("font-weight", "500")
.text(`Step ${state.step}: ${state.description}`);
// ===== ARROWHEAD MARKERS =====
const defs = svg.append("defs");
// Normal arrowhead
defs.append("marker")
.attr("id", "arrowhead-normal")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", theme.edgeStroke);
// Highlight arrowhead
defs.append("marker")
.attr("id", "arrowhead-highlight")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", theme.highlight);
// Accent arrowhead
defs.append("marker")
.attr("id", "arrowhead-accent")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", theme.accent);
return svg.node();
}
```
Each step:
1. Feed current tokens to the model
2. Get probability distribution over vocabulary for next position
3. Apply decoding strategy to select next token
4. Append selected token to sequence
5. Repeat until stopping criterion met
```{ojs}
//| echo: false
// Step slider for step-by-step generation
viewof stepGenStep = Inputs.range([1, 4], {
step: 1,
value: 1,
label: "Generation Step"
})
// Step data for the visualization
stepGenData = [
{
step: 1,
tokens: ["BOS", "The", "cat"],
probs: [
{ token: "sat", prob: 0.40, selected: true },
{ token: "is", prob: 0.22 },
{ token: "ran", prob: 0.18 },
{ token: "...", prob: 0.20 }
],
selected: "sat"
},
{
step: 2,
tokens: ["BOS", "The", "cat", "sat"],
probs: [
{ token: "on", prob: 0.50, selected: true },
{ token: "down", prob: 0.22 },
{ token: "still", prob: 0.13 },
{ token: "...", prob: 0.15 }
],
selected: "on"
},
{
step: 3,
tokens: ["BOS", "The", "cat", "sat", "on"],
probs: [
{ token: "the", prob: 0.55, selected: true },
{ token: "a", prob: 0.25 },
{ token: "my", prob: 0.10 },
{ token: "...", prob: 0.10 }
],
selected: "the"
},
{
step: 4,
tokens: ["BOS", "The", "cat", "sat", "on", "the"],
probs: [
{ token: "mat", prob: 0.38, selected: true },
{ token: "floor", prob: 0.28 },
{ token: "bed", prob: 0.19 },
{ token: "...", prob: 0.15 }
],
selected: "mat"
}
]
currentStepGen = stepGenData[stepGenStep - 1]
```
```{ojs}
//| echo: false
// Step-by-step generation diagram
{
const width = 700;
const height = 260;
const theme = diagramTheme;
const data = currentStepGen;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'IBM Plex Sans', system-ui, sans-serif");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 10);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 24)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text(`Step ${data.step}: Generate next token`);
// === INPUT TOKENS SECTION ===
const tokenY = 55;
const tokenWidth = 52;
const tokenHeight = 28;
const tokenGap = 5;
const tokens = data.tokens;
const totalTokenWidth = tokens.length * (tokenWidth + tokenGap) - tokenGap;
const tokenStartX = 30;
// Label
svg.append("text")
.attr("x", tokenStartX)
.attr("y", tokenY - 8)
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.attr("opacity", 0.8)
.text("Input sequence:");
// Draw tokens
tokens.forEach((tok, i) => {
const x = tokenStartX + i * (tokenWidth + tokenGap);
const isNew = i === tokens.length - 1 && data.step > 1;
const g = svg.append("g")
.attr("transform", `translate(${x}, ${tokenY})`);
g.append("rect")
.attr("width", tokenWidth)
.attr("height", tokenHeight)
.attr("rx", 5)
.attr("fill", isNew ? theme.accent : theme.nodeFill)
.attr("stroke", isNew ? theme.accent : theme.nodeStroke)
.attr("stroke-width", isNew ? 2 : 1.5);
if (isNew) {
g.select("rect")
.attr("filter", `drop-shadow(0 0 6px ${theme.accentGlow})`);
}
g.append("text")
.attr("x", tokenWidth / 2)
.attr("y", tokenHeight / 2)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "middle")
.attr("fill", isNew ? theme.textOnAccent : theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.text(tok);
});
// === ARROW TO PROBABILITY ===
const arrowStartX = tokenStartX + totalTokenWidth + 15;
const arrowEndX = 380;
const arrowY = tokenY + tokenHeight / 2;
svg.append("path")
.attr("d", `M${arrowStartX},${arrowY} L${arrowEndX - 8},${arrowY}`)
.attr("stroke", theme.highlight)
.attr("stroke-width", 2)
.attr("fill", "none")
.attr("marker-end", "url(#stepgen-arrow)");
svg.append("text")
.attr("x", (arrowStartX + arrowEndX) / 2)
.attr("y", arrowY - 10)
.attr("text-anchor", "middle")
.attr("fill", theme.highlight)
.attr("font-size", "10px")
.attr("font-weight", "500")
.text("Model forward pass");
// === PROBABILITY DISTRIBUTION ===
const probX = 390;
const probY = 40;
const probWidth = 280;
const probHeight = 100;
// Label
svg.append("text")
.attr("x", probX)
.attr("y", probY)
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.attr("opacity", 0.8)
.text("P(next token):");
// Probability bars
const barWidth = 55;
const barMaxHeight = 55;
const barGap = 10;
const probs = data.probs;
const barStartX = probX;
const barBaseY = probY + 75;
probs.forEach((p, i) => {
const bx = barStartX + i * (barWidth + barGap);
const barHeight = p.prob * barMaxHeight * 1.8;
const by = barBaseY - barHeight;
// Bar
svg.append("rect")
.attr("x", bx)
.attr("y", by)
.attr("width", barWidth)
.attr("height", barHeight)
.attr("rx", 4)
.attr("fill", p.selected ? theme.accent : theme.edgeStroke)
.attr("opacity", p.selected ? 1 : 0.5);
if (p.selected) {
svg.append("rect")
.attr("x", bx)
.attr("y", by)
.attr("width", barWidth)
.attr("height", barHeight)
.attr("rx", 4)
.attr("fill", "none")
.attr("stroke", theme.accent)
.attr("stroke-width", 2)
.attr("filter", `drop-shadow(0 0 5px ${theme.accentGlow})`);
}
// Probability value
svg.append("text")
.attr("x", bx + barWidth / 2)
.attr("y", by - 5)
.attr("text-anchor", "middle")
.attr("fill", p.selected ? theme.accent : theme.nodeText)
.attr("font-size", "10px")
.attr("font-weight", p.selected ? "600" : "400")
.text((p.prob * 100).toFixed(0) + "%");
// Token label
svg.append("text")
.attr("x", bx + barWidth / 2)
.attr("y", barBaseY + 14)
.attr("text-anchor", "middle")
.attr("fill", p.selected ? theme.accent : theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", p.selected ? "600" : "400")
.text(p.token);
});
// === SELECTION RESULT ===
const resultY = 175;
const resultBoxWidth = 200;
const resultBoxHeight = 36;
const resultX = width / 2 - resultBoxWidth / 2;
// Arrow down to result
svg.append("path")
.attr("d", `M${probX + 130},${barBaseY + 25} L${probX + 130},${resultY - 5}`)
.attr("stroke", theme.accent)
.attr("stroke-width", 2)
.attr("fill", "none")
.attr("marker-end", "url(#stepgen-arrow-accent)");
// Result box
svg.append("rect")
.attr("x", resultX)
.attr("y", resultY)
.attr("width", resultBoxWidth)
.attr("height", resultBoxHeight)
.attr("rx", 6)
.attr("fill", theme.bgSecondary)
.attr("stroke", theme.accent)
.attr("stroke-width", 2);
svg.append("text")
.attr("x", resultX + resultBoxWidth / 2)
.attr("y", resultY + resultBoxHeight / 2)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "middle")
.attr("fill", theme.accent)
.attr("font-size", "13px")
.attr("font-weight", "600")
.text(`Select: "${data.selected}"`);
// === STEP INDICATOR ===
const stepIndicatorY = 235;
const stepIndicatorWidth = 280;
const stepStartX = width / 2 - stepIndicatorWidth / 2;
// Step dots
for (let i = 1; i <= 4; i++) {
const dotX = stepStartX + (i - 1) * 70 + 35;
const isActive = i === data.step;
const isPast = i < data.step;
svg.append("circle")
.attr("cx", dotX)
.attr("cy", stepIndicatorY)
.attr("r", isActive ? 10 : 7)
.attr("fill", isActive ? theme.accent : isPast ? theme.highlight : theme.nodeFill)
.attr("stroke", isActive ? theme.accent : isPast ? theme.highlight : theme.nodeStroke)
.attr("stroke-width", isActive ? 2 : 1.5);
if (isActive) {
svg.append("circle")
.attr("cx", dotX)
.attr("cy", stepIndicatorY)
.attr("r", 10)
.attr("fill", "none")
.attr("stroke", theme.accent)
.attr("stroke-width", 2)
.attr("filter", `drop-shadow(0 0 4px ${theme.accentGlow})`);
}
svg.append("text")
.attr("x", dotX)
.attr("y", stepIndicatorY)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "middle")
.attr("fill", isActive ? theme.textOnAccent : isPast ? theme.textOnHighlight : theme.nodeText)
.attr("font-size", "10px")
.attr("font-weight", "600")
.text(i);
// Connecting line
if (i < 4) {
svg.append("line")
.attr("x1", dotX + 12)
.attr("y1", stepIndicatorY)
.attr("x2", dotX + 58)
.attr("y2", stepIndicatorY)
.attr("stroke", i < data.step ? theme.highlight : theme.nodeStroke)
.attr("stroke-width", 1.5)
.attr("stroke-dasharray", i < data.step ? "none" : "4,3");
}
}
// === ARROW MARKERS ===
const defs = svg.append("defs");
defs.append("marker")
.attr("id", "stepgen-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", theme.highlight);
defs.append("marker")
.attr("id", "stepgen-arrow-accent")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", theme.accent);
return svg.node();
}
```
### From Scratch: The Generation Loop
Let's build generation from the ground up. At its core, generation is just **repeated next-token prediction with sampling**.
```{python}
import numpy as np
def softmax(x: np.ndarray) -> np.ndarray:
"""Stable softmax: subtract max to prevent overflow."""
x_max = x.max(axis=-1, keepdims=True)
exp_x = np.exp(x - x_max)
return exp_x / exp_x.sum(axis=-1, keepdims=True)
def generate_scratch(get_logits, context: np.ndarray, max_new_tokens: int = 10, temperature: float = 1.0) -> np.ndarray:
"""
Generate tokens autoregressively from scratch.
Args:
get_logits: Function that takes context (1, seq_len) and returns logits (1, vocab_size)
context: Starting token ids, shape (1, seq_len)
max_new_tokens: How many tokens to generate
temperature: Sampling temperature (higher = more random)
Returns:
Extended context with generated tokens
"""
ctx = context.copy()
for _ in range(max_new_tokens):
# 1. Get logits for last position
logits = get_logits(ctx) # (1, vocab_size)
# 2. Apply temperature (scale before softmax)
logits = logits / temperature
# 3. Convert to probabilities
probs = softmax(logits)[0] # (vocab_size,)
# 4. Sample next token
next_token = np.random.choice(len(probs), p=probs)
# 5. Append to context
ctx = np.concatenate([ctx, [[next_token]]], axis=1)
return ctx
```
**Key insight:** Generation is surprisingly simple. The model predicts, we sample, and we feed the result back in. That's it.
### From Logits to Tokens
The logits-to-token pipeline is the heart of generation:
```{python}
# Step-by-step: logits -> probabilities -> token
# Simulate model output (logits for 8-token vocabulary)
logits = np.array([[2.0, 1.5, 0.5, 0.0, -0.5, -1.0, -1.5, -2.0]])
token_names = ["the", "cat", "sat", "on", "mat", "dog", "ran", "fast"]
print("Step 1: Raw logits from model")
for i, (name, logit) in enumerate(zip(token_names, logits[0])):
print(f" {name:>4}: {logit:+.1f}")
print("\nStep 2: Apply softmax to get probabilities")
probs = softmax(logits)[0]
for name, prob in zip(token_names, probs):
bar = "█" * int(prob * 40)
print(f" {name:>4}: {prob:.3f} {bar}")
print("\nStep 3: Sample from the distribution")
np.random.seed(42)
sampled_idx = np.random.choice(len(probs), p=probs)
print(f" Sampled token: '{token_names[sampled_idx]}' (index {sampled_idx})")
```
**Temperature** controls the randomness by scaling logits before softmax:
```{python}
print("Effect of temperature on the same logits:\n")
for temp in [0.5, 1.0, 2.0]:
scaled_logits = logits / temp
probs = softmax(scaled_logits)[0]
print(f"Temperature = {temp}:")
for name, prob in zip(token_names[:4], probs[:4]): # Show top 4
bar = "█" * int(prob * 30)
print(f" {name:>4}: {prob:.3f} {bar}")
print()
print("Lower temp = sharper (more deterministic)")
print("Higher temp = flatter (more random)")
```
Now let's see it in action with a real model.
## Code Walkthrough
Let's explore generation interactively:
```{python}
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
print(f"PyTorch version: {torch.__version__}")
```
### Setting Up
```{python}
import sys
sys.path.insert(0, '..')
from generation import (
top_k_filtering,
top_p_filtering,
apply_repetition_penalty,
generate,
generate_greedy,
generate_sample,
get_token_probabilities,
get_top_tokens,
)
from m06_transformer.transformer import create_gpt_tiny
# Create a small model for demonstration
vocab_size = 50
model = create_gpt_tiny(vocab_size=vocab_size)
# Create a sample prompt
prompt = torch.randint(0, vocab_size, (1, 5))
print(f"Prompt tokens: {prompt[0].tolist()}")
```
### Understanding Model Output
A language model outputs logits (unnormalized scores) that become probabilities after softmax:
```{python}
# Get probability distribution for next token
probs = get_token_probabilities(model, prompt)
print(f"Probability distribution shape: {probs.shape}")
print(f"Sum of probabilities: {probs.sum().item():.4f}")
# Visualize the distribution
plt.figure(figsize=(12, 4))
plt.bar(range(vocab_size), probs[0].numpy())
plt.xlabel('Token ID')
plt.ylabel('Probability')
plt.title('Next Token Probability Distribution')
plt.grid(True, alpha=0.3)
plt.show()
# Show top tokens
top = get_top_tokens(probs, k=5)
print("\nTop 5 most likely next tokens:")
for token_id, prob in top:
print(f" Token {token_id}: {prob*100:.2f}%")
```
## Decoding Strategies
### 1. Greedy Decoding
Always pick the token with the highest probability - simple but can be repetitive.
```{ojs}
//| echo: false
// Greedy decoding interactive visualization
{
const width = 580;
const height = 300;
const margin = { top: 45, right: 30, bottom: 70, left: 55 };
const chartWidth = width - margin.left - margin.right;
const chartHeight = height - margin.top - margin.bottom;
// Token probabilities as specified
const tokenData = [
{ token: "the", prob: 0.40 },
{ token: "a", prob: 0.30 },
{ token: "this", prob: 0.20 },
{ token: "it", prob: 0.07 },
{ token: "an", prob: 0.03 }
];
// Find max probability token
const maxProb = Math.max(...tokenData.map(d => d.prob));
const maxToken = tokenData.find(d => d.prob === maxProb);
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'IBM Plex Sans', system-ui, sans-serif");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 10);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 24)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.attr("letter-spacing", "0.3px")
.text("Greedy Decoding: Always Pick Maximum Probability");
const chart = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
// Scales
const xScale = d3.scaleBand()
.domain(tokenData.map(d => d.token))
.range([0, chartWidth])
.padding(0.25);
const yScale = d3.scaleLinear()
.domain([0, 0.5])
.range([chartHeight, 0]);
// Grid lines
chart.append("g")
.attr("class", "grid")
.selectAll("line")
.data([0.1, 0.2, 0.3, 0.4, 0.5])
.join("line")
.attr("x1", 0)
.attr("x2", chartWidth)
.attr("y1", d => yScale(d))
.attr("y2", d => yScale(d))
.attr("stroke", diagramTheme.nodeStroke)
.attr("stroke-opacity", 0.3)
.attr("stroke-dasharray", "3,3");
// Y axis
chart.append("g")
.call(d3.axisLeft(yScale).ticks(5).tickFormat(d => `${(d * 100).toFixed(0)}%`))
.call(g => g.select(".domain").attr("stroke", diagramTheme.nodeStroke))
.call(g => g.selectAll(".tick line").attr("stroke", diagramTheme.nodeStroke))
.call(g => g.selectAll(".tick text")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px"));
// Y axis label
chart.append("text")
.attr("transform", "rotate(-90)")
.attr("y", -42)
.attr("x", -chartHeight / 2)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.text("Probability");
// Bars
chart.selectAll(".bar")
.data(tokenData)
.join("rect")
.attr("class", "bar")
.attr("x", d => xScale(d.token))
.attr("y", d => yScale(d.prob))
.attr("width", xScale.bandwidth())
.attr("height", d => chartHeight - yScale(d.prob))
.attr("rx", 4)
.attr("fill", d => d.prob === maxProb ? diagramTheme.accent : diagramTheme.nodeStroke)
.attr("opacity", d => d.prob === maxProb ? 1 : 0.5);
// Glow effect on max bar
chart.selectAll(".bar-glow")
.data(tokenData.filter(d => d.prob === maxProb))
.join("rect")
.attr("class", "bar-glow")
.attr("x", d => xScale(d.token))
.attr("y", d => yScale(d.prob))
.attr("width", xScale.bandwidth())
.attr("height", d => chartHeight - yScale(d.prob))
.attr("rx", 4)
.attr("fill", "none")
.attr("stroke", diagramTheme.accent)
.attr("stroke-width", 2.5)
.attr("filter", `drop-shadow(0 0 8px ${diagramTheme.accentGlow})`);
// Probability labels on bars
chart.selectAll(".prob-label")
.data(tokenData)
.join("text")
.attr("class", "prob-label")
.attr("x", d => xScale(d.token) + xScale.bandwidth() / 2)
.attr("y", d => yScale(d.prob) - 8)
.attr("text-anchor", "middle")
.attr("fill", d => d.prob === maxProb ? diagramTheme.accent : diagramTheme.nodeText)
.attr("font-size", "12px")
.attr("font-weight", d => d.prob === maxProb ? "700" : "500")
.text(d => `${(d.prob * 100).toFixed(0)}%`);
// X axis (token labels)
chart.append("g")
.attr("transform", `translate(0, ${chartHeight})`)
.call(d3.axisBottom(xScale).tickSize(0))
.call(g => g.select(".domain").attr("stroke", diagramTheme.nodeStroke))
.call(g => g.selectAll(".tick text")
.attr("fill", (d, i) => tokenData[i].prob === maxProb ? diagramTheme.accent : diagramTheme.nodeText)
.attr("font-size", "13px")
.attr("font-weight", (d, i) => tokenData[i].prob === maxProb ? "700" : "500")
.attr("font-family", "monospace")
.attr("dy", "0.8em"));
// "MAX" indicator arrow pointing down at the selected bar
const maxBarX = xScale(maxToken.token) + xScale.bandwidth() / 2;
const maxBarY = yScale(maxToken.prob);
// Arrow from label to bar
svg.append("path")
.attr("d", `M${margin.left + maxBarX},${margin.top + maxBarY - 38} L${margin.left + maxBarX},${margin.top + maxBarY - 22}`)
.attr("stroke", diagramTheme.highlight)
.attr("stroke-width", 2)
.attr("marker-end", "url(#greedy-arrow)");
// Arrow marker definition
const defs = svg.append("defs");
defs.append("marker")
.attr("id", "greedy-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.highlight);
// "ALWAYS PICK MAX" label
svg.append("text")
.attr("x", margin.left + maxBarX)
.attr("y", margin.top + maxBarY - 48)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.highlight)
.attr("font-size", "11px")
.attr("font-weight", "700")
.attr("letter-spacing", "0.5px")
.text("ALWAYS PICK MAX");
// Selected token indicator at bottom
const selectedY = height - 22;
svg.append("rect")
.attr("x", width / 2 - 85)
.attr("y", selectedY - 14)
.attr("width", 170)
.attr("height", 28)
.attr("rx", 6)
.attr("fill", diagramTheme.accent)
.attr("filter", `drop-shadow(0 0 6px ${diagramTheme.accentGlow})`);
svg.append("text")
.attr("x", width / 2)
.attr("y", selectedY + 2)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.textOnAccent)
.attr("font-size", "13px")
.attr("font-weight", "600")
.text(`Selected: "${maxToken.token}"`);
return svg.node();
}
```
**Pros**: Deterministic, coherent output
**Cons**: Boring, repetitive, can get stuck in loops
```{python}
# Generate with greedy decoding
output_greedy = generate_greedy(model, prompt, max_new_tokens=15)
print(f"Prompt: {prompt[0].tolist()}")
print(f"Generated: {output_greedy[0, 5:].tolist()}")
```
```{python}
# Greedy is deterministic - same output every time
print("Multiple greedy generations (should all be identical):")
for i in range(3):
out = generate_greedy(model, prompt, max_new_tokens=10)
print(f" Run {i+1}: {out[0, 5:].tolist()}")
```
### 2. Temperature Sampling
Temperature controls the "sharpness" of the probability distribution before sampling:
$$P_{\text{new}} = \text{softmax}(\text{logits} / T)$$
```{ojs}
//| echo: false
// Temperature slider input
viewof temperature = Inputs.range([0.1, 3.0], {
value: 1.0,
step: 0.1,
label: "Temperature (T)"
})
```
```{ojs}
//| echo: false
// Base logits and token labels
tempBaseLogits = [2.0, 1.5, 1.0, 0.5, 0.0, -0.5, -1.0]
tempTokenLabels = ["the", "a", "an", "this", "that", "one", "some"]
// Softmax with temperature calculation
tempScaledLogits = tempBaseLogits.map(l => l / temperature)
tempMaxLogit = Math.max(...tempScaledLogits)
tempExpLogits = tempScaledLogits.map(l => Math.exp(l - tempMaxLogit))
tempSumExp = tempExpLogits.reduce((a, b) => a + b, 0)
tempProbs = tempExpLogits.map(e => e / tempSumExp)
// Calculate entropy: -sum(p * log2(p))
tempEntropy = -tempProbs.reduce((sum, p) => {
if (p > 0) return sum + p * Math.log2(p);
return sum;
}, 0)
// Description based on temperature range
tempDescription = temperature < 0.5
? "Very sharp - nearly deterministic. The model strongly favors the highest-probability token."
: temperature < 0.9
? "Sharp distribution - more focused on high-probability tokens with some variety."
: temperature <= 1.1
? "Original distribution - balanced between coherence and diversity."
: temperature < 2.0
? "Flatter distribution - increased randomness, more creative outputs."
: "Very flat - nearly uniform. High chance of unexpected or incoherent tokens."
```
```{ojs}
//| echo: false
// Create the temperature effects visualization
{
const width = 540;
const height = 280;
const margin = {top: 40, right: 30, bottom: 50, left: 50};
const plotWidth = width - margin.left - margin.right;
const plotHeight = height - margin.top - margin.bottom;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`);
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 8);
const g = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
// Scales
const x = d3.scaleBand()
.domain(tempTokenLabels)
.range([0, plotWidth])
.padding(0.2);
const y = d3.scaleLinear()
.domain([0, 1])
.range([plotHeight, 0]);
// X axis
g.append("g")
.attr("transform", `translate(0, ${plotHeight})`)
.call(d3.axisBottom(x))
.selectAll("text")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-family", "monospace");
g.selectAll(".domain, .tick line")
.attr("stroke", diagramTheme.nodeStroke);
// Y axis
g.append("g")
.call(d3.axisLeft(y).ticks(5).tickFormat(d3.format(".0%")))
.selectAll("text")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px");
g.selectAll(".domain, .tick line")
.attr("stroke", diagramTheme.nodeStroke);
// Y axis label
g.append("text")
.attr("transform", "rotate(-90)")
.attr("y", -40)
.attr("x", -plotHeight / 2)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.text("Probability");
// Bars
g.selectAll(".bar")
.data(tempProbs)
.join("rect")
.attr("class", "bar")
.attr("x", (d, i) => x(tempTokenLabels[i]))
.attr("y", d => y(d))
.attr("width", x.bandwidth())
.attr("height", d => plotHeight - y(d))
.attr("fill", diagramTheme.accent)
.attr("rx", 3)
.attr("opacity", 0.85);
// Probability labels on bars
g.selectAll(".prob-label")
.data(tempProbs)
.join("text")
.attr("class", "prob-label")
.attr("x", (d, i) => x(tempTokenLabels[i]) + x.bandwidth() / 2)
.attr("y", d => y(d) - 6)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("font-weight", "500")
.text(d => (d * 100).toFixed(1) + "%");
// Title with temperature value
svg.append("text")
.attr("x", width / 2)
.attr("y", 22)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "13px")
.attr("font-weight", "600")
.text(`Temperature = ${temperature.toFixed(1)}`);
// Entropy indicator
svg.append("text")
.attr("x", width - margin.right - 5)
.attr("y", 22)
.attr("text-anchor", "end")
.attr("fill", diagramTheme.highlight)
.attr("font-size", "11px")
.text(`Entropy: ${tempEntropy.toFixed(2)} bits`);
return svg.node();
}
```
```{ojs}
//| echo: false
// Description panel
html`<div style="
background: ${diagramTheme.bgSecondary};
border: 1px solid ${diagramTheme.nodeStroke};
border-radius: 6px;
padding: 12px 16px;
margin-top: 8px;
max-width: 540px;
font-size: 13px;
color: ${diagramTheme.nodeText};
">
<strong style="color: ${diagramTheme.accent};">T = ${temperature.toFixed(1)}:</strong> ${tempDescription}
</div>`
```
- **Temperature < 1.0**: Sharper distribution (more like greedy)
- **Temperature = 1.0**: Original distribution
- **Temperature > 1.0**: Flatter distribution (more random)
**Note**: Temperature = 0 would cause division by zero. In practice, very low temperatures (e.g., 0.01) approximate greedy decoding, and many implementations treat temperature = 0 as an alias for greedy mode.
```{python}
# Visualize temperature effects
logits = torch.tensor([2.0, 1.0, 0.5, 0.0, -0.5, -1.0, -2.0])
fig, axes = plt.subplots(1, 4, figsize=(16, 3))
for ax, temp in zip(axes, [0.3, 0.7, 1.0, 2.0]):
scaled_logits = logits / temp
probs = F.softmax(scaled_logits, dim=0)
ax.bar(range(7), probs.numpy())
ax.set_xlabel('Token')
ax.set_ylabel('Probability')
ax.set_title(f'Temperature = {temp}')
ax.set_ylim(0, 1)
plt.suptitle('Effect of Temperature on Probability Distribution', fontsize=14)
plt.tight_layout()
plt.show()
```
```{python}
# Generate with different temperatures
print("Generating with different temperatures:\n")
for temp in [0.3, 0.7, 1.0, 1.5]:
print(f"Temperature = {temp}:")
for i in range(3):
torch.manual_seed(42 + i)
out = generate(model, prompt, max_new_tokens=10, temperature=temp, do_sample=True)
print(f" Sample {i+1}: {out[0, 5:].tolist()}")
print()
```
### 3. Top-k Sampling
Only sample from the k most likely tokens - filters out unlikely tokens:
```{ojs}
//| echo: false
// K slider for controlling top-k filtering
viewof topKValue = {
const container = document.createElement("div");
container.style.cssText = "display: flex; align-items: center; gap: 16px; margin-bottom: 16px; font-family: system-ui, sans-serif;";
const label = document.createElement("span");
label.style.cssText = `font-weight: 600; color: ${diagramTheme.nodeText}; font-size: 14px;`;
label.textContent = "K value:";
const slider = document.createElement("input");
slider.type = "range";
slider.min = 1;
slider.max = 7;
slider.value = 4;
slider.style.cssText = "width: 180px; cursor: pointer;";
const valueDisplay = document.createElement("span");
valueDisplay.style.cssText = `font-weight: 700; color: ${diagramTheme.highlight}; font-size: 16px; min-width: 24px;`;
valueDisplay.textContent = slider.value;
slider.oninput = () => {
valueDisplay.textContent = slider.value;
container.value = +slider.value;
container.dispatchEvent(new CustomEvent("input"));
};
container.appendChild(label);
container.appendChild(slider);
container.appendChild(valueDisplay);
container.value = +slider.value;
return container;
}
```
```{ojs}
//| echo: false
// Top-K filtering visualization
topKDiagram = {
const width = 700;
const height = 320;
const margin = { top: 50, right: 30, bottom: 60, left: 50 };
const chartWidth = (width - margin.left - margin.right - 40) / 2;
const chartHeight = height - margin.top - margin.bottom;
// Base logits and tokens as specified
const baseLogits = [3.0, 2.0, 1.5, 1.0, 0.5, 0.0, -0.5];
const tokens = ["the", "a", "an", "this", "that", "one", "some"];
// Softmax function
const softmax = (logits) => {
const maxLogit = Math.max(...logits);
const exps = logits.map(l => Math.exp(l - maxLogit));
const sum = exps.reduce((a, b) => a + b, 0);
return exps.map(e => e / sum);
};
// Original probabilities
const originalProbs = softmax(baseLogits);
// Top-K filtering: keep only top K, renormalize
const k = topKValue;
const indexed = originalProbs.map((p, i) => ({ prob: p, idx: i }));
const sorted = [...indexed].sort((a, b) => b.prob - a.prob);
const keptIndices = new Set(sorted.slice(0, k).map(x => x.idx));
const filteredProbs = originalProbs.map((p, i) => keptIndices.has(i) ? p : 0);
const filteredSum = filteredProbs.reduce((a, b) => a + b, 0);
const renormalizedProbs = filteredProbs.map(p => p > 0 ? p / filteredSum : 0);
// Create SVG
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`);
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 8);
// Scales
const xScale = d3.scaleBand()
.domain(tokens)
.range([0, chartWidth])
.padding(0.25);
const yScale = d3.scaleLinear()
.domain([0, Math.max(...originalProbs) * 1.15])
.range([chartHeight, 0]);
// Left chart: Original distribution
const leftChart = svg.append("g")
.attr("transform", `translate(${margin.left}, ${margin.top})`);
// Left title
leftChart.append("text")
.attr("x", chartWidth / 2)
.attr("y", -25)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text("Original Distribution");
// Left bars
leftChart.selectAll(".bar-orig")
.data(originalProbs)
.join("rect")
.attr("class", "bar-orig")
.attr("x", (d, i) => xScale(tokens[i]))
.attr("y", d => yScale(d))
.attr("width", xScale.bandwidth())
.attr("height", d => chartHeight - yScale(d))
.attr("fill", (d, i) => keptIndices.has(i) ? diagramTheme.accent : diagramTheme.nodeStroke)
.attr("rx", 3)
.attr("opacity", (d, i) => keptIndices.has(i) ? 1 : 0.4);
// Left probability labels
leftChart.selectAll(".prob-orig")
.data(originalProbs)
.join("text")
.attr("class", "prob-orig")
.attr("x", (d, i) => xScale(tokens[i]) + xScale.bandwidth() / 2)
.attr("y", d => yScale(d) - 6)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("font-weight", "500")
.text(d => d.toFixed(2));
// Left X axis
leftChart.append("g")
.attr("transform", `translate(0, ${chartHeight})`)
.call(d3.axisBottom(xScale).tickSize(0))
.call(g => g.select(".domain").attr("stroke", diagramTheme.nodeStroke))
.call(g => g.selectAll("text")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.attr("dy", "1em"));
// Arrow between charts
const arrowX = margin.left + chartWidth + 20;
const arrowY = margin.top + chartHeight / 2;
svg.append("path")
.attr("d", `M${arrowX - 5},${arrowY} L${arrowX + 15},${arrowY}`)
.attr("stroke", diagramTheme.highlight)
.attr("stroke-width", 2.5)
.attr("marker-end", "url(#topk-arrow)");
// Arrow marker
svg.append("defs")
.append("marker")
.attr("id", "topk-arrow")
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", diagramTheme.highlight);
// Filter label
svg.append("text")
.attr("x", arrowX + 5)
.attr("y", arrowY - 12)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.highlight)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(`top-${k}`);
// Right chart: Filtered distribution
const rightChart = svg.append("g")
.attr("transform", `translate(${margin.left + chartWidth + 40}, ${margin.top})`);
// Right title
rightChart.append("text")
.attr("x", chartWidth / 2)
.attr("y", -25)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text("After Top-K (Renormalized)");
// Right Y scale (for renormalized)
const yScaleRight = d3.scaleLinear()
.domain([0, Math.max(...renormalizedProbs) * 1.15])
.range([chartHeight, 0]);
// Right bars
rightChart.selectAll(".bar-filtered")
.data(renormalizedProbs)
.join("rect")
.attr("class", "bar-filtered")
.attr("x", (d, i) => xScale(tokens[i]))
.attr("y", d => d > 0 ? yScaleRight(d) : chartHeight)
.attr("width", xScale.bandwidth())
.attr("height", d => d > 0 ? chartHeight - yScaleRight(d) : 0)
.attr("fill", (d, i) => d > 0 ? diagramTheme.highlight : "transparent")
.attr("rx", 3);
// Filtered out indicators (X marks)
rightChart.selectAll(".filtered-out")
.data(renormalizedProbs)
.join("text")
.attr("class", "filtered-out")
.attr("x", (d, i) => xScale(tokens[i]) + xScale.bandwidth() / 2)
.attr("y", chartHeight - 15)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeStroke)
.attr("font-size", "16px")
.attr("opacity", d => d === 0 ? 0.6 : 0)
.text("×");
// Right probability labels
rightChart.selectAll(".prob-filtered")
.data(renormalizedProbs)
.join("text")
.attr("class", "prob-filtered")
.attr("x", (d, i) => xScale(tokens[i]) + xScale.bandwidth() / 2)
.attr("y", d => d > 0 ? yScaleRight(d) - 6 : chartHeight + 12)
.attr("text-anchor", "middle")
.attr("fill", d => d > 0 ? diagramTheme.nodeText : diagramTheme.nodeStroke)
.attr("font-size", "10px")
.attr("font-weight", "500")
.attr("opacity", d => d > 0 ? 1 : 0.5)
.text(d => d > 0 ? d.toFixed(2) : "0");
// Right X axis
rightChart.append("g")
.attr("transform", `translate(0, ${chartHeight})`)
.call(d3.axisBottom(xScale).tickSize(0))
.call(g => g.select(".domain").attr("stroke", diagramTheme.nodeStroke))
.call(g => g.selectAll("text")
.attr("fill", (d, i) => keptIndices.has(tokens.indexOf(d)) ? diagramTheme.nodeText : diagramTheme.nodeStroke)
.attr("font-size", "11px")
.attr("font-weight", "500")
.attr("opacity", (d, i) => keptIndices.has(tokens.indexOf(d)) ? 1 : 0.5)
.attr("dy", "1em"));
// Legend
const legend = svg.append("g")
.attr("transform", `translate(${width / 2}, ${height - 18})`);
legend.append("rect")
.attr("x", -120)
.attr("y", -8)
.attr("width", 12)
.attr("height", 12)
.attr("fill", diagramTheme.accent)
.attr("rx", 2);
legend.append("text")
.attr("x", -104)
.attr("y", 0)
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.text("Kept tokens");
legend.append("rect")
.attr("x", 20)
.attr("y", -8)
.attr("width", 12)
.attr("height", 12)
.attr("fill", diagramTheme.nodeStroke)
.attr("opacity", 0.4)
.attr("rx", 2);
legend.append("text")
.attr("x", 36)
.attr("y", 0)
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.text("Filtered out");
return svg.node();
}
```
```{python}
# Demonstrate top-k filtering
logits = torch.tensor([[1.0, 3.0, 0.5, 2.5, 0.0, 2.0, -1.0, 1.5]])
original_probs = F.softmax(logits, dim=-1)
print("Original probabilities:")
for i, p in enumerate(original_probs[0]):
print(f" Token {i}: {p.item():.3f}")
# Apply top-k filtering
for k in [3, 5]:
filtered = top_k_filtering(logits.clone(), k)
filtered_probs = F.softmax(filtered, dim=-1)
print(f"\nAfter top-k = {k}:")
for i, p in enumerate(filtered_probs[0]):
if p > 0:
print(f" Token {i}: {p.item():.3f}")
```
```{python}
# Visualize top-k effect
fig, axes = plt.subplots(1, 4, figsize=(16, 3))
logits = torch.randn(1, 20) * 2 # Random logits
for ax, k in zip(axes, [1, 3, 5, 20]):
filtered = top_k_filtering(logits.clone(), k)
probs = F.softmax(filtered, dim=-1)[0]
ax.bar(range(20), probs.numpy())
ax.set_xlabel('Token')
ax.set_ylabel('Probability')
ax.set_title(f'Top-k = {k}')
plt.suptitle('Effect of Top-k Filtering', fontsize=14)
plt.tight_layout()
plt.show()
```
### 4. Top-p (Nucleus) Sampling
Keep the smallest set of tokens whose cumulative probability exceeds p. This adapts to the distribution - keeps more tokens when uncertain, fewer when confident.
```{ojs}
//| echo: false
// Interactive slider for top-p threshold
viewof topPValue = Inputs.range([0.1, 1.0], {
step: 0.05,
value: 0.9,
label: "Top-p threshold"
})
```
```{ojs}
//| echo: false
// Top-p (nucleus) sampling visualization
topPDiagram = {
const width = 700;
const height = 400;
const theme = diagramTheme;
// Token data - sorted by probability (descending)
const tokens = ["the", "a", "an", "this", "that"];
const probs = [0.40, 0.30, 0.15, 0.10, 0.05];
// Calculate cumulative probabilities
const cumulative = [];
let sum = 0;
for (let i = 0; i < probs.length; i++) {
sum += probs[i];
cumulative.push(sum);
}
// Find cutoff index: first token where cumulative > p
let cutoffIndex = probs.length;
for (let i = 0; i < cumulative.length; i++) {
if (cumulative[i] > topPValue) {
cutoffIndex = i + 1; // Keep tokens up to and including this one
break;
}
}
// Kept tokens and renormalized probabilities
const keptProbs = probs.slice(0, cutoffIndex);
const keptSum = keptProbs.reduce((a, b) => a + b, 0);
const renormalized = keptProbs.map(p => p / keptSum);
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'IBM Plex Mono', 'Fira Code', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 12);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 28)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "15px")
.attr("font-weight", "600")
.attr("letter-spacing", "0.5px")
.text(`Top-p = ${topPValue.toFixed(2)} Nucleus Sampling`);
// Chart dimensions
const chartLeft = 100;
const chartRight = width - 50;
const chartWidth = chartRight - chartLeft;
const chartTop = 70;
const chartBottom = 260;
const chartHeight = chartBottom - chartTop;
// Bar dimensions
const barWidth = 50;
const barGap = 25;
const totalBarsWidth = tokens.length * barWidth + (tokens.length - 1) * barGap;
const barsStartX = chartLeft + (chartWidth - totalBarsWidth) / 2;
// Y scale
const yScale = d3.scaleLinear().domain([0, 1]).range([chartBottom, chartTop]);
// Draw y-axis
svg.append("line")
.attr("x1", chartLeft - 10)
.attr("y1", chartTop)
.attr("x2", chartLeft - 10)
.attr("y2", chartBottom)
.attr("stroke", theme.edgeStroke)
.attr("stroke-width", 1);
// Y-axis ticks
[0, 0.25, 0.5, 0.75, 1.0].forEach(tick => {
const y = yScale(tick);
svg.append("line")
.attr("x1", chartLeft - 15)
.attr("y1", y)
.attr("x2", chartLeft - 10)
.attr("y2", y)
.attr("stroke", theme.edgeStroke)
.attr("stroke-width", 1);
svg.append("text")
.attr("x", chartLeft - 20)
.attr("y", y)
.attr("text-anchor", "end")
.attr("dominant-baseline", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.7)
.text(tick.toFixed(2));
});
// Y-axis label
svg.append("text")
.attr("transform", `rotate(-90)`)
.attr("x", -(chartTop + chartHeight / 2))
.attr("y", 25)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.8)
.text("Probability");
// Draw threshold line (top-p cutoff)
const thresholdY = yScale(topPValue);
svg.append("line")
.attr("x1", chartLeft - 10)
.attr("y1", thresholdY)
.attr("x2", chartRight)
.attr("y2", thresholdY)
.attr("stroke", theme.highlight)
.attr("stroke-width", 2)
.attr("stroke-dasharray", "8,4")
.attr("opacity", 0.9);
svg.append("text")
.attr("x", chartRight + 5)
.attr("y", thresholdY)
.attr("dominant-baseline", "middle")
.attr("fill", theme.highlight)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(`p=${topPValue.toFixed(2)}`);
// Draw probability bars
tokens.forEach((token, i) => {
const x = barsStartX + i * (barWidth + barGap);
const prob = probs[i];
const barHeight = chartBottom - yScale(prob);
const isKept = i < cutoffIndex;
// Probability bar
svg.append("rect")
.attr("x", x)
.attr("y", yScale(prob))
.attr("width", barWidth)
.attr("height", barHeight)
.attr("fill", isKept ? theme.accent : theme.nodeFill)
.attr("stroke", isKept ? theme.accent : theme.nodeStroke)
.attr("stroke-width", 1.5)
.attr("rx", 4)
.attr("opacity", isKept ? 1 : 0.4);
// Probability value on bar
svg.append("text")
.attr("x", x + barWidth / 2)
.attr("y", yScale(prob) - 8)
.attr("text-anchor", "middle")
.attr("fill", isKept ? theme.accent : theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", isKept ? "600" : "400")
.attr("opacity", isKept ? 1 : 0.5)
.text(prob.toFixed(2));
// Token label
svg.append("text")
.attr("x", x + barWidth / 2)
.attr("y", chartBottom + 18)
.attr("text-anchor", "middle")
.attr("fill", isKept ? theme.nodeText : theme.nodeText)
.attr("font-size", "12px")
.attr("font-weight", isKept ? "600" : "400")
.attr("opacity", isKept ? 1 : 0.4)
.text(token);
// Cumulative probability marker (small diamond)
const cumY = yScale(cumulative[i]);
svg.append("path")
.attr("d", `M${x + barWidth / 2},${cumY - 5} l5,5 l-5,5 l-5,-5 z`)
.attr("fill", theme.highlight)
.attr("opacity", 0.9);
});
// Draw cumulative probability line
const linePoints = tokens.map((_, i) => {
return [barsStartX + i * (barWidth + barGap) + barWidth / 2, yScale(cumulative[i])];
});
const lineGenerator = d3.line().curve(d3.curveMonotoneX);
svg.append("path")
.attr("d", lineGenerator(linePoints))
.attr("fill", "none")
.attr("stroke", theme.highlight)
.attr("stroke-width", 2.5)
.attr("opacity", 0.8);
// Legend
const legendY = chartBottom + 50;
// Kept tokens indicator
svg.append("rect")
.attr("x", chartLeft)
.attr("y", legendY)
.attr("width", 16)
.attr("height", 16)
.attr("fill", theme.accent)
.attr("rx", 3);
svg.append("text")
.attr("x", chartLeft + 24)
.attr("y", legendY + 8)
.attr("dominant-baseline", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Kept tokens");
// Filtered tokens indicator
svg.append("rect")
.attr("x", chartLeft + 120)
.attr("y", legendY)
.attr("width", 16)
.attr("height", 16)
.attr("fill", theme.nodeFill)
.attr("stroke", theme.nodeStroke)
.attr("opacity", 0.4)
.attr("rx", 3);
svg.append("text")
.attr("x", chartLeft + 144)
.attr("y", legendY + 8)
.attr("dominant-baseline", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.6)
.text("Filtered out");
// Cumulative line indicator
svg.append("line")
.attr("x1", chartLeft + 250)
.attr("y1", legendY + 8)
.attr("x2", chartLeft + 280)
.attr("y2", legendY + 8)
.attr("stroke", theme.highlight)
.attr("stroke-width", 2.5);
svg.append("path")
.attr("d", `M${chartLeft + 265},${legendY + 3} l5,5 l-5,5 l-5,-5 z`)
.attr("fill", theme.highlight);
svg.append("text")
.attr("x", chartLeft + 290)
.attr("y", legendY + 8)
.attr("dominant-baseline", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Cumulative probability");
// Renormalized probabilities section
const renormY = legendY + 45;
svg.append("text")
.attr("x", width / 2)
.attr("y", renormY)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "13px")
.attr("font-weight", "600")
.text("Renormalized Probabilities");
// Draw renormalized probability bars (smaller)
const renormBarWidth = 60;
const renormBarHeight = 24;
const renormGap = 15;
const totalRenormWidth = cutoffIndex * (renormBarWidth + renormGap) - renormGap;
const renormStartX = (width - totalRenormWidth) / 2;
const renormBarY = renormY + 20;
for (let i = 0; i < cutoffIndex; i++) {
const x = renormStartX + i * (renormBarWidth + renormGap);
svg.append("rect")
.attr("x", x)
.attr("y", renormBarY)
.attr("width", renormBarWidth)
.attr("height", renormBarHeight)
.attr("fill", theme.accent)
.attr("rx", 6)
.attr("filter", `drop-shadow(0 0 6px ${theme.accentGlow})`);
svg.append("text")
.attr("x", x + renormBarWidth / 2)
.attr("y", renormBarY + renormBarHeight / 2)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "middle")
.attr("fill", theme.textOnAccent)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(`${tokens[i]}: ${renormalized[i].toFixed(2)}`);
}
// Summary stats
const statsY = renormBarY + renormBarHeight + 25;
svg.append("text")
.attr("x", width / 2)
.attr("y", statsY)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.7)
.text(`Kept ${cutoffIndex} of ${tokens.length} tokens (cumulative sum: ${keptSum.toFixed(2)} > ${topPValue.toFixed(2)})`);
return svg.node();
}
```
**Key advantage**: Top-p adapts to the distribution shape:
- Peaked (confident): Keeps fewer tokens
- Flat (uncertain): Keeps more tokens
```{python}
# Demonstrate top-p filtering
logits = torch.tensor([[3.0, 2.0, 1.5, 1.0, 0.5, 0.0, -0.5, -1.0]])
probs = F.softmax(logits, dim=-1)[0]
# Sort and show cumulative probabilities
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumulative = torch.cumsum(sorted_probs, dim=0)
print("Tokens sorted by probability:")
print(f"{'Token':<8} {'Prob':<10} {'Cumulative':<10}")
print("-" * 28)
for i, (idx, p, c) in enumerate(zip(sorted_idx, sorted_probs, cumulative)):
marker = " <- cutoff (p=0.9)" if c.item() > 0.9 and (i == 0 or cumulative[i-1].item() <= 0.9) else ""
print(f"{idx.item():<8} {p.item():<10.3f} {c.item():<10.3f}{marker}")
```
```{python}
# Compare top-p on peaked vs flat distributions
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
# Peaked distribution (confident model)
peaked_logits = torch.tensor([[5.0, 1.0, 0.5, 0.0, -0.5, -1.0, -1.5, -2.0]])
# Flat distribution (uncertain model)
flat_logits = torch.tensor([[1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3]])
for row, (logits, name) in enumerate([(peaked_logits, "Peaked (confident)"), (flat_logits, "Flat (uncertain)")]):
for col, p in enumerate([0.5, 0.7, 0.9]):
filtered = top_p_filtering(logits.clone(), p)
probs = F.softmax(filtered, dim=-1)[0]
ax = axes[row, col]
ax.bar(range(8), probs.numpy())
ax.set_xlabel('Token')
ax.set_ylabel('Probability')
ax.set_title(f'{name}\np={p}, tokens kept: {(probs > 0).sum().item()}')
plt.suptitle('Top-p Adapts to Distribution Shape', fontsize=14)
plt.tight_layout()
plt.show()
print("Notice: Top-p keeps more tokens when the model is uncertain (flat dist)")
print("and fewer tokens when confident (peaked dist)!")
```
## Combining Strategies
Each strategy has trade-offs: temperature affects the overall distribution shape, top-k provides a hard cutoff, and top-p adapts to model confidence. In practice, combining them often works better than any single approach:
```{python}
# The typical generation pipeline
logits_example = torch.randn(1, vocab_size) * 2
# Step 1: Apply temperature
temperature = 0.7
logits_temp = logits_example / temperature
# Step 2: Apply top-k filtering
logits_topk = top_k_filtering(logits_temp.clone(), top_k=20)
# Step 3: Apply top-p filtering
logits_topp = top_p_filtering(logits_topk.clone(), top_p=0.9)
# Step 4: Sample from the distribution
probs = F.softmax(logits_topp, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
print("Combined filtering pipeline:")
print(f" Original vocab: {vocab_size} tokens")
print(f" After top-k=20: {(F.softmax(logits_topk, dim=-1) > 0).sum().item()} tokens")
print(f" After top-p=0.9: {(probs > 0).sum().item()} tokens")
print(f" Sampled token: {next_token.item()}")
```
```{python}
# Compare different strategy combinations
strategies = [
("Greedy", {"do_sample": False}),
("Temperature=0.5", {"temperature": 0.5, "do_sample": True}),
("Temperature=1.0", {"temperature": 1.0, "do_sample": True}),
("Top-k=5", {"top_k": 5, "do_sample": True}),
("Top-p=0.9", {"top_p": 0.9, "do_sample": True}),
("Combined (T=0.7, k=20, p=0.9)", {"temperature": 0.7, "top_k": 20, "top_p": 0.9, "do_sample": True}),
]
print("Comparing strategies (3 samples each):\n")
for name, kwargs in strategies:
print(f"{name}:")
for i in range(3):
torch.manual_seed(100 + i)
out = generate(model, prompt, max_new_tokens=10, **kwargs)
tokens = out[0, 5:].tolist()
print(f" {tokens}")
print()
```
## Measuring Output Diversity
Let's quantify how different strategies affect output diversity:
```{python}
def measure_diversity(model, prompt, num_samples=20, **kwargs):
"""Measure how diverse the generated outputs are."""
outputs = []
for i in range(num_samples):
torch.manual_seed(i)
out = generate(model, prompt, max_new_tokens=15, **kwargs)
outputs.append(tuple(out[0].tolist()))
unique = len(set(outputs))
return unique / num_samples
# Compare diversity across settings
settings = [
("Greedy", {"do_sample": False}),
("Temp=0.3", {"temperature": 0.3, "do_sample": True}),
("Temp=0.7", {"temperature": 0.7, "do_sample": True}),
("Temp=1.0", {"temperature": 1.0, "do_sample": True}),
("Temp=1.5", {"temperature": 1.5, "do_sample": True}),
]
diversities = []
for name, kwargs in settings:
div = measure_diversity(model, prompt, num_samples=20, **kwargs)
diversities.append((name, div))
print(f"{name}: {div*100:.0f}% unique outputs")
```
```{python}
# Visualize diversity
names = [d[0] for d in diversities]
values = [d[1] * 100 for d in diversities]
plt.figure(figsize=(10, 5))
colors = ['gray'] + list(plt.cm.viridis(np.linspace(0.2, 0.8, len(names)-1)))
plt.bar(names, values, color=colors)
plt.ylabel('% Unique Outputs')
plt.title('Output Diversity vs Temperature')
plt.ylim(0, 105)
for i, v in enumerate(values):
plt.text(i, v + 2, f'{v:.0f}%', ha='center')
plt.show()
```
## Choosing Parameters
Recommended settings for different use cases:
| Goal | Temperature | Top-k | Top-p |
|------|-------------|-------|-------|
| Code generation | 0.2-0.4 | 10-20 | 0.8-0.9 |
| Factual/deterministic | 0.3-0.5 | 5-10 | 0.5-0.7 |
| Coherent responses | 0.7-0.9 | 20-50 | 0.85-0.92 |
| Creative writing | 0.8-1.2 | 40-100 | 0.9-0.95 |
```{python}
# Example settings for different applications
use_cases = {
"Code generation": {"temperature": 0.2, "top_p": 0.9, "do_sample": True},
"Balanced chat": {"temperature": 0.7, "top_p": 0.9, "do_sample": True},
"Creative writing": {"temperature": 1.0, "top_p": 0.95, "do_sample": True},
"Brainstorming": {"temperature": 1.5, "top_p": 0.95, "do_sample": True},
}
print("Sample outputs for different use cases:\n")
for name, kwargs in use_cases.items():
print(f"{name}:")
for i in range(2):
torch.manual_seed(42 + i)
out = generate(model, prompt, max_new_tokens=12, **kwargs)
print(f" {out[0, 5:].tolist()}")
print()
```
## Repetition Penalty
A common problem with text generation is **repetition** - the model gets stuck repeating the same tokens or phrases. Repetition penalties address this by reducing the probability of tokens that have already appeared.
```{ojs}
//| echo: false
// Penalty slider with distinctive styling
viewof repetitionPenalty = {
const container = d3.create("div")
.style("font-family", "'IBM Plex Mono', 'Fira Code', monospace")
.style("margin-bottom", "16px");
container.append("label")
.style("display", "block")
.style("font-size", "13px")
.style("font-weight", "600")
.style("color", diagramTheme.nodeText)
.style("margin-bottom", "8px")
.text("Repetition Penalty");
const sliderRow = container.append("div")
.style("display", "flex")
.style("align-items", "center")
.style("gap", "12px");
const input = sliderRow.append("input")
.attr("type", "range")
.attr("min", "1.0")
.attr("max", "2.0")
.attr("step", "0.05")
.attr("value", "1.2")
.style("width", "200px")
.style("accent-color", diagramTheme.highlight);
const valueDisplay = sliderRow.append("span")
.style("font-size", "14px")
.style("font-weight", "700")
.style("color", diagramTheme.highlight)
.style("min-width", "40px")
.text("1.2");
const effectLabel = sliderRow.append("span")
.style("font-size", "11px")
.style("color", diagramTheme.nodeText)
.style("opacity", "0.7")
.style("margin-left", "8px")
.text("(discourages repetition)");
input.on("input", function() {
valueDisplay.text(parseFloat(this.value).toFixed(2));
if (parseFloat(this.value) === 1.0) {
effectLabel.text("(no effect)");
} else {
effectLabel.text("(discourages repetition)");
}
});
const node = container.node();
node.value = 1.2;
input.on("input", function() {
node.value = parseFloat(this.value);
valueDisplay.text(node.value.toFixed(2));
if (node.value === 1.0) {
effectLabel.text("(no effect)");
} else {
effectLabel.text("(discourages repetition)");
}
node.dispatchEvent(new Event("input"));
});
return node;
}
```
```{ojs}
//| echo: false
// Data for the repetition penalty diagram
repetitionData = {
const tokens = ["the", "cat", "dog", "sat"];
const originalLogits = [2.0, -0.5, 1.0, 1.5];
const isPenalized = [true, true, false, false]; // "the" and "cat" appeared before
const penalty = repetitionPenalty;
// Apply penalty: positive logits divided, negative logits multiplied
const penalizedLogits = originalLogits.map((logit, i) => {
if (!isPenalized[i]) return logit;
return logit > 0 ? logit / penalty : logit * penalty;
});
return {
tokens,
originalLogits,
penalizedLogits,
isPenalized,
penalty
};
}
// Interactive repetition penalty visualization
repetitionPenaltyDiagram = {
const width = 650;
const height = 320;
const data = repetitionData;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'IBM Plex Mono', 'Fira Code', monospace");
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", diagramTheme.bg)
.attr("rx", 8);
// Title
svg.append("text")
.attr("x", width / 2)
.attr("y", 28)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "14px")
.attr("font-weight", "600")
.text(`Repetition Penalty Effect (penalty = ${data.penalty.toFixed(2)})`);
// Previous tokens indicator
const prevBox = svg.append("g").attr("transform", "translate(325, 55)");
prevBox.append("rect")
.attr("x", -130)
.attr("y", -12)
.attr("width", 260)
.attr("height", 24)
.attr("rx", 4)
.attr("fill", diagramTheme.bgSecondary)
.attr("stroke", diagramTheme.nodeStroke)
.attr("stroke-width", 1);
prevBox.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "11px")
.html(`Previous tokens: `);
prevBox.append("text")
.attr("x", 45)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.highlight)
.attr("font-size", "11px")
.attr("font-weight", "600")
.text(`["the", "cat"]`);
// Bar chart settings
const barWidth = 50;
const barSpacing = 110;
const maxBarHeight = 100;
const baseY = 200;
const startX = 95;
// Scale for logits to bar height
const maxLogit = Math.max(...data.originalLogits.map(Math.abs), ...data.penalizedLogits.map(Math.abs));
const scale = maxBarHeight / (maxLogit + 0.5);
// Draw bars for each token
data.tokens.forEach((token, i) => {
const x = startX + i * barSpacing;
const originalLogit = data.originalLogits[i];
const penalizedLogit = data.penalizedLogits[i];
const penalized = data.isPenalized[i];
// Token label at bottom
const labelY = baseY + 55;
svg.append("text")
.attr("x", x + barWidth / 2)
.attr("y", labelY)
.attr("text-anchor", "middle")
.attr("fill", penalized ? diagramTheme.highlight : diagramTheme.nodeText)
.attr("font-size", "13px")
.attr("font-weight", penalized ? "700" : "500")
.text(`"${token}"`);
// Penalized indicator
if (penalized) {
svg.append("text")
.attr("x", x + barWidth / 2)
.attr("y", labelY + 16)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.highlight)
.attr("font-size", "9px")
.attr("opacity", 0.8)
.text("(penalized)");
}
// Original bar (left, semi-transparent)
const origHeight = Math.abs(originalLogit) * scale;
const origY = originalLogit >= 0 ? baseY - origHeight : baseY;
svg.append("rect")
.attr("x", x)
.attr("y", origY)
.attr("width", barWidth / 2 - 2)
.attr("height", origHeight)
.attr("fill", diagramTheme.accent)
.attr("opacity", 0.4)
.attr("rx", 3);
// Penalized bar (right, full opacity)
const penHeight = Math.abs(penalizedLogit) * scale;
const penY = penalizedLogit >= 0 ? baseY - penHeight : baseY;
const barColor = penalized ? diagramTheme.highlight : diagramTheme.accent;
svg.append("rect")
.attr("x", x + barWidth / 2 + 2)
.attr("y", penY)
.attr("width", barWidth / 2 - 2)
.attr("height", penHeight)
.attr("fill", barColor)
.attr("rx", 3);
// Value labels above/below bars
// Original value
const origLabelY = originalLogit >= 0 ? origY - 8 : origY + origHeight + 14;
svg.append("text")
.attr("x", x + barWidth / 4)
.attr("y", origLabelY)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.accent)
.attr("font-size", "10px")
.attr("opacity", 0.7)
.text(originalLogit.toFixed(1));
// Penalized value
const penLabelY = penalizedLogit >= 0 ? penY - 8 : penY + penHeight + 14;
svg.append("text")
.attr("x", x + barWidth * 3 / 4)
.attr("y", penLabelY)
.attr("text-anchor", "middle")
.attr("fill", barColor)
.attr("font-size", "10px")
.attr("font-weight", "600")
.text(penalizedLogit.toFixed(2));
// Show math formula for penalized tokens
if (penalized && data.penalty > 1.0) {
const formula = originalLogit > 0
? `${originalLogit.toFixed(1)} ÷ ${data.penalty.toFixed(2)}`
: `${originalLogit.toFixed(1)} × ${data.penalty.toFixed(2)}`;
svg.append("text")
.attr("x", x + barWidth / 2)
.attr("y", baseY + 35)
.attr("text-anchor", "middle")
.attr("fill", diagramTheme.highlight)
.attr("font-size", "9px")
.attr("opacity", 0.9)
.text(formula);
}
});
// Zero line
svg.append("line")
.attr("x1", startX - 20)
.attr("y1", baseY)
.attr("x2", startX + 4 * barSpacing - 30)
.attr("y2", baseY)
.attr("stroke", diagramTheme.nodeStroke)
.attr("stroke-width", 1)
.attr("stroke-dasharray", "4,3");
svg.append("text")
.attr("x", startX - 35)
.attr("y", baseY + 4)
.attr("text-anchor", "end")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.attr("opacity", 0.6)
.text("0");
// Legend
const legendY = height - 25;
const legendX = width / 2 - 100;
// Original legend item
svg.append("rect")
.attr("x", legendX)
.attr("y", legendY - 6)
.attr("width", 12)
.attr("height", 12)
.attr("fill", diagramTheme.accent)
.attr("opacity", 0.4)
.attr("rx", 2);
svg.append("text")
.attr("x", legendX + 18)
.attr("y", legendY)
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.text("Original");
// After penalty legend item
svg.append("rect")
.attr("x", legendX + 80)
.attr("y", legendY - 6)
.attr("width", 12)
.attr("height", 12)
.attr("fill", diagramTheme.highlight)
.attr("rx", 2);
svg.append("text")
.attr("x", legendX + 98)
.attr("y", legendY)
.attr("dominant-baseline", "central")
.attr("fill", diagramTheme.nodeText)
.attr("font-size", "10px")
.text("After penalty");
return svg.node();
}
```
The repetition penalty works as follows:
- For tokens that have appeared before:
- If the logit is **positive**, divide by the penalty (reduces probability)
- If the logit is **negative**, multiply by the penalty (makes it more negative)
- Penalty = 1.0 means no change
- Penalty > 1.0 discourages repetition (common values: 1.1 - 1.5)
```{python}
# Demonstrate repetition penalty
logits = torch.tensor([[2.0, 1.5, 1.0, 0.5, -0.5, -1.0]])
previous_tokens = torch.tensor([[0, 1, 4]]) # Tokens 0, 1, and 4 appeared
print("Original logits:")
for i, l in enumerate(logits[0]):
marker = " (appeared)" if i in [0, 1, 4] else ""
print(f" Token {i}: {l.item():.2f}{marker}")
# Apply penalty
penalized = apply_repetition_penalty(logits, previous_tokens, penalty=1.5)
print("\nAfter repetition penalty (1.5):")
for i, l in enumerate(penalized[0]):
marker = " (appeared)" if i in [0, 1, 4] else ""
print(f" Token {i}: {l.item():.2f}{marker}")
# Compare probabilities
orig_probs = F.softmax(logits, dim=-1)
new_probs = F.softmax(penalized, dim=-1)
print("\nProbability changes:")
for i in [0, 1, 2]:
print(f" Token {i}: {orig_probs[0,i].item():.3f} -> {new_probs[0,i].item():.3f}")
```
```{python}
# Generate with and without repetition penalty
print("Generation without repetition penalty:")
out = generate_greedy(model, prompt, max_new_tokens=30)
tokens = out[0].tolist()
from collections import Counter
counts = Counter(tokens)
print(f" Tokens: {tokens[5:]}")
print(f" Most common: {counts.most_common(3)}")
print("\nGeneration with repetition penalty (1.3):")
out = generate(model, prompt, max_new_tokens=30, do_sample=False, repetition_penalty=1.3)
tokens = out[0].tolist()
counts = Counter(tokens)
print(f" Tokens: {tokens[5:]}")
print(f" Most common: {counts.most_common(3)}")
```
**When to use repetition penalty:**
- Always for open-ended generation (stories, chat)
- Less critical for short, structured outputs (classification, extraction)
- Typical values: 1.1 for mild effect, 1.3-1.5 for stronger effect
- Too high (> 2.0) can make outputs incoherent
## Stop Conditions
Generation needs to know when to stop. There are two main stopping conditions:
1. **Maximum length** (`max_new_tokens`) - Hard limit on generated tokens
2. **EOS token** (`eos_token_id`) - Stop when a special end-of-sequence token is generated
```{ojs}
//| echo: false
// Configuration for the stopping conditions simulation
stoppingConfig = ({
maxTokens: 8,
eosToken: "EOS",
eosPosition: 6, // EOS appears at step 6 (0-indexed)
tokens: ["The", "cat", "sat", "on", "the", "mat", "EOS", ".", "and"]
})
// Current step in the simulation (0 = start, then each step is a generation cycle)
viewof stoppingStep = Inputs.range([0, stoppingConfig.maxTokens], {
step: 1,
value: 0,
label: "Generation step"
})
// Determine what state we're in at the current step
stoppingState = {
const step = stoppingStep;
const config = stoppingConfig;
// Build token sequence up to current step
const generatedTokens = config.tokens.slice(0, step);
const currentToken = step > 0 ? config.tokens[step - 1] : null;
// Check conditions
const hitMaxTokens = step >= config.maxTokens;
const hitEOS = currentToken === config.eosToken;
const shouldStop = hitMaxTokens || hitEOS;
// Determine which node is active in the flowchart
let activeNode;
let phase;
if (step === 0) {
activeNode = "start";
phase = "ready";
} else if (hitEOS) {
activeNode = "eos-check";
phase = "eos-stop";
} else if (hitMaxTokens) {
activeNode = "max-check";
phase = "max-stop";
} else {
// Still generating - cycle through the flow
activeNode = "generate";
phase = "generating";
}
return {
step,
generatedTokens,
currentToken,
hitMaxTokens,
hitEOS,
shouldStop,
activeNode,
phase,
tokenCount: step,
maxTokens: config.maxTokens
};
}
// Render the interactive stopping conditions diagram
stoppingDiagram = {
const state = stoppingState;
const theme = diagramTheme;
const width = 650;
const height = 420;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`);
// Background
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", theme.bg)
.attr("rx", 8);
// Defs for markers and filters
const defs = svg.append("defs");
// Arrow markers
const createMarker = (id, color) => {
defs.append("marker")
.attr("id", id)
.attr("viewBox", "0 -5 10 10")
.attr("refX", 8)
.attr("refY", 0)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,-5L10,0L0,5")
.attr("fill", color);
};
createMarker("stop-arrow-normal", theme.edgeStroke);
createMarker("stop-arrow-highlight", theme.highlight);
createMarker("stop-arrow-accent", theme.accent);
// Glow filter
const filter = defs.append("filter")
.attr("id", "stop-glow")
.attr("x", "-50%")
.attr("y", "-50%")
.attr("width", "200%")
.attr("height", "200%");
filter.append("feGaussianBlur")
.attr("stdDeviation", "3")
.attr("result", "blur");
filter.append("feMerge")
.selectAll("feMergeNode")
.data(["blur", "SourceGraphic"])
.enter()
.append("feMergeNode")
.attr("in", d => d);
// Node positions
const nodes = {
start: { x: 100, y: 60, label: "Start", sublabel: "Begin generation" },
generate: { x: 100, y: 160, label: "Generate Token", sublabel: "Get next prediction" },
maxCheck: { x: 280, y: 160, label: "max_tokens?", sublabel: `${state.tokenCount}/${state.maxTokens}`, isDecision: true },
eosCheck: { x: 450, y: 160, label: "EOS token?", sublabel: state.currentToken || "—", isDecision: true },
continue: { x: 450, y: 280, label: "Continue", sublabel: "Loop back" },
stop: { x: 280, y: 340, label: "STOP", sublabel: "Return sequence", isTerminal: true }
};
// Determine active states
const getNodeState = (nodeKey) => {
if (state.phase === "ready" && nodeKey === "start") return "active";
if (state.phase === "generating" && nodeKey === "generate") return "active";
if (state.phase === "generating" && nodeKey === "maxCheck") return "pending";
if (state.phase === "generating" && nodeKey === "eosCheck") return "pending";
if (state.phase === "generating" && nodeKey === "continue") return "pending";
if (state.phase === "max-stop" && nodeKey === "maxCheck") return "triggered";
if (state.phase === "max-stop" && nodeKey === "stop") return "active";
if (state.phase === "eos-stop" && nodeKey === "eosCheck") return "triggered";
if (state.phase === "eos-stop" && nodeKey === "stop") return "active";
return "inactive";
};
// Draw edges first
const edges = [
{ from: "start", to: "generate" },
{ from: "generate", to: "maxCheck" },
{ from: "maxCheck", to: "eosCheck", label: "No" },
{ from: "maxCheck", to: "stop", label: "Yes", condition: "max" },
{ from: "eosCheck", to: "continue", label: "No" },
{ from: "eosCheck", to: "stop", label: "Yes", condition: "eos" },
{ from: "continue", to: "generate", curved: true }
];
const edgeGroup = svg.append("g").attr("class", "edges");
edges.forEach(edge => {
const from = nodes[edge.from];
const to = nodes[edge.to];
// Determine if this edge should be highlighted
let isHighlighted = false;
let isTriggered = false;
if (state.phase === "max-stop" && edge.condition === "max") {
isTriggered = true;
} else if (state.phase === "eos-stop" && edge.condition === "eos") {
isTriggered = true;
} else if (state.phase === "generating") {
// Highlight the normal flow
if (edge.from === "start" || edge.from === "generate" ||
(edge.from === "maxCheck" && edge.label === "No") ||
(edge.from === "eosCheck" && edge.label === "No") ||
edge.from === "continue") {
isHighlighted = true;
}
}
const color = isTriggered ? theme.highlight : (isHighlighted ? theme.accent : theme.edgeStroke);
const markerId = isTriggered ? "stop-arrow-highlight" : (isHighlighted ? "stop-arrow-accent" : "stop-arrow-normal");
const strokeWidth = isTriggered ? 3 : (isHighlighted ? 2 : 1.5);
// Calculate path
let pathD;
const nodeWidth = 110;
const nodeHeight = 50;
if (edge.curved) {
// Loop back arrow
const startX = from.x;
const startY = from.y - nodeHeight/2;
const endX = to.x;
const endY = to.y + nodeHeight/2;
const cpX = (startX + endX) / 2 - 80;
pathD = `M${startX},${startY} C${cpX},${startY - 40} ${cpX},${endY + 40} ${endX},${endY}`;
} else {
// Straight lines with offsets
let x1 = from.x, y1 = from.y, x2 = to.x, y2 = to.y;
// Adjust start/end points based on direction
if (Math.abs(x2 - x1) > Math.abs(y2 - y1)) {
// Horizontal-ish
x1 += (x2 > x1 ? nodeWidth/2 : -nodeWidth/2);
x2 -= (x2 > x1 ? nodeWidth/2 + 8 : -nodeWidth/2 - 8);
} else {
// Vertical-ish
y1 += (y2 > y1 ? nodeHeight/2 : -nodeHeight/2);
y2 -= (y2 > y1 ? nodeHeight/2 + 8 : -nodeHeight/2 - 8);
}
pathD = `M${x1},${y1} L${x2},${y2}`;
}
const path = edgeGroup.append("path")
.attr("d", pathD)
.attr("fill", "none")
.attr("stroke", color)
.attr("stroke-width", strokeWidth)
.attr("marker-end", `url(#${markerId})`);
if (isTriggered) {
path.attr("filter", "url(#stop-glow)");
}
// Edge label
if (edge.label) {
const midX = (from.x + to.x) / 2;
const midY = (from.y + to.y) / 2;
let labelX = midX, labelY = midY;
// Offset labels
if (edge.from === "maxCheck" && edge.to === "stop") {
labelX -= 15;
labelY -= 10;
} else if (edge.from === "eosCheck" && edge.to === "stop") {
labelX += 15;
labelY += 10;
} else if (edge.label === "No") {
labelY -= 12;
}
edgeGroup.append("text")
.attr("x", labelX)
.attr("y", labelY)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isTriggered ? theme.highlight : theme.nodeText)
.attr("font-size", "10px")
.attr("font-weight", isTriggered ? "600" : "400")
.text(edge.label);
}
});
// Draw nodes
const nodeGroup = svg.append("g").attr("class", "nodes");
Object.entries(nodes).forEach(([key, node]) => {
const nodeState = getNodeState(key);
const g = nodeGroup.append("g")
.attr("transform", `translate(${node.x}, ${node.y})`);
const w = node.isDecision ? 100 : 110;
const h = 50;
// Determine colors based on state
let fill, stroke, textColor;
if (nodeState === "active") {
fill = theme.accent;
stroke = theme.accent;
textColor = theme.textOnAccent;
} else if (nodeState === "triggered") {
fill = theme.highlight;
stroke = theme.highlight;
textColor = theme.textOnHighlight;
} else if (nodeState === "pending") {
fill = theme.nodeFillHover;
stroke = theme.nodeStroke;
textColor = theme.nodeText;
} else {
fill = theme.nodeFill;
stroke = theme.nodeStroke;
textColor = theme.nodeText;
}
// Draw shape (diamond for decision, rounded rect for others)
if (node.isDecision) {
// Diamond shape
const points = [
[0, -h/2 - 5],
[w/2 + 10, 0],
[0, h/2 + 5],
[-w/2 - 10, 0]
].map(p => p.join(",")).join(" ");
const diamond = g.append("polygon")
.attr("points", points)
.attr("fill", fill)
.attr("stroke", stroke)
.attr("stroke-width", nodeState === "triggered" ? 2.5 : 1.5);
if (nodeState === "active" || nodeState === "triggered") {
diamond.attr("filter", "url(#stop-glow)");
}
} else {
// Rounded rectangle
const rect = g.append("rect")
.attr("x", -w/2)
.attr("y", -h/2)
.attr("width", w)
.attr("height", h)
.attr("rx", node.isTerminal ? 25 : 6)
.attr("fill", fill)
.attr("stroke", stroke)
.attr("stroke-width", nodeState === "triggered" || nodeState === "active" ? 2.5 : 1.5);
if (nodeState === "active" || nodeState === "triggered") {
rect.attr("filter", "url(#stop-glow)");
}
}
// Main label
g.append("text")
.attr("x", 0)
.attr("y", node.sublabel ? -6 : 0)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "12px")
.attr("font-weight", "600")
.text(node.label);
// Sublabel
if (node.sublabel) {
g.append("text")
.attr("x", 0)
.attr("y", 10)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "10px")
.attr("opacity", 0.85)
.text(node.sublabel);
}
});
// Token sequence display
const tokenY = 390;
const tokenGroup = svg.append("g").attr("transform", `translate(20, ${tokenY})`);
tokenGroup.append("text")
.attr("x", 0)
.attr("y", 0)
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("font-weight", "500")
.text("Generated:");
state.generatedTokens.forEach((token, i) => {
const isEOS = token === "EOS";
const tokenX = 75 + i * 55;
tokenGroup.append("rect")
.attr("x", tokenX - 22)
.attr("y", -12)
.attr("width", 48)
.attr("height", 24)
.attr("rx", 4)
.attr("fill", isEOS ? theme.highlight : theme.bgSecondary)
.attr("stroke", isEOS ? theme.highlight : theme.nodeStroke)
.attr("stroke-width", 1);
tokenGroup.append("text")
.attr("x", tokenX)
.attr("y", 0)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", isEOS ? theme.textOnHighlight : theme.nodeText)
.attr("font-size", "11px")
.attr("font-family", "monospace")
.text(token);
});
// Status message
let statusText = "";
let statusColor = theme.nodeText;
if (state.phase === "ready") {
statusText = "Ready to generate. Use the slider to step through.";
} else if (state.phase === "generating") {
statusText = `Step ${state.step}: Generated "${state.currentToken}" — checking stop conditions...`;
statusColor = theme.accent;
} else if (state.phase === "max-stop") {
statusText = `STOPPED: Reached max_tokens limit (${state.maxTokens})`;
statusColor = theme.highlight;
} else if (state.phase === "eos-stop") {
statusText = `STOPPED: EOS token generated at step ${state.step}`;
statusColor = theme.highlight;
}
svg.append("text")
.attr("x", width / 2)
.attr("y", 25)
.attr("text-anchor", "middle")
.attr("fill", statusColor)
.attr("font-size", "13px")
.attr("font-weight", "500")
.text(statusText);
return svg.node();
}
```
```{python}
# Demonstrate EOS stopping
# In real models, EOS is a special token. Here we'll use token 42 as our "EOS"
eos_id = 42
print(f"Generating with EOS token = {eos_id}")
print(f"(Generation stops early if token {eos_id} is produced)")
# Without EOS
out_no_eos = generate_greedy(model, prompt, max_new_tokens=20)
print(f"\nWithout EOS check: {len(out_no_eos[0]) - 5} new tokens generated")
print(f" Tokens: {out_no_eos[0, 5:].tolist()}")
# With EOS (may stop early if 42 is generated)
out_with_eos = generate_greedy(model, prompt, max_new_tokens=20, eos_token_id=eos_id)
print(f"\nWith EOS check: {len(out_with_eos[0]) - 5} new tokens generated")
print(f" Tokens: {out_with_eos[0, 5:].tolist()}")
if len(out_with_eos[0]) < len(out_no_eos[0]):
print(f" (Stopped early due to EOS token)")
```
**Practical notes on stopping:**
- Always set a reasonable `max_new_tokens` to prevent runaway generation
- EOS tokens are essential for chat/instruction models to indicate response completion
- Batched generation continues until ALL sequences hit a stop condition
- Some APIs support multiple stop sequences (not just EOS)
## KV-Cache Optimization
In generation, we process one new token at a time. Without optimization, we'd recompute attention for ALL previous tokens every step - wasting computation!
```{ojs}
//| echo: false
// Step control for KV-cache comparison
viewof kvStep = Inputs.range([1, 3], {
value: 1,
step: 1,
label: "Generation Step"
})
// Token sequence
kvTokens = ["A", "B", "C", "D"]
// Calculate what's computed at each step for both approaches
kvStepData = {
const step = kvStep;
// Naive approach - recomputes everything up to current position
// Step 1: compute A,B (2 ops)
// Step 2: compute A,B,C (3 ops)
// Step 3: compute A,B,C,D (4 ops)
const naiveCurrentOps = step + 1; // How many computed this step
const naiveTotalOps = step === 1 ? 2 : step === 2 ? 5 : 9; // Running total: 2, 2+3=5, 5+4=9
// Cached approach - only compute new tokens
// Step 1: compute A,B (2 ops) -> cache
// Step 2: compute C (1 op) -> append
// Step 3: compute D (1 op) -> append
const cachedCurrentOps = step === 1 ? 2 : 1;
const cachedTotalOps = step === 1 ? 2 : step === 2 ? 3 : 4; // Running total: 2, 2+1=3, 3+1=4
return {
step,
naive: {
currentOps: naiveCurrentOps,
totalOps: naiveTotalOps,
tokensComputed: kvTokens.slice(0, step + 1),
redundant: step > 1 ? kvTokens.slice(0, step) : []
},
cached: {
currentOps: cachedCurrentOps,
totalOps: cachedTotalOps,
tokensComputed: step === 1 ? ["A", "B"] : [kvTokens[step]],
cached: step > 1 ? kvTokens.slice(0, step) : []
}
};
}
// KV-Cache comparison visualization
kvCacheComparison = {
const width = 700;
const height = 320;
const theme = diagramTheme;
const data = kvStepData;
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${width} ${height}`)
.style("font-family", "'JetBrains Mono', 'Fira Code', monospace");
// Background with subtle gradient
const defs = svg.append("defs");
const bgGrad = defs.append("linearGradient")
.attr("id", "kv-bg-grad")
.attr("x1", "0%").attr("y1", "0%")
.attr("x2", "100%").attr("y2", "100%");
bgGrad.append("stop").attr("offset", "0%").attr("stop-color", theme.bg);
bgGrad.append("stop").attr("offset", "100%").attr("stop-color", theme.bgSecondary);
svg.append("rect")
.attr("width", width)
.attr("height", height)
.attr("fill", "url(#kv-bg-grad)")
.attr("rx", 8);
// Colors for states
const colors = {
redundant: theme.isDark ? "#ef4444" : "#dc2626", // Red for wasted computation
redundantBg: theme.isDark ? "rgba(239,68,68,0.15)" : "rgba(220,38,38,0.1)",
efficient: theme.isDark ? "#22c55e" : "#16a34a", // Green for efficient
efficientBg: theme.isDark ? "rgba(34,197,94,0.15)" : "rgba(22,163,74,0.1)",
cached: theme.isDark ? "#3b82f6" : "#2563eb", // Blue for cached
cachedBg: theme.isDark ? "rgba(59,130,246,0.15)" : "rgba(37,99,235,0.1)",
current: theme.highlight, // Orange for current computation
currentBg: theme.highlightGlow
};
// Layout
const leftX = 175;
const rightX = 525;
const headerY = 30;
const tokenY = 90;
const statsY = 240;
const boxSize = 50;
const boxGap = 8;
// Helper to draw a computation box
function drawBox(g, x, y, label, state) {
const fill = state === "redundant" ? colors.redundantBg :
state === "efficient" ? colors.efficientBg :
state === "cached" ? colors.cachedBg :
state === "current" ? colors.currentBg : theme.nodeFill;
const stroke = state === "redundant" ? colors.redundant :
state === "efficient" ? colors.efficient :
state === "cached" ? colors.cached :
state === "current" ? colors.current : theme.nodeStroke;
const textColor = state === "redundant" ? colors.redundant :
state === "efficient" ? colors.efficient :
state === "cached" ? colors.cached :
state === "current" ? colors.current : theme.nodeText;
const box = g.append("g").attr("transform", `translate(${x}, ${y})`);
box.append("rect")
.attr("x", -boxSize/2)
.attr("y", -boxSize/2)
.attr("width", boxSize)
.attr("height", boxSize)
.attr("rx", 6)
.attr("fill", fill)
.attr("stroke", stroke)
.attr("stroke-width", state === "current" ? 2.5 : 1.5);
if (state === "current") {
box.append("rect")
.attr("x", -boxSize/2)
.attr("y", -boxSize/2)
.attr("width", boxSize)
.attr("height", boxSize)
.attr("rx", 6)
.attr("fill", "none")
.attr("stroke", colors.current)
.attr("stroke-width", 2)
.style("filter", `drop-shadow(0 0 8px ${colors.currentBg})`);
}
box.append("text")
.attr("text-anchor", "middle")
.attr("dominant-baseline", "central")
.attr("fill", textColor)
.attr("font-size", "16px")
.attr("font-weight", "600")
.text(label);
// State indicator for cached boxes
if (state === "cached") {
box.append("text")
.attr("y", boxSize/2 + 12)
.attr("text-anchor", "middle")
.attr("fill", colors.cached)
.attr("font-size", "9px")
.attr("font-weight", "500")
.text("cached");
}
return box;
}
// --- LEFT SIDE: Naive (Without Cache) ---
const naiveGroup = svg.append("g");
// Header
naiveGroup.append("text")
.attr("x", leftX)
.attr("y", headerY)
.attr("text-anchor", "middle")
.attr("fill", colors.redundant)
.attr("font-size", "14px")
.attr("font-weight", "700")
.text("Without KV-Cache");
naiveGroup.append("text")
.attr("x", leftX)
.attr("y", headerY + 18)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.7)
.text("Recomputes everything each step");
// Draw token boxes for naive
const naiveTokens = data.naive.tokensComputed;
const naiveStartX = leftX - ((naiveTokens.length - 1) * (boxSize + boxGap)) / 2;
naiveTokens.forEach((token, i) => {
const x = naiveStartX + i * (boxSize + boxGap);
const isRedundant = data.naive.redundant.includes(token);
const state = isRedundant ? "redundant" : "current";
drawBox(naiveGroup, x, tokenY, token, state);
});
// Step indicator
naiveGroup.append("text")
.attr("x", leftX)
.attr("y", tokenY + 55)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text(`Step ${data.step}: Computing K,V for ${naiveTokens.length} token${naiveTokens.length > 1 ? 's' : ''}`);
if (data.step > 1) {
naiveGroup.append("text")
.attr("x", leftX)
.attr("y", tokenY + 72)
.attr("text-anchor", "middle")
.attr("fill", colors.redundant)
.attr("font-size", "10px")
.attr("font-weight", "500")
.text(`${data.naive.redundant.length} redundant computation${data.naive.redundant.length > 1 ? 's' : ''}!`);
}
// --- RIGHT SIDE: Cached (With Cache) ---
const cachedGroup = svg.append("g");
// Header
cachedGroup.append("text")
.attr("x", rightX)
.attr("y", headerY)
.attr("text-anchor", "middle")
.attr("fill", colors.efficient)
.attr("font-size", "14px")
.attr("font-weight", "700")
.text("With KV-Cache");
cachedGroup.append("text")
.attr("x", rightX)
.attr("y", headerY + 18)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.attr("opacity", 0.7)
.text("Only computes new tokens");
// Draw all tokens up to current step
const allCachedTokens = kvTokens.slice(0, data.step + 1);
const cachedStartX = rightX - ((allCachedTokens.length - 1) * (boxSize + boxGap)) / 2;
allCachedTokens.forEach((token, i) => {
const x = cachedStartX + i * (boxSize + boxGap);
const isCached = data.cached.cached.includes(token);
const isCurrent = data.cached.tokensComputed.includes(token);
const state = isCached ? "cached" : (isCurrent ? "efficient" : "current");
drawBox(cachedGroup, x, tokenY, token, state);
});
// Step indicator
cachedGroup.append("text")
.attr("x", rightX)
.attr("y", tokenY + 55)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text(`Step ${data.step}: Computing K,V for ${data.cached.currentOps} token${data.cached.currentOps > 1 ? 's' : ''}`);
if (data.step > 1) {
cachedGroup.append("text")
.attr("x", rightX)
.attr("y", tokenY + 72)
.attr("text-anchor", "middle")
.attr("fill", colors.cached)
.attr("font-size", "10px")
.attr("font-weight", "500")
.text(`${data.cached.cached.length} token${data.cached.cached.length > 1 ? 's' : ''} retrieved from cache`);
}
// --- STATS COMPARISON ---
const statsGroup = svg.append("g");
// Divider line
statsGroup.append("line")
.attr("x1", 50)
.attr("x2", width - 50)
.attr("y1", statsY - 40)
.attr("y2", statsY - 40)
.attr("stroke", theme.nodeStroke)
.attr("stroke-opacity", 0.3)
.attr("stroke-dasharray", "4,4");
// Stats header
statsGroup.append("text")
.attr("x", width / 2)
.attr("y", statsY - 20)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "12px")
.attr("font-weight", "600")
.text("Operation Count Comparison");
// Left stats (Naive)
statsGroup.append("text")
.attr("x", leftX)
.attr("y", statsY + 10)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("This step:");
statsGroup.append("text")
.attr("x", leftX)
.attr("y", statsY + 28)
.attr("text-anchor", "middle")
.attr("fill", colors.redundant)
.attr("font-size", "18px")
.attr("font-weight", "700")
.text(`${data.naive.currentOps} ops`);
statsGroup.append("text")
.attr("x", leftX)
.attr("y", statsY + 52)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Running total:");
statsGroup.append("text")
.attr("x", leftX)
.attr("y", statsY + 70)
.attr("text-anchor", "middle")
.attr("fill", colors.redundant)
.attr("font-size", "18px")
.attr("font-weight", "700")
.text(`${data.naive.totalOps} ops`);
// Right stats (Cached)
statsGroup.append("text")
.attr("x", rightX)
.attr("y", statsY + 10)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("This step:");
statsGroup.append("text")
.attr("x", rightX)
.attr("y", statsY + 28)
.attr("text-anchor", "middle")
.attr("fill", colors.efficient)
.attr("font-size", "18px")
.attr("font-weight", "700")
.text(`${data.cached.currentOps} ops`);
statsGroup.append("text")
.attr("x", rightX)
.attr("y", statsY + 52)
.attr("text-anchor", "middle")
.attr("fill", theme.nodeText)
.attr("font-size", "11px")
.text("Running total:");
statsGroup.append("text")
.attr("x", rightX)
.attr("y", statsY + 70)
.attr("text-anchor", "middle")
.attr("fill", colors.efficient)
.attr("font-size", "18px")
.attr("font-weight", "700")
.text(`${data.cached.totalOps} ops`);
// Savings indicator
const savings = data.naive.totalOps - data.cached.totalOps;
const savingsPercent = Math.round((savings / data.naive.totalOps) * 100);
if (savings > 0) {
statsGroup.append("text")
.attr("x", width / 2)
.attr("y", statsY + 50)
.attr("text-anchor", "middle")
.attr("fill", colors.efficient)
.attr("font-size", "13px")
.attr("font-weight", "600")
.text(`${savingsPercent}% fewer operations`);
}
// Center divider
svg.append("line")
.attr("x1", width / 2)
.attr("x2", width / 2)
.attr("y1", 60)
.attr("y2", height - 20)
.attr("stroke", theme.nodeStroke)
.attr("stroke-opacity", 0.2)
.attr("stroke-width", 1);
return svg.node();
}
// Legend
kvLegend = {
const theme = diagramTheme;
const colors = {
redundant: theme.isDark ? "#ef4444" : "#dc2626",
efficient: theme.isDark ? "#22c55e" : "#16a34a",
cached: theme.isDark ? "#3b82f6" : "#2563eb",
current: theme.highlight
};
const container = htl.html`<div style="
display: flex;
justify-content: center;
gap: 24px;
margin-top: 8px;
font-family: 'JetBrains Mono', 'Fira Code', monospace;
font-size: 11px;
">
<span style="color: ${colors.current}">● Current computation</span>
<span style="color: ${colors.redundant}">● Redundant (wasted)</span>
<span style="color: ${colors.cached}">● Retrieved from cache</span>
<span style="color: ${colors.efficient}">● Efficient (new only)</span>
</div>`;
return container;
}
```
**KV-Cache** stores Key and Value projections from previous tokens:
- Without cache: O(n^2) per token, O(n^3) total for n tokens
- With cache: O(n) per token, O(n^2) total for n tokens
This is a crucial optimization for fast inference!
### How KV-Cache Works
In attention, we compute:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
For autoregressive generation:
1. **First forward pass (prompt)**: Compute K, V for all prompt tokens and cache them
2. **Each new token**: Only compute Q, K, V for the new token
3. **Attention**: New Q attends to cached K, V plus new K, V
4. **Update cache**: Append new K, V to the cache
### Memory Tradeoff
KV-cache trades memory for speed:
| Aspect | Without Cache | With Cache |
|--------|---------------|------------|
| Computation per token | O(n^2) | O(n) |
| Memory | O(1) extra | O(n * layers * d) |
| Total time for n tokens | O(n^3) | O(n^2) |
For a model with:
- 32 layers, d_model = 4096, 8K context
- KV cache size = 2 (K and V) × 32 × 4096 × 8192 × 2 bytes (float16) = ~4GB per sequence
This is why long-context models need significant GPU memory!
### Practical Considerations
- **Prompt processing**: First pass processes entire prompt (can be batched efficiently)
- **Generation**: Subsequent tokens are generated one at a time (memory-bound)
- **Batch size tradeoff**: Larger batches amortize overhead but need more KV-cache memory
- **Context length**: Longer contexts need more cache memory per sequence
Note: Our simple `generate()` function recomputes everything each step for clarity. Production implementations use KV-caching for efficiency.
## Interactive Exploration
Experiment with sampling strategies in real-time. Adjust temperature, top-k, and top-p to see how each reshapes the probability distribution before sampling.
```{ojs}
//| echo: false
// Token labels (simulating a small vocabulary)
tokenLabels = ["the", "a", "to", "of", "and", "in", "is", "it", "for", "that", "on", "was", "with", "as", "be"]
// Pre-defined logits that create a realistic distribution
// High values for common tokens, tapering off for less likely ones
baseLogits = [2.8, 2.1, 1.6, 1.3, 1.0, 0.7, 0.4, 0.1, -0.2, -0.5, -0.9, -1.3, -1.7, -2.2, -2.8]
// Softmax function
function softmax(logits, temperature) {
const scaled = logits.map(x => x / temperature);
const maxVal = Math.max(...scaled);
const exps = scaled.map(x => Math.exp(x - maxVal));
const sum = exps.reduce((a, b) => a + b, 0);
return exps.map(x => x / sum);
}
// Top-k filtering: keep only k highest probability tokens
function applyTopK(probs, k) {
if (k <= 0 || k >= probs.length) return probs.slice();
// Find the k-th highest probability
const sorted = [...probs].sort((a, b) => b - a);
const threshold = sorted[k - 1];
// Zero out probabilities below threshold
const filtered = probs.map(p => p >= threshold ? p : 0);
// Renormalize
const sum = filtered.reduce((a, b) => a + b, 0);
return sum > 0 ? filtered.map(p => p / sum) : filtered;
}
// Top-p (nucleus) filtering: keep smallest set with cumulative prob >= p
function applyTopP(probs, p) {
if (p >= 1.0) return probs.slice();
// Sort indices by probability (descending)
const indexed = probs.map((prob, idx) => ({prob, idx}));
indexed.sort((a, b) => b.prob - a.prob);
// Find cumulative probability cutoff
let cumSum = 0;
const keepIndices = new Set();
for (const item of indexed) {
keepIndices.add(item.idx);
cumSum += item.prob;
if (cumSum >= p) break;
}
// Zero out tokens not in the nucleus
const filtered = probs.map((prob, idx) => keepIndices.has(idx) ? prob : 0);
// Renormalize
const sum = filtered.reduce((a, b) => a + b, 0);
return sum > 0 ? filtered.map(prob => prob / sum) : filtered;
}
// Compute entropy (measure of diversity)
function entropy(probs) {
return -probs.reduce((sum, p) => p > 0 ? sum + p * Math.log2(p) : sum, 0);
}
```
```{ojs}
//| echo: false
// Dark mode detection
isDark = {
const check = () => document.body.classList.contains('quarto-dark');
return check();
}
// Theme colors for light/dark mode
theme = isDark ? {
barKept: '#6b8cae',
barFiltered: '#3a3c3d',
refDot: '#7a8a9a'
} : {
barKept: '#3b82f6',
barFiltered: '#e5e7eb',
refDot: '#94a3b8'
}
```
```{ojs}
//| echo: false
// Input controls
viewof temperature = Inputs.range([0.1, 2.0], {
value: 1.0,
step: 0.1,
label: "Temperature"
})
viewof topK = Inputs.range([0, 15], {
value: 0,
step: 1,
label: "Top-k (0 = off)"
})
viewof topP = Inputs.range([0.1, 1.0], {
value: 1.0,
step: 0.05,
label: "Top-p (1.0 = off)"
})
```
```{ojs}
//| echo: false
// Apply the sampling pipeline: temperature → top-k → top-p
baseProbs = softmax(baseLogits, temperature)
afterTopK = topK > 0 ? applyTopK(baseProbs, topK) : baseProbs
finalProbs = topP < 1.0 ? applyTopP(afterTopK, topP) : afterTopK
// Build data for visualization
chartData = tokenLabels.map((label, i) => ({
token: label,
probability: finalProbs[i],
originalProb: baseProbs[i],
kept: finalProbs[i] > 0,
index: i
}))
// Compute metrics
tokensKept = finalProbs.filter(p => p > 0).length
currentEntropy = entropy(finalProbs)
```
```{ojs}
//| echo: false
Plot = import("https://esm.sh/@observablehq/plot@0.6")
Plot.plot({
title: "Token Probability Distribution After Filtering",
subtitle: `Temperature: ${temperature.toFixed(1)} | Top-k: ${topK || "off"} | Top-p: ${topP.toFixed(2)}`,
width: 650,
height: 350,
marginBottom: 60,
marginLeft: 50,
x: {
label: "Token →",
tickRotate: -45
},
y: {
label: "↑ Probability",
domain: [0, Math.max(0.5, Math.max(...finalProbs) * 1.1)]
},
color: {
domain: [true, false],
range: [theme.barKept, theme.barFiltered]
},
marks: [
// Bars showing filtered probabilities
Plot.barY(chartData, {
x: "token",
y: "probability",
fill: "kept",
tip: true,
title: d => `${d.token}\nProb: ${(d.probability * 100).toFixed(1)}%\nOriginal: ${(d.originalProb * 100).toFixed(1)}%`
}),
// Reference line showing original distribution (before filtering)
Plot.dot(chartData, {
x: "token",
y: "originalProb",
fill: "none",
stroke: theme.refDot,
strokeWidth: 2,
r: 4,
title: d => `Original: ${(d.originalProb * 100).toFixed(1)}%`
}),
Plot.ruleY([0])
]
})
```
```{ojs}
//| echo: false
// Metrics display
md`**Tokens kept:** ${tokensKept} of ${tokenLabels.length} | **Entropy:** ${currentEntropy.toFixed(2)} bits ${currentEntropy < 1.5 ? "(focused)" : currentEntropy > 3 ? "(diverse)" : "(balanced)"}`
```
```{ojs}
//| echo: false
// Legend
md`<span style="color: ${theme.barKept}">■</span> Probability after filtering <span style="color: ${theme.refDot}">○</span> Original probability (before top-k/top-p)`
```
::: {.callout-tip}
## Try This
1. **Temperature effect**: Set top-k=0, top-p=1.0, then slide temperature from 0.1 to 2.0. Watch how low temperature makes "the" dominate, while high temperature flattens the distribution.
2. **Top-k sharpness**: Set temperature=1.0, top-p=1.0, then increase top-k from 1 to 10. Notice how exactly k tokens are kept, regardless of their probabilities.
3. **Top-p adaptation**: Set temperature=1.0, top-k=0, then decrease top-p from 1.0 to 0.5. Notice how top-p keeps more tokens when the distribution is flat, fewer when peaked.
4. **Combined filtering**: Try temperature=0.7, top-k=10, top-p=0.9. This is a realistic configuration for balanced text generation.
:::
## Exercises
### Exercise 1: Temperature Exploration
Experiment with extreme temperatures and observe the output behavior:
```{python}
# Try very low and very high temperatures
print("Extreme temperature exploration:\n")
for temp in [0.1, 0.5, 1.0, 2.0, 5.0]:
print(f"Temperature = {temp}:")
outputs = set()
for i in range(5):
torch.manual_seed(i)
out = generate(model, prompt, max_new_tokens=8, temperature=temp, do_sample=True)
outputs.add(tuple(out[0, 5:].tolist()))
print(f" {len(outputs)}/5 unique sequences")
# Show one sample
torch.manual_seed(42)
sample = generate(model, prompt, max_new_tokens=8, temperature=temp, do_sample=True)
print(f" Sample: {sample[0, 5:].tolist()}")
print()
```
### Exercise 2: Top-k vs Top-p
Compare how top-k and top-p behave differently:
```{python}
# Compare filtering approaches
print("Top-k vs Top-p filtering:\n")
# Create a bimodal distribution (two likely options)
bimodal_logits = torch.tensor([[3.0, 3.0, -1.0, -1.0, -2.0, -2.0, -3.0, -3.0, -4.0, -4.0]])
print("Original probabilities (bimodal - two equally likely tokens):")
bimodal_probs = F.softmax(bimodal_logits, dim=-1)
for i, p in enumerate(bimodal_probs[0][:5]):
print(f" Token {i}: {p.item():.3f}")
# Top-k=2 keeps exactly 2 tokens
topk_filtered = top_k_filtering(bimodal_logits.clone(), 2)
topk_probs = F.softmax(topk_filtered, dim=-1)
# Top-p=0.5 adapts to distribution
topp_filtered = top_p_filtering(bimodal_logits.clone(), 0.5)
topp_probs = F.softmax(topp_filtered, dim=-1)
print(f"\nTop-k=2 keeps: {(topk_probs > 0).sum().item()} tokens")
print(f"Top-p=0.5 keeps: {(topp_probs > 0).sum().item()} tokens")
print("\n** Key insight: Top-k always keeps exactly k tokens.")
print(" Top-p adapts: it may keep fewer tokens if one dominates.")
```
### Exercise 3: Repetition Penalty Effects
Explore how different repetition penalty values affect generation:
```{python}
# Compare different repetition penalties
print("Repetition penalty comparison (greedy decoding, 40 tokens):\n")
for penalty in [1.0, 1.1, 1.3, 1.5, 2.0]:
out = generate(model, prompt, max_new_tokens=40, do_sample=False, repetition_penalty=penalty)
tokens = out[0, 5:].tolist()
# Count unique tokens
unique_ratio = len(set(tokens)) / len(tokens)
print(f"Penalty = {penalty}:")
print(f" Unique tokens: {len(set(tokens))}/{len(tokens)} ({unique_ratio*100:.0f}%)")
print(f" First 15: {tokens[:15]}")
print()
```
### Exercise 4: Observing Repetition in Long Generation
Observe how greedy decoding can lead to repetition:
```{python}
# Generate longer sequences to see repetition patterns
print("Long greedy generation (may show repetition):\n")
# Generate more tokens
long_output = generate_greedy(model, prompt, max_new_tokens=50)
tokens = long_output[0].tolist()
print(f"Generated sequence ({len(tokens)} tokens):")
print(tokens)
# Count token frequency
from collections import Counter
token_counts = Counter(tokens)
print(f"\nMost common tokens:")
for token, count in token_counts.most_common(5):
print(f" Token {token}: {count} times ({count/len(tokens)*100:.1f}%)")
```
## The Complete Generation Function
Here's the main generation function from our codebase:
```{python}
# Display the signature and key parts
import inspect
from generation import generate
print("generate() function signature:")
print(inspect.signature(generate))
print()
print("Key parameters:")
print(" - model: The language model")
print(" - prompt_tokens: Starting sequence (batch, seq_len)")
print(" - max_new_tokens: How many tokens to generate")
print(" - temperature: Distribution sharpness (default 1.0)")
print(" - top_k: Filter to top k tokens (optional)")
print(" - top_p: Nucleus sampling threshold (optional)")
print(" - do_sample: If False, use greedy decoding")
print(" - eos_token_id: Stop token (optional)")
print(" - repetition_penalty: Penalize repeated tokens (default 1.0)")
```
## Summary
Key takeaways from this module:
1. **Autoregressive generation**: Produce tokens one at a time, feeding each back as input
2. **Greedy decoding**: Always pick the max - deterministic but can be repetitive
3. **Temperature**: Controls randomness - lower is more focused, higher is more diverse
4. **Top-k sampling**: Limits choices to k most likely tokens
5. **Top-p (nucleus) sampling**: Adapts to distribution shape - keeps more tokens when uncertain
6. **Repetition penalty**: Reduces probability of previously-generated tokens to prevent loops
7. **Stop conditions**: Use EOS tokens and max length to control when generation ends
8. **Combine strategies**: Temperature + top-p + repetition penalty is common in practice
9. **KV-cache**: Essential optimization - trades memory for O(n) speedup per token
### Common Pitfalls
| Problem | Cause | Solution |
|---------|-------|----------|
| Repetitive output | Greedy decoding or low temperature | Use sampling, repetition penalty |
| Incoherent nonsense | Temperature too high | Lower temperature, use top-p |
| Cuts off mid-sentence | max_new_tokens too low | Increase limit, ensure EOS handling |
| Slow generation | No KV-cache | Implement caching (production) |
| Out of memory | Long context + large batch | Reduce batch size or context |
## Conclusion
Congratulations! You've completed the Learn LLM series. You now understand all the building blocks of a language model:
1. **Tensors**: The fundamental data structure
2. **Autograd**: Automatic differentiation for training
3. **Tokenization**: Converting text to numbers
4. **Embeddings**: Learned vector representations
5. **Attention**: The mechanism that lets tokens interact
6. **Transformer**: The complete architecture
7. **Training**: How models learn from data
8. **Generation**: How to produce text from trained models
### What's Next?
- Check out the **minigpt** directory to see everything assembled into a working model
- Train your own small language model on real data
- Explore the [Going Deeper](#going-deeper) resources for advanced topics
### Going Deeper
**Core Papers:**
- [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) - Top-p (nucleus) sampling paper
- [Hierarchical Neural Story Generation](https://arxiv.org/abs/1805.04833) - Sampling strategies for creative text
- [CTRL: A Conditional Transformer Language Model](https://arxiv.org/abs/1909.05858) - Repetition penalty
**Advanced Topics (not covered here):**
- **Beam Search**: Maintain k best partial sequences; better for translation, worse for open-ended generation
- **Speculative Decoding**: Use a small draft model to propose tokens, verify with large model in parallel
- **Structured Generation**: Constrain outputs to valid JSON, code syntax, or grammar rules
- **Contrastive Decoding**: Compare probabilities from expert and amateur models
**Practical Resources:**
- [Transformers Generation Strategies](https://huggingface.co/docs/transformers/generation_strategies) - HuggingFace generation docs
- [vLLM](https://github.com/vllm-project/vllm) - High-performance inference with PagedAttention
- [Outlines](https://github.com/outlines-dev/outlines) - Structured generation library