Sampling and Generation Techniques for Language Models
Overview
In our previous lessons, we explored the architecture of transformer models and their applications in NLP. We now understand how these models can predict the next token in a sequence. But a crucial question remains: once we have these predictions, how do we use them to generate coherent and high-quality text?
This lesson introduces sampling and generation techniques—the methods that transform a language model's raw probability distributions into actual text. These techniques profoundly influence the quality, diversity, and characteristics of the generated content, making them essential knowledge for anyone working with generative language models.
Learning Objectives
After completing this lesson, you will be able to:
- Understand the key challenges in text generation from language models
- Explain different deterministic sampling methods (greedy, beam search)
- Implement various probabilistic sampling approaches (top-k, top-p/nucleus)
- Compare the trade-offs between different sampling techniques
- Apply practical strategies to control generation quality and diversity
- Debug common issues in text generation pipelines
The Text Generation Challenge
From Probabilities to Text
Language models don't directly output text—they produce probability distributions over the vocabulary for each position in a sequence. The fundamental challenge is converting these probabilities into actual token selections.
Analogy: The Chef's Choice
Imagine a chef deciding what ingredient to add next to a dish. The chef has:
- A recipe (the prompt)
- Experience (the training data)
- Several possible ingredients that could work (token probabilities)
Different selection strategies represent different cooking styles:
- A rigid, by-the-book chef always picks the most obvious ingredient (greedy search)
- A methodical chef considers several alternative combinations before deciding (beam search)
- A creative chef introduces some spontaneity while still following culinary principles (temperature sampling)
- An experimental chef focuses on unusual but still workable ingredients (nucleus sampling)
The Exploration-Exploitation Dilemma
Text generation faces a fundamental trade-off:
- Exploitation: Selecting the most likely tokens to maximize fluency and coherence
- Exploration: Introducing diversity and creativity by considering less likely options
This interactive tool is still under development. Check back later!
Deterministic Sampling Methods
Greedy Search: Always Choose the Most Likely
The simplest approach is to always select the token with the highest probability at each step.
Algorithm
- Get probability distribution over vocabulary from the model
- Select the token with the highest probability
- Add this token to the generated sequence
- Repeat until end condition is met (e.g., EOS token or maximum length)
Python Implementation
1 def greedy_search(model, tokenizer, prompt, max_length=50): 2 input_ids = tokenizer.encode(prompt, return_tensors="pt") 3 generated = input_ids[0].tolist() 4 5 for _ in range(max_length): 6 # Get model predictions 7 outputs = model(input_ids=input_ids) 8 next_token_logits = outputs.logits[0, -1, :] 9 10 # Select the token with highest probability 11 next_token_id = torch.argmax(next_token_logits).item() 12 13 # Add the token to the sequence 14 generated.append(next_token_id) 15 input_ids = torch.tensor([generated]).long() 16 17 # Stop if we generate EOS token 18 if next_token_id == tokenizer.eos_token_id: 19 break 20 21 return tokenizer.decode(generated)
Visualization: Greedy Decoding Path
This interactive tool is still under development. Check back later!
Advantages and Disadvantages
Advantages:
- Simple to implement
- Computationally efficient
- Often produces fluent text for short completions
Disadvantages:
- No exploration of alternative paths
- Prone to repetitive loops ("I'm happy to help you. I'm happy to help you...")
- Can get stuck in local optima
- Generates deterministic output (same prompt always gives same completion)
Beam Search: Exploring Multiple Paths
Beam search improves upon greedy search by keeping track of the top-K most likely sequences at each step.
Algorithm
- Start with the top-K most likely first tokens
- For each candidate sequence, compute probabilities of all possible next tokens
- Keep only the top-K most likely sequences overall
- Repeat until all K beams reach an end condition
- Return the sequence with the highest overall probability
Visualization: Beam Search Tree
This interactive tool is still under development. Check back later!
Python Implementation
1 def beam_search(model, tokenizer, prompt, beam_width=5, max_length=50): 2 # Encode the prompt 3 input_ids = tokenizer.encode(prompt, return_tensors="pt") 4 5 # Initialize beams with prompt tokens and score 0 6 beams = [(input_ids[0].tolist(), 0.0)] 7 finished_beams = [] 8 9 for _ in range(max_length): 10 candidates = [] 11 12 # For each existing beam 13 for beam_tokens, beam_score in beams: 14 # Skip if this beam is finished 15 if beam_tokens[-1] == tokenizer.eos_token_id: 16 finished_beams.append((beam_tokens, beam_score)) 17 continue 18 19 # Get predictions from the model 20 outputs = model(input_ids=torch.tensor([beam_tokens])) 21 next_token_logits = outputs.logits[0, -1, :] 22 next_token_probs = torch.nn.functional.softmax(next_token_logits, dim=-1) 23 24 # Get top-k tokens 25 topk_probs, topk_indices = torch.topk(next_token_probs, beam_width) 26 27 # Add each candidate to our list 28 for prob, token_id in zip(topk_probs, topk_indices): 29 # Log probability for numeric stability 30 log_prob = torch.log(prob).item() 31 32 # New sequence and its score 33 new_tokens = beam_tokens + [token_id.item()] 34 new_score = beam_score + log_prob 35 36 candidates.append((new_tokens, new_score)) 37 38 # If all beams are finished or no candidates 39 if len(finished_beams) == beam_width or not candidates: 40 break 41 42 # Sort candidates by score and keep top beam_width 43 candidates.sort(key=lambda x: x[1], reverse=True) 44 beams = candidates[:beam_width - len(finished_beams)] 45 46 # Add any unfinished beams 47 finished_beams.extend(beams) 48 49 # Sort by score and return the best one 50 finished_beams.sort(key=lambda x: x[1], reverse=True) 51 best_tokens = finished_beams[0][0] 52 53 return tokenizer.decode(best_tokens)
Advantages and Disadvantages
Advantages:
- Considers multiple possible continuations
- Often produces more coherent and fluent text than greedy search
- Can avoid some local optima
Disadvantages:
- Computationally more expensive than greedy search
- Still produces deterministic output
- Favors common, generic language over more interesting text
- Beam search curse: wider beams don't always yield better results
- Prone to length bias (favoring shorter completions)
Comparing Greedy vs. Beam Results
Probabilistic Sampling Methods
Temperature Sampling: Controlling Randomness
Temperature sampling introduces controlled randomness by adjusting the "temperature" parameter of the softmax function.
Mathematical Formulation
Where:
- is the logit for token
- is the temperature parameter
- Lower temperature → more deterministic (as , approaches greedy search)
- Higher temperature → more random (as , approaches uniform distribution)
Visualization: Effect of Temperature
This interactive tool is still under development. Check back later!
Python Implementation
1 def temperature_sampling(model, tokenizer, prompt, temperature=0.7, max_length=50): 2 input_ids = tokenizer.encode(prompt, return_tensors="pt") 3 generated = input_ids[0].tolist() 4 5 for _ in range(max_length): 6 outputs = model(input_ids=torch.tensor([generated])) 7 next_token_logits = outputs.logits[0, -1, :] 8 9 # Apply temperature 10 scaled_logits = next_token_logits / temperature 11 12 # Convert to probabilities 13 probs = torch.nn.functional.softmax(scaled_logits, dim=-1) 14 15 # Sample from the distribution 16 next_token_id = torch.multinomial(probs, num_samples=1).item() 17 18 generated.append(next_token_id) 19 20 if next_token_id == tokenizer.eos_token_id: 21 break 22 23 return tokenizer.decode(generated)
Effects of Different Temperatures
Top-K Sampling: Focusing on the Most Likely Options
Top-K sampling limits the sampling pool to only the K most likely next tokens.
Algorithm
- Get probability distribution over vocabulary
- Select the K tokens with highest probability
- Renormalize probabilities among these K tokens
- Sample the next token from this reduced distribution
Visualization: Top-K Filtering
This interactive tool is still under development. Check back later!
Python Implementation
1 def top_k_sampling(model, tokenizer, prompt, k=50, temperature=1.0, max_length=50): 2 input_ids = tokenizer.encode(prompt, return_tensors="pt") 3 generated = input_ids[0].tolist() 4 5 for _ in range(max_length): 6 outputs = model(input_ids=torch.tensor([generated])) 7 next_token_logits = outputs.logits[0, -1, :] 8 9 # Apply temperature 10 scaled_logits = next_token_logits / temperature 11 12 # Filter to top-k tokens 13 top_k_logits, top_k_indices = torch.topk(scaled_logits, k) 14 15 # Create a new distribution with only top-k tokens 16 filtered_logits = torch.full_like(scaled_logits, float('-inf')) 17 filtered_logits.scatter_(0, top_k_indices, top_k_logits) 18 19 # Convert to probabilities 20 probs = torch.nn.functional.softmax(filtered_logits, dim=-1) 21 22 # Sample from the filtered distribution 23 next_token_id = torch.multinomial(probs, num_samples=1).item() 24 25 generated.append(next_token_id) 26 27 if next_token_id == tokenizer.eos_token_id: 28 break 29 30 return tokenizer.decode(generated)
Advantages and Disadvantages
Advantages:
- Controls for extreme randomness
- Prevents selection of very low probability tokens
- Ensures output quality remains above a threshold
Disadvantages:
- Hard cutoff at K can be arbitrary
- Different positions may need different K values
- Hard to choose a universally good K value
Top-p (Nucleus) Sampling: Adaptive Token Selection
Top-p or nucleus sampling dynamically selects the smallest set of tokens whose cumulative probability exceeds threshold p.
Algorithm
- Get probability distribution over vocabulary
- Sort tokens by probability (descending)
- Select the smallest set of tokens whose cumulative probability exceeds p
- Renormalize probabilities among this set
- Sample the next token from this reduced distribution
Visualization: Nucleus Sampling
This interactive tool is still under development. Check back later!
Python Implementation
1 def nucleus_sampling(model, tokenizer, prompt, p=0.9, temperature=1.0, max_length=50): 2 input_ids = tokenizer.encode(prompt, return_tensors="pt") 3 generated = input_ids[0].tolist() 4 5 for _ in range(max_length): 6 outputs = model(input_ids=torch.tensor([generated])) 7 next_token_logits = outputs.logits[0, -1, :] 8 9 # Apply temperature 10 scaled_logits = next_token_logits / temperature 11 12 # Convert to probabilities 13 probs = torch.nn.functional.softmax(scaled_logits, dim=-1) 14 15 # Sort probabilities in descending order 16 sorted_probs, sorted_indices = torch.sort(probs, descending=True) 17 18 # Compute cumulative probabilities 19 cumulative_probs = torch.cumsum(sorted_probs, dim=-1) 20 21 # Find the nucleus (tokens that cumulatively exceed p) 22 nucleus_indices = cumulative_probs <= p 23 24 # Add the first token whose cumulative probability exceeds p 25 if nucleus_indices.sum() < len(cumulative_probs): 26 nucleus_indices[nucleus_indices.sum()] = True 27 28 # Select tokens in the nucleus and their probabilities 29 nucleus_tokens = sorted_indices[nucleus_indices] 30 nucleus_probs = sorted_probs[nucleus_indices] 31 32 # Renormalize probabilities within the nucleus 33 nucleus_probs = nucleus_probs / nucleus_probs.sum() 34 35 # Sample from the nucleus distribution 36 sample_idx = torch.multinomial(nucleus_probs, num_samples=1).item() 37 next_token_id = nucleus_tokens[sample_idx].item() 38 39 generated.append(next_token_id) 40 41 if next_token_id == tokenizer.eos_token_id: 42 break 43 44 return tokenizer.decode(generated)
Advantages of Nucleus Sampling
- Adaptive selection: Adjusts to the confidence of the model at each step
- Balance of quality and diversity: Produces interesting text while maintaining coherence
- Works across contexts: Single p value works well for different prompts and generation stages
Advanced Techniques and Considerations
Combining Strategies: The Best of All Worlds
Modern text generation systems often combine multiple techniques:
1 def advanced_sampling(model, tokenizer, prompt, 2 top_k=50, top_p=0.9, temperature=0.7, 3 repetition_penalty=1.2, max_length=100): 4 input_ids = tokenizer.encode(prompt, return_tensors="pt") 5 generated = input_ids[0].tolist() 6 past_tokens = set(generated) 7 8 for _ in range(max_length): 9 outputs = model(input_ids=torch.tensor([generated])) 10 next_token_logits = outputs.logits[0, -1, :] 11 12 # Apply repetition penalty 13 for token_id in past_tokens: 14 next_token_logits[token_id] /= repetition_penalty 15 16 # Apply temperature 17 scaled_logits = next_token_logits / temperature 18 19 # Filter with top-k 20 top_k_logits, top_k_indices = torch.topk(scaled_logits, top_k) 21 filtered_logits = torch.full_like(scaled_logits, float('-inf')) 22 filtered_logits.scatter_(0, top_k_indices, top_k_logits) 23 24 # Convert to probabilities 25 probs = torch.nn.functional.softmax(filtered_logits, dim=-1) 26 27 # Apply nucleus (top-p) filtering 28 sorted_probs, sorted_indices = torch.sort(probs, descending=True) 29 cumulative_probs = torch.cumsum(sorted_probs, dim=-1) 30 nucleus_indices = cumulative_probs <= top_p 31 32 if nucleus_indices.sum() < len(cumulative_probs): 33 nucleus_indices[nucleus_indices.sum()] = True 34 35 nucleus_tokens = sorted_indices[nucleus_indices] 36 nucleus_probs = sorted_probs[nucleus_indices] 37 nucleus_probs = nucleus_probs / nucleus_probs.sum() 38 39 # Sample from the nucleus distribution 40 sample_idx = torch.multinomial(nucleus_probs, num_samples=1).item() 41 next_token_id = nucleus_tokens[sample_idx].item() 42 43 generated.append(next_token_id) 44 past_tokens.add(next_token_id) 45 46 if next_token_id == tokenizer.eos_token_id: 47 break 48 49 return tokenizer.decode(generated)
Handling Repetition
Repetition is a common problem in text generation. Several techniques address this:
- Repetition penalty: Reduce probability of recently generated tokens
- Frequency penalty: Penalize tokens based on their frequency in the generated text
- Presence penalty: Penalize tokens that have appeared at all in the generated text
Visualization: Effect of Repetition Penalty
This interactive tool is still under development. Check back later!
Contrastive Search: Optimizing for Diverse yet Coherent Text
Contrastive search aims to balance high probability with debiased self-similarity:
- Generate a set of candidate tokens with highest probability
- Select the token that maximizes fluency while minimizing similarity to already generated content
This helps prevent degeneration and promotes more diverse content.
Length Control Strategies
Different applications require different generation lengths:
- Maximum length: Simple hard cutoff
- Minimum length: Prevent early termination by masking EOS token
- Length penalty: Adjust scores based on sequence length (e.g., in beam search)
- Length normalization: Normalize scores by sequence length
Comparison of Sampling Methods
Qualitative Comparison
Output Examples
This interactive tool is still under development. Check back later!
Choosing the Right Sampling Method
The best sampling strategy depends on your specific task:
- Machine Translation: Beam search with length normalization
- Factual Question Answering: Greedy or narrow beam search
- Creative Writing: Nucleus sampling with temperature tuning
- Story Generation: Combined approaches with repetition penalties
- Dialogue Systems: Nucleus sampling with additional filtering
Practical Implementation with Hugging Face
The Hugging Face Transformers library provides a convenient interface for different sampling methods:
1 from transformers import pipeline, set_seed 2
3 # Set up the model 4 generator = pipeline('text-generation', model='gpt2-medium') 5 set_seed(42) # For reproducibility 6
7 # Sample prompt 8 prompt = "In a world where AI has become ubiquitous," 9
10 # Greedy search 11 greedy_output = generator(prompt, max_length=50, num_return_sequences=1, 12 do_sample=False) 13
14 # Beam search 15 beam_output = generator(prompt, max_length=50, num_return_sequences=1, 16 num_beams=5, early_stopping=True, do_sample=False) 17
18 # Temperature sampling 19 temp_output = generator(prompt, max_length=50, num_return_sequences=3, 20 do_sample=True, temperature=0.7) 21
22 # Top-K sampling 23 topk_output = generator(prompt, max_length=50, num_return_sequences=3, 24 do_sample=True, temperature=1.0, top_k=50) 25
26 # Nucleus sampling 27 nucleus_output = generator(prompt, max_length=50, num_return_sequences=3, 28 do_sample=True, temperature=1.0, top_p=0.9) 29
30 # Combined approach 31 combined_output = generator(prompt, max_length=50, num_return_sequences=3, 32 do_sample=True, temperature=0.7, top_k=50, top_p=0.9, 33 repetition_penalty=1.2) 34
35 # Print the results 36 for method, outputs in [ 37 ("Greedy", greedy_output), 38 ("Beam", beam_output), 39 ("Temperature", temp_output), 40 ("Top-K", topk_output), 41 ("Nucleus", nucleus_output), 42 ("Combined", combined_output) 43 ]: 44 print(f"\n=== {method} Sampling ===") 45 for i, output in enumerate(outputs): 46 print(f"Sample {i+1}: {output['generated_text']}")
Evaluating Generation Quality
Evaluating generation quality remains challenging. Common approaches include:
- Human evaluation: Still the gold standard but expensive and subjective
- Perplexity: Measures how well a probability model predicts a sample
- BLEU, ROUGE, METEOR: For tasks with reference outputs like translation
- Diversity metrics: Distinct-n, Self-BLEU, etc.
- Task-specific metrics: Depends on the specific application
Summary
In this lesson, we've covered:
- Deterministic methods like greedy search and beam search
- Probabilistic methods including temperature, top-k, and nucleus sampling
- Advanced techniques for handling repetition and improving generation
- Trade-offs between coherence, diversity, and computational cost
- Practical implementations using Python and Hugging Face
Understanding these sampling techniques is crucial for controlling AI-generated content, balancing creativity with coherence, and customizing behavior for specific applications. As language models continue to advance, the methods we use to sample from them will remain a critical area of research and practical importance.
Practice Exercises
-
Sampling Strategy Comparison:
- Generate completions for the same prompt using different sampling methods
- Compare and analyze the outputs for quality, diversity, and creativity
- Determine which method works best for different types of prompts
-
Implementation Challenge:
- Implement top-p sampling from scratch
- Compare your implementation with the Hugging Face version
- Extend the implementation to include repetition penalty
-
Parameter Tuning:
- Experiment with different temperature values (0.2-2.0)
- Try various combinations of top-k and top-p parameters
- Create a graph showing how perplexity and diversity metrics change with different parameters
-
Creative Application:
- Build a simple text generation interface with adjustable sampling parameters
- Create a system that automatically switches sampling methods based on the context
Additional Resources
- How to generate text: using different decoding methods for language generation with Transformers - Hugging Face blog
- The Curious Case of Neural Text Degeneration - Paper introducing nucleus sampling
- A Survey of Methods for Addressing Neural Text Degeneration
- Typical Decoding for Natural Language Generation
- Contrastive Search Is What You Need For Neural Text Generation