Add stuck initiatives audit report

This commit is contained in:
Ezra
2026-04-03 22:42:06 +00:00
parent dc3d975c2f
commit 56aa692d1c
1267 changed files with 1263232 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
---
description: Model evaluation benchmarks, experiment tracking, data curation, tokenizers, and interpretability tools.
---

View File

@@ -0,0 +1,519 @@
---
name: huggingface-tokenizers
description: Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integrates seamlessly with transformers. Use when you need high-performance tokenization or custom tokenizer training.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [tokenizers, transformers, datasets]
metadata:
hermes:
tags: [Tokenization, HuggingFace, BPE, WordPiece, Unigram, Fast Tokenization, Rust, Custom Tokenizer, Alignment Tracking, Production]
---
# HuggingFace Tokenizers - Fast Tokenization for NLP
Fast, production-ready tokenizers with Rust performance and Python ease-of-use.
## When to use HuggingFace Tokenizers
**Use HuggingFace Tokenizers when:**
- Need extremely fast tokenization (<20s per GB of text)
- Training custom tokenizers from scratch
- Want alignment tracking (token → original text position)
- Building production NLP pipelines
- Need to tokenize large corpora efficiently
**Performance**:
- **Speed**: <20 seconds to tokenize 1GB on CPU
- **Implementation**: Rust core with Python/Node.js bindings
- **Efficiency**: 10-100× faster than pure Python implementations
**Use alternatives instead**:
- **SentencePiece**: Language-independent, used by T5/ALBERT
- **tiktoken**: OpenAI's BPE tokenizer for GPT models
- **transformers AutoTokenizer**: Loading pretrained only (uses this library internally)
## Quick start
### Installation
```bash
# Install tokenizers
pip install tokenizers
# With transformers integration
pip install tokenizers transformers
```
### Load pretrained tokenizer
```python
from tokenizers import Tokenizer
# Load from HuggingFace Hub
tokenizer = Tokenizer.from_pretrained("bert-base-uncased")
# Encode text
output = tokenizer.encode("Hello, how are you?")
print(output.tokens) # ['hello', ',', 'how', 'are', 'you', '?']
print(output.ids) # [7592, 1010, 2129, 2024, 2017, 1029]
# Decode back
text = tokenizer.decode(output.ids)
print(text) # "hello, how are you?"
```
### Train custom BPE tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
# Initialize tokenizer with BPE model
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
# Configure trainer
trainer = BpeTrainer(
vocab_size=30000,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
min_frequency=2
)
# Train on files
files = ["train.txt", "validation.txt"]
tokenizer.train(files, trainer)
# Save
tokenizer.save("my-tokenizer.json")
```
**Training time**: ~1-2 minutes for 100MB corpus, ~10-20 minutes for 1GB
### Batch encoding with padding
```python
# Enable padding
tokenizer.enable_padding(pad_id=3, pad_token="[PAD]")
# Encode batch
texts = ["Hello world", "This is a longer sentence"]
encodings = tokenizer.encode_batch(texts)
for encoding in encodings:
print(encoding.ids)
# [101, 7592, 2088, 102, 3, 3, 3]
# [101, 2023, 2003, 1037, 2936, 6251, 102]
```
## Tokenization algorithms
### BPE (Byte-Pair Encoding)
**How it works**:
1. Start with character-level vocabulary
2. Find most frequent character pair
3. Merge into new token, add to vocabulary
4. Repeat until vocabulary size reached
**Used by**: GPT-2, GPT-3, RoBERTa, BART, DeBERTa
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
tokenizer = Tokenizer(BPE(unk_token="<|endoftext|>"))
tokenizer.pre_tokenizer = ByteLevel()
trainer = BpeTrainer(
vocab_size=50257,
special_tokens=["<|endoftext|>"],
min_frequency=2
)
tokenizer.train(files=["data.txt"], trainer=trainer)
```
**Advantages**:
- Handles OOV words well (breaks into subwords)
- Flexible vocabulary size
- Good for morphologically rich languages
**Trade-offs**:
- Tokenization depends on merge order
- May split common words unexpectedly
### WordPiece
**How it works**:
1. Start with character vocabulary
2. Score merge pairs: `frequency(pair) / (frequency(first) × frequency(second))`
3. Merge highest scoring pair
4. Repeat until vocabulary size reached
**Used by**: BERT, DistilBERT, MobileBERT
```python
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import BertNormalizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
tokenizer.normalizer = BertNormalizer(lowercase=True)
tokenizer.pre_tokenizer = Whitespace()
trainer = WordPieceTrainer(
vocab_size=30522,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
continuing_subword_prefix="##"
)
tokenizer.train(files=["corpus.txt"], trainer=trainer)
```
**Advantages**:
- Prioritizes meaningful merges (high score = semantically related)
- Used successfully in BERT (state-of-the-art results)
**Trade-offs**:
- Unknown words become `[UNK]` if no subword match
- Saves vocabulary, not merge rules (larger files)
### Unigram
**How it works**:
1. Start with large vocabulary (all substrings)
2. Compute loss for corpus with current vocabulary
3. Remove tokens with minimal impact on loss
4. Repeat until vocabulary size reached
**Used by**: ALBERT, T5, mBART, XLNet (via SentencePiece)
```python
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
tokenizer = Tokenizer(Unigram())
trainer = UnigramTrainer(
vocab_size=8000,
special_tokens=["<unk>", "<s>", "</s>"],
unk_token="<unk>"
)
tokenizer.train(files=["data.txt"], trainer=trainer)
```
**Advantages**:
- Probabilistic (finds most likely tokenization)
- Works well for languages without word boundaries
- Handles diverse linguistic contexts
**Trade-offs**:
- Computationally expensive to train
- More hyperparameters to tune
## Tokenization pipeline
Complete pipeline: **Normalization → Pre-tokenization → Model → Post-processing**
### Normalization
Clean and standardize text:
```python
from tokenizers.normalizers import NFD, StripAccents, Lowercase, Sequence
tokenizer.normalizer = Sequence([
NFD(), # Unicode normalization (decompose)
Lowercase(), # Convert to lowercase
StripAccents() # Remove accents
])
# Input: "Héllo WORLD"
# After normalization: "hello world"
```
**Common normalizers**:
- `NFD`, `NFC`, `NFKD`, `NFKC` - Unicode normalization forms
- `Lowercase()` - Convert to lowercase
- `StripAccents()` - Remove accents (é → e)
- `Strip()` - Remove whitespace
- `Replace(pattern, content)` - Regex replacement
### Pre-tokenization
Split text into word-like units:
```python
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence, ByteLevel
# Split on whitespace and punctuation
tokenizer.pre_tokenizer = Sequence([
Whitespace(),
Punctuation()
])
# Input: "Hello, world!"
# After pre-tokenization: ["Hello", ",", "world", "!"]
```
**Common pre-tokenizers**:
- `Whitespace()` - Split on spaces, tabs, newlines
- `ByteLevel()` - GPT-2 style byte-level splitting
- `Punctuation()` - Isolate punctuation
- `Digits(individual_digits=True)` - Split digits individually
- `Metaspace()` - Replace spaces with ▁ (SentencePiece style)
### Post-processing
Add special tokens for model input:
```python
from tokenizers.processors import TemplateProcessing
# BERT-style: [CLS] sentence [SEP]
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[
("[CLS]", 1),
("[SEP]", 2),
],
)
```
**Common patterns**:
```python
# GPT-2: sentence <|endoftext|>
TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[("<|endoftext|>", 50256)]
)
# RoBERTa: <s> sentence </s>
TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[("<s>", 0), ("</s>", 2)]
)
```
## Alignment tracking
Track token positions in original text:
```python
output = tokenizer.encode("Hello, world!")
# Get token offsets
for token, offset in zip(output.tokens, output.offsets):
start, end = offset
print(f"{token:10} → [{start:2}, {end:2}): {text[start:end]!r}")
# Output:
# hello → [ 0, 5): 'Hello'
# , → [ 5, 6): ','
# world → [ 7, 12): 'world'
# ! → [12, 13): '!'
```
**Use cases**:
- Named entity recognition (map predictions back to text)
- Question answering (extract answer spans)
- Token classification (align labels to original positions)
## Integration with transformers
### Load with AutoTokenizer
```python
from transformers import AutoTokenizer
# AutoTokenizer automatically uses fast tokenizers
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Check if using fast tokenizer
print(tokenizer.is_fast) # True
# Access underlying tokenizers.Tokenizer
fast_tokenizer = tokenizer.backend_tokenizer
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
```
### Convert custom tokenizer to transformers
```python
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
# Train custom tokenizer
tokenizer = Tokenizer(BPE())
# ... train tokenizer ...
tokenizer.save("my-tokenizer.json")
# Wrap for transformers
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_file="my-tokenizer.json",
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]"
)
# Use like any transformers tokenizer
outputs = transformers_tokenizer(
"Hello world",
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
```
## Common patterns
### Train from iterator (large datasets)
```python
from datasets import load_dataset
# Load dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
# Create batch iterator
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i:i + batch_size]["text"]
# Train tokenizer
tokenizer.train_from_iterator(
batch_iterator(),
trainer=trainer,
length=len(dataset) # For progress bar
)
```
**Performance**: Processes 1GB in ~10-20 minutes
### Enable truncation and padding
```python
# Enable truncation
tokenizer.enable_truncation(max_length=512)
# Enable padding
tokenizer.enable_padding(
pad_id=tokenizer.token_to_id("[PAD]"),
pad_token="[PAD]",
length=512 # Fixed length, or None for batch max
)
# Encode with both
output = tokenizer.encode("This is a long sentence that will be truncated...")
print(len(output.ids)) # 512
```
### Multi-processing
```python
from tokenizers import Tokenizer
from multiprocessing import Pool
# Load tokenizer
tokenizer = Tokenizer.from_file("tokenizer.json")
def encode_batch(texts):
return tokenizer.encode_batch(texts)
# Process large corpus in parallel
with Pool(8) as pool:
# Split corpus into chunks
chunk_size = 1000
chunks = [corpus[i:i+chunk_size] for i in range(0, len(corpus), chunk_size)]
# Encode in parallel
results = pool.map(encode_batch, chunks)
```
**Speedup**: 5-8× with 8 cores
## Performance benchmarks
### Training speed
| Corpus Size | BPE (30k vocab) | WordPiece (30k) | Unigram (8k) |
|-------------|-----------------|-----------------|--------------|
| 10 MB | 15 sec | 18 sec | 25 sec |
| 100 MB | 1.5 min | 2 min | 4 min |
| 1 GB | 15 min | 20 min | 40 min |
**Hardware**: 16-core CPU, tested on English Wikipedia
### Tokenization speed
| Implementation | 1 GB corpus | Throughput |
|----------------|-------------|---------------|
| Pure Python | ~20 minutes | ~50 MB/min |
| HF Tokenizers | ~15 seconds | ~4 GB/min |
| **Speedup** | **80×** | **80×** |
**Test**: English text, average sentence length 20 words
### Memory usage
| Task | Memory |
|-------------------------|---------|
| Load tokenizer | ~10 MB |
| Train BPE (30k vocab) | ~200 MB |
| Encode 1M sentences | ~500 MB |
## Supported models
Pre-trained tokenizers available via `from_pretrained()`:
**BERT family**:
- `bert-base-uncased`, `bert-large-cased`
- `distilbert-base-uncased`
- `roberta-base`, `roberta-large`
**GPT family**:
- `gpt2`, `gpt2-medium`, `gpt2-large`
- `distilgpt2`
**T5 family**:
- `t5-small`, `t5-base`, `t5-large`
- `google/flan-t5-xxl`
**Other**:
- `facebook/bart-base`, `facebook/mbart-large-cc25`
- `albert-base-v2`, `albert-xlarge-v2`
- `xlm-roberta-base`, `xlm-roberta-large`
Browse all: https://huggingface.co/models?library=tokenizers
## References
- **[Training Guide](references/training.md)** - Train custom tokenizers, configure trainers, handle large datasets
- **[Algorithms Deep Dive](references/algorithms.md)** - BPE, WordPiece, Unigram explained in detail
- **[Pipeline Components](references/pipeline.md)** - Normalizers, pre-tokenizers, post-processors, decoders
- **[Transformers Integration](references/integration.md)** - AutoTokenizer, PreTrainedTokenizerFast, special tokens
## Resources
- **Docs**: https://huggingface.co/docs/tokenizers
- **GitHub**: https://github.com/huggingface/tokenizers ⭐ 9,000+
- **Version**: 0.20.0+
- **Course**: https://huggingface.co/learn/nlp-course/chapter6/1
- **Paper**: BPE (Sennrich et al., 2016), WordPiece (Schuster & Nakajima, 2012)

View File

@@ -0,0 +1,653 @@
# Tokenization Algorithms Deep Dive
Comprehensive explanation of BPE, WordPiece, and Unigram algorithms.
## Byte-Pair Encoding (BPE)
### Algorithm overview
BPE iteratively merges the most frequent pair of tokens in a corpus.
**Training process**:
1. Initialize vocabulary with all characters
2. Count frequency of all adjacent token pairs
3. Merge most frequent pair into new token
4. Add new token to vocabulary
5. Update corpus with new token
6. Repeat until vocabulary size reached
### Step-by-step example
**Corpus**:
```
low: 5
lower: 2
newest: 6
widest: 3
```
**Iteration 1**:
```
Count pairs:
'e' + 's': 9 (newest: 6, widest: 3) ← most frequent
'l' + 'o': 7
'o' + 'w': 7
...
Merge: 'e' + 's' → 'es'
Updated corpus:
low: 5
lower: 2
newest: 6 → newes|t: 6
widest: 3 → wides|t: 3
Vocabulary: [a-z] + ['es']
```
**Iteration 2**:
```
Count pairs:
'es' + 't': 9 ← most frequent
'l' + 'o': 7
...
Merge: 'es' + 't' → 'est'
Updated corpus:
low: 5
lower: 2
newest: 6 → new|est: 6
widest: 3 → wid|est: 3
Vocabulary: [a-z] + ['es', 'est']
```
**Continue until desired vocabulary size...**
### Tokenization with trained BPE
Given vocabulary: `['l', 'o', 'w', 'e', 'r', 'n', 's', 't', 'i', 'd', 'es', 'est', 'lo', 'low', 'ne', 'new', 'newest', 'wi', 'wid', 'widest']`
Tokenize "lowest":
```
Step 1: Split into characters
['l', 'o', 'w', 'e', 's', 't']
Step 2: Apply merges in order learned during training
- Merge 'l' + 'o' → 'lo' (if this merge was learned)
- Merge 'lo' + 'w' → 'low' (if learned)
- Merge 'e' + 's' → 'es' (learned)
- Merge 'es' + 't' → 'est' (learned)
Final: ['low', 'est']
```
### Implementation
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
# Initialize
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
# Configure trainer
trainer = BpeTrainer(
vocab_size=1000,
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
# Train
corpus = [
"This is a sample corpus for BPE training.",
"BPE learns subword units from the training data.",
# ... more sentences
]
tokenizer.train_from_iterator(corpus, trainer=trainer)
# Use
output = tokenizer.encode("This is tokenization")
print(output.tokens) # ['This', 'is', 'token', 'ization']
```
### Byte-level BPE (GPT-2 variant)
**Problem**: Standard BPE has limited character coverage (256+ Unicode chars)
**Solution**: Operate on byte level (256 bytes)
```python
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
tokenizer = Tokenizer(BPE())
# Byte-level pre-tokenization
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()
# This handles ALL possible characters, including emojis
text = "Hello 🌍 世界"
tokens = tokenizer.encode(text).tokens
```
**Advantages**:
- Handles any Unicode character (256 byte coverage)
- No unknown tokens (worst case: bytes)
- Used by GPT-2, GPT-3, BART
**Trade-offs**:
- Slightly worse compression (bytes vs characters)
- More tokens for non-ASCII text
### BPE variants
**SentencePiece BPE**:
- Language-independent (no pre-tokenization)
- Treats input as raw byte stream
- Used by T5, ALBERT, XLNet
**Robust BPE**:
- Dropout during training (randomly skip merges)
- More robust tokenization at inference
- Reduces overfitting to training data
## WordPiece
### Algorithm overview
WordPiece is similar to BPE but uses a different merge selection criterion.
**Training process**:
1. Initialize vocabulary with all characters
2. Count frequency of all token pairs
3. Score each pair: `score = freq(pair) / (freq(first) × freq(second))`
4. Merge pair with highest score
5. Repeat until vocabulary size reached
### Why different scoring?
**BPE**: Merges most frequent pairs
- "aa" appears 100 times → high priority
- Even if 'a' appears 1000 times alone
**WordPiece**: Merges pairs that are semantically related
- "aa" appears 100 times, 'a' appears 1000 times → low score (100 / (1000 × 1000))
- "th" appears 50 times, 't' appears 60 times, 'h' appears 55 times → high score (50 / (60 × 55))
- Prioritizes pairs that appear together more than expected
### Step-by-step example
**Corpus**:
```
low: 5
lower: 2
newest: 6
widest: 3
```
**Iteration 1**:
```
Count frequencies:
'e': 11 (lower: 2, newest: 6, widest: 3)
's': 9
't': 9
...
Count pairs:
'e' + 's': 9 (newest: 6, widest: 3)
'es' + 't': 9 (newest: 6, widest: 3)
...
Compute scores:
score('e' + 's') = 9 / (11 × 9) = 0.091
score('es' + 't') = 9 / (9 × 9) = 0.111 ← highest score
score('l' + 'o') = 7 / (7 × 9) = 0.111 ← tied
Choose: 'es' + 't' → 'est' (or 'lo' if tied)
```
**Key difference**: WordPiece prioritizes rare combinations over frequent ones.
### Tokenization with WordPiece
Given vocabulary: `['##e', '##s', '##t', 'l', 'o', 'w', 'new', 'est', 'low']`
Tokenize "lowest":
```
Step 1: Find longest matching prefix
'lowest' → 'low' (matches)
Step 2: Find longest match for remainder
'est' → 'est' (matches)
Final: ['low', 'est']
```
**If no match**:
```
Tokenize "unknownword":
'unknownword' → no match
'unknown' → no match
'unkn' → no match
'un' → no match
'u' → no match
→ [UNK]
```
### Implementation
```python
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
# Initialize BERT-style tokenizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
# Normalization (lowercase, accent stripping)
tokenizer.normalizer = BertNormalizer(lowercase=True)
# Pre-tokenization (whitespace + punctuation)
tokenizer.pre_tokenizer = BertPreTokenizer()
# Configure trainer
trainer = WordPieceTrainer(
vocab_size=30522, # BERT vocab size
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
continuing_subword_prefix="##" # BERT uses ##
)
# Train
tokenizer.train_from_iterator(corpus, trainer=trainer)
# Use
output = tokenizer.encode("Tokenization works great!")
print(output.tokens) # ['token', '##ization', 'works', 'great', '!']
```
### Subword prefix
**BERT uses `##` prefix**:
```
"unbelievable" → ['un', '##believ', '##able']
```
**Why?**
- Indicates token is a continuation
- Allows reconstruction: remove ##, concatenate
- Helps model distinguish word boundaries
### WordPiece advantages
**Semantic merges**:
- Prioritizes meaningful combinations
- "qu" has high score (always together)
- "qx" has low score (rare combination)
**Better for morphology**:
- Captures affixes: un-, -ing, -ed
- Preserves word stems
**Trade-offs**:
- Slower training than BPE
- More memory (stores vocabulary, not merges)
- Original implementation not open-source (HF reimplementation)
## Unigram
### Algorithm overview
Unigram works backward: start with large vocabulary, remove tokens.
**Training process**:
1. Initialize with large vocabulary (all substrings)
2. Estimate probability of each token (frequency-based)
3. For each token, compute loss increase if removed
4. Remove 10-20% of tokens with lowest loss impact
5. Re-estimate probabilities
6. Repeat until desired vocabulary size
### Probabilistic tokenization
**Unigram assumption**: Each token is independent.
Given vocabulary with probabilities:
```
P('low') = 0.02
P('l') = 0.01
P('o') = 0.015
P('w') = 0.01
P('est') = 0.03
P('e') = 0.02
P('s') = 0.015
P('t') = 0.015
```
Tokenize "lowest":
```
Option 1: ['low', 'est']
P = P('low') × P('est') = 0.02 × 0.03 = 0.0006
Option 2: ['l', 'o', 'w', 'est']
P = 0.01 × 0.015 × 0.01 × 0.03 = 0.000000045
Option 3: ['low', 'e', 's', 't']
P = 0.02 × 0.02 × 0.015 × 0.015 = 0.0000009
Choose option 1 (highest probability)
```
### Viterbi algorithm
Finding best tokenization is expensive (exponential possibilities).
**Viterbi algorithm** (dynamic programming):
```python
def tokenize_viterbi(word, vocab, probs):
n = len(word)
# dp[i] = (best_prob, best_tokens) for word[:i]
dp = [{} for _ in range(n + 1)]
dp[0] = (0.0, []) # log probability
for i in range(1, n + 1):
best_prob = float('-inf')
best_tokens = []
# Try all possible last tokens
for j in range(i):
token = word[j:i]
if token in vocab:
prob = dp[j][0] + log(probs[token])
if prob > best_prob:
best_prob = prob
best_tokens = dp[j][1] + [token]
dp[i] = (best_prob, best_tokens)
return dp[n][1]
```
**Time complexity**: O(n² × vocab_size) vs O(2^n) brute force
### Implementation
```python
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
# Initialize
tokenizer = Tokenizer(Unigram())
# Configure trainer
trainer = UnigramTrainer(
vocab_size=8000,
special_tokens=["<unk>", "<s>", "</s>"],
unk_token="<unk>",
max_piece_length=16, # Max token length
n_sub_iterations=2, # EM iterations
shrinking_factor=0.75 # Remove 25% each iteration
)
# Train
tokenizer.train_from_iterator(corpus, trainer=trainer)
# Use
output = tokenizer.encode("Tokenization with Unigram")
print(output.tokens) # ['▁Token', 'ization', '▁with', '▁Un', 'igram']
```
### Unigram advantages
**Probabilistic**:
- Multiple valid tokenizations
- Can sample different tokenizations (data augmentation)
**Subword regularization**:
```python
# Sample different tokenizations
for _ in range(3):
tokens = tokenizer.encode("tokenization", is_pretokenized=False).tokens
print(tokens)
# Output (different each time):
# ['token', 'ization']
# ['tok', 'en', 'ization']
# ['token', 'iz', 'ation']
```
**Language-independent**:
- No word boundaries needed
- Works for CJK languages (Chinese, Japanese, Korean)
- Treats input as character stream
**Trade-offs**:
- Slower training (EM algorithm)
- More hyperparameters
- Larger model (stores probabilities)
## Algorithm comparison
### Training speed
| Algorithm | Small (10MB) | Medium (100MB) | Large (1GB) |
|------------|--------------|----------------|-------------|
| BPE | 10-15 sec | 1-2 min | 10-20 min |
| WordPiece | 15-20 sec | 2-3 min | 15-30 min |
| Unigram | 20-30 sec | 3-5 min | 30-60 min |
**Tested on**: 16-core CPU, 30k vocab
### Tokenization quality
Tested on English Wikipedia (perplexity measurement):
| Algorithm | Vocab Size | Tokens/Word | Unknown Rate |
|------------|------------|-------------|--------------|
| BPE | 30k | 1.3 | 0.5% |
| WordPiece | 30k | 1.2 | 1.2% |
| Unigram | 8k | 1.5 | 0.3% |
**Key observations**:
- WordPiece: Slightly better compression
- BPE: Lower unknown rate
- Unigram: Smallest vocab, good coverage
### Compression ratio
Characters per token (higher = better compression):
| Language | BPE (30k) | WordPiece (30k) | Unigram (8k) |
|----------|-----------|-----------------|--------------|
| English | 4.2 | 4.5 | 3.8 |
| Chinese | 2.1 | 2.3 | 2.5 |
| Arabic | 3.5 | 3.8 | 3.2 |
**Best for each**:
- English: WordPiece
- Chinese: Unigram (language-independent)
- Arabic: WordPiece
### Use case recommendations
**BPE** - Best for:
- English language models
- Code (handles symbols well)
- Fast training needed
- **Models**: GPT-2, GPT-3, RoBERTa, BART
**WordPiece** - Best for:
- Masked language modeling (BERT-style)
- Morphologically rich languages
- Semantic understanding tasks
- **Models**: BERT, DistilBERT, ELECTRA
**Unigram** - Best for:
- Multilingual models
- Languages without word boundaries (CJK)
- Data augmentation via subword regularization
- **Models**: T5, ALBERT, XLNet (via SentencePiece)
## Advanced topics
### Handling rare words
**BPE approach**:
```
"antidisestablishmentarianism"
→ ['anti', 'dis', 'establish', 'ment', 'arian', 'ism']
```
**WordPiece approach**:
```
"antidisestablishmentarianism"
→ ['anti', '##dis', '##establish', '##ment', '##arian', '##ism']
```
**Unigram approach**:
```
"antidisestablishmentarianism"
→ ['▁anti', 'dis', 'establish', 'ment', 'arian', 'ism']
```
### Handling numbers
**Challenge**: Infinite number combinations
**BPE solution**: Byte-level (handles any digit sequence)
```python
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = ByteLevel()
# Handles any number
"123456789" byte-level tokens
```
**WordPiece solution**: Digit pre-tokenization
```python
from tokenizers.pre_tokenizers import Digits
# Split digits individually or as groups
tokenizer.pre_tokenizer = Digits(individual_digits=True)
"123" ['1', '2', '3']
```
**Unigram solution**: Learns common number patterns
```python
# Learns patterns during training
"2023" ['202', '3'] or ['20', '23']
```
### Handling case sensitivity
**Lowercase (BERT)**:
```python
from tokenizers.normalizers import Lowercase
tokenizer.normalizer = Lowercase()
"Hello WORLD" "hello world" ['hello', 'world']
```
**Preserve case (GPT-2)**:
```python
# No case normalization
tokenizer.normalizer = None
"Hello WORLD" ['Hello', 'WORLD']
```
**Cased tokens (RoBERTa)**:
```python
# Learns separate tokens for different cases
Vocabulary: ['Hello', 'hello', 'HELLO', 'world', 'WORLD']
```
### Handling emojis and special characters
**Byte-level (GPT-2)**:
```python
tokenizer.pre_tokenizer = ByteLevel()
"Hello 🌍 👋" byte-level representation (always works)
```
**Unicode normalization**:
```python
from tokenizers.normalizers import NFKC
tokenizer.normalizer = NFKC()
"é" (composed) "é" (decomposed) normalized to one form
```
## Troubleshooting
### Issue: Poor subword splitting
**Symptom**:
```
"running" → ['r', 'u', 'n', 'n', 'i', 'n', 'g'] (too granular)
```
**Solutions**:
1. Increase vocabulary size
2. Train longer (more merge iterations)
3. Lower `min_frequency` threshold
### Issue: Too many unknown tokens
**Symptom**:
```
5% of tokens are [UNK]
```
**Solutions**:
1. Increase vocabulary size
2. Use byte-level BPE (no UNK possible)
3. Verify training corpus is representative
### Issue: Inconsistent tokenization
**Symptom**:
```
"running" → ['run', 'ning']
"runner" → ['r', 'u', 'n', 'n', 'e', 'r']
```
**Solutions**:
1. Check normalization consistency
2. Ensure pre-tokenization is deterministic
3. Use Unigram for probabilistic variance
## Best practices
1. **Match algorithm to model architecture**:
- BERT-style → WordPiece
- GPT-style → BPE
- T5-style → Unigram
2. **Use byte-level for multilingual**:
- Handles any Unicode
- No unknown tokens
3. **Test on representative data**:
- Measure compression ratio
- Check unknown token rate
- Inspect sample tokenizations
4. **Version control tokenizers**:
- Save with model
- Document special tokens
- Track vocabulary changes

View File

@@ -0,0 +1,637 @@
# Transformers Integration
Complete guide to using HuggingFace Tokenizers with the Transformers library.
## AutoTokenizer
The easiest way to load tokenizers.
### Loading pretrained tokenizers
```python
from transformers import AutoTokenizer
# Load from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Check if using fast tokenizer (Rust-based)
print(tokenizer.is_fast) # True
# Access underlying tokenizers.Tokenizer
if tokenizer.is_fast:
fast_tokenizer = tokenizer.backend_tokenizer
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
```
### Fast vs slow tokenizers
| Feature | Fast (Rust) | Slow (Python) |
|--------------------------|----------------|---------------|
| Speed | 5-10× faster | Baseline |
| Alignment tracking | ✅ Full support | ❌ Limited |
| Batch processing | ✅ Optimized | ⚠️ Slower |
| Offset mapping | ✅ Yes | ❌ No |
| Installation | `tokenizers` | Built-in |
**Always use fast tokenizers when available.**
### Check available tokenizers
```python
from transformers import TOKENIZER_MAPPING
# List all fast tokenizers
for config_class, (slow, fast) in TOKENIZER_MAPPING.items():
if fast is not None:
print(f"{config_class.__name__}: {fast.__name__}")
```
## PreTrainedTokenizerFast
Wrap custom tokenizers for transformers.
### Convert custom tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast
# Train custom tokenizer
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(
vocab_size=30000,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
tokenizer.train(files=["corpus.txt"], trainer=trainer)
# Save tokenizer
tokenizer.save("my-tokenizer.json")
# Wrap for transformers
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_file="my-tokenizer.json",
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]"
)
# Save in transformers format
transformers_tokenizer.save_pretrained("my-tokenizer")
```
**Result**: Directory with `tokenizer.json` + `tokenizer_config.json` + `special_tokens_map.json`
### Use like any transformers tokenizer
```python
# Load
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("my-tokenizer")
# Encode with all transformers features
outputs = tokenizer(
"Hello world",
padding="max_length",
truncation=True,
max_length=128,
return_tensors="pt"
)
print(outputs.keys())
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
```
## Special tokens
### Default special tokens
| Model Family | CLS/BOS | SEP/EOS | PAD | UNK | MASK |
|--------------|---------|---------------|---------|---------|---------|
| BERT | [CLS] | [SEP] | [PAD] | [UNK] | [MASK] |
| GPT-2 | - | <\|endoftext\|> | <\|endoftext\|> | <\|endoftext\|> | - |
| RoBERTa | <s> | </s> | <pad> | <unk> | <mask> |
| T5 | - | </s> | <pad> | <unk> | - |
### Adding special tokens
```python
# Add new special tokens
special_tokens_dict = {
"additional_special_tokens": ["<|image|>", "<|video|>", "<|audio|>"]
}
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
print(f"Added {num_added_tokens} tokens")
# Resize model embeddings
model.resize_token_embeddings(len(tokenizer))
# Use new tokens
text = "This is an image: <|image|>"
tokens = tokenizer.encode(text)
```
### Adding regular tokens
```python
# Add domain-specific tokens
new_tokens = ["COVID-19", "mRNA", "vaccine"]
num_added = tokenizer.add_tokens(new_tokens)
# These are NOT special tokens (can be split if needed)
tokenizer.add_tokens(new_tokens, special_tokens=False)
# These ARE special tokens (never split)
tokenizer.add_tokens(new_tokens, special_tokens=True)
```
## Encoding and decoding
### Basic encoding
```python
# Single sentence
text = "Hello, how are you?"
encoded = tokenizer(text)
print(encoded)
# {'input_ids': [101, 7592, 1010, 2129, 2024, 2017, 1029, 102],
# 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0],
# 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
```
### Batch encoding
```python
# Multiple sentences
texts = ["Hello world", "How are you?", "I am fine"]
encoded = tokenizer(texts, padding=True, truncation=True, max_length=10)
print(encoded['input_ids'])
# [[101, 7592, 2088, 102, 0, 0, 0, 0, 0, 0],
# [101, 2129, 2024, 2017, 1029, 102, 0, 0, 0, 0],
# [101, 1045, 2572, 2986, 102, 0, 0, 0, 0, 0]]
```
### Return tensors
```python
# Return PyTorch tensors
outputs = tokenizer("Hello world", return_tensors="pt")
print(outputs['input_ids'].shape) # torch.Size([1, 5])
# Return TensorFlow tensors
outputs = tokenizer("Hello world", return_tensors="tf")
# Return NumPy arrays
outputs = tokenizer("Hello world", return_tensors="np")
# Return lists (default)
outputs = tokenizer("Hello world", return_tensors=None)
```
### Decoding
```python
# Decode token IDs
ids = [101, 7592, 2088, 102]
text = tokenizer.decode(ids)
print(text) # "[CLS] hello world [SEP]"
# Skip special tokens
text = tokenizer.decode(ids, skip_special_tokens=True)
print(text) # "hello world"
# Batch decode
batch_ids = [[101, 7592, 102], [101, 2088, 102]]
texts = tokenizer.batch_decode(batch_ids, skip_special_tokens=True)
print(texts) # ["hello", "world"]
```
## Padding and truncation
### Padding strategies
```python
# Pad to max length in batch
tokenizer(texts, padding="longest")
# Pad to model max length
tokenizer(texts, padding="max_length", max_length=128)
# No padding
tokenizer(texts, padding=False)
# Pad to multiple of value (for efficient computation)
tokenizer(texts, padding="max_length", max_length=128, pad_to_multiple_of=8)
# Result: length will be 128 (already multiple of 8)
```
### Truncation strategies
```python
# Truncate to max length
tokenizer(text, truncation=True, max_length=10)
# Only truncate first sequence (for pairs)
tokenizer(text1, text2, truncation="only_first", max_length=20)
# Only truncate second sequence
tokenizer(text1, text2, truncation="only_second", max_length=20)
# Truncate longest first (default for pairs)
tokenizer(text1, text2, truncation="longest_first", max_length=20)
# No truncation (error if too long)
tokenizer(text, truncation=False)
```
### Stride for long documents
```python
# For documents longer than max_length
text = "Very long document " * 1000
# Encode with overlap
encodings = tokenizer(
text,
max_length=512,
stride=128, # Overlap between chunks
truncation=True,
return_overflowing_tokens=True,
return_offsets_mapping=True
)
# Get all chunks
num_chunks = len(encodings['input_ids'])
print(f"Split into {num_chunks} chunks")
# Each chunk overlaps by stride tokens
for i, chunk in enumerate(encodings['input_ids']):
print(f"Chunk {i}: {len(chunk)} tokens")
```
**Use case**: Long document QA, sliding window inference
## Alignment and offsets
### Offset mapping
```python
# Get character offsets for each token
encoded = tokenizer("Hello, world!", return_offsets_mapping=True)
for token, (start, end) in zip(
encoded.tokens(),
encoded['offset_mapping'][0]
):
print(f"{token:10s} → [{start:2d}, {end:2d})")
# Output:
# [CLS] → [ 0, 0)
# Hello → [ 0, 5)
# , → [ 5, 6)
# world → [ 7, 12)
# ! → [12, 13)
# [SEP] → [ 0, 0)
```
### Word IDs
```python
# Get word index for each token
encoded = tokenizer("Hello world", return_offsets_mapping=True)
word_ids = encoded.word_ids()
print(word_ids)
# [None, 0, 1, None]
# None = special token, 0 = first word, 1 = second word
```
**Use case**: Token classification (NER, POS tagging)
### Character to token mapping
```python
text = "Machine learning is awesome"
encoded = tokenizer(text, return_offsets_mapping=True)
# Find token for character position
char_pos = 8 # "l" in "learning"
token_idx = encoded.char_to_token(char_pos)
print(f"Character {char_pos} is in token {token_idx}: {encoded.tokens()[token_idx]}")
# Character 8 is in token 2: learning
```
**Use case**: Question answering (map answer character span to tokens)
### Sequence pairs
```python
# Encode sentence pair
encoded = tokenizer("Question here", "Answer here", return_offsets_mapping=True)
# Get sequence IDs (which sequence each token belongs to)
sequence_ids = encoded.sequence_ids()
print(sequence_ids)
# [None, 0, 0, 0, None, 1, 1, 1, None]
# None = special token, 0 = question, 1 = answer
```
## Model integration
### Use with transformers models
```python
from transformers import AutoModel, AutoTokenizer
import torch
# Load model and tokenizer
model = AutoModel.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Tokenize
text = "Hello world"
inputs = tokenizer(text, return_tensors="pt")
# Forward pass
with torch.no_grad():
outputs = model(**inputs)
# Get embeddings
last_hidden_state = outputs.last_hidden_state
print(last_hidden_state.shape) # [1, seq_len, hidden_size]
```
### Custom model with custom tokenizer
```python
from transformers import BertConfig, BertModel
# Train custom tokenizer
from tokenizers import Tokenizer, models, trainers
tokenizer = Tokenizer(models.BPE())
trainer = trainers.BpeTrainer(vocab_size=30000)
tokenizer.train(files=["data.txt"], trainer=trainer)
# Wrap for transformers
from transformers import PreTrainedTokenizerFast
fast_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
unk_token="[UNK]",
pad_token="[PAD]"
)
# Create model with custom vocab size
config = BertConfig(vocab_size=30000)
model = BertModel(config)
# Use together
inputs = fast_tokenizer("Hello world", return_tensors="pt")
outputs = model(**inputs)
```
### Save and load together
```python
# Save both
model.save_pretrained("my-model")
tokenizer.save_pretrained("my-model")
# Directory structure:
# my-model/
# ├── config.json
# ├── pytorch_model.bin
# ├── tokenizer.json
# ├── tokenizer_config.json
# └── special_tokens_map.json
# Load both
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("my-model")
tokenizer = AutoTokenizer.from_pretrained("my-model")
```
## Advanced features
### Multimodal tokenization
```python
from transformers import AutoTokenizer
# LLaVA-style (image + text)
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf")
# Add image placeholder token
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
# Use in prompt
text = "Describe this image: <image>"
inputs = tokenizer(text, return_tensors="pt")
```
### Template formatting
```python
# Chat template
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi! How can I help?"},
{"role": "user", "content": "What's the weather?"}
]
# Apply chat template (if tokenizer has one)
if hasattr(tokenizer, "apply_chat_template"):
text = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(text, return_tensors="pt")
```
### Custom template
```python
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
# Define chat template
tokenizer.chat_template = """
{%- for message in messages %}
{%- if message['role'] == 'system' %}
System: {{ message['content'] }}\\n
{%- elif message['role'] == 'user' %}
User: {{ message['content'] }}\\n
{%- elif message['role'] == 'assistant' %}
Assistant: {{ message['content'] }}\\n
{%- endif %}
{%- endfor %}
Assistant:
"""
# Use template
text = tokenizer.apply_chat_template(messages, tokenize=False)
```
## Performance optimization
### Batch processing
```python
# Process large datasets efficiently
from datasets import load_dataset
dataset = load_dataset("imdb", split="train[:1000]")
# Tokenize in batches
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=512
)
# Map over dataset (batched)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
batch_size=1000,
num_proc=4 # Parallel processing
)
```
### Caching
```python
# Enable caching for repeated tokenization
tokenizer = AutoTokenizer.from_pretrained(
"bert-base-uncased",
use_fast=True,
cache_dir="./cache" # Cache tokenizer files
)
# Tokenize with caching
from functools import lru_cache
@lru_cache(maxsize=10000)
def cached_tokenize(text):
return tuple(tokenizer.encode(text))
# Reuses cached results for repeated inputs
```
### Memory efficiency
```python
# For very large datasets, use streaming
from datasets import load_dataset
dataset = load_dataset("pile", split="train", streaming=True)
def process_batch(batch):
# Tokenize
tokens = tokenizer(batch["text"], truncation=True, max_length=512)
# Process tokens...
return tokens
# Process in chunks (memory efficient)
for batch in dataset.batch(batch_size=1000):
processed = process_batch(batch)
```
## Troubleshooting
### Issue: Tokenizer not fast
**Symptom**:
```python
tokenizer.is_fast # False
```
**Solution**: Install tokenizers library
```bash
pip install tokenizers
```
### Issue: Special tokens not working
**Symptom**: Special tokens are split into subwords
**Solution**: Add as special tokens, not regular tokens
```python
# Wrong
tokenizer.add_tokens(["<|image|>"])
# Correct
tokenizer.add_special_tokens({"additional_special_tokens": ["<|image|>"]})
```
### Issue: Offset mapping not available
**Symptom**:
```python
tokenizer("text", return_offsets_mapping=True)
# Error: return_offsets_mapping not supported
```
**Solution**: Use fast tokenizer
```python
from transformers import AutoTokenizer
# Load fast version
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
```
### Issue: Padding inconsistent
**Symptom**: Some sequences padded, others not
**Solution**: Specify padding strategy
```python
# Explicit padding
tokenizer(
texts,
padding="max_length", # or "longest"
max_length=128
)
```
## Best practices
1. **Always use fast tokenizers**:
- 5-10× faster
- Full alignment tracking
- Better batch processing
2. **Save tokenizer with model**:
- Ensures reproducibility
- Prevents version mismatches
3. **Use batch processing for datasets**:
- Tokenize with `.map(batched=True)`
- Set `num_proc` for parallelism
4. **Enable caching for repeated inputs**:
- Use `lru_cache` for inference
- Cache tokenizer files with `cache_dir`
5. **Handle special tokens properly**:
- Use `add_special_tokens()` for never-split tokens
- Resize embeddings after adding tokens
6. **Test alignment for downstream tasks**:
- Verify `offset_mapping` is correct
- Test `char_to_token()` on samples
7. **Version control tokenizer config**:
- Save `tokenizer_config.json`
- Document custom templates
- Track vocabulary changes

View File

@@ -0,0 +1,723 @@
# Tokenization Pipeline Components
Complete guide to normalizers, pre-tokenizers, models, post-processors, and decoders.
## Pipeline overview
**Full tokenization pipeline**:
```
Raw Text
Normalization (cleaning, lowercasing)
Pre-tokenization (split into words)
Model (apply BPE/WordPiece/Unigram)
Post-processing (add special tokens)
Token IDs
```
**Decoding reverses the process**:
```
Token IDs
Decoder (handle special encodings)
Raw Text
```
## Normalizers
Clean and standardize input text.
### Common normalizers
**Lowercase**:
```python
from tokenizers.normalizers import Lowercase
tokenizer.normalizer = Lowercase()
# Input: "Hello WORLD"
# Output: "hello world"
```
**Unicode normalization**:
```python
from tokenizers.normalizers import NFD, NFC, NFKD, NFKC
# NFD: Canonical decomposition
tokenizer.normalizer = NFD()
# "é" → "e" + "́" (separate characters)
# NFC: Canonical composition (default)
tokenizer.normalizer = NFC()
# "e" + "́" → "é" (composed)
# NFKD: Compatibility decomposition
tokenizer.normalizer = NFKD()
# "fi" → "f" + "i"
# NFKC: Compatibility composition
tokenizer.normalizer = NFKC()
# Most aggressive normalization
```
**Strip accents**:
```python
from tokenizers.normalizers import StripAccents
tokenizer.normalizer = StripAccents()
# Input: "café"
# Output: "cafe"
```
**Whitespace handling**:
```python
from tokenizers.normalizers import Strip, StripAccents
# Remove leading/trailing whitespace
tokenizer.normalizer = Strip()
# Input: " hello "
# Output: "hello"
```
**Replace patterns**:
```python
from tokenizers.normalizers import Replace
# Replace newlines with spaces
tokenizer.normalizer = Replace("\\n", " ")
# Input: "hello\\nworld"
# Output: "hello world"
```
### Combining normalizers
```python
from tokenizers.normalizers import Sequence, NFD, Lowercase, StripAccents
# BERT-style normalization
tokenizer.normalizer = Sequence([
NFD(), # Unicode decomposition
Lowercase(), # Convert to lowercase
StripAccents() # Remove accents
])
# Input: "Café au Lait"
# After NFD: "Café au Lait" (e + ́)
# After Lowercase: "café au lait"
# After StripAccents: "cafe au lait"
```
### Use case examples
**Case-insensitive model (BERT)**:
```python
from tokenizers.normalizers import BertNormalizer
# All-in-one BERT normalization
tokenizer.normalizer = BertNormalizer(
clean_text=True, # Remove control characters
handle_chinese_chars=True, # Add spaces around Chinese
strip_accents=True, # Remove accents
lowercase=True # Lowercase
)
```
**Case-sensitive model (GPT-2)**:
```python
# Minimal normalization
tokenizer.normalizer = NFC() # Only normalize Unicode
```
**Multilingual (mBERT)**:
```python
# Preserve scripts, normalize form
tokenizer.normalizer = NFKC()
```
## Pre-tokenizers
Split text into word-like units before tokenization.
### Whitespace splitting
```python
from tokenizers.pre_tokenizers import Whitespace
tokenizer.pre_tokenizer = Whitespace()
# Input: "Hello world! How are you?"
# Output: [("Hello", (0, 5)), ("world!", (6, 12)), ("How", (13, 16)), ("are", (17, 20)), ("you?", (21, 25))]
```
### Punctuation isolation
```python
from tokenizers.pre_tokenizers import Punctuation
tokenizer.pre_tokenizer = Punctuation()
# Input: "Hello, world!"
# Output: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
```
### Byte-level (GPT-2)
```python
from tokenizers.pre_tokenizers import ByteLevel
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
# Input: "Hello world"
# Output: Byte-level tokens with Ġ prefix for spaces
# [("ĠHello", ...), ("Ġworld", ...)]
```
**Key feature**: Handles ALL Unicode characters (256 byte combinations)
### Metaspace (SentencePiece)
```python
from tokenizers.pre_tokenizers import Metaspace
tokenizer.pre_tokenizer = Metaspace(replacement="", add_prefix_space=True)
# Input: "Hello world"
# Output: [("▁Hello", ...), ("▁world", ...)]
```
**Used by**: T5, ALBERT (via SentencePiece)
### Digits splitting
```python
from tokenizers.pre_tokenizers import Digits
# Split digits individually
tokenizer.pre_tokenizer = Digits(individual_digits=True)
# Input: "Room 123"
# Output: [("Room", ...), ("1", ...), ("2", ...), ("3", ...)]
# Keep digits together
tokenizer.pre_tokenizer = Digits(individual_digits=False)
# Input: "Room 123"
# Output: [("Room", ...), ("123", ...)]
```
### BERT pre-tokenizer
```python
from tokenizers.pre_tokenizers import BertPreTokenizer
tokenizer.pre_tokenizer = BertPreTokenizer()
# Splits on whitespace and punctuation, preserves CJK
# Input: "Hello, 世界!"
# Output: [("Hello", ...), (",", ...), ("世", ...), ("界", ...), ("!", ...)]
```
### Combining pre-tokenizers
```python
from tokenizers.pre_tokenizers import Sequence, Whitespace, Punctuation
tokenizer.pre_tokenizer = Sequence([
Whitespace(), # Split on whitespace first
Punctuation() # Then isolate punctuation
])
# Input: "Hello, world!"
# After Whitespace: [("Hello,", ...), ("world!", ...)]
# After Punctuation: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
```
### Pre-tokenizer comparison
| Pre-tokenizer | Use Case | Example |
|-------------------|---------------------------------|--------------------------------------------|
| Whitespace | Simple English | "Hello world" → ["Hello", "world"] |
| Punctuation | Isolate symbols | "world!" → ["world", "!"] |
| ByteLevel | Multilingual, emojis | "🌍" → byte tokens |
| Metaspace | SentencePiece-style | "Hello" → ["▁Hello"] |
| BertPreTokenizer | BERT-style (CJK aware) | "世界" → ["世", "界"] |
| Digits | Handle numbers | "123" → ["1", "2", "3"] or ["123"] |
## Models
Core tokenization algorithms.
### BPE Model
```python
from tokenizers.models import BPE
model = BPE(
vocab=None, # Or provide pre-built vocab
merges=None, # Or provide merge rules
unk_token="[UNK]", # Unknown token
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False # Keep unknown tokens separate
)
tokenizer = Tokenizer(model)
```
**Parameters**:
- `vocab`: Dict of token → id
- `merges`: List of merge rules `["a b", "ab c"]`
- `unk_token`: Token for unknown words
- `continuing_subword_prefix`: Prefix for subwords (empty for GPT-2)
- `end_of_word_suffix`: Suffix for last subword (empty for GPT-2)
### WordPiece Model
```python
from tokenizers.models import WordPiece
model = WordPiece(
vocab=None,
unk_token="[UNK]",
max_input_chars_per_word=100, # Max word length
continuing_subword_prefix="##" # BERT-style prefix
)
tokenizer = Tokenizer(model)
```
**Key difference**: Uses `##` prefix for continuing subwords.
### Unigram Model
```python
from tokenizers.models import Unigram
model = Unigram(
vocab=None, # List of (token, score) tuples
unk_id=0, # ID for unknown token
byte_fallback=False # Fall back to bytes if no match
)
tokenizer = Tokenizer(model)
```
**Probabilistic**: Selects tokenization with highest probability.
### WordLevel Model
```python
from tokenizers.models import WordLevel
# Simple word-to-ID mapping (no subwords)
model = WordLevel(
vocab=None,
unk_token="[UNK]"
)
tokenizer = Tokenizer(model)
```
**Warning**: Requires huge vocabulary (one token per word).
## Post-processors
Add special tokens and format output.
### Template processing
**BERT-style** (`[CLS] sentence [SEP]`):
```python
from tokenizers.processors import TemplateProcessing
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[
("[CLS]", 101),
("[SEP]", 102),
],
)
# Single sentence
output = tokenizer.encode("Hello world")
# [101, ..., 102] ([CLS] hello world [SEP])
# Sentence pair
output = tokenizer.encode("Hello", "world")
# [101, ..., 102, ..., 102] ([CLS] hello [SEP] world [SEP])
```
**GPT-2 style** (`sentence <|endoftext|>`):
```python
tokenizer.post_processor = TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[
("<|endoftext|>", 50256),
],
)
```
**RoBERTa style** (`<s> sentence </s>`):
```python
tokenizer.post_processor = TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[
("<s>", 0),
("</s>", 2),
],
)
```
**T5 style** (no special tokens):
```python
# T5 doesn't add special tokens via post-processor
tokenizer.post_processor = None
```
### RobertaProcessing
```python
from tokenizers.processors import RobertaProcessing
tokenizer.post_processor = RobertaProcessing(
sep=("</s>", 2),
cls=("<s>", 0),
add_prefix_space=True, # Add space before first token
trim_offsets=True # Trim leading space from offsets
)
```
### ByteLevelProcessing
```python
from tokenizers.processors import ByteLevel as ByteLevelProcessing
tokenizer.post_processor = ByteLevelProcessing(
trim_offsets=True # Remove Ġ from offsets
)
```
## Decoders
Convert token IDs back to text.
### ByteLevel decoder
```python
from tokenizers.decoders import ByteLevel
tokenizer.decoder = ByteLevel()
# Handles byte-level tokens
# ["ĠHello", "Ġworld"] → "Hello world"
```
### WordPiece decoder
```python
from tokenizers.decoders import WordPiece
tokenizer.decoder = WordPiece(prefix="##")
# Removes ## prefix and concatenates
# ["token", "##ization"] → "tokenization"
```
### Metaspace decoder
```python
from tokenizers.decoders import Metaspace
tokenizer.decoder = Metaspace(replacement="", add_prefix_space=True)
# Converts ▁ back to spaces
# ["▁Hello", "▁world"] → "Hello world"
```
### BPEDecoder
```python
from tokenizers.decoders import BPEDecoder
tokenizer.decoder = BPEDecoder(suffix="</w>")
# Removes suffix and concatenates
# ["token", "ization</w>"] → "tokenization"
```
### Sequence decoder
```python
from tokenizers.decoders import Sequence, ByteLevel, Strip
tokenizer.decoder = Sequence([
ByteLevel(), # Decode byte-level first
Strip(' ', 1, 1) # Strip leading/trailing spaces
])
```
## Complete pipeline examples
### BERT tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
from tokenizers.processors import TemplateProcessing
from tokenizers.decoders import WordPiece as WordPieceDecoder
# Model
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
# Normalization
tokenizer.normalizer = BertNormalizer(lowercase=True)
# Pre-tokenization
tokenizer.pre_tokenizer = BertPreTokenizer()
# Post-processing
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
)
# Decoder
tokenizer.decoder = WordPieceDecoder(prefix="##")
# Enable padding
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]")
# Enable truncation
tokenizer.enable_truncation(max_length=512)
```
### GPT-2 tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.normalizers import NFC
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.processors import TemplateProcessing
# Model
tokenizer = Tokenizer(BPE())
# Normalization (minimal)
tokenizer.normalizer = NFC()
# Byte-level pre-tokenization
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
# Post-processing
tokenizer.post_processor = TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[("<|endoftext|>", 50256)],
)
# Byte-level decoder
tokenizer.decoder = ByteLevelDecoder()
```
### T5 tokenizer (SentencePiece-style)
```python
from tokenizers import Tokenizer
from tokenizers.models import Unigram
from tokenizers.normalizers import NFKC
from tokenizers.pre_tokenizers import Metaspace
from tokenizers.decoders import Metaspace as MetaspaceDecoder
# Model
tokenizer = Tokenizer(Unigram())
# Normalization
tokenizer.normalizer = NFKC()
# Metaspace pre-tokenization
tokenizer.pre_tokenizer = Metaspace(replacement="", add_prefix_space=True)
# No post-processing (T5 doesn't add CLS/SEP)
tokenizer.post_processor = None
# Metaspace decoder
tokenizer.decoder = MetaspaceDecoder(replacement="", add_prefix_space=True)
```
## Alignment tracking
Track token positions in original text.
### Basic alignment
```python
text = "Hello, world!"
output = tokenizer.encode(text)
for token, (start, end) in zip(output.tokens, output.offsets):
print(f"{token:10s} → [{start:2d}, {end:2d}): {text[start:end]!r}")
# Output:
# [CLS] → [ 0, 0): ''
# hello → [ 0, 5): 'Hello'
# , → [ 5, 6): ','
# world → [ 7, 12): 'world'
# ! → [12, 13): '!'
# [SEP] → [ 0, 0): ''
```
### Word-level alignment
```python
# Get word_ids (which word each token belongs to)
encoding = tokenizer.encode("Hello world")
word_ids = encoding.word_ids
print(word_ids)
# [None, 0, 0, 1, None]
# None = special token, 0 = first word, 1 = second word
```
**Use case**: Token classification (NER)
```python
# Align predictions to words
predictions = ["O", "B-PER", "I-PER", "O", "O"]
word_predictions = {}
for token_idx, word_idx in enumerate(encoding.word_ids):
if word_idx is not None and word_idx not in word_predictions:
word_predictions[word_idx] = predictions[token_idx]
print(word_predictions)
# {0: "B-PER", 1: "O"} # First word is PERSON, second is OTHER
```
### Span alignment
```python
# Find token span for character span
text = "Machine learning is awesome"
char_start, char_end = 8, 16 # "learning"
encoding = tokenizer.encode(text)
# Find token span
token_start = encoding.char_to_token(char_start)
token_end = encoding.char_to_token(char_end - 1) + 1
print(f"Tokens {token_start}:{token_end} = {encoding.tokens[token_start:token_end]}")
# Tokens 2:3 = ['learning']
```
**Use case**: Question answering (extract answer span)
## Custom components
### Custom normalizer
```python
from tokenizers import NormalizedString, Normalizer
class CustomNormalizer:
def normalize(self, normalized: NormalizedString):
# Custom normalization logic
normalized.lowercase()
normalized.replace(" ", " ") # Replace double spaces
# Use custom normalizer
tokenizer.normalizer = CustomNormalizer()
```
### Custom pre-tokenizer
```python
from tokenizers import PreTokenizedString
class CustomPreTokenizer:
def pre_tokenize(self, pretok: PreTokenizedString):
# Custom pre-tokenization logic
pretok.split(lambda i, char: char.isspace())
tokenizer.pre_tokenizer = CustomPreTokenizer()
```
## Troubleshooting
### Issue: Misaligned offsets
**Symptom**: Offsets don't match original text
```python
text = " hello" # Leading spaces
offsets = [(0, 5)] # Expects " hel"
```
**Solution**: Check normalization strips spaces
```python
# Preserve offsets
tokenizer.normalizer = Sequence([
Strip(), # This changes offsets!
])
# Use trim_offsets in post-processor instead
tokenizer.post_processor = ByteLevelProcessing(trim_offsets=True)
```
### Issue: Special tokens not added
**Symptom**: No [CLS] or [SEP] in output
**Solution**: Check post-processor is set
```python
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
)
```
### Issue: Incorrect decoding
**Symptom**: Decoded text has ## or ▁
**Solution**: Set correct decoder
```python
# For WordPiece
tokenizer.decoder = WordPieceDecoder(prefix="##")
# For SentencePiece
tokenizer.decoder = MetaspaceDecoder(replacement="")
```
## Best practices
1. **Match pipeline to model architecture**:
- BERT → BertNormalizer + BertPreTokenizer + WordPiece
- GPT-2 → NFC + ByteLevel + BPE
- T5 → NFKC + Metaspace + Unigram
2. **Test pipeline on sample inputs**:
- Check normalization doesn't over-normalize
- Verify pre-tokenization splits correctly
- Ensure decoding reconstructs text
3. **Preserve alignment for downstream tasks**:
- Use `trim_offsets` instead of stripping in normalizer
- Test `char_to_token()` on sample spans
4. **Document your pipeline**:
- Save complete tokenizer config
- Document special tokens
- Note any custom components

View File

@@ -0,0 +1,565 @@
# Training Custom Tokenizers
Complete guide to training tokenizers from scratch.
## Training workflow
### Step 1: Choose tokenization algorithm
**Decision tree**:
- **GPT-style model** → BPE
- **BERT-style model** → WordPiece
- **Multilingual/No word boundaries** → Unigram
### Step 2: Prepare training data
```python
# Option 1: From files
files = ["train.txt", "validation.txt"]
# Option 2: From Python list
texts = [
"This is the first sentence.",
"This is the second sentence.",
# ... more texts
]
# Option 3: From dataset iterator
from datasets import load_dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i:i + batch_size]["text"]
```
### Step 3: Initialize tokenizer
**BPE example**:
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()
trainer = BpeTrainer(
vocab_size=50000,
min_frequency=2,
special_tokens=["<|endoftext|>", "<|padding|>"],
show_progress=True
)
```
**WordPiece example**:
```python
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
tokenizer.normalizer = BertNormalizer(lowercase=True)
tokenizer.pre_tokenizer = BertPreTokenizer()
trainer = WordPieceTrainer(
vocab_size=30522,
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
continuing_subword_prefix="##",
show_progress=True
)
```
**Unigram example**:
```python
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
tokenizer = Tokenizer(Unigram())
trainer = UnigramTrainer(
vocab_size=8000,
special_tokens=["<unk>", "<s>", "</s>", "<pad>"],
unk_token="<unk>",
show_progress=True
)
```
### Step 4: Train
```python
# From files
tokenizer.train(files=files, trainer=trainer)
# From iterator (recommended for large datasets)
tokenizer.train_from_iterator(
batch_iterator(),
trainer=trainer,
length=len(dataset) # Optional, for progress bar
)
```
**Training time** (30k vocab on 16-core CPU):
- 10 MB: 15-30 seconds
- 100 MB: 1-3 minutes
- 1 GB: 15-30 minutes
- 10 GB: 2-4 hours
### Step 5: Add post-processing
```python
from tokenizers.processors import TemplateProcessing
# BERT-style
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B [SEP]",
special_tokens=[
("[CLS]", tokenizer.token_to_id("[CLS]")),
("[SEP]", tokenizer.token_to_id("[SEP]")),
],
)
# GPT-2 style
tokenizer.post_processor = TemplateProcessing(
single="$A <|endoftext|>",
special_tokens=[
("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")),
],
)
```
### Step 6: Save
```python
# Save to JSON
tokenizer.save("my-tokenizer.json")
# Save to directory (for transformers)
tokenizer.save("my-tokenizer-dir/tokenizer.json")
# Convert to transformers format
from transformers import PreTrainedTokenizerFast
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]"
)
transformers_tokenizer.save_pretrained("my-tokenizer-dir")
```
## Trainer configuration
### BpeTrainer parameters
```python
from tokenizers.trainers import BpeTrainer
trainer = BpeTrainer(
vocab_size=30000, # Target vocabulary size
min_frequency=2, # Minimum frequency for merges
special_tokens=["[UNK]"], # Special tokens (added first)
limit_alphabet=1000, # Limit initial alphabet size
initial_alphabet=[], # Pre-defined initial characters
show_progress=True, # Show progress bar
continuing_subword_prefix="", # Prefix for continuing subwords
end_of_word_suffix="" # Suffix for end of words
)
```
**Parameter tuning**:
- **vocab_size**: Start with 30k for English, 50k for multilingual
- **min_frequency**: 2-5 for large corpora, 1 for small
- **limit_alphabet**: Reduce for non-English (CJK languages)
### WordPieceTrainer parameters
```python
from tokenizers.trainers import WordPieceTrainer
trainer = WordPieceTrainer(
vocab_size=30522, # BERT uses 30,522
min_frequency=2,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
limit_alphabet=1000,
continuing_subword_prefix="##", # BERT-style prefix
show_progress=True
)
```
### UnigramTrainer parameters
```python
from tokenizers.trainers import UnigramTrainer
trainer = UnigramTrainer(
vocab_size=8000, # Typically smaller than BPE/WordPiece
special_tokens=["<unk>", "<s>", "</s>"],
unk_token="<unk>",
max_piece_length=16, # Maximum token length
n_sub_iterations=2, # EM algorithm iterations
shrinking_factor=0.75, # Vocabulary reduction rate
show_progress=True
)
```
## Training from large datasets
### Memory-efficient training
```python
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
# Load dataset
dataset = load_dataset("wikipedia", "20220301.en", split="train", streaming=True)
# Create iterator (yields batches)
def batch_iterator(batch_size=1000):
batch = []
for sample in dataset:
batch.append(sample["text"])
if len(batch) >= batch_size:
yield batch
batch = []
if batch:
yield batch
# Initialize tokenizer
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=50000, special_tokens=["<|endoftext|>"])
# Train (memory efficient - streams data)
tokenizer.train_from_iterator(
batch_iterator(),
trainer=trainer
)
```
**Memory usage**: ~200 MB (vs 10+ GB loading full dataset)
### Multi-file training
```python
import glob
# Find all training files
files = glob.glob("data/train/*.txt")
print(f"Training on {len(files)} files")
# Train on all files
tokenizer.train(files=files, trainer=trainer)
```
### Parallel training (multi-processing)
```python
from multiprocessing import Pool, cpu_count
import os
def train_shard(shard_files):
"""Train tokenizer on a shard of files."""
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=50000)
tokenizer.train(files=shard_files, trainer=trainer)
return tokenizer.get_vocab()
# Split files into shards
num_shards = cpu_count()
file_shards = [files[i::num_shards] for i in range(num_shards)]
# Train shards in parallel
with Pool(num_shards) as pool:
vocab_shards = pool.map(train_shard, file_shards)
# Merge vocabularies (custom logic needed)
# This is a simplified example - real implementation would merge intelligently
final_vocab = {}
for vocab in vocab_shards:
final_vocab.update(vocab)
```
## Domain-specific tokenizers
### Code tokenizer
```python
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.normalizers import Sequence, NFC
# Code-optimized configuration
tokenizer = Tokenizer(BPE())
# Minimal normalization (preserve case, whitespace)
tokenizer.normalizer = NFC() # Only normalize Unicode
# Byte-level pre-tokenization (handles all characters)
tokenizer.pre_tokenizer = ByteLevel()
# Train on code corpus
trainer = BpeTrainer(
vocab_size=50000,
special_tokens=["<|endoftext|>", "<|pad|>"],
min_frequency=2
)
tokenizer.train(files=["code_corpus.txt"], trainer=trainer)
```
### Medical/scientific tokenizer
```python
# Preserve case and special characters
from tokenizers.normalizers import NFKC
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence
tokenizer = Tokenizer(BPE())
# Minimal normalization
tokenizer.normalizer = NFKC()
# Preserve medical terms
tokenizer.pre_tokenizer = Sequence([
Whitespace(),
Punctuation(behavior="isolated") # Keep punctuation separate
])
trainer = BpeTrainer(
vocab_size=50000,
special_tokens=["[UNK]", "[CLS]", "[SEP]"],
min_frequency=3 # Higher threshold for rare medical terms
)
tokenizer.train(files=["pubmed_corpus.txt"], trainer=trainer)
```
### Multilingual tokenizer
```python
# Handle multiple scripts
from tokenizers.normalizers import NFKC, Lowercase, Sequence
tokenizer = Tokenizer(BPE())
# Normalize but don't lowercase (preserves script differences)
tokenizer.normalizer = NFKC()
# Byte-level handles all Unicode
from tokenizers.pre_tokenizers import ByteLevel
tokenizer.pre_tokenizer = ByteLevel()
trainer = BpeTrainer(
vocab_size=100000, # Larger vocab for multiple languages
special_tokens=["<unk>", "<s>", "</s>"],
limit_alphabet=None # No limit (handles all scripts)
)
# Train on multilingual corpus
tokenizer.train(files=["multilingual_corpus.txt"], trainer=trainer)
```
## Vocabulary size selection
### Guidelines by task
| Task | Recommended Vocab Size | Rationale |
|-----------------------|------------------------|-----------|
| English (monolingual) | 30,000 - 50,000 | Balanced coverage |
| Multilingual | 50,000 - 250,000 | More languages = more tokens |
| Code | 30,000 - 50,000 | Similar to English |
| Domain-specific | 10,000 - 30,000 | Smaller, focused vocabulary |
| Character-level tasks | 1,000 - 5,000 | Only characters + subwords |
### Vocabulary size impact
**Small vocab (10k)**:
- Pros: Faster training, smaller model, less memory
- Cons: More tokens per sentence, worse OOV handling
**Medium vocab (30k-50k)**:
- Pros: Good balance, standard choice
- Cons: None (recommended default)
**Large vocab (100k+)**:
- Pros: Fewer tokens per sentence, better OOV
- Cons: Slower training, larger embedding table
### Empirical testing
```python
# Train multiple tokenizers with different vocab sizes
vocab_sizes = [10000, 30000, 50000, 100000]
for vocab_size in vocab_sizes:
tokenizer = Tokenizer(BPE())
trainer = BpeTrainer(vocab_size=vocab_size)
tokenizer.train(files=["sample.txt"], trainer=trainer)
# Evaluate on test set
test_text = "Test sentence for evaluation..."
tokens = tokenizer.encode(test_text).ids
print(f"Vocab: {vocab_size:6d} | Tokens: {len(tokens):3d} | Avg: {len(test_text)/len(tokens):.2f} chars/token")
# Example output:
# Vocab: 10000 | Tokens: 12 | Avg: 2.33 chars/token
# Vocab: 30000 | Tokens: 8 | Avg: 3.50 chars/token
# Vocab: 50000 | Tokens: 7 | Avg: 4.00 chars/token
# Vocab: 100000 | Tokens: 6 | Avg: 4.67 chars/token
```
## Testing tokenizer quality
### Coverage test
```python
# Test on held-out data
test_corpus = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
total_tokens = 0
unk_tokens = 0
unk_id = tokenizer.token_to_id("[UNK]")
for text in test_corpus["text"]:
if text.strip():
encoding = tokenizer.encode(text)
total_tokens += len(encoding.ids)
unk_tokens += encoding.ids.count(unk_id)
unk_rate = unk_tokens / total_tokens
print(f"Unknown token rate: {unk_rate:.2%}")
# Good quality: <1% unknown tokens
# Acceptable: 1-5%
# Poor: >5%
```
### Compression test
```python
# Measure tokenization efficiency
import numpy as np
token_lengths = []
for text in test_corpus["text"][:1000]:
if text.strip():
encoding = tokenizer.encode(text)
chars_per_token = len(text) / len(encoding.ids)
token_lengths.append(chars_per_token)
avg_chars_per_token = np.mean(token_lengths)
print(f"Average characters per token: {avg_chars_per_token:.2f}")
# Good: 4-6 chars/token (English)
# Acceptable: 3-4 chars/token
# Poor: <3 chars/token (under-compression)
```
### Semantic test
```python
# Manually inspect tokenization of common words/phrases
test_phrases = [
"tokenization",
"machine learning",
"artificial intelligence",
"preprocessing",
"hello world"
]
for phrase in test_phrases:
tokens = tokenizer.encode(phrase).tokens
print(f"{phrase:25s}{tokens}")
# Good tokenization:
# tokenization → ['token', 'ization']
# machine learning → ['machine', 'learning']
# artificial intelligence → ['artificial', 'intelligence']
```
## Troubleshooting
### Issue: Training too slow
**Solutions**:
1. Reduce vocabulary size
2. Increase `min_frequency`
3. Use `limit_alphabet` to reduce initial alphabet
4. Train on subset first
```python
# Fast training configuration
trainer = BpeTrainer(
vocab_size=20000, # Smaller vocab
min_frequency=5, # Higher threshold
limit_alphabet=500, # Limit alphabet
show_progress=True
)
```
### Issue: High unknown token rate
**Solutions**:
1. Increase vocabulary size
2. Decrease `min_frequency`
3. Check normalization (might be too aggressive)
```python
# Better coverage configuration
trainer = BpeTrainer(
vocab_size=50000, # Larger vocab
min_frequency=1, # Lower threshold
)
```
### Issue: Poor quality tokenization
**Solutions**:
1. Verify normalization matches your use case
2. Check pre-tokenization splits correctly
3. Ensure training data is representative
4. Try different algorithm (BPE vs WordPiece vs Unigram)
```python
# Debug tokenization pipeline
text = "Sample text to debug"
# Check normalization
normalized = tokenizer.normalizer.normalize_str(text)
print(f"Normalized: {normalized}")
# Check pre-tokenization
pre_tokens = tokenizer.pre_tokenizer.pre_tokenize_str(text)
print(f"Pre-tokens: {pre_tokens}")
# Check final tokenization
tokens = tokenizer.encode(text).tokens
print(f"Tokens: {tokens}")
```
## Best practices
1. **Use representative training data** - Match your target domain
2. **Start with standard configs** - BERT WordPiece or GPT-2 BPE
3. **Test on held-out data** - Measure unknown token rate
4. **Iterate on vocabulary size** - Test 30k, 50k, 100k
5. **Save tokenizer with model** - Ensure reproducibility
6. **Version your tokenizers** - Track changes for reproducibility
7. **Document special tokens** - Critical for model training

View File

@@ -0,0 +1,493 @@
---
name: evaluating-llms-harness
description: Evaluates LLMs across 60+ academic benchmarks (MMLU, HumanEval, GSM8K, TruthfulQA, HellaSwag). Use when benchmarking model quality, comparing models, reporting academic results, or tracking training progress. Industry standard used by EleutherAI, HuggingFace, and major labs. Supports HuggingFace, vLLM, APIs.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [lm-eval, transformers, vllm]
metadata:
hermes:
tags: [Evaluation, LM Evaluation Harness, Benchmarking, MMLU, HumanEval, GSM8K, EleutherAI, Model Quality, Academic Benchmarks, Industry Standard]
---
# lm-evaluation-harness - LLM Benchmarking
## Quick start
lm-evaluation-harness evaluates LLMs across 60+ academic benchmarks using standardized prompts and metrics.
**Installation**:
```bash
pip install lm-eval
```
**Evaluate any HuggingFace model**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu,gsm8k,hellaswag \
--device cuda:0 \
--batch_size 8
```
**View available tasks**:
```bash
lm_eval --tasks list
```
## Common workflows
### Workflow 1: Standard benchmark evaluation
Evaluate model on core benchmarks (MMLU, GSM8K, HumanEval).
Copy this checklist:
```
Benchmark Evaluation:
- [ ] Step 1: Choose benchmark suite
- [ ] Step 2: Configure model
- [ ] Step 3: Run evaluation
- [ ] Step 4: Analyze results
```
**Step 1: Choose benchmark suite**
**Core reasoning benchmarks**:
- **MMLU** (Massive Multitask Language Understanding) - 57 subjects, multiple choice
- **GSM8K** - Grade school math word problems
- **HellaSwag** - Common sense reasoning
- **TruthfulQA** - Truthfulness and factuality
- **ARC** (AI2 Reasoning Challenge) - Science questions
**Code benchmarks**:
- **HumanEval** - Python code generation (164 problems)
- **MBPP** (Mostly Basic Python Problems) - Python coding
**Standard suite** (recommended for model releases):
```bash
--tasks mmlu,gsm8k,hellaswag,truthfulqa,arc_challenge
```
**Step 2: Configure model**
**HuggingFace model**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf,dtype=bfloat16 \
--tasks mmlu \
--device cuda:0 \
--batch_size auto # Auto-detect optimal batch size
```
**Quantized model (4-bit/8-bit)**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf,load_in_4bit=True \
--tasks mmlu \
--device cuda:0
```
**Custom checkpoint**:
```bash
lm_eval --model hf \
--model_args pretrained=/path/to/my-model,tokenizer=/path/to/tokenizer \
--tasks mmlu \
--device cuda:0
```
**Step 3: Run evaluation**
```bash
# Full MMLU evaluation (57 subjects)
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu \
--num_fewshot 5 \ # 5-shot evaluation (standard)
--batch_size 8 \
--output_path results/ \
--log_samples # Save individual predictions
# Multiple benchmarks at once
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu,gsm8k,hellaswag,truthfulqa,arc_challenge \
--num_fewshot 5 \
--batch_size 8 \
--output_path results/llama2-7b-eval.json
```
**Step 4: Analyze results**
Results saved to `results/llama2-7b-eval.json`:
```json
{
"results": {
"mmlu": {
"acc": 0.459,
"acc_stderr": 0.004
},
"gsm8k": {
"exact_match": 0.142,
"exact_match_stderr": 0.006
},
"hellaswag": {
"acc_norm": 0.765,
"acc_norm_stderr": 0.004
}
},
"config": {
"model": "hf",
"model_args": "pretrained=meta-llama/Llama-2-7b-hf",
"num_fewshot": 5
}
}
```
### Workflow 2: Track training progress
Evaluate checkpoints during training.
```
Training Progress Tracking:
- [ ] Step 1: Set up periodic evaluation
- [ ] Step 2: Choose quick benchmarks
- [ ] Step 3: Automate evaluation
- [ ] Step 4: Plot learning curves
```
**Step 1: Set up periodic evaluation**
Evaluate every N training steps:
```bash
#!/bin/bash
# eval_checkpoint.sh
CHECKPOINT_DIR=$1
STEP=$2
lm_eval --model hf \
--model_args pretrained=$CHECKPOINT_DIR/checkpoint-$STEP \
--tasks gsm8k,hellaswag \
--num_fewshot 0 \ # 0-shot for speed
--batch_size 16 \
--output_path results/step-$STEP.json
```
**Step 2: Choose quick benchmarks**
Fast benchmarks for frequent evaluation:
- **HellaSwag**: ~10 minutes on 1 GPU
- **GSM8K**: ~5 minutes
- **PIQA**: ~2 minutes
Avoid for frequent eval (too slow):
- **MMLU**: ~2 hours (57 subjects)
- **HumanEval**: Requires code execution
**Step 3: Automate evaluation**
Integrate with training script:
```python
# In training loop
if step % eval_interval == 0:
model.save_pretrained(f"checkpoints/step-{step}")
# Run evaluation
os.system(f"./eval_checkpoint.sh checkpoints step-{step}")
```
Or use PyTorch Lightning callbacks:
```python
from pytorch_lightning import Callback
class EvalHarnessCallback(Callback):
def on_validation_epoch_end(self, trainer, pl_module):
step = trainer.global_step
checkpoint_path = f"checkpoints/step-{step}"
# Save checkpoint
trainer.save_checkpoint(checkpoint_path)
# Run lm-eval
os.system(f"lm_eval --model hf --model_args pretrained={checkpoint_path} ...")
```
**Step 4: Plot learning curves**
```python
import json
import matplotlib.pyplot as plt
# Load all results
steps = []
mmlu_scores = []
for file in sorted(glob.glob("results/step-*.json")):
with open(file) as f:
data = json.load(f)
step = int(file.split("-")[1].split(".")[0])
steps.append(step)
mmlu_scores.append(data["results"]["mmlu"]["acc"])
# Plot
plt.plot(steps, mmlu_scores)
plt.xlabel("Training Step")
plt.ylabel("MMLU Accuracy")
plt.title("Training Progress")
plt.savefig("training_curve.png")
```
### Workflow 3: Compare multiple models
Benchmark suite for model comparison.
```
Model Comparison:
- [ ] Step 1: Define model list
- [ ] Step 2: Run evaluations
- [ ] Step 3: Generate comparison table
```
**Step 1: Define model list**
```bash
# models.txt
meta-llama/Llama-2-7b-hf
meta-llama/Llama-2-13b-hf
mistralai/Mistral-7B-v0.1
microsoft/phi-2
```
**Step 2: Run evaluations**
```bash
#!/bin/bash
# eval_all_models.sh
TASKS="mmlu,gsm8k,hellaswag,truthfulqa"
while read model; do
echo "Evaluating $model"
# Extract model name for output file
model_name=$(echo $model | sed 's/\//-/g')
lm_eval --model hf \
--model_args pretrained=$model,dtype=bfloat16 \
--tasks $TASKS \
--num_fewshot 5 \
--batch_size auto \
--output_path results/$model_name.json
done < models.txt
```
**Step 3: Generate comparison table**
```python
import json
import pandas as pd
models = [
"meta-llama-Llama-2-7b-hf",
"meta-llama-Llama-2-13b-hf",
"mistralai-Mistral-7B-v0.1",
"microsoft-phi-2"
]
tasks = ["mmlu", "gsm8k", "hellaswag", "truthfulqa"]
results = []
for model in models:
with open(f"results/{model}.json") as f:
data = json.load(f)
row = {"Model": model.replace("-", "/")}
for task in tasks:
# Get primary metric for each task
metrics = data["results"][task]
if "acc" in metrics:
row[task.upper()] = f"{metrics['acc']:.3f}"
elif "exact_match" in metrics:
row[task.upper()] = f"{metrics['exact_match']:.3f}"
results.append(row)
df = pd.DataFrame(results)
print(df.to_markdown(index=False))
```
Output:
```
| Model | MMLU | GSM8K | HELLASWAG | TRUTHFULQA |
|------------------------|-------|-------|-----------|------------|
| meta-llama/Llama-2-7b | 0.459 | 0.142 | 0.765 | 0.391 |
| meta-llama/Llama-2-13b | 0.549 | 0.287 | 0.801 | 0.430 |
| mistralai/Mistral-7B | 0.626 | 0.395 | 0.812 | 0.428 |
| microsoft/phi-2 | 0.560 | 0.613 | 0.682 | 0.447 |
```
### Workflow 4: Evaluate with vLLM (faster inference)
Use vLLM backend for 5-10x faster evaluation.
```
vLLM Evaluation:
- [ ] Step 1: Install vLLM
- [ ] Step 2: Configure vLLM backend
- [ ] Step 3: Run evaluation
```
**Step 1: Install vLLM**
```bash
pip install vllm
```
**Step 2: Configure vLLM backend**
```bash
lm_eval --model vllm \
--model_args pretrained=meta-llama/Llama-2-7b-hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8 \
--tasks mmlu \
--batch_size auto
```
**Step 3: Run evaluation**
vLLM is 5-10× faster than standard HuggingFace:
```bash
# Standard HF: ~2 hours for MMLU on 7B model
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu \
--batch_size 8
# vLLM: ~15-20 minutes for MMLU on 7B model
lm_eval --model vllm \
--model_args pretrained=meta-llama/Llama-2-7b-hf,tensor_parallel_size=2 \
--tasks mmlu \
--batch_size auto
```
## When to use vs alternatives
**Use lm-evaluation-harness when:**
- Benchmarking models for academic papers
- Comparing model quality across standard tasks
- Tracking training progress
- Reporting standardized metrics (everyone uses same prompts)
- Need reproducible evaluation
**Use alternatives instead:**
- **HELM** (Stanford): Broader evaluation (fairness, efficiency, calibration)
- **AlpacaEval**: Instruction-following evaluation with LLM judges
- **MT-Bench**: Conversational multi-turn evaluation
- **Custom scripts**: Domain-specific evaluation
## Common issues
**Issue: Evaluation too slow**
Use vLLM backend:
```bash
lm_eval --model vllm \
--model_args pretrained=model-name,tensor_parallel_size=2
```
Or reduce fewshot examples:
```bash
--num_fewshot 0 # Instead of 5
```
Or evaluate subset of MMLU:
```bash
--tasks mmlu_stem # Only STEM subjects
```
**Issue: Out of memory**
Reduce batch size:
```bash
--batch_size 1 # Or --batch_size auto
```
Use quantization:
```bash
--model_args pretrained=model-name,load_in_8bit=True
```
Enable CPU offloading:
```bash
--model_args pretrained=model-name,device_map=auto,offload_folder=offload
```
**Issue: Different results than reported**
Check fewshot count:
```bash
--num_fewshot 5 # Most papers use 5-shot
```
Check exact task name:
```bash
--tasks mmlu # Not mmlu_direct or mmlu_fewshot
```
Verify model and tokenizer match:
```bash
--model_args pretrained=model-name,tokenizer=same-model-name
```
**Issue: HumanEval not executing code**
Install execution dependencies:
```bash
pip install human-eval
```
Enable code execution:
```bash
lm_eval --model hf \
--model_args pretrained=model-name \
--tasks humaneval \
--allow_code_execution # Required for HumanEval
```
## Advanced topics
**Benchmark descriptions**: See [references/benchmark-guide.md](references/benchmark-guide.md) for detailed description of all 60+ tasks, what they measure, and interpretation.
**Custom tasks**: See [references/custom-tasks.md](references/custom-tasks.md) for creating domain-specific evaluation tasks.
**API evaluation**: See [references/api-evaluation.md](references/api-evaluation.md) for evaluating OpenAI, Anthropic, and other API models.
**Multi-GPU strategies**: See [references/distributed-eval.md](references/distributed-eval.md) for data parallel and tensor parallel evaluation.
## Hardware requirements
- **GPU**: NVIDIA (CUDA 11.8+), works on CPU (very slow)
- **VRAM**:
- 7B model: 16GB (bf16) or 8GB (8-bit)
- 13B model: 28GB (bf16) or 14GB (8-bit)
- 70B model: Requires multi-GPU or quantization
- **Time** (7B model, single A100):
- HellaSwag: 10 minutes
- GSM8K: 5 minutes
- MMLU (full): 2 hours
- HumanEval: 20 minutes
## Resources
- GitHub: https://github.com/EleutherAI/lm-evaluation-harness
- Docs: https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs
- Task library: 60+ tasks including MMLU, GSM8K, HumanEval, TruthfulQA, HellaSwag, ARC, WinoGrande, etc.
- Leaderboard: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard (uses this harness)

View File

@@ -0,0 +1,490 @@
# API Evaluation
Guide to evaluating OpenAI, Anthropic, and other API-based language models.
## Overview
The lm-evaluation-harness supports evaluating API-based models through a unified `TemplateAPI` interface. This allows benchmarking of:
- OpenAI models (GPT-4, GPT-3.5, etc.)
- Anthropic models (Claude 3, Claude 2, etc.)
- Local OpenAI-compatible APIs
- Custom API endpoints
**Why evaluate API models**:
- Benchmark closed-source models
- Compare API models to open models
- Validate API performance
- Track model updates over time
## Supported API Models
| Provider | Model Type | Request Types | Logprobs |
|----------|------------|---------------|----------|
| OpenAI (completions) | `openai-completions` | All | ✅ Yes |
| OpenAI (chat) | `openai-chat-completions` | `generate_until` only | ❌ No |
| Anthropic (completions) | `anthropic-completions` | All | ❌ No |
| Anthropic (chat) | `anthropic-chat` | `generate_until` only | ❌ No |
| Local (OpenAI-compatible) | `local-completions` | Depends on server | Varies |
**Note**: Models without logprobs can only be evaluated on generation tasks, not perplexity or loglikelihood tasks.
## OpenAI Models
### Setup
```bash
export OPENAI_API_KEY=sk-...
```
### Completion Models (Legacy)
**Available models**: `davinci-002`, `babbage-002`
```bash
lm_eval --model openai-completions \
--model_args model=davinci-002 \
--tasks lambada_openai,hellaswag \
--batch_size auto
```
**Supports**:
- `generate_until`: ✅
- `loglikelihood`: ✅
- `loglikelihood_rolling`: ✅
### Chat Models
**Available models**: `gpt-4`, `gpt-4-turbo`, `gpt-3.5-turbo`
```bash
lm_eval --model openai-chat-completions \
--model_args model=gpt-4-turbo \
--tasks mmlu,gsm8k,humaneval \
--num_fewshot 5 \
--batch_size auto
```
**Supports**:
- `generate_until`: ✅
- `loglikelihood`: ❌ (no logprobs)
- `loglikelihood_rolling`: ❌
**Important**: Chat models don't provide logprobs, so they can only be used with generation tasks (MMLU, GSM8K, HumanEval), not perplexity tasks.
### Configuration Options
```bash
lm_eval --model openai-chat-completions \
--model_args \
model=gpt-4-turbo,\
base_url=https://api.openai.com/v1,\
num_concurrent=5,\
max_retries=3,\
timeout=60,\
batch_size=auto
```
**Parameters**:
- `model`: Model identifier (required)
- `base_url`: API endpoint (default: OpenAI)
- `num_concurrent`: Concurrent requests (default: 5)
- `max_retries`: Retry failed requests (default: 3)
- `timeout`: Request timeout in seconds (default: 60)
- `tokenizer`: Tokenizer to use (default: matches model)
- `tokenizer_backend`: `"tiktoken"` or `"huggingface"`
### Cost Management
OpenAI charges per token. Estimate costs before running:
```python
# Rough estimate
num_samples = 1000
avg_tokens_per_sample = 500 # input + output
cost_per_1k_tokens = 0.01 # GPT-3.5 Turbo
total_cost = (num_samples * avg_tokens_per_sample / 1000) * cost_per_1k_tokens
print(f"Estimated cost: ${total_cost:.2f}")
```
**Cost-saving tips**:
- Use `--limit N` for testing
- Start with `gpt-3.5-turbo` before `gpt-4`
- Set `max_gen_toks` to minimum needed
- Use `num_fewshot=0` for zero-shot when possible
## Anthropic Models
### Setup
```bash
export ANTHROPIC_API_KEY=sk-ant-...
```
### Completion Models (Legacy)
```bash
lm_eval --model anthropic-completions \
--model_args model=claude-2.1 \
--tasks lambada_openai,hellaswag \
--batch_size auto
```
### Chat Models (Recommended)
**Available models**: `claude-3-5-sonnet-20241022`, `claude-3-opus-20240229`, `claude-3-sonnet-20240229`, `claude-3-haiku-20240307`
```bash
lm_eval --model anthropic-chat \
--model_args model=claude-3-5-sonnet-20241022 \
--tasks mmlu,gsm8k,humaneval \
--num_fewshot 5 \
--batch_size auto
```
**Aliases**: `anthropic-chat-completions` (same as `anthropic-chat`)
### Configuration Options
```bash
lm_eval --model anthropic-chat \
--model_args \
model=claude-3-5-sonnet-20241022,\
base_url=https://api.anthropic.com,\
num_concurrent=5,\
max_retries=3,\
timeout=60
```
### Cost Management
Anthropic pricing (as of 2024):
- Claude 3.5 Sonnet: $3.00 / 1M input, $15.00 / 1M output
- Claude 3 Opus: $15.00 / 1M input, $75.00 / 1M output
- Claude 3 Haiku: $0.25 / 1M input, $1.25 / 1M output
**Budget-friendly strategy**:
```bash
# Test on small sample first
lm_eval --model anthropic-chat \
--model_args model=claude-3-haiku-20240307 \
--tasks mmlu \
--limit 100
# Then run full eval on best model
lm_eval --model anthropic-chat \
--model_args model=claude-3-5-sonnet-20241022 \
--tasks mmlu \
--num_fewshot 5
```
## Local OpenAI-Compatible APIs
Many local inference servers expose OpenAI-compatible APIs (vLLM, Text Generation Inference, llama.cpp, Ollama).
### vLLM Local Server
**Start server**:
```bash
vllm serve meta-llama/Llama-2-7b-hf \
--host 0.0.0.0 \
--port 8000
```
**Evaluate**:
```bash
lm_eval --model local-completions \
--model_args \
model=meta-llama/Llama-2-7b-hf,\
base_url=http://localhost:8000/v1,\
num_concurrent=1 \
--tasks mmlu,gsm8k \
--batch_size auto
```
### Text Generation Inference (TGI)
**Start server**:
```bash
docker run --gpus all --shm-size 1g -p 8080:80 \
ghcr.io/huggingface/text-generation-inference:latest \
--model-id meta-llama/Llama-2-7b-hf
```
**Evaluate**:
```bash
lm_eval --model local-completions \
--model_args \
model=meta-llama/Llama-2-7b-hf,\
base_url=http://localhost:8080/v1 \
--tasks hellaswag,arc_challenge
```
### Ollama
**Start server**:
```bash
ollama serve
ollama pull llama2:7b
```
**Evaluate**:
```bash
lm_eval --model local-completions \
--model_args \
model=llama2:7b,\
base_url=http://localhost:11434/v1 \
--tasks mmlu
```
### llama.cpp Server
**Start server**:
```bash
./server -m models/llama-2-7b.gguf --host 0.0.0.0 --port 8080
```
**Evaluate**:
```bash
lm_eval --model local-completions \
--model_args \
model=llama2,\
base_url=http://localhost:8080/v1 \
--tasks gsm8k
```
## Custom API Implementation
For custom API endpoints, subclass `TemplateAPI`:
### Create `my_api.py`
```python
from lm_eval.models.api_models import TemplateAPI
import requests
class MyCustomAPI(TemplateAPI):
"""Custom API model."""
def __init__(self, base_url, api_key, **kwargs):
super().__init__(base_url=base_url, **kwargs)
self.api_key = api_key
def _create_payload(self, messages, gen_kwargs):
"""Create API request payload."""
return {
"messages": messages,
"api_key": self.api_key,
**gen_kwargs
}
def parse_generations(self, response):
"""Parse generation response."""
return response.json()["choices"][0]["text"]
def parse_logprobs(self, response):
"""Parse logprobs (if available)."""
# Return None if API doesn't provide logprobs
logprobs = response.json().get("logprobs")
if logprobs:
return logprobs["token_logprobs"]
return None
```
### Register and Use
```python
from lm_eval import evaluator
from my_api import MyCustomAPI
model = MyCustomAPI(
base_url="https://api.example.com/v1",
api_key="your-key"
)
results = evaluator.simple_evaluate(
model=model,
tasks=["mmlu", "gsm8k"],
num_fewshot=5,
batch_size="auto"
)
```
## Comparing API and Open Models
### Side-by-Side Evaluation
```bash
# Evaluate OpenAI GPT-4
lm_eval --model openai-chat-completions \
--model_args model=gpt-4-turbo \
--tasks mmlu,gsm8k,hellaswag \
--num_fewshot 5 \
--output_path results/gpt4.json
# Evaluate open Llama 2 70B
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-70b-hf,dtype=bfloat16 \
--tasks mmlu,gsm8k,hellaswag \
--num_fewshot 5 \
--output_path results/llama2-70b.json
# Compare results
python scripts/compare_results.py \
results/gpt4.json \
results/llama2-70b.json
```
### Typical Comparisons
| Model | MMLU | GSM8K | HumanEval | Cost |
|-------|------|-------|-----------|------|
| GPT-4 Turbo | 86.4% | 92.0% | 67.0% | $$$$ |
| Claude 3 Opus | 86.8% | 95.0% | 84.9% | $$$$ |
| GPT-3.5 Turbo | 70.0% | 57.1% | 48.1% | $$ |
| Llama 2 70B | 68.9% | 56.8% | 29.9% | Free (self-host) |
| Mixtral 8x7B | 70.6% | 58.4% | 40.2% | Free (self-host) |
## Best Practices
### Rate Limiting
Respect API rate limits:
```bash
lm_eval --model openai-chat-completions \
--model_args \
model=gpt-4-turbo,\
num_concurrent=3,\ # Lower concurrency
timeout=120 \ # Longer timeout
--tasks mmlu
```
### Reproducibility
Set temperature to 0 for deterministic results:
```bash
lm_eval --model openai-chat-completions \
--model_args model=gpt-4-turbo \
--tasks mmlu \
--gen_kwargs temperature=0.0
```
Or use `seed` for sampling:
```bash
lm_eval --model anthropic-chat \
--model_args model=claude-3-5-sonnet-20241022 \
--tasks gsm8k \
--gen_kwargs temperature=0.7,seed=42
```
### Caching
API models automatically cache responses to avoid redundant calls:
```bash
# First run: makes API calls
lm_eval --model openai-chat-completions \
--model_args model=gpt-4-turbo \
--tasks mmlu \
--limit 100
# Second run: uses cache (instant, free)
lm_eval --model openai-chat-completions \
--model_args model=gpt-4-turbo \
--tasks mmlu \
--limit 100
```
Cache location: `~/.cache/lm_eval/`
### Error Handling
APIs can fail. Use retries:
```bash
lm_eval --model openai-chat-completions \
--model_args \
model=gpt-4-turbo,\
max_retries=5,\
timeout=120 \
--tasks mmlu
```
## Troubleshooting
### "Authentication failed"
Check API key:
```bash
echo $OPENAI_API_KEY # Should print sk-...
echo $ANTHROPIC_API_KEY # Should print sk-ant-...
```
### "Rate limit exceeded"
Reduce concurrency:
```bash
--model_args num_concurrent=1
```
Or add delays between requests.
### "Timeout error"
Increase timeout:
```bash
--model_args timeout=180
```
### "Model not found"
For local APIs, verify server is running:
```bash
curl http://localhost:8000/v1/models
```
### Cost Runaway
Use `--limit` for testing:
```bash
lm_eval --model openai-chat-completions \
--model_args model=gpt-4-turbo \
--tasks mmlu \
--limit 50 # Only 50 samples
```
## Advanced Features
### Custom Headers
```bash
lm_eval --model local-completions \
--model_args \
base_url=http://api.example.com/v1,\
header="Authorization: Bearer token,X-Custom: value"
```
### Disable SSL Verification (Development Only)
```bash
lm_eval --model local-completions \
--model_args \
base_url=https://localhost:8000/v1,\
verify_certificate=false
```
### Custom Tokenizer
```bash
lm_eval --model openai-chat-completions \
--model_args \
model=gpt-4-turbo,\
tokenizer=gpt2,\
tokenizer_backend=huggingface
```
## References
- OpenAI API: https://platform.openai.com/docs/api-reference
- Anthropic API: https://docs.anthropic.com/claude/reference
- TemplateAPI: `lm_eval/models/api_models.py`
- OpenAI models: `lm_eval/models/openai_completions.py`
- Anthropic models: `lm_eval/models/anthropic_llms.py`

View File

@@ -0,0 +1,488 @@
# Benchmark Guide
Complete guide to all 60+ evaluation tasks in lm-evaluation-harness, what they measure, and how to interpret results.
## Overview
The lm-evaluation-harness includes 60+ benchmarks spanning:
- Language understanding (MMLU, GLUE)
- Mathematical reasoning (GSM8K, MATH)
- Code generation (HumanEval, MBPP)
- Instruction following (IFEval, AlpacaEval)
- Long-context understanding (LongBench)
- Multilingual capabilities (AfroBench, NorEval)
- Reasoning (BBH, ARC)
- Truthfulness (TruthfulQA)
**List all tasks**:
```bash
lm_eval --tasks list
```
## Major Benchmarks
### MMLU (Massive Multitask Language Understanding)
**What it measures**: Broad knowledge across 57 subjects (STEM, humanities, social sciences, law).
**Task variants**:
- `mmlu`: Original 57-subject benchmark
- `mmlu_pro`: More challenging version with reasoning-focused questions
- `mmlu_prox`: Multilingual extension
**Format**: Multiple choice (4 options)
**Example**:
```
Question: What is the capital of France?
A. Berlin
B. Paris
C. London
D. Madrid
Answer: B
```
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu \
--num_fewshot 5
```
**Interpretation**:
- Random: 25% (chance)
- GPT-3 (175B): 43.9%
- GPT-4: 86.4%
- Human expert: ~90%
**Good for**: Assessing general knowledge and domain expertise.
### GSM8K (Grade School Math 8K)
**What it measures**: Mathematical reasoning on grade-school level word problems.
**Task variants**:
- `gsm8k`: Base task
- `gsm8k_cot`: With chain-of-thought prompting
- `gsm_plus`: Adversarial variant with perturbations
**Format**: Free-form generation, extract numerical answer
**Example**:
```
Question: A baker made 200 cookies. He sold 3/5 of them in the morning and 1/4 of the remaining in the afternoon. How many cookies does he have left?
Answer: 60
```
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks gsm8k \
--num_fewshot 5
```
**Interpretation**:
- Random: ~0%
- GPT-3 (175B): 17.0%
- GPT-4: 92.0%
- Llama 2 70B: 56.8%
**Good for**: Testing multi-step reasoning and arithmetic.
### HumanEval
**What it measures**: Python code generation from docstrings (functional correctness).
**Task variants**:
- `humaneval`: Standard benchmark
- `humaneval_instruct`: For instruction-tuned models
**Format**: Code generation, execution-based evaluation
**Example**:
```python
def has_close_elements(numbers: List[float], threshold: float) -> bool:
""" Check if in given list of numbers, are any two numbers closer to each other than
given threshold.
>>> has_close_elements([1.0, 2.0, 3.0], 0.5)
False
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
True
"""
```
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=codellama/CodeLlama-7b-hf \
--tasks humaneval \
--batch_size 1
```
**Interpretation**:
- Random: 0%
- GPT-3 (175B): 0%
- Codex: 28.8%
- GPT-4: 67.0%
- Code Llama 34B: 53.7%
**Good for**: Evaluating code generation capabilities.
### BBH (BIG-Bench Hard)
**What it measures**: 23 challenging reasoning tasks where models previously failed to beat humans.
**Categories**:
- Logical reasoning
- Math word problems
- Social understanding
- Algorithmic reasoning
**Format**: Multiple choice and free-form
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks bbh \
--num_fewshot 3
```
**Interpretation**:
- Random: ~25%
- GPT-3 (175B): 33.9%
- PaLM 540B: 58.3%
- GPT-4: 86.7%
**Good for**: Testing advanced reasoning capabilities.
### IFEval (Instruction-Following Evaluation)
**What it measures**: Ability to follow specific, verifiable instructions.
**Instruction types**:
- Format constraints (e.g., "answer in 3 sentences")
- Length constraints (e.g., "use at least 100 words")
- Content constraints (e.g., "include the word 'banana'")
- Structural constraints (e.g., "use bullet points")
**Format**: Free-form generation with rule-based verification
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-chat-hf \
--tasks ifeval \
--batch_size auto
```
**Interpretation**:
- Measures: Instruction adherence (not quality)
- GPT-4: 86% instruction following
- Claude 2: 84%
**Good for**: Evaluating chat/instruct models.
### GLUE (General Language Understanding Evaluation)
**What it measures**: Natural language understanding across 9 tasks.
**Tasks**:
- `cola`: Grammatical acceptability
- `sst2`: Sentiment analysis
- `mrpc`: Paraphrase detection
- `qqp`: Question pairs
- `stsb`: Semantic similarity
- `mnli`: Natural language inference
- `qnli`: Question answering NLI
- `rte`: Recognizing textual entailment
- `wnli`: Winograd schemas
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=bert-base-uncased \
--tasks glue \
--num_fewshot 0
```
**Interpretation**:
- BERT Base: 78.3 (GLUE score)
- RoBERTa Large: 88.5
- Human baseline: 87.1
**Good for**: Encoder-only models, fine-tuning baselines.
### LongBench
**What it measures**: Long-context understanding (4K-32K tokens).
**21 tasks covering**:
- Single-document QA
- Multi-document QA
- Summarization
- Few-shot learning
- Code completion
- Synthetic tasks
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks longbench \
--batch_size 1
```
**Interpretation**:
- Tests context utilization
- Many models struggle beyond 4K tokens
- GPT-4 Turbo: 54.3%
**Good for**: Evaluating long-context models.
## Additional Benchmarks
### TruthfulQA
**What it measures**: Model's propensity to be truthful vs. generate plausible-sounding falsehoods.
**Format**: Multiple choice with 4-5 options
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks truthfulqa_mc2 \
--batch_size auto
```
**Interpretation**:
- Larger models often score worse (more convincing lies)
- GPT-3: 58.8%
- GPT-4: 59.0%
- Human: ~94%
### ARC (AI2 Reasoning Challenge)
**What it measures**: Grade-school science questions.
**Variants**:
- `arc_easy`: Easier questions
- `arc_challenge`: Harder questions requiring reasoning
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks arc_challenge \
--num_fewshot 25
```
**Interpretation**:
- ARC-Easy: Most models >80%
- ARC-Challenge random: 25%
- GPT-4: 96.3%
### HellaSwag
**What it measures**: Commonsense reasoning about everyday situations.
**Format**: Choose most plausible continuation
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks hellaswag \
--num_fewshot 10
```
**Interpretation**:
- Random: 25%
- GPT-3: 78.9%
- Llama 2 70B: 85.3%
### WinoGrande
**What it measures**: Commonsense reasoning via pronoun resolution.
**Example**:
```
The trophy doesn't fit in the brown suitcase because _ is too large.
A. the trophy
B. the suitcase
```
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks winogrande \
--num_fewshot 5
```
### PIQA
**What it measures**: Physical commonsense reasoning.
**Example**: "To clean a keyboard, use compressed air or..."
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks piqa
```
## Multilingual Benchmarks
### AfroBench
**What it measures**: Performance across 64 African languages.
**15 tasks**: NLU, text generation, knowledge, QA, math reasoning
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks afrobench
```
### NorEval
**What it measures**: Norwegian language understanding (9 task categories).
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=NbAiLab/nb-gpt-j-6B \
--tasks noreval
```
## Domain-Specific Benchmarks
### MATH
**What it measures**: High-school competition math problems.
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks math \
--num_fewshot 4
```
**Interpretation**:
- Very challenging
- GPT-4: 42.5%
- Minerva 540B: 33.6%
### MBPP (Mostly Basic Python Problems)
**What it measures**: Python programming from natural language descriptions.
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=codellama/CodeLlama-7b-hf \
--tasks mbpp \
--batch_size 1
```
### DROP
**What it measures**: Reading comprehension requiring discrete reasoning.
**Command**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks drop
```
## Benchmark Selection Guide
### For General Purpose Models
Run this suite:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu,gsm8k,hellaswag,arc_challenge,truthfulqa_mc2 \
--num_fewshot 5
```
### For Code Models
```bash
lm_eval --model hf \
--model_args pretrained=codellama/CodeLlama-7b-hf \
--tasks humaneval,mbpp \
--batch_size 1
```
### For Chat/Instruct Models
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-chat-hf \
--tasks ifeval,mmlu,gsm8k_cot \
--batch_size auto
```
### For Long Context Models
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-3.1-8B \
--tasks longbench \
--batch_size 1
```
## Interpreting Results
### Understanding Metrics
**Accuracy**: Percentage of correct answers (most common)
**Exact Match (EM)**: Requires exact string match (strict)
**F1 Score**: Balances precision and recall
**BLEU/ROUGE**: Text generation similarity
**Pass@k**: Percentage passing when generating k samples
### Typical Score Ranges
| Model Size | MMLU | GSM8K | HumanEval | HellaSwag |
|------------|------|-------|-----------|-----------|
| 7B | 40-50% | 10-20% | 5-15% | 70-80% |
| 13B | 45-55% | 20-35% | 15-25% | 75-82% |
| 70B | 60-70% | 50-65% | 35-50% | 82-87% |
| GPT-4 | 86% | 92% | 67% | 95% |
### Red Flags
- **All tasks at random chance**: Model not trained properly
- **Exact 0% on generation tasks**: Likely format/parsing issue
- **Huge variance across runs**: Check seed/sampling settings
- **Better than GPT-4 on everything**: Likely contamination
## Best Practices
1. **Always report few-shot setting**: 0-shot, 5-shot, etc.
2. **Run multiple seeds**: Report mean ± std
3. **Check for data contamination**: Search training data for benchmark examples
4. **Compare to published baselines**: Validate your setup
5. **Report all hyperparameters**: Model, batch size, max tokens, temperature
## References
- Task list: `lm_eval --tasks list`
- Task README: `lm_eval/tasks/README.md`
- Papers: See individual benchmark papers

View File

@@ -0,0 +1,602 @@
# Custom Tasks
Complete guide to creating domain-specific evaluation tasks in lm-evaluation-harness.
## Overview
Custom tasks allow you to evaluate models on your own datasets and metrics. Tasks are defined using YAML configuration files with optional Python utilities for complex logic.
**Why create custom tasks**:
- Evaluate on proprietary/domain-specific data
- Test specific capabilities not covered by existing benchmarks
- Create evaluation pipelines for internal models
- Reproduce research experiments
## Quick Start
### Minimal Custom Task
Create `my_tasks/simple_qa.yaml`:
```yaml
task: simple_qa
dataset_path: data/simple_qa.jsonl
output_type: generate_until
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{answer}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
```
**Run it**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks simple_qa \
--include_path my_tasks/
```
## Task Configuration Reference
### Essential Fields
```yaml
# Task identification
task: my_custom_task # Unique task name (required)
task_alias: "My Task" # Display name
tag: # Tags for grouping
- custom
- domain_specific
# Dataset configuration
dataset_path: data/my_data.jsonl # HuggingFace dataset or local path
dataset_name: default # Subset name (if applicable)
training_split: train
validation_split: validation
test_split: test
# Evaluation configuration
output_type: generate_until # or loglikelihood, multiple_choice
num_fewshot: 5 # Number of few-shot examples
batch_size: auto # Batch size
# Prompt templates (Jinja2)
doc_to_text: "Question: {{question}}"
doc_to_target: "{{answer}}"
# Metrics
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
# Metadata
metadata:
version: 1.0
```
### Output Types
**`generate_until`**: Free-form generation
```yaml
output_type: generate_until
generation_kwargs:
max_gen_toks: 256
until:
- "\n"
- "."
temperature: 0.0
```
**`loglikelihood`**: Compute log probability of targets
```yaml
output_type: loglikelihood
# Used for perplexity, classification
```
**`multiple_choice`**: Choose from options
```yaml
output_type: multiple_choice
doc_to_choice: "{{choices}}" # List of choices
```
## Data Formats
### Local JSONL File
`data/my_data.jsonl`:
```json
{"question": "What is 2+2?", "answer": "4"}
{"question": "Capital of France?", "answer": "Paris"}
```
**Task config**:
```yaml
dataset_path: data/my_data.jsonl
dataset_kwargs:
data_files:
test: data/my_data.jsonl
```
### HuggingFace Dataset
```yaml
dataset_path: squad
dataset_name: plain_text
test_split: validation
```
### CSV File
`data/my_data.csv`:
```csv
question,answer,category
What is 2+2?,4,math
Capital of France?,Paris,geography
```
**Task config**:
```yaml
dataset_path: data/my_data.csv
dataset_kwargs:
data_files:
test: data/my_data.csv
```
## Prompt Engineering
### Simple Template
```yaml
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{answer}}"
```
### Conditional Logic
```yaml
doc_to_text: |
{% if context %}
Context: {{context}}
{% endif %}
Question: {{question}}
Answer:
```
### Multiple Choice
```yaml
doc_to_text: |
Question: {{question}}
A. {{choices[0]}}
B. {{choices[1]}}
C. {{choices[2]}}
D. {{choices[3]}}
Answer:
doc_to_target: "{{ 'ABCD'[answer_idx] }}"
doc_to_choice: ["A", "B", "C", "D"]
```
### Few-Shot Formatting
```yaml
fewshot_delimiter: "\n\n" # Between examples
target_delimiter: " " # Between question and answer
doc_to_text: "Q: {{question}}"
doc_to_target: "A: {{answer}}"
```
## Custom Python Functions
For complex logic, use Python functions in `utils.py`.
### Create `my_tasks/utils.py`
```python
def process_docs(dataset):
"""Preprocess documents."""
def _process(doc):
# Custom preprocessing
doc["question"] = doc["question"].strip().lower()
return doc
return dataset.map(_process)
def doc_to_text(doc):
"""Custom prompt formatting."""
context = doc.get("context", "")
question = doc["question"]
if context:
return f"Context: {context}\nQuestion: {question}\nAnswer:"
return f"Question: {question}\nAnswer:"
def doc_to_target(doc):
"""Custom target extraction."""
return doc["answer"].strip().lower()
def aggregate_scores(items):
"""Custom metric aggregation."""
correct = sum(1 for item in items if item == 1.0)
total = len(items)
return correct / total if total > 0 else 0.0
```
### Use in Task Config
```yaml
task: my_custom_task
dataset_path: data/my_data.jsonl
# Use Python functions
process_docs: !function utils.process_docs
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
metric_list:
- metric: exact_match
aggregation: !function utils.aggregate_scores
higher_is_better: true
```
## Real-World Examples
### Example 1: Domain QA Task
**Goal**: Evaluate medical question answering.
`medical_qa/medical_qa.yaml`:
```yaml
task: medical_qa
dataset_path: data/medical_qa.jsonl
output_type: generate_until
num_fewshot: 3
doc_to_text: |
Medical Question: {{question}}
Context: {{context}}
Answer (be concise):
doc_to_target: "{{answer}}"
generation_kwargs:
max_gen_toks: 100
until:
- "\n\n"
temperature: 0.0
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
- metric: !function utils.medical_f1
aggregation: mean
higher_is_better: true
filter_list:
- name: lowercase
filter:
- function: lowercase
- function: remove_whitespace
metadata:
version: 1.0
domain: medical
```
`medical_qa/utils.py`:
```python
from sklearn.metrics import f1_score
import re
def medical_f1(predictions, references):
"""Custom F1 for medical terms."""
pred_terms = set(extract_medical_terms(predictions[0]))
ref_terms = set(extract_medical_terms(references[0]))
if not pred_terms and not ref_terms:
return 1.0
if not pred_terms or not ref_terms:
return 0.0
tp = len(pred_terms & ref_terms)
fp = len(pred_terms - ref_terms)
fn = len(ref_terms - pred_terms)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
return 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
def extract_medical_terms(text):
"""Extract medical terminology."""
# Custom logic
return re.findall(r'\b[A-Z][a-z]+(?:[A-Z][a-z]+)*\b', text)
```
### Example 2: Code Evaluation
`code_eval/python_challenges.yaml`:
```yaml
task: python_challenges
dataset_path: data/python_problems.jsonl
output_type: generate_until
num_fewshot: 0
doc_to_text: |
Write a Python function to solve:
{{problem_statement}}
Function signature:
{{function_signature}}
doc_to_target: "{{canonical_solution}}"
generation_kwargs:
max_gen_toks: 512
until:
- "\n\nclass"
- "\n\ndef"
temperature: 0.2
metric_list:
- metric: !function utils.execute_code
aggregation: mean
higher_is_better: true
process_results: !function utils.process_code_results
metadata:
version: 1.0
```
`code_eval/utils.py`:
```python
import subprocess
import json
def execute_code(predictions, references):
"""Execute generated code against test cases."""
generated_code = predictions[0]
test_cases = json.loads(references[0])
try:
# Execute code with test cases
for test_input, expected_output in test_cases:
result = execute_with_timeout(generated_code, test_input, timeout=5)
if result != expected_output:
return 0.0
return 1.0
except Exception:
return 0.0
def execute_with_timeout(code, input_data, timeout=5):
"""Safely execute code with timeout."""
# Implementation with subprocess and timeout
pass
def process_code_results(doc, results):
"""Process code execution results."""
return {
"passed": results[0] == 1.0,
"generated_code": results[1]
}
```
### Example 3: Instruction Following
`instruction_eval/instruction_eval.yaml`:
```yaml
task: instruction_following
dataset_path: data/instructions.jsonl
output_type: generate_until
num_fewshot: 0
doc_to_text: |
Instruction: {{instruction}}
{% if constraints %}
Constraints: {{constraints}}
{% endif %}
Response:
doc_to_target: "{{expected_response}}"
generation_kwargs:
max_gen_toks: 256
temperature: 0.7
metric_list:
- metric: !function utils.check_constraints
aggregation: mean
higher_is_better: true
- metric: !function utils.semantic_similarity
aggregation: mean
higher_is_better: true
process_docs: !function utils.add_constraint_checkers
```
`instruction_eval/utils.py`:
```python
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('all-MiniLM-L6-v2')
def check_constraints(predictions, references):
"""Check if response satisfies constraints."""
response = predictions[0]
constraints = json.loads(references[0])
satisfied = 0
total = len(constraints)
for constraint in constraints:
if verify_constraint(response, constraint):
satisfied += 1
return satisfied / total if total > 0 else 1.0
def verify_constraint(response, constraint):
"""Verify single constraint."""
if constraint["type"] == "length":
return len(response.split()) >= constraint["min_words"]
elif constraint["type"] == "contains":
return constraint["keyword"] in response.lower()
# Add more constraint types
return True
def semantic_similarity(predictions, references):
"""Compute semantic similarity."""
pred_embedding = model.encode(predictions[0])
ref_embedding = model.encode(references[0])
return float(util.cos_sim(pred_embedding, ref_embedding))
def add_constraint_checkers(dataset):
"""Parse constraints into verifiable format."""
def _parse(doc):
# Parse constraint string into structured format
doc["parsed_constraints"] = parse_constraints(doc.get("constraints", ""))
return doc
return dataset.map(_parse)
```
## Advanced Features
### Output Filtering
```yaml
filter_list:
- name: extract_answer
filter:
- function: regex
regex_pattern: "Answer: (.*)"
group: 1
- function: lowercase
- function: strip_whitespace
```
### Multiple Metrics
```yaml
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
- metric: bleu
aggregation: mean
higher_is_better: true
```
### Task Groups
Create `my_tasks/_default.yaml`:
```yaml
group: my_eval_suite
task:
- simple_qa
- medical_qa
- python_challenges
```
**Run entire suite**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks my_eval_suite \
--include_path my_tasks/
```
## Testing Your Task
### Validate Configuration
```bash
# Test task loading
lm_eval --tasks my_custom_task --include_path my_tasks/ --limit 0
# Run on 5 samples
lm_eval --model hf \
--model_args pretrained=gpt2 \
--tasks my_custom_task \
--include_path my_tasks/ \
--limit 5
```
### Debug Mode
```bash
lm_eval --model hf \
--model_args pretrained=gpt2 \
--tasks my_custom_task \
--include_path my_tasks/ \
--limit 1 \
--log_samples # Save input/output samples
```
## Best Practices
1. **Start simple**: Test with minimal config first
2. **Version your tasks**: Use `metadata.version`
3. **Document your metrics**: Explain custom metrics in comments
4. **Test with multiple models**: Ensure robustness
5. **Validate on known examples**: Include sanity checks
6. **Use filters carefully**: Can hide errors
7. **Handle edge cases**: Empty strings, missing fields
## Common Patterns
### Classification Task
```yaml
output_type: loglikelihood
doc_to_text: "Text: {{text}}\nLabel:"
doc_to_target: " {{label}}" # Space prefix important!
metric_list:
- metric: acc
aggregation: mean
```
### Perplexity Evaluation
```yaml
output_type: loglikelihood_rolling
doc_to_text: "{{text}}"
metric_list:
- metric: perplexity
aggregation: perplexity
```
### Ranking Task
```yaml
output_type: loglikelihood
doc_to_text: "Query: {{query}}\nPassage: {{passage}}\nRelevant:"
doc_to_target: [" Yes", " No"]
metric_list:
- metric: acc
aggregation: mean
```
## Troubleshooting
**"Task not found"**: Check `--include_path` and task name
**Empty results**: Verify `doc_to_text` and `doc_to_target` templates
**Metric errors**: Ensure metric names are correct (exact_match, not exact-match)
**Filter issues**: Test filters with `--log_samples`
**Python function not found**: Check `!function module.function_name` syntax
## References
- Task system: EleutherAI/lm-evaluation-harness docs
- Example tasks: `lm_eval/tasks/` directory
- TaskConfig: `lm_eval/api/task.py`

View File

@@ -0,0 +1,519 @@
# Distributed Evaluation
Guide to running evaluation across multiple GPUs using data parallelism and tensor/pipeline parallelism.
## Overview
Distributed evaluation speeds up benchmarking by:
- **Data Parallelism**: Split evaluation samples across GPUs (each GPU has full model copy)
- **Tensor Parallelism**: Split model weights across GPUs (for large models)
- **Pipeline Parallelism**: Split model layers across GPUs (for very large models)
**When to use**:
- Data Parallel: Model fits on single GPU, want faster evaluation
- Tensor/Pipeline Parallel: Model too large for single GPU
## HuggingFace Models (`hf`)
### Data Parallelism (Recommended)
Each GPU loads a full copy of the model and processes a subset of evaluation data.
**Single Node (8 GPUs)**:
```bash
accelerate launch --multi_gpu --num_processes 8 \
-m lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf,dtype=bfloat16 \
--tasks mmlu,gsm8k,hellaswag \
--batch_size 16
```
**Speedup**: Near-linear (8 GPUs = ~8× faster)
**Memory**: Each GPU needs full model (7B model ≈ 14GB × 8 = 112GB total)
### Tensor Parallelism (Model Sharding)
Split model weights across GPUs for models too large for single GPU.
**Without accelerate launcher**:
```bash
lm_eval --model hf \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
parallelize=True,\
dtype=bfloat16 \
--tasks mmlu,gsm8k \
--batch_size 8
```
**With 8 GPUs**: 70B model (140GB) / 8 = 17.5GB per GPU ✅
**Advanced sharding**:
```bash
lm_eval --model hf \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
parallelize=True,\
device_map_option=auto,\
max_memory_per_gpu=40GB,\
max_cpu_memory=100GB,\
dtype=bfloat16 \
--tasks mmlu
```
**Options**:
- `device_map_option`: `"auto"` (default), `"balanced"`, `"balanced_low_0"`
- `max_memory_per_gpu`: Max memory per GPU (e.g., `"40GB"`)
- `max_cpu_memory`: Max CPU memory for offloading
- `offload_folder`: Disk offloading directory
### Combined Data + Tensor Parallelism
Use both for very large models.
**Example: 70B model on 16 GPUs (2 copies, 8 GPUs each)**:
```bash
accelerate launch --multi_gpu --num_processes 2 \
-m lm_eval --model hf \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
parallelize=True,\
dtype=bfloat16 \
--tasks mmlu \
--batch_size 8
```
**Result**: 2× speedup from data parallelism, 70B model fits via tensor parallelism
### Configuration with `accelerate config`
Create `~/.cache/huggingface/accelerate/default_config.yaml`:
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
num_machines: 1
num_processes: 8
gpu_ids: all
mixed_precision: bf16
```
**Then run**:
```bash
accelerate launch -m lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu
```
## vLLM Models (`vllm`)
vLLM provides highly optimized distributed inference.
### Tensor Parallelism
**Single Node (4 GPUs)**:
```bash
lm_eval --model vllm \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
tensor_parallel_size=4,\
dtype=auto,\
gpu_memory_utilization=0.9 \
--tasks mmlu,gsm8k \
--batch_size auto
```
**Memory**: 70B model split across 4 GPUs = ~35GB per GPU
### Data Parallelism
**Multiple model replicas**:
```bash
lm_eval --model vllm \
--model_args \
pretrained=meta-llama/Llama-2-7b-hf,\
data_parallel_size=4,\
dtype=auto,\
gpu_memory_utilization=0.8 \
--tasks hellaswag,arc_challenge \
--batch_size auto
```
**Result**: 4 model replicas = 4× throughput
### Combined Tensor + Data Parallelism
**Example: 8 GPUs = 4 TP × 2 DP**:
```bash
lm_eval --model vllm \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
tensor_parallel_size=4,\
data_parallel_size=2,\
dtype=auto,\
gpu_memory_utilization=0.85 \
--tasks mmlu \
--batch_size auto
```
**Result**: 70B model fits (TP=4), 2× speedup (DP=2)
### Multi-Node vLLM
vLLM doesn't natively support multi-node. Use Ray:
```bash
# Start Ray cluster
ray start --head --port=6379
# Run evaluation
lm_eval --model vllm \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
tensor_parallel_size=8,\
dtype=auto \
--tasks mmlu
```
## NVIDIA NeMo Models (`nemo_lm`)
### Data Replication
**8 replicas on 8 GPUs**:
```bash
torchrun --nproc-per-node=8 --no-python \
lm_eval --model nemo_lm \
--model_args \
path=/path/to/model.nemo,\
devices=8 \
--tasks hellaswag,arc_challenge \
--batch_size 32
```
**Speedup**: Near-linear (8× faster)
### Tensor Parallelism
**4-way tensor parallelism**:
```bash
torchrun --nproc-per-node=4 --no-python \
lm_eval --model nemo_lm \
--model_args \
path=/path/to/70b_model.nemo,\
devices=4,\
tensor_model_parallel_size=4 \
--tasks mmlu,gsm8k \
--batch_size 16
```
### Pipeline Parallelism
**2 TP × 2 PP on 4 GPUs**:
```bash
torchrun --nproc-per-node=4 --no-python \
lm_eval --model nemo_lm \
--model_args \
path=/path/to/model.nemo,\
devices=4,\
tensor_model_parallel_size=2,\
pipeline_model_parallel_size=2 \
--tasks mmlu \
--batch_size 8
```
**Constraint**: `devices = TP × PP`
### Multi-Node NeMo
Currently not supported by lm-evaluation-harness.
## SGLang Models (`sglang`)
### Tensor Parallelism
```bash
lm_eval --model sglang \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
tp_size=4,\
dtype=auto \
--tasks gsm8k \
--batch_size auto
```
### Data Parallelism (Deprecated)
**Note**: SGLang is deprecating data parallelism. Use tensor parallelism instead.
```bash
lm_eval --model sglang \
--model_args \
pretrained=meta-llama/Llama-2-7b-hf,\
dp_size=4,\
dtype=auto \
--tasks mmlu
```
## Performance Comparison
### 70B Model Evaluation (MMLU, 5-shot)
| Method | GPUs | Time | Memory/GPU | Notes |
|--------|------|------|------------|-------|
| HF (no parallel) | 1 | 8 hours | 140GB (OOM) | Won't fit |
| HF (TP=8) | 8 | 2 hours | 17.5GB | Slower, fits |
| HF (DP=8) | 8 | 1 hour | 140GB (OOM) | Won't fit |
| vLLM (TP=4) | 4 | 30 min | 35GB | Fast! |
| vLLM (TP=4, DP=2) | 8 | 15 min | 35GB | Fastest |
### 7B Model Evaluation (Multiple Tasks)
| Method | GPUs | Time | Speedup |
|--------|------|------|---------|
| HF (single) | 1 | 4 hours | 1× |
| HF (DP=4) | 4 | 1 hour | 4× |
| HF (DP=8) | 8 | 30 min | 8× |
| vLLM (DP=8) | 8 | 15 min | 16× |
**Takeaway**: vLLM is significantly faster than HuggingFace for inference.
## Choosing Parallelism Strategy
### Decision Tree
```
Model fits on single GPU?
├─ YES: Use data parallelism
│ ├─ HF: accelerate launch --multi_gpu --num_processes N
│ └─ vLLM: data_parallel_size=N (fastest)
└─ NO: Use tensor/pipeline parallelism
├─ Model < 70B:
│ └─ vLLM: tensor_parallel_size=4
├─ Model 70-175B:
│ ├─ vLLM: tensor_parallel_size=8
│ └─ Or HF: parallelize=True
└─ Model > 175B:
└─ Contact framework authors
```
### Memory Estimation
**Rule of thumb**:
```
Memory (GB) = Parameters (B) × Precision (bytes) × 1.2 (overhead)
```
**Examples**:
- 7B FP16: 7 × 2 × 1.2 = 16.8GB ✅ Fits A100 40GB
- 13B FP16: 13 × 2 × 1.2 = 31.2GB ✅ Fits A100 40GB
- 70B FP16: 70 × 2 × 1.2 = 168GB ❌ Need TP=4 or TP=8
- 70B BF16: 70 × 2 × 1.2 = 168GB (same as FP16)
**With tensor parallelism**:
```
Memory per GPU = Total Memory / TP
```
- 70B on 4 GPUs: 168GB / 4 = 42GB per GPU ✅
- 70B on 8 GPUs: 168GB / 8 = 21GB per GPU ✅
## Multi-Node Evaluation
### HuggingFace with SLURM
**Submit job**:
```bash
#!/bin/bash
#SBATCH --nodes=4
#SBATCH --gpus-per-node=8
#SBATCH --ntasks-per-node=1
srun accelerate launch --multi_gpu \
--num_processes $((SLURM_NNODES * 8)) \
-m lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu,gsm8k,hellaswag \
--batch_size 16
```
**Submit**:
```bash
sbatch eval_job.sh
```
### Manual Multi-Node Setup
**On each node, run**:
```bash
accelerate launch \
--multi_gpu \
--num_machines 4 \
--num_processes 32 \
--main_process_ip $MASTER_IP \
--main_process_port 29500 \
--machine_rank $NODE_RANK \
-m lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu
```
**Environment variables**:
- `MASTER_IP`: IP of rank 0 node
- `NODE_RANK`: 0, 1, 2, 3 for each node
## Best Practices
### 1. Start Small
Test on small sample first:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-70b-hf,parallelize=True \
--tasks mmlu \
--limit 100 # Just 100 samples
```
### 2. Monitor GPU Usage
```bash
# Terminal 1: Run evaluation
lm_eval --model hf ...
# Terminal 2: Monitor
watch -n 1 nvidia-smi
```
Look for:
- GPU utilization > 90%
- Memory usage stable
- All GPUs active
### 3. Optimize Batch Size
```bash
# Auto batch size (recommended)
--batch_size auto
# Or tune manually
--batch_size 16 # Start here
--batch_size 32 # Increase if memory allows
```
### 4. Use Mixed Precision
```bash
--model_args dtype=bfloat16 # Faster, less memory
```
### 5. Check Communication
For data parallelism, check network bandwidth:
```bash
# Should see InfiniBand or high-speed network
nvidia-smi topo -m
```
## Troubleshooting
### "CUDA out of memory"
**Solutions**:
1. Increase tensor parallelism:
```bash
--model_args tensor_parallel_size=8 # Was 4
```
2. Reduce batch size:
```bash
--batch_size 4 # Was 16
```
3. Lower precision:
```bash
--model_args dtype=int8 # Quantization
```
### "NCCL error" or Hanging
**Check**:
1. All GPUs visible: `nvidia-smi`
2. NCCL installed: `python -c "import torch; print(torch.cuda.nccl.version())"`
3. Network connectivity between nodes
**Fix**:
```bash
export NCCL_DEBUG=INFO # Enable debug logging
export NCCL_IB_DISABLE=0 # Use InfiniBand if available
```
### Slow Evaluation
**Possible causes**:
1. **Data loading bottleneck**: Preprocess dataset
2. **Low GPU utilization**: Increase batch size
3. **Communication overhead**: Reduce parallelism degree
**Profile**:
```bash
lm_eval --model hf \
--model_args pretrained=meta-llama/Llama-2-7b-hf \
--tasks mmlu \
--limit 100 \
--log_samples # Check timing
```
### GPUs Imbalanced
**Symptom**: GPU 0 at 100%, others at 50%
**Solution**: Use `device_map_option=balanced`:
```bash
--model_args parallelize=True,device_map_option=balanced
```
## Example Configurations
### Small Model (7B) - Fast Evaluation
```bash
# 8 A100s, data parallel
accelerate launch --multi_gpu --num_processes 8 \
-m lm_eval --model hf \
--model_args \
pretrained=meta-llama/Llama-2-7b-hf,\
dtype=bfloat16 \
--tasks mmlu,gsm8k,hellaswag,arc_challenge \
--num_fewshot 5 \
--batch_size 32
# Time: ~30 minutes
```
### Large Model (70B) - vLLM
```bash
# 8 H100s, tensor parallel
lm_eval --model vllm \
--model_args \
pretrained=meta-llama/Llama-2-70b-hf,\
tensor_parallel_size=8,\
dtype=auto,\
gpu_memory_utilization=0.9 \
--tasks mmlu,gsm8k,humaneval \
--num_fewshot 5 \
--batch_size auto
# Time: ~1 hour
```
### Very Large Model (175B+)
**Requires specialized setup - contact framework maintainers**
## References
- HuggingFace Accelerate: https://huggingface.co/docs/accelerate/
- vLLM docs: https://docs.vllm.ai/
- NeMo docs: https://docs.nvidia.com/nemo-framework/
- lm-eval distributed guide: `docs/model_guide.md`

View File

@@ -0,0 +1,386 @@
---
name: nemo-curator
description: GPU-accelerated data curation for LLM training. Supports text/image/video/audio. Features fuzzy deduplication (16× faster), quality filtering (30+ heuristics), semantic deduplication, PII redaction, NSFW detection. Scales across GPUs with RAPIDS. Use for preparing high-quality training datasets, cleaning web data, or deduplicating large corpora.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [nemo-curator, cudf, dask, rapids]
metadata:
hermes:
tags: [Data Processing, NeMo Curator, Data Curation, GPU Acceleration, Deduplication, Quality Filtering, NVIDIA, RAPIDS, PII Redaction, Multimodal, LLM Training Data]
---
# NeMo Curator - GPU-Accelerated Data Curation
NVIDIA's toolkit for preparing high-quality training data for LLMs.
## When to use NeMo Curator
**Use NeMo Curator when:**
- Preparing LLM training data from web scrapes (Common Crawl)
- Need fast deduplication (16× faster than CPU)
- Curating multi-modal datasets (text, images, video, audio)
- Filtering low-quality or toxic content
- Scaling data processing across GPU cluster
**Performance**:
- **16× faster** fuzzy deduplication (8TB RedPajama v2)
- **40% lower TCO** vs CPU alternatives
- **Near-linear scaling** across GPU nodes
**Use alternatives instead**:
- **datatrove**: CPU-based, open-source data processing
- **dolma**: Allen AI's data toolkit
- **Ray Data**: General ML data processing (no curation focus)
## Quick start
### Installation
```bash
# Text curation (CUDA 12)
uv pip install "nemo-curator[text_cuda12]"
# All modalities
uv pip install "nemo-curator[all_cuda12]"
# CPU-only (slower)
uv pip install "nemo-curator[cpu]"
```
### Basic text curation pipeline
```python
from nemo_curator import ScoreFilter, Modify
from nemo_curator.datasets import DocumentDataset
import pandas as pd
# Load data
df = pd.DataFrame({"text": ["Good document", "Bad doc", "Excellent text"]})
dataset = DocumentDataset(df)
# Quality filtering
def quality_score(doc):
return len(doc["text"].split()) > 5 # Filter short docs
filtered = ScoreFilter(quality_score)(dataset)
# Deduplication
from nemo_curator.modules import ExactDuplicates
deduped = ExactDuplicates()(filtered)
# Save
deduped.to_parquet("curated_data/")
```
## Data curation pipeline
### Stage 1: Quality filtering
```python
from nemo_curator.filters import (
WordCountFilter,
RepeatedLinesFilter,
UrlRatioFilter,
NonAlphaNumericFilter
)
# Apply 30+ heuristic filters
from nemo_curator import ScoreFilter
# Word count filter
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
# Remove repetitive content
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
# URL ratio filter
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
```
### Stage 2: Deduplication
**Exact deduplication**:
```python
from nemo_curator.modules import ExactDuplicates
# Remove exact duplicates
deduped = ExactDuplicates(id_field="id", text_field="text")(dataset)
```
**Fuzzy deduplication** (16× faster on GPU):
```python
from nemo_curator.modules import FuzzyDuplicates
# MinHash + LSH deduplication
fuzzy_dedup = FuzzyDuplicates(
id_field="id",
text_field="text",
num_hashes=260, # MinHash parameters
num_buckets=20,
hash_method="md5"
)
deduped = fuzzy_dedup(dataset)
```
**Semantic deduplication**:
```python
from nemo_curator.modules import SemanticDuplicates
# Embedding-based deduplication
semantic_dedup = SemanticDuplicates(
id_field="id",
text_field="text",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
threshold=0.8 # Cosine similarity threshold
)
deduped = semantic_dedup(dataset)
```
### Stage 3: PII redaction
```python
from nemo_curator.modules import Modify
from nemo_curator.modifiers import PIIRedactor
# Redact personally identifiable information
pii_redactor = PIIRedactor(
supported_entities=["EMAIL_ADDRESS", "PHONE_NUMBER", "PERSON", "LOCATION"],
anonymize_action="replace" # or "redact"
)
redacted = Modify(pii_redactor)(dataset)
```
### Stage 4: Classifier filtering
```python
from nemo_curator.classifiers import QualityClassifier
# Quality classification
quality_clf = QualityClassifier(
model_path="nvidia/quality-classifier-deberta",
batch_size=256,
device="cuda"
)
# Filter low-quality documents
high_quality = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
```
## GPU acceleration
### GPU vs CPU performance
| Operation | CPU (16 cores) | GPU (A100) | Speedup |
|-----------|----------------|------------|---------|
| Fuzzy dedup (8TB) | 120 hours | 7.5 hours | 16× |
| Exact dedup (1TB) | 8 hours | 0.5 hours | 16× |
| Quality filtering | 2 hours | 0.2 hours | 10× |
### Multi-GPU scaling
```python
from nemo_curator import get_client
import dask_cuda
# Initialize GPU cluster
client = get_client(cluster_type="gpu", n_workers=8)
# Process with 8 GPUs
deduped = FuzzyDuplicates(...)(dataset)
```
## Multi-modal curation
### Image curation
```python
from nemo_curator.image import (
AestheticFilter,
NSFWFilter,
CLIPEmbedder
)
# Aesthetic scoring
aesthetic_filter = AestheticFilter(threshold=5.0)
filtered_images = aesthetic_filter(image_dataset)
# NSFW detection
nsfw_filter = NSFWFilter(threshold=0.9)
safe_images = nsfw_filter(filtered_images)
# Generate CLIP embeddings
clip_embedder = CLIPEmbedder(model="openai/clip-vit-base-patch32")
image_embeddings = clip_embedder(safe_images)
```
### Video curation
```python
from nemo_curator.video import (
SceneDetector,
ClipExtractor,
InternVideo2Embedder
)
# Detect scenes
scene_detector = SceneDetector(threshold=27.0)
scenes = scene_detector(video_dataset)
# Extract clips
clip_extractor = ClipExtractor(min_duration=2.0, max_duration=10.0)
clips = clip_extractor(scenes)
# Generate embeddings
video_embedder = InternVideo2Embedder()
video_embeddings = video_embedder(clips)
```
### Audio curation
```python
from nemo_curator.audio import (
ASRInference,
WERFilter,
DurationFilter
)
# ASR transcription
asr = ASRInference(model="nvidia/stt_en_fastconformer_hybrid_large_pc")
transcribed = asr(audio_dataset)
# Filter by WER (word error rate)
wer_filter = WERFilter(max_wer=0.3)
high_quality_audio = wer_filter(transcribed)
# Duration filtering
duration_filter = DurationFilter(min_duration=1.0, max_duration=30.0)
filtered_audio = duration_filter(high_quality_audio)
```
## Common patterns
### Web scrape curation (Common Crawl)
```python
from nemo_curator import ScoreFilter, Modify
from nemo_curator.filters import *
from nemo_curator.modules import *
from nemo_curator.datasets import DocumentDataset
# Load Common Crawl data
dataset = DocumentDataset.read_parquet("common_crawl/*.parquet")
# Pipeline
pipeline = [
# 1. Quality filtering
WordCountFilter(min_words=100, max_words=50000),
RepeatedLinesFilter(max_repeated_line_fraction=0.2),
SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3),
UrlRatioFilter(max_url_ratio=0.3),
# 2. Language filtering
LanguageIdentificationFilter(target_languages=["en"]),
# 3. Deduplication
ExactDuplicates(id_field="id", text_field="text"),
FuzzyDuplicates(id_field="id", text_field="text", num_hashes=260),
# 4. PII redaction
PIIRedactor(),
# 5. NSFW filtering
NSFWClassifier(threshold=0.8)
]
# Execute
for stage in pipeline:
dataset = stage(dataset)
# Save
dataset.to_parquet("curated_common_crawl/")
```
### Distributed processing
```python
from nemo_curator import get_client
from dask_cuda import LocalCUDACluster
# Multi-GPU cluster
cluster = LocalCUDACluster(n_workers=8)
client = get_client(cluster=cluster)
# Process large dataset
dataset = DocumentDataset.read_parquet("s3://large_dataset/*.parquet")
deduped = FuzzyDuplicates(...)(dataset)
# Cleanup
client.close()
cluster.close()
```
## Performance benchmarks
### Fuzzy deduplication (8TB RedPajama v2)
- **CPU (256 cores)**: 120 hours
- **GPU (8× A100)**: 7.5 hours
- **Speedup**: 16×
### Exact deduplication (1TB)
- **CPU (64 cores)**: 8 hours
- **GPU (4× A100)**: 0.5 hours
- **Speedup**: 16×
### Quality filtering (100GB)
- **CPU (32 cores)**: 2 hours
- **GPU (2× A100)**: 0.2 hours
- **Speedup**: 10×
## Cost comparison
**CPU-based curation** (AWS c5.18xlarge × 10):
- Cost: $3.60/hour × 10 = $36/hour
- Time for 8TB: 120 hours
- **Total**: $4,320
**GPU-based curation** (AWS p4d.24xlarge × 2):
- Cost: $32.77/hour × 2 = $65.54/hour
- Time for 8TB: 7.5 hours
- **Total**: $491.55
**Savings**: 89% reduction ($3,828 saved)
## Supported data formats
- **Input**: Parquet, JSONL, CSV
- **Output**: Parquet (recommended), JSONL
- **WebDataset**: TAR archives for multi-modal
## Use cases
**Production deployments**:
- NVIDIA used NeMo Curator to prepare Nemotron-4 training data
- Open-source datasets curated: RedPajama v2, The Pile
## References
- **[Filtering Guide](references/filtering.md)** - 30+ quality filters, heuristics
- **[Deduplication Guide](references/deduplication.md)** - Exact, fuzzy, semantic methods
## Resources
- **GitHub**: https://github.com/NVIDIA/NeMo-Curator ⭐ 500+
- **Docs**: https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/
- **Version**: 0.4.0+
- **License**: Apache 2.0

View File

@@ -0,0 +1,87 @@
# Deduplication Guide
Complete guide to exact, fuzzy, and semantic deduplication.
## Exact deduplication
Remove documents with identical content.
```python
from nemo_curator.modules import ExactDuplicates
# Exact deduplication
exact_dedup = ExactDuplicates(
id_field="id",
text_field="text",
hash_method="md5" # or "sha256"
)
deduped = exact_dedup(dataset)
```
**Performance**: ~16× faster on GPU vs CPU
## Fuzzy deduplication
Remove near-duplicate documents using MinHash + LSH.
```python
from nemo_curator.modules import FuzzyDuplicates
fuzzy_dedup = FuzzyDuplicates(
id_field="id",
text_field="text",
num_hashes=260, # MinHash permutations (more = accurate)
num_buckets=20, # LSH buckets (more = faster, less recall)
hash_method="md5",
jaccard_threshold=0.8 # Similarity threshold
)
deduped = fuzzy_dedup(dataset)
```
**Parameters**:
- `num_hashes`: 128-512 (default 260)
- `num_buckets`: 10-50 (default 20)
- `jaccard_threshold`: 0.7-0.9 (default 0.8)
**Performance**: 16× faster on 8TB dataset (120h → 7.5h)
## Semantic deduplication
Remove semantically similar documents using embeddings.
```python
from nemo_curator.modules import SemanticDuplicates
semantic_dedup = SemanticDuplicates(
id_field="id",
text_field="text",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
embedding_batch_size=256,
threshold=0.85, # Cosine similarity threshold
device="cuda"
)
deduped = semantic_dedup(dataset)
```
**Models**:
- `all-MiniLM-L6-v2`: Fast, 384 dims
- `all-mpnet-base-v2`: Better quality, 768 dims
- Custom models supported
## Comparison
| Method | Speed | Recall | Use Case |
|--------|-------|--------|----------|
| Exact | Fastest | 100% | Exact matches only |
| Fuzzy | Fast | ~95% | Near-duplicates (recommended) |
| Semantic | Slow | ~90% | Paraphrases, rewrites |
## Best practices
1. **Start with exact dedup** - Remove obvious duplicates
2. **Use fuzzy for large datasets** - Best speed/quality trade-off
3. **Semantic for high-value data** - Expensive but thorough
4. **GPU acceleration required** - 10-16× speedup

View File

@@ -0,0 +1,102 @@
# Quality Filtering Guide
Complete guide to NeMo Curator's 30+ quality filters.
## Text-based filters
### Word count
```python
from nemo_curator.filters import WordCountFilter
# Filter by word count
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
```
### Repeated content
```python
from nemo_curator.filters import RepeatedLinesFilter
# Remove documents with >30% repeated lines
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
```
### Symbol ratio
```python
from nemo_curator.filters import SymbolToWordRatioFilter
# Remove documents with too many symbols
dataset = dataset.filter(SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3))
```
### URL ratio
```python
from nemo_curator.filters import UrlRatioFilter
# Remove documents with many URLs
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
```
## Language filtering
```python
from nemo_curator.filters import LanguageIdentificationFilter
# Keep only English documents
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en"]))
# Multiple languages
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en", "es", "fr"]))
```
## Classifier-based filtering
### Quality classifier
```python
from nemo_curator.classifiers import QualityClassifier
quality_clf = QualityClassifier(
model_path="nvidia/quality-classifier-deberta",
batch_size=256,
device="cuda"
)
# Filter low-quality (threshold > 0.5 = high quality)
dataset = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
```
### NSFW classifier
```python
from nemo_curator.classifiers import NSFWClassifier
nsfw_clf = NSFWClassifier(threshold=0.9, device="cuda")
# Remove NSFW content
dataset = dataset.filter(lambda doc: nsfw_clf(doc["text"]) < 0.9)
```
## Heuristic filters
Full list of 30+ filters:
- WordCountFilter
- RepeatedLinesFilter
- UrlRatioFilter
- SymbolToWordRatioFilter
- NonAlphaNumericFilter
- BulletsFilter
- WhiteSpaceFilter
- ParenthesesFilter
- LongWordFilter
- And 20+ more...
## Best practices
1. **Apply cheap filters first** - Word count before GPU classifiers
2. **Tune thresholds on sample** - Test on 10k docs before full run
3. **Use GPU classifiers sparingly** - Expensive but effective
4. **Chain filters efficiently** - Order by cost (cheap → expensive)

View File

@@ -0,0 +1,389 @@
---
name: sparse-autoencoder-training
description: Provides guidance for training and analyzing Sparse Autoencoders (SAEs) using SAELens to decompose neural network activations into interpretable features. Use when discovering interpretable features, analyzing superposition, or studying monosemantic representations in language models.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [sae-lens>=6.0.0, transformer-lens>=2.0.0, torch>=2.0.0]
metadata:
hermes:
tags: [Sparse Autoencoders, SAE, Mechanistic Interpretability, Feature Discovery, Superposition]
---
# SAELens: Sparse Autoencoders for Mechanistic Interpretability
SAELens is the primary library for training and analyzing Sparse Autoencoders (SAEs) - a technique for decomposing polysemantic neural network activations into sparse, interpretable features. Based on Anthropic's groundbreaking research on monosemanticity.
**GitHub**: [jbloomAus/SAELens](https://github.com/jbloomAus/SAELens) (1,100+ stars)
## The Problem: Polysemanticity & Superposition
Individual neurons in neural networks are **polysemantic** - they activate in multiple, semantically distinct contexts. This happens because models use **superposition** to represent more features than they have neurons, making interpretability difficult.
**SAEs solve this** by decomposing dense activations into sparse, monosemantic features - typically only a small number of features activate for any given input, and each feature corresponds to an interpretable concept.
## When to Use SAELens
**Use SAELens when you need to:**
- Discover interpretable features in model activations
- Understand what concepts a model has learned
- Study superposition and feature geometry
- Perform feature-based steering or ablation
- Analyze safety-relevant features (deception, bias, harmful content)
**Consider alternatives when:**
- You need basic activation analysis → Use **TransformerLens** directly
- You want causal intervention experiments → Use **pyvene** or **TransformerLens**
- You need production steering → Consider direct activation engineering
## Installation
```bash
pip install sae-lens
```
Requirements: Python 3.10+, transformer-lens>=2.0.0
## Core Concepts
### What SAEs Learn
SAEs are trained to reconstruct model activations through a sparse bottleneck:
```
Input Activation → Encoder → Sparse Features → Decoder → Reconstructed Activation
(d_model) ↓ (d_sae >> d_model) ↓ (d_model)
sparsity reconstruction
penalty loss
```
**Loss Function**: `MSE(original, reconstructed) + L1_coefficient × L1(features)`
### Key Validation (Anthropic Research)
In "Towards Monosemanticity", human evaluators found **70% of SAE features genuinely interpretable**. Features discovered include:
- DNA sequences, legal language, HTTP requests
- Hebrew text, nutrition statements, code syntax
- Sentiment, named entities, grammatical structures
## Workflow 1: Loading and Analyzing Pre-trained SAEs
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
# 1. Load model and pre-trained SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# 2. Get model activations
tokens = model.to_tokens("The capital of France is Paris")
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8] # [batch, pos, d_model]
# 3. Encode to SAE features
sae_features = sae.encode(activations) # [batch, pos, d_sae]
print(f"Active features: {(sae_features > 0).sum()}")
# 4. Find top features for each position
for pos in range(tokens.shape[1]):
top_features = sae_features[0, pos].topk(5)
token = model.to_str_tokens(tokens[0, pos:pos+1])[0]
print(f"Token '{token}': features {top_features.indices.tolist()}")
# 5. Reconstruct activations
reconstructed = sae.decode(sae_features)
reconstruction_error = (activations - reconstructed).norm()
```
### Available Pre-trained SAEs
| Release | Model | Layers |
|---------|-------|--------|
| `gpt2-small-res-jb` | GPT-2 Small | Multiple residual streams |
| `gemma-2b-res` | Gemma 2B | Residual streams |
| Various on HuggingFace | Search tag `saelens` | Various |
### Checklist
- [ ] Load model with TransformerLens
- [ ] Load matching SAE for target layer
- [ ] Encode activations to sparse features
- [ ] Identify top-activating features per token
- [ ] Validate reconstruction quality
## Workflow 2: Training a Custom SAE
### Step-by-Step
```python
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner
# 1. Configure training
cfg = LanguageModelSAERunnerConfig(
# Model
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768, # Model dimension
# SAE architecture
architecture="standard", # or "gated", "topk"
d_sae=768 * 8, # Expansion factor of 8
activation_fn="relu",
# Training
lr=4e-4,
l1_coefficient=8e-5, # Sparsity penalty
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=100_000_000,
# Data
dataset_path="monology/pile-uncopyrighted",
context_size=128,
# Logging
log_to_wandb=True,
wandb_project="sae-training",
# Checkpointing
checkpoint_path="checkpoints",
n_checkpoints=5,
)
# 2. Train
trainer = SAETrainingRunner(cfg)
sae = trainer.run()
# 3. Evaluate
print(f"L0 (avg active features): {trainer.metrics['l0']}")
print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
```
### Key Hyperparameters
| Parameter | Typical Value | Effect |
|-----------|---------------|--------|
| `d_sae` | 4-16× d_model | More features, higher capacity |
| `l1_coefficient` | 5e-5 to 1e-4 | Higher = sparser, less accurate |
| `lr` | 1e-4 to 1e-3 | Standard optimizer LR |
| `l1_warm_up_steps` | 500-2000 | Prevents early feature death |
### Evaluation Metrics
| Metric | Target | Meaning |
|--------|--------|---------|
| **L0** | 50-200 | Average active features per token |
| **CE Loss Score** | 80-95% | Cross-entropy recovered vs original |
| **Dead Features** | <5% | Features that never activate |
| **Explained Variance** | >90% | Reconstruction quality |
### Checklist
- [ ] Choose target layer and hook point
- [ ] Set expansion factor (d_sae = 4-16× d_model)
- [ ] Tune L1 coefficient for desired sparsity
- [ ] Enable L1 warm-up to prevent dead features
- [ ] Monitor metrics during training (W&B)
- [ ] Validate L0 and CE loss recovery
- [ ] Check dead feature ratio
## Workflow 3: Feature Analysis and Steering
### Analyzing Individual Features
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Find what activates a specific feature
feature_idx = 1234
test_texts = [
"The scientist conducted an experiment",
"I love chocolate cake",
"The code compiles successfully",
"Paris is beautiful in spring",
]
for text in test_texts:
tokens = model.to_tokens(text)
_, cache = model.run_with_cache(tokens)
features = sae.encode(cache["resid_pre", 8])
activation = features[0, :, feature_idx].max().item()
print(f"{activation:.3f}: {text}")
```
### Feature Steering
```python
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
"""Add SAE feature direction to residual stream."""
tokens = model.to_tokens(prompt)
# Get feature direction from decoder
feature_direction = sae.W_dec[feature_idx] # [d_model]
def steering_hook(activation, hook):
# Add scaled feature direction at all positions
activation += strength * feature_direction
return activation
# Generate with steering
output = model.generate(
tokens,
max_new_tokens=50,
fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
)
return model.to_string(output[0])
```
### Feature Attribution
```python
# Which features most affect a specific output?
tokens = model.to_tokens("The capital of France is")
_, cache = model.run_with_cache(tokens)
# Get features at final position
features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]
# Get logit attribution per feature
# Feature contribution = feature_activation × decoder_weight × unembedding
W_dec = sae.W_dec # [d_sae, d_model]
W_U = model.W_U # [d_model, vocab]
# Contribution to "Paris" logit
paris_token = model.to_single_token(" Paris")
feature_contributions = features * (W_dec @ W_U[:, paris_token])
top_features = feature_contributions.topk(10)
print("Top features for 'Paris' prediction:")
for idx, val in zip(top_features.indices, top_features.values):
print(f" Feature {idx.item()}: {val.item():.3f}")
```
## Common Issues & Solutions
### Issue: High dead feature ratio
```python
# WRONG: No warm-up, features die early
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4,
l1_warm_up_steps=0, # Bad!
)
# RIGHT: Warm-up L1 penalty
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=8e-5,
l1_warm_up_steps=1000, # Gradually increase
use_ghost_grads=True, # Revive dead features
)
```
### Issue: Poor reconstruction (low CE recovery)
```python
# Reduce sparsity penalty
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=5e-5, # Lower = better reconstruction
d_sae=768 * 16, # More capacity
)
```
### Issue: Features not interpretable
```python
# Increase sparsity (higher L1)
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4, # Higher = sparser, more interpretable
)
# Or use TopK architecture
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
)
```
### Issue: Memory errors during training
```python
cfg = LanguageModelSAERunnerConfig(
train_batch_size_tokens=2048, # Reduce batch size
store_batch_size_prompts=4, # Fewer prompts in buffer
n_batches_in_buffer=8, # Smaller activation buffer
)
```
## Integration with Neuronpedia
Browse pre-trained SAE features at [neuronpedia.org](https://neuronpedia.org):
```python
# Features are indexed by SAE ID
# Example: gpt2-small layer 8 feature 1234
# → neuronpedia.org/gpt2-small/8-res-jb/1234
```
## Key Classes Reference
| Class | Purpose |
|-------|---------|
| `SAE` | Sparse Autoencoder model |
| `LanguageModelSAERunnerConfig` | Training configuration |
| `SAETrainingRunner` | Training loop manager |
| `ActivationsStore` | Activation collection and batching |
| `HookedSAETransformer` | TransformerLens + SAE integration |
## Reference Documentation
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
| File | Contents |
|------|----------|
| [references/README.md](references/README.md) | Overview and quick start guide |
| [references/api.md](references/api.md) | Complete API reference for SAE, TrainingSAE, configurations |
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for training, analysis, steering |
## External Resources
### Tutorials
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
- [Training a Sparse Autoencoder](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
- [ARENA SAE Curriculum](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab)
### Papers
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
- [Sparse Autoencoders Find Highly Interpretable Features](https://arxiv.org/abs/2309.08600) - Cunningham et al. (ICLR 2024)
### Official Documentation
- [SAELens Docs](https://jbloomaus.github.io/SAELens/)
- [Neuronpedia](https://neuronpedia.org) - Feature browser
## SAE Architectures
| Architecture | Description | Use Case |
|--------------|-------------|----------|
| **Standard** | ReLU + L1 penalty | General purpose |
| **Gated** | Learned gating mechanism | Better sparsity control |
| **TopK** | Exactly K active features | Consistent sparsity |
```python
# TopK SAE (exactly 50 features active)
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50},
)
```

View File

@@ -0,0 +1,70 @@
# SAELens Reference Documentation
This directory contains comprehensive reference materials for SAELens.
## Contents
- [api.md](api.md) - Complete API reference for SAE, TrainingSAE, and configuration classes
- [tutorials.md](tutorials.md) - Step-by-step tutorials for training and analyzing SAEs
- [papers.md](papers.md) - Key research papers on sparse autoencoders
## Quick Links
- **GitHub Repository**: https://github.com/jbloomAus/SAELens
- **Neuronpedia**: https://neuronpedia.org (browse pre-trained SAE features)
- **HuggingFace SAEs**: Search for tag `saelens`
## Installation
```bash
pip install sae-lens
```
Requirements: Python 3.10+, transformer-lens>=2.0.0
## Basic Usage
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
# Load model and SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Encode activations to sparse features
tokens = model.to_tokens("Hello world")
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations) # Sparse feature activations
reconstructed = sae.decode(features) # Reconstructed activations
```
## Key Concepts
### Sparse Autoencoders
SAEs decompose dense neural activations into sparse, interpretable features:
- **Encoder**: Maps d_model → d_sae (typically 4-16x expansion)
- **ReLU/TopK**: Enforces sparsity
- **Decoder**: Reconstructs original activations
### Training Loss
`Loss = MSE(original, reconstructed) + L1_coefficient × L1(features)`
### Key Metrics
- **L0**: Average number of active features (target: 50-200)
- **CE Loss Score**: Cross-entropy recovered vs original model (target: 80-95%)
- **Dead Features**: Features that never activate (target: <5%)
## Available Pre-trained SAEs
| Release | Model | Description |
|---------|-------|-------------|
| `gpt2-small-res-jb` | GPT-2 Small | Residual stream SAEs |
| `gemma-2b-res` | Gemma 2B | Residual stream SAEs |
| Various | Search HuggingFace | Community-trained SAEs |

View File

@@ -0,0 +1,333 @@
# SAELens API Reference
## SAE Class
The core class representing a Sparse Autoencoder.
### Loading Pre-trained SAEs
```python
from sae_lens import SAE
# From official releases
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# From HuggingFace
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="username/repo-name",
sae_id="path/to/sae",
device="cuda"
)
# From local disk
sae = SAE.load_from_disk("/path/to/sae", device="cuda")
```
### SAE Attributes
| Attribute | Shape | Description |
|-----------|-------|-------------|
| `W_enc` | [d_in, d_sae] | Encoder weights |
| `W_dec` | [d_sae, d_in] | Decoder weights |
| `b_enc` | [d_sae] | Encoder bias |
| `b_dec` | [d_in] | Decoder bias |
| `cfg` | SAEConfig | Configuration object |
### Core Methods
#### encode()
```python
# Encode activations to sparse features
features = sae.encode(activations)
# Input: [batch, pos, d_in]
# Output: [batch, pos, d_sae]
```
#### decode()
```python
# Reconstruct activations from features
reconstructed = sae.decode(features)
# Input: [batch, pos, d_sae]
# Output: [batch, pos, d_in]
```
#### forward()
```python
# Full forward pass (encode + decode)
reconstructed = sae(activations)
# Returns reconstructed activations
```
#### save_model()
```python
sae.save_model("/path/to/save")
```
---
## SAEConfig
Configuration class for SAE architecture and training context.
### Key Parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `d_in` | int | Input dimension (model's d_model) |
| `d_sae` | int | SAE hidden dimension |
| `architecture` | str | "standard", "gated", "jumprelu", "topk" |
| `activation_fn_str` | str | Activation function name |
| `model_name` | str | Source model name |
| `hook_name` | str | Hook point in model |
| `normalize_activations` | str | Normalization method |
| `dtype` | str | Data type |
| `device` | str | Device |
### Accessing Config
```python
print(sae.cfg.d_in) # 768 for GPT-2 small
print(sae.cfg.d_sae) # e.g., 24576 (32x expansion)
print(sae.cfg.hook_name) # e.g., "blocks.8.hook_resid_pre"
```
---
## LanguageModelSAERunnerConfig
Comprehensive configuration for training SAEs.
### Example Configuration
```python
from sae_lens import LanguageModelSAERunnerConfig
cfg = LanguageModelSAERunnerConfig(
# Model and hook
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768,
# SAE architecture
architecture="standard", # "standard", "gated", "jumprelu", "topk"
d_sae=768 * 8, # Expansion factor
activation_fn="relu",
# Training hyperparameters
lr=4e-4,
l1_coefficient=8e-5,
lp_norm=1.0,
lr_scheduler_name="constant",
lr_warm_up_steps=500,
# Sparsity control
l1_warm_up_steps=1000,
use_ghost_grads=True,
feature_sampling_window=1000,
dead_feature_window=5000,
dead_feature_threshold=1e-8,
# Data
dataset_path="monology/pile-uncopyrighted",
streaming=True,
context_size=128,
# Batch sizes
train_batch_size_tokens=4096,
store_batch_size_prompts=16,
n_batches_in_buffer=64,
# Training duration
training_tokens=100_000_000,
# Logging
log_to_wandb=True,
wandb_project="sae-training",
wandb_log_frequency=100,
# Checkpointing
checkpoint_path="checkpoints",
n_checkpoints=5,
# Hardware
device="cuda",
dtype="float32",
)
```
### Key Parameters Explained
#### Architecture Parameters
| Parameter | Description |
|-----------|-------------|
| `architecture` | SAE type: "standard", "gated", "jumprelu", "topk" |
| `d_sae` | Hidden dimension (or use `expansion_factor`) |
| `expansion_factor` | Alternative to d_sae: d_sae = d_in × expansion_factor |
| `activation_fn` | "relu", "topk", etc. |
| `activation_fn_kwargs` | Dict for activation params (e.g., {"k": 50} for topk) |
#### Sparsity Parameters
| Parameter | Description |
|-----------|-------------|
| `l1_coefficient` | L1 penalty weight (higher = sparser) |
| `l1_warm_up_steps` | Steps to ramp up L1 penalty |
| `use_ghost_grads` | Apply gradients to dead features |
| `dead_feature_threshold` | Activation threshold for "dead" |
| `dead_feature_window` | Steps to check for dead features |
#### Learning Rate Parameters
| Parameter | Description |
|-----------|-------------|
| `lr` | Base learning rate |
| `lr_scheduler_name` | "constant", "cosineannealing", etc. |
| `lr_warm_up_steps` | LR warmup steps |
| `lr_decay_steps` | Steps for LR decay |
---
## SAETrainingRunner
Main class for executing training.
### Basic Training
```python
from sae_lens import SAETrainingRunner, LanguageModelSAERunnerConfig
cfg = LanguageModelSAERunnerConfig(...)
runner = SAETrainingRunner(cfg)
sae = runner.run()
```
### Accessing Training Metrics
```python
# During training, metrics logged to W&B include:
# - l0: Average active features
# - ce_loss_score: Cross-entropy recovery
# - mse_loss: Reconstruction loss
# - l1_loss: Sparsity loss
# - dead_features: Count of dead features
```
---
## ActivationsStore
Manages activation collection and batching.
### Basic Usage
```python
from sae_lens import ActivationsStore
store = ActivationsStore.from_sae(
model=model,
sae=sae,
store_batch_size_prompts=8,
train_batch_size_tokens=4096,
n_batches_in_buffer=32,
device="cuda",
)
# Get batch of activations
activations = store.get_batch_tokens()
```
---
## HookedSAETransformer
Integration of SAEs with TransformerLens models.
### Basic Usage
```python
from sae_lens import HookedSAETransformer
# Load model with SAE
model = HookedSAETransformer.from_pretrained("gpt2-small")
model.add_sae(sae)
# Run with SAE in the loop
output = model.run_with_saes(tokens, saes=[sae])
# Cache with SAE activations
output, cache = model.run_with_cache_with_saes(tokens, saes=[sae])
```
---
## SAE Architectures
### Standard (ReLU + L1)
```python
cfg = LanguageModelSAERunnerConfig(
architecture="standard",
activation_fn="relu",
l1_coefficient=8e-5,
)
```
### Gated
```python
cfg = LanguageModelSAERunnerConfig(
architecture="gated",
)
```
### TopK
```python
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
)
```
### JumpReLU (State-of-the-art)
```python
cfg = LanguageModelSAERunnerConfig(
architecture="jumprelu",
)
```
---
## Utility Functions
### Upload to HuggingFace
```python
from sae_lens import upload_saes_to_huggingface
upload_saes_to_huggingface(
saes=[sae],
repo_id="username/my-saes",
token="hf_token",
)
```
### Neuronpedia Integration
```python
# Features can be viewed on Neuronpedia
# URL format: neuronpedia.org/{model}/{layer}-{sae_type}/{feature_id}
# Example: neuronpedia.org/gpt2-small/8-res-jb/1234
```

View File

@@ -0,0 +1,318 @@
# SAELens Tutorials
## Tutorial 1: Loading and Analyzing Pre-trained SAEs
### Goal
Load a pre-trained SAE and analyze which features activate on specific inputs.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
# 1. Load model and SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
print(f"SAE input dim: {sae.cfg.d_in}")
print(f"SAE hidden dim: {sae.cfg.d_sae}")
print(f"Expansion factor: {sae.cfg.d_sae / sae.cfg.d_in:.1f}x")
# 2. Get model activations
prompt = "The capital of France is Paris"
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8] # [1, seq_len, 768]
# 3. Encode to SAE features
features = sae.encode(activations) # [1, seq_len, d_sae]
# 4. Analyze sparsity
active_per_token = (features > 0).sum(dim=-1)
print(f"Average active features per token: {active_per_token.float().mean():.1f}")
# 5. Find top features for each token
str_tokens = model.to_str_tokens(prompt)
for pos in range(len(str_tokens)):
top_features = features[0, pos].topk(5)
print(f"\nToken '{str_tokens[pos]}':")
for feat_idx, feat_val in zip(top_features.indices, top_features.values):
print(f" Feature {feat_idx.item()}: {feat_val.item():.3f}")
# 6. Check reconstruction quality
reconstructed = sae.decode(features)
mse = ((activations - reconstructed) ** 2).mean()
print(f"\nReconstruction MSE: {mse.item():.6f}")
```
---
## Tutorial 2: Training a Custom SAE
### Goal
Train a Sparse Autoencoder on GPT-2 activations.
### Step-by-Step
```python
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
# 1. Configure training
cfg = LanguageModelSAERunnerConfig(
# Model
model_name="gpt2-small",
hook_name="blocks.6.hook_resid_pre",
hook_layer=6,
d_in=768,
# SAE architecture
architecture="standard",
d_sae=768 * 8, # 8x expansion
activation_fn="relu",
# Training
lr=4e-4,
l1_coefficient=8e-5,
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=10_000_000, # Small run for demo
# Data
dataset_path="monology/pile-uncopyrighted",
streaming=True,
context_size=128,
# Dead feature prevention
use_ghost_grads=True,
dead_feature_window=5000,
# Logging
log_to_wandb=True,
wandb_project="sae-training-demo",
# Hardware
device="cuda",
dtype="float32",
)
# 2. Train
runner = SAETrainingRunner(cfg)
sae = runner.run()
# 3. Save
sae.save_model("./my_trained_sae")
```
### Hyperparameter Tuning Guide
| If you see... | Try... |
|---------------|--------|
| High L0 (>200) | Increase `l1_coefficient` |
| Low CE recovery (<80%) | Decrease `l1_coefficient`, increase `d_sae` |
| Many dead features (>5%) | Enable `use_ghost_grads`, increase `l1_warm_up_steps` |
| Training instability | Lower `lr`, increase `lr_warm_up_steps` |
---
## Tutorial 3: Feature Attribution and Steering
### Goal
Identify which SAE features contribute to specific predictions and use them for steering.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# 1. Feature attribution for a specific prediction
prompt = "The capital of France is"
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations)
# Target token
target_token = model.to_single_token(" Paris")
# Compute feature contributions to target logit
# contribution = feature_activation * decoder_weight * unembedding
W_dec = sae.W_dec # [d_sae, d_model]
W_U = model.W_U # [d_model, d_vocab]
# Feature direction projected to vocabulary
feature_to_logit = W_dec @ W_U # [d_sae, d_vocab]
# Contribution of each feature to "Paris" at final position
feature_acts = features[0, -1] # [d_sae]
contributions = feature_acts * feature_to_logit[:, target_token]
# Top contributing features
top_features = contributions.topk(10)
print("Top features contributing to 'Paris':")
for idx, val in zip(top_features.indices, top_features.values):
print(f" Feature {idx.item()}: {val.item():.3f}")
# 2. Feature steering
def steer_with_feature(feature_idx, strength=5.0):
"""Add a feature direction to the residual stream."""
feature_direction = sae.W_dec[feature_idx] # [d_model]
def hook(activation, hook_obj):
activation[:, -1, :] += strength * feature_direction
return activation
output = model.generate(
tokens,
max_new_tokens=10,
fwd_hooks=[("blocks.8.hook_resid_pre", hook)]
)
return model.to_string(output[0])
# Try steering with top feature
top_feature_idx = top_features.indices[0].item()
print(f"\nSteering with feature {top_feature_idx}:")
print(steer_with_feature(top_feature_idx, strength=10.0))
```
---
## Tutorial 4: Feature Ablation
### Goal
Test the causal importance of features by ablating them.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
prompt = "The capital of France is"
tokens = model.to_tokens(prompt)
# Baseline prediction
baseline_logits = model(tokens)
target_token = model.to_single_token(" Paris")
baseline_prob = torch.softmax(baseline_logits[0, -1], dim=-1)[target_token].item()
print(f"Baseline P(Paris): {baseline_prob:.4f}")
# Get features to ablate
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations)
top_features = features[0, -1].topk(10).indices
# Ablate top features one by one
for feat_idx in top_features:
def ablation_hook(activation, hook, feat_idx=feat_idx):
# Encode → zero feature → decode
feats = sae.encode(activation)
feats[:, :, feat_idx] = 0
return sae.decode(feats)
ablated_logits = model.run_with_hooks(
tokens,
fwd_hooks=[("blocks.8.hook_resid_pre", ablation_hook)]
)
ablated_prob = torch.softmax(ablated_logits[0, -1], dim=-1)[target_token].item()
change = (ablated_prob - baseline_prob) / baseline_prob * 100
print(f"Ablate feature {feat_idx.item()}: P(Paris)={ablated_prob:.4f} ({change:+.1f}%)")
```
---
## Tutorial 5: Comparing Features Across Prompts
### Goal
Find which features activate consistently for a concept.
### Step-by-Step
```python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# Test prompts about the same concept
prompts = [
"The Eiffel Tower is located in",
"Paris is the capital of",
"France's largest city is",
"The Louvre museum is in",
]
# Collect feature activations
all_features = []
for prompt in prompts:
tokens = model.to_tokens(prompt)
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8]
features = sae.encode(activations)
# Take max activation across positions
max_features = features[0].max(dim=0).values
all_features.append(max_features)
all_features = torch.stack(all_features) # [n_prompts, d_sae]
# Find features that activate consistently
mean_activation = all_features.mean(dim=0)
min_activation = all_features.min(dim=0).values
# Features active in ALL prompts
consistent_features = (min_activation > 0.5).nonzero().squeeze(-1)
print(f"Features active in all prompts: {len(consistent_features)}")
# Top consistent features
top_consistent = mean_activation[consistent_features].topk(min(10, len(consistent_features)))
print("\nTop consistent features (possibly 'France/Paris' related):")
for idx, val in zip(top_consistent.indices, top_consistent.values):
feat_idx = consistent_features[idx].item()
print(f" Feature {feat_idx}: mean activation {val.item():.3f}")
```
---
## External Resources
### Official Tutorials
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
- [Training SAEs](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
- [Logits Lens with Features](https://github.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
### ARENA Curriculum
Comprehensive SAE course: https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab
### Key Papers
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
- [Sparse Autoencoders Find Interpretable Features](https://arxiv.org/abs/2309.08600) - ICLR 2024

View File

@@ -0,0 +1,593 @@
---
name: weights-and-biases
description: Track ML experiments with automatic logging, visualize training in real-time, optimize hyperparameters with sweeps, and manage model registry with W&B - collaborative MLOps platform
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [wandb]
metadata:
hermes:
tags: [MLOps, Weights And Biases, WandB, Experiment Tracking, Hyperparameter Tuning, Model Registry, Collaboration, Real-Time Visualization, PyTorch, TensorFlow, HuggingFace]
---
# Weights & Biases: ML Experiment Tracking & MLOps
## When to Use This Skill
Use Weights & Biases (W&B) when you need to:
- **Track ML experiments** with automatic metric logging
- **Visualize training** in real-time dashboards
- **Compare runs** across hyperparameters and configurations
- **Optimize hyperparameters** with automated sweeps
- **Manage model registry** with versioning and lineage
- **Collaborate on ML projects** with team workspaces
- **Track artifacts** (datasets, models, code) with lineage
**Users**: 200,000+ ML practitioners | **GitHub Stars**: 10.5k+ | **Integrations**: 100+
## Installation
```bash
# Install W&B
pip install wandb
# Login (creates API key)
wandb login
# Or set API key programmatically
export WANDB_API_KEY=your_api_key_here
```
## Quick Start
### Basic Experiment Tracking
```python
import wandb
# Initialize a run
run = wandb.init(
project="my-project",
config={
"learning_rate": 0.001,
"epochs": 10,
"batch_size": 32,
"architecture": "ResNet50"
}
)
# Training loop
for epoch in range(run.config.epochs):
# Your training code
train_loss = train_epoch()
val_loss = validate()
# Log metrics
wandb.log({
"epoch": epoch,
"train/loss": train_loss,
"val/loss": val_loss,
"train/accuracy": train_acc,
"val/accuracy": val_acc
})
# Finish the run
wandb.finish()
```
### With PyTorch
```python
import torch
import wandb
# Initialize
wandb.init(project="pytorch-demo", config={
"lr": 0.001,
"epochs": 10
})
# Access config
config = wandb.config
# Training loop
for epoch in range(config.epochs):
for batch_idx, (data, target) in enumerate(train_loader):
# Forward pass
output = model(data)
loss = criterion(output, target)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Log every 100 batches
if batch_idx % 100 == 0:
wandb.log({
"loss": loss.item(),
"epoch": epoch,
"batch": batch_idx
})
# Save model
torch.save(model.state_dict(), "model.pth")
wandb.save("model.pth") # Upload to W&B
wandb.finish()
```
## Core Concepts
### 1. Projects and Runs
**Project**: Collection of related experiments
**Run**: Single execution of your training script
```python
# Create/use project
run = wandb.init(
project="image-classification",
name="resnet50-experiment-1", # Optional run name
tags=["baseline", "resnet"], # Organize with tags
notes="First baseline run" # Add notes
)
# Each run has unique ID
print(f"Run ID: {run.id}")
print(f"Run URL: {run.url}")
```
### 2. Configuration Tracking
Track hyperparameters automatically:
```python
config = {
# Model architecture
"model": "ResNet50",
"pretrained": True,
# Training params
"learning_rate": 0.001,
"batch_size": 32,
"epochs": 50,
"optimizer": "Adam",
# Data params
"dataset": "ImageNet",
"augmentation": "standard"
}
wandb.init(project="my-project", config=config)
# Access config during training
lr = wandb.config.learning_rate
batch_size = wandb.config.batch_size
```
### 3. Metric Logging
```python
# Log scalars
wandb.log({"loss": 0.5, "accuracy": 0.92})
# Log multiple metrics
wandb.log({
"train/loss": train_loss,
"train/accuracy": train_acc,
"val/loss": val_loss,
"val/accuracy": val_acc,
"learning_rate": current_lr,
"epoch": epoch
})
# Log with custom x-axis
wandb.log({"loss": loss}, step=global_step)
# Log media (images, audio, video)
wandb.log({"examples": [wandb.Image(img) for img in images]})
# Log histograms
wandb.log({"gradients": wandb.Histogram(gradients)})
# Log tables
table = wandb.Table(columns=["id", "prediction", "ground_truth"])
wandb.log({"predictions": table})
```
### 4. Model Checkpointing
```python
import torch
import wandb
# Save model checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, 'checkpoint.pth')
# Upload to W&B
wandb.save('checkpoint.pth')
# Or use Artifacts (recommended)
artifact = wandb.Artifact('model', type='model')
artifact.add_file('checkpoint.pth')
wandb.log_artifact(artifact)
```
## Hyperparameter Sweeps
Automatically search for optimal hyperparameters.
### Define Sweep Configuration
```python
sweep_config = {
'method': 'bayes', # or 'grid', 'random'
'metric': {
'name': 'val/accuracy',
'goal': 'maximize'
},
'parameters': {
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
},
'batch_size': {
'values': [16, 32, 64, 128]
},
'optimizer': {
'values': ['adam', 'sgd', 'rmsprop']
},
'dropout': {
'distribution': 'uniform',
'min': 0.1,
'max': 0.5
}
}
}
# Initialize sweep
sweep_id = wandb.sweep(sweep_config, project="my-project")
```
### Define Training Function
```python
def train():
# Initialize run
run = wandb.init()
# Access sweep parameters
lr = wandb.config.learning_rate
batch_size = wandb.config.batch_size
optimizer_name = wandb.config.optimizer
# Build model with sweep config
model = build_model(wandb.config)
optimizer = get_optimizer(optimizer_name, lr)
# Training loop
for epoch in range(NUM_EPOCHS):
train_loss = train_epoch(model, optimizer, batch_size)
val_acc = validate(model)
# Log metrics
wandb.log({
"train/loss": train_loss,
"val/accuracy": val_acc
})
# Run sweep
wandb.agent(sweep_id, function=train, count=50) # Run 50 trials
```
### Sweep Strategies
```python
# Grid search - exhaustive
sweep_config = {
'method': 'grid',
'parameters': {
'lr': {'values': [0.001, 0.01, 0.1]},
'batch_size': {'values': [16, 32, 64]}
}
}
# Random search
sweep_config = {
'method': 'random',
'parameters': {
'lr': {'distribution': 'uniform', 'min': 0.0001, 'max': 0.1},
'dropout': {'distribution': 'uniform', 'min': 0.1, 'max': 0.5}
}
}
# Bayesian optimization (recommended)
sweep_config = {
'method': 'bayes',
'metric': {'name': 'val/loss', 'goal': 'minimize'},
'parameters': {
'lr': {'distribution': 'log_uniform', 'min': 1e-5, 'max': 1e-1}
}
}
```
## Artifacts
Track datasets, models, and other files with lineage.
### Log Artifacts
```python
# Create artifact
artifact = wandb.Artifact(
name='training-dataset',
type='dataset',
description='ImageNet training split',
metadata={'size': '1.2M images', 'split': 'train'}
)
# Add files
artifact.add_file('data/train.csv')
artifact.add_dir('data/images/')
# Log artifact
wandb.log_artifact(artifact)
```
### Use Artifacts
```python
# Download and use artifact
run = wandb.init(project="my-project")
# Download artifact
artifact = run.use_artifact('training-dataset:latest')
artifact_dir = artifact.download()
# Use the data
data = load_data(f"{artifact_dir}/train.csv")
```
### Model Registry
```python
# Log model as artifact
model_artifact = wandb.Artifact(
name='resnet50-model',
type='model',
metadata={'architecture': 'ResNet50', 'accuracy': 0.95}
)
model_artifact.add_file('model.pth')
wandb.log_artifact(model_artifact, aliases=['best', 'production'])
# Link to model registry
run.link_artifact(model_artifact, 'model-registry/production-models')
```
## Integration Examples
### HuggingFace Transformers
```python
from transformers import Trainer, TrainingArguments
import wandb
# Initialize W&B
wandb.init(project="hf-transformers")
# Training arguments with W&B
training_args = TrainingArguments(
output_dir="./results",
report_to="wandb", # Enable W&B logging
run_name="bert-finetuning",
logging_steps=100,
save_steps=500
)
# Trainer automatically logs to W&B
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()
```
### PyTorch Lightning
```python
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
import wandb
# Create W&B logger
wandb_logger = WandbLogger(
project="lightning-demo",
log_model=True # Log model checkpoints
)
# Use with Trainer
trainer = Trainer(
logger=wandb_logger,
max_epochs=10
)
trainer.fit(model, datamodule=dm)
```
### Keras/TensorFlow
```python
import wandb
from wandb.keras import WandbCallback
# Initialize
wandb.init(project="keras-demo")
# Add callback
model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=10,
callbacks=[WandbCallback()] # Auto-logs metrics
)
```
## Visualization & Analysis
### Custom Charts
```python
# Log custom visualizations
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(x, y)
wandb.log({"custom_plot": wandb.Image(fig)})
# Log confusion matrix
wandb.log({"conf_mat": wandb.plot.confusion_matrix(
probs=None,
y_true=ground_truth,
preds=predictions,
class_names=class_names
)})
```
### Reports
Create shareable reports in W&B UI:
- Combine runs, charts, and text
- Markdown support
- Embeddable visualizations
- Team collaboration
## Best Practices
### 1. Organize with Tags and Groups
```python
wandb.init(
project="my-project",
tags=["baseline", "resnet50", "imagenet"],
group="resnet-experiments", # Group related runs
job_type="train" # Type of job
)
```
### 2. Log Everything Relevant
```python
# Log system metrics
wandb.log({
"gpu/util": gpu_utilization,
"gpu/memory": gpu_memory_used,
"cpu/util": cpu_utilization
})
# Log code version
wandb.log({"git_commit": git_commit_hash})
# Log data splits
wandb.log({
"data/train_size": len(train_dataset),
"data/val_size": len(val_dataset)
})
```
### 3. Use Descriptive Names
```python
# ✅ Good: Descriptive run names
wandb.init(
project="nlp-classification",
name="bert-base-lr0.001-bs32-epoch10"
)
# ❌ Bad: Generic names
wandb.init(project="nlp", name="run1")
```
### 4. Save Important Artifacts
```python
# Save final model
artifact = wandb.Artifact('final-model', type='model')
artifact.add_file('model.pth')
wandb.log_artifact(artifact)
# Save predictions for analysis
predictions_table = wandb.Table(
columns=["id", "input", "prediction", "ground_truth"],
data=predictions_data
)
wandb.log({"predictions": predictions_table})
```
### 5. Use Offline Mode for Unstable Connections
```python
import os
# Enable offline mode
os.environ["WANDB_MODE"] = "offline"
wandb.init(project="my-project")
# ... your code ...
# Sync later
# wandb sync <run_directory>
```
## Team Collaboration
### Share Runs
```python
# Runs are automatically shareable via URL
run = wandb.init(project="team-project")
print(f"Share this URL: {run.url}")
```
### Team Projects
- Create team account at wandb.ai
- Add team members
- Set project visibility (private/public)
- Use team-level artifacts and model registry
## Pricing
- **Free**: Unlimited public projects, 100GB storage
- **Academic**: Free for students/researchers
- **Teams**: $50/seat/month, private projects, unlimited storage
- **Enterprise**: Custom pricing, on-prem options
## Resources
- **Documentation**: https://docs.wandb.ai
- **GitHub**: https://github.com/wandb/wandb (10.5k+ stars)
- **Examples**: https://github.com/wandb/examples
- **Community**: https://wandb.ai/community
- **Discord**: https://wandb.me/discord
## See Also
- `references/sweeps.md` - Comprehensive hyperparameter optimization guide
- `references/artifacts.md` - Data and model versioning patterns
- `references/integrations.md` - Framework-specific examples

View File

@@ -0,0 +1,584 @@
# Artifacts & Model Registry Guide
Complete guide to data versioning and model management with W&B Artifacts.
## Table of Contents
- What are Artifacts
- Creating Artifacts
- Using Artifacts
- Model Registry
- Versioning & Lineage
- Best Practices
## What are Artifacts
Artifacts are versioned datasets, models, or files tracked with lineage.
**Key Features:**
- Automatic versioning (v0, v1, v2...)
- Lineage tracking (which runs produced/used artifacts)
- Efficient storage (deduplication)
- Collaboration (team-wide access)
- Aliases (latest, best, production)
**Common Use Cases:**
- Dataset versioning
- Model checkpoints
- Preprocessed data
- Evaluation results
- Configuration files
## Creating Artifacts
### Basic Dataset Artifact
```python
import wandb
run = wandb.init(project="my-project")
# Create artifact
dataset = wandb.Artifact(
name='training-data',
type='dataset',
description='ImageNet training split with augmentations',
metadata={
'size': '1.2M images',
'format': 'JPEG',
'resolution': '224x224'
}
)
# Add files
dataset.add_file('data/train.csv') # Single file
dataset.add_dir('data/images') # Entire directory
dataset.add_reference('s3://bucket/data') # Cloud reference
# Log artifact
run.log_artifact(dataset)
wandb.finish()
```
### Model Artifact
```python
import torch
import wandb
run = wandb.init(project="my-project")
# Train model
model = train_model()
# Save model
torch.save(model.state_dict(), 'model.pth')
# Create model artifact
model_artifact = wandb.Artifact(
name='resnet50-classifier',
type='model',
description='ResNet50 trained on ImageNet',
metadata={
'architecture': 'ResNet50',
'accuracy': 0.95,
'loss': 0.15,
'epochs': 50,
'framework': 'PyTorch'
}
)
# Add model file
model_artifact.add_file('model.pth')
# Add config
model_artifact.add_file('config.yaml')
# Log with aliases
run.log_artifact(model_artifact, aliases=['latest', 'best'])
wandb.finish()
```
### Preprocessed Data Artifact
```python
import pandas as pd
import wandb
run = wandb.init(project="nlp-project")
# Preprocess data
df = pd.read_csv('raw_data.csv')
df_processed = preprocess(df)
df_processed.to_csv('processed_data.csv', index=False)
# Create artifact
processed_data = wandb.Artifact(
name='processed-text-data',
type='dataset',
metadata={
'rows': len(df_processed),
'columns': list(df_processed.columns),
'preprocessing_steps': ['lowercase', 'remove_stopwords', 'tokenize']
}
)
processed_data.add_file('processed_data.csv')
# Log artifact
run.log_artifact(processed_data)
```
## Using Artifacts
### Download and Use
```python
import wandb
run = wandb.init(project="my-project")
# Download artifact
artifact = run.use_artifact('training-data:latest')
artifact_dir = artifact.download()
# Use files
import pandas as pd
df = pd.read_csv(f'{artifact_dir}/train.csv')
# Train with artifact data
model = train_model(df)
```
### Use Specific Version
```python
# Use specific version
artifact_v2 = run.use_artifact('training-data:v2')
# Use alias
artifact_best = run.use_artifact('model:best')
artifact_prod = run.use_artifact('model:production')
# Use from another project
artifact = run.use_artifact('team/other-project/model:latest')
```
### Check Artifact Metadata
```python
artifact = run.use_artifact('training-data:latest')
# Access metadata
print(artifact.metadata)
print(f"Size: {artifact.metadata['size']}")
# Access version info
print(f"Version: {artifact.version}")
print(f"Created at: {artifact.created_at}")
print(f"Digest: {artifact.digest}")
```
## Model Registry
Link models to a central registry for governance and deployment.
### Create Model Registry
```python
# In W&B UI:
# 1. Go to "Registry" tab
# 2. Create new registry: "production-models"
# 3. Define stages: development, staging, production
```
### Link Model to Registry
```python
import wandb
run = wandb.init(project="training")
# Create model artifact
model_artifact = wandb.Artifact(
name='sentiment-classifier',
type='model',
metadata={'accuracy': 0.94, 'f1': 0.92}
)
model_artifact.add_file('model.pth')
# Log artifact
run.log_artifact(model_artifact)
# Link to registry
run.link_artifact(
model_artifact,
'model-registry/production-models',
aliases=['staging'] # Deploy to staging
)
wandb.finish()
```
### Promote Model in Registry
```python
# Retrieve model from registry
api = wandb.Api()
artifact = api.artifact('model-registry/production-models/sentiment-classifier:staging')
# Promote to production
artifact.link('model-registry/production-models', aliases=['production'])
# Demote from production
artifact.aliases = ['archived']
artifact.save()
```
### Use Model from Registry
```python
import wandb
run = wandb.init()
# Download production model
model_artifact = run.use_artifact(
'model-registry/production-models/sentiment-classifier:production'
)
model_dir = model_artifact.download()
# Load and use
import torch
model = torch.load(f'{model_dir}/model.pth')
model.eval()
```
## Versioning & Lineage
### Automatic Versioning
```python
# First log: creates v0
run1 = wandb.init(project="my-project")
dataset_v0 = wandb.Artifact('my-dataset', type='dataset')
dataset_v0.add_file('data_v1.csv')
run1.log_artifact(dataset_v0)
# Second log with same name: creates v1
run2 = wandb.init(project="my-project")
dataset_v1 = wandb.Artifact('my-dataset', type='dataset')
dataset_v1.add_file('data_v2.csv') # Different content
run2.log_artifact(dataset_v1)
# Third log with SAME content as v1: references v1 (no new version)
run3 = wandb.init(project="my-project")
dataset_v1_again = wandb.Artifact('my-dataset', type='dataset')
dataset_v1_again.add_file('data_v2.csv') # Same content as v1
run3.log_artifact(dataset_v1_again) # Still v1, no v2 created
```
### Track Lineage
```python
# Training run
run = wandb.init(project="my-project")
# Use dataset (input)
dataset = run.use_artifact('training-data:v3')
data = load_data(dataset.download())
# Train model
model = train(data)
# Save model (output)
model_artifact = wandb.Artifact('trained-model', type='model')
torch.save(model.state_dict(), 'model.pth')
model_artifact.add_file('model.pth')
run.log_artifact(model_artifact)
# Lineage automatically tracked:
# training-data:v3 --> [run] --> trained-model:v0
```
### View Lineage Graph
```python
# In W&B UI:
# Artifacts → Select artifact → Lineage tab
# Shows:
# - Which runs produced this artifact
# - Which runs used this artifact
# - Parent/child artifacts
```
## Artifact Types
### Dataset Artifacts
```python
# Raw data
raw_data = wandb.Artifact('raw-data', type='dataset')
raw_data.add_dir('raw/')
# Processed data
processed_data = wandb.Artifact('processed-data', type='dataset')
processed_data.add_dir('processed/')
# Train/val/test splits
train_split = wandb.Artifact('train-split', type='dataset')
train_split.add_file('train.csv')
val_split = wandb.Artifact('val-split', type='dataset')
val_split.add_file('val.csv')
```
### Model Artifacts
```python
# Checkpoint during training
checkpoint = wandb.Artifact('checkpoint-epoch-10', type='model')
checkpoint.add_file('checkpoint_epoch_10.pth')
# Final model
final_model = wandb.Artifact('final-model', type='model')
final_model.add_file('model.pth')
final_model.add_file('tokenizer.json')
# Quantized model
quantized = wandb.Artifact('quantized-model', type='model')
quantized.add_file('model_int8.onnx')
```
### Result Artifacts
```python
# Predictions
predictions = wandb.Artifact('test-predictions', type='predictions')
predictions.add_file('predictions.csv')
# Evaluation metrics
eval_results = wandb.Artifact('evaluation', type='evaluation')
eval_results.add_file('metrics.json')
eval_results.add_file('confusion_matrix.png')
```
## Advanced Patterns
### Incremental Artifacts
Add files incrementally without re-uploading.
```python
run = wandb.init(project="my-project")
# Create artifact
dataset = wandb.Artifact('incremental-dataset', type='dataset')
# Add files incrementally
for i in range(100):
filename = f'batch_{i}.csv'
process_batch(i, filename)
dataset.add_file(filename)
# Log progress
if (i + 1) % 10 == 0:
print(f"Added {i + 1}/100 batches")
# Log complete artifact
run.log_artifact(dataset)
```
### Artifact Tables
Track structured data with W&B Tables.
```python
import wandb
run = wandb.init(project="my-project")
# Create table
table = wandb.Table(columns=["id", "image", "label", "prediction"])
for idx, (img, label, pred) in enumerate(zip(images, labels, predictions)):
table.add_data(
idx,
wandb.Image(img),
label,
pred
)
# Log as artifact
artifact = wandb.Artifact('predictions-table', type='predictions')
artifact.add(table, "predictions")
run.log_artifact(artifact)
```
### Artifact References
Reference external data without copying.
```python
# S3 reference
dataset = wandb.Artifact('s3-dataset', type='dataset')
dataset.add_reference('s3://my-bucket/data/', name='train')
dataset.add_reference('s3://my-bucket/labels/', name='labels')
# GCS reference
dataset.add_reference('gs://my-bucket/data/')
# HTTP reference
dataset.add_reference('https://example.com/data.zip')
# Local filesystem reference (for shared storage)
dataset.add_reference('file:///mnt/shared/data')
```
## Collaboration Patterns
### Team Dataset Sharing
```python
# Data engineer creates dataset
run = wandb.init(project="data-eng", entity="my-team")
dataset = wandb.Artifact('shared-dataset', type='dataset')
dataset.add_dir('data/')
run.log_artifact(dataset, aliases=['latest', 'production'])
# ML engineer uses dataset
run = wandb.init(project="ml-training", entity="my-team")
dataset = run.use_artifact('my-team/data-eng/shared-dataset:production')
data = load_data(dataset.download())
```
### Model Handoff
```python
# Training team
train_run = wandb.init(project="model-training", entity="ml-team")
model = train_model()
model_artifact = wandb.Artifact('nlp-model', type='model')
model_artifact.add_file('model.pth')
train_run.log_artifact(model_artifact)
train_run.link_artifact(model_artifact, 'model-registry/nlp-models', aliases=['candidate'])
# Evaluation team
eval_run = wandb.init(project="model-eval", entity="ml-team")
model_artifact = eval_run.use_artifact('model-registry/nlp-models/nlp-model:candidate')
metrics = evaluate_model(model_artifact)
if metrics['f1'] > 0.9:
# Promote to production
model_artifact.link('model-registry/nlp-models', aliases=['production'])
```
## Best Practices
### 1. Use Descriptive Names
```python
# ✅ Good: Descriptive names
wandb.Artifact('imagenet-train-augmented-v2', type='dataset')
wandb.Artifact('bert-base-sentiment-finetuned', type='model')
# ❌ Bad: Generic names
wandb.Artifact('dataset1', type='dataset')
wandb.Artifact('model', type='model')
```
### 2. Add Comprehensive Metadata
```python
model_artifact = wandb.Artifact(
'production-model',
type='model',
description='ResNet50 classifier for product categorization',
metadata={
# Model info
'architecture': 'ResNet50',
'framework': 'PyTorch 2.0',
'pretrained': True,
# Performance
'accuracy': 0.95,
'f1_score': 0.93,
'inference_time_ms': 15,
# Training
'epochs': 50,
'dataset': 'imagenet',
'num_samples': 1200000,
# Business context
'use_case': 'e-commerce product classification',
'owner': 'ml-team@company.com',
'approved_by': 'data-science-lead'
}
)
```
### 3. Use Aliases for Deployment Stages
```python
# Development
run.log_artifact(model, aliases=['dev', 'latest'])
# Staging
run.log_artifact(model, aliases=['staging'])
# Production
run.log_artifact(model, aliases=['production', 'v1.2.0'])
# Archive old versions
old_artifact = api.artifact('model:production')
old_artifact.aliases = ['archived-v1.1.0']
old_artifact.save()
```
### 4. Track Data Lineage
```python
def create_training_pipeline():
run = wandb.init(project="pipeline")
# 1. Load raw data
raw_data = run.use_artifact('raw-data:latest')
# 2. Preprocess
processed = preprocess(raw_data)
processed_artifact = wandb.Artifact('processed-data', type='dataset')
processed_artifact.add_file('processed.csv')
run.log_artifact(processed_artifact)
# 3. Train model
model = train(processed)
model_artifact = wandb.Artifact('trained-model', type='model')
model_artifact.add_file('model.pth')
run.log_artifact(model_artifact)
# Lineage: raw-data → processed-data → trained-model
```
### 5. Efficient Storage
```python
# ✅ Good: Reference large files
large_dataset = wandb.Artifact('large-dataset', type='dataset')
large_dataset.add_reference('s3://bucket/huge-file.tar.gz')
# ❌ Bad: Upload giant files
# large_dataset.add_file('huge-file.tar.gz') # Don't do this
# ✅ Good: Upload only metadata
metadata_artifact = wandb.Artifact('dataset-metadata', type='dataset')
metadata_artifact.add_file('metadata.json') # Small file
```
## Resources
- **Artifacts Documentation**: https://docs.wandb.ai/guides/artifacts
- **Model Registry**: https://docs.wandb.ai/guides/model-registry
- **Best Practices**: https://wandb.ai/site/articles/versioning-data-and-models-in-ml

View File

@@ -0,0 +1,700 @@
# Framework Integrations Guide
Complete guide to integrating W&B with popular ML frameworks.
## Table of Contents
- HuggingFace Transformers
- PyTorch Lightning
- Keras/TensorFlow
- Fast.ai
- XGBoost/LightGBM
- PyTorch Native
- Custom Integrations
## HuggingFace Transformers
### Automatic Integration
```python
from transformers import Trainer, TrainingArguments
import wandb
# Initialize W&B
wandb.init(project="hf-transformers", name="bert-finetuning")
# Training arguments with W&B
training_args = TrainingArguments(
output_dir="./results",
report_to="wandb", # Enable W&B logging
run_name="bert-base-finetuning",
# Training params
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
learning_rate=2e-5,
# Logging
logging_dir="./logs",
logging_steps=100,
logging_first_step=True,
# Evaluation
evaluation_strategy="steps",
eval_steps=500,
save_steps=500,
# Other
load_best_model_at_end=True,
metric_for_best_model="eval_accuracy"
)
# Trainer automatically logs to W&B
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics
)
# Train (metrics logged automatically)
trainer.train()
# Finish W&B run
wandb.finish()
```
### Custom Logging
```python
from transformers import Trainer, TrainingArguments
from transformers.integrations import WandbCallback
import wandb
class CustomWandbCallback(WandbCallback):
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
super().on_evaluate(args, state, control, metrics, **kwargs)
# Log custom metrics
wandb.log({
"custom/eval_score": metrics["eval_accuracy"] * 100,
"custom/epoch": state.epoch
})
# Use custom callback
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=[CustomWandbCallback()]
)
```
### Log Model to Registry
```python
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir="./results",
report_to="wandb",
load_best_model_at_end=True
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()
# Save final model as artifact
model_artifact = wandb.Artifact(
'hf-bert-model',
type='model',
description='BERT finetuned on sentiment analysis'
)
# Save model files
trainer.save_model("./final_model")
model_artifact.add_dir("./final_model")
# Log artifact
wandb.log_artifact(model_artifact, aliases=['best', 'production'])
wandb.finish()
```
## PyTorch Lightning
### Basic Integration
```python
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
# Create W&B logger
wandb_logger = WandbLogger(
project="lightning-demo",
name="resnet50-training",
log_model=True, # Log model checkpoints as artifacts
save_code=True # Save code as artifact
)
# Lightning module
class LitModel(pl.LightningModule):
def __init__(self, learning_rate=0.001):
super().__init__()
self.save_hyperparameters()
self.model = create_model()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
# Log metrics (automatically sent to W&B)
self.log('train/loss', loss, on_step=True, on_epoch=True)
self.log('train/accuracy', accuracy(y_hat, y), on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log('val/loss', loss, on_step=False, on_epoch=True)
self.log('val/accuracy', accuracy(y_hat, y), on_epoch=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
# Trainer with W&B logger
trainer = pl.Trainer(
logger=wandb_logger,
max_epochs=10,
accelerator="gpu",
devices=1
)
# Train (metrics logged automatically)
trainer.fit(model, datamodule=dm)
# Finish W&B run
wandb.finish()
```
### Log Media
```python
class LitModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
# Log images (first batch only)
if batch_idx == 0:
self.logger.experiment.log({
"examples": [wandb.Image(img) for img in x[:8]]
})
return loss
def on_validation_epoch_end(self):
# Log confusion matrix
cm = compute_confusion_matrix(self.all_preds, self.all_targets)
self.logger.experiment.log({
"confusion_matrix": wandb.plot.confusion_matrix(
probs=None,
y_true=self.all_targets,
preds=self.all_preds,
class_names=self.class_names
)
})
```
### Hyperparameter Sweeps
```python
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
# Define sweep
sweep_config = {
'method': 'bayes',
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
'parameters': {
'learning_rate': {'min': 1e-5, 'max': 1e-2, 'distribution': 'log_uniform'},
'batch_size': {'values': [16, 32, 64]},
'hidden_size': {'values': [128, 256, 512]}
}
}
sweep_id = wandb.sweep(sweep_config, project="lightning-sweeps")
def train():
# Initialize W&B
run = wandb.init()
# Get hyperparameters
config = wandb.config
# Create logger
wandb_logger = WandbLogger()
# Create model with sweep params
model = LitModel(
learning_rate=config.learning_rate,
hidden_size=config.hidden_size
)
# Create datamodule with sweep batch size
dm = DataModule(batch_size=config.batch_size)
# Train
trainer = pl.Trainer(logger=wandb_logger, max_epochs=10)
trainer.fit(model, dm)
# Run sweep
wandb.agent(sweep_id, function=train, count=30)
```
## Keras/TensorFlow
### With Callback
```python
import tensorflow as tf
from wandb.keras import WandbCallback
import wandb
# Initialize W&B
wandb.init(
project="keras-demo",
config={
"learning_rate": 0.001,
"epochs": 10,
"batch_size": 32
}
)
config = wandb.config
# Build model
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(config.learning_rate),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Train with W&B callback
history = model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=config.epochs,
batch_size=config.batch_size,
callbacks=[
WandbCallback(
log_weights=True, # Log model weights
log_gradients=True, # Log gradients
training_data=(x_train, y_train),
validation_data=(x_val, y_val),
labels=class_names
)
]
)
# Save model as artifact
model.save('model.h5')
artifact = wandb.Artifact('keras-model', type='model')
artifact.add_file('model.h5')
wandb.log_artifact(artifact)
wandb.finish()
```
### Custom Training Loop
```python
import tensorflow as tf
import wandb
wandb.init(project="tf-custom-loop")
# Model, optimizer, loss
model = create_model()
optimizer = tf.keras.optimizers.Adam(1e-3)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
# Metrics
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
predictions = model(x, training=True)
loss = loss_fn(y, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(y, predictions)
# Training loop
for epoch in range(EPOCHS):
train_loss.reset_states()
train_accuracy.reset_states()
for step, (x, y) in enumerate(train_dataset):
train_step(x, y)
# Log every 100 steps
if step % 100 == 0:
wandb.log({
'train/loss': train_loss.result().numpy(),
'train/accuracy': train_accuracy.result().numpy(),
'epoch': epoch,
'step': step
})
# Log epoch metrics
wandb.log({
'epoch/train_loss': train_loss.result().numpy(),
'epoch/train_accuracy': train_accuracy.result().numpy(),
'epoch': epoch
})
wandb.finish()
```
## Fast.ai
### With Callback
```python
from fastai.vision.all import *
from fastai.callback.wandb import *
import wandb
# Initialize W&B
wandb.init(project="fastai-demo")
# Create data loaders
dls = ImageDataLoaders.from_folder(
path,
train='train',
valid='valid',
bs=64
)
# Create learner with W&B callback
learn = vision_learner(
dls,
resnet34,
metrics=accuracy,
cbs=WandbCallback(
log_preds=True, # Log predictions
log_model=True, # Log model as artifact
log_dataset=True # Log dataset as artifact
)
)
# Train (metrics logged automatically)
learn.fine_tune(5)
wandb.finish()
```
## XGBoost/LightGBM
### XGBoost
```python
import xgboost as xgb
import wandb
# Initialize W&B
run = wandb.init(project="xgboost-demo", config={
"max_depth": 6,
"learning_rate": 0.1,
"n_estimators": 100
})
config = wandb.config
# Create DMatrix
dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)
# XGBoost params
params = {
'max_depth': config.max_depth,
'learning_rate': config.learning_rate,
'objective': 'binary:logistic',
'eval_metric': ['logloss', 'auc']
}
# Custom callback for W&B
def wandb_callback(env):
"""Log XGBoost metrics to W&B."""
for metric_name, metric_value in env.evaluation_result_list:
wandb.log({
f"{metric_name}": metric_value,
"iteration": env.iteration
})
# Train with callback
model = xgb.train(
params,
dtrain,
num_boost_round=config.n_estimators,
evals=[(dtrain, 'train'), (dval, 'val')],
callbacks=[wandb_callback],
verbose_eval=10
)
# Save model
model.save_model('xgboost_model.json')
artifact = wandb.Artifact('xgboost-model', type='model')
artifact.add_file('xgboost_model.json')
wandb.log_artifact(artifact)
wandb.finish()
```
### LightGBM
```python
import lightgbm as lgb
import wandb
run = wandb.init(project="lgbm-demo")
# Create datasets
train_data = lgb.Dataset(X_train, label=y_train)
val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)
# Parameters
params = {
'objective': 'binary',
'metric': ['binary_logloss', 'auc'],
'learning_rate': 0.1,
'num_leaves': 31
}
# Custom callback
def log_to_wandb(env):
"""Log LightGBM metrics to W&B."""
for entry in env.evaluation_result_list:
dataset_name, metric_name, metric_value, _ = entry
wandb.log({
f"{dataset_name}/{metric_name}": metric_value,
"iteration": env.iteration
})
# Train
model = lgb.train(
params,
train_data,
num_boost_round=100,
valid_sets=[train_data, val_data],
valid_names=['train', 'val'],
callbacks=[log_to_wandb]
)
# Save model
model.save_model('lgbm_model.txt')
artifact = wandb.Artifact('lgbm-model', type='model')
artifact.add_file('lgbm_model.txt')
wandb.log_artifact(artifact)
wandb.finish()
```
## PyTorch Native
### Training Loop Integration
```python
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
# Initialize W&B
wandb.init(project="pytorch-native", config={
"learning_rate": 0.001,
"epochs": 10,
"batch_size": 32
})
config = wandb.config
# Model, loss, optimizer
model = create_model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
# Watch model (logs gradients and parameters)
wandb.watch(model, criterion, log="all", log_freq=100)
# Training loop
for epoch in range(config.epochs):
model.train()
train_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
# Forward pass
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# Backward pass
loss.backward()
optimizer.step()
# Track metrics
train_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
# Log every 100 batches
if batch_idx % 100 == 0:
wandb.log({
'train/loss': loss.item(),
'train/batch_accuracy': 100. * correct / total,
'epoch': epoch,
'batch': batch_idx
})
# Validation
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
val_loss += loss.item()
_, predicted = output.max(1)
val_total += target.size(0)
val_correct += predicted.eq(target).sum().item()
# Log epoch metrics
wandb.log({
'epoch/train_loss': train_loss / len(train_loader),
'epoch/train_accuracy': 100. * correct / total,
'epoch/val_loss': val_loss / len(val_loader),
'epoch/val_accuracy': 100. * val_correct / val_total,
'epoch': epoch
})
# Save final model
torch.save(model.state_dict(), 'model.pth')
artifact = wandb.Artifact('final-model', type='model')
artifact.add_file('model.pth')
wandb.log_artifact(artifact)
wandb.finish()
```
## Custom Integrations
### Generic Framework Integration
```python
import wandb
class WandbIntegration:
"""Generic W&B integration wrapper."""
def __init__(self, project, config):
self.run = wandb.init(project=project, config=config)
self.config = wandb.config
self.step = 0
def log_metrics(self, metrics, step=None):
"""Log training metrics."""
if step is None:
step = self.step
self.step += 1
wandb.log(metrics, step=step)
def log_images(self, images, caption=""):
"""Log images."""
wandb.log({
caption: [wandb.Image(img) for img in images]
})
def log_table(self, data, columns):
"""Log tabular data."""
table = wandb.Table(columns=columns, data=data)
wandb.log({"table": table})
def save_model(self, model_path, metadata=None):
"""Save model as artifact."""
artifact = wandb.Artifact(
'model',
type='model',
metadata=metadata or {}
)
artifact.add_file(model_path)
self.run.log_artifact(artifact)
def finish(self):
"""Finish W&B run."""
wandb.finish()
# Usage
wb = WandbIntegration(project="my-project", config={"lr": 0.001})
# Training loop
for epoch in range(10):
# Your training code
loss, accuracy = train_epoch()
# Log metrics
wb.log_metrics({
'train/loss': loss,
'train/accuracy': accuracy
})
# Save model
wb.save_model('model.pth', metadata={'accuracy': 0.95})
wb.finish()
```
## Resources
- **Integrations Guide**: https://docs.wandb.ai/guides/integrations
- **HuggingFace**: https://docs.wandb.ai/guides/integrations/huggingface
- **PyTorch Lightning**: https://docs.wandb.ai/guides/integrations/lightning
- **Keras**: https://docs.wandb.ai/guides/integrations/keras
- **Examples**: https://github.com/wandb/examples

View File

@@ -0,0 +1,847 @@
# Comprehensive Hyperparameter Sweeps Guide
Complete guide to hyperparameter optimization with W&B Sweeps.
## Table of Contents
- Sweep Configuration
- Search Strategies
- Parameter Distributions
- Early Termination
- Parallel Execution
- Advanced Patterns
- Real-World Examples
## Sweep Configuration
### Basic Sweep Config
```python
sweep_config = {
'method': 'bayes', # Search strategy
'metric': {
'name': 'val/accuracy',
'goal': 'maximize' # or 'minimize'
},
'parameters': {
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
},
'batch_size': {
'values': [16, 32, 64, 128]
}
}
}
# Initialize sweep
sweep_id = wandb.sweep(sweep_config, project="my-project")
```
### Complete Config Example
```python
sweep_config = {
# Required: Search method
'method': 'bayes',
# Required: Optimization metric
'metric': {
'name': 'val/f1_score',
'goal': 'maximize'
},
# Required: Parameters to search
'parameters': {
# Continuous parameter
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
},
# Discrete values
'batch_size': {
'values': [16, 32, 64, 128]
},
# Categorical
'optimizer': {
'values': ['adam', 'sgd', 'rmsprop', 'adamw']
},
# Uniform distribution
'dropout': {
'distribution': 'uniform',
'min': 0.1,
'max': 0.5
},
# Integer range
'num_layers': {
'distribution': 'int_uniform',
'min': 2,
'max': 10
},
# Fixed value (constant across runs)
'epochs': {
'value': 50
}
},
# Optional: Early termination
'early_terminate': {
'type': 'hyperband',
'min_iter': 5,
's': 2,
'eta': 3,
'max_iter': 27
}
}
```
## Search Strategies
### 1. Grid Search
Exhaustively search all combinations.
```python
sweep_config = {
'method': 'grid',
'parameters': {
'learning_rate': {
'values': [0.001, 0.01, 0.1]
},
'batch_size': {
'values': [16, 32, 64]
},
'optimizer': {
'values': ['adam', 'sgd']
}
}
}
# Total runs: 3 × 3 × 2 = 18 runs
```
**Pros:**
- Comprehensive search
- Reproducible results
- No randomness
**Cons:**
- Exponential growth with parameters
- Inefficient for continuous parameters
- Not scalable beyond 3-4 parameters
**When to use:**
- Few parameters (< 4)
- All discrete values
- Need complete coverage
### 2. Random Search
Randomly sample parameter combinations.
```python
sweep_config = {
'method': 'random',
'parameters': {
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
},
'batch_size': {
'values': [16, 32, 64, 128, 256]
},
'dropout': {
'distribution': 'uniform',
'min': 0.0,
'max': 0.5
},
'num_layers': {
'distribution': 'int_uniform',
'min': 2,
'max': 8
}
}
}
# Run 100 random trials
wandb.agent(sweep_id, function=train, count=100)
```
**Pros:**
- Scales to many parameters
- Can run indefinitely
- Often finds good solutions quickly
**Cons:**
- No learning from previous runs
- May miss optimal region
- Results vary with random seed
**When to use:**
- Many parameters (> 4)
- Quick exploration
- Limited budget
### 3. Bayesian Optimization (Recommended)
Learn from previous trials to sample promising regions.
```python
sweep_config = {
'method': 'bayes',
'metric': {
'name': 'val/loss',
'goal': 'minimize'
},
'parameters': {
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
},
'weight_decay': {
'distribution': 'log_uniform',
'min': 1e-6,
'max': 1e-2
},
'dropout': {
'distribution': 'uniform',
'min': 0.1,
'max': 0.5
},
'num_layers': {
'values': [2, 3, 4, 5, 6]
}
}
}
```
**Pros:**
- Most sample-efficient
- Learns from past trials
- Focuses on promising regions
**Cons:**
- Initial random exploration phase
- May get stuck in local optima
- Slower per iteration
**When to use:**
- Expensive training runs
- Need best performance
- Limited compute budget
## Parameter Distributions
### Continuous Distributions
```python
# Log-uniform: Good for learning rates, regularization
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-6,
'max': 1e-1
}
# Uniform: Good for dropout, momentum
'dropout': {
'distribution': 'uniform',
'min': 0.0,
'max': 0.5
}
# Normal distribution
'parameter': {
'distribution': 'normal',
'mu': 0.5,
'sigma': 0.1
}
# Log-normal distribution
'parameter': {
'distribution': 'log_normal',
'mu': 0.0,
'sigma': 1.0
}
```
### Discrete Distributions
```python
# Fixed values
'batch_size': {
'values': [16, 32, 64, 128, 256]
}
# Integer uniform
'num_layers': {
'distribution': 'int_uniform',
'min': 2,
'max': 10
}
# Quantized uniform (step size)
'layer_size': {
'distribution': 'q_uniform',
'min': 32,
'max': 512,
'q': 32 # Step by 32: 32, 64, 96, 128...
}
# Quantized log-uniform
'hidden_size': {
'distribution': 'q_log_uniform',
'min': 32,
'max': 1024,
'q': 32
}
```
### Categorical Parameters
```python
# Optimizers
'optimizer': {
'values': ['adam', 'sgd', 'rmsprop', 'adamw']
}
# Model architectures
'model': {
'values': ['resnet18', 'resnet34', 'resnet50', 'efficientnet_b0']
}
# Activation functions
'activation': {
'values': ['relu', 'gelu', 'silu', 'leaky_relu']
}
```
## Early Termination
Stop underperforming runs early to save compute.
### Hyperband
```python
sweep_config = {
'method': 'bayes',
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
'parameters': {...},
# Hyperband early termination
'early_terminate': {
'type': 'hyperband',
'min_iter': 3, # Minimum iterations before termination
's': 2, # Bracket count
'eta': 3, # Downsampling rate
'max_iter': 27 # Maximum iterations
}
}
```
**How it works:**
- Runs trials in brackets
- Keeps top 1/eta performers each round
- Eliminates bottom performers early
### Custom Termination
```python
def train():
run = wandb.init()
for epoch in range(MAX_EPOCHS):
loss = train_epoch()
val_acc = validate()
wandb.log({'val/accuracy': val_acc, 'epoch': epoch})
# Custom early stopping
if epoch > 5 and val_acc < 0.5:
print("Early stop: Poor performance")
break
if epoch > 10 and val_acc > best_acc - 0.01:
print("Early stop: No improvement")
break
```
## Training Function
### Basic Template
```python
def train():
# Initialize W&B run
run = wandb.init()
# Get hyperparameters
config = wandb.config
# Build model with config
model = build_model(
hidden_size=config.hidden_size,
num_layers=config.num_layers,
dropout=config.dropout
)
# Create optimizer
optimizer = create_optimizer(
model.parameters(),
name=config.optimizer,
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# Training loop
for epoch in range(config.epochs):
# Train
train_loss, train_acc = train_epoch(
model, optimizer, train_loader, config.batch_size
)
# Validate
val_loss, val_acc = validate(model, val_loader)
# Log metrics
wandb.log({
'train/loss': train_loss,
'train/accuracy': train_acc,
'val/loss': val_loss,
'val/accuracy': val_acc,
'epoch': epoch
})
# Log final model
torch.save(model.state_dict(), 'model.pth')
wandb.save('model.pth')
# Finish run
wandb.finish()
```
### With PyTorch
```python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import wandb
def train():
run = wandb.init()
config = wandb.config
# Data
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True
)
# Model
model = ResNet(
num_classes=config.num_classes,
dropout=config.dropout
).to(device)
# Optimizer
if config.optimizer == 'adam':
optimizer = torch.optim.Adam(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
elif config.optimizer == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(),
lr=config.learning_rate,
momentum=config.momentum,
weight_decay=config.weight_decay
)
# Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=config.epochs
)
# Training
for epoch in range(config.epochs):
model.train()
train_loss = 0.0
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = nn.CrossEntropyLoss()(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation
model.eval()
val_loss, val_acc = validate(model, val_loader)
# Step scheduler
scheduler.step()
# Log
wandb.log({
'train/loss': train_loss / len(train_loader),
'val/loss': val_loss,
'val/accuracy': val_acc,
'learning_rate': scheduler.get_last_lr()[0],
'epoch': epoch
})
```
## Parallel Execution
### Multiple Agents
Run sweep agents in parallel to speed up search.
```python
# Initialize sweep once
sweep_id = wandb.sweep(sweep_config, project="my-project")
# Run multiple agents in parallel
# Agent 1 (Terminal 1)
wandb.agent(sweep_id, function=train, count=20)
# Agent 2 (Terminal 2)
wandb.agent(sweep_id, function=train, count=20)
# Agent 3 (Terminal 3)
wandb.agent(sweep_id, function=train, count=20)
# Total: 60 runs across 3 agents
```
### Multi-GPU Execution
```python
import os
def train():
# Get available GPU
gpu_id = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
run = wandb.init()
config = wandb.config
# Train on specific GPU
device = torch.device(f'cuda:{gpu_id}')
model = model.to(device)
# ... rest of training ...
# Run agents on different GPUs
# Terminal 1
# CUDA_VISIBLE_DEVICES=0 wandb agent sweep_id
# Terminal 2
# CUDA_VISIBLE_DEVICES=1 wandb agent sweep_id
# Terminal 3
# CUDA_VISIBLE_DEVICES=2 wandb agent sweep_id
```
## Advanced Patterns
### Nested Parameters
```python
sweep_config = {
'method': 'bayes',
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
'parameters': {
'model': {
'parameters': {
'type': {
'values': ['resnet', 'efficientnet']
},
'size': {
'values': ['small', 'medium', 'large']
}
}
},
'optimizer': {
'parameters': {
'type': {
'values': ['adam', 'sgd']
},
'lr': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
}
}
}
}
}
# Access nested config
def train():
run = wandb.init()
model_type = wandb.config.model.type
model_size = wandb.config.model.size
opt_type = wandb.config.optimizer.type
lr = wandb.config.optimizer.lr
```
### Conditional Parameters
```python
sweep_config = {
'method': 'bayes',
'parameters': {
'optimizer': {
'values': ['adam', 'sgd']
},
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-1
},
# Only used if optimizer == 'sgd'
'momentum': {
'distribution': 'uniform',
'min': 0.5,
'max': 0.99
}
}
}
def train():
run = wandb.init()
config = wandb.config
if config.optimizer == 'adam':
optimizer = torch.optim.Adam(
model.parameters(),
lr=config.learning_rate
)
elif config.optimizer == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(),
lr=config.learning_rate,
momentum=config.momentum # Conditional parameter
)
```
## Real-World Examples
### Image Classification
```python
sweep_config = {
'method': 'bayes',
'metric': {
'name': 'val/top1_accuracy',
'goal': 'maximize'
},
'parameters': {
# Model
'architecture': {
'values': ['resnet50', 'resnet101', 'efficientnet_b0', 'efficientnet_b3']
},
'pretrained': {
'values': [True, False]
},
# Training
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-5,
'max': 1e-2
},
'batch_size': {
'values': [16, 32, 64, 128]
},
'optimizer': {
'values': ['adam', 'sgd', 'adamw']
},
'weight_decay': {
'distribution': 'log_uniform',
'min': 1e-6,
'max': 1e-2
},
# Regularization
'dropout': {
'distribution': 'uniform',
'min': 0.0,
'max': 0.5
},
'label_smoothing': {
'distribution': 'uniform',
'min': 0.0,
'max': 0.2
},
# Data augmentation
'mixup_alpha': {
'distribution': 'uniform',
'min': 0.0,
'max': 1.0
},
'cutmix_alpha': {
'distribution': 'uniform',
'min': 0.0,
'max': 1.0
}
},
'early_terminate': {
'type': 'hyperband',
'min_iter': 5
}
}
```
### NLP Fine-Tuning
```python
sweep_config = {
'method': 'bayes',
'metric': {'name': 'eval/f1', 'goal': 'maximize'},
'parameters': {
# Model
'model_name': {
'values': ['bert-base-uncased', 'roberta-base', 'distilbert-base-uncased']
},
# Training
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-6,
'max': 1e-4
},
'per_device_train_batch_size': {
'values': [8, 16, 32]
},
'num_train_epochs': {
'values': [3, 4, 5]
},
'warmup_ratio': {
'distribution': 'uniform',
'min': 0.0,
'max': 0.1
},
'weight_decay': {
'distribution': 'log_uniform',
'min': 1e-4,
'max': 1e-1
},
# Optimizer
'adam_beta1': {
'distribution': 'uniform',
'min': 0.8,
'max': 0.95
},
'adam_beta2': {
'distribution': 'uniform',
'min': 0.95,
'max': 0.999
}
}
}
```
## Best Practices
### 1. Start Small
```python
# Initial exploration: Random search, 20 runs
sweep_config_v1 = {
'method': 'random',
'parameters': {...}
}
wandb.agent(sweep_id_v1, train, count=20)
# Refined search: Bayes, narrow ranges
sweep_config_v2 = {
'method': 'bayes',
'parameters': {
'learning_rate': {
'min': 5e-5, # Narrowed from 1e-6 to 1e-4
'max': 1e-4
}
}
}
```
### 2. Use Log Scales
```python
# ✅ Good: Log scale for learning rate
'learning_rate': {
'distribution': 'log_uniform',
'min': 1e-6,
'max': 1e-2
}
# ❌ Bad: Linear scale
'learning_rate': {
'distribution': 'uniform',
'min': 0.000001,
'max': 0.01
}
```
### 3. Set Reasonable Ranges
```python
# Base ranges on prior knowledge
'learning_rate': {'min': 1e-5, 'max': 1e-3}, # Typical for Adam
'batch_size': {'values': [16, 32, 64]}, # GPU memory limits
'dropout': {'min': 0.1, 'max': 0.5} # Too high hurts training
```
### 4. Monitor Resource Usage
```python
def train():
run = wandb.init()
# Log system metrics
wandb.log({
'system/gpu_memory_allocated': torch.cuda.memory_allocated(),
'system/gpu_memory_reserved': torch.cuda.memory_reserved()
})
```
### 5. Save Best Models
```python
def train():
run = wandb.init()
best_acc = 0.0
for epoch in range(config.epochs):
val_acc = validate(model)
if val_acc > best_acc:
best_acc = val_acc
# Save best checkpoint
torch.save(model.state_dict(), 'best_model.pth')
wandb.save('best_model.pth')
```
## Resources
- **Sweeps Documentation**: https://docs.wandb.ai/guides/sweeps
- **Configuration Reference**: https://docs.wandb.ai/guides/sweeps/configuration
- **Examples**: https://github.com/wandb/examples/tree/master/examples/wandb-sweeps