Add stuck initiatives audit report
This commit is contained in:
3
protected/skills-backup/mlops/evaluation/DESCRIPTION.md
Normal file
3
protected/skills-backup/mlops/evaluation/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Model evaluation benchmarks, experiment tracking, data curation, tokenizers, and interpretability tools.
|
||||
---
|
||||
@@ -0,0 +1,519 @@
|
||||
---
|
||||
name: huggingface-tokenizers
|
||||
description: Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integrates seamlessly with transformers. Use when you need high-performance tokenization or custom tokenizer training.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [tokenizers, transformers, datasets]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Tokenization, HuggingFace, BPE, WordPiece, Unigram, Fast Tokenization, Rust, Custom Tokenizer, Alignment Tracking, Production]
|
||||
|
||||
---
|
||||
|
||||
# HuggingFace Tokenizers - Fast Tokenization for NLP
|
||||
|
||||
Fast, production-ready tokenizers with Rust performance and Python ease-of-use.
|
||||
|
||||
## When to use HuggingFace Tokenizers
|
||||
|
||||
**Use HuggingFace Tokenizers when:**
|
||||
- Need extremely fast tokenization (<20s per GB of text)
|
||||
- Training custom tokenizers from scratch
|
||||
- Want alignment tracking (token → original text position)
|
||||
- Building production NLP pipelines
|
||||
- Need to tokenize large corpora efficiently
|
||||
|
||||
**Performance**:
|
||||
- **Speed**: <20 seconds to tokenize 1GB on CPU
|
||||
- **Implementation**: Rust core with Python/Node.js bindings
|
||||
- **Efficiency**: 10-100× faster than pure Python implementations
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **SentencePiece**: Language-independent, used by T5/ALBERT
|
||||
- **tiktoken**: OpenAI's BPE tokenizer for GPT models
|
||||
- **transformers AutoTokenizer**: Loading pretrained only (uses this library internally)
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Install tokenizers
|
||||
pip install tokenizers
|
||||
|
||||
# With transformers integration
|
||||
pip install tokenizers transformers
|
||||
```
|
||||
|
||||
### Load pretrained tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
# Load from HuggingFace Hub
|
||||
tokenizer = Tokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Encode text
|
||||
output = tokenizer.encode("Hello, how are you?")
|
||||
print(output.tokens) # ['hello', ',', 'how', 'are', 'you', '?']
|
||||
print(output.ids) # [7592, 1010, 2129, 2024, 2017, 1029]
|
||||
|
||||
# Decode back
|
||||
text = tokenizer.decode(output.ids)
|
||||
print(text) # "hello, how are you?"
|
||||
```
|
||||
|
||||
### Train custom BPE tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
|
||||
# Initialize tokenizer with BPE model
|
||||
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
# Configure trainer
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=30000,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
min_frequency=2
|
||||
)
|
||||
|
||||
# Train on files
|
||||
files = ["train.txt", "validation.txt"]
|
||||
tokenizer.train(files, trainer)
|
||||
|
||||
# Save
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
```
|
||||
|
||||
**Training time**: ~1-2 minutes for 100MB corpus, ~10-20 minutes for 1GB
|
||||
|
||||
### Batch encoding with padding
|
||||
|
||||
```python
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(pad_id=3, pad_token="[PAD]")
|
||||
|
||||
# Encode batch
|
||||
texts = ["Hello world", "This is a longer sentence"]
|
||||
encodings = tokenizer.encode_batch(texts)
|
||||
|
||||
for encoding in encodings:
|
||||
print(encoding.ids)
|
||||
# [101, 7592, 2088, 102, 3, 3, 3]
|
||||
# [101, 2023, 2003, 1037, 2936, 6251, 102]
|
||||
```
|
||||
|
||||
## Tokenization algorithms
|
||||
|
||||
### BPE (Byte-Pair Encoding)
|
||||
|
||||
**How it works**:
|
||||
1. Start with character-level vocabulary
|
||||
2. Find most frequent character pair
|
||||
3. Merge into new token, add to vocabulary
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
**Used by**: GPT-2, GPT-3, RoBERTa, BART, DeBERTa
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
|
||||
tokenizer = Tokenizer(BPE(unk_token="<|endoftext|>"))
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50257,
|
||||
special_tokens=["<|endoftext|>"],
|
||||
min_frequency=2
|
||||
)
|
||||
|
||||
tokenizer.train(files=["data.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Handles OOV words well (breaks into subwords)
|
||||
- Flexible vocabulary size
|
||||
- Good for morphologically rich languages
|
||||
|
||||
**Trade-offs**:
|
||||
- Tokenization depends on merge order
|
||||
- May split common words unexpectedly
|
||||
|
||||
### WordPiece
|
||||
|
||||
**How it works**:
|
||||
1. Start with character vocabulary
|
||||
2. Score merge pairs: `frequency(pair) / (frequency(first) × frequency(second))`
|
||||
3. Merge highest scoring pair
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
**Used by**: BERT, DistilBERT, MobileBERT
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
continuing_subword_prefix="##"
|
||||
)
|
||||
|
||||
tokenizer.train(files=["corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Prioritizes meaningful merges (high score = semantically related)
|
||||
- Used successfully in BERT (state-of-the-art results)
|
||||
|
||||
**Trade-offs**:
|
||||
- Unknown words become `[UNK]` if no subword match
|
||||
- Saves vocabulary, not merge rules (larger files)
|
||||
|
||||
### Unigram
|
||||
|
||||
**How it works**:
|
||||
1. Start with large vocabulary (all substrings)
|
||||
2. Compute loss for corpus with current vocabulary
|
||||
3. Remove tokens with minimal impact on loss
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
**Used by**: ALBERT, T5, mBART, XLNet (via SentencePiece)
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000,
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
unk_token="<unk>"
|
||||
)
|
||||
|
||||
tokenizer.train(files=["data.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Probabilistic (finds most likely tokenization)
|
||||
- Works well for languages without word boundaries
|
||||
- Handles diverse linguistic contexts
|
||||
|
||||
**Trade-offs**:
|
||||
- Computationally expensive to train
|
||||
- More hyperparameters to tune
|
||||
|
||||
## Tokenization pipeline
|
||||
|
||||
Complete pipeline: **Normalization → Pre-tokenization → Model → Post-processing**
|
||||
|
||||
### Normalization
|
||||
|
||||
Clean and standardize text:
|
||||
|
||||
```python
|
||||
from tokenizers.normalizers import NFD, StripAccents, Lowercase, Sequence
|
||||
|
||||
tokenizer.normalizer = Sequence([
|
||||
NFD(), # Unicode normalization (decompose)
|
||||
Lowercase(), # Convert to lowercase
|
||||
StripAccents() # Remove accents
|
||||
])
|
||||
|
||||
# Input: "Héllo WORLD"
|
||||
# After normalization: "hello world"
|
||||
```
|
||||
|
||||
**Common normalizers**:
|
||||
- `NFD`, `NFC`, `NFKD`, `NFKC` - Unicode normalization forms
|
||||
- `Lowercase()` - Convert to lowercase
|
||||
- `StripAccents()` - Remove accents (é → e)
|
||||
- `Strip()` - Remove whitespace
|
||||
- `Replace(pattern, content)` - Regex replacement
|
||||
|
||||
### Pre-tokenization
|
||||
|
||||
Split text into word-like units:
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence, ByteLevel
|
||||
|
||||
# Split on whitespace and punctuation
|
||||
tokenizer.pre_tokenizer = Sequence([
|
||||
Whitespace(),
|
||||
Punctuation()
|
||||
])
|
||||
|
||||
# Input: "Hello, world!"
|
||||
# After pre-tokenization: ["Hello", ",", "world", "!"]
|
||||
```
|
||||
|
||||
**Common pre-tokenizers**:
|
||||
- `Whitespace()` - Split on spaces, tabs, newlines
|
||||
- `ByteLevel()` - GPT-2 style byte-level splitting
|
||||
- `Punctuation()` - Isolate punctuation
|
||||
- `Digits(individual_digits=True)` - Split digits individually
|
||||
- `Metaspace()` - Replace spaces with ▁ (SentencePiece style)
|
||||
|
||||
### Post-processing
|
||||
|
||||
Add special tokens for model input:
|
||||
|
||||
```python
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
# BERT-style: [CLS] sentence [SEP]
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[
|
||||
("[CLS]", 1),
|
||||
("[SEP]", 2),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
**Common patterns**:
|
||||
```python
|
||||
# GPT-2: sentence <|endoftext|>
|
||||
TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[("<|endoftext|>", 50256)]
|
||||
)
|
||||
|
||||
# RoBERTa: <s> sentence </s>
|
||||
TemplateProcessing(
|
||||
single="<s> $A </s>",
|
||||
pair="<s> $A </s> </s> $B </s>",
|
||||
special_tokens=[("<s>", 0), ("</s>", 2)]
|
||||
)
|
||||
```
|
||||
|
||||
## Alignment tracking
|
||||
|
||||
Track token positions in original text:
|
||||
|
||||
```python
|
||||
output = tokenizer.encode("Hello, world!")
|
||||
|
||||
# Get token offsets
|
||||
for token, offset in zip(output.tokens, output.offsets):
|
||||
start, end = offset
|
||||
print(f"{token:10} → [{start:2}, {end:2}): {text[start:end]!r}")
|
||||
|
||||
# Output:
|
||||
# hello → [ 0, 5): 'Hello'
|
||||
# , → [ 5, 6): ','
|
||||
# world → [ 7, 12): 'world'
|
||||
# ! → [12, 13): '!'
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Named entity recognition (map predictions back to text)
|
||||
- Question answering (extract answer spans)
|
||||
- Token classification (align labels to original positions)
|
||||
|
||||
## Integration with transformers
|
||||
|
||||
### Load with AutoTokenizer
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# AutoTokenizer automatically uses fast tokenizers
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Check if using fast tokenizer
|
||||
print(tokenizer.is_fast) # True
|
||||
|
||||
# Access underlying tokenizers.Tokenizer
|
||||
fast_tokenizer = tokenizer.backend_tokenizer
|
||||
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
|
||||
```
|
||||
|
||||
### Convert custom tokenizer to transformers
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
# Train custom tokenizer
|
||||
tokenizer = Tokenizer(BPE())
|
||||
# ... train tokenizer ...
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
|
||||
# Wrap for transformers
|
||||
transformers_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_file="my-tokenizer.json",
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
mask_token="[MASK]"
|
||||
)
|
||||
|
||||
# Use like any transformers tokenizer
|
||||
outputs = transformers_tokenizer(
|
||||
"Hello world",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_tensors="pt"
|
||||
)
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Train from iterator (large datasets)
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
|
||||
|
||||
# Create batch iterator
|
||||
def batch_iterator(batch_size=1000):
|
||||
for i in range(0, len(dataset), batch_size):
|
||||
yield dataset[i:i + batch_size]["text"]
|
||||
|
||||
# Train tokenizer
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(),
|
||||
trainer=trainer,
|
||||
length=len(dataset) # For progress bar
|
||||
)
|
||||
```
|
||||
|
||||
**Performance**: Processes 1GB in ~10-20 minutes
|
||||
|
||||
### Enable truncation and padding
|
||||
|
||||
```python
|
||||
# Enable truncation
|
||||
tokenizer.enable_truncation(max_length=512)
|
||||
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(
|
||||
pad_id=tokenizer.token_to_id("[PAD]"),
|
||||
pad_token="[PAD]",
|
||||
length=512 # Fixed length, or None for batch max
|
||||
)
|
||||
|
||||
# Encode with both
|
||||
output = tokenizer.encode("This is a long sentence that will be truncated...")
|
||||
print(len(output.ids)) # 512
|
||||
```
|
||||
|
||||
### Multi-processing
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from multiprocessing import Pool
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = Tokenizer.from_file("tokenizer.json")
|
||||
|
||||
def encode_batch(texts):
|
||||
return tokenizer.encode_batch(texts)
|
||||
|
||||
# Process large corpus in parallel
|
||||
with Pool(8) as pool:
|
||||
# Split corpus into chunks
|
||||
chunk_size = 1000
|
||||
chunks = [corpus[i:i+chunk_size] for i in range(0, len(corpus), chunk_size)]
|
||||
|
||||
# Encode in parallel
|
||||
results = pool.map(encode_batch, chunks)
|
||||
```
|
||||
|
||||
**Speedup**: 5-8× with 8 cores
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Training speed
|
||||
|
||||
| Corpus Size | BPE (30k vocab) | WordPiece (30k) | Unigram (8k) |
|
||||
|-------------|-----------------|-----------------|--------------|
|
||||
| 10 MB | 15 sec | 18 sec | 25 sec |
|
||||
| 100 MB | 1.5 min | 2 min | 4 min |
|
||||
| 1 GB | 15 min | 20 min | 40 min |
|
||||
|
||||
**Hardware**: 16-core CPU, tested on English Wikipedia
|
||||
|
||||
### Tokenization speed
|
||||
|
||||
| Implementation | 1 GB corpus | Throughput |
|
||||
|----------------|-------------|---------------|
|
||||
| Pure Python | ~20 minutes | ~50 MB/min |
|
||||
| HF Tokenizers | ~15 seconds | ~4 GB/min |
|
||||
| **Speedup** | **80×** | **80×** |
|
||||
|
||||
**Test**: English text, average sentence length 20 words
|
||||
|
||||
### Memory usage
|
||||
|
||||
| Task | Memory |
|
||||
|-------------------------|---------|
|
||||
| Load tokenizer | ~10 MB |
|
||||
| Train BPE (30k vocab) | ~200 MB |
|
||||
| Encode 1M sentences | ~500 MB |
|
||||
|
||||
## Supported models
|
||||
|
||||
Pre-trained tokenizers available via `from_pretrained()`:
|
||||
|
||||
**BERT family**:
|
||||
- `bert-base-uncased`, `bert-large-cased`
|
||||
- `distilbert-base-uncased`
|
||||
- `roberta-base`, `roberta-large`
|
||||
|
||||
**GPT family**:
|
||||
- `gpt2`, `gpt2-medium`, `gpt2-large`
|
||||
- `distilgpt2`
|
||||
|
||||
**T5 family**:
|
||||
- `t5-small`, `t5-base`, `t5-large`
|
||||
- `google/flan-t5-xxl`
|
||||
|
||||
**Other**:
|
||||
- `facebook/bart-base`, `facebook/mbart-large-cc25`
|
||||
- `albert-base-v2`, `albert-xlarge-v2`
|
||||
- `xlm-roberta-base`, `xlm-roberta-large`
|
||||
|
||||
Browse all: https://huggingface.co/models?library=tokenizers
|
||||
|
||||
## References
|
||||
|
||||
- **[Training Guide](references/training.md)** - Train custom tokenizers, configure trainers, handle large datasets
|
||||
- **[Algorithms Deep Dive](references/algorithms.md)** - BPE, WordPiece, Unigram explained in detail
|
||||
- **[Pipeline Components](references/pipeline.md)** - Normalizers, pre-tokenizers, post-processors, decoders
|
||||
- **[Transformers Integration](references/integration.md)** - AutoTokenizer, PreTrainedTokenizerFast, special tokens
|
||||
|
||||
## Resources
|
||||
|
||||
- **Docs**: https://huggingface.co/docs/tokenizers
|
||||
- **GitHub**: https://github.com/huggingface/tokenizers ⭐ 9,000+
|
||||
- **Version**: 0.20.0+
|
||||
- **Course**: https://huggingface.co/learn/nlp-course/chapter6/1
|
||||
- **Paper**: BPE (Sennrich et al., 2016), WordPiece (Schuster & Nakajima, 2012)
|
||||
|
||||
|
||||
@@ -0,0 +1,653 @@
|
||||
# Tokenization Algorithms Deep Dive
|
||||
|
||||
Comprehensive explanation of BPE, WordPiece, and Unigram algorithms.
|
||||
|
||||
## Byte-Pair Encoding (BPE)
|
||||
|
||||
### Algorithm overview
|
||||
|
||||
BPE iteratively merges the most frequent pair of tokens in a corpus.
|
||||
|
||||
**Training process**:
|
||||
1. Initialize vocabulary with all characters
|
||||
2. Count frequency of all adjacent token pairs
|
||||
3. Merge most frequent pair into new token
|
||||
4. Add new token to vocabulary
|
||||
5. Update corpus with new token
|
||||
6. Repeat until vocabulary size reached
|
||||
|
||||
### Step-by-step example
|
||||
|
||||
**Corpus**:
|
||||
```
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6
|
||||
widest: 3
|
||||
```
|
||||
|
||||
**Iteration 1**:
|
||||
```
|
||||
Count pairs:
|
||||
'e' + 's': 9 (newest: 6, widest: 3) ← most frequent
|
||||
'l' + 'o': 7
|
||||
'o' + 'w': 7
|
||||
...
|
||||
|
||||
Merge: 'e' + 's' → 'es'
|
||||
|
||||
Updated corpus:
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6 → newes|t: 6
|
||||
widest: 3 → wides|t: 3
|
||||
|
||||
Vocabulary: [a-z] + ['es']
|
||||
```
|
||||
|
||||
**Iteration 2**:
|
||||
```
|
||||
Count pairs:
|
||||
'es' + 't': 9 ← most frequent
|
||||
'l' + 'o': 7
|
||||
...
|
||||
|
||||
Merge: 'es' + 't' → 'est'
|
||||
|
||||
Updated corpus:
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6 → new|est: 6
|
||||
widest: 3 → wid|est: 3
|
||||
|
||||
Vocabulary: [a-z] + ['es', 'est']
|
||||
```
|
||||
|
||||
**Continue until desired vocabulary size...**
|
||||
|
||||
### Tokenization with trained BPE
|
||||
|
||||
Given vocabulary: `['l', 'o', 'w', 'e', 'r', 'n', 's', 't', 'i', 'd', 'es', 'est', 'lo', 'low', 'ne', 'new', 'newest', 'wi', 'wid', 'widest']`
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Step 1: Split into characters
|
||||
['l', 'o', 'w', 'e', 's', 't']
|
||||
|
||||
Step 2: Apply merges in order learned during training
|
||||
- Merge 'l' + 'o' → 'lo' (if this merge was learned)
|
||||
- Merge 'lo' + 'w' → 'low' (if learned)
|
||||
- Merge 'e' + 's' → 'es' (learned)
|
||||
- Merge 'es' + 't' → 'est' (learned)
|
||||
|
||||
Final: ['low', 'est']
|
||||
```
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
|
||||
# Initialize
|
||||
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
# Configure trainer
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=1000,
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
||||
)
|
||||
|
||||
# Train
|
||||
corpus = [
|
||||
"This is a sample corpus for BPE training.",
|
||||
"BPE learns subword units from the training data.",
|
||||
# ... more sentences
|
||||
]
|
||||
|
||||
tokenizer.train_from_iterator(corpus, trainer=trainer)
|
||||
|
||||
# Use
|
||||
output = tokenizer.encode("This is tokenization")
|
||||
print(output.tokens) # ['This', 'is', 'token', 'ization']
|
||||
```
|
||||
|
||||
### Byte-level BPE (GPT-2 variant)
|
||||
|
||||
**Problem**: Standard BPE has limited character coverage (256+ Unicode chars)
|
||||
|
||||
**Solution**: Operate on byte level (256 bytes)
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Byte-level pre-tokenization
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
tokenizer.decoder = ByteLevelDecoder()
|
||||
|
||||
# This handles ALL possible characters, including emojis
|
||||
text = "Hello 🌍 世界"
|
||||
tokens = tokenizer.encode(text).tokens
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Handles any Unicode character (256 byte coverage)
|
||||
- No unknown tokens (worst case: bytes)
|
||||
- Used by GPT-2, GPT-3, BART
|
||||
|
||||
**Trade-offs**:
|
||||
- Slightly worse compression (bytes vs characters)
|
||||
- More tokens for non-ASCII text
|
||||
|
||||
### BPE variants
|
||||
|
||||
**SentencePiece BPE**:
|
||||
- Language-independent (no pre-tokenization)
|
||||
- Treats input as raw byte stream
|
||||
- Used by T5, ALBERT, XLNet
|
||||
|
||||
**Robust BPE**:
|
||||
- Dropout during training (randomly skip merges)
|
||||
- More robust tokenization at inference
|
||||
- Reduces overfitting to training data
|
||||
|
||||
## WordPiece
|
||||
|
||||
### Algorithm overview
|
||||
|
||||
WordPiece is similar to BPE but uses a different merge selection criterion.
|
||||
|
||||
**Training process**:
|
||||
1. Initialize vocabulary with all characters
|
||||
2. Count frequency of all token pairs
|
||||
3. Score each pair: `score = freq(pair) / (freq(first) × freq(second))`
|
||||
4. Merge pair with highest score
|
||||
5. Repeat until vocabulary size reached
|
||||
|
||||
### Why different scoring?
|
||||
|
||||
**BPE**: Merges most frequent pairs
|
||||
- "aa" appears 100 times → high priority
|
||||
- Even if 'a' appears 1000 times alone
|
||||
|
||||
**WordPiece**: Merges pairs that are semantically related
|
||||
- "aa" appears 100 times, 'a' appears 1000 times → low score (100 / (1000 × 1000))
|
||||
- "th" appears 50 times, 't' appears 60 times, 'h' appears 55 times → high score (50 / (60 × 55))
|
||||
- Prioritizes pairs that appear together more than expected
|
||||
|
||||
### Step-by-step example
|
||||
|
||||
**Corpus**:
|
||||
```
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6
|
||||
widest: 3
|
||||
```
|
||||
|
||||
**Iteration 1**:
|
||||
```
|
||||
Count frequencies:
|
||||
'e': 11 (lower: 2, newest: 6, widest: 3)
|
||||
's': 9
|
||||
't': 9
|
||||
...
|
||||
|
||||
Count pairs:
|
||||
'e' + 's': 9 (newest: 6, widest: 3)
|
||||
'es' + 't': 9 (newest: 6, widest: 3)
|
||||
...
|
||||
|
||||
Compute scores:
|
||||
score('e' + 's') = 9 / (11 × 9) = 0.091
|
||||
score('es' + 't') = 9 / (9 × 9) = 0.111 ← highest score
|
||||
score('l' + 'o') = 7 / (7 × 9) = 0.111 ← tied
|
||||
|
||||
Choose: 'es' + 't' → 'est' (or 'lo' if tied)
|
||||
```
|
||||
|
||||
**Key difference**: WordPiece prioritizes rare combinations over frequent ones.
|
||||
|
||||
### Tokenization with WordPiece
|
||||
|
||||
Given vocabulary: `['##e', '##s', '##t', 'l', 'o', 'w', 'new', 'est', 'low']`
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Step 1: Find longest matching prefix
|
||||
'lowest' → 'low' (matches)
|
||||
|
||||
Step 2: Find longest match for remainder
|
||||
'est' → 'est' (matches)
|
||||
|
||||
Final: ['low', 'est']
|
||||
```
|
||||
|
||||
**If no match**:
|
||||
```
|
||||
Tokenize "unknownword":
|
||||
'unknownword' → no match
|
||||
'unknown' → no match
|
||||
'unkn' → no match
|
||||
'un' → no match
|
||||
'u' → no match
|
||||
→ [UNK]
|
||||
```
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
|
||||
# Initialize BERT-style tokenizer
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
|
||||
# Normalization (lowercase, accent stripping)
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
|
||||
# Pre-tokenization (whitespace + punctuation)
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
# Configure trainer
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522, # BERT vocab size
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
continuing_subword_prefix="##" # BERT uses ##
|
||||
)
|
||||
|
||||
# Train
|
||||
tokenizer.train_from_iterator(corpus, trainer=trainer)
|
||||
|
||||
# Use
|
||||
output = tokenizer.encode("Tokenization works great!")
|
||||
print(output.tokens) # ['token', '##ization', 'works', 'great', '!']
|
||||
```
|
||||
|
||||
### Subword prefix
|
||||
|
||||
**BERT uses `##` prefix**:
|
||||
```
|
||||
"unbelievable" → ['un', '##believ', '##able']
|
||||
```
|
||||
|
||||
**Why?**
|
||||
- Indicates token is a continuation
|
||||
- Allows reconstruction: remove ##, concatenate
|
||||
- Helps model distinguish word boundaries
|
||||
|
||||
### WordPiece advantages
|
||||
|
||||
**Semantic merges**:
|
||||
- Prioritizes meaningful combinations
|
||||
- "qu" has high score (always together)
|
||||
- "qx" has low score (rare combination)
|
||||
|
||||
**Better for morphology**:
|
||||
- Captures affixes: un-, -ing, -ed
|
||||
- Preserves word stems
|
||||
|
||||
**Trade-offs**:
|
||||
- Slower training than BPE
|
||||
- More memory (stores vocabulary, not merges)
|
||||
- Original implementation not open-source (HF reimplementation)
|
||||
|
||||
## Unigram
|
||||
|
||||
### Algorithm overview
|
||||
|
||||
Unigram works backward: start with large vocabulary, remove tokens.
|
||||
|
||||
**Training process**:
|
||||
1. Initialize with large vocabulary (all substrings)
|
||||
2. Estimate probability of each token (frequency-based)
|
||||
3. For each token, compute loss increase if removed
|
||||
4. Remove 10-20% of tokens with lowest loss impact
|
||||
5. Re-estimate probabilities
|
||||
6. Repeat until desired vocabulary size
|
||||
|
||||
### Probabilistic tokenization
|
||||
|
||||
**Unigram assumption**: Each token is independent.
|
||||
|
||||
Given vocabulary with probabilities:
|
||||
```
|
||||
P('low') = 0.02
|
||||
P('l') = 0.01
|
||||
P('o') = 0.015
|
||||
P('w') = 0.01
|
||||
P('est') = 0.03
|
||||
P('e') = 0.02
|
||||
P('s') = 0.015
|
||||
P('t') = 0.015
|
||||
```
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Option 1: ['low', 'est']
|
||||
P = P('low') × P('est') = 0.02 × 0.03 = 0.0006
|
||||
|
||||
Option 2: ['l', 'o', 'w', 'est']
|
||||
P = 0.01 × 0.015 × 0.01 × 0.03 = 0.000000045
|
||||
|
||||
Option 3: ['low', 'e', 's', 't']
|
||||
P = 0.02 × 0.02 × 0.015 × 0.015 = 0.0000009
|
||||
|
||||
Choose option 1 (highest probability)
|
||||
```
|
||||
|
||||
### Viterbi algorithm
|
||||
|
||||
Finding best tokenization is expensive (exponential possibilities).
|
||||
|
||||
**Viterbi algorithm** (dynamic programming):
|
||||
```python
|
||||
def tokenize_viterbi(word, vocab, probs):
|
||||
n = len(word)
|
||||
# dp[i] = (best_prob, best_tokens) for word[:i]
|
||||
dp = [{} for _ in range(n + 1)]
|
||||
dp[0] = (0.0, []) # log probability
|
||||
|
||||
for i in range(1, n + 1):
|
||||
best_prob = float('-inf')
|
||||
best_tokens = []
|
||||
|
||||
# Try all possible last tokens
|
||||
for j in range(i):
|
||||
token = word[j:i]
|
||||
if token in vocab:
|
||||
prob = dp[j][0] + log(probs[token])
|
||||
if prob > best_prob:
|
||||
best_prob = prob
|
||||
best_tokens = dp[j][1] + [token]
|
||||
|
||||
dp[i] = (best_prob, best_tokens)
|
||||
|
||||
return dp[n][1]
|
||||
```
|
||||
|
||||
**Time complexity**: O(n² × vocab_size) vs O(2^n) brute force
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
# Initialize
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
# Configure trainer
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000,
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
unk_token="<unk>",
|
||||
max_piece_length=16, # Max token length
|
||||
n_sub_iterations=2, # EM iterations
|
||||
shrinking_factor=0.75 # Remove 25% each iteration
|
||||
)
|
||||
|
||||
# Train
|
||||
tokenizer.train_from_iterator(corpus, trainer=trainer)
|
||||
|
||||
# Use
|
||||
output = tokenizer.encode("Tokenization with Unigram")
|
||||
print(output.tokens) # ['▁Token', 'ization', '▁with', '▁Un', 'igram']
|
||||
```
|
||||
|
||||
### Unigram advantages
|
||||
|
||||
**Probabilistic**:
|
||||
- Multiple valid tokenizations
|
||||
- Can sample different tokenizations (data augmentation)
|
||||
|
||||
**Subword regularization**:
|
||||
```python
|
||||
# Sample different tokenizations
|
||||
for _ in range(3):
|
||||
tokens = tokenizer.encode("tokenization", is_pretokenized=False).tokens
|
||||
print(tokens)
|
||||
|
||||
# Output (different each time):
|
||||
# ['token', 'ization']
|
||||
# ['tok', 'en', 'ization']
|
||||
# ['token', 'iz', 'ation']
|
||||
```
|
||||
|
||||
**Language-independent**:
|
||||
- No word boundaries needed
|
||||
- Works for CJK languages (Chinese, Japanese, Korean)
|
||||
- Treats input as character stream
|
||||
|
||||
**Trade-offs**:
|
||||
- Slower training (EM algorithm)
|
||||
- More hyperparameters
|
||||
- Larger model (stores probabilities)
|
||||
|
||||
## Algorithm comparison
|
||||
|
||||
### Training speed
|
||||
|
||||
| Algorithm | Small (10MB) | Medium (100MB) | Large (1GB) |
|
||||
|------------|--------------|----------------|-------------|
|
||||
| BPE | 10-15 sec | 1-2 min | 10-20 min |
|
||||
| WordPiece | 15-20 sec | 2-3 min | 15-30 min |
|
||||
| Unigram | 20-30 sec | 3-5 min | 30-60 min |
|
||||
|
||||
**Tested on**: 16-core CPU, 30k vocab
|
||||
|
||||
### Tokenization quality
|
||||
|
||||
Tested on English Wikipedia (perplexity measurement):
|
||||
|
||||
| Algorithm | Vocab Size | Tokens/Word | Unknown Rate |
|
||||
|------------|------------|-------------|--------------|
|
||||
| BPE | 30k | 1.3 | 0.5% |
|
||||
| WordPiece | 30k | 1.2 | 1.2% |
|
||||
| Unigram | 8k | 1.5 | 0.3% |
|
||||
|
||||
**Key observations**:
|
||||
- WordPiece: Slightly better compression
|
||||
- BPE: Lower unknown rate
|
||||
- Unigram: Smallest vocab, good coverage
|
||||
|
||||
### Compression ratio
|
||||
|
||||
Characters per token (higher = better compression):
|
||||
|
||||
| Language | BPE (30k) | WordPiece (30k) | Unigram (8k) |
|
||||
|----------|-----------|-----------------|--------------|
|
||||
| English | 4.2 | 4.5 | 3.8 |
|
||||
| Chinese | 2.1 | 2.3 | 2.5 |
|
||||
| Arabic | 3.5 | 3.8 | 3.2 |
|
||||
|
||||
**Best for each**:
|
||||
- English: WordPiece
|
||||
- Chinese: Unigram (language-independent)
|
||||
- Arabic: WordPiece
|
||||
|
||||
### Use case recommendations
|
||||
|
||||
**BPE** - Best for:
|
||||
- English language models
|
||||
- Code (handles symbols well)
|
||||
- Fast training needed
|
||||
- **Models**: GPT-2, GPT-3, RoBERTa, BART
|
||||
|
||||
**WordPiece** - Best for:
|
||||
- Masked language modeling (BERT-style)
|
||||
- Morphologically rich languages
|
||||
- Semantic understanding tasks
|
||||
- **Models**: BERT, DistilBERT, ELECTRA
|
||||
|
||||
**Unigram** - Best for:
|
||||
- Multilingual models
|
||||
- Languages without word boundaries (CJK)
|
||||
- Data augmentation via subword regularization
|
||||
- **Models**: T5, ALBERT, XLNet (via SentencePiece)
|
||||
|
||||
## Advanced topics
|
||||
|
||||
### Handling rare words
|
||||
|
||||
**BPE approach**:
|
||||
```
|
||||
"antidisestablishmentarianism"
|
||||
→ ['anti', 'dis', 'establish', 'ment', 'arian', 'ism']
|
||||
```
|
||||
|
||||
**WordPiece approach**:
|
||||
```
|
||||
"antidisestablishmentarianism"
|
||||
→ ['anti', '##dis', '##establish', '##ment', '##arian', '##ism']
|
||||
```
|
||||
|
||||
**Unigram approach**:
|
||||
```
|
||||
"antidisestablishmentarianism"
|
||||
→ ['▁anti', 'dis', 'establish', 'ment', 'arian', 'ism']
|
||||
```
|
||||
|
||||
### Handling numbers
|
||||
|
||||
**Challenge**: Infinite number combinations
|
||||
|
||||
**BPE solution**: Byte-level (handles any digit sequence)
|
||||
```python
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
# Handles any number
|
||||
"123456789" → byte-level tokens
|
||||
```
|
||||
|
||||
**WordPiece solution**: Digit pre-tokenization
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Digits
|
||||
|
||||
# Split digits individually or as groups
|
||||
tokenizer.pre_tokenizer = Digits(individual_digits=True)
|
||||
|
||||
"123" → ['1', '2', '3']
|
||||
```
|
||||
|
||||
**Unigram solution**: Learns common number patterns
|
||||
```python
|
||||
# Learns patterns during training
|
||||
"2023" → ['202', '3'] or ['20', '23']
|
||||
```
|
||||
|
||||
### Handling case sensitivity
|
||||
|
||||
**Lowercase (BERT)**:
|
||||
```python
|
||||
from tokenizers.normalizers import Lowercase
|
||||
|
||||
tokenizer.normalizer = Lowercase()
|
||||
|
||||
"Hello WORLD" → "hello world" → ['hello', 'world']
|
||||
```
|
||||
|
||||
**Preserve case (GPT-2)**:
|
||||
```python
|
||||
# No case normalization
|
||||
tokenizer.normalizer = None
|
||||
|
||||
"Hello WORLD" → ['Hello', 'WORLD']
|
||||
```
|
||||
|
||||
**Cased tokens (RoBERTa)**:
|
||||
```python
|
||||
# Learns separate tokens for different cases
|
||||
Vocabulary: ['Hello', 'hello', 'HELLO', 'world', 'WORLD']
|
||||
```
|
||||
|
||||
### Handling emojis and special characters
|
||||
|
||||
**Byte-level (GPT-2)**:
|
||||
```python
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
"Hello 🌍 👋" → byte-level representation (always works)
|
||||
```
|
||||
|
||||
**Unicode normalization**:
|
||||
```python
|
||||
from tokenizers.normalizers import NFKC
|
||||
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
"é" (composed) ↔ "é" (decomposed) → normalized to one form
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Poor subword splitting
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
"running" → ['r', 'u', 'n', 'n', 'i', 'n', 'g'] (too granular)
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Increase vocabulary size
|
||||
2. Train longer (more merge iterations)
|
||||
3. Lower `min_frequency` threshold
|
||||
|
||||
### Issue: Too many unknown tokens
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
5% of tokens are [UNK]
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Increase vocabulary size
|
||||
2. Use byte-level BPE (no UNK possible)
|
||||
3. Verify training corpus is representative
|
||||
|
||||
### Issue: Inconsistent tokenization
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
"running" → ['run', 'ning']
|
||||
"runner" → ['r', 'u', 'n', 'n', 'e', 'r']
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Check normalization consistency
|
||||
2. Ensure pre-tokenization is deterministic
|
||||
3. Use Unigram for probabilistic variance
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Match algorithm to model architecture**:
|
||||
- BERT-style → WordPiece
|
||||
- GPT-style → BPE
|
||||
- T5-style → Unigram
|
||||
|
||||
2. **Use byte-level for multilingual**:
|
||||
- Handles any Unicode
|
||||
- No unknown tokens
|
||||
|
||||
3. **Test on representative data**:
|
||||
- Measure compression ratio
|
||||
- Check unknown token rate
|
||||
- Inspect sample tokenizations
|
||||
|
||||
4. **Version control tokenizers**:
|
||||
- Save with model
|
||||
- Document special tokens
|
||||
- Track vocabulary changes
|
||||
@@ -0,0 +1,637 @@
|
||||
# Transformers Integration
|
||||
|
||||
Complete guide to using HuggingFace Tokenizers with the Transformers library.
|
||||
|
||||
## AutoTokenizer
|
||||
|
||||
The easiest way to load tokenizers.
|
||||
|
||||
### Loading pretrained tokenizers
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Load from HuggingFace Hub
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Check if using fast tokenizer (Rust-based)
|
||||
print(tokenizer.is_fast) # True
|
||||
|
||||
# Access underlying tokenizers.Tokenizer
|
||||
if tokenizer.is_fast:
|
||||
fast_tokenizer = tokenizer.backend_tokenizer
|
||||
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
|
||||
```
|
||||
|
||||
### Fast vs slow tokenizers
|
||||
|
||||
| Feature | Fast (Rust) | Slow (Python) |
|
||||
|--------------------------|----------------|---------------|
|
||||
| Speed | 5-10× faster | Baseline |
|
||||
| Alignment tracking | ✅ Full support | ❌ Limited |
|
||||
| Batch processing | ✅ Optimized | ⚠️ Slower |
|
||||
| Offset mapping | ✅ Yes | ❌ No |
|
||||
| Installation | `tokenizers` | Built-in |
|
||||
|
||||
**Always use fast tokenizers when available.**
|
||||
|
||||
### Check available tokenizers
|
||||
|
||||
```python
|
||||
from transformers import TOKENIZER_MAPPING
|
||||
|
||||
# List all fast tokenizers
|
||||
for config_class, (slow, fast) in TOKENIZER_MAPPING.items():
|
||||
if fast is not None:
|
||||
print(f"{config_class.__name__}: {fast.__name__}")
|
||||
```
|
||||
|
||||
## PreTrainedTokenizerFast
|
||||
|
||||
Wrap custom tokenizers for transformers.
|
||||
|
||||
### Convert custom tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
# Train custom tokenizer
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=30000,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
||||
)
|
||||
tokenizer.train(files=["corpus.txt"], trainer=trainer)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
|
||||
# Wrap for transformers
|
||||
transformers_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_file="my-tokenizer.json",
|
||||
unk_token="[UNK]",
|
||||
sep_token="[SEP]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
mask_token="[MASK]"
|
||||
)
|
||||
|
||||
# Save in transformers format
|
||||
transformers_tokenizer.save_pretrained("my-tokenizer")
|
||||
```
|
||||
|
||||
**Result**: Directory with `tokenizer.json` + `tokenizer_config.json` + `special_tokens_map.json`
|
||||
|
||||
### Use like any transformers tokenizer
|
||||
|
||||
```python
|
||||
# Load
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("my-tokenizer")
|
||||
|
||||
# Encode with all transformers features
|
||||
outputs = tokenizer(
|
||||
"Hello world",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=128,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
print(outputs.keys())
|
||||
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
|
||||
```
|
||||
|
||||
## Special tokens
|
||||
|
||||
### Default special tokens
|
||||
|
||||
| Model Family | CLS/BOS | SEP/EOS | PAD | UNK | MASK |
|
||||
|--------------|---------|---------------|---------|---------|---------|
|
||||
| BERT | [CLS] | [SEP] | [PAD] | [UNK] | [MASK] |
|
||||
| GPT-2 | - | <\|endoftext\|> | <\|endoftext\|> | <\|endoftext\|> | - |
|
||||
| RoBERTa | <s> | </s> | <pad> | <unk> | <mask> |
|
||||
| T5 | - | </s> | <pad> | <unk> | - |
|
||||
|
||||
### Adding special tokens
|
||||
|
||||
```python
|
||||
# Add new special tokens
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": ["<|image|>", "<|video|>", "<|audio|>"]
|
||||
}
|
||||
|
||||
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
print(f"Added {num_added_tokens} tokens")
|
||||
|
||||
# Resize model embeddings
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Use new tokens
|
||||
text = "This is an image: <|image|>"
|
||||
tokens = tokenizer.encode(text)
|
||||
```
|
||||
|
||||
### Adding regular tokens
|
||||
|
||||
```python
|
||||
# Add domain-specific tokens
|
||||
new_tokens = ["COVID-19", "mRNA", "vaccine"]
|
||||
num_added = tokenizer.add_tokens(new_tokens)
|
||||
|
||||
# These are NOT special tokens (can be split if needed)
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=False)
|
||||
|
||||
# These ARE special tokens (never split)
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
```
|
||||
|
||||
## Encoding and decoding
|
||||
|
||||
### Basic encoding
|
||||
|
||||
```python
|
||||
# Single sentence
|
||||
text = "Hello, how are you?"
|
||||
encoded = tokenizer(text)
|
||||
|
||||
print(encoded)
|
||||
# {'input_ids': [101, 7592, 1010, 2129, 2024, 2017, 1029, 102],
|
||||
# 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0],
|
||||
# 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
|
||||
```
|
||||
|
||||
### Batch encoding
|
||||
|
||||
```python
|
||||
# Multiple sentences
|
||||
texts = ["Hello world", "How are you?", "I am fine"]
|
||||
encoded = tokenizer(texts, padding=True, truncation=True, max_length=10)
|
||||
|
||||
print(encoded['input_ids'])
|
||||
# [[101, 7592, 2088, 102, 0, 0, 0, 0, 0, 0],
|
||||
# [101, 2129, 2024, 2017, 1029, 102, 0, 0, 0, 0],
|
||||
# [101, 1045, 2572, 2986, 102, 0, 0, 0, 0, 0]]
|
||||
```
|
||||
|
||||
### Return tensors
|
||||
|
||||
```python
|
||||
# Return PyTorch tensors
|
||||
outputs = tokenizer("Hello world", return_tensors="pt")
|
||||
print(outputs['input_ids'].shape) # torch.Size([1, 5])
|
||||
|
||||
# Return TensorFlow tensors
|
||||
outputs = tokenizer("Hello world", return_tensors="tf")
|
||||
|
||||
# Return NumPy arrays
|
||||
outputs = tokenizer("Hello world", return_tensors="np")
|
||||
|
||||
# Return lists (default)
|
||||
outputs = tokenizer("Hello world", return_tensors=None)
|
||||
```
|
||||
|
||||
### Decoding
|
||||
|
||||
```python
|
||||
# Decode token IDs
|
||||
ids = [101, 7592, 2088, 102]
|
||||
text = tokenizer.decode(ids)
|
||||
print(text) # "[CLS] hello world [SEP]"
|
||||
|
||||
# Skip special tokens
|
||||
text = tokenizer.decode(ids, skip_special_tokens=True)
|
||||
print(text) # "hello world"
|
||||
|
||||
# Batch decode
|
||||
batch_ids = [[101, 7592, 102], [101, 2088, 102]]
|
||||
texts = tokenizer.batch_decode(batch_ids, skip_special_tokens=True)
|
||||
print(texts) # ["hello", "world"]
|
||||
```
|
||||
|
||||
## Padding and truncation
|
||||
|
||||
### Padding strategies
|
||||
|
||||
```python
|
||||
# Pad to max length in batch
|
||||
tokenizer(texts, padding="longest")
|
||||
|
||||
# Pad to model max length
|
||||
tokenizer(texts, padding="max_length", max_length=128)
|
||||
|
||||
# No padding
|
||||
tokenizer(texts, padding=False)
|
||||
|
||||
# Pad to multiple of value (for efficient computation)
|
||||
tokenizer(texts, padding="max_length", max_length=128, pad_to_multiple_of=8)
|
||||
# Result: length will be 128 (already multiple of 8)
|
||||
```
|
||||
|
||||
### Truncation strategies
|
||||
|
||||
```python
|
||||
# Truncate to max length
|
||||
tokenizer(text, truncation=True, max_length=10)
|
||||
|
||||
# Only truncate first sequence (for pairs)
|
||||
tokenizer(text1, text2, truncation="only_first", max_length=20)
|
||||
|
||||
# Only truncate second sequence
|
||||
tokenizer(text1, text2, truncation="only_second", max_length=20)
|
||||
|
||||
# Truncate longest first (default for pairs)
|
||||
tokenizer(text1, text2, truncation="longest_first", max_length=20)
|
||||
|
||||
# No truncation (error if too long)
|
||||
tokenizer(text, truncation=False)
|
||||
```
|
||||
|
||||
### Stride for long documents
|
||||
|
||||
```python
|
||||
# For documents longer than max_length
|
||||
text = "Very long document " * 1000
|
||||
|
||||
# Encode with overlap
|
||||
encodings = tokenizer(
|
||||
text,
|
||||
max_length=512,
|
||||
stride=128, # Overlap between chunks
|
||||
truncation=True,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True
|
||||
)
|
||||
|
||||
# Get all chunks
|
||||
num_chunks = len(encodings['input_ids'])
|
||||
print(f"Split into {num_chunks} chunks")
|
||||
|
||||
# Each chunk overlaps by stride tokens
|
||||
for i, chunk in enumerate(encodings['input_ids']):
|
||||
print(f"Chunk {i}: {len(chunk)} tokens")
|
||||
```
|
||||
|
||||
**Use case**: Long document QA, sliding window inference
|
||||
|
||||
## Alignment and offsets
|
||||
|
||||
### Offset mapping
|
||||
|
||||
```python
|
||||
# Get character offsets for each token
|
||||
encoded = tokenizer("Hello, world!", return_offsets_mapping=True)
|
||||
|
||||
for token, (start, end) in zip(
|
||||
encoded.tokens(),
|
||||
encoded['offset_mapping'][0]
|
||||
):
|
||||
print(f"{token:10s} → [{start:2d}, {end:2d})")
|
||||
|
||||
# Output:
|
||||
# [CLS] → [ 0, 0)
|
||||
# Hello → [ 0, 5)
|
||||
# , → [ 5, 6)
|
||||
# world → [ 7, 12)
|
||||
# ! → [12, 13)
|
||||
# [SEP] → [ 0, 0)
|
||||
```
|
||||
|
||||
### Word IDs
|
||||
|
||||
```python
|
||||
# Get word index for each token
|
||||
encoded = tokenizer("Hello world", return_offsets_mapping=True)
|
||||
word_ids = encoded.word_ids()
|
||||
|
||||
print(word_ids)
|
||||
# [None, 0, 1, None]
|
||||
# None = special token, 0 = first word, 1 = second word
|
||||
```
|
||||
|
||||
**Use case**: Token classification (NER, POS tagging)
|
||||
|
||||
### Character to token mapping
|
||||
|
||||
```python
|
||||
text = "Machine learning is awesome"
|
||||
encoded = tokenizer(text, return_offsets_mapping=True)
|
||||
|
||||
# Find token for character position
|
||||
char_pos = 8 # "l" in "learning"
|
||||
token_idx = encoded.char_to_token(char_pos)
|
||||
|
||||
print(f"Character {char_pos} is in token {token_idx}: {encoded.tokens()[token_idx]}")
|
||||
# Character 8 is in token 2: learning
|
||||
```
|
||||
|
||||
**Use case**: Question answering (map answer character span to tokens)
|
||||
|
||||
### Sequence pairs
|
||||
|
||||
```python
|
||||
# Encode sentence pair
|
||||
encoded = tokenizer("Question here", "Answer here", return_offsets_mapping=True)
|
||||
|
||||
# Get sequence IDs (which sequence each token belongs to)
|
||||
sequence_ids = encoded.sequence_ids()
|
||||
print(sequence_ids)
|
||||
# [None, 0, 0, 0, None, 1, 1, 1, None]
|
||||
# None = special token, 0 = question, 1 = answer
|
||||
```
|
||||
|
||||
## Model integration
|
||||
|
||||
### Use with transformers models
|
||||
|
||||
```python
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import torch
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModel.from_pretrained("bert-base-uncased")
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Tokenize
|
||||
text = "Hello world"
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Get embeddings
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
print(last_hidden_state.shape) # [1, seq_len, hidden_size]
|
||||
```
|
||||
|
||||
### Custom model with custom tokenizer
|
||||
|
||||
```python
|
||||
from transformers import BertConfig, BertModel
|
||||
|
||||
# Train custom tokenizer
|
||||
from tokenizers import Tokenizer, models, trainers
|
||||
tokenizer = Tokenizer(models.BPE())
|
||||
trainer = trainers.BpeTrainer(vocab_size=30000)
|
||||
tokenizer.train(files=["data.txt"], trainer=trainer)
|
||||
|
||||
# Wrap for transformers
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
fast_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]"
|
||||
)
|
||||
|
||||
# Create model with custom vocab size
|
||||
config = BertConfig(vocab_size=30000)
|
||||
model = BertModel(config)
|
||||
|
||||
# Use together
|
||||
inputs = fast_tokenizer("Hello world", return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
```
|
||||
|
||||
### Save and load together
|
||||
|
||||
```python
|
||||
# Save both
|
||||
model.save_pretrained("my-model")
|
||||
tokenizer.save_pretrained("my-model")
|
||||
|
||||
# Directory structure:
|
||||
# my-model/
|
||||
# ├── config.json
|
||||
# ├── pytorch_model.bin
|
||||
# ├── tokenizer.json
|
||||
# ├── tokenizer_config.json
|
||||
# └── special_tokens_map.json
|
||||
|
||||
# Load both
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
model = AutoModel.from_pretrained("my-model")
|
||||
tokenizer = AutoTokenizer.from_pretrained("my-model")
|
||||
```
|
||||
|
||||
## Advanced features
|
||||
|
||||
### Multimodal tokenization
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# LLaVA-style (image + text)
|
||||
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
# Add image placeholder token
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
||||
|
||||
# Use in prompt
|
||||
text = "Describe this image: <image>"
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
```
|
||||
|
||||
### Template formatting
|
||||
|
||||
```python
|
||||
# Chat template
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi! How can I help?"},
|
||||
{"role": "user", "content": "What's the weather?"}
|
||||
]
|
||||
|
||||
# Apply chat template (if tokenizer has one)
|
||||
if hasattr(tokenizer, "apply_chat_template"):
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
```
|
||||
|
||||
### Custom template
|
||||
|
||||
```python
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
|
||||
|
||||
# Define chat template
|
||||
tokenizer.chat_template = """
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' %}
|
||||
System: {{ message['content'] }}\\n
|
||||
{%- elif message['role'] == 'user' %}
|
||||
User: {{ message['content'] }}\\n
|
||||
{%- elif message['role'] == 'assistant' %}
|
||||
Assistant: {{ message['content'] }}\\n
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
Assistant:
|
||||
"""
|
||||
|
||||
# Use template
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Batch processing
|
||||
|
||||
```python
|
||||
# Process large datasets efficiently
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("imdb", split="train[:1000]")
|
||||
|
||||
# Tokenize in batches
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(
|
||||
examples["text"],
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# Map over dataset (batched)
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
batch_size=1000,
|
||||
num_proc=4 # Parallel processing
|
||||
)
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
```python
|
||||
# Enable caching for repeated tokenization
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
use_fast=True,
|
||||
cache_dir="./cache" # Cache tokenizer files
|
||||
)
|
||||
|
||||
# Tokenize with caching
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=10000)
|
||||
def cached_tokenize(text):
|
||||
return tuple(tokenizer.encode(text))
|
||||
|
||||
# Reuses cached results for repeated inputs
|
||||
```
|
||||
|
||||
### Memory efficiency
|
||||
|
||||
```python
|
||||
# For very large datasets, use streaming
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("pile", split="train", streaming=True)
|
||||
|
||||
def process_batch(batch):
|
||||
# Tokenize
|
||||
tokens = tokenizer(batch["text"], truncation=True, max_length=512)
|
||||
|
||||
# Process tokens...
|
||||
|
||||
return tokens
|
||||
|
||||
# Process in chunks (memory efficient)
|
||||
for batch in dataset.batch(batch_size=1000):
|
||||
processed = process_batch(batch)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Tokenizer not fast
|
||||
|
||||
**Symptom**:
|
||||
```python
|
||||
tokenizer.is_fast # False
|
||||
```
|
||||
|
||||
**Solution**: Install tokenizers library
|
||||
```bash
|
||||
pip install tokenizers
|
||||
```
|
||||
|
||||
### Issue: Special tokens not working
|
||||
|
||||
**Symptom**: Special tokens are split into subwords
|
||||
|
||||
**Solution**: Add as special tokens, not regular tokens
|
||||
```python
|
||||
# Wrong
|
||||
tokenizer.add_tokens(["<|image|>"])
|
||||
|
||||
# Correct
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<|image|>"]})
|
||||
```
|
||||
|
||||
### Issue: Offset mapping not available
|
||||
|
||||
**Symptom**:
|
||||
```python
|
||||
tokenizer("text", return_offsets_mapping=True)
|
||||
# Error: return_offsets_mapping not supported
|
||||
```
|
||||
|
||||
**Solution**: Use fast tokenizer
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Load fast version
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
|
||||
```
|
||||
|
||||
### Issue: Padding inconsistent
|
||||
|
||||
**Symptom**: Some sequences padded, others not
|
||||
|
||||
**Solution**: Specify padding strategy
|
||||
```python
|
||||
# Explicit padding
|
||||
tokenizer(
|
||||
texts,
|
||||
padding="max_length", # or "longest"
|
||||
max_length=128
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Always use fast tokenizers**:
|
||||
- 5-10× faster
|
||||
- Full alignment tracking
|
||||
- Better batch processing
|
||||
|
||||
2. **Save tokenizer with model**:
|
||||
- Ensures reproducibility
|
||||
- Prevents version mismatches
|
||||
|
||||
3. **Use batch processing for datasets**:
|
||||
- Tokenize with `.map(batched=True)`
|
||||
- Set `num_proc` for parallelism
|
||||
|
||||
4. **Enable caching for repeated inputs**:
|
||||
- Use `lru_cache` for inference
|
||||
- Cache tokenizer files with `cache_dir`
|
||||
|
||||
5. **Handle special tokens properly**:
|
||||
- Use `add_special_tokens()` for never-split tokens
|
||||
- Resize embeddings after adding tokens
|
||||
|
||||
6. **Test alignment for downstream tasks**:
|
||||
- Verify `offset_mapping` is correct
|
||||
- Test `char_to_token()` on samples
|
||||
|
||||
7. **Version control tokenizer config**:
|
||||
- Save `tokenizer_config.json`
|
||||
- Document custom templates
|
||||
- Track vocabulary changes
|
||||
@@ -0,0 +1,723 @@
|
||||
# Tokenization Pipeline Components
|
||||
|
||||
Complete guide to normalizers, pre-tokenizers, models, post-processors, and decoders.
|
||||
|
||||
## Pipeline overview
|
||||
|
||||
**Full tokenization pipeline**:
|
||||
```
|
||||
Raw Text
|
||||
↓
|
||||
Normalization (cleaning, lowercasing)
|
||||
↓
|
||||
Pre-tokenization (split into words)
|
||||
↓
|
||||
Model (apply BPE/WordPiece/Unigram)
|
||||
↓
|
||||
Post-processing (add special tokens)
|
||||
↓
|
||||
Token IDs
|
||||
```
|
||||
|
||||
**Decoding reverses the process**:
|
||||
```
|
||||
Token IDs
|
||||
↓
|
||||
Decoder (handle special encodings)
|
||||
↓
|
||||
Raw Text
|
||||
```
|
||||
|
||||
## Normalizers
|
||||
|
||||
Clean and standardize input text.
|
||||
|
||||
### Common normalizers
|
||||
|
||||
**Lowercase**:
|
||||
```python
|
||||
from tokenizers.normalizers import Lowercase
|
||||
|
||||
tokenizer.normalizer = Lowercase()
|
||||
|
||||
# Input: "Hello WORLD"
|
||||
# Output: "hello world"
|
||||
```
|
||||
|
||||
**Unicode normalization**:
|
||||
```python
|
||||
from tokenizers.normalizers import NFD, NFC, NFKD, NFKC
|
||||
|
||||
# NFD: Canonical decomposition
|
||||
tokenizer.normalizer = NFD()
|
||||
# "é" → "e" + "́" (separate characters)
|
||||
|
||||
# NFC: Canonical composition (default)
|
||||
tokenizer.normalizer = NFC()
|
||||
# "e" + "́" → "é" (composed)
|
||||
|
||||
# NFKD: Compatibility decomposition
|
||||
tokenizer.normalizer = NFKD()
|
||||
# "fi" → "f" + "i"
|
||||
|
||||
# NFKC: Compatibility composition
|
||||
tokenizer.normalizer = NFKC()
|
||||
# Most aggressive normalization
|
||||
```
|
||||
|
||||
**Strip accents**:
|
||||
```python
|
||||
from tokenizers.normalizers import StripAccents
|
||||
|
||||
tokenizer.normalizer = StripAccents()
|
||||
|
||||
# Input: "café"
|
||||
# Output: "cafe"
|
||||
```
|
||||
|
||||
**Whitespace handling**:
|
||||
```python
|
||||
from tokenizers.normalizers import Strip, StripAccents
|
||||
|
||||
# Remove leading/trailing whitespace
|
||||
tokenizer.normalizer = Strip()
|
||||
|
||||
# Input: " hello "
|
||||
# Output: "hello"
|
||||
```
|
||||
|
||||
**Replace patterns**:
|
||||
```python
|
||||
from tokenizers.normalizers import Replace
|
||||
|
||||
# Replace newlines with spaces
|
||||
tokenizer.normalizer = Replace("\\n", " ")
|
||||
|
||||
# Input: "hello\\nworld"
|
||||
# Output: "hello world"
|
||||
```
|
||||
|
||||
### Combining normalizers
|
||||
|
||||
```python
|
||||
from tokenizers.normalizers import Sequence, NFD, Lowercase, StripAccents
|
||||
|
||||
# BERT-style normalization
|
||||
tokenizer.normalizer = Sequence([
|
||||
NFD(), # Unicode decomposition
|
||||
Lowercase(), # Convert to lowercase
|
||||
StripAccents() # Remove accents
|
||||
])
|
||||
|
||||
# Input: "Café au Lait"
|
||||
# After NFD: "Café au Lait" (e + ́)
|
||||
# After Lowercase: "café au lait"
|
||||
# After StripAccents: "cafe au lait"
|
||||
```
|
||||
|
||||
### Use case examples
|
||||
|
||||
**Case-insensitive model (BERT)**:
|
||||
```python
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
|
||||
# All-in-one BERT normalization
|
||||
tokenizer.normalizer = BertNormalizer(
|
||||
clean_text=True, # Remove control characters
|
||||
handle_chinese_chars=True, # Add spaces around Chinese
|
||||
strip_accents=True, # Remove accents
|
||||
lowercase=True # Lowercase
|
||||
)
|
||||
```
|
||||
|
||||
**Case-sensitive model (GPT-2)**:
|
||||
```python
|
||||
# Minimal normalization
|
||||
tokenizer.normalizer = NFC() # Only normalize Unicode
|
||||
```
|
||||
|
||||
**Multilingual (mBERT)**:
|
||||
```python
|
||||
# Preserve scripts, normalize form
|
||||
tokenizer.normalizer = NFKC()
|
||||
```
|
||||
|
||||
## Pre-tokenizers
|
||||
|
||||
Split text into word-like units before tokenization.
|
||||
|
||||
### Whitespace splitting
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
# Input: "Hello world! How are you?"
|
||||
# Output: [("Hello", (0, 5)), ("world!", (6, 12)), ("How", (13, 16)), ("are", (17, 20)), ("you?", (21, 25))]
|
||||
```
|
||||
|
||||
### Punctuation isolation
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Punctuation
|
||||
|
||||
tokenizer.pre_tokenizer = Punctuation()
|
||||
|
||||
# Input: "Hello, world!"
|
||||
# Output: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
|
||||
```
|
||||
|
||||
### Byte-level (GPT-2)
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
|
||||
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
|
||||
|
||||
# Input: "Hello world"
|
||||
# Output: Byte-level tokens with Ġ prefix for spaces
|
||||
# [("ĠHello", ...), ("Ġworld", ...)]
|
||||
```
|
||||
|
||||
**Key feature**: Handles ALL Unicode characters (256 byte combinations)
|
||||
|
||||
### Metaspace (SentencePiece)
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Metaspace
|
||||
|
||||
tokenizer.pre_tokenizer = Metaspace(replacement="▁", add_prefix_space=True)
|
||||
|
||||
# Input: "Hello world"
|
||||
# Output: [("▁Hello", ...), ("▁world", ...)]
|
||||
```
|
||||
|
||||
**Used by**: T5, ALBERT (via SentencePiece)
|
||||
|
||||
### Digits splitting
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Digits
|
||||
|
||||
# Split digits individually
|
||||
tokenizer.pre_tokenizer = Digits(individual_digits=True)
|
||||
|
||||
# Input: "Room 123"
|
||||
# Output: [("Room", ...), ("1", ...), ("2", ...), ("3", ...)]
|
||||
|
||||
# Keep digits together
|
||||
tokenizer.pre_tokenizer = Digits(individual_digits=False)
|
||||
|
||||
# Input: "Room 123"
|
||||
# Output: [("Room", ...), ("123", ...)]
|
||||
```
|
||||
|
||||
### BERT pre-tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
# Splits on whitespace and punctuation, preserves CJK
|
||||
# Input: "Hello, 世界!"
|
||||
# Output: [("Hello", ...), (",", ...), ("世", ...), ("界", ...), ("!", ...)]
|
||||
```
|
||||
|
||||
### Combining pre-tokenizers
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Sequence, Whitespace, Punctuation
|
||||
|
||||
tokenizer.pre_tokenizer = Sequence([
|
||||
Whitespace(), # Split on whitespace first
|
||||
Punctuation() # Then isolate punctuation
|
||||
])
|
||||
|
||||
# Input: "Hello, world!"
|
||||
# After Whitespace: [("Hello,", ...), ("world!", ...)]
|
||||
# After Punctuation: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
|
||||
```
|
||||
|
||||
### Pre-tokenizer comparison
|
||||
|
||||
| Pre-tokenizer | Use Case | Example |
|
||||
|-------------------|---------------------------------|--------------------------------------------|
|
||||
| Whitespace | Simple English | "Hello world" → ["Hello", "world"] |
|
||||
| Punctuation | Isolate symbols | "world!" → ["world", "!"] |
|
||||
| ByteLevel | Multilingual, emojis | "🌍" → byte tokens |
|
||||
| Metaspace | SentencePiece-style | "Hello" → ["▁Hello"] |
|
||||
| BertPreTokenizer | BERT-style (CJK aware) | "世界" → ["世", "界"] |
|
||||
| Digits | Handle numbers | "123" → ["1", "2", "3"] or ["123"] |
|
||||
|
||||
## Models
|
||||
|
||||
Core tokenization algorithms.
|
||||
|
||||
### BPE Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import BPE
|
||||
|
||||
model = BPE(
|
||||
vocab=None, # Or provide pre-built vocab
|
||||
merges=None, # Or provide merge rules
|
||||
unk_token="[UNK]", # Unknown token
|
||||
continuing_subword_prefix="",
|
||||
end_of_word_suffix="",
|
||||
fuse_unk=False # Keep unknown tokens separate
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `vocab`: Dict of token → id
|
||||
- `merges`: List of merge rules `["a b", "ab c"]`
|
||||
- `unk_token`: Token for unknown words
|
||||
- `continuing_subword_prefix`: Prefix for subwords (empty for GPT-2)
|
||||
- `end_of_word_suffix`: Suffix for last subword (empty for GPT-2)
|
||||
|
||||
### WordPiece Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import WordPiece
|
||||
|
||||
model = WordPiece(
|
||||
vocab=None,
|
||||
unk_token="[UNK]",
|
||||
max_input_chars_per_word=100, # Max word length
|
||||
continuing_subword_prefix="##" # BERT-style prefix
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Key difference**: Uses `##` prefix for continuing subwords.
|
||||
|
||||
### Unigram Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import Unigram
|
||||
|
||||
model = Unigram(
|
||||
vocab=None, # List of (token, score) tuples
|
||||
unk_id=0, # ID for unknown token
|
||||
byte_fallback=False # Fall back to bytes if no match
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Probabilistic**: Selects tokenization with highest probability.
|
||||
|
||||
### WordLevel Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import WordLevel
|
||||
|
||||
# Simple word-to-ID mapping (no subwords)
|
||||
model = WordLevel(
|
||||
vocab=None,
|
||||
unk_token="[UNK]"
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Warning**: Requires huge vocabulary (one token per word).
|
||||
|
||||
## Post-processors
|
||||
|
||||
Add special tokens and format output.
|
||||
|
||||
### Template processing
|
||||
|
||||
**BERT-style** (`[CLS] sentence [SEP]`):
|
||||
```python
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[
|
||||
("[CLS]", 101),
|
||||
("[SEP]", 102),
|
||||
],
|
||||
)
|
||||
|
||||
# Single sentence
|
||||
output = tokenizer.encode("Hello world")
|
||||
# [101, ..., 102] ([CLS] hello world [SEP])
|
||||
|
||||
# Sentence pair
|
||||
output = tokenizer.encode("Hello", "world")
|
||||
# [101, ..., 102, ..., 102] ([CLS] hello [SEP] world [SEP])
|
||||
```
|
||||
|
||||
**GPT-2 style** (`sentence <|endoftext|>`):
|
||||
```python
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[
|
||||
("<|endoftext|>", 50256),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
**RoBERTa style** (`<s> sentence </s>`):
|
||||
```python
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="<s> $A </s>",
|
||||
pair="<s> $A </s> </s> $B </s>",
|
||||
special_tokens=[
|
||||
("<s>", 0),
|
||||
("</s>", 2),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
**T5 style** (no special tokens):
|
||||
```python
|
||||
# T5 doesn't add special tokens via post-processor
|
||||
tokenizer.post_processor = None
|
||||
```
|
||||
|
||||
### RobertaProcessing
|
||||
|
||||
```python
|
||||
from tokenizers.processors import RobertaProcessing
|
||||
|
||||
tokenizer.post_processor = RobertaProcessing(
|
||||
sep=("</s>", 2),
|
||||
cls=("<s>", 0),
|
||||
add_prefix_space=True, # Add space before first token
|
||||
trim_offsets=True # Trim leading space from offsets
|
||||
)
|
||||
```
|
||||
|
||||
### ByteLevelProcessing
|
||||
|
||||
```python
|
||||
from tokenizers.processors import ByteLevel as ByteLevelProcessing
|
||||
|
||||
tokenizer.post_processor = ByteLevelProcessing(
|
||||
trim_offsets=True # Remove Ġ from offsets
|
||||
)
|
||||
```
|
||||
|
||||
## Decoders
|
||||
|
||||
Convert token IDs back to text.
|
||||
|
||||
### ByteLevel decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import ByteLevel
|
||||
|
||||
tokenizer.decoder = ByteLevel()
|
||||
|
||||
# Handles byte-level tokens
|
||||
# ["ĠHello", "Ġworld"] → "Hello world"
|
||||
```
|
||||
|
||||
### WordPiece decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import WordPiece
|
||||
|
||||
tokenizer.decoder = WordPiece(prefix="##")
|
||||
|
||||
# Removes ## prefix and concatenates
|
||||
# ["token", "##ization"] → "tokenization"
|
||||
```
|
||||
|
||||
### Metaspace decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import Metaspace
|
||||
|
||||
tokenizer.decoder = Metaspace(replacement="▁", add_prefix_space=True)
|
||||
|
||||
# Converts ▁ back to spaces
|
||||
# ["▁Hello", "▁world"] → "Hello world"
|
||||
```
|
||||
|
||||
### BPEDecoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import BPEDecoder
|
||||
|
||||
tokenizer.decoder = BPEDecoder(suffix="</w>")
|
||||
|
||||
# Removes suffix and concatenates
|
||||
# ["token", "ization</w>"] → "tokenization"
|
||||
```
|
||||
|
||||
### Sequence decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import Sequence, ByteLevel, Strip
|
||||
|
||||
tokenizer.decoder = Sequence([
|
||||
ByteLevel(), # Decode byte-level first
|
||||
Strip(' ', 1, 1) # Strip leading/trailing spaces
|
||||
])
|
||||
```
|
||||
|
||||
## Complete pipeline examples
|
||||
|
||||
### BERT tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
from tokenizers.decoders import WordPiece as WordPieceDecoder
|
||||
|
||||
# Model
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
|
||||
# Normalization
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
|
||||
# Pre-tokenization
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
# Post-processing
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
|
||||
)
|
||||
|
||||
# Decoder
|
||||
tokenizer.decoder = WordPieceDecoder(prefix="##")
|
||||
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]")
|
||||
|
||||
# Enable truncation
|
||||
tokenizer.enable_truncation(max_length=512)
|
||||
```
|
||||
|
||||
### GPT-2 tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.normalizers import NFC
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
# Model
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Normalization (minimal)
|
||||
tokenizer.normalizer = NFC()
|
||||
|
||||
# Byte-level pre-tokenization
|
||||
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
|
||||
|
||||
# Post-processing
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[("<|endoftext|>", 50256)],
|
||||
)
|
||||
|
||||
# Byte-level decoder
|
||||
tokenizer.decoder = ByteLevelDecoder()
|
||||
```
|
||||
|
||||
### T5 tokenizer (SentencePiece-style)
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.normalizers import NFKC
|
||||
from tokenizers.pre_tokenizers import Metaspace
|
||||
from tokenizers.decoders import Metaspace as MetaspaceDecoder
|
||||
|
||||
# Model
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
# Normalization
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
# Metaspace pre-tokenization
|
||||
tokenizer.pre_tokenizer = Metaspace(replacement="▁", add_prefix_space=True)
|
||||
|
||||
# No post-processing (T5 doesn't add CLS/SEP)
|
||||
tokenizer.post_processor = None
|
||||
|
||||
# Metaspace decoder
|
||||
tokenizer.decoder = MetaspaceDecoder(replacement="▁", add_prefix_space=True)
|
||||
```
|
||||
|
||||
## Alignment tracking
|
||||
|
||||
Track token positions in original text.
|
||||
|
||||
### Basic alignment
|
||||
|
||||
```python
|
||||
text = "Hello, world!"
|
||||
output = tokenizer.encode(text)
|
||||
|
||||
for token, (start, end) in zip(output.tokens, output.offsets):
|
||||
print(f"{token:10s} → [{start:2d}, {end:2d}): {text[start:end]!r}")
|
||||
|
||||
# Output:
|
||||
# [CLS] → [ 0, 0): ''
|
||||
# hello → [ 0, 5): 'Hello'
|
||||
# , → [ 5, 6): ','
|
||||
# world → [ 7, 12): 'world'
|
||||
# ! → [12, 13): '!'
|
||||
# [SEP] → [ 0, 0): ''
|
||||
```
|
||||
|
||||
### Word-level alignment
|
||||
|
||||
```python
|
||||
# Get word_ids (which word each token belongs to)
|
||||
encoding = tokenizer.encode("Hello world")
|
||||
word_ids = encoding.word_ids
|
||||
|
||||
print(word_ids)
|
||||
# [None, 0, 0, 1, None]
|
||||
# None = special token, 0 = first word, 1 = second word
|
||||
```
|
||||
|
||||
**Use case**: Token classification (NER)
|
||||
```python
|
||||
# Align predictions to words
|
||||
predictions = ["O", "B-PER", "I-PER", "O", "O"]
|
||||
word_predictions = {}
|
||||
|
||||
for token_idx, word_idx in enumerate(encoding.word_ids):
|
||||
if word_idx is not None and word_idx not in word_predictions:
|
||||
word_predictions[word_idx] = predictions[token_idx]
|
||||
|
||||
print(word_predictions)
|
||||
# {0: "B-PER", 1: "O"} # First word is PERSON, second is OTHER
|
||||
```
|
||||
|
||||
### Span alignment
|
||||
|
||||
```python
|
||||
# Find token span for character span
|
||||
text = "Machine learning is awesome"
|
||||
char_start, char_end = 8, 16 # "learning"
|
||||
|
||||
encoding = tokenizer.encode(text)
|
||||
|
||||
# Find token span
|
||||
token_start = encoding.char_to_token(char_start)
|
||||
token_end = encoding.char_to_token(char_end - 1) + 1
|
||||
|
||||
print(f"Tokens {token_start}:{token_end} = {encoding.tokens[token_start:token_end]}")
|
||||
# Tokens 2:3 = ['learning']
|
||||
```
|
||||
|
||||
**Use case**: Question answering (extract answer span)
|
||||
|
||||
## Custom components
|
||||
|
||||
### Custom normalizer
|
||||
|
||||
```python
|
||||
from tokenizers import NormalizedString, Normalizer
|
||||
|
||||
class CustomNormalizer:
|
||||
def normalize(self, normalized: NormalizedString):
|
||||
# Custom normalization logic
|
||||
normalized.lowercase()
|
||||
normalized.replace(" ", " ") # Replace double spaces
|
||||
|
||||
# Use custom normalizer
|
||||
tokenizer.normalizer = CustomNormalizer()
|
||||
```
|
||||
|
||||
### Custom pre-tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import PreTokenizedString
|
||||
|
||||
class CustomPreTokenizer:
|
||||
def pre_tokenize(self, pretok: PreTokenizedString):
|
||||
# Custom pre-tokenization logic
|
||||
pretok.split(lambda i, char: char.isspace())
|
||||
|
||||
tokenizer.pre_tokenizer = CustomPreTokenizer()
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Misaligned offsets
|
||||
|
||||
**Symptom**: Offsets don't match original text
|
||||
```python
|
||||
text = " hello" # Leading spaces
|
||||
offsets = [(0, 5)] # Expects " hel"
|
||||
```
|
||||
|
||||
**Solution**: Check normalization strips spaces
|
||||
```python
|
||||
# Preserve offsets
|
||||
tokenizer.normalizer = Sequence([
|
||||
Strip(), # This changes offsets!
|
||||
])
|
||||
|
||||
# Use trim_offsets in post-processor instead
|
||||
tokenizer.post_processor = ByteLevelProcessing(trim_offsets=True)
|
||||
```
|
||||
|
||||
### Issue: Special tokens not added
|
||||
|
||||
**Symptom**: No [CLS] or [SEP] in output
|
||||
|
||||
**Solution**: Check post-processor is set
|
||||
```python
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Incorrect decoding
|
||||
|
||||
**Symptom**: Decoded text has ## or ▁
|
||||
|
||||
**Solution**: Set correct decoder
|
||||
```python
|
||||
# For WordPiece
|
||||
tokenizer.decoder = WordPieceDecoder(prefix="##")
|
||||
|
||||
# For SentencePiece
|
||||
tokenizer.decoder = MetaspaceDecoder(replacement="▁")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Match pipeline to model architecture**:
|
||||
- BERT → BertNormalizer + BertPreTokenizer + WordPiece
|
||||
- GPT-2 → NFC + ByteLevel + BPE
|
||||
- T5 → NFKC + Metaspace + Unigram
|
||||
|
||||
2. **Test pipeline on sample inputs**:
|
||||
- Check normalization doesn't over-normalize
|
||||
- Verify pre-tokenization splits correctly
|
||||
- Ensure decoding reconstructs text
|
||||
|
||||
3. **Preserve alignment for downstream tasks**:
|
||||
- Use `trim_offsets` instead of stripping in normalizer
|
||||
- Test `char_to_token()` on sample spans
|
||||
|
||||
4. **Document your pipeline**:
|
||||
- Save complete tokenizer config
|
||||
- Document special tokens
|
||||
- Note any custom components
|
||||
@@ -0,0 +1,565 @@
|
||||
# Training Custom Tokenizers
|
||||
|
||||
Complete guide to training tokenizers from scratch.
|
||||
|
||||
## Training workflow
|
||||
|
||||
### Step 1: Choose tokenization algorithm
|
||||
|
||||
**Decision tree**:
|
||||
- **GPT-style model** → BPE
|
||||
- **BERT-style model** → WordPiece
|
||||
- **Multilingual/No word boundaries** → Unigram
|
||||
|
||||
### Step 2: Prepare training data
|
||||
|
||||
```python
|
||||
# Option 1: From files
|
||||
files = ["train.txt", "validation.txt"]
|
||||
|
||||
# Option 2: From Python list
|
||||
texts = [
|
||||
"This is the first sentence.",
|
||||
"This is the second sentence.",
|
||||
# ... more texts
|
||||
]
|
||||
|
||||
# Option 3: From dataset iterator
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
|
||||
|
||||
def batch_iterator(batch_size=1000):
|
||||
for i in range(0, len(dataset), batch_size):
|
||||
yield dataset[i:i + batch_size]["text"]
|
||||
```
|
||||
|
||||
### Step 3: Initialize tokenizer
|
||||
|
||||
**BPE example**:
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
tokenizer.decoder = ByteLevelDecoder()
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000,
|
||||
min_frequency=2,
|
||||
special_tokens=["<|endoftext|>", "<|padding|>"],
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
**WordPiece example**:
|
||||
```python
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522,
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
continuing_subword_prefix="##",
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
**Unigram example**:
|
||||
```python
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000,
|
||||
special_tokens=["<unk>", "<s>", "</s>", "<pad>"],
|
||||
unk_token="<unk>",
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
### Step 4: Train
|
||||
|
||||
```python
|
||||
# From files
|
||||
tokenizer.train(files=files, trainer=trainer)
|
||||
|
||||
# From iterator (recommended for large datasets)
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(),
|
||||
trainer=trainer,
|
||||
length=len(dataset) # Optional, for progress bar
|
||||
)
|
||||
```
|
||||
|
||||
**Training time** (30k vocab on 16-core CPU):
|
||||
- 10 MB: 15-30 seconds
|
||||
- 100 MB: 1-3 minutes
|
||||
- 1 GB: 15-30 minutes
|
||||
- 10 GB: 2-4 hours
|
||||
|
||||
### Step 5: Add post-processing
|
||||
|
||||
```python
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
# BERT-style
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[
|
||||
("[CLS]", tokenizer.token_to_id("[CLS]")),
|
||||
("[SEP]", tokenizer.token_to_id("[SEP]")),
|
||||
],
|
||||
)
|
||||
|
||||
# GPT-2 style
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[
|
||||
("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
### Step 6: Save
|
||||
|
||||
```python
|
||||
# Save to JSON
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
|
||||
# Save to directory (for transformers)
|
||||
tokenizer.save("my-tokenizer-dir/tokenizer.json")
|
||||
|
||||
# Convert to transformers format
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
transformers_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
mask_token="[MASK]"
|
||||
)
|
||||
|
||||
transformers_tokenizer.save_pretrained("my-tokenizer-dir")
|
||||
```
|
||||
|
||||
## Trainer configuration
|
||||
|
||||
### BpeTrainer parameters
|
||||
|
||||
```python
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=30000, # Target vocabulary size
|
||||
min_frequency=2, # Minimum frequency for merges
|
||||
special_tokens=["[UNK]"], # Special tokens (added first)
|
||||
limit_alphabet=1000, # Limit initial alphabet size
|
||||
initial_alphabet=[], # Pre-defined initial characters
|
||||
show_progress=True, # Show progress bar
|
||||
continuing_subword_prefix="", # Prefix for continuing subwords
|
||||
end_of_word_suffix="" # Suffix for end of words
|
||||
)
|
||||
```
|
||||
|
||||
**Parameter tuning**:
|
||||
- **vocab_size**: Start with 30k for English, 50k for multilingual
|
||||
- **min_frequency**: 2-5 for large corpora, 1 for small
|
||||
- **limit_alphabet**: Reduce for non-English (CJK languages)
|
||||
|
||||
### WordPieceTrainer parameters
|
||||
|
||||
```python
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522, # BERT uses 30,522
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
limit_alphabet=1000,
|
||||
continuing_subword_prefix="##", # BERT-style prefix
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
### UnigramTrainer parameters
|
||||
|
||||
```python
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000, # Typically smaller than BPE/WordPiece
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
unk_token="<unk>",
|
||||
max_piece_length=16, # Maximum token length
|
||||
n_sub_iterations=2, # EM algorithm iterations
|
||||
shrinking_factor=0.75, # Vocabulary reduction rate
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
## Training from large datasets
|
||||
|
||||
### Memory-efficient training
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("wikipedia", "20220301.en", split="train", streaming=True)
|
||||
|
||||
# Create iterator (yields batches)
|
||||
def batch_iterator(batch_size=1000):
|
||||
batch = []
|
||||
for sample in dataset:
|
||||
batch.append(sample["text"])
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
# Initialize tokenizer
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(vocab_size=50000, special_tokens=["<|endoftext|>"])
|
||||
|
||||
# Train (memory efficient - streams data)
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(),
|
||||
trainer=trainer
|
||||
)
|
||||
```
|
||||
|
||||
**Memory usage**: ~200 MB (vs 10+ GB loading full dataset)
|
||||
|
||||
### Multi-file training
|
||||
|
||||
```python
|
||||
import glob
|
||||
|
||||
# Find all training files
|
||||
files = glob.glob("data/train/*.txt")
|
||||
print(f"Training on {len(files)} files")
|
||||
|
||||
# Train on all files
|
||||
tokenizer.train(files=files, trainer=trainer)
|
||||
```
|
||||
|
||||
### Parallel training (multi-processing)
|
||||
|
||||
```python
|
||||
from multiprocessing import Pool, cpu_count
|
||||
import os
|
||||
|
||||
def train_shard(shard_files):
|
||||
"""Train tokenizer on a shard of files."""
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(vocab_size=50000)
|
||||
tokenizer.train(files=shard_files, trainer=trainer)
|
||||
return tokenizer.get_vocab()
|
||||
|
||||
# Split files into shards
|
||||
num_shards = cpu_count()
|
||||
file_shards = [files[i::num_shards] for i in range(num_shards)]
|
||||
|
||||
# Train shards in parallel
|
||||
with Pool(num_shards) as pool:
|
||||
vocab_shards = pool.map(train_shard, file_shards)
|
||||
|
||||
# Merge vocabularies (custom logic needed)
|
||||
# This is a simplified example - real implementation would merge intelligently
|
||||
final_vocab = {}
|
||||
for vocab in vocab_shards:
|
||||
final_vocab.update(vocab)
|
||||
```
|
||||
|
||||
## Domain-specific tokenizers
|
||||
|
||||
### Code tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.normalizers import Sequence, NFC
|
||||
|
||||
# Code-optimized configuration
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Minimal normalization (preserve case, whitespace)
|
||||
tokenizer.normalizer = NFC() # Only normalize Unicode
|
||||
|
||||
# Byte-level pre-tokenization (handles all characters)
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
# Train on code corpus
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000,
|
||||
special_tokens=["<|endoftext|>", "<|pad|>"],
|
||||
min_frequency=2
|
||||
)
|
||||
|
||||
tokenizer.train(files=["code_corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
### Medical/scientific tokenizer
|
||||
|
||||
```python
|
||||
# Preserve case and special characters
|
||||
from tokenizers.normalizers import NFKC
|
||||
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Minimal normalization
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
# Preserve medical terms
|
||||
tokenizer.pre_tokenizer = Sequence([
|
||||
Whitespace(),
|
||||
Punctuation(behavior="isolated") # Keep punctuation separate
|
||||
])
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]"],
|
||||
min_frequency=3 # Higher threshold for rare medical terms
|
||||
)
|
||||
|
||||
tokenizer.train(files=["pubmed_corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
### Multilingual tokenizer
|
||||
|
||||
```python
|
||||
# Handle multiple scripts
|
||||
from tokenizers.normalizers import NFKC, Lowercase, Sequence
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Normalize but don't lowercase (preserves script differences)
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
# Byte-level handles all Unicode
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=100000, # Larger vocab for multiple languages
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
limit_alphabet=None # No limit (handles all scripts)
|
||||
)
|
||||
|
||||
# Train on multilingual corpus
|
||||
tokenizer.train(files=["multilingual_corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
## Vocabulary size selection
|
||||
|
||||
### Guidelines by task
|
||||
|
||||
| Task | Recommended Vocab Size | Rationale |
|
||||
|-----------------------|------------------------|-----------|
|
||||
| English (monolingual) | 30,000 - 50,000 | Balanced coverage |
|
||||
| Multilingual | 50,000 - 250,000 | More languages = more tokens |
|
||||
| Code | 30,000 - 50,000 | Similar to English |
|
||||
| Domain-specific | 10,000 - 30,000 | Smaller, focused vocabulary |
|
||||
| Character-level tasks | 1,000 - 5,000 | Only characters + subwords |
|
||||
|
||||
### Vocabulary size impact
|
||||
|
||||
**Small vocab (10k)**:
|
||||
- Pros: Faster training, smaller model, less memory
|
||||
- Cons: More tokens per sentence, worse OOV handling
|
||||
|
||||
**Medium vocab (30k-50k)**:
|
||||
- Pros: Good balance, standard choice
|
||||
- Cons: None (recommended default)
|
||||
|
||||
**Large vocab (100k+)**:
|
||||
- Pros: Fewer tokens per sentence, better OOV
|
||||
- Cons: Slower training, larger embedding table
|
||||
|
||||
### Empirical testing
|
||||
|
||||
```python
|
||||
# Train multiple tokenizers with different vocab sizes
|
||||
vocab_sizes = [10000, 30000, 50000, 100000]
|
||||
|
||||
for vocab_size in vocab_sizes:
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(vocab_size=vocab_size)
|
||||
tokenizer.train(files=["sample.txt"], trainer=trainer)
|
||||
|
||||
# Evaluate on test set
|
||||
test_text = "Test sentence for evaluation..."
|
||||
tokens = tokenizer.encode(test_text).ids
|
||||
|
||||
print(f"Vocab: {vocab_size:6d} | Tokens: {len(tokens):3d} | Avg: {len(test_text)/len(tokens):.2f} chars/token")
|
||||
|
||||
# Example output:
|
||||
# Vocab: 10000 | Tokens: 12 | Avg: 2.33 chars/token
|
||||
# Vocab: 30000 | Tokens: 8 | Avg: 3.50 chars/token
|
||||
# Vocab: 50000 | Tokens: 7 | Avg: 4.00 chars/token
|
||||
# Vocab: 100000 | Tokens: 6 | Avg: 4.67 chars/token
|
||||
```
|
||||
|
||||
## Testing tokenizer quality
|
||||
|
||||
### Coverage test
|
||||
|
||||
```python
|
||||
# Test on held-out data
|
||||
test_corpus = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
|
||||
|
||||
total_tokens = 0
|
||||
unk_tokens = 0
|
||||
unk_id = tokenizer.token_to_id("[UNK]")
|
||||
|
||||
for text in test_corpus["text"]:
|
||||
if text.strip():
|
||||
encoding = tokenizer.encode(text)
|
||||
total_tokens += len(encoding.ids)
|
||||
unk_tokens += encoding.ids.count(unk_id)
|
||||
|
||||
unk_rate = unk_tokens / total_tokens
|
||||
print(f"Unknown token rate: {unk_rate:.2%}")
|
||||
|
||||
# Good quality: <1% unknown tokens
|
||||
# Acceptable: 1-5%
|
||||
# Poor: >5%
|
||||
```
|
||||
|
||||
### Compression test
|
||||
|
||||
```python
|
||||
# Measure tokenization efficiency
|
||||
import numpy as np
|
||||
|
||||
token_lengths = []
|
||||
|
||||
for text in test_corpus["text"][:1000]:
|
||||
if text.strip():
|
||||
encoding = tokenizer.encode(text)
|
||||
chars_per_token = len(text) / len(encoding.ids)
|
||||
token_lengths.append(chars_per_token)
|
||||
|
||||
avg_chars_per_token = np.mean(token_lengths)
|
||||
print(f"Average characters per token: {avg_chars_per_token:.2f}")
|
||||
|
||||
# Good: 4-6 chars/token (English)
|
||||
# Acceptable: 3-4 chars/token
|
||||
# Poor: <3 chars/token (under-compression)
|
||||
```
|
||||
|
||||
### Semantic test
|
||||
|
||||
```python
|
||||
# Manually inspect tokenization of common words/phrases
|
||||
test_phrases = [
|
||||
"tokenization",
|
||||
"machine learning",
|
||||
"artificial intelligence",
|
||||
"preprocessing",
|
||||
"hello world"
|
||||
]
|
||||
|
||||
for phrase in test_phrases:
|
||||
tokens = tokenizer.encode(phrase).tokens
|
||||
print(f"{phrase:25s} → {tokens}")
|
||||
|
||||
# Good tokenization:
|
||||
# tokenization → ['token', 'ization']
|
||||
# machine learning → ['machine', 'learning']
|
||||
# artificial intelligence → ['artificial', 'intelligence']
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Training too slow
|
||||
|
||||
**Solutions**:
|
||||
1. Reduce vocabulary size
|
||||
2. Increase `min_frequency`
|
||||
3. Use `limit_alphabet` to reduce initial alphabet
|
||||
4. Train on subset first
|
||||
|
||||
```python
|
||||
# Fast training configuration
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=20000, # Smaller vocab
|
||||
min_frequency=5, # Higher threshold
|
||||
limit_alphabet=500, # Limit alphabet
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: High unknown token rate
|
||||
|
||||
**Solutions**:
|
||||
1. Increase vocabulary size
|
||||
2. Decrease `min_frequency`
|
||||
3. Check normalization (might be too aggressive)
|
||||
|
||||
```python
|
||||
# Better coverage configuration
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000, # Larger vocab
|
||||
min_frequency=1, # Lower threshold
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Poor quality tokenization
|
||||
|
||||
**Solutions**:
|
||||
1. Verify normalization matches your use case
|
||||
2. Check pre-tokenization splits correctly
|
||||
3. Ensure training data is representative
|
||||
4. Try different algorithm (BPE vs WordPiece vs Unigram)
|
||||
|
||||
```python
|
||||
# Debug tokenization pipeline
|
||||
text = "Sample text to debug"
|
||||
|
||||
# Check normalization
|
||||
normalized = tokenizer.normalizer.normalize_str(text)
|
||||
print(f"Normalized: {normalized}")
|
||||
|
||||
# Check pre-tokenization
|
||||
pre_tokens = tokenizer.pre_tokenizer.pre_tokenize_str(text)
|
||||
print(f"Pre-tokens: {pre_tokens}")
|
||||
|
||||
# Check final tokenization
|
||||
tokens = tokenizer.encode(text).tokens
|
||||
print(f"Tokens: {tokens}")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use representative training data** - Match your target domain
|
||||
2. **Start with standard configs** - BERT WordPiece or GPT-2 BPE
|
||||
3. **Test on held-out data** - Measure unknown token rate
|
||||
4. **Iterate on vocabulary size** - Test 30k, 50k, 100k
|
||||
5. **Save tokenizer with model** - Ensure reproducibility
|
||||
6. **Version your tokenizers** - Track changes for reproducibility
|
||||
7. **Document special tokens** - Critical for model training
|
||||
@@ -0,0 +1,493 @@
|
||||
---
|
||||
name: evaluating-llms-harness
|
||||
description: Evaluates LLMs across 60+ academic benchmarks (MMLU, HumanEval, GSM8K, TruthfulQA, HellaSwag). Use when benchmarking model quality, comparing models, reporting academic results, or tracking training progress. Industry standard used by EleutherAI, HuggingFace, and major labs. Supports HuggingFace, vLLM, APIs.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [lm-eval, transformers, vllm]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Evaluation, LM Evaluation Harness, Benchmarking, MMLU, HumanEval, GSM8K, EleutherAI, Model Quality, Academic Benchmarks, Industry Standard]
|
||||
|
||||
---
|
||||
|
||||
# lm-evaluation-harness - LLM Benchmarking
|
||||
|
||||
## Quick start
|
||||
|
||||
lm-evaluation-harness evaluates LLMs across 60+ academic benchmarks using standardized prompts and metrics.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install lm-eval
|
||||
```
|
||||
|
||||
**Evaluate any HuggingFace model**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--device cuda:0 \
|
||||
--batch_size 8
|
||||
```
|
||||
|
||||
**View available tasks**:
|
||||
```bash
|
||||
lm_eval --tasks list
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Standard benchmark evaluation
|
||||
|
||||
Evaluate model on core benchmarks (MMLU, GSM8K, HumanEval).
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Benchmark Evaluation:
|
||||
- [ ] Step 1: Choose benchmark suite
|
||||
- [ ] Step 2: Configure model
|
||||
- [ ] Step 3: Run evaluation
|
||||
- [ ] Step 4: Analyze results
|
||||
```
|
||||
|
||||
**Step 1: Choose benchmark suite**
|
||||
|
||||
**Core reasoning benchmarks**:
|
||||
- **MMLU** (Massive Multitask Language Understanding) - 57 subjects, multiple choice
|
||||
- **GSM8K** - Grade school math word problems
|
||||
- **HellaSwag** - Common sense reasoning
|
||||
- **TruthfulQA** - Truthfulness and factuality
|
||||
- **ARC** (AI2 Reasoning Challenge) - Science questions
|
||||
|
||||
**Code benchmarks**:
|
||||
- **HumanEval** - Python code generation (164 problems)
|
||||
- **MBPP** (Mostly Basic Python Problems) - Python coding
|
||||
|
||||
**Standard suite** (recommended for model releases):
|
||||
```bash
|
||||
--tasks mmlu,gsm8k,hellaswag,truthfulqa,arc_challenge
|
||||
```
|
||||
|
||||
**Step 2: Configure model**
|
||||
|
||||
**HuggingFace model**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,dtype=bfloat16 \
|
||||
--tasks mmlu \
|
||||
--device cuda:0 \
|
||||
--batch_size auto # Auto-detect optimal batch size
|
||||
```
|
||||
|
||||
**Quantized model (4-bit/8-bit)**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,load_in_4bit=True \
|
||||
--tasks mmlu \
|
||||
--device cuda:0
|
||||
```
|
||||
|
||||
**Custom checkpoint**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=/path/to/my-model,tokenizer=/path/to/tokenizer \
|
||||
--tasks mmlu \
|
||||
--device cuda:0
|
||||
```
|
||||
|
||||
**Step 3: Run evaluation**
|
||||
|
||||
```bash
|
||||
# Full MMLU evaluation (57 subjects)
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu \
|
||||
--num_fewshot 5 \ # 5-shot evaluation (standard)
|
||||
--batch_size 8 \
|
||||
--output_path results/ \
|
||||
--log_samples # Save individual predictions
|
||||
|
||||
# Multiple benchmarks at once
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu,gsm8k,hellaswag,truthfulqa,arc_challenge \
|
||||
--num_fewshot 5 \
|
||||
--batch_size 8 \
|
||||
--output_path results/llama2-7b-eval.json
|
||||
```
|
||||
|
||||
**Step 4: Analyze results**
|
||||
|
||||
Results saved to `results/llama2-7b-eval.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"results": {
|
||||
"mmlu": {
|
||||
"acc": 0.459,
|
||||
"acc_stderr": 0.004
|
||||
},
|
||||
"gsm8k": {
|
||||
"exact_match": 0.142,
|
||||
"exact_match_stderr": 0.006
|
||||
},
|
||||
"hellaswag": {
|
||||
"acc_norm": 0.765,
|
||||
"acc_norm_stderr": 0.004
|
||||
}
|
||||
},
|
||||
"config": {
|
||||
"model": "hf",
|
||||
"model_args": "pretrained=meta-llama/Llama-2-7b-hf",
|
||||
"num_fewshot": 5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Workflow 2: Track training progress
|
||||
|
||||
Evaluate checkpoints during training.
|
||||
|
||||
```
|
||||
Training Progress Tracking:
|
||||
- [ ] Step 1: Set up periodic evaluation
|
||||
- [ ] Step 2: Choose quick benchmarks
|
||||
- [ ] Step 3: Automate evaluation
|
||||
- [ ] Step 4: Plot learning curves
|
||||
```
|
||||
|
||||
**Step 1: Set up periodic evaluation**
|
||||
|
||||
Evaluate every N training steps:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# eval_checkpoint.sh
|
||||
|
||||
CHECKPOINT_DIR=$1
|
||||
STEP=$2
|
||||
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=$CHECKPOINT_DIR/checkpoint-$STEP \
|
||||
--tasks gsm8k,hellaswag \
|
||||
--num_fewshot 0 \ # 0-shot for speed
|
||||
--batch_size 16 \
|
||||
--output_path results/step-$STEP.json
|
||||
```
|
||||
|
||||
**Step 2: Choose quick benchmarks**
|
||||
|
||||
Fast benchmarks for frequent evaluation:
|
||||
- **HellaSwag**: ~10 minutes on 1 GPU
|
||||
- **GSM8K**: ~5 minutes
|
||||
- **PIQA**: ~2 minutes
|
||||
|
||||
Avoid for frequent eval (too slow):
|
||||
- **MMLU**: ~2 hours (57 subjects)
|
||||
- **HumanEval**: Requires code execution
|
||||
|
||||
**Step 3: Automate evaluation**
|
||||
|
||||
Integrate with training script:
|
||||
|
||||
```python
|
||||
# In training loop
|
||||
if step % eval_interval == 0:
|
||||
model.save_pretrained(f"checkpoints/step-{step}")
|
||||
|
||||
# Run evaluation
|
||||
os.system(f"./eval_checkpoint.sh checkpoints step-{step}")
|
||||
```
|
||||
|
||||
Or use PyTorch Lightning callbacks:
|
||||
|
||||
```python
|
||||
from pytorch_lightning import Callback
|
||||
|
||||
class EvalHarnessCallback(Callback):
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
step = trainer.global_step
|
||||
checkpoint_path = f"checkpoints/step-{step}"
|
||||
|
||||
# Save checkpoint
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
|
||||
# Run lm-eval
|
||||
os.system(f"lm_eval --model hf --model_args pretrained={checkpoint_path} ...")
|
||||
```
|
||||
|
||||
**Step 4: Plot learning curves**
|
||||
|
||||
```python
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Load all results
|
||||
steps = []
|
||||
mmlu_scores = []
|
||||
|
||||
for file in sorted(glob.glob("results/step-*.json")):
|
||||
with open(file) as f:
|
||||
data = json.load(f)
|
||||
step = int(file.split("-")[1].split(".")[0])
|
||||
steps.append(step)
|
||||
mmlu_scores.append(data["results"]["mmlu"]["acc"])
|
||||
|
||||
# Plot
|
||||
plt.plot(steps, mmlu_scores)
|
||||
plt.xlabel("Training Step")
|
||||
plt.ylabel("MMLU Accuracy")
|
||||
plt.title("Training Progress")
|
||||
plt.savefig("training_curve.png")
|
||||
```
|
||||
|
||||
### Workflow 3: Compare multiple models
|
||||
|
||||
Benchmark suite for model comparison.
|
||||
|
||||
```
|
||||
Model Comparison:
|
||||
- [ ] Step 1: Define model list
|
||||
- [ ] Step 2: Run evaluations
|
||||
- [ ] Step 3: Generate comparison table
|
||||
```
|
||||
|
||||
**Step 1: Define model list**
|
||||
|
||||
```bash
|
||||
# models.txt
|
||||
meta-llama/Llama-2-7b-hf
|
||||
meta-llama/Llama-2-13b-hf
|
||||
mistralai/Mistral-7B-v0.1
|
||||
microsoft/phi-2
|
||||
```
|
||||
|
||||
**Step 2: Run evaluations**
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# eval_all_models.sh
|
||||
|
||||
TASKS="mmlu,gsm8k,hellaswag,truthfulqa"
|
||||
|
||||
while read model; do
|
||||
echo "Evaluating $model"
|
||||
|
||||
# Extract model name for output file
|
||||
model_name=$(echo $model | sed 's/\//-/g')
|
||||
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=$model,dtype=bfloat16 \
|
||||
--tasks $TASKS \
|
||||
--num_fewshot 5 \
|
||||
--batch_size auto \
|
||||
--output_path results/$model_name.json
|
||||
|
||||
done < models.txt
|
||||
```
|
||||
|
||||
**Step 3: Generate comparison table**
|
||||
|
||||
```python
|
||||
import json
|
||||
import pandas as pd
|
||||
|
||||
models = [
|
||||
"meta-llama-Llama-2-7b-hf",
|
||||
"meta-llama-Llama-2-13b-hf",
|
||||
"mistralai-Mistral-7B-v0.1",
|
||||
"microsoft-phi-2"
|
||||
]
|
||||
|
||||
tasks = ["mmlu", "gsm8k", "hellaswag", "truthfulqa"]
|
||||
|
||||
results = []
|
||||
for model in models:
|
||||
with open(f"results/{model}.json") as f:
|
||||
data = json.load(f)
|
||||
row = {"Model": model.replace("-", "/")}
|
||||
for task in tasks:
|
||||
# Get primary metric for each task
|
||||
metrics = data["results"][task]
|
||||
if "acc" in metrics:
|
||||
row[task.upper()] = f"{metrics['acc']:.3f}"
|
||||
elif "exact_match" in metrics:
|
||||
row[task.upper()] = f"{metrics['exact_match']:.3f}"
|
||||
results.append(row)
|
||||
|
||||
df = pd.DataFrame(results)
|
||||
print(df.to_markdown(index=False))
|
||||
```
|
||||
|
||||
Output:
|
||||
```
|
||||
| Model | MMLU | GSM8K | HELLASWAG | TRUTHFULQA |
|
||||
|------------------------|-------|-------|-----------|------------|
|
||||
| meta-llama/Llama-2-7b | 0.459 | 0.142 | 0.765 | 0.391 |
|
||||
| meta-llama/Llama-2-13b | 0.549 | 0.287 | 0.801 | 0.430 |
|
||||
| mistralai/Mistral-7B | 0.626 | 0.395 | 0.812 | 0.428 |
|
||||
| microsoft/phi-2 | 0.560 | 0.613 | 0.682 | 0.447 |
|
||||
```
|
||||
|
||||
### Workflow 4: Evaluate with vLLM (faster inference)
|
||||
|
||||
Use vLLM backend for 5-10x faster evaluation.
|
||||
|
||||
```
|
||||
vLLM Evaluation:
|
||||
- [ ] Step 1: Install vLLM
|
||||
- [ ] Step 2: Configure vLLM backend
|
||||
- [ ] Step 3: Run evaluation
|
||||
```
|
||||
|
||||
**Step 1: Install vLLM**
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
```
|
||||
|
||||
**Step 2: Configure vLLM backend**
|
||||
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8 \
|
||||
--tasks mmlu \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Step 3: Run evaluation**
|
||||
|
||||
vLLM is 5-10× faster than standard HuggingFace:
|
||||
|
||||
```bash
|
||||
# Standard HF: ~2 hours for MMLU on 7B model
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu \
|
||||
--batch_size 8
|
||||
|
||||
# vLLM: ~15-20 minutes for MMLU on 7B model
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,tensor_parallel_size=2 \
|
||||
--tasks mmlu \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use lm-evaluation-harness when:**
|
||||
- Benchmarking models for academic papers
|
||||
- Comparing model quality across standard tasks
|
||||
- Tracking training progress
|
||||
- Reporting standardized metrics (everyone uses same prompts)
|
||||
- Need reproducible evaluation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **HELM** (Stanford): Broader evaluation (fairness, efficiency, calibration)
|
||||
- **AlpacaEval**: Instruction-following evaluation with LLM judges
|
||||
- **MT-Bench**: Conversational multi-turn evaluation
|
||||
- **Custom scripts**: Domain-specific evaluation
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Evaluation too slow**
|
||||
|
||||
Use vLLM backend:
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=model-name,tensor_parallel_size=2
|
||||
```
|
||||
|
||||
Or reduce fewshot examples:
|
||||
```bash
|
||||
--num_fewshot 0 # Instead of 5
|
||||
```
|
||||
|
||||
Or evaluate subset of MMLU:
|
||||
```bash
|
||||
--tasks mmlu_stem # Only STEM subjects
|
||||
```
|
||||
|
||||
**Issue: Out of memory**
|
||||
|
||||
Reduce batch size:
|
||||
```bash
|
||||
--batch_size 1 # Or --batch_size auto
|
||||
```
|
||||
|
||||
Use quantization:
|
||||
```bash
|
||||
--model_args pretrained=model-name,load_in_8bit=True
|
||||
```
|
||||
|
||||
Enable CPU offloading:
|
||||
```bash
|
||||
--model_args pretrained=model-name,device_map=auto,offload_folder=offload
|
||||
```
|
||||
|
||||
**Issue: Different results than reported**
|
||||
|
||||
Check fewshot count:
|
||||
```bash
|
||||
--num_fewshot 5 # Most papers use 5-shot
|
||||
```
|
||||
|
||||
Check exact task name:
|
||||
```bash
|
||||
--tasks mmlu # Not mmlu_direct or mmlu_fewshot
|
||||
```
|
||||
|
||||
Verify model and tokenizer match:
|
||||
```bash
|
||||
--model_args pretrained=model-name,tokenizer=same-model-name
|
||||
```
|
||||
|
||||
**Issue: HumanEval not executing code**
|
||||
|
||||
Install execution dependencies:
|
||||
```bash
|
||||
pip install human-eval
|
||||
```
|
||||
|
||||
Enable code execution:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=model-name \
|
||||
--tasks humaneval \
|
||||
--allow_code_execution # Required for HumanEval
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Benchmark descriptions**: See [references/benchmark-guide.md](references/benchmark-guide.md) for detailed description of all 60+ tasks, what they measure, and interpretation.
|
||||
|
||||
**Custom tasks**: See [references/custom-tasks.md](references/custom-tasks.md) for creating domain-specific evaluation tasks.
|
||||
|
||||
**API evaluation**: See [references/api-evaluation.md](references/api-evaluation.md) for evaluating OpenAI, Anthropic, and other API models.
|
||||
|
||||
**Multi-GPU strategies**: See [references/distributed-eval.md](references/distributed-eval.md) for data parallel and tensor parallel evaluation.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA (CUDA 11.8+), works on CPU (very slow)
|
||||
- **VRAM**:
|
||||
- 7B model: 16GB (bf16) or 8GB (8-bit)
|
||||
- 13B model: 28GB (bf16) or 14GB (8-bit)
|
||||
- 70B model: Requires multi-GPU or quantization
|
||||
- **Time** (7B model, single A100):
|
||||
- HellaSwag: 10 minutes
|
||||
- GSM8K: 5 minutes
|
||||
- MMLU (full): 2 hours
|
||||
- HumanEval: 20 minutes
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub: https://github.com/EleutherAI/lm-evaluation-harness
|
||||
- Docs: https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs
|
||||
- Task library: 60+ tasks including MMLU, GSM8K, HumanEval, TruthfulQA, HellaSwag, ARC, WinoGrande, etc.
|
||||
- Leaderboard: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard (uses this harness)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,490 @@
|
||||
# API Evaluation
|
||||
|
||||
Guide to evaluating OpenAI, Anthropic, and other API-based language models.
|
||||
|
||||
## Overview
|
||||
|
||||
The lm-evaluation-harness supports evaluating API-based models through a unified `TemplateAPI` interface. This allows benchmarking of:
|
||||
- OpenAI models (GPT-4, GPT-3.5, etc.)
|
||||
- Anthropic models (Claude 3, Claude 2, etc.)
|
||||
- Local OpenAI-compatible APIs
|
||||
- Custom API endpoints
|
||||
|
||||
**Why evaluate API models**:
|
||||
- Benchmark closed-source models
|
||||
- Compare API models to open models
|
||||
- Validate API performance
|
||||
- Track model updates over time
|
||||
|
||||
## Supported API Models
|
||||
|
||||
| Provider | Model Type | Request Types | Logprobs |
|
||||
|----------|------------|---------------|----------|
|
||||
| OpenAI (completions) | `openai-completions` | All | ✅ Yes |
|
||||
| OpenAI (chat) | `openai-chat-completions` | `generate_until` only | ❌ No |
|
||||
| Anthropic (completions) | `anthropic-completions` | All | ❌ No |
|
||||
| Anthropic (chat) | `anthropic-chat` | `generate_until` only | ❌ No |
|
||||
| Local (OpenAI-compatible) | `local-completions` | Depends on server | Varies |
|
||||
|
||||
**Note**: Models without logprobs can only be evaluated on generation tasks, not perplexity or loglikelihood tasks.
|
||||
|
||||
## OpenAI Models
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY=sk-...
|
||||
```
|
||||
|
||||
### Completion Models (Legacy)
|
||||
|
||||
**Available models**: `davinci-002`, `babbage-002`
|
||||
|
||||
```bash
|
||||
lm_eval --model openai-completions \
|
||||
--model_args model=davinci-002 \
|
||||
--tasks lambada_openai,hellaswag \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Supports**:
|
||||
- `generate_until`: ✅
|
||||
- `loglikelihood`: ✅
|
||||
- `loglikelihood_rolling`: ✅
|
||||
|
||||
### Chat Models
|
||||
|
||||
**Available models**: `gpt-4`, `gpt-4-turbo`, `gpt-3.5-turbo`
|
||||
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu,gsm8k,humaneval \
|
||||
--num_fewshot 5 \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Supports**:
|
||||
- `generate_until`: ✅
|
||||
- `loglikelihood`: ❌ (no logprobs)
|
||||
- `loglikelihood_rolling`: ❌
|
||||
|
||||
**Important**: Chat models don't provide logprobs, so they can only be used with generation tasks (MMLU, GSM8K, HumanEval), not perplexity tasks.
|
||||
|
||||
### Configuration Options
|
||||
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args \
|
||||
model=gpt-4-turbo,\
|
||||
base_url=https://api.openai.com/v1,\
|
||||
num_concurrent=5,\
|
||||
max_retries=3,\
|
||||
timeout=60,\
|
||||
batch_size=auto
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `model`: Model identifier (required)
|
||||
- `base_url`: API endpoint (default: OpenAI)
|
||||
- `num_concurrent`: Concurrent requests (default: 5)
|
||||
- `max_retries`: Retry failed requests (default: 3)
|
||||
- `timeout`: Request timeout in seconds (default: 60)
|
||||
- `tokenizer`: Tokenizer to use (default: matches model)
|
||||
- `tokenizer_backend`: `"tiktoken"` or `"huggingface"`
|
||||
|
||||
### Cost Management
|
||||
|
||||
OpenAI charges per token. Estimate costs before running:
|
||||
|
||||
```python
|
||||
# Rough estimate
|
||||
num_samples = 1000
|
||||
avg_tokens_per_sample = 500 # input + output
|
||||
cost_per_1k_tokens = 0.01 # GPT-3.5 Turbo
|
||||
|
||||
total_cost = (num_samples * avg_tokens_per_sample / 1000) * cost_per_1k_tokens
|
||||
print(f"Estimated cost: ${total_cost:.2f}")
|
||||
```
|
||||
|
||||
**Cost-saving tips**:
|
||||
- Use `--limit N` for testing
|
||||
- Start with `gpt-3.5-turbo` before `gpt-4`
|
||||
- Set `max_gen_toks` to minimum needed
|
||||
- Use `num_fewshot=0` for zero-shot when possible
|
||||
|
||||
## Anthropic Models
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
export ANTHROPIC_API_KEY=sk-ant-...
|
||||
```
|
||||
|
||||
### Completion Models (Legacy)
|
||||
|
||||
```bash
|
||||
lm_eval --model anthropic-completions \
|
||||
--model_args model=claude-2.1 \
|
||||
--tasks lambada_openai,hellaswag \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
### Chat Models (Recommended)
|
||||
|
||||
**Available models**: `claude-3-5-sonnet-20241022`, `claude-3-opus-20240229`, `claude-3-sonnet-20240229`, `claude-3-haiku-20240307`
|
||||
|
||||
```bash
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args model=claude-3-5-sonnet-20241022 \
|
||||
--tasks mmlu,gsm8k,humaneval \
|
||||
--num_fewshot 5 \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Aliases**: `anthropic-chat-completions` (same as `anthropic-chat`)
|
||||
|
||||
### Configuration Options
|
||||
|
||||
```bash
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args \
|
||||
model=claude-3-5-sonnet-20241022,\
|
||||
base_url=https://api.anthropic.com,\
|
||||
num_concurrent=5,\
|
||||
max_retries=3,\
|
||||
timeout=60
|
||||
```
|
||||
|
||||
### Cost Management
|
||||
|
||||
Anthropic pricing (as of 2024):
|
||||
- Claude 3.5 Sonnet: $3.00 / 1M input, $15.00 / 1M output
|
||||
- Claude 3 Opus: $15.00 / 1M input, $75.00 / 1M output
|
||||
- Claude 3 Haiku: $0.25 / 1M input, $1.25 / 1M output
|
||||
|
||||
**Budget-friendly strategy**:
|
||||
```bash
|
||||
# Test on small sample first
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args model=claude-3-haiku-20240307 \
|
||||
--tasks mmlu \
|
||||
--limit 100
|
||||
|
||||
# Then run full eval on best model
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args model=claude-3-5-sonnet-20241022 \
|
||||
--tasks mmlu \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
## Local OpenAI-Compatible APIs
|
||||
|
||||
Many local inference servers expose OpenAI-compatible APIs (vLLM, Text Generation Inference, llama.cpp, Ollama).
|
||||
|
||||
### vLLM Local Server
|
||||
|
||||
**Start server**:
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-2-7b-hf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8000
|
||||
```
|
||||
|
||||
**Evaluate**:
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
model=meta-llama/Llama-2-7b-hf,\
|
||||
base_url=http://localhost:8000/v1,\
|
||||
num_concurrent=1 \
|
||||
--tasks mmlu,gsm8k \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
### Text Generation Inference (TGI)
|
||||
|
||||
**Start server**:
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 \
|
||||
ghcr.io/huggingface/text-generation-inference:latest \
|
||||
--model-id meta-llama/Llama-2-7b-hf
|
||||
```
|
||||
|
||||
**Evaluate**:
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
model=meta-llama/Llama-2-7b-hf,\
|
||||
base_url=http://localhost:8080/v1 \
|
||||
--tasks hellaswag,arc_challenge
|
||||
```
|
||||
|
||||
### Ollama
|
||||
|
||||
**Start server**:
|
||||
```bash
|
||||
ollama serve
|
||||
ollama pull llama2:7b
|
||||
```
|
||||
|
||||
**Evaluate**:
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
model=llama2:7b,\
|
||||
base_url=http://localhost:11434/v1 \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
### llama.cpp Server
|
||||
|
||||
**Start server**:
|
||||
```bash
|
||||
./server -m models/llama-2-7b.gguf --host 0.0.0.0 --port 8080
|
||||
```
|
||||
|
||||
**Evaluate**:
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
model=llama2,\
|
||||
base_url=http://localhost:8080/v1 \
|
||||
--tasks gsm8k
|
||||
```
|
||||
|
||||
## Custom API Implementation
|
||||
|
||||
For custom API endpoints, subclass `TemplateAPI`:
|
||||
|
||||
### Create `my_api.py`
|
||||
|
||||
```python
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
import requests
|
||||
|
||||
class MyCustomAPI(TemplateAPI):
|
||||
"""Custom API model."""
|
||||
|
||||
def __init__(self, base_url, api_key, **kwargs):
|
||||
super().__init__(base_url=base_url, **kwargs)
|
||||
self.api_key = api_key
|
||||
|
||||
def _create_payload(self, messages, gen_kwargs):
|
||||
"""Create API request payload."""
|
||||
return {
|
||||
"messages": messages,
|
||||
"api_key": self.api_key,
|
||||
**gen_kwargs
|
||||
}
|
||||
|
||||
def parse_generations(self, response):
|
||||
"""Parse generation response."""
|
||||
return response.json()["choices"][0]["text"]
|
||||
|
||||
def parse_logprobs(self, response):
|
||||
"""Parse logprobs (if available)."""
|
||||
# Return None if API doesn't provide logprobs
|
||||
logprobs = response.json().get("logprobs")
|
||||
if logprobs:
|
||||
return logprobs["token_logprobs"]
|
||||
return None
|
||||
```
|
||||
|
||||
### Register and Use
|
||||
|
||||
```python
|
||||
from lm_eval import evaluator
|
||||
from my_api import MyCustomAPI
|
||||
|
||||
model = MyCustomAPI(
|
||||
base_url="https://api.example.com/v1",
|
||||
api_key="your-key"
|
||||
)
|
||||
|
||||
results = evaluator.simple_evaluate(
|
||||
model=model,
|
||||
tasks=["mmlu", "gsm8k"],
|
||||
num_fewshot=5,
|
||||
batch_size="auto"
|
||||
)
|
||||
```
|
||||
|
||||
## Comparing API and Open Models
|
||||
|
||||
### Side-by-Side Evaluation
|
||||
|
||||
```bash
|
||||
# Evaluate OpenAI GPT-4
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--num_fewshot 5 \
|
||||
--output_path results/gpt4.json
|
||||
|
||||
# Evaluate open Llama 2 70B
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-70b-hf,dtype=bfloat16 \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--num_fewshot 5 \
|
||||
--output_path results/llama2-70b.json
|
||||
|
||||
# Compare results
|
||||
python scripts/compare_results.py \
|
||||
results/gpt4.json \
|
||||
results/llama2-70b.json
|
||||
```
|
||||
|
||||
### Typical Comparisons
|
||||
|
||||
| Model | MMLU | GSM8K | HumanEval | Cost |
|
||||
|-------|------|-------|-----------|------|
|
||||
| GPT-4 Turbo | 86.4% | 92.0% | 67.0% | $$$$ |
|
||||
| Claude 3 Opus | 86.8% | 95.0% | 84.9% | $$$$ |
|
||||
| GPT-3.5 Turbo | 70.0% | 57.1% | 48.1% | $$ |
|
||||
| Llama 2 70B | 68.9% | 56.8% | 29.9% | Free (self-host) |
|
||||
| Mixtral 8x7B | 70.6% | 58.4% | 40.2% | Free (self-host) |
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Rate Limiting
|
||||
|
||||
Respect API rate limits:
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args \
|
||||
model=gpt-4-turbo,\
|
||||
num_concurrent=3,\ # Lower concurrency
|
||||
timeout=120 \ # Longer timeout
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
### Reproducibility
|
||||
|
||||
Set temperature to 0 for deterministic results:
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu \
|
||||
--gen_kwargs temperature=0.0
|
||||
```
|
||||
|
||||
Or use `seed` for sampling:
|
||||
```bash
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args model=claude-3-5-sonnet-20241022 \
|
||||
--tasks gsm8k \
|
||||
--gen_kwargs temperature=0.7,seed=42
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
API models automatically cache responses to avoid redundant calls:
|
||||
```bash
|
||||
# First run: makes API calls
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu \
|
||||
--limit 100
|
||||
|
||||
# Second run: uses cache (instant, free)
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu \
|
||||
--limit 100
|
||||
```
|
||||
|
||||
Cache location: `~/.cache/lm_eval/`
|
||||
|
||||
### Error Handling
|
||||
|
||||
APIs can fail. Use retries:
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args \
|
||||
model=gpt-4-turbo,\
|
||||
max_retries=5,\
|
||||
timeout=120 \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Authentication failed"
|
||||
|
||||
Check API key:
|
||||
```bash
|
||||
echo $OPENAI_API_KEY # Should print sk-...
|
||||
echo $ANTHROPIC_API_KEY # Should print sk-ant-...
|
||||
```
|
||||
|
||||
### "Rate limit exceeded"
|
||||
|
||||
Reduce concurrency:
|
||||
```bash
|
||||
--model_args num_concurrent=1
|
||||
```
|
||||
|
||||
Or add delays between requests.
|
||||
|
||||
### "Timeout error"
|
||||
|
||||
Increase timeout:
|
||||
```bash
|
||||
--model_args timeout=180
|
||||
```
|
||||
|
||||
### "Model not found"
|
||||
|
||||
For local APIs, verify server is running:
|
||||
```bash
|
||||
curl http://localhost:8000/v1/models
|
||||
```
|
||||
|
||||
### Cost Runaway
|
||||
|
||||
Use `--limit` for testing:
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu \
|
||||
--limit 50 # Only 50 samples
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Custom Headers
|
||||
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
base_url=http://api.example.com/v1,\
|
||||
header="Authorization: Bearer token,X-Custom: value"
|
||||
```
|
||||
|
||||
### Disable SSL Verification (Development Only)
|
||||
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
base_url=https://localhost:8000/v1,\
|
||||
verify_certificate=false
|
||||
```
|
||||
|
||||
### Custom Tokenizer
|
||||
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args \
|
||||
model=gpt-4-turbo,\
|
||||
tokenizer=gpt2,\
|
||||
tokenizer_backend=huggingface
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- OpenAI API: https://platform.openai.com/docs/api-reference
|
||||
- Anthropic API: https://docs.anthropic.com/claude/reference
|
||||
- TemplateAPI: `lm_eval/models/api_models.py`
|
||||
- OpenAI models: `lm_eval/models/openai_completions.py`
|
||||
- Anthropic models: `lm_eval/models/anthropic_llms.py`
|
||||
@@ -0,0 +1,488 @@
|
||||
# Benchmark Guide
|
||||
|
||||
Complete guide to all 60+ evaluation tasks in lm-evaluation-harness, what they measure, and how to interpret results.
|
||||
|
||||
## Overview
|
||||
|
||||
The lm-evaluation-harness includes 60+ benchmarks spanning:
|
||||
- Language understanding (MMLU, GLUE)
|
||||
- Mathematical reasoning (GSM8K, MATH)
|
||||
- Code generation (HumanEval, MBPP)
|
||||
- Instruction following (IFEval, AlpacaEval)
|
||||
- Long-context understanding (LongBench)
|
||||
- Multilingual capabilities (AfroBench, NorEval)
|
||||
- Reasoning (BBH, ARC)
|
||||
- Truthfulness (TruthfulQA)
|
||||
|
||||
**List all tasks**:
|
||||
```bash
|
||||
lm_eval --tasks list
|
||||
```
|
||||
|
||||
## Major Benchmarks
|
||||
|
||||
### MMLU (Massive Multitask Language Understanding)
|
||||
|
||||
**What it measures**: Broad knowledge across 57 subjects (STEM, humanities, social sciences, law).
|
||||
|
||||
**Task variants**:
|
||||
- `mmlu`: Original 57-subject benchmark
|
||||
- `mmlu_pro`: More challenging version with reasoning-focused questions
|
||||
- `mmlu_prox`: Multilingual extension
|
||||
|
||||
**Format**: Multiple choice (4 options)
|
||||
|
||||
**Example**:
|
||||
```
|
||||
Question: What is the capital of France?
|
||||
A. Berlin
|
||||
B. Paris
|
||||
C. London
|
||||
D. Madrid
|
||||
Answer: B
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: 25% (chance)
|
||||
- GPT-3 (175B): 43.9%
|
||||
- GPT-4: 86.4%
|
||||
- Human expert: ~90%
|
||||
|
||||
**Good for**: Assessing general knowledge and domain expertise.
|
||||
|
||||
### GSM8K (Grade School Math 8K)
|
||||
|
||||
**What it measures**: Mathematical reasoning on grade-school level word problems.
|
||||
|
||||
**Task variants**:
|
||||
- `gsm8k`: Base task
|
||||
- `gsm8k_cot`: With chain-of-thought prompting
|
||||
- `gsm_plus`: Adversarial variant with perturbations
|
||||
|
||||
**Format**: Free-form generation, extract numerical answer
|
||||
|
||||
**Example**:
|
||||
```
|
||||
Question: A baker made 200 cookies. He sold 3/5 of them in the morning and 1/4 of the remaining in the afternoon. How many cookies does he have left?
|
||||
Answer: 60
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks gsm8k \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: ~0%
|
||||
- GPT-3 (175B): 17.0%
|
||||
- GPT-4: 92.0%
|
||||
- Llama 2 70B: 56.8%
|
||||
|
||||
**Good for**: Testing multi-step reasoning and arithmetic.
|
||||
|
||||
### HumanEval
|
||||
|
||||
**What it measures**: Python code generation from docstrings (functional correctness).
|
||||
|
||||
**Task variants**:
|
||||
- `humaneval`: Standard benchmark
|
||||
- `humaneval_instruct`: For instruction-tuned models
|
||||
|
||||
**Format**: Code generation, execution-based evaluation
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
def has_close_elements(numbers: List[float], threshold: float) -> bool:
|
||||
""" Check if in given list of numbers, are any two numbers closer to each other than
|
||||
given threshold.
|
||||
>>> has_close_elements([1.0, 2.0, 3.0], 0.5)
|
||||
False
|
||||
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
|
||||
True
|
||||
"""
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=codellama/CodeLlama-7b-hf \
|
||||
--tasks humaneval \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: 0%
|
||||
- GPT-3 (175B): 0%
|
||||
- Codex: 28.8%
|
||||
- GPT-4: 67.0%
|
||||
- Code Llama 34B: 53.7%
|
||||
|
||||
**Good for**: Evaluating code generation capabilities.
|
||||
|
||||
### BBH (BIG-Bench Hard)
|
||||
|
||||
**What it measures**: 23 challenging reasoning tasks where models previously failed to beat humans.
|
||||
|
||||
**Categories**:
|
||||
- Logical reasoning
|
||||
- Math word problems
|
||||
- Social understanding
|
||||
- Algorithmic reasoning
|
||||
|
||||
**Format**: Multiple choice and free-form
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks bbh \
|
||||
--num_fewshot 3
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: ~25%
|
||||
- GPT-3 (175B): 33.9%
|
||||
- PaLM 540B: 58.3%
|
||||
- GPT-4: 86.7%
|
||||
|
||||
**Good for**: Testing advanced reasoning capabilities.
|
||||
|
||||
### IFEval (Instruction-Following Evaluation)
|
||||
|
||||
**What it measures**: Ability to follow specific, verifiable instructions.
|
||||
|
||||
**Instruction types**:
|
||||
- Format constraints (e.g., "answer in 3 sentences")
|
||||
- Length constraints (e.g., "use at least 100 words")
|
||||
- Content constraints (e.g., "include the word 'banana'")
|
||||
- Structural constraints (e.g., "use bullet points")
|
||||
|
||||
**Format**: Free-form generation with rule-based verification
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-chat-hf \
|
||||
--tasks ifeval \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Measures: Instruction adherence (not quality)
|
||||
- GPT-4: 86% instruction following
|
||||
- Claude 2: 84%
|
||||
|
||||
**Good for**: Evaluating chat/instruct models.
|
||||
|
||||
### GLUE (General Language Understanding Evaluation)
|
||||
|
||||
**What it measures**: Natural language understanding across 9 tasks.
|
||||
|
||||
**Tasks**:
|
||||
- `cola`: Grammatical acceptability
|
||||
- `sst2`: Sentiment analysis
|
||||
- `mrpc`: Paraphrase detection
|
||||
- `qqp`: Question pairs
|
||||
- `stsb`: Semantic similarity
|
||||
- `mnli`: Natural language inference
|
||||
- `qnli`: Question answering NLI
|
||||
- `rte`: Recognizing textual entailment
|
||||
- `wnli`: Winograd schemas
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=bert-base-uncased \
|
||||
--tasks glue \
|
||||
--num_fewshot 0
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- BERT Base: 78.3 (GLUE score)
|
||||
- RoBERTa Large: 88.5
|
||||
- Human baseline: 87.1
|
||||
|
||||
**Good for**: Encoder-only models, fine-tuning baselines.
|
||||
|
||||
### LongBench
|
||||
|
||||
**What it measures**: Long-context understanding (4K-32K tokens).
|
||||
|
||||
**21 tasks covering**:
|
||||
- Single-document QA
|
||||
- Multi-document QA
|
||||
- Summarization
|
||||
- Few-shot learning
|
||||
- Code completion
|
||||
- Synthetic tasks
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks longbench \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Tests context utilization
|
||||
- Many models struggle beyond 4K tokens
|
||||
- GPT-4 Turbo: 54.3%
|
||||
|
||||
**Good for**: Evaluating long-context models.
|
||||
|
||||
## Additional Benchmarks
|
||||
|
||||
### TruthfulQA
|
||||
|
||||
**What it measures**: Model's propensity to be truthful vs. generate plausible-sounding falsehoods.
|
||||
|
||||
**Format**: Multiple choice with 4-5 options
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks truthfulqa_mc2 \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Larger models often score worse (more convincing lies)
|
||||
- GPT-3: 58.8%
|
||||
- GPT-4: 59.0%
|
||||
- Human: ~94%
|
||||
|
||||
### ARC (AI2 Reasoning Challenge)
|
||||
|
||||
**What it measures**: Grade-school science questions.
|
||||
|
||||
**Variants**:
|
||||
- `arc_easy`: Easier questions
|
||||
- `arc_challenge`: Harder questions requiring reasoning
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks arc_challenge \
|
||||
--num_fewshot 25
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- ARC-Easy: Most models >80%
|
||||
- ARC-Challenge random: 25%
|
||||
- GPT-4: 96.3%
|
||||
|
||||
### HellaSwag
|
||||
|
||||
**What it measures**: Commonsense reasoning about everyday situations.
|
||||
|
||||
**Format**: Choose most plausible continuation
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks hellaswag \
|
||||
--num_fewshot 10
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: 25%
|
||||
- GPT-3: 78.9%
|
||||
- Llama 2 70B: 85.3%
|
||||
|
||||
### WinoGrande
|
||||
|
||||
**What it measures**: Commonsense reasoning via pronoun resolution.
|
||||
|
||||
**Example**:
|
||||
```
|
||||
The trophy doesn't fit in the brown suitcase because _ is too large.
|
||||
A. the trophy
|
||||
B. the suitcase
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks winogrande \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
### PIQA
|
||||
|
||||
**What it measures**: Physical commonsense reasoning.
|
||||
|
||||
**Example**: "To clean a keyboard, use compressed air or..."
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks piqa
|
||||
```
|
||||
|
||||
## Multilingual Benchmarks
|
||||
|
||||
### AfroBench
|
||||
|
||||
**What it measures**: Performance across 64 African languages.
|
||||
|
||||
**15 tasks**: NLU, text generation, knowledge, QA, math reasoning
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks afrobench
|
||||
```
|
||||
|
||||
### NorEval
|
||||
|
||||
**What it measures**: Norwegian language understanding (9 task categories).
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=NbAiLab/nb-gpt-j-6B \
|
||||
--tasks noreval
|
||||
```
|
||||
|
||||
## Domain-Specific Benchmarks
|
||||
|
||||
### MATH
|
||||
|
||||
**What it measures**: High-school competition math problems.
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks math \
|
||||
--num_fewshot 4
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Very challenging
|
||||
- GPT-4: 42.5%
|
||||
- Minerva 540B: 33.6%
|
||||
|
||||
### MBPP (Mostly Basic Python Problems)
|
||||
|
||||
**What it measures**: Python programming from natural language descriptions.
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=codellama/CodeLlama-7b-hf \
|
||||
--tasks mbpp \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
### DROP
|
||||
|
||||
**What it measures**: Reading comprehension requiring discrete reasoning.
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks drop
|
||||
```
|
||||
|
||||
## Benchmark Selection Guide
|
||||
|
||||
### For General Purpose Models
|
||||
|
||||
Run this suite:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu,gsm8k,hellaswag,arc_challenge,truthfulqa_mc2 \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
### For Code Models
|
||||
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=codellama/CodeLlama-7b-hf \
|
||||
--tasks humaneval,mbpp \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
### For Chat/Instruct Models
|
||||
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-chat-hf \
|
||||
--tasks ifeval,mmlu,gsm8k_cot \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
### For Long Context Models
|
||||
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-3.1-8B \
|
||||
--tasks longbench \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
## Interpreting Results
|
||||
|
||||
### Understanding Metrics
|
||||
|
||||
**Accuracy**: Percentage of correct answers (most common)
|
||||
|
||||
**Exact Match (EM)**: Requires exact string match (strict)
|
||||
|
||||
**F1 Score**: Balances precision and recall
|
||||
|
||||
**BLEU/ROUGE**: Text generation similarity
|
||||
|
||||
**Pass@k**: Percentage passing when generating k samples
|
||||
|
||||
### Typical Score Ranges
|
||||
|
||||
| Model Size | MMLU | GSM8K | HumanEval | HellaSwag |
|
||||
|------------|------|-------|-----------|-----------|
|
||||
| 7B | 40-50% | 10-20% | 5-15% | 70-80% |
|
||||
| 13B | 45-55% | 20-35% | 15-25% | 75-82% |
|
||||
| 70B | 60-70% | 50-65% | 35-50% | 82-87% |
|
||||
| GPT-4 | 86% | 92% | 67% | 95% |
|
||||
|
||||
### Red Flags
|
||||
|
||||
- **All tasks at random chance**: Model not trained properly
|
||||
- **Exact 0% on generation tasks**: Likely format/parsing issue
|
||||
- **Huge variance across runs**: Check seed/sampling settings
|
||||
- **Better than GPT-4 on everything**: Likely contamination
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always report few-shot setting**: 0-shot, 5-shot, etc.
|
||||
2. **Run multiple seeds**: Report mean ± std
|
||||
3. **Check for data contamination**: Search training data for benchmark examples
|
||||
4. **Compare to published baselines**: Validate your setup
|
||||
5. **Report all hyperparameters**: Model, batch size, max tokens, temperature
|
||||
|
||||
## References
|
||||
|
||||
- Task list: `lm_eval --tasks list`
|
||||
- Task README: `lm_eval/tasks/README.md`
|
||||
- Papers: See individual benchmark papers
|
||||
@@ -0,0 +1,602 @@
|
||||
# Custom Tasks
|
||||
|
||||
Complete guide to creating domain-specific evaluation tasks in lm-evaluation-harness.
|
||||
|
||||
## Overview
|
||||
|
||||
Custom tasks allow you to evaluate models on your own datasets and metrics. Tasks are defined using YAML configuration files with optional Python utilities for complex logic.
|
||||
|
||||
**Why create custom tasks**:
|
||||
- Evaluate on proprietary/domain-specific data
|
||||
- Test specific capabilities not covered by existing benchmarks
|
||||
- Create evaluation pipelines for internal models
|
||||
- Reproduce research experiments
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Minimal Custom Task
|
||||
|
||||
Create `my_tasks/simple_qa.yaml`:
|
||||
|
||||
```yaml
|
||||
task: simple_qa
|
||||
dataset_path: data/simple_qa.jsonl
|
||||
output_type: generate_until
|
||||
doc_to_text: "Question: {{question}}\nAnswer:"
|
||||
doc_to_target: "{{answer}}"
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
```
|
||||
|
||||
**Run it**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks simple_qa \
|
||||
--include_path my_tasks/
|
||||
```
|
||||
|
||||
## Task Configuration Reference
|
||||
|
||||
### Essential Fields
|
||||
|
||||
```yaml
|
||||
# Task identification
|
||||
task: my_custom_task # Unique task name (required)
|
||||
task_alias: "My Task" # Display name
|
||||
tag: # Tags for grouping
|
||||
- custom
|
||||
- domain_specific
|
||||
|
||||
# Dataset configuration
|
||||
dataset_path: data/my_data.jsonl # HuggingFace dataset or local path
|
||||
dataset_name: default # Subset name (if applicable)
|
||||
training_split: train
|
||||
validation_split: validation
|
||||
test_split: test
|
||||
|
||||
# Evaluation configuration
|
||||
output_type: generate_until # or loglikelihood, multiple_choice
|
||||
num_fewshot: 5 # Number of few-shot examples
|
||||
batch_size: auto # Batch size
|
||||
|
||||
# Prompt templates (Jinja2)
|
||||
doc_to_text: "Question: {{question}}"
|
||||
doc_to_target: "{{answer}}"
|
||||
|
||||
# Metrics
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
|
||||
# Metadata
|
||||
metadata:
|
||||
version: 1.0
|
||||
```
|
||||
|
||||
### Output Types
|
||||
|
||||
**`generate_until`**: Free-form generation
|
||||
```yaml
|
||||
output_type: generate_until
|
||||
generation_kwargs:
|
||||
max_gen_toks: 256
|
||||
until:
|
||||
- "\n"
|
||||
- "."
|
||||
temperature: 0.0
|
||||
```
|
||||
|
||||
**`loglikelihood`**: Compute log probability of targets
|
||||
```yaml
|
||||
output_type: loglikelihood
|
||||
# Used for perplexity, classification
|
||||
```
|
||||
|
||||
**`multiple_choice`**: Choose from options
|
||||
```yaml
|
||||
output_type: multiple_choice
|
||||
doc_to_choice: "{{choices}}" # List of choices
|
||||
```
|
||||
|
||||
## Data Formats
|
||||
|
||||
### Local JSONL File
|
||||
|
||||
`data/my_data.jsonl`:
|
||||
```json
|
||||
{"question": "What is 2+2?", "answer": "4"}
|
||||
{"question": "Capital of France?", "answer": "Paris"}
|
||||
```
|
||||
|
||||
**Task config**:
|
||||
```yaml
|
||||
dataset_path: data/my_data.jsonl
|
||||
dataset_kwargs:
|
||||
data_files:
|
||||
test: data/my_data.jsonl
|
||||
```
|
||||
|
||||
### HuggingFace Dataset
|
||||
|
||||
```yaml
|
||||
dataset_path: squad
|
||||
dataset_name: plain_text
|
||||
test_split: validation
|
||||
```
|
||||
|
||||
### CSV File
|
||||
|
||||
`data/my_data.csv`:
|
||||
```csv
|
||||
question,answer,category
|
||||
What is 2+2?,4,math
|
||||
Capital of France?,Paris,geography
|
||||
```
|
||||
|
||||
**Task config**:
|
||||
```yaml
|
||||
dataset_path: data/my_data.csv
|
||||
dataset_kwargs:
|
||||
data_files:
|
||||
test: data/my_data.csv
|
||||
```
|
||||
|
||||
## Prompt Engineering
|
||||
|
||||
### Simple Template
|
||||
|
||||
```yaml
|
||||
doc_to_text: "Question: {{question}}\nAnswer:"
|
||||
doc_to_target: "{{answer}}"
|
||||
```
|
||||
|
||||
### Conditional Logic
|
||||
|
||||
```yaml
|
||||
doc_to_text: |
|
||||
{% if context %}
|
||||
Context: {{context}}
|
||||
{% endif %}
|
||||
Question: {{question}}
|
||||
Answer:
|
||||
```
|
||||
|
||||
### Multiple Choice
|
||||
|
||||
```yaml
|
||||
doc_to_text: |
|
||||
Question: {{question}}
|
||||
A. {{choices[0]}}
|
||||
B. {{choices[1]}}
|
||||
C. {{choices[2]}}
|
||||
D. {{choices[3]}}
|
||||
Answer:
|
||||
|
||||
doc_to_target: "{{ 'ABCD'[answer_idx] }}"
|
||||
doc_to_choice: ["A", "B", "C", "D"]
|
||||
```
|
||||
|
||||
### Few-Shot Formatting
|
||||
|
||||
```yaml
|
||||
fewshot_delimiter: "\n\n" # Between examples
|
||||
target_delimiter: " " # Between question and answer
|
||||
doc_to_text: "Q: {{question}}"
|
||||
doc_to_target: "A: {{answer}}"
|
||||
```
|
||||
|
||||
## Custom Python Functions
|
||||
|
||||
For complex logic, use Python functions in `utils.py`.
|
||||
|
||||
### Create `my_tasks/utils.py`
|
||||
|
||||
```python
|
||||
def process_docs(dataset):
|
||||
"""Preprocess documents."""
|
||||
def _process(doc):
|
||||
# Custom preprocessing
|
||||
doc["question"] = doc["question"].strip().lower()
|
||||
return doc
|
||||
|
||||
return dataset.map(_process)
|
||||
|
||||
def doc_to_text(doc):
|
||||
"""Custom prompt formatting."""
|
||||
context = doc.get("context", "")
|
||||
question = doc["question"]
|
||||
|
||||
if context:
|
||||
return f"Context: {context}\nQuestion: {question}\nAnswer:"
|
||||
return f"Question: {question}\nAnswer:"
|
||||
|
||||
def doc_to_target(doc):
|
||||
"""Custom target extraction."""
|
||||
return doc["answer"].strip().lower()
|
||||
|
||||
def aggregate_scores(items):
|
||||
"""Custom metric aggregation."""
|
||||
correct = sum(1 for item in items if item == 1.0)
|
||||
total = len(items)
|
||||
return correct / total if total > 0 else 0.0
|
||||
```
|
||||
|
||||
### Use in Task Config
|
||||
|
||||
```yaml
|
||||
task: my_custom_task
|
||||
dataset_path: data/my_data.jsonl
|
||||
|
||||
# Use Python functions
|
||||
process_docs: !function utils.process_docs
|
||||
doc_to_text: !function utils.doc_to_text
|
||||
doc_to_target: !function utils.doc_to_target
|
||||
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: !function utils.aggregate_scores
|
||||
higher_is_better: true
|
||||
```
|
||||
|
||||
## Real-World Examples
|
||||
|
||||
### Example 1: Domain QA Task
|
||||
|
||||
**Goal**: Evaluate medical question answering.
|
||||
|
||||
`medical_qa/medical_qa.yaml`:
|
||||
```yaml
|
||||
task: medical_qa
|
||||
dataset_path: data/medical_qa.jsonl
|
||||
output_type: generate_until
|
||||
num_fewshot: 3
|
||||
|
||||
doc_to_text: |
|
||||
Medical Question: {{question}}
|
||||
Context: {{context}}
|
||||
Answer (be concise):
|
||||
|
||||
doc_to_target: "{{answer}}"
|
||||
|
||||
generation_kwargs:
|
||||
max_gen_toks: 100
|
||||
until:
|
||||
- "\n\n"
|
||||
temperature: 0.0
|
||||
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
- metric: !function utils.medical_f1
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
|
||||
filter_list:
|
||||
- name: lowercase
|
||||
filter:
|
||||
- function: lowercase
|
||||
- function: remove_whitespace
|
||||
|
||||
metadata:
|
||||
version: 1.0
|
||||
domain: medical
|
||||
```
|
||||
|
||||
`medical_qa/utils.py`:
|
||||
```python
|
||||
from sklearn.metrics import f1_score
|
||||
import re
|
||||
|
||||
def medical_f1(predictions, references):
|
||||
"""Custom F1 for medical terms."""
|
||||
pred_terms = set(extract_medical_terms(predictions[0]))
|
||||
ref_terms = set(extract_medical_terms(references[0]))
|
||||
|
||||
if not pred_terms and not ref_terms:
|
||||
return 1.0
|
||||
if not pred_terms or not ref_terms:
|
||||
return 0.0
|
||||
|
||||
tp = len(pred_terms & ref_terms)
|
||||
fp = len(pred_terms - ref_terms)
|
||||
fn = len(ref_terms - pred_terms)
|
||||
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||
|
||||
return 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
|
||||
def extract_medical_terms(text):
|
||||
"""Extract medical terminology."""
|
||||
# Custom logic
|
||||
return re.findall(r'\b[A-Z][a-z]+(?:[A-Z][a-z]+)*\b', text)
|
||||
```
|
||||
|
||||
### Example 2: Code Evaluation
|
||||
|
||||
`code_eval/python_challenges.yaml`:
|
||||
```yaml
|
||||
task: python_challenges
|
||||
dataset_path: data/python_problems.jsonl
|
||||
output_type: generate_until
|
||||
num_fewshot: 0
|
||||
|
||||
doc_to_text: |
|
||||
Write a Python function to solve:
|
||||
{{problem_statement}}
|
||||
|
||||
Function signature:
|
||||
{{function_signature}}
|
||||
|
||||
doc_to_target: "{{canonical_solution}}"
|
||||
|
||||
generation_kwargs:
|
||||
max_gen_toks: 512
|
||||
until:
|
||||
- "\n\nclass"
|
||||
- "\n\ndef"
|
||||
temperature: 0.2
|
||||
|
||||
metric_list:
|
||||
- metric: !function utils.execute_code
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
|
||||
process_results: !function utils.process_code_results
|
||||
|
||||
metadata:
|
||||
version: 1.0
|
||||
```
|
||||
|
||||
`code_eval/utils.py`:
|
||||
```python
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
def execute_code(predictions, references):
|
||||
"""Execute generated code against test cases."""
|
||||
generated_code = predictions[0]
|
||||
test_cases = json.loads(references[0])
|
||||
|
||||
try:
|
||||
# Execute code with test cases
|
||||
for test_input, expected_output in test_cases:
|
||||
result = execute_with_timeout(generated_code, test_input, timeout=5)
|
||||
if result != expected_output:
|
||||
return 0.0
|
||||
return 1.0
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def execute_with_timeout(code, input_data, timeout=5):
|
||||
"""Safely execute code with timeout."""
|
||||
# Implementation with subprocess and timeout
|
||||
pass
|
||||
|
||||
def process_code_results(doc, results):
|
||||
"""Process code execution results."""
|
||||
return {
|
||||
"passed": results[0] == 1.0,
|
||||
"generated_code": results[1]
|
||||
}
|
||||
```
|
||||
|
||||
### Example 3: Instruction Following
|
||||
|
||||
`instruction_eval/instruction_eval.yaml`:
|
||||
```yaml
|
||||
task: instruction_following
|
||||
dataset_path: data/instructions.jsonl
|
||||
output_type: generate_until
|
||||
num_fewshot: 0
|
||||
|
||||
doc_to_text: |
|
||||
Instruction: {{instruction}}
|
||||
{% if constraints %}
|
||||
Constraints: {{constraints}}
|
||||
{% endif %}
|
||||
Response:
|
||||
|
||||
doc_to_target: "{{expected_response}}"
|
||||
|
||||
generation_kwargs:
|
||||
max_gen_toks: 256
|
||||
temperature: 0.7
|
||||
|
||||
metric_list:
|
||||
- metric: !function utils.check_constraints
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
- metric: !function utils.semantic_similarity
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
|
||||
process_docs: !function utils.add_constraint_checkers
|
||||
```
|
||||
|
||||
`instruction_eval/utils.py`:
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
def check_constraints(predictions, references):
|
||||
"""Check if response satisfies constraints."""
|
||||
response = predictions[0]
|
||||
constraints = json.loads(references[0])
|
||||
|
||||
satisfied = 0
|
||||
total = len(constraints)
|
||||
|
||||
for constraint in constraints:
|
||||
if verify_constraint(response, constraint):
|
||||
satisfied += 1
|
||||
|
||||
return satisfied / total if total > 0 else 1.0
|
||||
|
||||
def verify_constraint(response, constraint):
|
||||
"""Verify single constraint."""
|
||||
if constraint["type"] == "length":
|
||||
return len(response.split()) >= constraint["min_words"]
|
||||
elif constraint["type"] == "contains":
|
||||
return constraint["keyword"] in response.lower()
|
||||
# Add more constraint types
|
||||
return True
|
||||
|
||||
def semantic_similarity(predictions, references):
|
||||
"""Compute semantic similarity."""
|
||||
pred_embedding = model.encode(predictions[0])
|
||||
ref_embedding = model.encode(references[0])
|
||||
return float(util.cos_sim(pred_embedding, ref_embedding))
|
||||
|
||||
def add_constraint_checkers(dataset):
|
||||
"""Parse constraints into verifiable format."""
|
||||
def _parse(doc):
|
||||
# Parse constraint string into structured format
|
||||
doc["parsed_constraints"] = parse_constraints(doc.get("constraints", ""))
|
||||
return doc
|
||||
return dataset.map(_parse)
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Output Filtering
|
||||
|
||||
```yaml
|
||||
filter_list:
|
||||
- name: extract_answer
|
||||
filter:
|
||||
- function: regex
|
||||
regex_pattern: "Answer: (.*)"
|
||||
group: 1
|
||||
- function: lowercase
|
||||
- function: strip_whitespace
|
||||
```
|
||||
|
||||
### Multiple Metrics
|
||||
|
||||
```yaml
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
- metric: f1
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
- metric: bleu
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
```
|
||||
|
||||
### Task Groups
|
||||
|
||||
Create `my_tasks/_default.yaml`:
|
||||
```yaml
|
||||
group: my_eval_suite
|
||||
task:
|
||||
- simple_qa
|
||||
- medical_qa
|
||||
- python_challenges
|
||||
```
|
||||
|
||||
**Run entire suite**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks my_eval_suite \
|
||||
--include_path my_tasks/
|
||||
```
|
||||
|
||||
## Testing Your Task
|
||||
|
||||
### Validate Configuration
|
||||
|
||||
```bash
|
||||
# Test task loading
|
||||
lm_eval --tasks my_custom_task --include_path my_tasks/ --limit 0
|
||||
|
||||
# Run on 5 samples
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=gpt2 \
|
||||
--tasks my_custom_task \
|
||||
--include_path my_tasks/ \
|
||||
--limit 5
|
||||
```
|
||||
|
||||
### Debug Mode
|
||||
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=gpt2 \
|
||||
--tasks my_custom_task \
|
||||
--include_path my_tasks/ \
|
||||
--limit 1 \
|
||||
--log_samples # Save input/output samples
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Start simple**: Test with minimal config first
|
||||
2. **Version your tasks**: Use `metadata.version`
|
||||
3. **Document your metrics**: Explain custom metrics in comments
|
||||
4. **Test with multiple models**: Ensure robustness
|
||||
5. **Validate on known examples**: Include sanity checks
|
||||
6. **Use filters carefully**: Can hide errors
|
||||
7. **Handle edge cases**: Empty strings, missing fields
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Classification Task
|
||||
|
||||
```yaml
|
||||
output_type: loglikelihood
|
||||
doc_to_text: "Text: {{text}}\nLabel:"
|
||||
doc_to_target: " {{label}}" # Space prefix important!
|
||||
metric_list:
|
||||
- metric: acc
|
||||
aggregation: mean
|
||||
```
|
||||
|
||||
### Perplexity Evaluation
|
||||
|
||||
```yaml
|
||||
output_type: loglikelihood_rolling
|
||||
doc_to_text: "{{text}}"
|
||||
metric_list:
|
||||
- metric: perplexity
|
||||
aggregation: perplexity
|
||||
```
|
||||
|
||||
### Ranking Task
|
||||
|
||||
```yaml
|
||||
output_type: loglikelihood
|
||||
doc_to_text: "Query: {{query}}\nPassage: {{passage}}\nRelevant:"
|
||||
doc_to_target: [" Yes", " No"]
|
||||
metric_list:
|
||||
- metric: acc
|
||||
aggregation: mean
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**"Task not found"**: Check `--include_path` and task name
|
||||
|
||||
**Empty results**: Verify `doc_to_text` and `doc_to_target` templates
|
||||
|
||||
**Metric errors**: Ensure metric names are correct (exact_match, not exact-match)
|
||||
|
||||
**Filter issues**: Test filters with `--log_samples`
|
||||
|
||||
**Python function not found**: Check `!function module.function_name` syntax
|
||||
|
||||
## References
|
||||
|
||||
- Task system: EleutherAI/lm-evaluation-harness docs
|
||||
- Example tasks: `lm_eval/tasks/` directory
|
||||
- TaskConfig: `lm_eval/api/task.py`
|
||||
@@ -0,0 +1,519 @@
|
||||
# Distributed Evaluation
|
||||
|
||||
Guide to running evaluation across multiple GPUs using data parallelism and tensor/pipeline parallelism.
|
||||
|
||||
## Overview
|
||||
|
||||
Distributed evaluation speeds up benchmarking by:
|
||||
- **Data Parallelism**: Split evaluation samples across GPUs (each GPU has full model copy)
|
||||
- **Tensor Parallelism**: Split model weights across GPUs (for large models)
|
||||
- **Pipeline Parallelism**: Split model layers across GPUs (for very large models)
|
||||
|
||||
**When to use**:
|
||||
- Data Parallel: Model fits on single GPU, want faster evaluation
|
||||
- Tensor/Pipeline Parallel: Model too large for single GPU
|
||||
|
||||
## HuggingFace Models (`hf`)
|
||||
|
||||
### Data Parallelism (Recommended)
|
||||
|
||||
Each GPU loads a full copy of the model and processes a subset of evaluation data.
|
||||
|
||||
**Single Node (8 GPUs)**:
|
||||
```bash
|
||||
accelerate launch --multi_gpu --num_processes 8 \
|
||||
-m lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,dtype=bfloat16 \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--batch_size 16
|
||||
```
|
||||
|
||||
**Speedup**: Near-linear (8 GPUs = ~8× faster)
|
||||
|
||||
**Memory**: Each GPU needs full model (7B model ≈ 14GB × 8 = 112GB total)
|
||||
|
||||
### Tensor Parallelism (Model Sharding)
|
||||
|
||||
Split model weights across GPUs for models too large for single GPU.
|
||||
|
||||
**Without accelerate launcher**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
parallelize=True,\
|
||||
dtype=bfloat16 \
|
||||
--tasks mmlu,gsm8k \
|
||||
--batch_size 8
|
||||
```
|
||||
|
||||
**With 8 GPUs**: 70B model (140GB) / 8 = 17.5GB per GPU ✅
|
||||
|
||||
**Advanced sharding**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
parallelize=True,\
|
||||
device_map_option=auto,\
|
||||
max_memory_per_gpu=40GB,\
|
||||
max_cpu_memory=100GB,\
|
||||
dtype=bfloat16 \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
**Options**:
|
||||
- `device_map_option`: `"auto"` (default), `"balanced"`, `"balanced_low_0"`
|
||||
- `max_memory_per_gpu`: Max memory per GPU (e.g., `"40GB"`)
|
||||
- `max_cpu_memory`: Max CPU memory for offloading
|
||||
- `offload_folder`: Disk offloading directory
|
||||
|
||||
### Combined Data + Tensor Parallelism
|
||||
|
||||
Use both for very large models.
|
||||
|
||||
**Example: 70B model on 16 GPUs (2 copies, 8 GPUs each)**:
|
||||
```bash
|
||||
accelerate launch --multi_gpu --num_processes 2 \
|
||||
-m lm_eval --model hf \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
parallelize=True,\
|
||||
dtype=bfloat16 \
|
||||
--tasks mmlu \
|
||||
--batch_size 8
|
||||
```
|
||||
|
||||
**Result**: 2× speedup from data parallelism, 70B model fits via tensor parallelism
|
||||
|
||||
### Configuration with `accelerate config`
|
||||
|
||||
Create `~/.cache/huggingface/accelerate/default_config.yaml`:
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: MULTI_GPU
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
gpu_ids: all
|
||||
mixed_precision: bf16
|
||||
```
|
||||
|
||||
**Then run**:
|
||||
```bash
|
||||
accelerate launch -m lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
## vLLM Models (`vllm`)
|
||||
|
||||
vLLM provides highly optimized distributed inference.
|
||||
|
||||
### Tensor Parallelism
|
||||
|
||||
**Single Node (4 GPUs)**:
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tensor_parallel_size=4,\
|
||||
dtype=auto,\
|
||||
gpu_memory_utilization=0.9 \
|
||||
--tasks mmlu,gsm8k \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Memory**: 70B model split across 4 GPUs = ~35GB per GPU
|
||||
|
||||
### Data Parallelism
|
||||
|
||||
**Multiple model replicas**:
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-7b-hf,\
|
||||
data_parallel_size=4,\
|
||||
dtype=auto,\
|
||||
gpu_memory_utilization=0.8 \
|
||||
--tasks hellaswag,arc_challenge \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Result**: 4 model replicas = 4× throughput
|
||||
|
||||
### Combined Tensor + Data Parallelism
|
||||
|
||||
**Example: 8 GPUs = 4 TP × 2 DP**:
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tensor_parallel_size=4,\
|
||||
data_parallel_size=2,\
|
||||
dtype=auto,\
|
||||
gpu_memory_utilization=0.85 \
|
||||
--tasks mmlu \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Result**: 70B model fits (TP=4), 2× speedup (DP=2)
|
||||
|
||||
### Multi-Node vLLM
|
||||
|
||||
vLLM doesn't natively support multi-node. Use Ray:
|
||||
|
||||
```bash
|
||||
# Start Ray cluster
|
||||
ray start --head --port=6379
|
||||
|
||||
# Run evaluation
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tensor_parallel_size=8,\
|
||||
dtype=auto \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
## NVIDIA NeMo Models (`nemo_lm`)
|
||||
|
||||
### Data Replication
|
||||
|
||||
**8 replicas on 8 GPUs**:
|
||||
```bash
|
||||
torchrun --nproc-per-node=8 --no-python \
|
||||
lm_eval --model nemo_lm \
|
||||
--model_args \
|
||||
path=/path/to/model.nemo,\
|
||||
devices=8 \
|
||||
--tasks hellaswag,arc_challenge \
|
||||
--batch_size 32
|
||||
```
|
||||
|
||||
**Speedup**: Near-linear (8× faster)
|
||||
|
||||
### Tensor Parallelism
|
||||
|
||||
**4-way tensor parallelism**:
|
||||
```bash
|
||||
torchrun --nproc-per-node=4 --no-python \
|
||||
lm_eval --model nemo_lm \
|
||||
--model_args \
|
||||
path=/path/to/70b_model.nemo,\
|
||||
devices=4,\
|
||||
tensor_model_parallel_size=4 \
|
||||
--tasks mmlu,gsm8k \
|
||||
--batch_size 16
|
||||
```
|
||||
|
||||
### Pipeline Parallelism
|
||||
|
||||
**2 TP × 2 PP on 4 GPUs**:
|
||||
```bash
|
||||
torchrun --nproc-per-node=4 --no-python \
|
||||
lm_eval --model nemo_lm \
|
||||
--model_args \
|
||||
path=/path/to/model.nemo,\
|
||||
devices=4,\
|
||||
tensor_model_parallel_size=2,\
|
||||
pipeline_model_parallel_size=2 \
|
||||
--tasks mmlu \
|
||||
--batch_size 8
|
||||
```
|
||||
|
||||
**Constraint**: `devices = TP × PP`
|
||||
|
||||
### Multi-Node NeMo
|
||||
|
||||
Currently not supported by lm-evaluation-harness.
|
||||
|
||||
## SGLang Models (`sglang`)
|
||||
|
||||
### Tensor Parallelism
|
||||
|
||||
```bash
|
||||
lm_eval --model sglang \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tp_size=4,\
|
||||
dtype=auto \
|
||||
--tasks gsm8k \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
### Data Parallelism (Deprecated)
|
||||
|
||||
**Note**: SGLang is deprecating data parallelism. Use tensor parallelism instead.
|
||||
|
||||
```bash
|
||||
lm_eval --model sglang \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-7b-hf,\
|
||||
dp_size=4,\
|
||||
dtype=auto \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
## Performance Comparison
|
||||
|
||||
### 70B Model Evaluation (MMLU, 5-shot)
|
||||
|
||||
| Method | GPUs | Time | Memory/GPU | Notes |
|
||||
|--------|------|------|------------|-------|
|
||||
| HF (no parallel) | 1 | 8 hours | 140GB (OOM) | Won't fit |
|
||||
| HF (TP=8) | 8 | 2 hours | 17.5GB | Slower, fits |
|
||||
| HF (DP=8) | 8 | 1 hour | 140GB (OOM) | Won't fit |
|
||||
| vLLM (TP=4) | 4 | 30 min | 35GB | Fast! |
|
||||
| vLLM (TP=4, DP=2) | 8 | 15 min | 35GB | Fastest |
|
||||
|
||||
### 7B Model Evaluation (Multiple Tasks)
|
||||
|
||||
| Method | GPUs | Time | Speedup |
|
||||
|--------|------|------|---------|
|
||||
| HF (single) | 1 | 4 hours | 1× |
|
||||
| HF (DP=4) | 4 | 1 hour | 4× |
|
||||
| HF (DP=8) | 8 | 30 min | 8× |
|
||||
| vLLM (DP=8) | 8 | 15 min | 16× |
|
||||
|
||||
**Takeaway**: vLLM is significantly faster than HuggingFace for inference.
|
||||
|
||||
## Choosing Parallelism Strategy
|
||||
|
||||
### Decision Tree
|
||||
|
||||
```
|
||||
Model fits on single GPU?
|
||||
├─ YES: Use data parallelism
|
||||
│ ├─ HF: accelerate launch --multi_gpu --num_processes N
|
||||
│ └─ vLLM: data_parallel_size=N (fastest)
|
||||
│
|
||||
└─ NO: Use tensor/pipeline parallelism
|
||||
├─ Model < 70B:
|
||||
│ └─ vLLM: tensor_parallel_size=4
|
||||
├─ Model 70-175B:
|
||||
│ ├─ vLLM: tensor_parallel_size=8
|
||||
│ └─ Or HF: parallelize=True
|
||||
└─ Model > 175B:
|
||||
└─ Contact framework authors
|
||||
```
|
||||
|
||||
### Memory Estimation
|
||||
|
||||
**Rule of thumb**:
|
||||
```
|
||||
Memory (GB) = Parameters (B) × Precision (bytes) × 1.2 (overhead)
|
||||
```
|
||||
|
||||
**Examples**:
|
||||
- 7B FP16: 7 × 2 × 1.2 = 16.8GB ✅ Fits A100 40GB
|
||||
- 13B FP16: 13 × 2 × 1.2 = 31.2GB ✅ Fits A100 40GB
|
||||
- 70B FP16: 70 × 2 × 1.2 = 168GB ❌ Need TP=4 or TP=8
|
||||
- 70B BF16: 70 × 2 × 1.2 = 168GB (same as FP16)
|
||||
|
||||
**With tensor parallelism**:
|
||||
```
|
||||
Memory per GPU = Total Memory / TP
|
||||
```
|
||||
|
||||
- 70B on 4 GPUs: 168GB / 4 = 42GB per GPU ✅
|
||||
- 70B on 8 GPUs: 168GB / 8 = 21GB per GPU ✅
|
||||
|
||||
## Multi-Node Evaluation
|
||||
|
||||
### HuggingFace with SLURM
|
||||
|
||||
**Submit job**:
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=4
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --ntasks-per-node=1
|
||||
|
||||
srun accelerate launch --multi_gpu \
|
||||
--num_processes $((SLURM_NNODES * 8)) \
|
||||
-m lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--batch_size 16
|
||||
```
|
||||
|
||||
**Submit**:
|
||||
```bash
|
||||
sbatch eval_job.sh
|
||||
```
|
||||
|
||||
### Manual Multi-Node Setup
|
||||
|
||||
**On each node, run**:
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_machines 4 \
|
||||
--num_processes 32 \
|
||||
--main_process_ip $MASTER_IP \
|
||||
--main_process_port 29500 \
|
||||
--machine_rank $NODE_RANK \
|
||||
-m lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
**Environment variables**:
|
||||
- `MASTER_IP`: IP of rank 0 node
|
||||
- `NODE_RANK`: 0, 1, 2, 3 for each node
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Small
|
||||
|
||||
Test on small sample first:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-70b-hf,parallelize=True \
|
||||
--tasks mmlu \
|
||||
--limit 100 # Just 100 samples
|
||||
```
|
||||
|
||||
### 2. Monitor GPU Usage
|
||||
|
||||
```bash
|
||||
# Terminal 1: Run evaluation
|
||||
lm_eval --model hf ...
|
||||
|
||||
# Terminal 2: Monitor
|
||||
watch -n 1 nvidia-smi
|
||||
```
|
||||
|
||||
Look for:
|
||||
- GPU utilization > 90%
|
||||
- Memory usage stable
|
||||
- All GPUs active
|
||||
|
||||
### 3. Optimize Batch Size
|
||||
|
||||
```bash
|
||||
# Auto batch size (recommended)
|
||||
--batch_size auto
|
||||
|
||||
# Or tune manually
|
||||
--batch_size 16 # Start here
|
||||
--batch_size 32 # Increase if memory allows
|
||||
```
|
||||
|
||||
### 4. Use Mixed Precision
|
||||
|
||||
```bash
|
||||
--model_args dtype=bfloat16 # Faster, less memory
|
||||
```
|
||||
|
||||
### 5. Check Communication
|
||||
|
||||
For data parallelism, check network bandwidth:
|
||||
```bash
|
||||
# Should see InfiniBand or high-speed network
|
||||
nvidia-smi topo -m
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "CUDA out of memory"
|
||||
|
||||
**Solutions**:
|
||||
1. Increase tensor parallelism:
|
||||
```bash
|
||||
--model_args tensor_parallel_size=8 # Was 4
|
||||
```
|
||||
|
||||
2. Reduce batch size:
|
||||
```bash
|
||||
--batch_size 4 # Was 16
|
||||
```
|
||||
|
||||
3. Lower precision:
|
||||
```bash
|
||||
--model_args dtype=int8 # Quantization
|
||||
```
|
||||
|
||||
### "NCCL error" or Hanging
|
||||
|
||||
**Check**:
|
||||
1. All GPUs visible: `nvidia-smi`
|
||||
2. NCCL installed: `python -c "import torch; print(torch.cuda.nccl.version())"`
|
||||
3. Network connectivity between nodes
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO # Enable debug logging
|
||||
export NCCL_IB_DISABLE=0 # Use InfiniBand if available
|
||||
```
|
||||
|
||||
### Slow Evaluation
|
||||
|
||||
**Possible causes**:
|
||||
1. **Data loading bottleneck**: Preprocess dataset
|
||||
2. **Low GPU utilization**: Increase batch size
|
||||
3. **Communication overhead**: Reduce parallelism degree
|
||||
|
||||
**Profile**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu \
|
||||
--limit 100 \
|
||||
--log_samples # Check timing
|
||||
```
|
||||
|
||||
### GPUs Imbalanced
|
||||
|
||||
**Symptom**: GPU 0 at 100%, others at 50%
|
||||
|
||||
**Solution**: Use `device_map_option=balanced`:
|
||||
```bash
|
||||
--model_args parallelize=True,device_map_option=balanced
|
||||
```
|
||||
|
||||
## Example Configurations
|
||||
|
||||
### Small Model (7B) - Fast Evaluation
|
||||
|
||||
```bash
|
||||
# 8 A100s, data parallel
|
||||
accelerate launch --multi_gpu --num_processes 8 \
|
||||
-m lm_eval --model hf \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-7b-hf,\
|
||||
dtype=bfloat16 \
|
||||
--tasks mmlu,gsm8k,hellaswag,arc_challenge \
|
||||
--num_fewshot 5 \
|
||||
--batch_size 32
|
||||
|
||||
# Time: ~30 minutes
|
||||
```
|
||||
|
||||
### Large Model (70B) - vLLM
|
||||
|
||||
```bash
|
||||
# 8 H100s, tensor parallel
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tensor_parallel_size=8,\
|
||||
dtype=auto,\
|
||||
gpu_memory_utilization=0.9 \
|
||||
--tasks mmlu,gsm8k,humaneval \
|
||||
--num_fewshot 5 \
|
||||
--batch_size auto
|
||||
|
||||
# Time: ~1 hour
|
||||
```
|
||||
|
||||
### Very Large Model (175B+)
|
||||
|
||||
**Requires specialized setup - contact framework maintainers**
|
||||
|
||||
## References
|
||||
|
||||
- HuggingFace Accelerate: https://huggingface.co/docs/accelerate/
|
||||
- vLLM docs: https://docs.vllm.ai/
|
||||
- NeMo docs: https://docs.nvidia.com/nemo-framework/
|
||||
- lm-eval distributed guide: `docs/model_guide.md`
|
||||
386
protected/skills-backup/mlops/evaluation/nemo-curator/SKILL.md
Normal file
386
protected/skills-backup/mlops/evaluation/nemo-curator/SKILL.md
Normal file
@@ -0,0 +1,386 @@
|
||||
---
|
||||
name: nemo-curator
|
||||
description: GPU-accelerated data curation for LLM training. Supports text/image/video/audio. Features fuzzy deduplication (16× faster), quality filtering (30+ heuristics), semantic deduplication, PII redaction, NSFW detection. Scales across GPUs with RAPIDS. Use for preparing high-quality training datasets, cleaning web data, or deduplicating large corpora.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [nemo-curator, cudf, dask, rapids]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Data Processing, NeMo Curator, Data Curation, GPU Acceleration, Deduplication, Quality Filtering, NVIDIA, RAPIDS, PII Redaction, Multimodal, LLM Training Data]
|
||||
|
||||
---
|
||||
|
||||
# NeMo Curator - GPU-Accelerated Data Curation
|
||||
|
||||
NVIDIA's toolkit for preparing high-quality training data for LLMs.
|
||||
|
||||
## When to use NeMo Curator
|
||||
|
||||
**Use NeMo Curator when:**
|
||||
- Preparing LLM training data from web scrapes (Common Crawl)
|
||||
- Need fast deduplication (16× faster than CPU)
|
||||
- Curating multi-modal datasets (text, images, video, audio)
|
||||
- Filtering low-quality or toxic content
|
||||
- Scaling data processing across GPU cluster
|
||||
|
||||
**Performance**:
|
||||
- **16× faster** fuzzy deduplication (8TB RedPajama v2)
|
||||
- **40% lower TCO** vs CPU alternatives
|
||||
- **Near-linear scaling** across GPU nodes
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **datatrove**: CPU-based, open-source data processing
|
||||
- **dolma**: Allen AI's data toolkit
|
||||
- **Ray Data**: General ML data processing (no curation focus)
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Text curation (CUDA 12)
|
||||
uv pip install "nemo-curator[text_cuda12]"
|
||||
|
||||
# All modalities
|
||||
uv pip install "nemo-curator[all_cuda12]"
|
||||
|
||||
# CPU-only (slower)
|
||||
uv pip install "nemo-curator[cpu]"
|
||||
```
|
||||
|
||||
### Basic text curation pipeline
|
||||
|
||||
```python
|
||||
from nemo_curator import ScoreFilter, Modify
|
||||
from nemo_curator.datasets import DocumentDataset
|
||||
import pandas as pd
|
||||
|
||||
# Load data
|
||||
df = pd.DataFrame({"text": ["Good document", "Bad doc", "Excellent text"]})
|
||||
dataset = DocumentDataset(df)
|
||||
|
||||
# Quality filtering
|
||||
def quality_score(doc):
|
||||
return len(doc["text"].split()) > 5 # Filter short docs
|
||||
|
||||
filtered = ScoreFilter(quality_score)(dataset)
|
||||
|
||||
# Deduplication
|
||||
from nemo_curator.modules import ExactDuplicates
|
||||
deduped = ExactDuplicates()(filtered)
|
||||
|
||||
# Save
|
||||
deduped.to_parquet("curated_data/")
|
||||
```
|
||||
|
||||
## Data curation pipeline
|
||||
|
||||
### Stage 1: Quality filtering
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import (
|
||||
WordCountFilter,
|
||||
RepeatedLinesFilter,
|
||||
UrlRatioFilter,
|
||||
NonAlphaNumericFilter
|
||||
)
|
||||
|
||||
# Apply 30+ heuristic filters
|
||||
from nemo_curator import ScoreFilter
|
||||
|
||||
# Word count filter
|
||||
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
|
||||
|
||||
# Remove repetitive content
|
||||
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
|
||||
|
||||
# URL ratio filter
|
||||
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
|
||||
```
|
||||
|
||||
### Stage 2: Deduplication
|
||||
|
||||
**Exact deduplication**:
|
||||
```python
|
||||
from nemo_curator.modules import ExactDuplicates
|
||||
|
||||
# Remove exact duplicates
|
||||
deduped = ExactDuplicates(id_field="id", text_field="text")(dataset)
|
||||
```
|
||||
|
||||
**Fuzzy deduplication** (16× faster on GPU):
|
||||
```python
|
||||
from nemo_curator.modules import FuzzyDuplicates
|
||||
|
||||
# MinHash + LSH deduplication
|
||||
fuzzy_dedup = FuzzyDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
num_hashes=260, # MinHash parameters
|
||||
num_buckets=20,
|
||||
hash_method="md5"
|
||||
)
|
||||
|
||||
deduped = fuzzy_dedup(dataset)
|
||||
```
|
||||
|
||||
**Semantic deduplication**:
|
||||
```python
|
||||
from nemo_curator.modules import SemanticDuplicates
|
||||
|
||||
# Embedding-based deduplication
|
||||
semantic_dedup = SemanticDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
threshold=0.8 # Cosine similarity threshold
|
||||
)
|
||||
|
||||
deduped = semantic_dedup(dataset)
|
||||
```
|
||||
|
||||
### Stage 3: PII redaction
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import Modify
|
||||
from nemo_curator.modifiers import PIIRedactor
|
||||
|
||||
# Redact personally identifiable information
|
||||
pii_redactor = PIIRedactor(
|
||||
supported_entities=["EMAIL_ADDRESS", "PHONE_NUMBER", "PERSON", "LOCATION"],
|
||||
anonymize_action="replace" # or "redact"
|
||||
)
|
||||
|
||||
redacted = Modify(pii_redactor)(dataset)
|
||||
```
|
||||
|
||||
### Stage 4: Classifier filtering
|
||||
|
||||
```python
|
||||
from nemo_curator.classifiers import QualityClassifier
|
||||
|
||||
# Quality classification
|
||||
quality_clf = QualityClassifier(
|
||||
model_path="nvidia/quality-classifier-deberta",
|
||||
batch_size=256,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Filter low-quality documents
|
||||
high_quality = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
### GPU vs CPU performance
|
||||
|
||||
| Operation | CPU (16 cores) | GPU (A100) | Speedup |
|
||||
|-----------|----------------|------------|---------|
|
||||
| Fuzzy dedup (8TB) | 120 hours | 7.5 hours | 16× |
|
||||
| Exact dedup (1TB) | 8 hours | 0.5 hours | 16× |
|
||||
| Quality filtering | 2 hours | 0.2 hours | 10× |
|
||||
|
||||
### Multi-GPU scaling
|
||||
|
||||
```python
|
||||
from nemo_curator import get_client
|
||||
import dask_cuda
|
||||
|
||||
# Initialize GPU cluster
|
||||
client = get_client(cluster_type="gpu", n_workers=8)
|
||||
|
||||
# Process with 8 GPUs
|
||||
deduped = FuzzyDuplicates(...)(dataset)
|
||||
```
|
||||
|
||||
## Multi-modal curation
|
||||
|
||||
### Image curation
|
||||
|
||||
```python
|
||||
from nemo_curator.image import (
|
||||
AestheticFilter,
|
||||
NSFWFilter,
|
||||
CLIPEmbedder
|
||||
)
|
||||
|
||||
# Aesthetic scoring
|
||||
aesthetic_filter = AestheticFilter(threshold=5.0)
|
||||
filtered_images = aesthetic_filter(image_dataset)
|
||||
|
||||
# NSFW detection
|
||||
nsfw_filter = NSFWFilter(threshold=0.9)
|
||||
safe_images = nsfw_filter(filtered_images)
|
||||
|
||||
# Generate CLIP embeddings
|
||||
clip_embedder = CLIPEmbedder(model="openai/clip-vit-base-patch32")
|
||||
image_embeddings = clip_embedder(safe_images)
|
||||
```
|
||||
|
||||
### Video curation
|
||||
|
||||
```python
|
||||
from nemo_curator.video import (
|
||||
SceneDetector,
|
||||
ClipExtractor,
|
||||
InternVideo2Embedder
|
||||
)
|
||||
|
||||
# Detect scenes
|
||||
scene_detector = SceneDetector(threshold=27.0)
|
||||
scenes = scene_detector(video_dataset)
|
||||
|
||||
# Extract clips
|
||||
clip_extractor = ClipExtractor(min_duration=2.0, max_duration=10.0)
|
||||
clips = clip_extractor(scenes)
|
||||
|
||||
# Generate embeddings
|
||||
video_embedder = InternVideo2Embedder()
|
||||
video_embeddings = video_embedder(clips)
|
||||
```
|
||||
|
||||
### Audio curation
|
||||
|
||||
```python
|
||||
from nemo_curator.audio import (
|
||||
ASRInference,
|
||||
WERFilter,
|
||||
DurationFilter
|
||||
)
|
||||
|
||||
# ASR transcription
|
||||
asr = ASRInference(model="nvidia/stt_en_fastconformer_hybrid_large_pc")
|
||||
transcribed = asr(audio_dataset)
|
||||
|
||||
# Filter by WER (word error rate)
|
||||
wer_filter = WERFilter(max_wer=0.3)
|
||||
high_quality_audio = wer_filter(transcribed)
|
||||
|
||||
# Duration filtering
|
||||
duration_filter = DurationFilter(min_duration=1.0, max_duration=30.0)
|
||||
filtered_audio = duration_filter(high_quality_audio)
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Web scrape curation (Common Crawl)
|
||||
|
||||
```python
|
||||
from nemo_curator import ScoreFilter, Modify
|
||||
from nemo_curator.filters import *
|
||||
from nemo_curator.modules import *
|
||||
from nemo_curator.datasets import DocumentDataset
|
||||
|
||||
# Load Common Crawl data
|
||||
dataset = DocumentDataset.read_parquet("common_crawl/*.parquet")
|
||||
|
||||
# Pipeline
|
||||
pipeline = [
|
||||
# 1. Quality filtering
|
||||
WordCountFilter(min_words=100, max_words=50000),
|
||||
RepeatedLinesFilter(max_repeated_line_fraction=0.2),
|
||||
SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3),
|
||||
UrlRatioFilter(max_url_ratio=0.3),
|
||||
|
||||
# 2. Language filtering
|
||||
LanguageIdentificationFilter(target_languages=["en"]),
|
||||
|
||||
# 3. Deduplication
|
||||
ExactDuplicates(id_field="id", text_field="text"),
|
||||
FuzzyDuplicates(id_field="id", text_field="text", num_hashes=260),
|
||||
|
||||
# 4. PII redaction
|
||||
PIIRedactor(),
|
||||
|
||||
# 5. NSFW filtering
|
||||
NSFWClassifier(threshold=0.8)
|
||||
]
|
||||
|
||||
# Execute
|
||||
for stage in pipeline:
|
||||
dataset = stage(dataset)
|
||||
|
||||
# Save
|
||||
dataset.to_parquet("curated_common_crawl/")
|
||||
```
|
||||
|
||||
### Distributed processing
|
||||
|
||||
```python
|
||||
from nemo_curator import get_client
|
||||
from dask_cuda import LocalCUDACluster
|
||||
|
||||
# Multi-GPU cluster
|
||||
cluster = LocalCUDACluster(n_workers=8)
|
||||
client = get_client(cluster=cluster)
|
||||
|
||||
# Process large dataset
|
||||
dataset = DocumentDataset.read_parquet("s3://large_dataset/*.parquet")
|
||||
deduped = FuzzyDuplicates(...)(dataset)
|
||||
|
||||
# Cleanup
|
||||
client.close()
|
||||
cluster.close()
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Fuzzy deduplication (8TB RedPajama v2)
|
||||
|
||||
- **CPU (256 cores)**: 120 hours
|
||||
- **GPU (8× A100)**: 7.5 hours
|
||||
- **Speedup**: 16×
|
||||
|
||||
### Exact deduplication (1TB)
|
||||
|
||||
- **CPU (64 cores)**: 8 hours
|
||||
- **GPU (4× A100)**: 0.5 hours
|
||||
- **Speedup**: 16×
|
||||
|
||||
### Quality filtering (100GB)
|
||||
|
||||
- **CPU (32 cores)**: 2 hours
|
||||
- **GPU (2× A100)**: 0.2 hours
|
||||
- **Speedup**: 10×
|
||||
|
||||
## Cost comparison
|
||||
|
||||
**CPU-based curation** (AWS c5.18xlarge × 10):
|
||||
- Cost: $3.60/hour × 10 = $36/hour
|
||||
- Time for 8TB: 120 hours
|
||||
- **Total**: $4,320
|
||||
|
||||
**GPU-based curation** (AWS p4d.24xlarge × 2):
|
||||
- Cost: $32.77/hour × 2 = $65.54/hour
|
||||
- Time for 8TB: 7.5 hours
|
||||
- **Total**: $491.55
|
||||
|
||||
**Savings**: 89% reduction ($3,828 saved)
|
||||
|
||||
## Supported data formats
|
||||
|
||||
- **Input**: Parquet, JSONL, CSV
|
||||
- **Output**: Parquet (recommended), JSONL
|
||||
- **WebDataset**: TAR archives for multi-modal
|
||||
|
||||
## Use cases
|
||||
|
||||
**Production deployments**:
|
||||
- NVIDIA used NeMo Curator to prepare Nemotron-4 training data
|
||||
- Open-source datasets curated: RedPajama v2, The Pile
|
||||
|
||||
## References
|
||||
|
||||
- **[Filtering Guide](references/filtering.md)** - 30+ quality filters, heuristics
|
||||
- **[Deduplication Guide](references/deduplication.md)** - Exact, fuzzy, semantic methods
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/NVIDIA/NeMo-Curator ⭐ 500+
|
||||
- **Docs**: https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/
|
||||
- **Version**: 0.4.0+
|
||||
- **License**: Apache 2.0
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
# Deduplication Guide
|
||||
|
||||
Complete guide to exact, fuzzy, and semantic deduplication.
|
||||
|
||||
## Exact deduplication
|
||||
|
||||
Remove documents with identical content.
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import ExactDuplicates
|
||||
|
||||
# Exact deduplication
|
||||
exact_dedup = ExactDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
hash_method="md5" # or "sha256"
|
||||
)
|
||||
|
||||
deduped = exact_dedup(dataset)
|
||||
```
|
||||
|
||||
**Performance**: ~16× faster on GPU vs CPU
|
||||
|
||||
## Fuzzy deduplication
|
||||
|
||||
Remove near-duplicate documents using MinHash + LSH.
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import FuzzyDuplicates
|
||||
|
||||
fuzzy_dedup = FuzzyDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
num_hashes=260, # MinHash permutations (more = accurate)
|
||||
num_buckets=20, # LSH buckets (more = faster, less recall)
|
||||
hash_method="md5",
|
||||
jaccard_threshold=0.8 # Similarity threshold
|
||||
)
|
||||
|
||||
deduped = fuzzy_dedup(dataset)
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `num_hashes`: 128-512 (default 260)
|
||||
- `num_buckets`: 10-50 (default 20)
|
||||
- `jaccard_threshold`: 0.7-0.9 (default 0.8)
|
||||
|
||||
**Performance**: 16× faster on 8TB dataset (120h → 7.5h)
|
||||
|
||||
## Semantic deduplication
|
||||
|
||||
Remove semantically similar documents using embeddings.
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import SemanticDuplicates
|
||||
|
||||
semantic_dedup = SemanticDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
embedding_batch_size=256,
|
||||
threshold=0.85, # Cosine similarity threshold
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
deduped = semantic_dedup(dataset)
|
||||
```
|
||||
|
||||
**Models**:
|
||||
- `all-MiniLM-L6-v2`: Fast, 384 dims
|
||||
- `all-mpnet-base-v2`: Better quality, 768 dims
|
||||
- Custom models supported
|
||||
|
||||
## Comparison
|
||||
|
||||
| Method | Speed | Recall | Use Case |
|
||||
|--------|-------|--------|----------|
|
||||
| Exact | Fastest | 100% | Exact matches only |
|
||||
| Fuzzy | Fast | ~95% | Near-duplicates (recommended) |
|
||||
| Semantic | Slow | ~90% | Paraphrases, rewrites |
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with exact dedup** - Remove obvious duplicates
|
||||
2. **Use fuzzy for large datasets** - Best speed/quality trade-off
|
||||
3. **Semantic for high-value data** - Expensive but thorough
|
||||
4. **GPU acceleration required** - 10-16× speedup
|
||||
@@ -0,0 +1,102 @@
|
||||
# Quality Filtering Guide
|
||||
|
||||
Complete guide to NeMo Curator's 30+ quality filters.
|
||||
|
||||
## Text-based filters
|
||||
|
||||
### Word count
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import WordCountFilter
|
||||
|
||||
# Filter by word count
|
||||
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
|
||||
```
|
||||
|
||||
### Repeated content
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import RepeatedLinesFilter
|
||||
|
||||
# Remove documents with >30% repeated lines
|
||||
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
|
||||
```
|
||||
|
||||
### Symbol ratio
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import SymbolToWordRatioFilter
|
||||
|
||||
# Remove documents with too many symbols
|
||||
dataset = dataset.filter(SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3))
|
||||
```
|
||||
|
||||
### URL ratio
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import UrlRatioFilter
|
||||
|
||||
# Remove documents with many URLs
|
||||
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
|
||||
```
|
||||
|
||||
## Language filtering
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import LanguageIdentificationFilter
|
||||
|
||||
# Keep only English documents
|
||||
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en"]))
|
||||
|
||||
# Multiple languages
|
||||
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en", "es", "fr"]))
|
||||
```
|
||||
|
||||
## Classifier-based filtering
|
||||
|
||||
### Quality classifier
|
||||
|
||||
```python
|
||||
from nemo_curator.classifiers import QualityClassifier
|
||||
|
||||
quality_clf = QualityClassifier(
|
||||
model_path="nvidia/quality-classifier-deberta",
|
||||
batch_size=256,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Filter low-quality (threshold > 0.5 = high quality)
|
||||
dataset = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
|
||||
```
|
||||
|
||||
### NSFW classifier
|
||||
|
||||
```python
|
||||
from nemo_curator.classifiers import NSFWClassifier
|
||||
|
||||
nsfw_clf = NSFWClassifier(threshold=0.9, device="cuda")
|
||||
|
||||
# Remove NSFW content
|
||||
dataset = dataset.filter(lambda doc: nsfw_clf(doc["text"]) < 0.9)
|
||||
```
|
||||
|
||||
## Heuristic filters
|
||||
|
||||
Full list of 30+ filters:
|
||||
- WordCountFilter
|
||||
- RepeatedLinesFilter
|
||||
- UrlRatioFilter
|
||||
- SymbolToWordRatioFilter
|
||||
- NonAlphaNumericFilter
|
||||
- BulletsFilter
|
||||
- WhiteSpaceFilter
|
||||
- ParenthesesFilter
|
||||
- LongWordFilter
|
||||
- And 20+ more...
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Apply cheap filters first** - Word count before GPU classifiers
|
||||
2. **Tune thresholds on sample** - Test on 10k docs before full run
|
||||
3. **Use GPU classifiers sparingly** - Expensive but effective
|
||||
4. **Chain filters efficiently** - Order by cost (cheap → expensive)
|
||||
389
protected/skills-backup/mlops/evaluation/saelens/SKILL.md
Normal file
389
protected/skills-backup/mlops/evaluation/saelens/SKILL.md
Normal file
@@ -0,0 +1,389 @@
|
||||
---
|
||||
name: sparse-autoencoder-training
|
||||
description: Provides guidance for training and analyzing Sparse Autoencoders (SAEs) using SAELens to decompose neural network activations into interpretable features. Use when discovering interpretable features, analyzing superposition, or studying monosemantic representations in language models.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [sae-lens>=6.0.0, transformer-lens>=2.0.0, torch>=2.0.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Sparse Autoencoders, SAE, Mechanistic Interpretability, Feature Discovery, Superposition]
|
||||
|
||||
---
|
||||
|
||||
# SAELens: Sparse Autoencoders for Mechanistic Interpretability
|
||||
|
||||
SAELens is the primary library for training and analyzing Sparse Autoencoders (SAEs) - a technique for decomposing polysemantic neural network activations into sparse, interpretable features. Based on Anthropic's groundbreaking research on monosemanticity.
|
||||
|
||||
**GitHub**: [jbloomAus/SAELens](https://github.com/jbloomAus/SAELens) (1,100+ stars)
|
||||
|
||||
## The Problem: Polysemanticity & Superposition
|
||||
|
||||
Individual neurons in neural networks are **polysemantic** - they activate in multiple, semantically distinct contexts. This happens because models use **superposition** to represent more features than they have neurons, making interpretability difficult.
|
||||
|
||||
**SAEs solve this** by decomposing dense activations into sparse, monosemantic features - typically only a small number of features activate for any given input, and each feature corresponds to an interpretable concept.
|
||||
|
||||
## When to Use SAELens
|
||||
|
||||
**Use SAELens when you need to:**
|
||||
- Discover interpretable features in model activations
|
||||
- Understand what concepts a model has learned
|
||||
- Study superposition and feature geometry
|
||||
- Perform feature-based steering or ablation
|
||||
- Analyze safety-relevant features (deception, bias, harmful content)
|
||||
|
||||
**Consider alternatives when:**
|
||||
- You need basic activation analysis → Use **TransformerLens** directly
|
||||
- You want causal intervention experiments → Use **pyvene** or **TransformerLens**
|
||||
- You need production steering → Consider direct activation engineering
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install sae-lens
|
||||
```
|
||||
|
||||
Requirements: Python 3.10+, transformer-lens>=2.0.0
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### What SAEs Learn
|
||||
|
||||
SAEs are trained to reconstruct model activations through a sparse bottleneck:
|
||||
|
||||
```
|
||||
Input Activation → Encoder → Sparse Features → Decoder → Reconstructed Activation
|
||||
(d_model) ↓ (d_sae >> d_model) ↓ (d_model)
|
||||
sparsity reconstruction
|
||||
penalty loss
|
||||
```
|
||||
|
||||
**Loss Function**: `MSE(original, reconstructed) + L1_coefficient × L1(features)`
|
||||
|
||||
### Key Validation (Anthropic Research)
|
||||
|
||||
In "Towards Monosemanticity", human evaluators found **70% of SAE features genuinely interpretable**. Features discovered include:
|
||||
- DNA sequences, legal language, HTTP requests
|
||||
- Hebrew text, nutrition statements, code syntax
|
||||
- Sentiment, named entities, grammatical structures
|
||||
|
||||
## Workflow 1: Loading and Analyzing Pre-trained SAEs
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
|
||||
# 1. Load model and pre-trained SAE
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# 2. Get model activations
|
||||
tokens = model.to_tokens("The capital of France is Paris")
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8] # [batch, pos, d_model]
|
||||
|
||||
# 3. Encode to SAE features
|
||||
sae_features = sae.encode(activations) # [batch, pos, d_sae]
|
||||
print(f"Active features: {(sae_features > 0).sum()}")
|
||||
|
||||
# 4. Find top features for each position
|
||||
for pos in range(tokens.shape[1]):
|
||||
top_features = sae_features[0, pos].topk(5)
|
||||
token = model.to_str_tokens(tokens[0, pos:pos+1])[0]
|
||||
print(f"Token '{token}': features {top_features.indices.tolist()}")
|
||||
|
||||
# 5. Reconstruct activations
|
||||
reconstructed = sae.decode(sae_features)
|
||||
reconstruction_error = (activations - reconstructed).norm()
|
||||
```
|
||||
|
||||
### Available Pre-trained SAEs
|
||||
|
||||
| Release | Model | Layers |
|
||||
|---------|-------|--------|
|
||||
| `gpt2-small-res-jb` | GPT-2 Small | Multiple residual streams |
|
||||
| `gemma-2b-res` | Gemma 2B | Residual streams |
|
||||
| Various on HuggingFace | Search tag `saelens` | Various |
|
||||
|
||||
### Checklist
|
||||
- [ ] Load model with TransformerLens
|
||||
- [ ] Load matching SAE for target layer
|
||||
- [ ] Encode activations to sparse features
|
||||
- [ ] Identify top-activating features per token
|
||||
- [ ] Validate reconstruction quality
|
||||
|
||||
## Workflow 2: Training a Custom SAE
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner
|
||||
|
||||
# 1. Configure training
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
# Model
|
||||
model_name="gpt2-small",
|
||||
hook_name="blocks.8.hook_resid_pre",
|
||||
hook_layer=8,
|
||||
d_in=768, # Model dimension
|
||||
|
||||
# SAE architecture
|
||||
architecture="standard", # or "gated", "topk"
|
||||
d_sae=768 * 8, # Expansion factor of 8
|
||||
activation_fn="relu",
|
||||
|
||||
# Training
|
||||
lr=4e-4,
|
||||
l1_coefficient=8e-5, # Sparsity penalty
|
||||
l1_warm_up_steps=1000,
|
||||
train_batch_size_tokens=4096,
|
||||
training_tokens=100_000_000,
|
||||
|
||||
# Data
|
||||
dataset_path="monology/pile-uncopyrighted",
|
||||
context_size=128,
|
||||
|
||||
# Logging
|
||||
log_to_wandb=True,
|
||||
wandb_project="sae-training",
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_path="checkpoints",
|
||||
n_checkpoints=5,
|
||||
)
|
||||
|
||||
# 2. Train
|
||||
trainer = SAETrainingRunner(cfg)
|
||||
sae = trainer.run()
|
||||
|
||||
# 3. Evaluate
|
||||
print(f"L0 (avg active features): {trainer.metrics['l0']}")
|
||||
print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
|
||||
```
|
||||
|
||||
### Key Hyperparameters
|
||||
|
||||
| Parameter | Typical Value | Effect |
|
||||
|-----------|---------------|--------|
|
||||
| `d_sae` | 4-16× d_model | More features, higher capacity |
|
||||
| `l1_coefficient` | 5e-5 to 1e-4 | Higher = sparser, less accurate |
|
||||
| `lr` | 1e-4 to 1e-3 | Standard optimizer LR |
|
||||
| `l1_warm_up_steps` | 500-2000 | Prevents early feature death |
|
||||
|
||||
### Evaluation Metrics
|
||||
|
||||
| Metric | Target | Meaning |
|
||||
|--------|--------|---------|
|
||||
| **L0** | 50-200 | Average active features per token |
|
||||
| **CE Loss Score** | 80-95% | Cross-entropy recovered vs original |
|
||||
| **Dead Features** | <5% | Features that never activate |
|
||||
| **Explained Variance** | >90% | Reconstruction quality |
|
||||
|
||||
### Checklist
|
||||
- [ ] Choose target layer and hook point
|
||||
- [ ] Set expansion factor (d_sae = 4-16× d_model)
|
||||
- [ ] Tune L1 coefficient for desired sparsity
|
||||
- [ ] Enable L1 warm-up to prevent dead features
|
||||
- [ ] Monitor metrics during training (W&B)
|
||||
- [ ] Validate L0 and CE loss recovery
|
||||
- [ ] Check dead feature ratio
|
||||
|
||||
## Workflow 3: Feature Analysis and Steering
|
||||
|
||||
### Analyzing Individual Features
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, _, _ = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Find what activates a specific feature
|
||||
feature_idx = 1234
|
||||
test_texts = [
|
||||
"The scientist conducted an experiment",
|
||||
"I love chocolate cake",
|
||||
"The code compiles successfully",
|
||||
"Paris is beautiful in spring",
|
||||
]
|
||||
|
||||
for text in test_texts:
|
||||
tokens = model.to_tokens(text)
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
features = sae.encode(cache["resid_pre", 8])
|
||||
activation = features[0, :, feature_idx].max().item()
|
||||
print(f"{activation:.3f}: {text}")
|
||||
```
|
||||
|
||||
### Feature Steering
|
||||
|
||||
```python
|
||||
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
|
||||
"""Add SAE feature direction to residual stream."""
|
||||
tokens = model.to_tokens(prompt)
|
||||
|
||||
# Get feature direction from decoder
|
||||
feature_direction = sae.W_dec[feature_idx] # [d_model]
|
||||
|
||||
def steering_hook(activation, hook):
|
||||
# Add scaled feature direction at all positions
|
||||
activation += strength * feature_direction
|
||||
return activation
|
||||
|
||||
# Generate with steering
|
||||
output = model.generate(
|
||||
tokens,
|
||||
max_new_tokens=50,
|
||||
fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
|
||||
)
|
||||
return model.to_string(output[0])
|
||||
```
|
||||
|
||||
### Feature Attribution
|
||||
|
||||
```python
|
||||
# Which features most affect a specific output?
|
||||
tokens = model.to_tokens("The capital of France is")
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
|
||||
# Get features at final position
|
||||
features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]
|
||||
|
||||
# Get logit attribution per feature
|
||||
# Feature contribution = feature_activation × decoder_weight × unembedding
|
||||
W_dec = sae.W_dec # [d_sae, d_model]
|
||||
W_U = model.W_U # [d_model, vocab]
|
||||
|
||||
# Contribution to "Paris" logit
|
||||
paris_token = model.to_single_token(" Paris")
|
||||
feature_contributions = features * (W_dec @ W_U[:, paris_token])
|
||||
|
||||
top_features = feature_contributions.topk(10)
|
||||
print("Top features for 'Paris' prediction:")
|
||||
for idx, val in zip(top_features.indices, top_features.values):
|
||||
print(f" Feature {idx.item()}: {val.item():.3f}")
|
||||
```
|
||||
|
||||
## Common Issues & Solutions
|
||||
|
||||
### Issue: High dead feature ratio
|
||||
```python
|
||||
# WRONG: No warm-up, features die early
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
l1_coefficient=1e-4,
|
||||
l1_warm_up_steps=0, # Bad!
|
||||
)
|
||||
|
||||
# RIGHT: Warm-up L1 penalty
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
l1_coefficient=8e-5,
|
||||
l1_warm_up_steps=1000, # Gradually increase
|
||||
use_ghost_grads=True, # Revive dead features
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Poor reconstruction (low CE recovery)
|
||||
```python
|
||||
# Reduce sparsity penalty
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
l1_coefficient=5e-5, # Lower = better reconstruction
|
||||
d_sae=768 * 16, # More capacity
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Features not interpretable
|
||||
```python
|
||||
# Increase sparsity (higher L1)
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
l1_coefficient=1e-4, # Higher = sparser, more interpretable
|
||||
)
|
||||
# Or use TopK architecture
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="topk",
|
||||
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Memory errors during training
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
train_batch_size_tokens=2048, # Reduce batch size
|
||||
store_batch_size_prompts=4, # Fewer prompts in buffer
|
||||
n_batches_in_buffer=8, # Smaller activation buffer
|
||||
)
|
||||
```
|
||||
|
||||
## Integration with Neuronpedia
|
||||
|
||||
Browse pre-trained SAE features at [neuronpedia.org](https://neuronpedia.org):
|
||||
|
||||
```python
|
||||
# Features are indexed by SAE ID
|
||||
# Example: gpt2-small layer 8 feature 1234
|
||||
# → neuronpedia.org/gpt2-small/8-res-jb/1234
|
||||
```
|
||||
|
||||
## Key Classes Reference
|
||||
|
||||
| Class | Purpose |
|
||||
|-------|---------|
|
||||
| `SAE` | Sparse Autoencoder model |
|
||||
| `LanguageModelSAERunnerConfig` | Training configuration |
|
||||
| `SAETrainingRunner` | Training loop manager |
|
||||
| `ActivationsStore` | Activation collection and batching |
|
||||
| `HookedSAETransformer` | TransformerLens + SAE integration |
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
For detailed API documentation, tutorials, and advanced usage, see the `references/` folder:
|
||||
|
||||
| File | Contents |
|
||||
|------|----------|
|
||||
| [references/README.md](references/README.md) | Overview and quick start guide |
|
||||
| [references/api.md](references/api.md) | Complete API reference for SAE, TrainingSAE, configurations |
|
||||
| [references/tutorials.md](references/tutorials.md) | Step-by-step tutorials for training, analysis, steering |
|
||||
|
||||
## External Resources
|
||||
|
||||
### Tutorials
|
||||
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
|
||||
- [Training a Sparse Autoencoder](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
||||
- [ARENA SAE Curriculum](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab)
|
||||
|
||||
### Papers
|
||||
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
|
||||
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
|
||||
- [Sparse Autoencoders Find Highly Interpretable Features](https://arxiv.org/abs/2309.08600) - Cunningham et al. (ICLR 2024)
|
||||
|
||||
### Official Documentation
|
||||
- [SAELens Docs](https://jbloomaus.github.io/SAELens/)
|
||||
- [Neuronpedia](https://neuronpedia.org) - Feature browser
|
||||
|
||||
## SAE Architectures
|
||||
|
||||
| Architecture | Description | Use Case |
|
||||
|--------------|-------------|----------|
|
||||
| **Standard** | ReLU + L1 penalty | General purpose |
|
||||
| **Gated** | Learned gating mechanism | Better sparsity control |
|
||||
| **TopK** | Exactly K active features | Consistent sparsity |
|
||||
|
||||
```python
|
||||
# TopK SAE (exactly 50 features active)
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="topk",
|
||||
activation_fn="topk",
|
||||
activation_fn_kwargs={"k": 50},
|
||||
)
|
||||
```
|
||||
@@ -0,0 +1,70 @@
|
||||
# SAELens Reference Documentation
|
||||
|
||||
This directory contains comprehensive reference materials for SAELens.
|
||||
|
||||
## Contents
|
||||
|
||||
- [api.md](api.md) - Complete API reference for SAE, TrainingSAE, and configuration classes
|
||||
- [tutorials.md](tutorials.md) - Step-by-step tutorials for training and analyzing SAEs
|
||||
- [papers.md](papers.md) - Key research papers on sparse autoencoders
|
||||
|
||||
## Quick Links
|
||||
|
||||
- **GitHub Repository**: https://github.com/jbloomAus/SAELens
|
||||
- **Neuronpedia**: https://neuronpedia.org (browse pre-trained SAE features)
|
||||
- **HuggingFace SAEs**: Search for tag `saelens`
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install sae-lens
|
||||
```
|
||||
|
||||
Requirements: Python 3.10+, transformer-lens>=2.0.0
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
|
||||
# Load model and SAE
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Encode activations to sparse features
|
||||
tokens = model.to_tokens("Hello world")
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8]
|
||||
|
||||
features = sae.encode(activations) # Sparse feature activations
|
||||
reconstructed = sae.decode(features) # Reconstructed activations
|
||||
```
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### Sparse Autoencoders
|
||||
SAEs decompose dense neural activations into sparse, interpretable features:
|
||||
- **Encoder**: Maps d_model → d_sae (typically 4-16x expansion)
|
||||
- **ReLU/TopK**: Enforces sparsity
|
||||
- **Decoder**: Reconstructs original activations
|
||||
|
||||
### Training Loss
|
||||
`Loss = MSE(original, reconstructed) + L1_coefficient × L1(features)`
|
||||
|
||||
### Key Metrics
|
||||
- **L0**: Average number of active features (target: 50-200)
|
||||
- **CE Loss Score**: Cross-entropy recovered vs original model (target: 80-95%)
|
||||
- **Dead Features**: Features that never activate (target: <5%)
|
||||
|
||||
## Available Pre-trained SAEs
|
||||
|
||||
| Release | Model | Description |
|
||||
|---------|-------|-------------|
|
||||
| `gpt2-small-res-jb` | GPT-2 Small | Residual stream SAEs |
|
||||
| `gemma-2b-res` | Gemma 2B | Residual stream SAEs |
|
||||
| Various | Search HuggingFace | Community-trained SAEs |
|
||||
@@ -0,0 +1,333 @@
|
||||
# SAELens API Reference
|
||||
|
||||
## SAE Class
|
||||
|
||||
The core class representing a Sparse Autoencoder.
|
||||
|
||||
### Loading Pre-trained SAEs
|
||||
|
||||
```python
|
||||
from sae_lens import SAE
|
||||
|
||||
# From official releases
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# From HuggingFace
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="username/repo-name",
|
||||
sae_id="path/to/sae",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# From local disk
|
||||
sae = SAE.load_from_disk("/path/to/sae", device="cuda")
|
||||
```
|
||||
|
||||
### SAE Attributes
|
||||
|
||||
| Attribute | Shape | Description |
|
||||
|-----------|-------|-------------|
|
||||
| `W_enc` | [d_in, d_sae] | Encoder weights |
|
||||
| `W_dec` | [d_sae, d_in] | Decoder weights |
|
||||
| `b_enc` | [d_sae] | Encoder bias |
|
||||
| `b_dec` | [d_in] | Decoder bias |
|
||||
| `cfg` | SAEConfig | Configuration object |
|
||||
|
||||
### Core Methods
|
||||
|
||||
#### encode()
|
||||
|
||||
```python
|
||||
# Encode activations to sparse features
|
||||
features = sae.encode(activations)
|
||||
# Input: [batch, pos, d_in]
|
||||
# Output: [batch, pos, d_sae]
|
||||
```
|
||||
|
||||
#### decode()
|
||||
|
||||
```python
|
||||
# Reconstruct activations from features
|
||||
reconstructed = sae.decode(features)
|
||||
# Input: [batch, pos, d_sae]
|
||||
# Output: [batch, pos, d_in]
|
||||
```
|
||||
|
||||
#### forward()
|
||||
|
||||
```python
|
||||
# Full forward pass (encode + decode)
|
||||
reconstructed = sae(activations)
|
||||
# Returns reconstructed activations
|
||||
```
|
||||
|
||||
#### save_model()
|
||||
|
||||
```python
|
||||
sae.save_model("/path/to/save")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## SAEConfig
|
||||
|
||||
Configuration class for SAE architecture and training context.
|
||||
|
||||
### Key Parameters
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `d_in` | int | Input dimension (model's d_model) |
|
||||
| `d_sae` | int | SAE hidden dimension |
|
||||
| `architecture` | str | "standard", "gated", "jumprelu", "topk" |
|
||||
| `activation_fn_str` | str | Activation function name |
|
||||
| `model_name` | str | Source model name |
|
||||
| `hook_name` | str | Hook point in model |
|
||||
| `normalize_activations` | str | Normalization method |
|
||||
| `dtype` | str | Data type |
|
||||
| `device` | str | Device |
|
||||
|
||||
### Accessing Config
|
||||
|
||||
```python
|
||||
print(sae.cfg.d_in) # 768 for GPT-2 small
|
||||
print(sae.cfg.d_sae) # e.g., 24576 (32x expansion)
|
||||
print(sae.cfg.hook_name) # e.g., "blocks.8.hook_resid_pre"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## LanguageModelSAERunnerConfig
|
||||
|
||||
Comprehensive configuration for training SAEs.
|
||||
|
||||
### Example Configuration
|
||||
|
||||
```python
|
||||
from sae_lens import LanguageModelSAERunnerConfig
|
||||
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
# Model and hook
|
||||
model_name="gpt2-small",
|
||||
hook_name="blocks.8.hook_resid_pre",
|
||||
hook_layer=8,
|
||||
d_in=768,
|
||||
|
||||
# SAE architecture
|
||||
architecture="standard", # "standard", "gated", "jumprelu", "topk"
|
||||
d_sae=768 * 8, # Expansion factor
|
||||
activation_fn="relu",
|
||||
|
||||
# Training hyperparameters
|
||||
lr=4e-4,
|
||||
l1_coefficient=8e-5,
|
||||
lp_norm=1.0,
|
||||
lr_scheduler_name="constant",
|
||||
lr_warm_up_steps=500,
|
||||
|
||||
# Sparsity control
|
||||
l1_warm_up_steps=1000,
|
||||
use_ghost_grads=True,
|
||||
feature_sampling_window=1000,
|
||||
dead_feature_window=5000,
|
||||
dead_feature_threshold=1e-8,
|
||||
|
||||
# Data
|
||||
dataset_path="monology/pile-uncopyrighted",
|
||||
streaming=True,
|
||||
context_size=128,
|
||||
|
||||
# Batch sizes
|
||||
train_batch_size_tokens=4096,
|
||||
store_batch_size_prompts=16,
|
||||
n_batches_in_buffer=64,
|
||||
|
||||
# Training duration
|
||||
training_tokens=100_000_000,
|
||||
|
||||
# Logging
|
||||
log_to_wandb=True,
|
||||
wandb_project="sae-training",
|
||||
wandb_log_frequency=100,
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_path="checkpoints",
|
||||
n_checkpoints=5,
|
||||
|
||||
# Hardware
|
||||
device="cuda",
|
||||
dtype="float32",
|
||||
)
|
||||
```
|
||||
|
||||
### Key Parameters Explained
|
||||
|
||||
#### Architecture Parameters
|
||||
|
||||
| Parameter | Description |
|
||||
|-----------|-------------|
|
||||
| `architecture` | SAE type: "standard", "gated", "jumprelu", "topk" |
|
||||
| `d_sae` | Hidden dimension (or use `expansion_factor`) |
|
||||
| `expansion_factor` | Alternative to d_sae: d_sae = d_in × expansion_factor |
|
||||
| `activation_fn` | "relu", "topk", etc. |
|
||||
| `activation_fn_kwargs` | Dict for activation params (e.g., {"k": 50} for topk) |
|
||||
|
||||
#### Sparsity Parameters
|
||||
|
||||
| Parameter | Description |
|
||||
|-----------|-------------|
|
||||
| `l1_coefficient` | L1 penalty weight (higher = sparser) |
|
||||
| `l1_warm_up_steps` | Steps to ramp up L1 penalty |
|
||||
| `use_ghost_grads` | Apply gradients to dead features |
|
||||
| `dead_feature_threshold` | Activation threshold for "dead" |
|
||||
| `dead_feature_window` | Steps to check for dead features |
|
||||
|
||||
#### Learning Rate Parameters
|
||||
|
||||
| Parameter | Description |
|
||||
|-----------|-------------|
|
||||
| `lr` | Base learning rate |
|
||||
| `lr_scheduler_name` | "constant", "cosineannealing", etc. |
|
||||
| `lr_warm_up_steps` | LR warmup steps |
|
||||
| `lr_decay_steps` | Steps for LR decay |
|
||||
|
||||
---
|
||||
|
||||
## SAETrainingRunner
|
||||
|
||||
Main class for executing training.
|
||||
|
||||
### Basic Training
|
||||
|
||||
```python
|
||||
from sae_lens import SAETrainingRunner, LanguageModelSAERunnerConfig
|
||||
|
||||
cfg = LanguageModelSAERunnerConfig(...)
|
||||
runner = SAETrainingRunner(cfg)
|
||||
sae = runner.run()
|
||||
```
|
||||
|
||||
### Accessing Training Metrics
|
||||
|
||||
```python
|
||||
# During training, metrics logged to W&B include:
|
||||
# - l0: Average active features
|
||||
# - ce_loss_score: Cross-entropy recovery
|
||||
# - mse_loss: Reconstruction loss
|
||||
# - l1_loss: Sparsity loss
|
||||
# - dead_features: Count of dead features
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ActivationsStore
|
||||
|
||||
Manages activation collection and batching.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from sae_lens import ActivationsStore
|
||||
|
||||
store = ActivationsStore.from_sae(
|
||||
model=model,
|
||||
sae=sae,
|
||||
store_batch_size_prompts=8,
|
||||
train_batch_size_tokens=4096,
|
||||
n_batches_in_buffer=32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Get batch of activations
|
||||
activations = store.get_batch_tokens()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## HookedSAETransformer
|
||||
|
||||
Integration of SAEs with TransformerLens models.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from sae_lens import HookedSAETransformer
|
||||
|
||||
# Load model with SAE
|
||||
model = HookedSAETransformer.from_pretrained("gpt2-small")
|
||||
model.add_sae(sae)
|
||||
|
||||
# Run with SAE in the loop
|
||||
output = model.run_with_saes(tokens, saes=[sae])
|
||||
|
||||
# Cache with SAE activations
|
||||
output, cache = model.run_with_cache_with_saes(tokens, saes=[sae])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## SAE Architectures
|
||||
|
||||
### Standard (ReLU + L1)
|
||||
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="standard",
|
||||
activation_fn="relu",
|
||||
l1_coefficient=8e-5,
|
||||
)
|
||||
```
|
||||
|
||||
### Gated
|
||||
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="gated",
|
||||
)
|
||||
```
|
||||
|
||||
### TopK
|
||||
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="topk",
|
||||
activation_fn="topk",
|
||||
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
|
||||
)
|
||||
```
|
||||
|
||||
### JumpReLU (State-of-the-art)
|
||||
|
||||
```python
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
architecture="jumprelu",
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Utility Functions
|
||||
|
||||
### Upload to HuggingFace
|
||||
|
||||
```python
|
||||
from sae_lens import upload_saes_to_huggingface
|
||||
|
||||
upload_saes_to_huggingface(
|
||||
saes=[sae],
|
||||
repo_id="username/my-saes",
|
||||
token="hf_token",
|
||||
)
|
||||
```
|
||||
|
||||
### Neuronpedia Integration
|
||||
|
||||
```python
|
||||
# Features can be viewed on Neuronpedia
|
||||
# URL format: neuronpedia.org/{model}/{layer}-{sae_type}/{feature_id}
|
||||
# Example: neuronpedia.org/gpt2-small/8-res-jb/1234
|
||||
```
|
||||
@@ -0,0 +1,318 @@
|
||||
# SAELens Tutorials
|
||||
|
||||
## Tutorial 1: Loading and Analyzing Pre-trained SAEs
|
||||
|
||||
### Goal
|
||||
Load a pre-trained SAE and analyze which features activate on specific inputs.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
# 1. Load model and SAE
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, cfg_dict, sparsity = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
print(f"SAE input dim: {sae.cfg.d_in}")
|
||||
print(f"SAE hidden dim: {sae.cfg.d_sae}")
|
||||
print(f"Expansion factor: {sae.cfg.d_sae / sae.cfg.d_in:.1f}x")
|
||||
|
||||
# 2. Get model activations
|
||||
prompt = "The capital of France is Paris"
|
||||
tokens = model.to_tokens(prompt)
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8] # [1, seq_len, 768]
|
||||
|
||||
# 3. Encode to SAE features
|
||||
features = sae.encode(activations) # [1, seq_len, d_sae]
|
||||
|
||||
# 4. Analyze sparsity
|
||||
active_per_token = (features > 0).sum(dim=-1)
|
||||
print(f"Average active features per token: {active_per_token.float().mean():.1f}")
|
||||
|
||||
# 5. Find top features for each token
|
||||
str_tokens = model.to_str_tokens(prompt)
|
||||
for pos in range(len(str_tokens)):
|
||||
top_features = features[0, pos].topk(5)
|
||||
print(f"\nToken '{str_tokens[pos]}':")
|
||||
for feat_idx, feat_val in zip(top_features.indices, top_features.values):
|
||||
print(f" Feature {feat_idx.item()}: {feat_val.item():.3f}")
|
||||
|
||||
# 6. Check reconstruction quality
|
||||
reconstructed = sae.decode(features)
|
||||
mse = ((activations - reconstructed) ** 2).mean()
|
||||
print(f"\nReconstruction MSE: {mse.item():.6f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 2: Training a Custom SAE
|
||||
|
||||
### Goal
|
||||
Train a Sparse Autoencoder on GPT-2 activations.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
|
||||
|
||||
# 1. Configure training
|
||||
cfg = LanguageModelSAERunnerConfig(
|
||||
# Model
|
||||
model_name="gpt2-small",
|
||||
hook_name="blocks.6.hook_resid_pre",
|
||||
hook_layer=6,
|
||||
d_in=768,
|
||||
|
||||
# SAE architecture
|
||||
architecture="standard",
|
||||
d_sae=768 * 8, # 8x expansion
|
||||
activation_fn="relu",
|
||||
|
||||
# Training
|
||||
lr=4e-4,
|
||||
l1_coefficient=8e-5,
|
||||
l1_warm_up_steps=1000,
|
||||
train_batch_size_tokens=4096,
|
||||
training_tokens=10_000_000, # Small run for demo
|
||||
|
||||
# Data
|
||||
dataset_path="monology/pile-uncopyrighted",
|
||||
streaming=True,
|
||||
context_size=128,
|
||||
|
||||
# Dead feature prevention
|
||||
use_ghost_grads=True,
|
||||
dead_feature_window=5000,
|
||||
|
||||
# Logging
|
||||
log_to_wandb=True,
|
||||
wandb_project="sae-training-demo",
|
||||
|
||||
# Hardware
|
||||
device="cuda",
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
# 2. Train
|
||||
runner = SAETrainingRunner(cfg)
|
||||
sae = runner.run()
|
||||
|
||||
# 3. Save
|
||||
sae.save_model("./my_trained_sae")
|
||||
```
|
||||
|
||||
### Hyperparameter Tuning Guide
|
||||
|
||||
| If you see... | Try... |
|
||||
|---------------|--------|
|
||||
| High L0 (>200) | Increase `l1_coefficient` |
|
||||
| Low CE recovery (<80%) | Decrease `l1_coefficient`, increase `d_sae` |
|
||||
| Many dead features (>5%) | Enable `use_ghost_grads`, increase `l1_warm_up_steps` |
|
||||
| Training instability | Lower `lr`, increase `lr_warm_up_steps` |
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 3: Feature Attribution and Steering
|
||||
|
||||
### Goal
|
||||
Identify which SAE features contribute to specific predictions and use them for steering.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, _, _ = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# 1. Feature attribution for a specific prediction
|
||||
prompt = "The capital of France is"
|
||||
tokens = model.to_tokens(prompt)
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8]
|
||||
features = sae.encode(activations)
|
||||
|
||||
# Target token
|
||||
target_token = model.to_single_token(" Paris")
|
||||
|
||||
# Compute feature contributions to target logit
|
||||
# contribution = feature_activation * decoder_weight * unembedding
|
||||
W_dec = sae.W_dec # [d_sae, d_model]
|
||||
W_U = model.W_U # [d_model, d_vocab]
|
||||
|
||||
# Feature direction projected to vocabulary
|
||||
feature_to_logit = W_dec @ W_U # [d_sae, d_vocab]
|
||||
|
||||
# Contribution of each feature to "Paris" at final position
|
||||
feature_acts = features[0, -1] # [d_sae]
|
||||
contributions = feature_acts * feature_to_logit[:, target_token]
|
||||
|
||||
# Top contributing features
|
||||
top_features = contributions.topk(10)
|
||||
print("Top features contributing to 'Paris':")
|
||||
for idx, val in zip(top_features.indices, top_features.values):
|
||||
print(f" Feature {idx.item()}: {val.item():.3f}")
|
||||
|
||||
# 2. Feature steering
|
||||
def steer_with_feature(feature_idx, strength=5.0):
|
||||
"""Add a feature direction to the residual stream."""
|
||||
feature_direction = sae.W_dec[feature_idx] # [d_model]
|
||||
|
||||
def hook(activation, hook_obj):
|
||||
activation[:, -1, :] += strength * feature_direction
|
||||
return activation
|
||||
|
||||
output = model.generate(
|
||||
tokens,
|
||||
max_new_tokens=10,
|
||||
fwd_hooks=[("blocks.8.hook_resid_pre", hook)]
|
||||
)
|
||||
return model.to_string(output[0])
|
||||
|
||||
# Try steering with top feature
|
||||
top_feature_idx = top_features.indices[0].item()
|
||||
print(f"\nSteering with feature {top_feature_idx}:")
|
||||
print(steer_with_feature(top_feature_idx, strength=10.0))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 4: Feature Ablation
|
||||
|
||||
### Goal
|
||||
Test the causal importance of features by ablating them.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, _, _ = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
prompt = "The capital of France is"
|
||||
tokens = model.to_tokens(prompt)
|
||||
|
||||
# Baseline prediction
|
||||
baseline_logits = model(tokens)
|
||||
target_token = model.to_single_token(" Paris")
|
||||
baseline_prob = torch.softmax(baseline_logits[0, -1], dim=-1)[target_token].item()
|
||||
print(f"Baseline P(Paris): {baseline_prob:.4f}")
|
||||
|
||||
# Get features to ablate
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8]
|
||||
features = sae.encode(activations)
|
||||
top_features = features[0, -1].topk(10).indices
|
||||
|
||||
# Ablate top features one by one
|
||||
for feat_idx in top_features:
|
||||
def ablation_hook(activation, hook, feat_idx=feat_idx):
|
||||
# Encode → zero feature → decode
|
||||
feats = sae.encode(activation)
|
||||
feats[:, :, feat_idx] = 0
|
||||
return sae.decode(feats)
|
||||
|
||||
ablated_logits = model.run_with_hooks(
|
||||
tokens,
|
||||
fwd_hooks=[("blocks.8.hook_resid_pre", ablation_hook)]
|
||||
)
|
||||
ablated_prob = torch.softmax(ablated_logits[0, -1], dim=-1)[target_token].item()
|
||||
change = (ablated_prob - baseline_prob) / baseline_prob * 100
|
||||
print(f"Ablate feature {feat_idx.item()}: P(Paris)={ablated_prob:.4f} ({change:+.1f}%)")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tutorial 5: Comparing Features Across Prompts
|
||||
|
||||
### Goal
|
||||
Find which features activate consistently for a concept.
|
||||
|
||||
### Step-by-Step
|
||||
|
||||
```python
|
||||
from transformer_lens import HookedTransformer
|
||||
from sae_lens import SAE
|
||||
import torch
|
||||
|
||||
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
|
||||
sae, _, _ = SAE.from_pretrained(
|
||||
release="gpt2-small-res-jb",
|
||||
sae_id="blocks.8.hook_resid_pre",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Test prompts about the same concept
|
||||
prompts = [
|
||||
"The Eiffel Tower is located in",
|
||||
"Paris is the capital of",
|
||||
"France's largest city is",
|
||||
"The Louvre museum is in",
|
||||
]
|
||||
|
||||
# Collect feature activations
|
||||
all_features = []
|
||||
for prompt in prompts:
|
||||
tokens = model.to_tokens(prompt)
|
||||
_, cache = model.run_with_cache(tokens)
|
||||
activations = cache["resid_pre", 8]
|
||||
features = sae.encode(activations)
|
||||
# Take max activation across positions
|
||||
max_features = features[0].max(dim=0).values
|
||||
all_features.append(max_features)
|
||||
|
||||
all_features = torch.stack(all_features) # [n_prompts, d_sae]
|
||||
|
||||
# Find features that activate consistently
|
||||
mean_activation = all_features.mean(dim=0)
|
||||
min_activation = all_features.min(dim=0).values
|
||||
|
||||
# Features active in ALL prompts
|
||||
consistent_features = (min_activation > 0.5).nonzero().squeeze(-1)
|
||||
print(f"Features active in all prompts: {len(consistent_features)}")
|
||||
|
||||
# Top consistent features
|
||||
top_consistent = mean_activation[consistent_features].topk(min(10, len(consistent_features)))
|
||||
print("\nTop consistent features (possibly 'France/Paris' related):")
|
||||
for idx, val in zip(top_consistent.indices, top_consistent.values):
|
||||
feat_idx = consistent_features[idx].item()
|
||||
print(f" Feature {feat_idx}: mean activation {val.item():.3f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## External Resources
|
||||
|
||||
### Official Tutorials
|
||||
- [Basic Loading & Analysis](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
|
||||
- [Training SAEs](https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
||||
- [Logits Lens with Features](https://github.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
|
||||
|
||||
### ARENA Curriculum
|
||||
Comprehensive SAE course: https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab
|
||||
|
||||
### Key Papers
|
||||
- [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features) - Anthropic (2023)
|
||||
- [Scaling Monosemanticity](https://transformer-circuits.pub/2024/scaling-monosemanticity/) - Anthropic (2024)
|
||||
- [Sparse Autoencoders Find Interpretable Features](https://arxiv.org/abs/2309.08600) - ICLR 2024
|
||||
@@ -0,0 +1,593 @@
|
||||
---
|
||||
name: weights-and-biases
|
||||
description: Track ML experiments with automatic logging, visualize training in real-time, optimize hyperparameters with sweeps, and manage model registry with W&B - collaborative MLOps platform
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [wandb]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [MLOps, Weights And Biases, WandB, Experiment Tracking, Hyperparameter Tuning, Model Registry, Collaboration, Real-Time Visualization, PyTorch, TensorFlow, HuggingFace]
|
||||
|
||||
---
|
||||
|
||||
# Weights & Biases: ML Experiment Tracking & MLOps
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use Weights & Biases (W&B) when you need to:
|
||||
- **Track ML experiments** with automatic metric logging
|
||||
- **Visualize training** in real-time dashboards
|
||||
- **Compare runs** across hyperparameters and configurations
|
||||
- **Optimize hyperparameters** with automated sweeps
|
||||
- **Manage model registry** with versioning and lineage
|
||||
- **Collaborate on ML projects** with team workspaces
|
||||
- **Track artifacts** (datasets, models, code) with lineage
|
||||
|
||||
**Users**: 200,000+ ML practitioners | **GitHub Stars**: 10.5k+ | **Integrations**: 100+
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Install W&B
|
||||
pip install wandb
|
||||
|
||||
# Login (creates API key)
|
||||
wandb login
|
||||
|
||||
# Or set API key programmatically
|
||||
export WANDB_API_KEY=your_api_key_here
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Experiment Tracking
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
# Initialize a run
|
||||
run = wandb.init(
|
||||
project="my-project",
|
||||
config={
|
||||
"learning_rate": 0.001,
|
||||
"epochs": 10,
|
||||
"batch_size": 32,
|
||||
"architecture": "ResNet50"
|
||||
}
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(run.config.epochs):
|
||||
# Your training code
|
||||
train_loss = train_epoch()
|
||||
val_loss = validate()
|
||||
|
||||
# Log metrics
|
||||
wandb.log({
|
||||
"epoch": epoch,
|
||||
"train/loss": train_loss,
|
||||
"val/loss": val_loss,
|
||||
"train/accuracy": train_acc,
|
||||
"val/accuracy": val_acc
|
||||
})
|
||||
|
||||
# Finish the run
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### With PyTorch
|
||||
|
||||
```python
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
# Initialize
|
||||
wandb.init(project="pytorch-demo", config={
|
||||
"lr": 0.001,
|
||||
"epochs": 10
|
||||
})
|
||||
|
||||
# Access config
|
||||
config = wandb.config
|
||||
|
||||
# Training loop
|
||||
for epoch in range(config.epochs):
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
# Forward pass
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# Backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Log every 100 batches
|
||||
if batch_idx % 100 == 0:
|
||||
wandb.log({
|
||||
"loss": loss.item(),
|
||||
"epoch": epoch,
|
||||
"batch": batch_idx
|
||||
})
|
||||
|
||||
# Save model
|
||||
torch.save(model.state_dict(), "model.pth")
|
||||
wandb.save("model.pth") # Upload to W&B
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Projects and Runs
|
||||
|
||||
**Project**: Collection of related experiments
|
||||
**Run**: Single execution of your training script
|
||||
|
||||
```python
|
||||
# Create/use project
|
||||
run = wandb.init(
|
||||
project="image-classification",
|
||||
name="resnet50-experiment-1", # Optional run name
|
||||
tags=["baseline", "resnet"], # Organize with tags
|
||||
notes="First baseline run" # Add notes
|
||||
)
|
||||
|
||||
# Each run has unique ID
|
||||
print(f"Run ID: {run.id}")
|
||||
print(f"Run URL: {run.url}")
|
||||
```
|
||||
|
||||
### 2. Configuration Tracking
|
||||
|
||||
Track hyperparameters automatically:
|
||||
|
||||
```python
|
||||
config = {
|
||||
# Model architecture
|
||||
"model": "ResNet50",
|
||||
"pretrained": True,
|
||||
|
||||
# Training params
|
||||
"learning_rate": 0.001,
|
||||
"batch_size": 32,
|
||||
"epochs": 50,
|
||||
"optimizer": "Adam",
|
||||
|
||||
# Data params
|
||||
"dataset": "ImageNet",
|
||||
"augmentation": "standard"
|
||||
}
|
||||
|
||||
wandb.init(project="my-project", config=config)
|
||||
|
||||
# Access config during training
|
||||
lr = wandb.config.learning_rate
|
||||
batch_size = wandb.config.batch_size
|
||||
```
|
||||
|
||||
### 3. Metric Logging
|
||||
|
||||
```python
|
||||
# Log scalars
|
||||
wandb.log({"loss": 0.5, "accuracy": 0.92})
|
||||
|
||||
# Log multiple metrics
|
||||
wandb.log({
|
||||
"train/loss": train_loss,
|
||||
"train/accuracy": train_acc,
|
||||
"val/loss": val_loss,
|
||||
"val/accuracy": val_acc,
|
||||
"learning_rate": current_lr,
|
||||
"epoch": epoch
|
||||
})
|
||||
|
||||
# Log with custom x-axis
|
||||
wandb.log({"loss": loss}, step=global_step)
|
||||
|
||||
# Log media (images, audio, video)
|
||||
wandb.log({"examples": [wandb.Image(img) for img in images]})
|
||||
|
||||
# Log histograms
|
||||
wandb.log({"gradients": wandb.Histogram(gradients)})
|
||||
|
||||
# Log tables
|
||||
table = wandb.Table(columns=["id", "prediction", "ground_truth"])
|
||||
wandb.log({"predictions": table})
|
||||
```
|
||||
|
||||
### 4. Model Checkpointing
|
||||
|
||||
```python
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
# Save model checkpoint
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss,
|
||||
}
|
||||
|
||||
torch.save(checkpoint, 'checkpoint.pth')
|
||||
|
||||
# Upload to W&B
|
||||
wandb.save('checkpoint.pth')
|
||||
|
||||
# Or use Artifacts (recommended)
|
||||
artifact = wandb.Artifact('model', type='model')
|
||||
artifact.add_file('checkpoint.pth')
|
||||
wandb.log_artifact(artifact)
|
||||
```
|
||||
|
||||
## Hyperparameter Sweeps
|
||||
|
||||
Automatically search for optimal hyperparameters.
|
||||
|
||||
### Define Sweep Configuration
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes', # or 'grid', 'random'
|
||||
'metric': {
|
||||
'name': 'val/accuracy',
|
||||
'goal': 'maximize'
|
||||
},
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128]
|
||||
},
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd', 'rmsprop']
|
||||
},
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.1,
|
||||
'max': 0.5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Initialize sweep
|
||||
sweep_id = wandb.sweep(sweep_config, project="my-project")
|
||||
```
|
||||
|
||||
### Define Training Function
|
||||
|
||||
```python
|
||||
def train():
|
||||
# Initialize run
|
||||
run = wandb.init()
|
||||
|
||||
# Access sweep parameters
|
||||
lr = wandb.config.learning_rate
|
||||
batch_size = wandb.config.batch_size
|
||||
optimizer_name = wandb.config.optimizer
|
||||
|
||||
# Build model with sweep config
|
||||
model = build_model(wandb.config)
|
||||
optimizer = get_optimizer(optimizer_name, lr)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
train_loss = train_epoch(model, optimizer, batch_size)
|
||||
val_acc = validate(model)
|
||||
|
||||
# Log metrics
|
||||
wandb.log({
|
||||
"train/loss": train_loss,
|
||||
"val/accuracy": val_acc
|
||||
})
|
||||
|
||||
# Run sweep
|
||||
wandb.agent(sweep_id, function=train, count=50) # Run 50 trials
|
||||
```
|
||||
|
||||
### Sweep Strategies
|
||||
|
||||
```python
|
||||
# Grid search - exhaustive
|
||||
sweep_config = {
|
||||
'method': 'grid',
|
||||
'parameters': {
|
||||
'lr': {'values': [0.001, 0.01, 0.1]},
|
||||
'batch_size': {'values': [16, 32, 64]}
|
||||
}
|
||||
}
|
||||
|
||||
# Random search
|
||||
sweep_config = {
|
||||
'method': 'random',
|
||||
'parameters': {
|
||||
'lr': {'distribution': 'uniform', 'min': 0.0001, 'max': 0.1},
|
||||
'dropout': {'distribution': 'uniform', 'min': 0.1, 'max': 0.5}
|
||||
}
|
||||
}
|
||||
|
||||
# Bayesian optimization (recommended)
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'val/loss', 'goal': 'minimize'},
|
||||
'parameters': {
|
||||
'lr': {'distribution': 'log_uniform', 'min': 1e-5, 'max': 1e-1}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Artifacts
|
||||
|
||||
Track datasets, models, and other files with lineage.
|
||||
|
||||
### Log Artifacts
|
||||
|
||||
```python
|
||||
# Create artifact
|
||||
artifact = wandb.Artifact(
|
||||
name='training-dataset',
|
||||
type='dataset',
|
||||
description='ImageNet training split',
|
||||
metadata={'size': '1.2M images', 'split': 'train'}
|
||||
)
|
||||
|
||||
# Add files
|
||||
artifact.add_file('data/train.csv')
|
||||
artifact.add_dir('data/images/')
|
||||
|
||||
# Log artifact
|
||||
wandb.log_artifact(artifact)
|
||||
```
|
||||
|
||||
### Use Artifacts
|
||||
|
||||
```python
|
||||
# Download and use artifact
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Download artifact
|
||||
artifact = run.use_artifact('training-dataset:latest')
|
||||
artifact_dir = artifact.download()
|
||||
|
||||
# Use the data
|
||||
data = load_data(f"{artifact_dir}/train.csv")
|
||||
```
|
||||
|
||||
### Model Registry
|
||||
|
||||
```python
|
||||
# Log model as artifact
|
||||
model_artifact = wandb.Artifact(
|
||||
name='resnet50-model',
|
||||
type='model',
|
||||
metadata={'architecture': 'ResNet50', 'accuracy': 0.95}
|
||||
)
|
||||
|
||||
model_artifact.add_file('model.pth')
|
||||
wandb.log_artifact(model_artifact, aliases=['best', 'production'])
|
||||
|
||||
# Link to model registry
|
||||
run.link_artifact(model_artifact, 'model-registry/production-models')
|
||||
```
|
||||
|
||||
## Integration Examples
|
||||
|
||||
### HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project="hf-transformers")
|
||||
|
||||
# Training arguments with W&B
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
report_to="wandb", # Enable W&B logging
|
||||
run_name="bert-finetuning",
|
||||
logging_steps=100,
|
||||
save_steps=500
|
||||
)
|
||||
|
||||
# Trainer automatically logs to W&B
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### PyTorch Lightning
|
||||
|
||||
```python
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
import wandb
|
||||
|
||||
# Create W&B logger
|
||||
wandb_logger = WandbLogger(
|
||||
project="lightning-demo",
|
||||
log_model=True # Log model checkpoints
|
||||
)
|
||||
|
||||
# Use with Trainer
|
||||
trainer = Trainer(
|
||||
logger=wandb_logger,
|
||||
max_epochs=10
|
||||
)
|
||||
|
||||
trainer.fit(model, datamodule=dm)
|
||||
```
|
||||
|
||||
### Keras/TensorFlow
|
||||
|
||||
```python
|
||||
import wandb
|
||||
from wandb.keras import WandbCallback
|
||||
|
||||
# Initialize
|
||||
wandb.init(project="keras-demo")
|
||||
|
||||
# Add callback
|
||||
model.fit(
|
||||
x_train, y_train,
|
||||
validation_data=(x_val, y_val),
|
||||
epochs=10,
|
||||
callbacks=[WandbCallback()] # Auto-logs metrics
|
||||
)
|
||||
```
|
||||
|
||||
## Visualization & Analysis
|
||||
|
||||
### Custom Charts
|
||||
|
||||
```python
|
||||
# Log custom visualizations
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(x, y)
|
||||
wandb.log({"custom_plot": wandb.Image(fig)})
|
||||
|
||||
# Log confusion matrix
|
||||
wandb.log({"conf_mat": wandb.plot.confusion_matrix(
|
||||
probs=None,
|
||||
y_true=ground_truth,
|
||||
preds=predictions,
|
||||
class_names=class_names
|
||||
)})
|
||||
```
|
||||
|
||||
### Reports
|
||||
|
||||
Create shareable reports in W&B UI:
|
||||
- Combine runs, charts, and text
|
||||
- Markdown support
|
||||
- Embeddable visualizations
|
||||
- Team collaboration
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Organize with Tags and Groups
|
||||
|
||||
```python
|
||||
wandb.init(
|
||||
project="my-project",
|
||||
tags=["baseline", "resnet50", "imagenet"],
|
||||
group="resnet-experiments", # Group related runs
|
||||
job_type="train" # Type of job
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Log Everything Relevant
|
||||
|
||||
```python
|
||||
# Log system metrics
|
||||
wandb.log({
|
||||
"gpu/util": gpu_utilization,
|
||||
"gpu/memory": gpu_memory_used,
|
||||
"cpu/util": cpu_utilization
|
||||
})
|
||||
|
||||
# Log code version
|
||||
wandb.log({"git_commit": git_commit_hash})
|
||||
|
||||
# Log data splits
|
||||
wandb.log({
|
||||
"data/train_size": len(train_dataset),
|
||||
"data/val_size": len(val_dataset)
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Use Descriptive Names
|
||||
|
||||
```python
|
||||
# ✅ Good: Descriptive run names
|
||||
wandb.init(
|
||||
project="nlp-classification",
|
||||
name="bert-base-lr0.001-bs32-epoch10"
|
||||
)
|
||||
|
||||
# ❌ Bad: Generic names
|
||||
wandb.init(project="nlp", name="run1")
|
||||
```
|
||||
|
||||
### 4. Save Important Artifacts
|
||||
|
||||
```python
|
||||
# Save final model
|
||||
artifact = wandb.Artifact('final-model', type='model')
|
||||
artifact.add_file('model.pth')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
# Save predictions for analysis
|
||||
predictions_table = wandb.Table(
|
||||
columns=["id", "input", "prediction", "ground_truth"],
|
||||
data=predictions_data
|
||||
)
|
||||
wandb.log({"predictions": predictions_table})
|
||||
```
|
||||
|
||||
### 5. Use Offline Mode for Unstable Connections
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Enable offline mode
|
||||
os.environ["WANDB_MODE"] = "offline"
|
||||
|
||||
wandb.init(project="my-project")
|
||||
# ... your code ...
|
||||
|
||||
# Sync later
|
||||
# wandb sync <run_directory>
|
||||
```
|
||||
|
||||
## Team Collaboration
|
||||
|
||||
### Share Runs
|
||||
|
||||
```python
|
||||
# Runs are automatically shareable via URL
|
||||
run = wandb.init(project="team-project")
|
||||
print(f"Share this URL: {run.url}")
|
||||
```
|
||||
|
||||
### Team Projects
|
||||
|
||||
- Create team account at wandb.ai
|
||||
- Add team members
|
||||
- Set project visibility (private/public)
|
||||
- Use team-level artifacts and model registry
|
||||
|
||||
## Pricing
|
||||
|
||||
- **Free**: Unlimited public projects, 100GB storage
|
||||
- **Academic**: Free for students/researchers
|
||||
- **Teams**: $50/seat/month, private projects, unlimited storage
|
||||
- **Enterprise**: Custom pricing, on-prem options
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://docs.wandb.ai
|
||||
- **GitHub**: https://github.com/wandb/wandb (10.5k+ stars)
|
||||
- **Examples**: https://github.com/wandb/examples
|
||||
- **Community**: https://wandb.ai/community
|
||||
- **Discord**: https://wandb.me/discord
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/sweeps.md` - Comprehensive hyperparameter optimization guide
|
||||
- `references/artifacts.md` - Data and model versioning patterns
|
||||
- `references/integrations.md` - Framework-specific examples
|
||||
|
||||
|
||||
@@ -0,0 +1,584 @@
|
||||
# Artifacts & Model Registry Guide
|
||||
|
||||
Complete guide to data versioning and model management with W&B Artifacts.
|
||||
|
||||
## Table of Contents
|
||||
- What are Artifacts
|
||||
- Creating Artifacts
|
||||
- Using Artifacts
|
||||
- Model Registry
|
||||
- Versioning & Lineage
|
||||
- Best Practices
|
||||
|
||||
## What are Artifacts
|
||||
|
||||
Artifacts are versioned datasets, models, or files tracked with lineage.
|
||||
|
||||
**Key Features:**
|
||||
- Automatic versioning (v0, v1, v2...)
|
||||
- Lineage tracking (which runs produced/used artifacts)
|
||||
- Efficient storage (deduplication)
|
||||
- Collaboration (team-wide access)
|
||||
- Aliases (latest, best, production)
|
||||
|
||||
**Common Use Cases:**
|
||||
- Dataset versioning
|
||||
- Model checkpoints
|
||||
- Preprocessed data
|
||||
- Evaluation results
|
||||
- Configuration files
|
||||
|
||||
## Creating Artifacts
|
||||
|
||||
### Basic Dataset Artifact
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Create artifact
|
||||
dataset = wandb.Artifact(
|
||||
name='training-data',
|
||||
type='dataset',
|
||||
description='ImageNet training split with augmentations',
|
||||
metadata={
|
||||
'size': '1.2M images',
|
||||
'format': 'JPEG',
|
||||
'resolution': '224x224'
|
||||
}
|
||||
)
|
||||
|
||||
# Add files
|
||||
dataset.add_file('data/train.csv') # Single file
|
||||
dataset.add_dir('data/images') # Entire directory
|
||||
dataset.add_reference('s3://bucket/data') # Cloud reference
|
||||
|
||||
# Log artifact
|
||||
run.log_artifact(dataset)
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Model Artifact
|
||||
|
||||
```python
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Train model
|
||||
model = train_model()
|
||||
|
||||
# Save model
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
|
||||
# Create model artifact
|
||||
model_artifact = wandb.Artifact(
|
||||
name='resnet50-classifier',
|
||||
type='model',
|
||||
description='ResNet50 trained on ImageNet',
|
||||
metadata={
|
||||
'architecture': 'ResNet50',
|
||||
'accuracy': 0.95,
|
||||
'loss': 0.15,
|
||||
'epochs': 50,
|
||||
'framework': 'PyTorch'
|
||||
}
|
||||
)
|
||||
|
||||
# Add model file
|
||||
model_artifact.add_file('model.pth')
|
||||
|
||||
# Add config
|
||||
model_artifact.add_file('config.yaml')
|
||||
|
||||
# Log with aliases
|
||||
run.log_artifact(model_artifact, aliases=['latest', 'best'])
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Preprocessed Data Artifact
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="nlp-project")
|
||||
|
||||
# Preprocess data
|
||||
df = pd.read_csv('raw_data.csv')
|
||||
df_processed = preprocess(df)
|
||||
df_processed.to_csv('processed_data.csv', index=False)
|
||||
|
||||
# Create artifact
|
||||
processed_data = wandb.Artifact(
|
||||
name='processed-text-data',
|
||||
type='dataset',
|
||||
metadata={
|
||||
'rows': len(df_processed),
|
||||
'columns': list(df_processed.columns),
|
||||
'preprocessing_steps': ['lowercase', 'remove_stopwords', 'tokenize']
|
||||
}
|
||||
)
|
||||
|
||||
processed_data.add_file('processed_data.csv')
|
||||
|
||||
# Log artifact
|
||||
run.log_artifact(processed_data)
|
||||
```
|
||||
|
||||
## Using Artifacts
|
||||
|
||||
### Download and Use
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Download artifact
|
||||
artifact = run.use_artifact('training-data:latest')
|
||||
artifact_dir = artifact.download()
|
||||
|
||||
# Use files
|
||||
import pandas as pd
|
||||
df = pd.read_csv(f'{artifact_dir}/train.csv')
|
||||
|
||||
# Train with artifact data
|
||||
model = train_model(df)
|
||||
```
|
||||
|
||||
### Use Specific Version
|
||||
|
||||
```python
|
||||
# Use specific version
|
||||
artifact_v2 = run.use_artifact('training-data:v2')
|
||||
|
||||
# Use alias
|
||||
artifact_best = run.use_artifact('model:best')
|
||||
artifact_prod = run.use_artifact('model:production')
|
||||
|
||||
# Use from another project
|
||||
artifact = run.use_artifact('team/other-project/model:latest')
|
||||
```
|
||||
|
||||
### Check Artifact Metadata
|
||||
|
||||
```python
|
||||
artifact = run.use_artifact('training-data:latest')
|
||||
|
||||
# Access metadata
|
||||
print(artifact.metadata)
|
||||
print(f"Size: {artifact.metadata['size']}")
|
||||
|
||||
# Access version info
|
||||
print(f"Version: {artifact.version}")
|
||||
print(f"Created at: {artifact.created_at}")
|
||||
print(f"Digest: {artifact.digest}")
|
||||
```
|
||||
|
||||
## Model Registry
|
||||
|
||||
Link models to a central registry for governance and deployment.
|
||||
|
||||
### Create Model Registry
|
||||
|
||||
```python
|
||||
# In W&B UI:
|
||||
# 1. Go to "Registry" tab
|
||||
# 2. Create new registry: "production-models"
|
||||
# 3. Define stages: development, staging, production
|
||||
```
|
||||
|
||||
### Link Model to Registry
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="training")
|
||||
|
||||
# Create model artifact
|
||||
model_artifact = wandb.Artifact(
|
||||
name='sentiment-classifier',
|
||||
type='model',
|
||||
metadata={'accuracy': 0.94, 'f1': 0.92}
|
||||
)
|
||||
|
||||
model_artifact.add_file('model.pth')
|
||||
|
||||
# Log artifact
|
||||
run.log_artifact(model_artifact)
|
||||
|
||||
# Link to registry
|
||||
run.link_artifact(
|
||||
model_artifact,
|
||||
'model-registry/production-models',
|
||||
aliases=['staging'] # Deploy to staging
|
||||
)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Promote Model in Registry
|
||||
|
||||
```python
|
||||
# Retrieve model from registry
|
||||
api = wandb.Api()
|
||||
artifact = api.artifact('model-registry/production-models/sentiment-classifier:staging')
|
||||
|
||||
# Promote to production
|
||||
artifact.link('model-registry/production-models', aliases=['production'])
|
||||
|
||||
# Demote from production
|
||||
artifact.aliases = ['archived']
|
||||
artifact.save()
|
||||
```
|
||||
|
||||
### Use Model from Registry
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init()
|
||||
|
||||
# Download production model
|
||||
model_artifact = run.use_artifact(
|
||||
'model-registry/production-models/sentiment-classifier:production'
|
||||
)
|
||||
|
||||
model_dir = model_artifact.download()
|
||||
|
||||
# Load and use
|
||||
import torch
|
||||
model = torch.load(f'{model_dir}/model.pth')
|
||||
model.eval()
|
||||
```
|
||||
|
||||
## Versioning & Lineage
|
||||
|
||||
### Automatic Versioning
|
||||
|
||||
```python
|
||||
# First log: creates v0
|
||||
run1 = wandb.init(project="my-project")
|
||||
dataset_v0 = wandb.Artifact('my-dataset', type='dataset')
|
||||
dataset_v0.add_file('data_v1.csv')
|
||||
run1.log_artifact(dataset_v0)
|
||||
|
||||
# Second log with same name: creates v1
|
||||
run2 = wandb.init(project="my-project")
|
||||
dataset_v1 = wandb.Artifact('my-dataset', type='dataset')
|
||||
dataset_v1.add_file('data_v2.csv') # Different content
|
||||
run2.log_artifact(dataset_v1)
|
||||
|
||||
# Third log with SAME content as v1: references v1 (no new version)
|
||||
run3 = wandb.init(project="my-project")
|
||||
dataset_v1_again = wandb.Artifact('my-dataset', type='dataset')
|
||||
dataset_v1_again.add_file('data_v2.csv') # Same content as v1
|
||||
run3.log_artifact(dataset_v1_again) # Still v1, no v2 created
|
||||
```
|
||||
|
||||
### Track Lineage
|
||||
|
||||
```python
|
||||
# Training run
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Use dataset (input)
|
||||
dataset = run.use_artifact('training-data:v3')
|
||||
data = load_data(dataset.download())
|
||||
|
||||
# Train model
|
||||
model = train(data)
|
||||
|
||||
# Save model (output)
|
||||
model_artifact = wandb.Artifact('trained-model', type='model')
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
model_artifact.add_file('model.pth')
|
||||
run.log_artifact(model_artifact)
|
||||
|
||||
# Lineage automatically tracked:
|
||||
# training-data:v3 --> [run] --> trained-model:v0
|
||||
```
|
||||
|
||||
### View Lineage Graph
|
||||
|
||||
```python
|
||||
# In W&B UI:
|
||||
# Artifacts → Select artifact → Lineage tab
|
||||
# Shows:
|
||||
# - Which runs produced this artifact
|
||||
# - Which runs used this artifact
|
||||
# - Parent/child artifacts
|
||||
```
|
||||
|
||||
## Artifact Types
|
||||
|
||||
### Dataset Artifacts
|
||||
|
||||
```python
|
||||
# Raw data
|
||||
raw_data = wandb.Artifact('raw-data', type='dataset')
|
||||
raw_data.add_dir('raw/')
|
||||
|
||||
# Processed data
|
||||
processed_data = wandb.Artifact('processed-data', type='dataset')
|
||||
processed_data.add_dir('processed/')
|
||||
|
||||
# Train/val/test splits
|
||||
train_split = wandb.Artifact('train-split', type='dataset')
|
||||
train_split.add_file('train.csv')
|
||||
|
||||
val_split = wandb.Artifact('val-split', type='dataset')
|
||||
val_split.add_file('val.csv')
|
||||
```
|
||||
|
||||
### Model Artifacts
|
||||
|
||||
```python
|
||||
# Checkpoint during training
|
||||
checkpoint = wandb.Artifact('checkpoint-epoch-10', type='model')
|
||||
checkpoint.add_file('checkpoint_epoch_10.pth')
|
||||
|
||||
# Final model
|
||||
final_model = wandb.Artifact('final-model', type='model')
|
||||
final_model.add_file('model.pth')
|
||||
final_model.add_file('tokenizer.json')
|
||||
|
||||
# Quantized model
|
||||
quantized = wandb.Artifact('quantized-model', type='model')
|
||||
quantized.add_file('model_int8.onnx')
|
||||
```
|
||||
|
||||
### Result Artifacts
|
||||
|
||||
```python
|
||||
# Predictions
|
||||
predictions = wandb.Artifact('test-predictions', type='predictions')
|
||||
predictions.add_file('predictions.csv')
|
||||
|
||||
# Evaluation metrics
|
||||
eval_results = wandb.Artifact('evaluation', type='evaluation')
|
||||
eval_results.add_file('metrics.json')
|
||||
eval_results.add_file('confusion_matrix.png')
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Incremental Artifacts
|
||||
|
||||
Add files incrementally without re-uploading.
|
||||
|
||||
```python
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Create artifact
|
||||
dataset = wandb.Artifact('incremental-dataset', type='dataset')
|
||||
|
||||
# Add files incrementally
|
||||
for i in range(100):
|
||||
filename = f'batch_{i}.csv'
|
||||
process_batch(i, filename)
|
||||
dataset.add_file(filename)
|
||||
|
||||
# Log progress
|
||||
if (i + 1) % 10 == 0:
|
||||
print(f"Added {i + 1}/100 batches")
|
||||
|
||||
# Log complete artifact
|
||||
run.log_artifact(dataset)
|
||||
```
|
||||
|
||||
### Artifact Tables
|
||||
|
||||
Track structured data with W&B Tables.
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Create table
|
||||
table = wandb.Table(columns=["id", "image", "label", "prediction"])
|
||||
|
||||
for idx, (img, label, pred) in enumerate(zip(images, labels, predictions)):
|
||||
table.add_data(
|
||||
idx,
|
||||
wandb.Image(img),
|
||||
label,
|
||||
pred
|
||||
)
|
||||
|
||||
# Log as artifact
|
||||
artifact = wandb.Artifact('predictions-table', type='predictions')
|
||||
artifact.add(table, "predictions")
|
||||
run.log_artifact(artifact)
|
||||
```
|
||||
|
||||
### Artifact References
|
||||
|
||||
Reference external data without copying.
|
||||
|
||||
```python
|
||||
# S3 reference
|
||||
dataset = wandb.Artifact('s3-dataset', type='dataset')
|
||||
dataset.add_reference('s3://my-bucket/data/', name='train')
|
||||
dataset.add_reference('s3://my-bucket/labels/', name='labels')
|
||||
|
||||
# GCS reference
|
||||
dataset.add_reference('gs://my-bucket/data/')
|
||||
|
||||
# HTTP reference
|
||||
dataset.add_reference('https://example.com/data.zip')
|
||||
|
||||
# Local filesystem reference (for shared storage)
|
||||
dataset.add_reference('file:///mnt/shared/data')
|
||||
```
|
||||
|
||||
## Collaboration Patterns
|
||||
|
||||
### Team Dataset Sharing
|
||||
|
||||
```python
|
||||
# Data engineer creates dataset
|
||||
run = wandb.init(project="data-eng", entity="my-team")
|
||||
dataset = wandb.Artifact('shared-dataset', type='dataset')
|
||||
dataset.add_dir('data/')
|
||||
run.log_artifact(dataset, aliases=['latest', 'production'])
|
||||
|
||||
# ML engineer uses dataset
|
||||
run = wandb.init(project="ml-training", entity="my-team")
|
||||
dataset = run.use_artifact('my-team/data-eng/shared-dataset:production')
|
||||
data = load_data(dataset.download())
|
||||
```
|
||||
|
||||
### Model Handoff
|
||||
|
||||
```python
|
||||
# Training team
|
||||
train_run = wandb.init(project="model-training", entity="ml-team")
|
||||
model = train_model()
|
||||
model_artifact = wandb.Artifact('nlp-model', type='model')
|
||||
model_artifact.add_file('model.pth')
|
||||
train_run.log_artifact(model_artifact)
|
||||
train_run.link_artifact(model_artifact, 'model-registry/nlp-models', aliases=['candidate'])
|
||||
|
||||
# Evaluation team
|
||||
eval_run = wandb.init(project="model-eval", entity="ml-team")
|
||||
model_artifact = eval_run.use_artifact('model-registry/nlp-models/nlp-model:candidate')
|
||||
metrics = evaluate_model(model_artifact)
|
||||
|
||||
if metrics['f1'] > 0.9:
|
||||
# Promote to production
|
||||
model_artifact.link('model-registry/nlp-models', aliases=['production'])
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Descriptive Names
|
||||
|
||||
```python
|
||||
# ✅ Good: Descriptive names
|
||||
wandb.Artifact('imagenet-train-augmented-v2', type='dataset')
|
||||
wandb.Artifact('bert-base-sentiment-finetuned', type='model')
|
||||
|
||||
# ❌ Bad: Generic names
|
||||
wandb.Artifact('dataset1', type='dataset')
|
||||
wandb.Artifact('model', type='model')
|
||||
```
|
||||
|
||||
### 2. Add Comprehensive Metadata
|
||||
|
||||
```python
|
||||
model_artifact = wandb.Artifact(
|
||||
'production-model',
|
||||
type='model',
|
||||
description='ResNet50 classifier for product categorization',
|
||||
metadata={
|
||||
# Model info
|
||||
'architecture': 'ResNet50',
|
||||
'framework': 'PyTorch 2.0',
|
||||
'pretrained': True,
|
||||
|
||||
# Performance
|
||||
'accuracy': 0.95,
|
||||
'f1_score': 0.93,
|
||||
'inference_time_ms': 15,
|
||||
|
||||
# Training
|
||||
'epochs': 50,
|
||||
'dataset': 'imagenet',
|
||||
'num_samples': 1200000,
|
||||
|
||||
# Business context
|
||||
'use_case': 'e-commerce product classification',
|
||||
'owner': 'ml-team@company.com',
|
||||
'approved_by': 'data-science-lead'
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Use Aliases for Deployment Stages
|
||||
|
||||
```python
|
||||
# Development
|
||||
run.log_artifact(model, aliases=['dev', 'latest'])
|
||||
|
||||
# Staging
|
||||
run.log_artifact(model, aliases=['staging'])
|
||||
|
||||
# Production
|
||||
run.log_artifact(model, aliases=['production', 'v1.2.0'])
|
||||
|
||||
# Archive old versions
|
||||
old_artifact = api.artifact('model:production')
|
||||
old_artifact.aliases = ['archived-v1.1.0']
|
||||
old_artifact.save()
|
||||
```
|
||||
|
||||
### 4. Track Data Lineage
|
||||
|
||||
```python
|
||||
def create_training_pipeline():
|
||||
run = wandb.init(project="pipeline")
|
||||
|
||||
# 1. Load raw data
|
||||
raw_data = run.use_artifact('raw-data:latest')
|
||||
|
||||
# 2. Preprocess
|
||||
processed = preprocess(raw_data)
|
||||
processed_artifact = wandb.Artifact('processed-data', type='dataset')
|
||||
processed_artifact.add_file('processed.csv')
|
||||
run.log_artifact(processed_artifact)
|
||||
|
||||
# 3. Train model
|
||||
model = train(processed)
|
||||
model_artifact = wandb.Artifact('trained-model', type='model')
|
||||
model_artifact.add_file('model.pth')
|
||||
run.log_artifact(model_artifact)
|
||||
|
||||
# Lineage: raw-data → processed-data → trained-model
|
||||
```
|
||||
|
||||
### 5. Efficient Storage
|
||||
|
||||
```python
|
||||
# ✅ Good: Reference large files
|
||||
large_dataset = wandb.Artifact('large-dataset', type='dataset')
|
||||
large_dataset.add_reference('s3://bucket/huge-file.tar.gz')
|
||||
|
||||
# ❌ Bad: Upload giant files
|
||||
# large_dataset.add_file('huge-file.tar.gz') # Don't do this
|
||||
|
||||
# ✅ Good: Upload only metadata
|
||||
metadata_artifact = wandb.Artifact('dataset-metadata', type='dataset')
|
||||
metadata_artifact.add_file('metadata.json') # Small file
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Artifacts Documentation**: https://docs.wandb.ai/guides/artifacts
|
||||
- **Model Registry**: https://docs.wandb.ai/guides/model-registry
|
||||
- **Best Practices**: https://wandb.ai/site/articles/versioning-data-and-models-in-ml
|
||||
@@ -0,0 +1,700 @@
|
||||
# Framework Integrations Guide
|
||||
|
||||
Complete guide to integrating W&B with popular ML frameworks.
|
||||
|
||||
## Table of Contents
|
||||
- HuggingFace Transformers
|
||||
- PyTorch Lightning
|
||||
- Keras/TensorFlow
|
||||
- Fast.ai
|
||||
- XGBoost/LightGBM
|
||||
- PyTorch Native
|
||||
- Custom Integrations
|
||||
|
||||
## HuggingFace Transformers
|
||||
|
||||
### Automatic Integration
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project="hf-transformers", name="bert-finetuning")
|
||||
|
||||
# Training arguments with W&B
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
report_to="wandb", # Enable W&B logging
|
||||
run_name="bert-base-finetuning",
|
||||
|
||||
# Training params
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=64,
|
||||
learning_rate=2e-5,
|
||||
|
||||
# Logging
|
||||
logging_dir="./logs",
|
||||
logging_steps=100,
|
||||
logging_first_step=True,
|
||||
|
||||
# Evaluation
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=500,
|
||||
save_steps=500,
|
||||
|
||||
# Other
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="eval_accuracy"
|
||||
)
|
||||
|
||||
# Trainer automatically logs to W&B
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics
|
||||
)
|
||||
|
||||
# Train (metrics logged automatically)
|
||||
trainer.train()
|
||||
|
||||
# Finish W&B run
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Custom Logging
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
from transformers.integrations import WandbCallback
|
||||
import wandb
|
||||
|
||||
class CustomWandbCallback(WandbCallback):
|
||||
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
||||
super().on_evaluate(args, state, control, metrics, **kwargs)
|
||||
|
||||
# Log custom metrics
|
||||
wandb.log({
|
||||
"custom/eval_score": metrics["eval_accuracy"] * 100,
|
||||
"custom/epoch": state.epoch
|
||||
})
|
||||
|
||||
# Use custom callback
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
callbacks=[CustomWandbCallback()]
|
||||
)
|
||||
```
|
||||
|
||||
### Log Model to Registry
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
report_to="wandb",
|
||||
load_best_model_at_end=True
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save final model as artifact
|
||||
model_artifact = wandb.Artifact(
|
||||
'hf-bert-model',
|
||||
type='model',
|
||||
description='BERT finetuned on sentiment analysis'
|
||||
)
|
||||
|
||||
# Save model files
|
||||
trainer.save_model("./final_model")
|
||||
model_artifact.add_dir("./final_model")
|
||||
|
||||
# Log artifact
|
||||
wandb.log_artifact(model_artifact, aliases=['best', 'production'])
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## PyTorch Lightning
|
||||
|
||||
### Basic Integration
|
||||
|
||||
```python
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
import wandb
|
||||
|
||||
# Create W&B logger
|
||||
wandb_logger = WandbLogger(
|
||||
project="lightning-demo",
|
||||
name="resnet50-training",
|
||||
log_model=True, # Log model checkpoints as artifacts
|
||||
save_code=True # Save code as artifact
|
||||
)
|
||||
|
||||
# Lightning module
|
||||
class LitModel(pl.LightningModule):
|
||||
def __init__(self, learning_rate=0.001):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.model = create_model()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
|
||||
# Log metrics (automatically sent to W&B)
|
||||
self.log('train/loss', loss, on_step=True, on_epoch=True)
|
||||
self.log('train/accuracy', accuracy(y_hat, y), on_epoch=True)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
|
||||
self.log('val/loss', loss, on_step=False, on_epoch=True)
|
||||
self.log('val/accuracy', accuracy(y_hat, y), on_epoch=True)
|
||||
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
|
||||
# Trainer with W&B logger
|
||||
trainer = pl.Trainer(
|
||||
logger=wandb_logger,
|
||||
max_epochs=10,
|
||||
accelerator="gpu",
|
||||
devices=1
|
||||
)
|
||||
|
||||
# Train (metrics logged automatically)
|
||||
trainer.fit(model, datamodule=dm)
|
||||
|
||||
# Finish W&B run
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Log Media
|
||||
|
||||
```python
|
||||
class LitModel(pl.LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
|
||||
# Log images (first batch only)
|
||||
if batch_idx == 0:
|
||||
self.logger.experiment.log({
|
||||
"examples": [wandb.Image(img) for img in x[:8]]
|
||||
})
|
||||
|
||||
return loss
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
# Log confusion matrix
|
||||
cm = compute_confusion_matrix(self.all_preds, self.all_targets)
|
||||
|
||||
self.logger.experiment.log({
|
||||
"confusion_matrix": wandb.plot.confusion_matrix(
|
||||
probs=None,
|
||||
y_true=self.all_targets,
|
||||
preds=self.all_preds,
|
||||
class_names=self.class_names
|
||||
)
|
||||
})
|
||||
```
|
||||
|
||||
### Hyperparameter Sweeps
|
||||
|
||||
```python
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
import wandb
|
||||
|
||||
# Define sweep
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
|
||||
'parameters': {
|
||||
'learning_rate': {'min': 1e-5, 'max': 1e-2, 'distribution': 'log_uniform'},
|
||||
'batch_size': {'values': [16, 32, 64]},
|
||||
'hidden_size': {'values': [128, 256, 512]}
|
||||
}
|
||||
}
|
||||
|
||||
sweep_id = wandb.sweep(sweep_config, project="lightning-sweeps")
|
||||
|
||||
def train():
|
||||
# Initialize W&B
|
||||
run = wandb.init()
|
||||
|
||||
# Get hyperparameters
|
||||
config = wandb.config
|
||||
|
||||
# Create logger
|
||||
wandb_logger = WandbLogger()
|
||||
|
||||
# Create model with sweep params
|
||||
model = LitModel(
|
||||
learning_rate=config.learning_rate,
|
||||
hidden_size=config.hidden_size
|
||||
)
|
||||
|
||||
# Create datamodule with sweep batch size
|
||||
dm = DataModule(batch_size=config.batch_size)
|
||||
|
||||
# Train
|
||||
trainer = pl.Trainer(logger=wandb_logger, max_epochs=10)
|
||||
trainer.fit(model, dm)
|
||||
|
||||
# Run sweep
|
||||
wandb.agent(sweep_id, function=train, count=30)
|
||||
```
|
||||
|
||||
## Keras/TensorFlow
|
||||
|
||||
### With Callback
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
from wandb.keras import WandbCallback
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(
|
||||
project="keras-demo",
|
||||
config={
|
||||
"learning_rate": 0.001,
|
||||
"epochs": 10,
|
||||
"batch_size": 32
|
||||
}
|
||||
)
|
||||
|
||||
config = wandb.config
|
||||
|
||||
# Build model
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Dense(128, activation='relu'),
|
||||
tf.keras.layers.Dropout(0.2),
|
||||
tf.keras.layers.Dense(10, activation='softmax')
|
||||
])
|
||||
|
||||
model.compile(
|
||||
optimizer=tf.keras.optimizers.Adam(config.learning_rate),
|
||||
loss='sparse_categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
|
||||
# Train with W&B callback
|
||||
history = model.fit(
|
||||
x_train, y_train,
|
||||
validation_data=(x_val, y_val),
|
||||
epochs=config.epochs,
|
||||
batch_size=config.batch_size,
|
||||
callbacks=[
|
||||
WandbCallback(
|
||||
log_weights=True, # Log model weights
|
||||
log_gradients=True, # Log gradients
|
||||
training_data=(x_train, y_train),
|
||||
validation_data=(x_val, y_val),
|
||||
labels=class_names
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Save model as artifact
|
||||
model.save('model.h5')
|
||||
artifact = wandb.Artifact('keras-model', type='model')
|
||||
artifact.add_file('model.h5')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Custom Training Loop
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import wandb
|
||||
|
||||
wandb.init(project="tf-custom-loop")
|
||||
|
||||
# Model, optimizer, loss
|
||||
model = create_model()
|
||||
optimizer = tf.keras.optimizers.Adam(1e-3)
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
|
||||
# Metrics
|
||||
train_loss = tf.keras.metrics.Mean(name='train_loss')
|
||||
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
|
||||
|
||||
@tf.function
|
||||
def train_step(x, y):
|
||||
with tf.GradientTape() as tape:
|
||||
predictions = model(x, training=True)
|
||||
loss = loss_fn(y, predictions)
|
||||
|
||||
gradients = tape.gradient(loss, model.trainable_variables)
|
||||
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
||||
|
||||
train_loss(loss)
|
||||
train_accuracy(y, predictions)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(EPOCHS):
|
||||
train_loss.reset_states()
|
||||
train_accuracy.reset_states()
|
||||
|
||||
for step, (x, y) in enumerate(train_dataset):
|
||||
train_step(x, y)
|
||||
|
||||
# Log every 100 steps
|
||||
if step % 100 == 0:
|
||||
wandb.log({
|
||||
'train/loss': train_loss.result().numpy(),
|
||||
'train/accuracy': train_accuracy.result().numpy(),
|
||||
'epoch': epoch,
|
||||
'step': step
|
||||
})
|
||||
|
||||
# Log epoch metrics
|
||||
wandb.log({
|
||||
'epoch/train_loss': train_loss.result().numpy(),
|
||||
'epoch/train_accuracy': train_accuracy.result().numpy(),
|
||||
'epoch': epoch
|
||||
})
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## Fast.ai
|
||||
|
||||
### With Callback
|
||||
|
||||
```python
|
||||
from fastai.vision.all import *
|
||||
from fastai.callback.wandb import *
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project="fastai-demo")
|
||||
|
||||
# Create data loaders
|
||||
dls = ImageDataLoaders.from_folder(
|
||||
path,
|
||||
train='train',
|
||||
valid='valid',
|
||||
bs=64
|
||||
)
|
||||
|
||||
# Create learner with W&B callback
|
||||
learn = vision_learner(
|
||||
dls,
|
||||
resnet34,
|
||||
metrics=accuracy,
|
||||
cbs=WandbCallback(
|
||||
log_preds=True, # Log predictions
|
||||
log_model=True, # Log model as artifact
|
||||
log_dataset=True # Log dataset as artifact
|
||||
)
|
||||
)
|
||||
|
||||
# Train (metrics logged automatically)
|
||||
learn.fine_tune(5)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## XGBoost/LightGBM
|
||||
|
||||
### XGBoost
|
||||
|
||||
```python
|
||||
import xgboost as xgb
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
run = wandb.init(project="xgboost-demo", config={
|
||||
"max_depth": 6,
|
||||
"learning_rate": 0.1,
|
||||
"n_estimators": 100
|
||||
})
|
||||
|
||||
config = wandb.config
|
||||
|
||||
# Create DMatrix
|
||||
dtrain = xgb.DMatrix(X_train, label=y_train)
|
||||
dval = xgb.DMatrix(X_val, label=y_val)
|
||||
|
||||
# XGBoost params
|
||||
params = {
|
||||
'max_depth': config.max_depth,
|
||||
'learning_rate': config.learning_rate,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': ['logloss', 'auc']
|
||||
}
|
||||
|
||||
# Custom callback for W&B
|
||||
def wandb_callback(env):
|
||||
"""Log XGBoost metrics to W&B."""
|
||||
for metric_name, metric_value in env.evaluation_result_list:
|
||||
wandb.log({
|
||||
f"{metric_name}": metric_value,
|
||||
"iteration": env.iteration
|
||||
})
|
||||
|
||||
# Train with callback
|
||||
model = xgb.train(
|
||||
params,
|
||||
dtrain,
|
||||
num_boost_round=config.n_estimators,
|
||||
evals=[(dtrain, 'train'), (dval, 'val')],
|
||||
callbacks=[wandb_callback],
|
||||
verbose_eval=10
|
||||
)
|
||||
|
||||
# Save model
|
||||
model.save_model('xgboost_model.json')
|
||||
artifact = wandb.Artifact('xgboost-model', type='model')
|
||||
artifact.add_file('xgboost_model.json')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### LightGBM
|
||||
|
||||
```python
|
||||
import lightgbm as lgb
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="lgbm-demo")
|
||||
|
||||
# Create datasets
|
||||
train_data = lgb.Dataset(X_train, label=y_train)
|
||||
val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)
|
||||
|
||||
# Parameters
|
||||
params = {
|
||||
'objective': 'binary',
|
||||
'metric': ['binary_logloss', 'auc'],
|
||||
'learning_rate': 0.1,
|
||||
'num_leaves': 31
|
||||
}
|
||||
|
||||
# Custom callback
|
||||
def log_to_wandb(env):
|
||||
"""Log LightGBM metrics to W&B."""
|
||||
for entry in env.evaluation_result_list:
|
||||
dataset_name, metric_name, metric_value, _ = entry
|
||||
wandb.log({
|
||||
f"{dataset_name}/{metric_name}": metric_value,
|
||||
"iteration": env.iteration
|
||||
})
|
||||
|
||||
# Train
|
||||
model = lgb.train(
|
||||
params,
|
||||
train_data,
|
||||
num_boost_round=100,
|
||||
valid_sets=[train_data, val_data],
|
||||
valid_names=['train', 'val'],
|
||||
callbacks=[log_to_wandb]
|
||||
)
|
||||
|
||||
# Save model
|
||||
model.save_model('lgbm_model.txt')
|
||||
artifact = wandb.Artifact('lgbm-model', type='model')
|
||||
artifact.add_file('lgbm_model.txt')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## PyTorch Native
|
||||
|
||||
### Training Loop Integration
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project="pytorch-native", config={
|
||||
"learning_rate": 0.001,
|
||||
"epochs": 10,
|
||||
"batch_size": 32
|
||||
})
|
||||
|
||||
config = wandb.config
|
||||
|
||||
# Model, loss, optimizer
|
||||
model = create_model()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
# Watch model (logs gradients and parameters)
|
||||
wandb.watch(model, criterion, log="all", log_freq=100)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(config.epochs):
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
|
||||
# Forward pass
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
train_loss += loss.item()
|
||||
_, predicted = output.max(1)
|
||||
total += target.size(0)
|
||||
correct += predicted.eq(target).sum().item()
|
||||
|
||||
# Log every 100 batches
|
||||
if batch_idx % 100 == 0:
|
||||
wandb.log({
|
||||
'train/loss': loss.item(),
|
||||
'train/batch_accuracy': 100. * correct / total,
|
||||
'epoch': epoch,
|
||||
'batch': batch_idx
|
||||
})
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
val_correct = 0
|
||||
val_total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for data, target in val_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
|
||||
val_loss += loss.item()
|
||||
_, predicted = output.max(1)
|
||||
val_total += target.size(0)
|
||||
val_correct += predicted.eq(target).sum().item()
|
||||
|
||||
# Log epoch metrics
|
||||
wandb.log({
|
||||
'epoch/train_loss': train_loss / len(train_loader),
|
||||
'epoch/train_accuracy': 100. * correct / total,
|
||||
'epoch/val_loss': val_loss / len(val_loader),
|
||||
'epoch/val_accuracy': 100. * val_correct / val_total,
|
||||
'epoch': epoch
|
||||
})
|
||||
|
||||
# Save final model
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
artifact = wandb.Artifact('final-model', type='model')
|
||||
artifact.add_file('model.pth')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## Custom Integrations
|
||||
|
||||
### Generic Framework Integration
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
class WandbIntegration:
|
||||
"""Generic W&B integration wrapper."""
|
||||
|
||||
def __init__(self, project, config):
|
||||
self.run = wandb.init(project=project, config=config)
|
||||
self.config = wandb.config
|
||||
self.step = 0
|
||||
|
||||
def log_metrics(self, metrics, step=None):
|
||||
"""Log training metrics."""
|
||||
if step is None:
|
||||
step = self.step
|
||||
self.step += 1
|
||||
|
||||
wandb.log(metrics, step=step)
|
||||
|
||||
def log_images(self, images, caption=""):
|
||||
"""Log images."""
|
||||
wandb.log({
|
||||
caption: [wandb.Image(img) for img in images]
|
||||
})
|
||||
|
||||
def log_table(self, data, columns):
|
||||
"""Log tabular data."""
|
||||
table = wandb.Table(columns=columns, data=data)
|
||||
wandb.log({"table": table})
|
||||
|
||||
def save_model(self, model_path, metadata=None):
|
||||
"""Save model as artifact."""
|
||||
artifact = wandb.Artifact(
|
||||
'model',
|
||||
type='model',
|
||||
metadata=metadata or {}
|
||||
)
|
||||
artifact.add_file(model_path)
|
||||
self.run.log_artifact(artifact)
|
||||
|
||||
def finish(self):
|
||||
"""Finish W&B run."""
|
||||
wandb.finish()
|
||||
|
||||
# Usage
|
||||
wb = WandbIntegration(project="my-project", config={"lr": 0.001})
|
||||
|
||||
# Training loop
|
||||
for epoch in range(10):
|
||||
# Your training code
|
||||
loss, accuracy = train_epoch()
|
||||
|
||||
# Log metrics
|
||||
wb.log_metrics({
|
||||
'train/loss': loss,
|
||||
'train/accuracy': accuracy
|
||||
})
|
||||
|
||||
# Save model
|
||||
wb.save_model('model.pth', metadata={'accuracy': 0.95})
|
||||
wb.finish()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Integrations Guide**: https://docs.wandb.ai/guides/integrations
|
||||
- **HuggingFace**: https://docs.wandb.ai/guides/integrations/huggingface
|
||||
- **PyTorch Lightning**: https://docs.wandb.ai/guides/integrations/lightning
|
||||
- **Keras**: https://docs.wandb.ai/guides/integrations/keras
|
||||
- **Examples**: https://github.com/wandb/examples
|
||||
@@ -0,0 +1,847 @@
|
||||
# Comprehensive Hyperparameter Sweeps Guide
|
||||
|
||||
Complete guide to hyperparameter optimization with W&B Sweeps.
|
||||
|
||||
## Table of Contents
|
||||
- Sweep Configuration
|
||||
- Search Strategies
|
||||
- Parameter Distributions
|
||||
- Early Termination
|
||||
- Parallel Execution
|
||||
- Advanced Patterns
|
||||
- Real-World Examples
|
||||
|
||||
## Sweep Configuration
|
||||
|
||||
### Basic Sweep Config
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes', # Search strategy
|
||||
'metric': {
|
||||
'name': 'val/accuracy',
|
||||
'goal': 'maximize' # or 'minimize'
|
||||
},
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Initialize sweep
|
||||
sweep_id = wandb.sweep(sweep_config, project="my-project")
|
||||
```
|
||||
|
||||
### Complete Config Example
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
# Required: Search method
|
||||
'method': 'bayes',
|
||||
|
||||
# Required: Optimization metric
|
||||
'metric': {
|
||||
'name': 'val/f1_score',
|
||||
'goal': 'maximize'
|
||||
},
|
||||
|
||||
# Required: Parameters to search
|
||||
'parameters': {
|
||||
# Continuous parameter
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
|
||||
# Discrete values
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128]
|
||||
},
|
||||
|
||||
# Categorical
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd', 'rmsprop', 'adamw']
|
||||
},
|
||||
|
||||
# Uniform distribution
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.1,
|
||||
'max': 0.5
|
||||
},
|
||||
|
||||
# Integer range
|
||||
'num_layers': {
|
||||
'distribution': 'int_uniform',
|
||||
'min': 2,
|
||||
'max': 10
|
||||
},
|
||||
|
||||
# Fixed value (constant across runs)
|
||||
'epochs': {
|
||||
'value': 50
|
||||
}
|
||||
},
|
||||
|
||||
# Optional: Early termination
|
||||
'early_terminate': {
|
||||
'type': 'hyperband',
|
||||
'min_iter': 5,
|
||||
's': 2,
|
||||
'eta': 3,
|
||||
'max_iter': 27
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Search Strategies
|
||||
|
||||
### 1. Grid Search
|
||||
|
||||
Exhaustively search all combinations.
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'grid',
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'values': [0.001, 0.01, 0.1]
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64]
|
||||
},
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Total runs: 3 × 3 × 2 = 18 runs
|
||||
```
|
||||
|
||||
**Pros:**
|
||||
- Comprehensive search
|
||||
- Reproducible results
|
||||
- No randomness
|
||||
|
||||
**Cons:**
|
||||
- Exponential growth with parameters
|
||||
- Inefficient for continuous parameters
|
||||
- Not scalable beyond 3-4 parameters
|
||||
|
||||
**When to use:**
|
||||
- Few parameters (< 4)
|
||||
- All discrete values
|
||||
- Need complete coverage
|
||||
|
||||
### 2. Random Search
|
||||
|
||||
Randomly sample parameter combinations.
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'random',
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128, 256]
|
||||
},
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.5
|
||||
},
|
||||
'num_layers': {
|
||||
'distribution': 'int_uniform',
|
||||
'min': 2,
|
||||
'max': 8
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Run 100 random trials
|
||||
wandb.agent(sweep_id, function=train, count=100)
|
||||
```
|
||||
|
||||
**Pros:**
|
||||
- Scales to many parameters
|
||||
- Can run indefinitely
|
||||
- Often finds good solutions quickly
|
||||
|
||||
**Cons:**
|
||||
- No learning from previous runs
|
||||
- May miss optimal region
|
||||
- Results vary with random seed
|
||||
|
||||
**When to use:**
|
||||
- Many parameters (> 4)
|
||||
- Quick exploration
|
||||
- Limited budget
|
||||
|
||||
### 3. Bayesian Optimization (Recommended)
|
||||
|
||||
Learn from previous trials to sample promising regions.
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {
|
||||
'name': 'val/loss',
|
||||
'goal': 'minimize'
|
||||
},
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
'weight_decay': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-2
|
||||
},
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.1,
|
||||
'max': 0.5
|
||||
},
|
||||
'num_layers': {
|
||||
'values': [2, 3, 4, 5, 6]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Pros:**
|
||||
- Most sample-efficient
|
||||
- Learns from past trials
|
||||
- Focuses on promising regions
|
||||
|
||||
**Cons:**
|
||||
- Initial random exploration phase
|
||||
- May get stuck in local optima
|
||||
- Slower per iteration
|
||||
|
||||
**When to use:**
|
||||
- Expensive training runs
|
||||
- Need best performance
|
||||
- Limited compute budget
|
||||
|
||||
## Parameter Distributions
|
||||
|
||||
### Continuous Distributions
|
||||
|
||||
```python
|
||||
# Log-uniform: Good for learning rates, regularization
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-1
|
||||
}
|
||||
|
||||
# Uniform: Good for dropout, momentum
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.5
|
||||
}
|
||||
|
||||
# Normal distribution
|
||||
'parameter': {
|
||||
'distribution': 'normal',
|
||||
'mu': 0.5,
|
||||
'sigma': 0.1
|
||||
}
|
||||
|
||||
# Log-normal distribution
|
||||
'parameter': {
|
||||
'distribution': 'log_normal',
|
||||
'mu': 0.0,
|
||||
'sigma': 1.0
|
||||
}
|
||||
```
|
||||
|
||||
### Discrete Distributions
|
||||
|
||||
```python
|
||||
# Fixed values
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128, 256]
|
||||
}
|
||||
|
||||
# Integer uniform
|
||||
'num_layers': {
|
||||
'distribution': 'int_uniform',
|
||||
'min': 2,
|
||||
'max': 10
|
||||
}
|
||||
|
||||
# Quantized uniform (step size)
|
||||
'layer_size': {
|
||||
'distribution': 'q_uniform',
|
||||
'min': 32,
|
||||
'max': 512,
|
||||
'q': 32 # Step by 32: 32, 64, 96, 128...
|
||||
}
|
||||
|
||||
# Quantized log-uniform
|
||||
'hidden_size': {
|
||||
'distribution': 'q_log_uniform',
|
||||
'min': 32,
|
||||
'max': 1024,
|
||||
'q': 32
|
||||
}
|
||||
```
|
||||
|
||||
### Categorical Parameters
|
||||
|
||||
```python
|
||||
# Optimizers
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd', 'rmsprop', 'adamw']
|
||||
}
|
||||
|
||||
# Model architectures
|
||||
'model': {
|
||||
'values': ['resnet18', 'resnet34', 'resnet50', 'efficientnet_b0']
|
||||
}
|
||||
|
||||
# Activation functions
|
||||
'activation': {
|
||||
'values': ['relu', 'gelu', 'silu', 'leaky_relu']
|
||||
}
|
||||
```
|
||||
|
||||
## Early Termination
|
||||
|
||||
Stop underperforming runs early to save compute.
|
||||
|
||||
### Hyperband
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
|
||||
'parameters': {...},
|
||||
|
||||
# Hyperband early termination
|
||||
'early_terminate': {
|
||||
'type': 'hyperband',
|
||||
'min_iter': 3, # Minimum iterations before termination
|
||||
's': 2, # Bracket count
|
||||
'eta': 3, # Downsampling rate
|
||||
'max_iter': 27 # Maximum iterations
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
- Runs trials in brackets
|
||||
- Keeps top 1/eta performers each round
|
||||
- Eliminates bottom performers early
|
||||
|
||||
### Custom Termination
|
||||
|
||||
```python
|
||||
def train():
|
||||
run = wandb.init()
|
||||
|
||||
for epoch in range(MAX_EPOCHS):
|
||||
loss = train_epoch()
|
||||
val_acc = validate()
|
||||
|
||||
wandb.log({'val/accuracy': val_acc, 'epoch': epoch})
|
||||
|
||||
# Custom early stopping
|
||||
if epoch > 5 and val_acc < 0.5:
|
||||
print("Early stop: Poor performance")
|
||||
break
|
||||
|
||||
if epoch > 10 and val_acc > best_acc - 0.01:
|
||||
print("Early stop: No improvement")
|
||||
break
|
||||
```
|
||||
|
||||
## Training Function
|
||||
|
||||
### Basic Template
|
||||
|
||||
```python
|
||||
def train():
|
||||
# Initialize W&B run
|
||||
run = wandb.init()
|
||||
|
||||
# Get hyperparameters
|
||||
config = wandb.config
|
||||
|
||||
# Build model with config
|
||||
model = build_model(
|
||||
hidden_size=config.hidden_size,
|
||||
num_layers=config.num_layers,
|
||||
dropout=config.dropout
|
||||
)
|
||||
|
||||
# Create optimizer
|
||||
optimizer = create_optimizer(
|
||||
model.parameters(),
|
||||
name=config.optimizer,
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(config.epochs):
|
||||
# Train
|
||||
train_loss, train_acc = train_epoch(
|
||||
model, optimizer, train_loader, config.batch_size
|
||||
)
|
||||
|
||||
# Validate
|
||||
val_loss, val_acc = validate(model, val_loader)
|
||||
|
||||
# Log metrics
|
||||
wandb.log({
|
||||
'train/loss': train_loss,
|
||||
'train/accuracy': train_acc,
|
||||
'val/loss': val_loss,
|
||||
'val/accuracy': val_acc,
|
||||
'epoch': epoch
|
||||
})
|
||||
|
||||
# Log final model
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
wandb.save('model.pth')
|
||||
|
||||
# Finish run
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### With PyTorch
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
import wandb
|
||||
|
||||
def train():
|
||||
run = wandb.init()
|
||||
config = wandb.config
|
||||
|
||||
# Data
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
# Model
|
||||
model = ResNet(
|
||||
num_classes=config.num_classes,
|
||||
dropout=config.dropout
|
||||
).to(device)
|
||||
|
||||
# Optimizer
|
||||
if config.optimizer == 'adam':
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
elif config.optimizer == 'sgd':
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# Scheduler
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=config.epochs
|
||||
)
|
||||
|
||||
# Training
|
||||
for epoch in range(config.epochs):
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
|
||||
for data, target in train_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = nn.CrossEntropyLoss()(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss, val_acc = validate(model, val_loader)
|
||||
|
||||
# Step scheduler
|
||||
scheduler.step()
|
||||
|
||||
# Log
|
||||
wandb.log({
|
||||
'train/loss': train_loss / len(train_loader),
|
||||
'val/loss': val_loss,
|
||||
'val/accuracy': val_acc,
|
||||
'learning_rate': scheduler.get_last_lr()[0],
|
||||
'epoch': epoch
|
||||
})
|
||||
```
|
||||
|
||||
## Parallel Execution
|
||||
|
||||
### Multiple Agents
|
||||
|
||||
Run sweep agents in parallel to speed up search.
|
||||
|
||||
```python
|
||||
# Initialize sweep once
|
||||
sweep_id = wandb.sweep(sweep_config, project="my-project")
|
||||
|
||||
# Run multiple agents in parallel
|
||||
# Agent 1 (Terminal 1)
|
||||
wandb.agent(sweep_id, function=train, count=20)
|
||||
|
||||
# Agent 2 (Terminal 2)
|
||||
wandb.agent(sweep_id, function=train, count=20)
|
||||
|
||||
# Agent 3 (Terminal 3)
|
||||
wandb.agent(sweep_id, function=train, count=20)
|
||||
|
||||
# Total: 60 runs across 3 agents
|
||||
```
|
||||
|
||||
### Multi-GPU Execution
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
def train():
|
||||
# Get available GPU
|
||||
gpu_id = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
|
||||
|
||||
run = wandb.init()
|
||||
config = wandb.config
|
||||
|
||||
# Train on specific GPU
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
model = model.to(device)
|
||||
|
||||
# ... rest of training ...
|
||||
|
||||
# Run agents on different GPUs
|
||||
# Terminal 1
|
||||
# CUDA_VISIBLE_DEVICES=0 wandb agent sweep_id
|
||||
|
||||
# Terminal 2
|
||||
# CUDA_VISIBLE_DEVICES=1 wandb agent sweep_id
|
||||
|
||||
# Terminal 3
|
||||
# CUDA_VISIBLE_DEVICES=2 wandb agent sweep_id
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Nested Parameters
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
|
||||
'parameters': {
|
||||
'model': {
|
||||
'parameters': {
|
||||
'type': {
|
||||
'values': ['resnet', 'efficientnet']
|
||||
},
|
||||
'size': {
|
||||
'values': ['small', 'medium', 'large']
|
||||
}
|
||||
}
|
||||
},
|
||||
'optimizer': {
|
||||
'parameters': {
|
||||
'type': {
|
||||
'values': ['adam', 'sgd']
|
||||
},
|
||||
'lr': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Access nested config
|
||||
def train():
|
||||
run = wandb.init()
|
||||
model_type = wandb.config.model.type
|
||||
model_size = wandb.config.model.size
|
||||
opt_type = wandb.config.optimizer.type
|
||||
lr = wandb.config.optimizer.lr
|
||||
```
|
||||
|
||||
### Conditional Parameters
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'parameters': {
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd']
|
||||
},
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
# Only used if optimizer == 'sgd'
|
||||
'momentum': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.5,
|
||||
'max': 0.99
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def train():
|
||||
run = wandb.init()
|
||||
config = wandb.config
|
||||
|
||||
if config.optimizer == 'adam':
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate
|
||||
)
|
||||
elif config.optimizer == 'sgd':
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
momentum=config.momentum # Conditional parameter
|
||||
)
|
||||
```
|
||||
|
||||
## Real-World Examples
|
||||
|
||||
### Image Classification
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {
|
||||
'name': 'val/top1_accuracy',
|
||||
'goal': 'maximize'
|
||||
},
|
||||
'parameters': {
|
||||
# Model
|
||||
'architecture': {
|
||||
'values': ['resnet50', 'resnet101', 'efficientnet_b0', 'efficientnet_b3']
|
||||
},
|
||||
'pretrained': {
|
||||
'values': [True, False]
|
||||
},
|
||||
|
||||
# Training
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-2
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128]
|
||||
},
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd', 'adamw']
|
||||
},
|
||||
'weight_decay': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-2
|
||||
},
|
||||
|
||||
# Regularization
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.5
|
||||
},
|
||||
'label_smoothing': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.2
|
||||
},
|
||||
|
||||
# Data augmentation
|
||||
'mixup_alpha': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 1.0
|
||||
},
|
||||
'cutmix_alpha': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 1.0
|
||||
}
|
||||
},
|
||||
'early_terminate': {
|
||||
'type': 'hyperband',
|
||||
'min_iter': 5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### NLP Fine-Tuning
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'eval/f1', 'goal': 'maximize'},
|
||||
'parameters': {
|
||||
# Model
|
||||
'model_name': {
|
||||
'values': ['bert-base-uncased', 'roberta-base', 'distilbert-base-uncased']
|
||||
},
|
||||
|
||||
# Training
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-4
|
||||
},
|
||||
'per_device_train_batch_size': {
|
||||
'values': [8, 16, 32]
|
||||
},
|
||||
'num_train_epochs': {
|
||||
'values': [3, 4, 5]
|
||||
},
|
||||
'warmup_ratio': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.1
|
||||
},
|
||||
'weight_decay': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-4,
|
||||
'max': 1e-1
|
||||
},
|
||||
|
||||
# Optimizer
|
||||
'adam_beta1': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.8,
|
||||
'max': 0.95
|
||||
},
|
||||
'adam_beta2': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.95,
|
||||
'max': 0.999
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Small
|
||||
|
||||
```python
|
||||
# Initial exploration: Random search, 20 runs
|
||||
sweep_config_v1 = {
|
||||
'method': 'random',
|
||||
'parameters': {...}
|
||||
}
|
||||
wandb.agent(sweep_id_v1, train, count=20)
|
||||
|
||||
# Refined search: Bayes, narrow ranges
|
||||
sweep_config_v2 = {
|
||||
'method': 'bayes',
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'min': 5e-5, # Narrowed from 1e-6 to 1e-4
|
||||
'max': 1e-4
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Use Log Scales
|
||||
|
||||
```python
|
||||
# ✅ Good: Log scale for learning rate
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-2
|
||||
}
|
||||
|
||||
# ❌ Bad: Linear scale
|
||||
'learning_rate': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.000001,
|
||||
'max': 0.01
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Set Reasonable Ranges
|
||||
|
||||
```python
|
||||
# Base ranges on prior knowledge
|
||||
'learning_rate': {'min': 1e-5, 'max': 1e-3}, # Typical for Adam
|
||||
'batch_size': {'values': [16, 32, 64]}, # GPU memory limits
|
||||
'dropout': {'min': 0.1, 'max': 0.5} # Too high hurts training
|
||||
```
|
||||
|
||||
### 4. Monitor Resource Usage
|
||||
|
||||
```python
|
||||
def train():
|
||||
run = wandb.init()
|
||||
|
||||
# Log system metrics
|
||||
wandb.log({
|
||||
'system/gpu_memory_allocated': torch.cuda.memory_allocated(),
|
||||
'system/gpu_memory_reserved': torch.cuda.memory_reserved()
|
||||
})
|
||||
```
|
||||
|
||||
### 5. Save Best Models
|
||||
|
||||
```python
|
||||
def train():
|
||||
run = wandb.init()
|
||||
best_acc = 0.0
|
||||
|
||||
for epoch in range(config.epochs):
|
||||
val_acc = validate(model)
|
||||
|
||||
if val_acc > best_acc:
|
||||
best_acc = val_acc
|
||||
# Save best checkpoint
|
||||
torch.save(model.state_dict(), 'best_model.pth')
|
||||
wandb.save('best_model.pth')
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Sweeps Documentation**: https://docs.wandb.ai/guides/sweeps
|
||||
- **Configuration Reference**: https://docs.wandb.ai/guides/sweeps/configuration
|
||||
- **Examples**: https://github.com/wandb/examples/tree/master/examples/wandb-sweeps
|
||||
Reference in New Issue
Block a user