Add stuck initiatives audit report
This commit is contained in:
3
protected/skills-backup/mlops/training/DESCRIPTION.md
Normal file
3
protected/skills-backup/mlops/training/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Fine-tuning, RLHF/DPO/GRPO training, distributed training frameworks, and optimization tools for training LLMs and other models.
|
||||
---
|
||||
335
protected/skills-backup/mlops/training/accelerate/SKILL.md
Normal file
335
protected/skills-backup/mlops/training/accelerate/SKILL.md
Normal file
@@ -0,0 +1,335 @@
|
||||
---
|
||||
name: huggingface-accelerate
|
||||
description: Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [accelerate, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Distributed Training, HuggingFace, Accelerate, DeepSpeed, FSDP, Mixed Precision, PyTorch, DDP, Unified API, Simple]
|
||||
|
||||
---
|
||||
|
||||
# HuggingFace Accelerate - Unified Distributed Training
|
||||
|
||||
## Quick start
|
||||
|
||||
Accelerate simplifies distributed training to 4 lines of code.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
**Convert PyTorch script** (4 lines):
|
||||
```python
|
||||
import torch
|
||||
+ from accelerate import Accelerator
|
||||
|
||||
+ accelerator = Accelerator()
|
||||
|
||||
model = torch.nn.Transformer()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset)
|
||||
|
||||
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
- loss.backward()
|
||||
+ accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Run** (single command):
|
||||
```bash
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: From single GPU to multi-GPU
|
||||
|
||||
**Original script**:
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
|
||||
model = torch.nn.Linear(10, 2).to('cuda')
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
batch = batch.to('cuda')
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**With Accelerate** (4 lines added):
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
from accelerate import Accelerator # +1
|
||||
|
||||
accelerator = Accelerator() # +2
|
||||
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # +3
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
# No .to('cuda') needed - automatic!
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
accelerator.backward(loss) # +4
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Configure** (interactive):
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
- Which machine? (single/multi GPU/TPU/CPU)
|
||||
- How many machines? (1)
|
||||
- Mixed precision? (no/fp16/bf16/fp8)
|
||||
- DeepSpeed? (no/yes)
|
||||
|
||||
**Launch** (works on any setup):
|
||||
```bash
|
||||
# Single GPU
|
||||
accelerate launch train.py
|
||||
|
||||
# Multi-GPU (8 GPUs)
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
train.py
|
||||
```
|
||||
|
||||
### Workflow 2: Mixed precision training
|
||||
|
||||
**Enable FP16/BF16**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# FP16 (with gradient scaling)
|
||||
accelerator = Accelerator(mixed_precision='fp16')
|
||||
|
||||
# BF16 (no scaling, more stable)
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# FP8 (H100+)
|
||||
accelerator = Accelerator(mixed_precision='fp8')
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
# Everything else is automatic!
|
||||
for batch in dataloader:
|
||||
with accelerator.autocast(): # Optional, done automatically
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
```
|
||||
|
||||
### Workflow 3: DeepSpeed ZeRO integration
|
||||
|
||||
**Enable DeepSpeed ZeRO-2**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
deepspeed_plugin={
|
||||
"zero_stage": 2, # ZeRO-2
|
||||
"offload_optimizer": False,
|
||||
"gradient_accumulation_steps": 4
|
||||
}
|
||||
)
|
||||
|
||||
# Same code as before!
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: DeepSpeed → ZeRO-2
|
||||
```
|
||||
|
||||
**deepspeed_config.json**:
|
||||
```json
|
||||
{
|
||||
"fp16": {"enabled": false},
|
||||
"bf16": {"enabled": true},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {"device": "cpu"},
|
||||
"allgather_bucket_size": 5e8,
|
||||
"reduce_bucket_size": 5e8
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
accelerate launch --config_file deepspeed_config.json train.py
|
||||
```
|
||||
|
||||
### Workflow 4: FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
**Enable FSDP**:
|
||||
```python
|
||||
from accelerate import Accelerator, FullyShardedDataParallelPlugin
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
|
||||
auto_wrap_policy="TRANSFORMER_AUTO_WRAP",
|
||||
cpu_offload=False
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
fsdp_plugin=fsdp_plugin
|
||||
)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: FSDP → Full Shard → No CPU Offload
|
||||
```
|
||||
|
||||
### Workflow 5: Gradient accumulation
|
||||
|
||||
**Accumulate gradients**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(gradient_accumulation_steps=4)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Effective batch size**: `batch_size * num_gpus * gradient_accumulation_steps`
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Accelerate when**:
|
||||
- Want simplest distributed training
|
||||
- Need single script for any hardware
|
||||
- Use HuggingFace ecosystem
|
||||
- Want flexibility (DDP/DeepSpeed/FSDP/Megatron)
|
||||
- Need quick prototyping
|
||||
|
||||
**Key advantages**:
|
||||
- **4 lines**: Minimal code changes
|
||||
- **Unified API**: Same code for DDP, DeepSpeed, FSDP, Megatron
|
||||
- **Automatic**: Device placement, mixed precision, sharding
|
||||
- **Interactive config**: No manual launcher setup
|
||||
- **Single launch**: Works everywhere
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **PyTorch Lightning**: Need callbacks, high-level abstractions
|
||||
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
|
||||
- **DeepSpeed**: Direct API control, advanced features
|
||||
- **Raw DDP**: Maximum control, minimal abstraction
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Wrong device placement**
|
||||
|
||||
Don't manually move to device:
|
||||
```python
|
||||
# WRONG
|
||||
batch = batch.to('cuda')
|
||||
|
||||
# CORRECT
|
||||
# Accelerate handles it automatically after prepare()
|
||||
```
|
||||
|
||||
**Issue: Gradient accumulation not working**
|
||||
|
||||
Use context manager:
|
||||
```python
|
||||
# CORRECT
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Issue: Checkpointing in distributed**
|
||||
|
||||
Use accelerator methods:
|
||||
```python
|
||||
# Save only on main process
|
||||
if accelerator.is_main_process:
|
||||
accelerator.save_state('checkpoint/')
|
||||
|
||||
# Load on all processes
|
||||
accelerator.load_state('checkpoint/')
|
||||
```
|
||||
|
||||
**Issue: Different results with FSDP**
|
||||
|
||||
Ensure same random seed:
|
||||
```python
|
||||
from accelerate.utils import set_seed
|
||||
set_seed(42)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Megatron integration**: See [references/megatron-integration.md](references/megatron-integration.md) for tensor parallelism, pipeline parallelism, and sequence parallelism setup.
|
||||
|
||||
**Custom plugins**: See [references/custom-plugins.md](references/custom-plugins.md) for creating custom distributed plugins and advanced configuration.
|
||||
|
||||
**Performance tuning**: See [references/performance.md](references/performance.md) for profiling, memory optimization, and best practices.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **CPU**: Works (slow)
|
||||
- **Single GPU**: Works
|
||||
- **Multi-GPU**: DDP (default), DeepSpeed, or FSDP
|
||||
- **Multi-node**: DDP, DeepSpeed, FSDP, Megatron
|
||||
- **TPU**: Supported
|
||||
- **Apple MPS**: Supported
|
||||
|
||||
**Launcher requirements**:
|
||||
- **DDP**: `torch.distributed.run` (built-in)
|
||||
- **DeepSpeed**: `deepspeed` (pip install deepspeed)
|
||||
- **FSDP**: PyTorch 1.12+ (built-in)
|
||||
- **Megatron**: Custom setup
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://huggingface.co/docs/accelerate
|
||||
- GitHub: https://github.com/huggingface/accelerate
|
||||
- Version: 1.11.0+
|
||||
- Tutorial: "Accelerate your scripts"
|
||||
- Examples: https://github.com/huggingface/accelerate/tree/main/examples
|
||||
- Used by: HuggingFace Transformers, TRL, PEFT, all HF libraries
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,453 @@
|
||||
# Custom Plugins for Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate allows creating **custom plugins** to extend distributed training strategies beyond built-in options (DDP, FSDP, DeepSpeed).
|
||||
|
||||
## Plugin Architecture
|
||||
|
||||
### Base Plugin Structure
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
"""Custom training plugin."""
|
||||
|
||||
# Plugin configuration
|
||||
param1: int = 1
|
||||
param2: str = "default"
|
||||
|
||||
def __post_init__(self):
|
||||
# Validation logic
|
||||
if self.param1 < 1:
|
||||
raise ValueError("param1 must be >= 1")
|
||||
```
|
||||
|
||||
### Using Custom Plugin
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# Create plugin
|
||||
custom_plugin = CustomPlugin(param1=4, param2="value")
|
||||
|
||||
# Pass to Accelerator
|
||||
accelerator = Accelerator(
|
||||
custom_plugin=custom_plugin # Not a real parameter, example only
|
||||
)
|
||||
```
|
||||
|
||||
## Built-In Plugin Examples
|
||||
|
||||
### 1. GradScalerKwargs (FP16 Configuration)
|
||||
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
# Configure gradient scaler for FP16
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16, # Initial loss scale
|
||||
growth_factor=2.0, # Scale growth rate
|
||||
backoff_factor=0.5, # Scale backoff rate
|
||||
growth_interval=2000, # Steps between scale increases
|
||||
enabled=True # Enable scaler
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs] # Pass as kwargs handler
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Fine-tune FP16 gradient scaling behavior
|
||||
|
||||
### 2. DistributedDataParallelKwargs
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
# Configure DDP behavior
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Gradient bucketing size
|
||||
find_unused_parameters=False, # Find unused params (slower)
|
||||
check_reduction=False, # Check gradient reduction
|
||||
gradient_as_bucket_view=True, # Memory optimization
|
||||
static_graph=False # Static computation graph
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
kwargs_handlers=[ddp_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Optimize DDP performance for specific models
|
||||
|
||||
### 3. FP8RecipeKwargs (H100 FP8)
|
||||
|
||||
```python
|
||||
from accelerate.utils import FP8RecipeKwargs
|
||||
|
||||
# Configure FP8 training (H100)
|
||||
fp8_recipe = FP8RecipeKwargs(
|
||||
backend="te", # TransformerEngine backend
|
||||
margin=0, # Scaling margin
|
||||
interval=1, # Scaling interval
|
||||
fp8_format="HYBRID", # E4M3 + E5M2 hybrid
|
||||
amax_history_len=1024, # AMAX history length
|
||||
amax_compute_algo="max" # AMAX computation algorithm
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp8',
|
||||
kwargs_handlers=[fp8_recipe]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Ultra-fast training on H100 GPUs
|
||||
|
||||
## Custom DeepSpeed Configuration
|
||||
|
||||
### ZeRO-3 with CPU Offload
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
# Custom DeepSpeed config
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3, # ZeRO-3
|
||||
offload_optimizer_device="cpu", # CPU offload optimizer
|
||||
offload_param_device="cpu", # CPU offload parameters
|
||||
zero3_init_flag=True, # ZeRO-3 initialization
|
||||
zero3_save_16bit_model=True, # Save FP16 weights
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### ZeRO-2 with NVMe Offload
|
||||
|
||||
```python
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=2,
|
||||
offload_optimizer_device="nvme", # NVMe offload
|
||||
offload_param_device="nvme",
|
||||
nvme_path="/local_nvme", # NVMe mount path
|
||||
)
|
||||
```
|
||||
|
||||
### Custom JSON Config
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
# Load custom DeepSpeed config
|
||||
with open('deepspeed_config.json', 'r') as f:
|
||||
ds_config = json.load(f)
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(hf_ds_config=ds_config)
|
||||
|
||||
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
|
||||
```
|
||||
|
||||
**Example config** (`deepspeed_config.json`):
|
||||
```json
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"steps_per_print": 100,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
```
|
||||
|
||||
## Custom FSDP Configuration
|
||||
|
||||
### FSDP with Custom Auto-Wrap Policy
|
||||
|
||||
```python
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin
|
||||
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
||||
import functools
|
||||
|
||||
# Custom wrap policy (size-based)
|
||||
wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy,
|
||||
min_num_params=1e6 # Wrap layers with 1M+ params
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch strategy
|
||||
mixed_precision_policy=None, # Use Accelerator's mixed precision
|
||||
auto_wrap_policy=wrap_policy, # Custom wrapping
|
||||
cpu_offload=False,
|
||||
ignored_modules=None, # Modules to not wrap
|
||||
state_dict_type="FULL_STATE_DICT", # Save format
|
||||
optim_state_dict_config=None,
|
||||
limit_all_gathers=False,
|
||||
use_orig_params=True, # Use original param shapes
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
fsdp_plugin=fsdp_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### FSDP with Transformer Auto-Wrap
|
||||
|
||||
```python
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
||||
|
||||
# Wrap at transformer block level
|
||||
wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={GPT2Block} # Wrap GPT2Block layers
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy=wrap_policy
|
||||
)
|
||||
```
|
||||
|
||||
## Creating Custom Training Strategy
|
||||
|
||||
### Example: Custom Gradient Accumulation
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
class CustomGradientAccumulation:
|
||||
def __init__(self, steps=4, adaptive=False):
|
||||
self.steps = steps
|
||||
self.adaptive = adaptive
|
||||
self.current_step = 0
|
||||
|
||||
def should_sync(self, loss):
|
||||
"""Decide whether to sync gradients."""
|
||||
self.current_step += 1
|
||||
|
||||
# Adaptive: sync on high loss
|
||||
if self.adaptive and loss > threshold:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
# Regular: sync every N steps
|
||||
if self.current_step >= self.steps:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# Usage
|
||||
custom_accum = CustomGradientAccumulation(steps=8, adaptive=True)
|
||||
accelerator = Accelerator()
|
||||
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
# Scale loss
|
||||
loss = loss / custom_accum.steps
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Conditional sync
|
||||
if custom_accum.should_sync(loss.item()):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
### Example: Custom Mixed Precision
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class CustomMixedPrecision:
|
||||
"""Custom mixed precision with dynamic loss scaling."""
|
||||
|
||||
def __init__(self, init_scale=2**16, scale_window=2000):
|
||||
self.scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=init_scale,
|
||||
growth_interval=scale_window
|
||||
)
|
||||
self.scale_history = []
|
||||
|
||||
def scale_loss(self, loss):
|
||||
"""Scale loss for backward."""
|
||||
return self.scaler.scale(loss)
|
||||
|
||||
def unscale_and_clip(self, optimizer, max_norm=1.0):
|
||||
"""Unscale gradients and clip."""
|
||||
self.scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
optimizer.param_groups[0]['params'],
|
||||
max_norm
|
||||
)
|
||||
|
||||
def step(self, optimizer):
|
||||
"""Optimizer step with scaler update."""
|
||||
scale_before = self.scaler.get_scale()
|
||||
self.scaler.step(optimizer)
|
||||
self.scaler.update()
|
||||
scale_after = self.scaler.get_scale()
|
||||
|
||||
# Track scale changes
|
||||
if scale_before != scale_after:
|
||||
self.scale_history.append(scale_after)
|
||||
|
||||
# Usage
|
||||
custom_mp = CustomMixedPrecision()
|
||||
|
||||
for batch in dataloader:
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
loss = model(**batch).loss
|
||||
|
||||
scaled_loss = custom_mp.scale_loss(loss)
|
||||
scaled_loss.backward()
|
||||
|
||||
custom_mp.unscale_and_clip(optimizer, max_norm=1.0)
|
||||
custom_mp.step(optimizer)
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Advanced: Custom Distributed Backend
|
||||
|
||||
### Custom AllReduce Strategy
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
|
||||
class CustomAllReduce:
|
||||
"""Custom all-reduce with compression."""
|
||||
|
||||
def __init__(self, compression_ratio=0.1):
|
||||
self.compression_ratio = compression_ratio
|
||||
|
||||
def compress_gradients(self, tensor):
|
||||
"""Top-k gradient compression."""
|
||||
k = int(tensor.numel() * self.compression_ratio)
|
||||
values, indices = torch.topk(tensor.abs().view(-1), k)
|
||||
return values, indices
|
||||
|
||||
def all_reduce_compressed(self, tensor):
|
||||
"""All-reduce with gradient compression."""
|
||||
# Compress
|
||||
values, indices = self.compress_gradients(tensor)
|
||||
|
||||
# All-reduce compressed gradients
|
||||
dist.all_reduce(values, op=dist.ReduceOp.SUM)
|
||||
|
||||
# Decompress
|
||||
tensor_compressed = torch.zeros_like(tensor).view(-1)
|
||||
tensor_compressed[indices] = values / dist.get_world_size()
|
||||
|
||||
return tensor_compressed.view_as(tensor)
|
||||
|
||||
# Usage in training loop
|
||||
custom_ar = CustomAllReduce(compression_ratio=0.1)
|
||||
|
||||
for batch in dataloader:
|
||||
loss = model(**batch).loss
|
||||
loss.backward()
|
||||
|
||||
# Custom all-reduce
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad.data = custom_ar.all_reduce_compressed(param.grad.data)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Plugin Best Practices
|
||||
|
||||
### 1. Validation in `__post_init__`
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
learning_rate: float = 1e-3
|
||||
warmup_steps: int = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate parameters
|
||||
if self.learning_rate <= 0:
|
||||
raise ValueError("learning_rate must be positive")
|
||||
if self.warmup_steps < 0:
|
||||
raise ValueError("warmup_steps must be non-negative")
|
||||
|
||||
# Compute derived values
|
||||
self.min_lr = self.learning_rate * 0.1
|
||||
```
|
||||
|
||||
### 2. Compatibility Checks
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
feature_enabled: bool = True
|
||||
|
||||
def is_compatible(self, accelerator):
|
||||
"""Check if plugin is compatible with accelerator config."""
|
||||
if self.feature_enabled and accelerator.mixed_precision == 'fp8':
|
||||
raise ValueError("Custom plugin not compatible with FP8")
|
||||
return True
|
||||
```
|
||||
|
||||
### 3. State Management
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
counter: int = 0
|
||||
history: list = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.history is None:
|
||||
self.history = []
|
||||
|
||||
def update_state(self, value):
|
||||
"""Update plugin state during training."""
|
||||
self.counter += 1
|
||||
self.history.append(value)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Plugins: https://huggingface.co/docs/accelerate/package_reference/kwargs
|
||||
- DeepSpeed Config: https://www.deepspeed.ai/docs/config-json/
|
||||
- FSDP Guide: https://pytorch.org/docs/stable/fsdp.html
|
||||
- Custom Training Loops: https://huggingface.co/docs/accelerate/usage_guides/training_tpu
|
||||
@@ -0,0 +1,489 @@
|
||||
# Megatron Integration with Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate supports Megatron-LM for massive model training with tensor parallelism and pipeline parallelism.
|
||||
|
||||
**Megatron capabilities**:
|
||||
- **Tensor Parallelism (TP)**: Split layers across GPUs
|
||||
- **Pipeline Parallelism (PP)**: Split model depth across GPUs
|
||||
- **Data Parallelism (DP)**: Replicate model across GPU groups
|
||||
- **Sequence Parallelism**: Split sequences for long contexts
|
||||
|
||||
## Setup
|
||||
|
||||
### Install Megatron-LM
|
||||
|
||||
```bash
|
||||
# Clone Megatron-LM repository
|
||||
git clone https://github.com/NVIDIA/Megatron-LM.git
|
||||
cd Megatron-LM
|
||||
pip install -e .
|
||||
|
||||
# Install Apex (NVIDIA optimizations)
|
||||
git clone https://github.com/NVIDIA/apex
|
||||
cd apex
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
|
||||
--config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
```
|
||||
|
||||
### Accelerate Configuration
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
```
|
||||
In which compute environment are you running?
|
||||
> This machine
|
||||
|
||||
Which type of machine are you using?
|
||||
> Multi-GPU
|
||||
|
||||
How many different machines will you use?
|
||||
> 1
|
||||
|
||||
Do you want to use DeepSpeed/FSDP?
|
||||
> No
|
||||
|
||||
Do you want to use Megatron-LM?
|
||||
> Yes
|
||||
|
||||
What is the Tensor Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Do you want to enable Sequence Parallelism?
|
||||
> No
|
||||
|
||||
What is the Pipeline Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
What is the Data Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Where to perform activation checkpointing? ['SELECTIVE', 'FULL', 'NONE']
|
||||
> SELECTIVE
|
||||
|
||||
Where to perform activation partitioning? ['SEQUENTIAL', 'UNIFORM']
|
||||
> SEQUENTIAL
|
||||
```
|
||||
|
||||
**Generated config** (`~/.cache/huggingface/accelerate/default_config.yaml`):
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: MEGATRON_LM
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
megatron_lm_config:
|
||||
megatron_lm_gradient_clipping: 1.0
|
||||
megatron_lm_learning_rate_decay_iters: 320000
|
||||
megatron_lm_num_micro_batches: 1
|
||||
megatron_lm_pp_degree: 2
|
||||
megatron_lm_recompute_activations: true
|
||||
megatron_lm_sequence_parallelism: false
|
||||
megatron_lm_tp_degree: 2
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
## Parallelism Strategies
|
||||
|
||||
### Tensor Parallelism (TP)
|
||||
|
||||
**Splits each transformer layer across GPUs**:
|
||||
|
||||
```python
|
||||
# Layer split across 2 GPUs
|
||||
# GPU 0: First half of attention heads
|
||||
# GPU 1: Second half of attention heads
|
||||
|
||||
# Each GPU computes partial outputs
|
||||
# All-reduce combines results
|
||||
```
|
||||
|
||||
**TP degree recommendations**:
|
||||
- **TP=1**: No tensor parallelism (single GPU per layer)
|
||||
- **TP=2**: 2 GPUs per layer (good for 7-13B models)
|
||||
- **TP=4**: 4 GPUs per layer (good for 20-40B models)
|
||||
- **TP=8**: 8 GPUs per layer (good for 70B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Reduces memory per GPU
|
||||
- All-reduce communication (fast)
|
||||
|
||||
**Drawbacks**:
|
||||
- Requires fast inter-GPU bandwidth (NVLink)
|
||||
- Communication overhead per layer
|
||||
|
||||
### Pipeline Parallelism (PP)
|
||||
|
||||
**Splits model depth across GPUs**:
|
||||
|
||||
```python
|
||||
# 12-layer model, PP=4
|
||||
# GPU 0: Layers 0-2
|
||||
# GPU 1: Layers 3-5
|
||||
# GPU 2: Layers 6-8
|
||||
# GPU 3: Layers 9-11
|
||||
```
|
||||
|
||||
**PP degree recommendations**:
|
||||
- **PP=1**: No pipeline parallelism
|
||||
- **PP=2**: 2 pipeline stages (good for 20-40B models)
|
||||
- **PP=4**: 4 pipeline stages (good for 70B+ models)
|
||||
- **PP=8**: 8 pipeline stages (good for 175B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Linear memory reduction (4× PP = 4× less memory)
|
||||
- Works across nodes (slower interconnect OK)
|
||||
|
||||
**Drawbacks**:
|
||||
- Pipeline bubbles (idle time)
|
||||
- Requires micro-batching
|
||||
|
||||
### Data Parallelism (DP)
|
||||
|
||||
**Replicates model across GPU groups**:
|
||||
|
||||
```python
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
# Group 0 (GPUs 0-3): Full model replica
|
||||
# Group 1 (GPUs 4-7): Full model replica
|
||||
```
|
||||
|
||||
**DP degree**:
|
||||
- `DP = total_gpus / (TP × PP)`
|
||||
- Example: 8 GPUs, TP=2, PP=2 → DP=2
|
||||
|
||||
**Benefits**:
|
||||
- Increases throughput
|
||||
- Scales batch size
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
**Splits long sequences across GPUs** (extends TP):
|
||||
|
||||
```python
|
||||
# 8K sequence, TP=2, Sequence Parallel=True
|
||||
# GPU 0: Tokens 0-4095
|
||||
# GPU 1: Tokens 4096-8191
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Enables very long sequences (100K+ tokens)
|
||||
- Reduces activation memory
|
||||
|
||||
**Requirements**:
|
||||
- Must use with TP > 1
|
||||
- RoPE/ALiBi position encodings work best
|
||||
|
||||
## Accelerate Code Example
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
|
||||
# Configure Megatron
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2, # Tensor parallelism degree
|
||||
pp_degree=2, # Pipeline parallelism degree
|
||||
num_micro_batches=4, # Micro-batches for pipeline
|
||||
gradient_clipping=1.0, # Gradient clipping value
|
||||
sequence_parallelism=False, # Enable sequence parallelism
|
||||
recompute_activations=True, # Activation checkpointing
|
||||
use_distributed_optimizer=True, # Distributed optimizer
|
||||
custom_prepare_model_function=None, # Custom model prep
|
||||
)
|
||||
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Prepare model and optimizer
|
||||
model, optimizer, train_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
# Training loop (same as DDP!)
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
### Full Training Script
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
def main():
|
||||
# Megatron configuration
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2,
|
||||
pp_degree=2,
|
||||
num_micro_batches=4,
|
||||
gradient_clipping=1.0,
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
gradient_accumulation_steps=8,
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Model
|
||||
config = GPT2Config(
|
||||
n_layer=24,
|
||||
n_head=16,
|
||||
n_embd=1024,
|
||||
)
|
||||
model = GPT2LMHeadModel(config)
|
||||
|
||||
# Optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
|
||||
|
||||
# Prepare
|
||||
model, optimizer, train_loader = accelerator.prepare(
|
||||
model, optimizer, train_loader
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
for batch in train_loader:
|
||||
with accelerator.accumulate(model):
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Save checkpoint
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.save_state(f'checkpoint-epoch-{epoch}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
### Launch Command
|
||||
|
||||
```bash
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node (2 nodes, 8 GPUs each)
|
||||
# Node 0
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
|
||||
# Node 1
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 1 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
```
|
||||
|
||||
## Activation Checkpointing
|
||||
|
||||
**Reduces memory by recomputing activations**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
recompute_activations=True, # Enable checkpointing
|
||||
checkpoint_num_layers=1, # Checkpoint every N layers
|
||||
distribute_checkpointed_activations=True, # Distribute across TP
|
||||
partition_activations=True, # Partition in PP
|
||||
check_for_nan_in_loss_and_grad=True, # Stability check
|
||||
)
|
||||
```
|
||||
|
||||
**Strategies**:
|
||||
- `SELECTIVE`: Checkpoint transformer blocks only
|
||||
- `FULL`: Checkpoint all layers
|
||||
- `NONE`: No checkpointing
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
## Distributed Optimizer
|
||||
|
||||
**Shards optimizer state across DP ranks**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
use_distributed_optimizer=True, # Enable sharded optimizer
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Reduces optimizer memory by DP degree
|
||||
- Example: DP=4 → 4× less optimizer memory per GPU
|
||||
|
||||
**Compatible with**:
|
||||
- AdamW, Adam, SGD
|
||||
- Mixed precision training
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Micro-Batch Size
|
||||
|
||||
```python
|
||||
# Pipeline parallelism requires micro-batching
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # 16 micro-batches per pipeline
|
||||
)
|
||||
|
||||
# Effective batch = num_micro_batches × micro_batch_size × DP
|
||||
# Example: 16 × 2 × 4 = 128
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- More micro-batches → less pipeline bubble
|
||||
- Typical: 4-16 micro-batches
|
||||
|
||||
### Sequence Length
|
||||
|
||||
```python
|
||||
# For long sequences, enable sequence parallelism
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=4,
|
||||
sequence_parallelism=True, # Required: TP > 1
|
||||
)
|
||||
|
||||
# Enables sequences up to TP × normal limit
|
||||
# Example: TP=4, 8K normal → 32K with sequence parallel
|
||||
```
|
||||
|
||||
### GPU Topology
|
||||
|
||||
**NVLink required for TP**:
|
||||
```bash
|
||||
# Check NVLink topology
|
||||
nvidia-smi topo -m
|
||||
|
||||
# Good topology (NVLink between all GPUs)
|
||||
# GPU0 - GPU1: NV12 (fast)
|
||||
# GPU0 - GPU2: NV12 (fast)
|
||||
|
||||
# Bad topology (PCIe only)
|
||||
# GPU0 - GPU4: PHB (slow, avoid TP across these)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- **TP**: Within same node (NVLink)
|
||||
- **PP**: Across nodes (slower interconnect OK)
|
||||
- **DP**: Any topology
|
||||
|
||||
## Model Size Guidelines
|
||||
|
||||
| Model Size | GPUs | TP | PP | DP | Micro-Batches |
|
||||
|------------|------|----|----|----|--------------|
|
||||
| 7B | 8 | 1 | 1 | 8 | 1 |
|
||||
| 13B | 8 | 2 | 1 | 4 | 1 |
|
||||
| 20B | 16 | 4 | 1 | 4 | 1 |
|
||||
| 40B | 32 | 4 | 2 | 4 | 4 |
|
||||
| 70B | 64 | 8 | 2 | 4 | 8 |
|
||||
| 175B | 128 | 8 | 4 | 4 | 16 |
|
||||
|
||||
**Assumptions**: BF16, 2K sequence length, A100 80GB
|
||||
|
||||
## Checkpointing
|
||||
|
||||
### Save Checkpoint
|
||||
|
||||
```python
|
||||
# Save full model state
|
||||
accelerator.save_state('checkpoint-1000')
|
||||
|
||||
# Megatron saves separate files per rank
|
||||
# checkpoint-1000/
|
||||
# pytorch_model_tp_0_pp_0.bin
|
||||
# pytorch_model_tp_0_pp_1.bin
|
||||
# pytorch_model_tp_1_pp_0.bin
|
||||
# pytorch_model_tp_1_pp_1.bin
|
||||
# optimizer_tp_0_pp_0.bin
|
||||
# ...
|
||||
```
|
||||
|
||||
### Load Checkpoint
|
||||
|
||||
```python
|
||||
# Resume training
|
||||
accelerator.load_state('checkpoint-1000')
|
||||
|
||||
# Automatically loads correct shard per rank
|
||||
```
|
||||
|
||||
### Convert to Standard PyTorch
|
||||
|
||||
```bash
|
||||
# Merge Megatron checkpoint to single file
|
||||
python merge_megatron_checkpoint.py \
|
||||
--checkpoint-dir checkpoint-1000 \
|
||||
--output pytorch_model.bin
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: OOM with Pipeline Parallelism
|
||||
|
||||
**Solution**: Increase micro-batches
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # Increase from 4
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slow Training
|
||||
|
||||
**Check 1**: Pipeline bubbles (PP too high)
|
||||
```python
|
||||
# Reduce PP, increase TP
|
||||
tp_degree=4 # Increase
|
||||
pp_degree=2 # Decrease
|
||||
```
|
||||
|
||||
**Check 2**: Micro-batch size too small
|
||||
```python
|
||||
num_micro_batches=8 # Increase
|
||||
```
|
||||
|
||||
### Issue: NVLink Not Detected
|
||||
|
||||
```bash
|
||||
# Verify NVLink
|
||||
nvidia-smi nvlink -s
|
||||
|
||||
# If no NVLink, avoid TP > 1
|
||||
# Use PP or DP instead
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Megatron-LM: https://github.com/NVIDIA/Megatron-LM
|
||||
- Accelerate Megatron docs: https://huggingface.co/docs/accelerate/usage_guides/megatron_lm
|
||||
- Paper: "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism"
|
||||
- NVIDIA Apex: https://github.com/NVIDIA/apex
|
||||
@@ -0,0 +1,525 @@
|
||||
# Accelerate Performance Tuning
|
||||
|
||||
## Profiling
|
||||
|
||||
### Basic Profiling
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
import time
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
batch = next(iter(dataloader))
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Profile training loop
|
||||
start = time.time()
|
||||
total_batches = 100
|
||||
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= total_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone() # Sync all processes
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
batches_per_sec = total_batches / elapsed
|
||||
samples_per_sec = (total_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
|
||||
print(f"Throughput: {samples_per_sec:.2f} samples/sec")
|
||||
print(f"Batches/sec: {batches_per_sec:.2f}")
|
||||
```
|
||||
|
||||
### PyTorch Profiler Integration
|
||||
|
||||
```python
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True
|
||||
) as prof:
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10: # Profile first 10 batches
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Print profiling results
|
||||
print(prof.key_averages().table(
|
||||
sort_by="cuda_time_total", row_limit=20
|
||||
))
|
||||
|
||||
# Export to Chrome tracing
|
||||
prof.export_chrome_trace("trace.json")
|
||||
# View at chrome://tracing
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### 1. Gradient Accumulation
|
||||
|
||||
**Problem**: Large batch size causes OOM
|
||||
|
||||
**Solution**: Accumulate gradients across micro-batches
|
||||
|
||||
```python
|
||||
accelerator = Accelerator(gradient_accumulation_steps=8)
|
||||
|
||||
# Effective batch = batch_size × accumulation_steps × num_gpus
|
||||
# Example: 4 × 8 × 8 = 256
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation logic
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
**Memory savings**: 8× less activation memory (with 8 accumulation steps)
|
||||
|
||||
### 2. Gradient Checkpointing
|
||||
|
||||
**Enable in model**:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
use_cache=False # Required for gradient checkpointing
|
||||
)
|
||||
|
||||
# Enable checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare with Accelerate
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
### 3. Mixed Precision
|
||||
|
||||
**BF16 (A100/H100)**:
|
||||
```python
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# Automatic mixed precision
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch) # Forward in BF16
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss) # Backward in FP32
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**FP16 (V100, older GPUs)**:
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16,
|
||||
growth_interval=2000
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% compared to FP32
|
||||
|
||||
### 4. CPU Offloading (DeepSpeed)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3,
|
||||
offload_optimizer_device="cpu", # Offload optimizer to CPU
|
||||
offload_param_device="cpu", # Offload parameters to CPU
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 10-20× for optimizer state, 5-10× for parameters
|
||||
|
||||
**Trade-off**: 20-30% slower due to CPU-GPU transfers
|
||||
|
||||
### 5. Flash Attention
|
||||
|
||||
```python
|
||||
# Install flash-attn
|
||||
# pip install flash-attn
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
attn_implementation="flash_attention_2" # Enable Flash Attention 2
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% for attention, 2× faster
|
||||
|
||||
**Requirements**: A100/H100, sequence length must be multiple of 128
|
||||
|
||||
## Communication Optimization
|
||||
|
||||
### 1. Gradient Bucketing (DDP)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Bucket size for gradient reduction
|
||||
gradient_as_bucket_view=True, # Reduce memory copies
|
||||
static_graph=False # Set True if model doesn't change
|
||||
)
|
||||
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
```
|
||||
|
||||
**Recommended bucket sizes**:
|
||||
- Small models (<1B): 25 MB
|
||||
- Medium models (1-10B): 50-100 MB
|
||||
- Large models (>10B): 100-200 MB
|
||||
|
||||
### 2. Find Unused Parameters
|
||||
|
||||
```python
|
||||
# Only enable if model has unused parameters (slower!)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
find_unused_parameters=True
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Models with conditional branches (e.g., mixture of experts)
|
||||
|
||||
**Cost**: 10-20% slower
|
||||
|
||||
### 3. NCCL Tuning
|
||||
|
||||
```bash
|
||||
# Set environment variables before launch
|
||||
export NCCL_DEBUG=INFO # Debug info
|
||||
export NCCL_IB_DISABLE=0 # Enable InfiniBand
|
||||
export NCCL_SOCKET_IFNAME=eth0 # Network interface
|
||||
export NCCL_P2P_LEVEL=NVL # Use NVLink
|
||||
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
**NCCL_P2P_LEVEL options**:
|
||||
- `NVL`: NVLink (fastest, within node)
|
||||
- `PIX`: PCIe (fast, within node)
|
||||
- `PHB`: PCIe host bridge (slow, cross-node)
|
||||
|
||||
## Data Loading Optimization
|
||||
|
||||
### 1. DataLoader Workers
|
||||
|
||||
```python
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # Parallel data loading
|
||||
pin_memory=True, # Pin memory for faster GPU transfer
|
||||
prefetch_factor=2, # Prefetch batches per worker
|
||||
persistent_workers=True # Keep workers alive between epochs
|
||||
)
|
||||
|
||||
train_loader = accelerator.prepare(train_loader)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- `num_workers`: 2-4 per GPU (8 GPUs → 16-32 workers)
|
||||
- `pin_memory`: Always True for GPU training
|
||||
- `prefetch_factor`: 2-4 (higher for slow data loading)
|
||||
|
||||
### 2. Data Preprocessing
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Bad: Preprocess during training (slow)
|
||||
dataset = load_dataset("openwebtext")
|
||||
|
||||
for batch in dataset:
|
||||
tokens = tokenizer(batch['text']) # Slow!
|
||||
...
|
||||
|
||||
# Good: Preprocess once, save
|
||||
dataset = load_dataset("openwebtext")
|
||||
tokenized = dataset.map(
|
||||
lambda x: tokenizer(x['text']),
|
||||
batched=True,
|
||||
num_proc=8, # Parallel preprocessing
|
||||
remove_columns=['text']
|
||||
)
|
||||
tokenized.save_to_disk("preprocessed_data")
|
||||
|
||||
# Load preprocessed
|
||||
dataset = load_from_disk("preprocessed_data")
|
||||
```
|
||||
|
||||
### 3. Faster Tokenization
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Enable Rust-based tokenizers (10× faster)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"gpt2",
|
||||
use_fast=True # Use fast Rust tokenizer
|
||||
)
|
||||
```
|
||||
|
||||
## Compilation (PyTorch 2.0+)
|
||||
|
||||
### Compile Model
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Compile model for faster execution
|
||||
model = torch.compile(
|
||||
model,
|
||||
mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune
|
||||
fullgraph=False, # Compile entire graph (stricter)
|
||||
dynamic=True # Support dynamic shapes
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Speedup**: 10-50% depending on model
|
||||
|
||||
**Compilation modes**:
|
||||
- `default`: Balanced (best for most cases)
|
||||
- `reduce-overhead`: Min overhead (best for small batches)
|
||||
- `max-autotune`: Max performance (slow compile, best for production)
|
||||
|
||||
### Compilation Best Practices
|
||||
|
||||
```python
|
||||
# Bad: Compile after prepare (won't work)
|
||||
model = accelerator.prepare(model)
|
||||
model = torch.compile(model) # Error!
|
||||
|
||||
# Good: Compile before prepare
|
||||
model = torch.compile(model)
|
||||
model = accelerator.prepare(model)
|
||||
|
||||
# Training loop
|
||||
for batch in dataloader:
|
||||
# First iteration: slow (compilation)
|
||||
# Subsequent iterations: fast (compiled)
|
||||
outputs = model(**batch)
|
||||
...
|
||||
```
|
||||
|
||||
## Benchmarking Different Strategies
|
||||
|
||||
### Script Template
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
def benchmark_strategy(strategy_name, accelerator_kwargs):
|
||||
"""Benchmark a specific training strategy."""
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
|
||||
# Setup
|
||||
model = create_model()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
dataloader = create_dataloader()
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(
|
||||
model, optimizer, dataloader
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10:
|
||||
break
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Benchmark
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
|
||||
num_batches = 100
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= num_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
throughput = (num_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
memory_used = torch.cuda.max_memory_allocated() / 1e9 # GB
|
||||
|
||||
if accelerator.is_main_process:
|
||||
print(f"\n{strategy_name}:")
|
||||
print(f" Throughput: {throughput:.2f} samples/sec")
|
||||
print(f" Memory: {memory_used:.2f} GB")
|
||||
print(f" Time: {elapsed:.2f} sec")
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Benchmark different strategies
|
||||
strategies = [
|
||||
("DDP + FP32", {}),
|
||||
("DDP + BF16", {"mixed_precision": "bf16"}),
|
||||
("DDP + BF16 + GradAccum", {"mixed_precision": "bf16", "gradient_accumulation_steps": 4}),
|
||||
("FSDP", {"fsdp_plugin": fsdp_plugin}),
|
||||
("DeepSpeed ZeRO-2", {"deepspeed_plugin": ds_plugin_stage2}),
|
||||
("DeepSpeed ZeRO-3", {"deepspeed_plugin": ds_plugin_stage3}),
|
||||
]
|
||||
|
||||
for name, kwargs in strategies:
|
||||
benchmark_strategy(name, kwargs)
|
||||
```
|
||||
|
||||
## Performance Checklist
|
||||
|
||||
**Before training**:
|
||||
- [ ] Use BF16/FP16 mixed precision
|
||||
- [ ] Enable gradient checkpointing (if OOM)
|
||||
- [ ] Set appropriate `num_workers` (2-4 per GPU)
|
||||
- [ ] Enable `pin_memory=True`
|
||||
- [ ] Preprocess data once, not during training
|
||||
- [ ] Compile model with `torch.compile` (PyTorch 2.0+)
|
||||
|
||||
**For large models**:
|
||||
- [ ] Use FSDP or DeepSpeed ZeRO-3
|
||||
- [ ] Enable CPU offloading (if still OOM)
|
||||
- [ ] Use Flash Attention
|
||||
- [ ] Increase gradient accumulation
|
||||
|
||||
**For multi-node**:
|
||||
- [ ] Check network topology (InfiniBand > Ethernet)
|
||||
- [ ] Tune NCCL settings
|
||||
- [ ] Use larger bucket sizes for DDP
|
||||
- [ ] Verify NVLink for tensor parallelism
|
||||
|
||||
**Profiling**:
|
||||
- [ ] Profile first 10-100 batches
|
||||
- [ ] Check GPU utilization (`nvidia-smi dmon`)
|
||||
- [ ] Check data loading time (should be <5% of iteration)
|
||||
- [ ] Identify communication bottlenecks
|
||||
|
||||
## Common Performance Issues
|
||||
|
||||
### Issue: Low GPU Utilization (<80%)
|
||||
|
||||
**Cause 1**: Data loading bottleneck
|
||||
```python
|
||||
# Solution: Increase workers and prefetch
|
||||
num_workers=8
|
||||
prefetch_factor=4
|
||||
```
|
||||
|
||||
**Cause 2**: Small batch size
|
||||
```python
|
||||
# Solution: Increase batch size or use gradient accumulation
|
||||
batch_size=32 # Increase
|
||||
gradient_accumulation_steps=4 # Or accumulate
|
||||
```
|
||||
|
||||
### Issue: High Memory Usage
|
||||
|
||||
**Solution 1**: Gradient checkpointing
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
|
||||
**Solution 2**: Reduce batch size, increase accumulation
|
||||
```python
|
||||
batch_size=8 # Reduce from 32
|
||||
gradient_accumulation_steps=16 # Maintain effective batch
|
||||
```
|
||||
|
||||
**Solution 3**: Use FSDP or DeepSpeed ZeRO-3
|
||||
```python
|
||||
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
|
||||
```
|
||||
|
||||
### Issue: Slow Multi-GPU Training
|
||||
|
||||
**Cause**: Communication bottleneck
|
||||
|
||||
**Check 1**: Gradient bucket size
|
||||
```python
|
||||
ddp_kwargs = DistributedDataParallelKwargs(bucket_cap_mb=100)
|
||||
```
|
||||
|
||||
**Check 2**: NCCL settings
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
# Check for "Using NVLS" (good) vs "Using PHB" (bad)
|
||||
```
|
||||
|
||||
**Check 3**: Network bandwidth
|
||||
```bash
|
||||
# Test inter-GPU bandwidth
|
||||
nvidia-smi nvlink -s
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Performance: https://huggingface.co/docs/accelerate/usage_guides/performance
|
||||
- PyTorch Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
|
||||
- NCCL Tuning: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
|
||||
- Flash Attention: https://github.com/Dao-AILab/flash-attention
|
||||
161
protected/skills-backup/mlops/training/axolotl/SKILL.md
Normal file
161
protected/skills-backup/mlops/training/axolotl/SKILL.md
Normal file
@@ -0,0 +1,161 @@
|
||||
---
|
||||
name: axolotl
|
||||
description: Expert guidance for fine-tuning LLMs with Axolotl - YAML configs, 100+ models, LoRA/QLoRA, DPO/KTO/ORPO/GRPO, multimodal support
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [axolotl, torch, transformers, datasets, peft, accelerate, deepspeed]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Fine-Tuning, Axolotl, LLM, LoRA, QLoRA, DPO, KTO, ORPO, GRPO, YAML, HuggingFace, DeepSpeed, Multimodal]
|
||||
|
||||
---
|
||||
|
||||
# Axolotl Skill
|
||||
|
||||
Comprehensive assistance with axolotl development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be triggered when:
|
||||
- Working with axolotl
|
||||
- Asking about axolotl features or APIs
|
||||
- Implementing axolotl solutions
|
||||
- Debugging axolotl code
|
||||
- Learning axolotl best practices
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Patterns
|
||||
|
||||
**Pattern 1:** To validate that acceptable data transfer speeds exist for your training job, running NCCL Tests can help pinpoint bottlenecks, for example:
|
||||
|
||||
```
|
||||
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
|
||||
```
|
||||
|
||||
**Pattern 2:** Configure your model to use FSDP in the Axolotl yaml. For example:
|
||||
|
||||
```
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
```
|
||||
|
||||
**Pattern 3:** The context_parallel_size should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
```
|
||||
context_parallel_size
|
||||
```
|
||||
|
||||
**Pattern 4:** For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and context_parallel_size=4: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU micro_batch_size is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
```
|
||||
context_parallel_size=4
|
||||
```
|
||||
|
||||
**Pattern 5:** Setting save_compressed: true in your configuration enables saving models in a compressed format, which: - Reduces disk space usage by approximately 40% - Maintains compatibility with vLLM for accelerated inference - Maintains compatibility with llmcompressor for further optimization (example: quantization)
|
||||
|
||||
```
|
||||
save_compressed: true
|
||||
```
|
||||
|
||||
**Pattern 6:** Note It is not necessary to place your integration in the integrations folder. It can be in any location, so long as it’s installed in a package in your python env. See this repo for an example: https://github.com/axolotl-ai-cloud/diff-transformer
|
||||
|
||||
```
|
||||
integrations
|
||||
```
|
||||
|
||||
**Pattern 7:** Handle both single-example and batched data. - single example: sample[‘input_ids’] is a list[int] - batched data: sample[‘input_ids’] is a list[list[int]]
|
||||
|
||||
```
|
||||
utils.trainer.drop_long_seq(sample, sequence_len=2048, min_sequence_len=2)
|
||||
```
|
||||
|
||||
### Example Code Patterns
|
||||
|
||||
**Example 1** (python):
|
||||
```python
|
||||
cli.cloud.modal_.ModalCloud(config, app=None)
|
||||
```
|
||||
|
||||
**Example 2** (python):
|
||||
```python
|
||||
cli.cloud.modal_.run_cmd(cmd, run_folder, volumes=None)
|
||||
```
|
||||
|
||||
**Example 3** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer(
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
**Example 4** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer.log(logs, start_time=None)
|
||||
```
|
||||
|
||||
**Example 5** (python):
|
||||
```python
|
||||
prompt_strategies.input_output.RawInputOutputPrompter()
|
||||
```
|
||||
|
||||
## Reference Files
|
||||
|
||||
This skill includes comprehensive documentation in `references/`:
|
||||
|
||||
- **api.md** - Api documentation
|
||||
- **dataset-formats.md** - Dataset-Formats documentation
|
||||
- **other.md** - Other documentation
|
||||
|
||||
Use `view` to read specific reference files when detailed information is needed.
|
||||
|
||||
## Working with This Skill
|
||||
|
||||
### For Beginners
|
||||
Start with the getting_started or tutorials reference files for foundational concepts.
|
||||
|
||||
### For Specific Features
|
||||
Use the appropriate category reference file (api, guides, etc.) for detailed information.
|
||||
|
||||
### For Code Examples
|
||||
The quick reference section above contains common patterns extracted from the official docs.
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
Organized documentation extracted from official sources. These files contain:
|
||||
- Detailed explanations
|
||||
- Code examples with language annotations
|
||||
- Links to original documentation
|
||||
- Table of contents for quick navigation
|
||||
|
||||
### scripts/
|
||||
Add helper scripts here for common automation tasks.
|
||||
|
||||
### assets/
|
||||
Add templates, boilerplate, or example projects here.
|
||||
|
||||
## Notes
|
||||
|
||||
- This skill was automatically generated from official documentation
|
||||
- Reference files preserve the structure and examples from source docs
|
||||
- Code examples include language detection for better syntax highlighting
|
||||
- Quick reference patterns are extracted from common usage examples in the docs
|
||||
|
||||
## Updating
|
||||
|
||||
To refresh this skill with updated documentation:
|
||||
1. Re-run the scraper with the same configuration
|
||||
2. The skill will be rebuilt with the latest information
|
||||
|
||||
|
||||
5548
protected/skills-backup/mlops/training/axolotl/references/api.md
Normal file
5548
protected/skills-backup/mlops/training/axolotl/references/api.md
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,15 @@
|
||||
# Axolotl Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Api
|
||||
**File:** `api.md`
|
||||
**Pages:** 150
|
||||
|
||||
### Dataset-Formats
|
||||
**File:** `dataset-formats.md`
|
||||
**Pages:** 9
|
||||
|
||||
### Other
|
||||
**File:** `other.md`
|
||||
**Pages:** 26
|
||||
3563
protected/skills-backup/mlops/training/axolotl/references/other.md
Normal file
3563
protected/skills-backup/mlops/training/axolotl/references/other.md
Normal file
File diff suppressed because it is too large
Load Diff
370
protected/skills-backup/mlops/training/flash-attention/SKILL.md
Normal file
370
protected/skills-backup/mlops/training/flash-attention/SKILL.md
Normal file
@@ -0,0 +1,370 @@
|
||||
---
|
||||
name: optimizing-attention-flash
|
||||
description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [flash-attn, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
|
||||
|
||||
---
|
||||
|
||||
# Flash Attention - Fast Memory-Efficient Attention
|
||||
|
||||
## Quick start
|
||||
|
||||
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
|
||||
|
||||
**PyTorch native (easiest, PyTorch 2.2+)**:
|
||||
```python
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
|
||||
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
||||
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
||||
|
||||
# Automatically uses Flash Attention if available
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
```
|
||||
|
||||
**flash-attn library (more features)**:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# q, k, v: [batch, seqlen, nheads, headdim]
|
||||
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Enable in existing PyTorch model
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Flash Attention Integration:
|
||||
- [ ] Step 1: Check PyTorch version (≥2.2)
|
||||
- [ ] Step 2: Enable Flash Attention backend
|
||||
- [ ] Step 3: Verify speedup with profiling
|
||||
- [ ] Step 4: Test accuracy matches baseline
|
||||
```
|
||||
|
||||
**Step 1: Check PyTorch version**
|
||||
|
||||
```bash
|
||||
python -c "import torch; print(torch.__version__)"
|
||||
# Should be ≥2.2.0
|
||||
```
|
||||
|
||||
If <2.2, upgrade:
|
||||
```bash
|
||||
pip install --upgrade torch
|
||||
```
|
||||
|
||||
**Step 2: Enable Flash Attention backend**
|
||||
|
||||
Replace standard attention:
|
||||
```python
|
||||
# Before (standard attention)
|
||||
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
|
||||
out = attn_weights @ v
|
||||
|
||||
# After (Flash Attention)
|
||||
import torch.nn.functional as F
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
```
|
||||
|
||||
Force Flash Attention backend:
|
||||
```python
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=False,
|
||||
enable_mem_efficient=False
|
||||
):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
```
|
||||
|
||||
**Step 3: Verify speedup with profiling**
|
||||
|
||||
```python
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
def test_attention(use_flash):
|
||||
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
if use_flash:
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True):
|
||||
return F.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
|
||||
return attn @ v
|
||||
|
||||
# Benchmark
|
||||
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
|
||||
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
|
||||
|
||||
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
|
||||
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
|
||||
```
|
||||
|
||||
Expected: 2-4x speedup for sequences >512 tokens.
|
||||
|
||||
**Step 4: Test accuracy matches baseline**
|
||||
|
||||
```python
|
||||
# Compare outputs
|
||||
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
# Flash Attention
|
||||
out_flash = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
# Standard attention
|
||||
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
|
||||
out_standard = attn_weights @ v
|
||||
|
||||
# Check difference
|
||||
diff = (out_flash - out_standard).abs().max()
|
||||
print(f"Max difference: {diff:.6f}")
|
||||
# Should be <1e-3 for float16
|
||||
```
|
||||
|
||||
### Workflow 2: Use flash-attn library for advanced features
|
||||
|
||||
For multi-query attention, sliding window, or H100 FP8.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
flash-attn Library Setup:
|
||||
- [ ] Step 1: Install flash-attn library
|
||||
- [ ] Step 2: Modify attention code
|
||||
- [ ] Step 3: Enable advanced features
|
||||
- [ ] Step 4: Benchmark performance
|
||||
```
|
||||
|
||||
**Step 1: Install flash-attn library**
|
||||
|
||||
```bash
|
||||
# NVIDIA GPUs (CUDA 12.0+)
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
# Verify installation
|
||||
python -c "from flash_attn import flash_attn_func; print('Success')"
|
||||
```
|
||||
|
||||
**Step 2: Modify attention code**
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# Input: [batch_size, seq_len, num_heads, head_dim]
|
||||
# Transpose from [batch, heads, seq, dim] if needed
|
||||
q = q.transpose(1, 2) # [batch, seq, heads, dim]
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
out = flash_attn_func(
|
||||
q, k, v,
|
||||
dropout_p=0.1,
|
||||
causal=True, # For autoregressive models
|
||||
window_size=(-1, -1), # No sliding window
|
||||
softmax_scale=None # Auto-scale
|
||||
)
|
||||
|
||||
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
|
||||
```
|
||||
|
||||
**Step 3: Enable advanced features**
|
||||
|
||||
Multi-query attention (shared K/V across heads):
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# q: [batch, seq, num_q_heads, dim]
|
||||
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
|
||||
out = flash_attn_func(q, k, v) # Automatically handles MQA
|
||||
```
|
||||
|
||||
Sliding window attention (local attention):
|
||||
```python
|
||||
# Only attend to window of 256 tokens before/after
|
||||
out = flash_attn_func(
|
||||
q, k, v,
|
||||
window_size=(256, 256), # (left, right) window
|
||||
causal=True
|
||||
)
|
||||
```
|
||||
|
||||
**Step 4: Benchmark performance**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from flash_attn import flash_attn_func
|
||||
import time
|
||||
|
||||
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = flash_attn_func(q, k, v)
|
||||
|
||||
# Benchmark
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
out = flash_attn_func(q, k, v)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
|
||||
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
|
||||
```
|
||||
|
||||
### Workflow 3: H100 FP8 optimization (FlashAttention-3)
|
||||
|
||||
For maximum performance on H100 GPUs.
|
||||
|
||||
```
|
||||
FP8 Setup:
|
||||
- [ ] Step 1: Verify H100 GPU available
|
||||
- [ ] Step 2: Install flash-attn with FP8 support
|
||||
- [ ] Step 3: Convert inputs to FP8
|
||||
- [ ] Step 4: Run with FP8 attention
|
||||
```
|
||||
|
||||
**Step 1: Verify H100 GPU**
|
||||
|
||||
```bash
|
||||
nvidia-smi --query-gpu=name --format=csv
|
||||
# Should show "H100" or "H800"
|
||||
```
|
||||
|
||||
**Step 2: Install flash-attn with FP8 support**
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
# FP8 support included for H100
|
||||
```
|
||||
|
||||
**Step 3: Convert inputs to FP8**
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
|
||||
# Convert to float8_e4m3 (FP8)
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
k_fp8 = k.to(torch.float8_e4m3fn)
|
||||
v_fp8 = v.to(torch.float8_e4m3fn)
|
||||
```
|
||||
|
||||
**Step 4: Run with FP8 attention**
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# FlashAttention-3 automatically uses FP8 kernels on H100
|
||||
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
|
||||
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Flash Attention when:**
|
||||
- Training transformers with sequences >512 tokens
|
||||
- Running inference with long context (>2K tokens)
|
||||
- GPU memory constrained (OOM with standard attention)
|
||||
- Need 2-4x speedup without accuracy loss
|
||||
- Using PyTorch 2.2+ or can install flash-attn
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Standard attention**: Sequences <256 tokens (overhead not worth it)
|
||||
- **xFormers**: Need more attention variants (not just speed)
|
||||
- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: ImportError: cannot import flash_attn**
|
||||
|
||||
Install with no-build-isolation flag:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Or install CUDA toolkit first:
|
||||
```bash
|
||||
conda install cuda -c nvidia
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
**Issue: Slower than expected (no speedup)**
|
||||
|
||||
Flash Attention benefits increase with sequence length:
|
||||
- <512 tokens: Minimal speedup (10-20%)
|
||||
- 512-2K tokens: 2-3x speedup
|
||||
- >2K tokens: 3-4x speedup
|
||||
|
||||
Check sequence length is sufficient.
|
||||
|
||||
**Issue: RuntimeError: CUDA error**
|
||||
|
||||
Verify GPU supports Flash Attention:
|
||||
```python
|
||||
import torch
|
||||
print(torch.cuda.get_device_capability())
|
||||
# Should be ≥(7, 5) for Turing+
|
||||
```
|
||||
|
||||
Flash Attention requires:
|
||||
- Ampere (A100, A10): ✅ Full support
|
||||
- Turing (T4): ✅ Supported
|
||||
- Volta (V100): ❌ Not supported
|
||||
|
||||
**Issue: Accuracy degradation**
|
||||
|
||||
Check dtype is float16 or bfloat16 (not float32):
|
||||
```python
|
||||
q = q.to(torch.float16) # Or torch.bfloat16
|
||||
```
|
||||
|
||||
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
|
||||
|
||||
**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
|
||||
|
||||
**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
|
||||
|
||||
**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
|
||||
- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
|
||||
- **CUDA**: 12.0+ (11.8 minimum)
|
||||
- **PyTorch**: 2.2+ for native support
|
||||
|
||||
**Not supported**: V100 (Volta), CPU inference
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
|
||||
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
|
||||
- Blog: https://tridao.me/blog/2024/flash3/
|
||||
- GitHub: https://github.com/Dao-AILab/flash-attention
|
||||
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
# Performance Benchmarks
|
||||
|
||||
## Contents
|
||||
- Speed comparisons across GPUs
|
||||
- Memory usage analysis
|
||||
- Scaling with sequence length
|
||||
- Training vs inference performance
|
||||
- Flash Attention versions comparison
|
||||
|
||||
## Speed comparisons across GPUs
|
||||
|
||||
### A100 80GB (Ampere)
|
||||
|
||||
**Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
|
||||
|------------|----------|--------------|--------------|---------------|
|
||||
| 512 | 1.2 | 0.9 | N/A | 1.3x |
|
||||
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
|
||||
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
|
||||
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
|
||||
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
|
||||
|
||||
### H100 80GB (Hopper)
|
||||
|
||||
**Forward pass time** (milliseconds, same config):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
|
||||
|------------|----------|--------------|---------------------|--------------------|--------------|
|
||||
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
|
||||
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
|
||||
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
|
||||
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
|
||||
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
|
||||
|
||||
**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
|
||||
|
||||
### A10G 24GB (Ampere)
|
||||
|
||||
**Forward pass time** (milliseconds, batch=4):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Speedup |
|
||||
|------------|----------|--------------|---------|
|
||||
| 512 | 2.1 | 1.6 | 1.3x |
|
||||
| 1024 | 6.8 | 2.8 | 2.4x |
|
||||
| 2048 | 25.9 | 9.4 | 2.8x |
|
||||
| 4096 | 102.1 | 35.2 | 2.9x |
|
||||
|
||||
## Memory usage analysis
|
||||
|
||||
### GPU memory consumption (batch=8, heads=32, dim=64)
|
||||
|
||||
**Standard attention memory**:
|
||||
|
||||
| Seq Length | Attention Matrix | KV Cache | Total | Notes |
|
||||
|------------|------------------|----------|-------|-------|
|
||||
| 512 | 8 MB | 32 MB | 40 MB | Manageable |
|
||||
| 2048 | 128 MB | 128 MB | 256 MB | Growing |
|
||||
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
|
||||
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
|
||||
|
||||
**Flash Attention 2 memory**:
|
||||
|
||||
| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
|
||||
|------------|---------------------|----------|-------|-----------|
|
||||
| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
|
||||
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
|
||||
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
|
||||
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
|
||||
|
||||
**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
|
||||
|
||||
### Memory scaling comparison
|
||||
|
||||
**Llama 2 7B model memory** (float16, batch=1):
|
||||
|
||||
| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
|
||||
|----------------|-------------------|-------------------|-------------------|
|
||||
| 2K | 3.2 GB | 2.1 GB | Both: Yes |
|
||||
| 4K | 5.8 GB | 2.8 GB | Both: Yes |
|
||||
| 8K | 12.1 GB | 4.2 GB | Both: Yes |
|
||||
| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
|
||||
| 32K | OOM | 14.2 GB | Only Flash: Yes |
|
||||
|
||||
### Training memory (Llama 2 7B, batch=4)
|
||||
|
||||
| Context | Standard (GB) | Flash Attn (GB) | Reduction |
|
||||
|---------|---------------|-----------------|-----------|
|
||||
| 2K | 18.2 | 12.4 | 32% |
|
||||
| 4K | 34.8 | 16.8 | 52% |
|
||||
| 8K | OOM (>40GB) | 26.2 | Fits! |
|
||||
|
||||
## Scaling with sequence length
|
||||
|
||||
### Computational complexity
|
||||
|
||||
**Standard attention**:
|
||||
- Time: O(N² × d)
|
||||
- Memory: O(N² + N × d)
|
||||
|
||||
**Flash Attention**:
|
||||
- Time: O(N² × d) (same, but with better constants)
|
||||
- Memory: O(N × d) (linear!)
|
||||
|
||||
### Empirical scaling (A100, batch=1, heads=32, dim=64)
|
||||
|
||||
**Time per token (milliseconds)**:
|
||||
|
||||
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
||||
|----------|-----|-----|-----|-----|-----|------|
|
||||
| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
|
||||
| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
|
||||
| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
|
||||
|
||||
**Observation**: Speedup increases quadratically with sequence length!
|
||||
|
||||
### Memory per token (MB)
|
||||
|
||||
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
||||
|----------|-----|-----|-----|-----|-----|------|
|
||||
| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
|
||||
| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
|
||||
|
||||
**Observation**: Flash Attention memory per token is constant!
|
||||
|
||||
## Training vs inference performance
|
||||
|
||||
### Training (forward + backward, Llama 2 7B, A100)
|
||||
|
||||
| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
||||
|-------------|------------------------|--------------------------|---------|
|
||||
| 4 × 2K | 1.2 | 3.1 | 2.6x |
|
||||
| 8 × 2K | 2.1 | 5.8 | 2.8x |
|
||||
| 4 × 4K | 0.4 | 1.3 | 3.3x |
|
||||
| 8 × 4K | OOM | 2.4 | Enabled |
|
||||
| 2 × 8K | 0.1 | 0.4 | 4.0x |
|
||||
|
||||
### Inference (generation, Llama 2 7B, A100)
|
||||
|
||||
| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
||||
|----------------|----------------------|-------------------------|---------|
|
||||
| 512 | 48 | 52 | 1.1x |
|
||||
| 2K | 42 | 62 | 1.5x |
|
||||
| 4K | 31 | 58 | 1.9x |
|
||||
| 8K | 18 | 51 | 2.8x |
|
||||
| 16K | OOM | 42 | Enabled |
|
||||
|
||||
**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
|
||||
|
||||
## Flash Attention versions comparison
|
||||
|
||||
### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
|
||||
|
||||
| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
|
||||
|--------|-----|-----|------------|-----------|
|
||||
| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
|
||||
| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
|
||||
| TFLOPS | 180 | 420 | 740 | 1150 |
|
||||
| GPU util % | 35% | 55% | 75% | 82% |
|
||||
|
||||
**Key improvements**:
|
||||
- FA2: 2.3x faster than FA1 (better parallelism)
|
||||
- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
|
||||
- FA3 (FP8): 2.6x faster than FA2 (low precision)
|
||||
|
||||
### Features by version
|
||||
|
||||
| Feature | FA1 | FA2 | FA3 |
|
||||
|---------|-----|-----|-----|
|
||||
| Basic attention | ✅ | ✅ | ✅ |
|
||||
| Causal masking | ✅ | ✅ | ✅ |
|
||||
| Multi-query attention | ❌ | ✅ | ✅ |
|
||||
| Sliding window | ❌ | ✅ | ✅ |
|
||||
| Paged KV cache | ❌ | ✅ | ✅ |
|
||||
| FP8 support | ❌ | ❌ | ✅ (H100 only) |
|
||||
| Work partitioning | Basic | Advanced | Optimal |
|
||||
|
||||
## Real-world model benchmarks
|
||||
|
||||
### Llama 2 models (A100 80GB, batch=4, seq=2048)
|
||||
|
||||
| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
||||
|-------|--------|------------------------|--------------------------|---------|
|
||||
| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
|
||||
| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
|
||||
| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
|
||||
|
||||
### GPT-style models (seq=1024)
|
||||
|
||||
| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
||||
|-------|----------------------|-------------------------|---------|
|
||||
| GPT-2 (124M) | 520 | 680 | 1.3x |
|
||||
| GPT-J (6B) | 42 | 98 | 2.3x |
|
||||
| GPT-NeoX (20B) | 8 | 22 | 2.75x |
|
||||
|
||||
## Recommendations by use case
|
||||
|
||||
**Training large models (>7B parameters)**:
|
||||
- Use Flash Attention 2 on A100
|
||||
- Use Flash Attention 3 FP8 on H100 for maximum speed
|
||||
- Expected: 2.5-3x speedup
|
||||
|
||||
**Long context inference (>4K tokens)**:
|
||||
- Flash Attention essential (enables contexts standard attention can't handle)
|
||||
- Expected: 2-4x speedup, 5-10x memory reduction
|
||||
|
||||
**Short sequences (<512 tokens)**:
|
||||
- Flash Attention provides 1.2-1.5x speedup
|
||||
- Minimal memory benefit
|
||||
- Still worth enabling (no downside)
|
||||
|
||||
**Multi-user serving**:
|
||||
- Flash Attention reduces per-request memory
|
||||
- Allows higher concurrent batch sizes
|
||||
- Can serve 2-3x more users on same hardware
|
||||
@@ -0,0 +1,293 @@
|
||||
# HuggingFace Transformers Integration
|
||||
|
||||
## Contents
|
||||
- Enabling Flash Attention in Transformers
|
||||
- Supported model architectures
|
||||
- Configuration examples
|
||||
- Performance comparisons
|
||||
- Troubleshooting model-specific issues
|
||||
|
||||
## Enabling Flash Attention in Transformers
|
||||
|
||||
HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively.
|
||||
|
||||
**Simple enable for any supported model**:
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
**Install requirements**:
|
||||
```bash
|
||||
pip install transformers>=4.36
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
## Supported model architectures
|
||||
|
||||
As of Transformers 4.40:
|
||||
|
||||
**Fully supported**:
|
||||
- Llama / Llama 2 / Llama 3
|
||||
- Mistral / Mixtral
|
||||
- Falcon
|
||||
- GPT-NeoX
|
||||
- Phi / Phi-2 / Phi-3
|
||||
- Qwen / Qwen2
|
||||
- Gemma
|
||||
- Starcoder2
|
||||
- GPT-J
|
||||
- OPT
|
||||
- BLOOM
|
||||
|
||||
**Partially supported** (encoder-decoder):
|
||||
- BART
|
||||
- T5 / Flan-T5
|
||||
- Whisper
|
||||
|
||||
**Check support**:
|
||||
```python
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.from_pretrained("model-name")
|
||||
print(config._attn_implementation_internal)
|
||||
# 'flash_attention_2' if supported
|
||||
```
|
||||
|
||||
## Configuration examples
|
||||
|
||||
### Llama 2 with Flash Attention
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Generate
|
||||
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_length=100)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
### Mistral with Flash Attention for long context
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16, # Better for long context
|
||||
device_map="auto",
|
||||
max_position_embeddings=32768 # Extended context
|
||||
)
|
||||
|
||||
# Process long document (32K tokens)
|
||||
long_text = "..." * 10000
|
||||
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
```
|
||||
|
||||
### Fine-tuning with Flash Attention
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=3,
|
||||
fp16=True, # Must match model dtype
|
||||
optim="adamw_torch_fused" # Fast optimizer
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Multi-GPU training
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
# Model parallelism with Flash Attention
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-13b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto", # Automatic multi-GPU placement
|
||||
max_memory={0: "20GB", 1: "20GB"} # Limit per GPU
|
||||
)
|
||||
```
|
||||
|
||||
## Performance comparisons
|
||||
|
||||
### Memory usage (Llama 2 7B, batch=1)
|
||||
|
||||
| Sequence Length | Standard Attention | Flash Attention 2 | Reduction |
|
||||
|-----------------|-------------------|-------------------|-----------|
|
||||
| 512 | 1.2 GB | 0.9 GB | 25% |
|
||||
| 2048 | 3.8 GB | 1.4 GB | 63% |
|
||||
| 8192 | 14.2 GB | 3.2 GB | 77% |
|
||||
| 32768 | OOM (>24GB) | 10.8 GB | Fits! |
|
||||
|
||||
### Speed (tokens/sec, A100 80GB)
|
||||
|
||||
| Model | Standard | Flash Attn 2 | Speedup |
|
||||
|-------|----------|--------------|---------|
|
||||
| Llama 2 7B (seq=2048) | 42 | 118 | 2.8x |
|
||||
| Llama 2 13B (seq=4096) | 18 | 52 | 2.9x |
|
||||
| Llama 2 70B (seq=2048) | 4 | 11 | 2.75x |
|
||||
|
||||
### Training throughput (samples/sec)
|
||||
|
||||
| Model | Batch Size | Standard | Flash Attn 2 | Speedup |
|
||||
|-------|------------|----------|--------------|---------|
|
||||
| Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x |
|
||||
| Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x |
|
||||
| Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x |
|
||||
|
||||
## Troubleshooting model-specific issues
|
||||
|
||||
### Issue: Model doesn't support Flash Attention
|
||||
|
||||
Check support list above. If not supported, use PyTorch SDPA as fallback:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="sdpa", # PyTorch native (still faster)
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: CUDA out of memory during loading
|
||||
|
||||
Reduce memory footprint:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
max_memory={0: "18GB"}, # Reserve memory for KV cache
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slower inference than expected
|
||||
|
||||
Ensure dtype matches:
|
||||
|
||||
```python
|
||||
# Model and inputs must both be float16/bfloat16
|
||||
model = model.to(torch.float16)
|
||||
inputs = tokenizer(..., return_tensors="pt").to("cuda")
|
||||
inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v
|
||||
for k, v in inputs.items()}
|
||||
```
|
||||
|
||||
### Issue: Different outputs vs standard attention
|
||||
|
||||
Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal:
|
||||
|
||||
```python
|
||||
# Compare outputs
|
||||
model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16)
|
||||
model_flash = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
inputs = tokenizer("Test", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
out_standard = model_standard(**inputs).logits
|
||||
out_flash = model_flash(**inputs).logits
|
||||
|
||||
diff = (out_standard - out_flash).abs().max()
|
||||
print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4
|
||||
```
|
||||
|
||||
### Issue: ImportError during model loading
|
||||
|
||||
Install flash-attn:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Or disable Flash Attention:
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="eager", # Standard PyTorch
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Always use float16/bfloat16** with Flash Attention (not float32)
|
||||
2. **Set device_map="auto"** for automatic memory management
|
||||
3. **Use bfloat16 for long context** (better numerical stability)
|
||||
4. **Enable gradient checkpointing** for training large models
|
||||
5. **Monitor memory** with `torch.cuda.max_memory_allocated()`
|
||||
|
||||
**Example with all best practices**:
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16, # Better for training
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
# Enable gradient checkpointing for memory
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Training with optimizations
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=2,
|
||||
bf16=True, # Match model dtype
|
||||
optim="adamw_torch_fused",
|
||||
gradient_checkpointing=True
|
||||
)
|
||||
```
|
||||
@@ -0,0 +1,97 @@
|
||||
# GRPO/RL Training Skill
|
||||
|
||||
**Expert-level guidance for Group Relative Policy Optimization with TRL**
|
||||
|
||||
## 📁 Skill Structure
|
||||
|
||||
```
|
||||
grpo-rl-training/
|
||||
├── SKILL.md # Main skill documentation (READ THIS FIRST)
|
||||
├── README.md # This file
|
||||
├── templates/
|
||||
│ └── basic_grpo_training.py # Production-ready training template
|
||||
└── examples/
|
||||
└── reward_functions_library.py # 20+ reward function examples
|
||||
```
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
1. **Read SKILL.md** - Comprehensive guide with all concepts and patterns
|
||||
2. **Copy `templates/basic_grpo_training.py`** - Start with working code
|
||||
3. **Browse `examples/reward_functions_library.py`** - Pick reward functions for your task
|
||||
4. **Modify for your use case** - Adapt dataset, rewards, and config
|
||||
|
||||
## 💡 What's Inside
|
||||
|
||||
### SKILL.md (Main Documentation)
|
||||
- Core GRPO concepts and algorithm fundamentals
|
||||
- Complete implementation workflow (dataset → rewards → training → deployment)
|
||||
- 10+ reward function examples with code
|
||||
- Hyperparameter tuning guide
|
||||
- Training insights (loss behavior, metrics, debugging)
|
||||
- Troubleshooting guide
|
||||
- Production best practices
|
||||
|
||||
### Templates
|
||||
- **basic_grpo_training.py**: Minimal, production-ready training script
|
||||
- Uses Qwen 2.5 1.5B Instruct
|
||||
- 3 reward functions (format + correctness)
|
||||
- LoRA for efficient training
|
||||
- Fully documented and ready to run
|
||||
|
||||
### Examples
|
||||
- **reward_functions_library.py**: 20+ battle-tested reward functions
|
||||
- Correctness rewards (exact match, fuzzy match, numeric, code execution)
|
||||
- Format rewards (XML, JSON, strict/soft)
|
||||
- Length rewards (ideal length, min/max)
|
||||
- Style rewards (reasoning quality, citations, repetition penalty)
|
||||
- Combined rewards (multi-objective optimization)
|
||||
- Preset collections for common tasks
|
||||
|
||||
## 📖 Usage for Agents
|
||||
|
||||
When this skill is loaded in your agent's context:
|
||||
|
||||
1. **Always read SKILL.md first** before implementing
|
||||
2. **Start simple** - Use length-based reward to validate setup
|
||||
3. **Build incrementally** - Add one reward function at a time
|
||||
4. **Reference examples** - Copy patterns from reward_functions_library.py
|
||||
5. **Monitor training** - Watch reward metrics (not loss!)
|
||||
|
||||
## 🎯 Common Use Cases
|
||||
|
||||
| Task Type | Recommended Rewards | Template |
|
||||
|-----------|---------------------|----------|
|
||||
| Math reasoning | `MATH_REASONING_REWARDS` preset | basic_grpo_training.py |
|
||||
| Code generation | `CODE_GENERATION_REWARDS` preset | Modify dataset in template |
|
||||
| Summarization | `SUMMARIZATION_REWARDS` preset | Adjust prompts + rewards |
|
||||
| Q&A | `QA_REWARDS` preset | Use fuzzy match + citations |
|
||||
|
||||
## ⚠️ Critical Reminders
|
||||
|
||||
- **Loss goes UP during training** - This is normal (it's KL divergence)
|
||||
- **Use 3-5 reward functions** - Single rewards often fail
|
||||
- **Test rewards before training** - Debug each function independently
|
||||
- **Monitor reward_std** - Should stay > 0.1 (avoid mode collapse)
|
||||
- **Start with num_generations=4-8** - Scale up if GPU allows
|
||||
|
||||
## 🔗 External Resources
|
||||
|
||||
- [TRL Documentation](https://huggingface.co/docs/trl)
|
||||
- [DeepSeek R1 Paper](https://arxiv.org/abs/2501.12948)
|
||||
- [Open R1 Implementation](https://github.com/huggingface/open-r1)
|
||||
- [Unsloth (2-3x faster)](https://docs.unsloth.ai/)
|
||||
|
||||
## 📝 Version
|
||||
|
||||
**v1.0.0** - Initial release (January 2025)
|
||||
|
||||
## 👨💻 Maintained By
|
||||
|
||||
Orchestra Research
|
||||
For questions or improvements, see https://orchestra.com
|
||||
|
||||
---
|
||||
|
||||
**License:** MIT
|
||||
**Last Updated:** January 2025
|
||||
575
protected/skills-backup/mlops/training/grpo-rl-training/SKILL.md
Normal file
575
protected/skills-backup/mlops/training/grpo-rl-training/SKILL.md
Normal file
@@ -0,0 +1,575 @@
|
||||
---
|
||||
name: grpo-rl-training
|
||||
description: Expert guidance for GRPO/RL fine-tuning with TRL for reasoning and task-specific model training
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [transformers>=4.47.0, trl>=0.14.0, datasets>=3.2.0, peft>=0.14.0, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, Reinforcement Learning, GRPO, TRL, RLHF, Reward Modeling, Reasoning, DPO, PPO, Structured Output]
|
||||
|
||||
---
|
||||
|
||||
# GRPO/RL Training with TRL
|
||||
|
||||
Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use GRPO training when you need to:
|
||||
- **Enforce specific output formats** (e.g., XML tags, JSON, structured reasoning)
|
||||
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
|
||||
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
|
||||
- **Align models to domain-specific behaviors** without labeled preference data
|
||||
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
|
||||
|
||||
**Do NOT use GRPO for:**
|
||||
- Simple supervised fine-tuning tasks (use SFT instead)
|
||||
- Tasks without clear reward signals
|
||||
- When you already have high-quality preference pairs (use DPO/PPO instead)
|
||||
|
||||
---
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. GRPO Algorithm Fundamentals
|
||||
|
||||
**Key Mechanism:**
|
||||
- Generates **multiple completions** for each prompt (group size: 4-16)
|
||||
- Compares completions within each group using reward functions
|
||||
- Updates policy to favor higher-rewarded responses relative to the group
|
||||
|
||||
**Critical Difference from PPO:**
|
||||
- No separate reward model needed
|
||||
- More sample-efficient (learns from within-group comparisons)
|
||||
- Simpler to implement and debug
|
||||
|
||||
**Mathematical Intuition:**
|
||||
```
|
||||
For each prompt p:
|
||||
1. Generate N completions: {c₁, c₂, ..., cₙ}
|
||||
2. Compute rewards: {r₁, r₂, ..., rₙ}
|
||||
3. Learn to increase probability of high-reward completions
|
||||
relative to low-reward ones in the same group
|
||||
```
|
||||
|
||||
### 2. Reward Function Design Philosophy
|
||||
|
||||
**Golden Rules:**
|
||||
1. **Compose multiple reward functions** - Each handles one aspect (format, correctness, style)
|
||||
2. **Scale rewards appropriately** - Higher weight = stronger signal
|
||||
3. **Use incremental rewards** - Partial credit for partial compliance
|
||||
4. **Test rewards independently** - Debug each reward function in isolation
|
||||
|
||||
**Reward Function Types:**
|
||||
|
||||
| Type | Use Case | Example Weight |
|
||||
|------|----------|----------------|
|
||||
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
|
||||
| **Format** | Strict structure enforcement | 0.5-1.0 |
|
||||
| **Length** | Encourage verbosity/conciseness | 0.1-0.5 |
|
||||
| **Style** | Penalize unwanted patterns | -0.5 to 0.5 |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Workflow
|
||||
|
||||
### Step 1: Dataset Preparation
|
||||
|
||||
**Critical Requirements:**
|
||||
- Prompts in chat format (list of dicts with 'role' and 'content')
|
||||
- Include system prompts to set expectations
|
||||
- For verifiable tasks, include ground truth answers as additional columns
|
||||
|
||||
**Example Structure:**
|
||||
```python
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
def prepare_dataset(raw_data):
|
||||
"""
|
||||
Transform raw data into GRPO-compatible format.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content (system + user messages)
|
||||
- 'answer': str (ground truth, optional but recommended)
|
||||
"""
|
||||
return raw_data.map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': extract_answer(x['raw_answer'])
|
||||
})
|
||||
```
|
||||
|
||||
**Pro Tips:**
|
||||
- Use one-shot or few-shot examples in system prompt for complex formats
|
||||
- Keep prompts concise (max_prompt_length: 256-512 tokens)
|
||||
- Validate data quality before training (garbage in = garbage out)
|
||||
|
||||
### Step 2: Reward Function Implementation
|
||||
|
||||
**Template Structure:**
|
||||
```python
|
||||
def reward_function_name(
|
||||
prompts, # List[List[Dict]]: Original prompts
|
||||
completions, # List[List[Dict]]: Model generations
|
||||
answer=None, # Optional: Ground truth from dataset
|
||||
**kwargs # Additional dataset columns
|
||||
) -> list[float]:
|
||||
"""
|
||||
Evaluate completions and return rewards.
|
||||
|
||||
Returns: List of floats (one per completion)
|
||||
"""
|
||||
# Extract completion text
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
|
||||
# Compute rewards
|
||||
rewards = []
|
||||
for response in responses:
|
||||
score = compute_score(response)
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Example 1: Correctness Reward (Math/Coding)**
|
||||
```python
|
||||
def correctness_reward(prompts, completions, answer, **kwargs):
|
||||
"""Reward correct answers with high score."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_final_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0
|
||||
for ans, gt in zip(extracted, answer)]
|
||||
```
|
||||
|
||||
**Example 2: Format Reward (Structured Output)**
|
||||
```python
|
||||
import re
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Reward XML-like structured format."""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0
|
||||
for r in responses]
|
||||
```
|
||||
|
||||
**Example 3: Incremental Format Reward (Partial Credit)**
|
||||
```python
|
||||
def incremental_format_reward(completions, **kwargs):
|
||||
"""Award partial credit for format compliance."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.25
|
||||
if '</reasoning>' in r:
|
||||
score += 0.25
|
||||
if '<answer>' in r:
|
||||
score += 0.25
|
||||
if '</answer>' in r:
|
||||
score += 0.25
|
||||
# Penalize extra text after closing tag
|
||||
if r.count('</answer>') == 1:
|
||||
extra_text = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra_text) * 0.001
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Critical Insight:**
|
||||
Combine 3-5 reward functions for robust training. Order matters less than diversity of signals.
|
||||
|
||||
### Step 3: Training Configuration
|
||||
|
||||
**Memory-Optimized Config (Small GPU)**
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6, # Lower = more stable
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4, # Effective batch = 4
|
||||
|
||||
# GRPO-specific
|
||||
num_generations=8, # Group size: 8-16 recommended
|
||||
max_prompt_length=256,
|
||||
max_completion_length=512,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
max_steps=None, # Or set fixed steps (e.g., 500)
|
||||
|
||||
# Optimization
|
||||
bf16=True, # Faster on A100/H100
|
||||
optim="adamw_8bit", # Memory-efficient optimizer
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Or "none" for no logging
|
||||
)
|
||||
```
|
||||
|
||||
**High-Performance Config (Large GPU)**
|
||||
```python
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
learning_rate=1e-5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=2,
|
||||
num_generations=16, # Larger groups = better signal
|
||||
max_prompt_length=512,
|
||||
max_completion_length=1024,
|
||||
num_train_epochs=1,
|
||||
bf16=True,
|
||||
use_vllm=True, # Fast generation with vLLM
|
||||
logging_steps=10,
|
||||
)
|
||||
```
|
||||
|
||||
**Critical Hyperparameters:**
|
||||
|
||||
| Parameter | Impact | Tuning Advice |
|
||||
|-----------|--------|---------------|
|
||||
| `num_generations` | Group size for comparison | Start with 8, increase to 16 if GPU allows |
|
||||
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
|
||||
| `max_completion_length` | Output verbosity | Match your task (512 for reasoning, 256 for short answers) |
|
||||
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
|
||||
|
||||
### Step 4: Model Setup and Training
|
||||
|
||||
**Standard Setup (Transformers)**
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load model
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2", # 2-3x faster
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Optional: LoRA for parameter-efficient training
|
||||
peft_config = LoraConfig(
|
||||
r=16, # Rank (higher = more capacity)
|
||||
lora_alpha=32, # Scaling factor (typically 2*r)
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward,
|
||||
format_reward,
|
||||
correctness_reward,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config, # Remove for full fine-tuning
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.train()
|
||||
|
||||
# Save
|
||||
trainer.save_model("final_model")
|
||||
```
|
||||
|
||||
**Unsloth Setup (2-3x Faster)**
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="google/gemma-3-1b-it",
|
||||
max_seq_length=1024,
|
||||
load_in_4bit=True,
|
||||
fast_inference=True,
|
||||
max_lora_rank=32,
|
||||
)
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"],
|
||||
lora_alpha=32,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
# Rest is identical to standard setup
|
||||
trainer = GRPOTrainer(model=model, ...)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Critical Training Insights
|
||||
|
||||
### 1. Loss Behavior (EXPECTED PATTERN)
|
||||
- **Loss starts near 0 and INCREASES during training**
|
||||
- This is CORRECT - loss measures KL divergence from initial policy
|
||||
- Model is learning (diverging from original behavior to optimize rewards)
|
||||
- Monitor reward metrics instead of loss for progress
|
||||
|
||||
### 2. Reward Tracking
|
||||
Key metrics to watch:
|
||||
- `reward`: Average across all completions
|
||||
- `reward_std`: Diversity within groups (should remain > 0)
|
||||
- `kl`: KL divergence from reference (should grow moderately)
|
||||
|
||||
**Healthy Training Pattern:**
|
||||
```
|
||||
Step Reward Reward_Std KL
|
||||
100 0.5 0.3 0.02
|
||||
200 0.8 0.25 0.05
|
||||
300 1.2 0.2 0.08 ← Good progression
|
||||
400 1.5 0.15 0.12
|
||||
```
|
||||
|
||||
**Warning Signs:**
|
||||
- Reward std → 0 (model collapsing to single response)
|
||||
- KL exploding (> 0.5) (diverging too much, reduce LR)
|
||||
- Reward stuck (reward functions too harsh or model capacity issue)
|
||||
|
||||
### 3. Common Pitfalls and Solutions
|
||||
|
||||
| Problem | Symptom | Solution |
|
||||
|---------|---------|----------|
|
||||
| **Mode collapse** | All completions identical | Increase `num_generations`, add diversity penalty |
|
||||
| **No learning** | Flat rewards | Check reward function logic, increase LR |
|
||||
| **OOM errors** | GPU memory exceeded | Reduce `num_generations`, enable gradient checkpointing |
|
||||
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
|
||||
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
|
||||
|
||||
---
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### 1. Multi-Stage Training
|
||||
For complex tasks, train in stages:
|
||||
|
||||
```python
|
||||
# Stage 1: Format compliance (epochs=1)
|
||||
trainer_stage1 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[incremental_format_reward, format_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage1.train()
|
||||
|
||||
# Stage 2: Correctness (epochs=1)
|
||||
trainer_stage2 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[format_reward, correctness_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage2.train()
|
||||
```
|
||||
|
||||
### 2. Adaptive Reward Scaling
|
||||
```python
|
||||
class AdaptiveReward:
|
||||
def __init__(self, base_reward_func, initial_weight=1.0):
|
||||
self.func = base_reward_func
|
||||
self.weight = initial_weight
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
rewards = self.func(*args, **kwargs)
|
||||
return [r * self.weight for r in rewards]
|
||||
|
||||
def adjust_weight(self, success_rate):
|
||||
"""Increase weight if model struggling, decrease if succeeding."""
|
||||
if success_rate < 0.3:
|
||||
self.weight *= 1.2
|
||||
elif success_rate > 0.8:
|
||||
self.weight *= 0.9
|
||||
```
|
||||
|
||||
### 3. Custom Dataset Integration
|
||||
```python
|
||||
def load_custom_knowledge_base(csv_path):
|
||||
"""Example: School communication platform docs."""
|
||||
import pandas as pd
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
dataset = Dataset.from_pandas(df).map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': x['expert_answer']
|
||||
})
|
||||
return dataset
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Deployment and Inference
|
||||
|
||||
### Save and Merge LoRA
|
||||
```python
|
||||
# Merge LoRA adapters into base model
|
||||
if hasattr(trainer.model, 'merge_and_unload'):
|
||||
merged_model = trainer.model.merge_and_unload()
|
||||
merged_model.save_pretrained("production_model")
|
||||
tokenizer.save_pretrained("production_model")
|
||||
```
|
||||
|
||||
### Inference Example
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
generator = pipeline(
|
||||
"text-generation",
|
||||
model="production_model",
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
result = generator(
|
||||
[
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': "What is 15 + 27?"}
|
||||
],
|
||||
max_new_tokens=256,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9
|
||||
)
|
||||
print(result[0]['generated_text'])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices Checklist
|
||||
|
||||
**Before Training:**
|
||||
- [ ] Validate dataset format (prompts as List[Dict])
|
||||
- [ ] Test reward functions on sample data
|
||||
- [ ] Calculate expected max_prompt_length from data
|
||||
- [ ] Choose appropriate num_generations based on GPU memory
|
||||
- [ ] Set up logging (wandb recommended)
|
||||
|
||||
**During Training:**
|
||||
- [ ] Monitor reward progression (should increase)
|
||||
- [ ] Check reward_std (should stay > 0.1)
|
||||
- [ ] Watch for OOM errors (reduce batch size if needed)
|
||||
- [ ] Sample generations every 50-100 steps
|
||||
- [ ] Validate format compliance on holdout set
|
||||
|
||||
**After Training:**
|
||||
- [ ] Merge LoRA weights if using PEFT
|
||||
- [ ] Test on diverse prompts
|
||||
- [ ] Compare to baseline model
|
||||
- [ ] Document reward weights and hyperparameters
|
||||
- [ ] Save reproducibility config
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting Guide
|
||||
|
||||
### Debugging Workflow
|
||||
1. **Isolate reward functions** - Test each independently
|
||||
2. **Check data distribution** - Ensure diversity in prompts
|
||||
3. **Reduce complexity** - Start with single reward, add gradually
|
||||
4. **Monitor generations** - Print samples every N steps
|
||||
5. **Validate extraction logic** - Ensure answer parsing works
|
||||
|
||||
### Quick Fixes
|
||||
```python
|
||||
# Debug reward function
|
||||
def debug_reward(completions, **kwargs):
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
for i, r in enumerate(responses[:2]): # Print first 2
|
||||
print(f"Response {i}: {r[:200]}...")
|
||||
return [1.0] * len(responses) # Dummy rewards
|
||||
|
||||
# Test without training
|
||||
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
|
||||
trainer.generate_completions(dataset[:1]) # Generate without updating
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## References and Resources
|
||||
|
||||
**Official Documentation:**
|
||||
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
|
||||
- DeepSeek R1 Paper: https://arxiv.org/abs/2501.12948
|
||||
- Unsloth Docs: https://docs.unsloth.ai/
|
||||
|
||||
**Example Repositories:**
|
||||
- Open R1 Implementation: https://github.com/huggingface/open-r1
|
||||
- TRL Examples: https://github.com/huggingface/trl/tree/main/examples
|
||||
|
||||
**Recommended Reading:**
|
||||
- Progressive Disclosure Pattern for agent instructions
|
||||
- Reward shaping in RL (Ng et al.)
|
||||
- LoRA paper (Hu et al., 2021)
|
||||
|
||||
---
|
||||
|
||||
## Usage Instructions for Agents
|
||||
|
||||
When this skill is loaded:
|
||||
|
||||
1. **Read this entire file** before implementing GRPO training
|
||||
2. **Start with the simplest reward function** (e.g., length-based) to validate setup
|
||||
3. **Use the templates** in `templates/` directory as starting points
|
||||
4. **Reference examples** in `examples/` for task-specific implementations
|
||||
5. **Follow the workflow** sequentially (don't skip steps)
|
||||
6. **Debug incrementally** - add one reward function at a time
|
||||
|
||||
**Critical Reminders:**
|
||||
- Always use multiple reward functions (3-5 is optimal)
|
||||
- Monitor reward metrics, not loss
|
||||
- Test reward functions before training
|
||||
- Start small (num_generations=4), scale up gradually
|
||||
- Save checkpoints frequently (every 100 steps)
|
||||
|
||||
This skill is designed for **expert-level implementation**. Beginners should start with supervised fine-tuning before attempting GRPO.
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Basic GRPO Training Template
|
||||
=============================
|
||||
|
||||
A minimal, production-ready template for GRPO training with TRL.
|
||||
Adapt this for your specific task by modifying:
|
||||
1. Dataset loading (get_dataset function)
|
||||
2. Reward functions (reward_*_func)
|
||||
3. System prompt (SYSTEM_PROMPT)
|
||||
4. Hyperparameters (GRPOConfig)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import re
|
||||
from datasets import load_dataset, Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
# ==================== CONFIGURATION ====================
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
OUTPUT_DIR = "outputs/grpo-model"
|
||||
MAX_PROMPT_LENGTH = 256
|
||||
MAX_COMPLETION_LENGTH = 512
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
# ==================== DATASET ====================
|
||||
|
||||
def get_dataset(split="train"):
|
||||
"""
|
||||
Load and prepare your dataset.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content
|
||||
- 'answer': str (ground truth, optional)
|
||||
"""
|
||||
# Example: GSM8K math dataset
|
||||
data = load_dataset('openai/gsm8k', 'main')[split]
|
||||
|
||||
def process_example(x):
|
||||
# Extract ground truth answer
|
||||
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
|
||||
|
||||
return {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': answer
|
||||
}
|
||||
|
||||
return data.map(process_example)
|
||||
|
||||
# ==================== HELPER FUNCTIONS ====================
|
||||
|
||||
def extract_xml_tag(text: str, tag: str) -> str:
|
||||
"""Extract content between XML tags."""
|
||||
pattern = f'<{tag}>(.*?)</{tag}>'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
"""Extract the final answer from structured output."""
|
||||
return extract_xml_tag(text, 'answer')
|
||||
|
||||
# ==================== REWARD FUNCTIONS ====================
|
||||
|
||||
def correctness_reward_func(prompts, completions, answer, **kwargs):
|
||||
"""
|
||||
Reward correct answers.
|
||||
Weight: 2.0 (highest priority)
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Reward proper XML format.
|
||||
Weight: 0.5
|
||||
"""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
|
||||
|
||||
def incremental_format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Incremental reward for partial format compliance.
|
||||
Weight: up to 0.5
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.125
|
||||
if '</reasoning>' in r:
|
||||
score += 0.125
|
||||
if '<answer>' in r:
|
||||
score += 0.125
|
||||
if '</answer>' in r:
|
||||
score += 0.125
|
||||
|
||||
# Penalize extra content after closing tag
|
||||
if '</answer>' in r:
|
||||
extra = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra) * 0.001
|
||||
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
|
||||
# ==================== MODEL SETUP ====================
|
||||
|
||||
def setup_model_and_tokenizer():
|
||||
"""Load model and tokenizer with optimizations."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_NAME,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def get_peft_config():
|
||||
"""LoRA configuration for parameter-efficient training."""
|
||||
return LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# ==================== TRAINING ====================
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
|
||||
# Load data
|
||||
print("Loading dataset...")
|
||||
dataset = get_dataset()
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Setup model
|
||||
print("Loading model...")
|
||||
model, tokenizer = setup_model_and_tokenizer()
|
||||
|
||||
# Training configuration
|
||||
training_args = GRPOConfig(
|
||||
output_dir=OUTPUT_DIR,
|
||||
run_name="grpo-training",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
|
||||
# GRPO specific
|
||||
num_generations=8,
|
||||
max_prompt_length=MAX_PROMPT_LENGTH,
|
||||
max_completion_length=MAX_COMPLETION_LENGTH,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
|
||||
# Optimization
|
||||
bf16=True,
|
||||
optim="adamw_8bit",
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Change to "none" to disable logging
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward_func,
|
||||
format_reward_func,
|
||||
correctness_reward_func,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=get_peft_config(),
|
||||
)
|
||||
|
||||
# Train
|
||||
print("Starting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print(f"Saving model to {OUTPUT_DIR}/final")
|
||||
trainer.save_model(f"{OUTPUT_DIR}/final")
|
||||
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,302 @@
|
||||
---
|
||||
name: hermes-atropos-environments
|
||||
description: Build, test, and debug Hermes Agent RL environments for Atropos training. Covers the HermesAgentBaseEnv interface, reward functions, agent loop integration, evaluation with tools, wandb logging, and the three CLI modes (serve/process/evaluate). Use when creating, reviewing, or fixing RL environments in the hermes-agent repo.
|
||||
version: 1.1.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [atropos, rl, environments, training, reinforcement-learning, reward-functions]
|
||||
related_skills: [axolotl, grpo-rl-training, trl-fine-tuning, lm-evaluation-harness]
|
||||
---
|
||||
|
||||
# Hermes Agent Atropos Environments
|
||||
|
||||
Guide for building RL environments in the hermes-agent repo that integrate with the Atropos training framework.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
Atropos BaseEnv (atroposlib/envs/base.py)
|
||||
└── HermesAgentBaseEnv (environments/hermes_base_env.py)
|
||||
├── Handles agent loop orchestration
|
||||
├── Handles tool resolution per group
|
||||
├── Handles ToolContext for reward verification
|
||||
└── YOUR ENVIRONMENT (environments/your_env.py)
|
||||
Only implements: setup, get_next_item, format_prompt,
|
||||
compute_reward, evaluate, wandb_log
|
||||
```
|
||||
|
||||
Hermes environments are special because they run a **multi-turn agent loop with tool calling** — not just single-turn completions. The base env handles the loop; you implement the task and scoring.
|
||||
|
||||
## File Locations
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `environments/hermes_base_env.py` | Base class with agent loop + tool resolution |
|
||||
| `environments/agent_loop.py` | `HermesAgentLoop` + `AgentResult` dataclass |
|
||||
| `environments/tool_context.py` | `ToolContext` for reward verification |
|
||||
| `environments/tool_call_parsers.py` | Phase 2 tool call parsers (hermes, mistral, etc.) |
|
||||
| `environments/your_env.py` | Your environment implementation |
|
||||
|
||||
## Inference Setup — Ask the User First
|
||||
|
||||
**IMPORTANT:** Before running any test, evaluation, or data generation command, always ask the user how they want to handle inference. Do NOT assume OpenRouter or any specific endpoint. Present these options:
|
||||
|
||||
1. **OpenRouter** — Ask which model they want to use (e.g., `anthropic/claude-sonnet-4.5`, `google/gemini-2.5-pro`, `meta-llama/llama-3.3-70b-instruct`, etc.). Requires `OPENROUTER_API_KEY` in environment.
|
||||
2. **Self-hosted VLLM endpoint** — Ask for their base URL (e.g., `http://localhost:8000/v1`) and model name. Set `--openai.server_type vllm`.
|
||||
3. **Other OpenAI-compatible API** — Ask for the base URL, model name, and any required API key. Set `--openai.server_type openai` and `--openai.health_check false`.
|
||||
4. **Local Atropos training server** — For `serve` mode with a live training loop. Default `http://localhost:8000/v1`.
|
||||
|
||||
Once the user tells you their setup, use those values in all CLI commands for that session. Example prompts:
|
||||
|
||||
> "Before I run this, how would you like to handle inference?
|
||||
> 1. OpenRouter (I'll need your preferred model, e.g. claude-sonnet-4.5)
|
||||
> 2. A self-hosted VLLM endpoint (give me the URL and model name)
|
||||
> 3. Another OpenAI-compatible API (give me the URL, model, and any auth details)
|
||||
> 4. Local Atropos training server (serve mode)"
|
||||
|
||||
### Key flags by provider:
|
||||
|
||||
| Provider | `--openai.server_type` | `--openai.health_check` | `--openai.api_key` |
|
||||
|----------|----------------------|------------------------|-------------------|
|
||||
| OpenRouter | `openai` | `false` | `$OPENROUTER_API_KEY` |
|
||||
| VLLM (self-hosted) | `vllm` | (default) | (not needed) |
|
||||
| Other OpenAI-compatible | `openai` | `false` | As needed |
|
||||
| Local Atropos | (default) | (default) | (not needed) |
|
||||
|
||||
## Required Methods
|
||||
|
||||
### 1. `setup()` — Load dataset and initialize state
|
||||
|
||||
```python
|
||||
async def setup(self) -> None:
|
||||
"""Called once at startup. Load datasets, initialize state."""
|
||||
# Try HuggingFace first, fallback to built-in samples
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
ds = load_dataset("your/dataset", split="test")
|
||||
self._items = [...]
|
||||
except Exception:
|
||||
self._items = BUILTIN_SAMPLES
|
||||
|
||||
# Always split into train/eval
|
||||
random.shuffle(self._items)
|
||||
eval_size = max(20, int(len(self._items) * 0.1))
|
||||
self._eval_items = self._items[:eval_size]
|
||||
self._items = self._items[eval_size:]
|
||||
```
|
||||
|
||||
### 2. `get_next_item()` — Return next training item
|
||||
|
||||
```python
|
||||
async def get_next_item(self) -> dict:
|
||||
"""Return next item, cycling through dataset."""
|
||||
item = self._items[self._index % len(self._items)]
|
||||
self._index += 1
|
||||
return item
|
||||
```
|
||||
|
||||
### 3. `format_prompt(item)` — Convert item to user message
|
||||
|
||||
```python
|
||||
def format_prompt(self, item: dict) -> str:
|
||||
"""Convert a dataset item into the user-facing prompt."""
|
||||
return f"Research this question: {item['question']}"
|
||||
```
|
||||
|
||||
### 4. `compute_reward(item, result, ctx)` — Score the rollout
|
||||
|
||||
**CRITICAL**: `result` is an `AgentResult`, NOT a dict. It has these attributes:
|
||||
- `result.messages` — List of message dicts (OpenAI format)
|
||||
- `result.turns_used` — Number of LLM calls made
|
||||
- `result.finished_naturally` — True if model stopped voluntarily
|
||||
- `result.tool_errors` — List of ToolError objects
|
||||
|
||||
**AgentResult does NOT have**: `final_response`, `tool_calls`, `tools_used`.
|
||||
You must extract these from `result.messages`:
|
||||
|
||||
```python
|
||||
async def compute_reward(self, item, result: AgentResult, ctx: ToolContext) -> float:
|
||||
# Extract final response (last assistant message with content)
|
||||
final_response = ""
|
||||
tools_used = []
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
||||
final_response = msg["content"]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
name = fn.get("name", "")
|
||||
if name:
|
||||
tools_used.append(name)
|
||||
|
||||
# Score using LLM judge, heuristic, or ToolContext verification
|
||||
correctness = await self._llm_judge(item, final_response)
|
||||
return correctness
|
||||
```
|
||||
|
||||
`ctx` (ToolContext) gives you terminal/file access to the agent's sandbox for verification:
|
||||
```python
|
||||
# Run tests in the agent's sandbox
|
||||
result = ctx.terminal("pytest /workspace/test.py")
|
||||
return 1.0 if result["exit_code"] == 0 else 0.0
|
||||
```
|
||||
|
||||
### 5. `evaluate()` — Periodic evaluation with full agent loop
|
||||
|
||||
**MUST use the full agent loop with tools**, not single-turn chat_completion.
|
||||
The whole point of hermes-agent environments is agentic evaluation:
|
||||
|
||||
```python
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
import time, uuid
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
start_time = time.time()
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
samples = []
|
||||
|
||||
for item in self._eval_items[:self.config.eval_size]:
|
||||
task_id = str(uuid.uuid4())
|
||||
messages = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=0.0, # Deterministic for eval
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
samples.append({"prompt": ..., "response": ..., "reward": reward})
|
||||
|
||||
eval_metrics = {"eval/mean_reward": ...}
|
||||
await self.evaluate_log(metrics=eval_metrics, samples=samples,
|
||||
start_time=start_time, end_time=time.time())
|
||||
```
|
||||
|
||||
### 6. `wandb_log()` — Custom metrics logging
|
||||
|
||||
Always call `super().wandb_log()` at the end:
|
||||
|
||||
```python
|
||||
async def wandb_log(self, wandb_metrics=None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
if self._reward_buffer:
|
||||
n = len(self._reward_buffer)
|
||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
||||
self._reward_buffer.clear()
|
||||
await super().wandb_log(wandb_metrics) # MUST call super
|
||||
```
|
||||
|
||||
**Pitfall**: `compute_reward` appends to metric buffers. During eval, this pollutes training metrics. Roll back buffer entries added during eval.
|
||||
|
||||
## Config Class
|
||||
|
||||
Always create a custom config subclass with Pydantic Field descriptors. Key inherited fields you can tune: `enabled_toolsets`, `max_agent_turns`, `agent_temperature`, `system_prompt`, `terminal_backend`, `group_size`, `steps_per_eval`, `total_steps`.
|
||||
|
||||
## config_init() — Default Configuration
|
||||
|
||||
Classmethod returning `(YourEnvConfig, [APIServerConfig(...)])`. Set server_type to "openai" for OpenRouter/external APIs. Load API key from environment variable.
|
||||
|
||||
## Three CLI Modes
|
||||
|
||||
```bash
|
||||
# SERVE — Full training loop (connects to Atropos API server)
|
||||
python environments/my_env.py serve --openai.base_url http://localhost:8000/v1
|
||||
|
||||
# PROCESS — Offline data generation (saves JSONL)
|
||||
python environments/my_env.py process --env.total_steps 10 --env.group_size 1 \
|
||||
--env.use_wandb false --env.data_path_to_save_groups output.jsonl \
|
||||
--openai.base_url "<USER_BASE_URL>" \
|
||||
--openai.model_name "<USER_MODEL>" \
|
||||
--openai.server_type <USER_SERVER_TYPE> --openai.health_check false
|
||||
|
||||
# EVALUATE — Standalone eval (runs setup + evaluate only)
|
||||
python environments/my_env.py evaluate --env.eval_size 20 \
|
||||
--env.data_dir_to_save_evals /tmp/eval_results \
|
||||
--openai.base_url "<USER_BASE_URL>" \
|
||||
--openai.model_name "<USER_MODEL>" \
|
||||
--openai.server_type <USER_SERVER_TYPE> --openai.health_check false
|
||||
```
|
||||
|
||||
Config priority: CLI args > YAML file > config_init() defaults.
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
1. **AgentResult has .messages, not .final_response** — Extract the final response by iterating reversed(result.messages) looking for the last assistant message with content.
|
||||
|
||||
2. **evaluate() must use HermesAgentLoop, not chat_completion** — Single-turn chat_completion has no tools. The whole point of hermes-agent benchmarks is agentic evaluation with tool use.
|
||||
|
||||
3. **Don't call _llm_judge twice** — If compute_reward already calls it, extract the score from the buffer instead of calling judge separately in evaluate().
|
||||
|
||||
4. **Eval pollutes training buffers** — compute_reward appends to metric buffers. During eval, roll back buffer entries to keep training metrics clean.
|
||||
|
||||
5. **Always set health_check=false for OpenRouter** — OpenRouter has no /health endpoint.
|
||||
|
||||
6. **Set data_dir_to_save_evals in evaluate mode** — Without it, results aren't saved.
|
||||
|
||||
7. **default_toolsets class variable vs enabled_toolsets config** — The class variable is a hint; the config field is what actually controls tool resolution.
|
||||
|
||||
8. **Tool call parsing in messages** — Tool calls are dicts with `{"function": {"name": ..., "arguments": ...}}`. Always check `isinstance(tc, dict)`.
|
||||
|
||||
9. **ToolContext.cleanup()** — Always call in a finally block to release sandbox resources.
|
||||
|
||||
10. **server_type must be "openai" for external APIs** — Without it, Atropos assumes a local VLLM server.
|
||||
|
||||
11. **Always ask the user for their inference setup** — Never hardcode or assume a specific provider/model. See the "Inference Setup" section above.
|
||||
|
||||
## Reward Function Patterns
|
||||
|
||||
### LLM Judge (for open-ended tasks)
|
||||
Use `self.server.chat_completion()` with a scoring prompt. Parse JSON response for score float. Always include a heuristic fallback (keyword overlap) for when the judge call fails.
|
||||
|
||||
### Binary Verification (for code/terminal tasks)
|
||||
Use `ctx.terminal("pytest test.py -q")` to run tests in the agent's sandbox. Return 1.0 for pass, 0.0 for fail.
|
||||
|
||||
### Multi-Signal (combine multiple indicators)
|
||||
Weight correctness (0.6) + tool usage (0.2) + efficiency (0.2) + optional bonuses. Clamp to [0, 1].
|
||||
|
||||
## Testing Your Environment
|
||||
|
||||
1. **Import test**: `python -c "from environments.my_env import MyEnv; print('OK')"`
|
||||
2. **Ask the user for inference setup** (see "Inference Setup" section above)
|
||||
3. **Process mode** (1 item): Verify JSONL output has valid tokens, masks, scores
|
||||
4. **Evaluate mode**: Verify full agent loop runs with tools, metrics logged correctly
|
||||
5. **Check reward range**: Scores should be in [0, 1], not all identical
|
||||
|
||||
## Minimum Implementation Checklist
|
||||
|
||||
```python
|
||||
class MyEnv(HermesAgentBaseEnv):
|
||||
name = "my-env"
|
||||
env_config_cls = MyEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls): ... # Default server + env config
|
||||
async def setup(self): ... # Load dataset + train/eval split
|
||||
async def get_next_item(self): ... # Cycle through training items
|
||||
def format_prompt(self, item): ... # Item → user message string
|
||||
async def compute_reward(self, item, result, ctx): ... # Score rollout
|
||||
async def evaluate(self, *args, **kwargs): ... # Full agent loop eval
|
||||
async def wandb_log(self, metrics=None): ... # Custom metrics + super()
|
||||
|
||||
if __name__ == "__main__":
|
||||
MyEnv.cli()
|
||||
```
|
||||
@@ -0,0 +1,59 @@
|
||||
# AgentResult Fields Reference
|
||||
|
||||
`AgentResult` is defined in `environments/agent_loop.py` as a dataclass.
|
||||
|
||||
## Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `messages` | `List[Dict[str, Any]]` | Full conversation history in OpenAI message format |
|
||||
| `managed_state` | `Optional[Dict]` | ManagedServer.get_state() if Phase 2, else None |
|
||||
| `turns_used` | `int` | Number of LLM calls made during the loop |
|
||||
| `finished_naturally` | `bool` | True if model stopped calling tools on its own |
|
||||
| `reasoning_per_turn` | `List[Optional[str]]` | Extracted reasoning content per turn |
|
||||
| `tool_errors` | `List[ToolError]` | Tool errors encountered during the loop |
|
||||
|
||||
## ToolError Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `turn` | `int` | Which turn the error occurred |
|
||||
| `tool_name` | `str` | Name of the tool that failed |
|
||||
| `arguments` | `str` | Arguments passed to the tool |
|
||||
| `error` | `str` | Error message |
|
||||
| `tool_result` | `str` | The result returned to the model |
|
||||
|
||||
## Extracting Data from Messages
|
||||
|
||||
Messages follow OpenAI format. Common patterns:
|
||||
|
||||
```python
|
||||
# Get final assistant response
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content"):
|
||||
final_response = msg["content"]
|
||||
break
|
||||
|
||||
# Get all tool names used
|
||||
tools = []
|
||||
for msg in result.messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
tools.append(fn.get("name", ""))
|
||||
|
||||
# Get tool results
|
||||
for msg in result.messages:
|
||||
if msg.get("role") == "tool":
|
||||
tool_output = msg.get("content", "")
|
||||
call_id = msg.get("tool_call_id", "")
|
||||
```
|
||||
|
||||
## Fields that DO NOT EXIST
|
||||
|
||||
These are common mistakes — AgentResult does NOT have:
|
||||
- `final_response` — extract from messages
|
||||
- `tool_calls` — extract from messages
|
||||
- `tools_used` — extract from messages
|
||||
- `output` — extract from messages
|
||||
- `response` — extract from messages
|
||||
@@ -0,0 +1,65 @@
|
||||
# Atropos BaseEnv Reference
|
||||
|
||||
Source: `atroposlib/envs/base.py` (~2124 lines)
|
||||
|
||||
## Abstract Methods (MUST implement)
|
||||
|
||||
| Method | Signature | Description |
|
||||
|--------|-----------|-------------|
|
||||
| `get_next_item()` | `async def get_next_item(self) -> Item` | Return next item for trajectory. Return None to pause. |
|
||||
| `evaluate()` | `async def evaluate(self, *args, **kwargs)` | Called every steps_per_eval steps. |
|
||||
| `setup()` | `async def setup(self)` | Called once at start. Load datasets, init models. |
|
||||
| `collect_trajectory()` | `async def collect_trajectory(self, item) -> Tuple[Optional[ScoredDataItem], List[Item]]` | Single rollout. Or override collect_trajectories instead. |
|
||||
|
||||
## Overridable Methods
|
||||
|
||||
| Method | Default Behavior | Override When |
|
||||
|--------|-----------------|---------------|
|
||||
| `collect_trajectories()` | Runs collect_trajectory group_size times in parallel | Batch generation, MCTS, coupled rollouts |
|
||||
| `wandb_log()` | Logs completion lengths, rollout table, perf stats | Add custom metrics (always call super) |
|
||||
| `config_init()` | Returns (env_config_cls(), ServerBaseline()) | Custom defaults + server configs |
|
||||
| `postprocess_histories()` | Passthrough | Final processing before sending to trainer |
|
||||
| `save_checkpoint()` | Saves JSON to checkpoint_dir | Custom serialization |
|
||||
| `cleanup()` | No-op | Release resources after each rollout |
|
||||
|
||||
## ScoredDataGroup Structure
|
||||
|
||||
```python
|
||||
ScoredDataGroup = TypedDict with:
|
||||
tokens: List[List[int]] # Token IDs per rollout
|
||||
masks: List[List[int]] # -100=prompt, token_id=completion
|
||||
scores: List[float] # Score per rollout
|
||||
advantages: Optional[...] # Per-token advantages
|
||||
ref_logprobs: Optional[...] # Reference model logprobs
|
||||
messages: Optional[...] # OpenAI-format messages
|
||||
inference_logprobs: Optional[...] # Inference logprobs
|
||||
```
|
||||
|
||||
## BaseEnvConfig Key Fields
|
||||
|
||||
| Field | Default | Description |
|
||||
|-------|---------|-------------|
|
||||
| `group_size` | 4 | Responses grouped for scoring |
|
||||
| `steps_per_eval` | 100 | Steps between evaluations |
|
||||
| `max_token_length` | 2048 | Max token length for generations |
|
||||
| `total_steps` | 1000 | Total training steps |
|
||||
| `use_wandb` | True | Enable wandb logging |
|
||||
| `tokenizer_name` | DeepHermes-3 | Tokenizer for token encoding |
|
||||
| `ensure_scores_are_not_same` | True | Skip groups with identical scores |
|
||||
| `worker_timeout` | 600 | Task timeout seconds |
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
env_manager() → add_train_workers() → handle_env()
|
||||
→ collect_trajectories() → postprocess_histories()
|
||||
→ handle_send_to_api() → training server
|
||||
```
|
||||
|
||||
## Atropos Environment Statistics (82 environments analyzed)
|
||||
|
||||
- 95% implement setup, collect_trajectories, evaluate, get_next_item
|
||||
- 76% override wandb_log
|
||||
- 54% have custom config class
|
||||
- Most use collect_trajectories (plural), not collect_trajectory (singular)
|
||||
- Common reward patterns: LLM-judge (~40), regex-extract (~35), code-exec (~12)
|
||||
@@ -0,0 +1,199 @@
|
||||
# Usage Patterns — Testing Environments and Evaluating Models
|
||||
|
||||
## Pattern 1: Test Your Environment Works (process mode)
|
||||
|
||||
Use `process` mode to verify your environment runs end-to-end before
|
||||
committing. This generates trajectories without needing an Atropos
|
||||
training server.
|
||||
|
||||
**Before running:** Ask the user for their inference setup (see SKILL.md "Inference Setup" section). Replace `<BASE_URL>`, `<MODEL>`, and `<SERVER_TYPE>` below with their chosen values.
|
||||
|
||||
### Step 1: Run 1 trajectory
|
||||
|
||||
```bash
|
||||
cd ~/.hermes/hermes-agent
|
||||
source venv/bin/activate
|
||||
|
||||
python environments/your_env.py process \
|
||||
--env.total_steps 1 \
|
||||
--env.group_size 1 \
|
||||
--env.use_wandb false \
|
||||
--env.data_path_to_save_groups /tmp/test_output.jsonl \
|
||||
--openai.base_url "<BASE_URL>" \
|
||||
--openai.model_name "<MODEL>" \
|
||||
--openai.server_type <SERVER_TYPE> \
|
||||
--openai.health_check false
|
||||
```
|
||||
|
||||
### Step 2: Verify the output
|
||||
|
||||
```python
|
||||
import json
|
||||
for line in open("/tmp/test_output.jsonl"):
|
||||
data = json.loads(line)
|
||||
print(f"Scores: {data.get('scores', [])}")
|
||||
print(f"Token sequences: {len(data.get('tokens', []))}")
|
||||
# Check messages include tool calls
|
||||
for msg_list in data.get("messages", []):
|
||||
roles = [m.get("role") for m in msg_list]
|
||||
print(f"Roles: {roles}")
|
||||
for m in reversed(msg_list):
|
||||
if m.get("role") == "assistant" and m.get("content"):
|
||||
print(f"Response: {m['content'][:200]}...")
|
||||
break
|
||||
```
|
||||
|
||||
### What to check:
|
||||
- **Scores are not all 0.0** — if so, compute_reward is broken
|
||||
- **Scores are in [0, 1]** — not negative, not >1
|
||||
- **Messages include "tool" role entries** — agent used tools
|
||||
- **Token sequences are non-empty**
|
||||
- **An HTML visualization is generated** next to the .jsonl
|
||||
|
||||
### Common failures:
|
||||
- `'AgentResult' object has no attribute 'X'` — accessing a field that doesn't exist. See agentresult-fields.md.
|
||||
- Score always 0.0 — reward function erroring silently
|
||||
- Score always 1.0 — verification too lenient or not running
|
||||
|
||||
|
||||
## Pattern 2: Evaluate a Model (evaluate mode)
|
||||
|
||||
Use `evaluate` mode to benchmark a model on your environment's eval
|
||||
split. This runs the full agent loop with tools for each eval item.
|
||||
|
||||
### Step 1: Run evaluation
|
||||
|
||||
```bash
|
||||
python environments/your_env.py evaluate \
|
||||
--env.eval_size 20 \
|
||||
--env.use_wandb false \
|
||||
--env.data_dir_to_save_evals /tmp/eval_results \
|
||||
--openai.base_url "<BASE_URL>" \
|
||||
--openai.model_name "<MODEL>" \
|
||||
--openai.server_type <SERVER_TYPE> \
|
||||
--openai.health_check false
|
||||
```
|
||||
|
||||
### Step 2: Read results
|
||||
|
||||
Stdout shows a lighteval-compatible table:
|
||||
|
||||
```
|
||||
Evaluation Results: your-env_eval
|
||||
|Metric | Value|
|
||||
|mean correctness| 0.850 |
|
||||
|mean reward | 0.920 |
|
||||
|mean tool calls | 4.300 |
|
||||
|n items | 20 |
|
||||
Evaluation completed in 367 seconds
|
||||
```
|
||||
|
||||
JSON results saved to the eval directory:
|
||||
|
||||
```python
|
||||
import json
|
||||
data = json.load(open("/tmp/eval_results/metrics.json"))
|
||||
for metric, value in data["results"]["all"].items():
|
||||
print(f"{metric}: {value}")
|
||||
```
|
||||
|
||||
### Step 3: Compare models
|
||||
|
||||
Run evaluate with different models and compare the metrics.json files.
|
||||
|
||||
### What to check:
|
||||
- **"data_dir_to_save_evals is not set"** — you forgot the flag, results won't be saved
|
||||
- **Tool usage rate = 0** — evaluate() is using chat_completion instead of HermesAgentLoop
|
||||
- **All scores identical** — judge failing, falling back to heuristic
|
||||
- **Very slow** — each item runs a full agent loop (~30-90s). Use `--env.eval_size 5` for quick checks.
|
||||
|
||||
|
||||
## Pattern 3: Generate Training Data (process mode, larger scale)
|
||||
|
||||
Generate trajectory data for offline training or analysis:
|
||||
|
||||
```bash
|
||||
python environments/your_env.py process \
|
||||
--env.total_steps 50 \
|
||||
--env.group_size 4 \
|
||||
--env.use_wandb false \
|
||||
--env.data_path_to_save_groups data/trajectories.jsonl \
|
||||
--openai.base_url "<BASE_URL>" \
|
||||
--openai.model_name "<MODEL>" \
|
||||
--openai.server_type <SERVER_TYPE> \
|
||||
--openai.health_check false
|
||||
```
|
||||
|
||||
### Analyze the distribution:
|
||||
|
||||
```python
|
||||
import json
|
||||
scores = []
|
||||
for line in open("data/trajectories.jsonl"):
|
||||
data = json.loads(line)
|
||||
scores.extend(data.get("scores", []))
|
||||
|
||||
print(f"Total: {len(scores)}, Mean: {sum(scores)/len(scores):.3f}")
|
||||
for bucket in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]:
|
||||
count = sum(1 for s in scores if abs(s - bucket) < 0.1)
|
||||
print(f" {bucket:.1f}: {'█' * count} ({count})")
|
||||
```
|
||||
|
||||
### What to check:
|
||||
- **Score distribution has variance** — RL needs score variance. All-same scores are useless.
|
||||
|
||||
|
||||
## Pattern 4: Full RL Training (serve mode)
|
||||
|
||||
For actual RL training with Atropos:
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start Atropos API server
|
||||
run-api
|
||||
|
||||
# Terminal 2: Start your environment
|
||||
python environments/your_env.py serve \
|
||||
--config environments/your_env/default.yaml
|
||||
```
|
||||
|
||||
For Phase 2 with VLLM:
|
||||
|
||||
```bash
|
||||
# Terminal 1: VLLM server
|
||||
python -m vllm.entrypoints.openai.api_server --model your-model --port 8000
|
||||
|
||||
# Terminal 2: Atropos API
|
||||
run-api
|
||||
|
||||
# Terminal 3: Environment
|
||||
python environments/your_env.py serve \
|
||||
--openai.base_url http://localhost:8000/v1 \
|
||||
--openai.model_name your-model \
|
||||
--openai.server_type vllm
|
||||
```
|
||||
|
||||
|
||||
## Pattern 5: Quick Smoke Test
|
||||
|
||||
Verify imports and config before spending money on API calls:
|
||||
|
||||
```python
|
||||
from environments.your_env import YourEnv
|
||||
print(f"Name: {YourEnv.name}")
|
||||
cfg, servers = YourEnv.config_init()
|
||||
print(f"Toolsets: {cfg.enabled_toolsets}")
|
||||
print(f"Server: {servers[0].model_name}")
|
||||
print("All imports OK")
|
||||
```
|
||||
|
||||
|
||||
## Timing Expectations
|
||||
|
||||
| Mode | Items | Time per item | Total |
|
||||
|------|-------|--------------|-------|
|
||||
| process (1 item) | 1 | 30-90s | ~1 min |
|
||||
| evaluate (5 items) | 5 | 30-90s | ~5 min |
|
||||
| evaluate (20 items) | 20 | 30-90s | ~15-30 min |
|
||||
| process (50 items) | 50 | 30-90s | ~30-75 min |
|
||||
|
||||
Times are for cloud APIs with Claude Sonnet-class models. Local models may be faster or slower depending on hardware.
|
||||
434
protected/skills-backup/mlops/training/peft/SKILL.md
Normal file
434
protected/skills-backup/mlops/training/peft/SKILL.md
Normal file
@@ -0,0 +1,434 @@
|
||||
---
|
||||
name: peft-fine-tuning
|
||||
description: Parameter-efficient fine-tuning for LLMs using LoRA, QLoRA, and 25+ methods. Use when fine-tuning large models (7B-70B) with limited GPU memory, when you need to train <1% of parameters with minimal accuracy loss, or for multi-adapter serving. HuggingFace's official library integrated with transformers ecosystem.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [peft>=0.13.0, transformers>=4.45.0, torch>=2.0.0, bitsandbytes>=0.43.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Fine-Tuning, PEFT, LoRA, QLoRA, Parameter-Efficient, Adapters, Low-Rank, Memory Optimization, Multi-Adapter]
|
||||
|
||||
---
|
||||
|
||||
# PEFT (Parameter-Efficient Fine-Tuning)
|
||||
|
||||
Fine-tune LLMs by training <1% of parameters using LoRA, QLoRA, and 25+ adapter methods.
|
||||
|
||||
## When to use PEFT
|
||||
|
||||
**Use PEFT/LoRA when:**
|
||||
- Fine-tuning 7B-70B models on consumer GPUs (RTX 4090, A100)
|
||||
- Need to train <1% parameters (6MB adapters vs 14GB full model)
|
||||
- Want fast iteration with multiple task-specific adapters
|
||||
- Deploying multiple fine-tuned variants from one base model
|
||||
|
||||
**Use QLoRA (PEFT + quantization) when:**
|
||||
- Fine-tuning 70B models on single 24GB GPU
|
||||
- Memory is the primary constraint
|
||||
- Can accept ~5% quality trade-off vs full fine-tuning
|
||||
|
||||
**Use full fine-tuning instead when:**
|
||||
- Training small models (<1B parameters)
|
||||
- Need maximum quality and have compute budget
|
||||
- Significant domain shift requires updating all weights
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Basic installation
|
||||
pip install peft
|
||||
|
||||
# With quantization support (recommended)
|
||||
pip install peft bitsandbytes
|
||||
|
||||
# Full stack
|
||||
pip install peft transformers accelerate bitsandbytes datasets
|
||||
```
|
||||
|
||||
### LoRA fine-tuning (standard)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
|
||||
from peft import get_peft_model, LoraConfig, TaskType
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load base model
|
||||
model_name = "meta-llama/Llama-3.1-8B"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# LoRA configuration
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=16, # Rank (8-64, higher = more capacity)
|
||||
lora_alpha=32, # Scaling factor (typically 2*r)
|
||||
lora_dropout=0.05, # Dropout for regularization
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Attention layers
|
||||
bias="none" # Don't train biases
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
# Output: trainable params: 13,631,488 || all params: 8,043,307,008 || trainable%: 0.17%
|
||||
|
||||
# Prepare dataset
|
||||
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
|
||||
|
||||
def tokenize(example):
|
||||
text = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['response']}"
|
||||
return tokenizer(text, truncation=True, max_length=512, padding="max_length")
|
||||
|
||||
tokenized = dataset.map(tokenize, remove_columns=dataset.column_names)
|
||||
|
||||
# Training
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./lora-llama",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-4,
|
||||
fp16=True,
|
||||
logging_steps=10,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized,
|
||||
data_collator=lambda data: {"input_ids": torch.stack([f["input_ids"] for f in data]),
|
||||
"attention_mask": torch.stack([f["attention_mask"] for f in data]),
|
||||
"labels": torch.stack([f["input_ids"] for f in data])}
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save adapter only (6MB vs 16GB)
|
||||
model.save_pretrained("./lora-llama-adapter")
|
||||
```
|
||||
|
||||
### QLoRA fine-tuning (memory-efficient)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
|
||||
|
||||
# 4-bit quantization config
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4", # NormalFloat4 (best for LLMs)
|
||||
bnb_4bit_compute_dtype="bfloat16", # Compute in bf16
|
||||
bnb_4bit_use_double_quant=True # Nested quantization
|
||||
)
|
||||
|
||||
# Load quantized model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-70B",
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
# Prepare for training (enables gradient checkpointing)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
# LoRA config for QLoRA
|
||||
lora_config = LoraConfig(
|
||||
r=64, # Higher rank for 70B
|
||||
lora_alpha=128,
|
||||
lora_dropout=0.1,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
# 70B model now fits on single 24GB GPU!
|
||||
```
|
||||
|
||||
## LoRA parameter selection
|
||||
|
||||
### Rank (r) - capacity vs efficiency
|
||||
|
||||
| Rank | Trainable Params | Memory | Quality | Use Case |
|
||||
|------|-----------------|--------|---------|----------|
|
||||
| 4 | ~3M | Minimal | Lower | Simple tasks, prototyping |
|
||||
| **8** | ~7M | Low | Good | **Recommended starting point** |
|
||||
| **16** | ~14M | Medium | Better | **General fine-tuning** |
|
||||
| 32 | ~27M | Higher | High | Complex tasks |
|
||||
| 64 | ~54M | High | Highest | Domain adaptation, 70B models |
|
||||
|
||||
### Alpha (lora_alpha) - scaling factor
|
||||
|
||||
```python
|
||||
# Rule of thumb: alpha = 2 * rank
|
||||
LoraConfig(r=16, lora_alpha=32) # Standard
|
||||
LoraConfig(r=16, lora_alpha=16) # Conservative (lower learning rate effect)
|
||||
LoraConfig(r=16, lora_alpha=64) # Aggressive (higher learning rate effect)
|
||||
```
|
||||
|
||||
### Target modules by architecture
|
||||
|
||||
```python
|
||||
# Llama / Mistral / Qwen
|
||||
target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
|
||||
# GPT-2 / GPT-Neo
|
||||
target_modules = ["c_attn", "c_proj", "c_fc"]
|
||||
|
||||
# Falcon
|
||||
target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
|
||||
|
||||
# BLOOM
|
||||
target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
|
||||
|
||||
# Auto-detect all linear layers
|
||||
target_modules = "all-linear" # PEFT 0.6.0+
|
||||
```
|
||||
|
||||
## Loading and merging adapters
|
||||
|
||||
### Load trained adapter
|
||||
|
||||
```python
|
||||
from peft import PeftModel, AutoPeftModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# Option 1: Load with PeftModel
|
||||
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
|
||||
model = PeftModel.from_pretrained(base_model, "./lora-llama-adapter")
|
||||
|
||||
# Option 2: Load directly (recommended)
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
"./lora-llama-adapter",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Merge adapter into base model
|
||||
|
||||
```python
|
||||
# Merge for deployment (no adapter overhead)
|
||||
merged_model = model.merge_and_unload()
|
||||
|
||||
# Save merged model
|
||||
merged_model.save_pretrained("./llama-merged")
|
||||
tokenizer.save_pretrained("./llama-merged")
|
||||
|
||||
# Push to Hub
|
||||
merged_model.push_to_hub("username/llama-finetuned")
|
||||
```
|
||||
|
||||
### Multi-adapter serving
|
||||
|
||||
```python
|
||||
from peft import PeftModel
|
||||
|
||||
# Load base with first adapter
|
||||
model = AutoPeftModelForCausalLM.from_pretrained("./adapter-task1")
|
||||
|
||||
# Load additional adapters
|
||||
model.load_adapter("./adapter-task2", adapter_name="task2")
|
||||
model.load_adapter("./adapter-task3", adapter_name="task3")
|
||||
|
||||
# Switch between adapters at runtime
|
||||
model.set_adapter("task1") # Use task1 adapter
|
||||
output1 = model.generate(**inputs)
|
||||
|
||||
model.set_adapter("task2") # Switch to task2
|
||||
output2 = model.generate(**inputs)
|
||||
|
||||
# Disable adapters (use base model)
|
||||
with model.disable_adapter():
|
||||
base_output = model.generate(**inputs)
|
||||
```
|
||||
|
||||
## PEFT methods comparison
|
||||
|
||||
| Method | Trainable % | Memory | Speed | Best For |
|
||||
|--------|------------|--------|-------|----------|
|
||||
| **LoRA** | 0.1-1% | Low | Fast | General fine-tuning |
|
||||
| **QLoRA** | 0.1-1% | Very Low | Medium | Memory-constrained |
|
||||
| AdaLoRA | 0.1-1% | Low | Medium | Automatic rank selection |
|
||||
| IA3 | 0.01% | Minimal | Fastest | Few-shot adaptation |
|
||||
| Prefix Tuning | 0.1% | Low | Medium | Generation control |
|
||||
| Prompt Tuning | 0.001% | Minimal | Fast | Simple task adaptation |
|
||||
| P-Tuning v2 | 0.1% | Low | Medium | NLU tasks |
|
||||
|
||||
### IA3 (minimal parameters)
|
||||
|
||||
```python
|
||||
from peft import IA3Config
|
||||
|
||||
ia3_config = IA3Config(
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "down_proj"],
|
||||
feedforward_modules=["down_proj"]
|
||||
)
|
||||
model = get_peft_model(model, ia3_config)
|
||||
# Trains only 0.01% of parameters!
|
||||
```
|
||||
|
||||
### Prefix Tuning
|
||||
|
||||
```python
|
||||
from peft import PrefixTuningConfig
|
||||
|
||||
prefix_config = PrefixTuningConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
num_virtual_tokens=20, # Prepended tokens
|
||||
prefix_projection=True # Use MLP projection
|
||||
)
|
||||
model = get_peft_model(model, prefix_config)
|
||||
```
|
||||
|
||||
## Integration patterns
|
||||
|
||||
### With TRL (SFTTrainer)
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules="all-linear")
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=SFTConfig(output_dir="./output", max_seq_length=512),
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config, # Pass LoRA config directly
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### With Axolotl (YAML config)
|
||||
|
||||
```yaml
|
||||
# axolotl config.yaml
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
lora_target_linear: true # Target all linear layers
|
||||
```
|
||||
|
||||
### With vLLM (inference)
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
# Load base model with LoRA support
|
||||
llm = LLM(model="meta-llama/Llama-3.1-8B", enable_lora=True)
|
||||
|
||||
# Serve with adapter
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
lora_request=LoRARequest("adapter1", 1, "./lora-adapter")
|
||||
)
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Memory usage (Llama 3.1 8B)
|
||||
|
||||
| Method | GPU Memory | Trainable Params |
|
||||
|--------|-----------|------------------|
|
||||
| Full fine-tuning | 60+ GB | 8B (100%) |
|
||||
| LoRA r=16 | 18 GB | 14M (0.17%) |
|
||||
| QLoRA r=16 | 6 GB | 14M (0.17%) |
|
||||
| IA3 | 16 GB | 800K (0.01%) |
|
||||
|
||||
### Training speed (A100 80GB)
|
||||
|
||||
| Method | Tokens/sec | vs Full FT |
|
||||
|--------|-----------|------------|
|
||||
| Full FT | 2,500 | 1x |
|
||||
| LoRA | 3,200 | 1.3x |
|
||||
| QLoRA | 2,100 | 0.84x |
|
||||
|
||||
### Quality (MMLU benchmark)
|
||||
|
||||
| Model | Full FT | LoRA | QLoRA |
|
||||
|-------|---------|------|-------|
|
||||
| Llama 2-7B | 45.3 | 44.8 | 44.1 |
|
||||
| Llama 2-13B | 54.8 | 54.2 | 53.5 |
|
||||
|
||||
## Common issues
|
||||
|
||||
### CUDA OOM during training
|
||||
|
||||
```python
|
||||
# Solution 1: Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Solution 2: Reduce batch size + increase accumulation
|
||||
TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=16
|
||||
)
|
||||
|
||||
# Solution 3: Use QLoRA
|
||||
from transformers import BitsAndBytesConfig
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
|
||||
```
|
||||
|
||||
### Adapter not applying
|
||||
|
||||
```python
|
||||
# Verify adapter is active
|
||||
print(model.active_adapters) # Should show adapter name
|
||||
|
||||
# Check trainable parameters
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# Ensure model in training mode
|
||||
model.train()
|
||||
```
|
||||
|
||||
### Quality degradation
|
||||
|
||||
```python
|
||||
# Increase rank
|
||||
LoraConfig(r=32, lora_alpha=64)
|
||||
|
||||
# Target more modules
|
||||
target_modules = "all-linear"
|
||||
|
||||
# Use more training data and epochs
|
||||
TrainingArguments(num_train_epochs=5)
|
||||
|
||||
# Lower learning rate
|
||||
TrainingArguments(learning_rate=1e-4)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with r=8-16**, increase if quality insufficient
|
||||
2. **Use alpha = 2 * rank** as starting point
|
||||
3. **Target attention + MLP layers** for best quality/efficiency
|
||||
4. **Enable gradient checkpointing** for memory savings
|
||||
5. **Save adapters frequently** (small files, easy rollback)
|
||||
6. **Evaluate on held-out data** before merging
|
||||
7. **Use QLoRA for 70B+ models** on consumer hardware
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - DoRA, LoftQ, rank stabilization, custom modules
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common errors, debugging, optimization
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/huggingface/peft
|
||||
- **Docs**: https://huggingface.co/docs/peft
|
||||
- **LoRA Paper**: arXiv:2106.09685
|
||||
- **QLoRA Paper**: arXiv:2305.14314
|
||||
- **Models**: https://huggingface.co/models?library=peft
|
||||
@@ -0,0 +1,514 @@
|
||||
# PEFT Advanced Usage Guide
|
||||
|
||||
## Advanced LoRA Variants
|
||||
|
||||
### DoRA (Weight-Decomposed Low-Rank Adaptation)
|
||||
|
||||
DoRA decomposes weights into magnitude and direction components, often achieving better results than standard LoRA:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
dora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
use_dora=True, # Enable DoRA
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
model = get_peft_model(model, dora_config)
|
||||
```
|
||||
|
||||
**When to use DoRA**:
|
||||
- Consistently outperforms LoRA on instruction-following tasks
|
||||
- Slightly higher memory (~10%) due to magnitude vectors
|
||||
- Best for quality-critical fine-tuning
|
||||
|
||||
### AdaLoRA (Adaptive Rank)
|
||||
|
||||
Automatically adjusts rank per layer based on importance:
|
||||
|
||||
```python
|
||||
from peft import AdaLoraConfig
|
||||
|
||||
adalora_config = AdaLoraConfig(
|
||||
init_r=64, # Initial rank
|
||||
target_r=16, # Target average rank
|
||||
tinit=200, # Warmup steps
|
||||
tfinal=1000, # Final pruning step
|
||||
deltaT=10, # Rank update frequency
|
||||
beta1=0.85,
|
||||
beta2=0.85,
|
||||
orth_reg_weight=0.5, # Orthogonality regularization
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Allocates more rank to important layers
|
||||
- Can reduce total parameters while maintaining quality
|
||||
- Good for exploring optimal rank distribution
|
||||
|
||||
### LoRA+ (Asymmetric Learning Rates)
|
||||
|
||||
Different learning rates for A and B matrices:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
# LoRA+ uses higher LR for B matrix
|
||||
lora_plus_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
use_rslora=True, # Rank-stabilized LoRA (related technique)
|
||||
)
|
||||
|
||||
# Manual implementation of LoRA+
|
||||
from torch.optim import AdamW
|
||||
|
||||
# Group parameters
|
||||
lora_A_params = [p for n, p in model.named_parameters() if "lora_A" in n]
|
||||
lora_B_params = [p for n, p in model.named_parameters() if "lora_B" in n]
|
||||
|
||||
optimizer = AdamW([
|
||||
{"params": lora_A_params, "lr": 1e-4},
|
||||
{"params": lora_B_params, "lr": 1e-3}, # 10x higher for B
|
||||
])
|
||||
```
|
||||
|
||||
### rsLoRA (Rank-Stabilized LoRA)
|
||||
|
||||
Scales LoRA outputs to stabilize training with different ranks:
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=64,
|
||||
use_rslora=True, # Enables rank-stabilized scaling
|
||||
target_modules="all-linear"
|
||||
)
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- When experimenting with different ranks
|
||||
- Helps maintain consistent behavior across rank values
|
||||
- Recommended for r > 32
|
||||
|
||||
## LoftQ (LoRA-Fine-Tuning-aware Quantization)
|
||||
|
||||
Initializes LoRA weights to compensate for quantization error:
|
||||
|
||||
```python
|
||||
from peft import LoftQConfig, LoraConfig, get_peft_model
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
# LoftQ configuration
|
||||
loftq_config = LoftQConfig(
|
||||
loftq_bits=4, # Quantization bits
|
||||
loftq_iter=5, # Alternating optimization iterations
|
||||
)
|
||||
|
||||
# LoRA config with LoftQ initialization
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
init_lora_weights="loftq",
|
||||
loftq_config=loftq_config,
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
# Load quantized model
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
```
|
||||
|
||||
**Benefits over standard QLoRA**:
|
||||
- Better initial quality after quantization
|
||||
- Faster convergence
|
||||
- ~1-2% better final accuracy on benchmarks
|
||||
|
||||
## Custom Module Targeting
|
||||
|
||||
### Target specific layers
|
||||
|
||||
```python
|
||||
# Target only first and last transformer layers
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.v_proj",
|
||||
"model.layers.31.self_attn.q_proj",
|
||||
"model.layers.31.self_attn.v_proj"],
|
||||
layers_to_transform=[0, 31] # Alternative approach
|
||||
)
|
||||
```
|
||||
|
||||
### Layer pattern matching
|
||||
|
||||
```python
|
||||
# Target layers 0-10 only
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
layers_to_transform=list(range(11)), # Layers 0-10
|
||||
layers_pattern="model.layers"
|
||||
)
|
||||
```
|
||||
|
||||
### Exclude specific layers
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
target_modules="all-linear",
|
||||
modules_to_save=["lm_head"], # Train these fully (not LoRA)
|
||||
)
|
||||
```
|
||||
|
||||
## Embedding and LM Head Training
|
||||
|
||||
### Train embeddings with LoRA
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
# Include embeddings
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "v_proj", "embed_tokens"], # Include embeddings
|
||||
modules_to_save=["lm_head"], # Train lm_head fully
|
||||
)
|
||||
```
|
||||
|
||||
### Extending vocabulary with LoRA
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import get_peft_model, LoraConfig
|
||||
|
||||
# Add new tokens
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
|
||||
new_tokens = ["<custom_token_1>", "<custom_token_2>"]
|
||||
tokenizer.add_tokens(new_tokens)
|
||||
|
||||
# Resize model embeddings
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Configure LoRA to train new embeddings
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
target_modules="all-linear",
|
||||
modules_to_save=["embed_tokens", "lm_head"], # Train these fully
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
```
|
||||
|
||||
## Multi-Adapter Patterns
|
||||
|
||||
### Adapter composition
|
||||
|
||||
```python
|
||||
from peft import PeftModel
|
||||
|
||||
# Load model with multiple adapters
|
||||
model = AutoPeftModelForCausalLM.from_pretrained("./base-adapter")
|
||||
model.load_adapter("./style-adapter", adapter_name="style")
|
||||
model.load_adapter("./task-adapter", adapter_name="task")
|
||||
|
||||
# Combine adapters (weighted sum)
|
||||
model.add_weighted_adapter(
|
||||
adapters=["style", "task"],
|
||||
weights=[0.7, 0.3],
|
||||
adapter_name="combined",
|
||||
combination_type="linear" # or "cat", "svd"
|
||||
)
|
||||
|
||||
model.set_adapter("combined")
|
||||
```
|
||||
|
||||
### Adapter stacking
|
||||
|
||||
```python
|
||||
# Stack adapters (apply sequentially)
|
||||
model.add_weighted_adapter(
|
||||
adapters=["base", "domain", "task"],
|
||||
weights=[1.0, 1.0, 1.0],
|
||||
adapter_name="stacked",
|
||||
combination_type="cat" # Concatenate adapter outputs
|
||||
)
|
||||
```
|
||||
|
||||
### Dynamic adapter switching
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class MultiAdapterModel:
|
||||
def __init__(self, base_model_path, adapter_paths):
|
||||
self.model = AutoPeftModelForCausalLM.from_pretrained(adapter_paths[0])
|
||||
for name, path in adapter_paths[1:].items():
|
||||
self.model.load_adapter(path, adapter_name=name)
|
||||
|
||||
def generate(self, prompt, adapter_name="default"):
|
||||
self.model.set_adapter(adapter_name)
|
||||
return self.model.generate(**self.tokenize(prompt))
|
||||
|
||||
def generate_ensemble(self, prompt, adapters, weights):
|
||||
"""Generate with weighted adapter ensemble"""
|
||||
outputs = []
|
||||
for adapter, weight in zip(adapters, weights):
|
||||
self.model.set_adapter(adapter)
|
||||
logits = self.model(**self.tokenize(prompt)).logits
|
||||
outputs.append(weight * logits)
|
||||
return torch.stack(outputs).sum(dim=0)
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### Gradient checkpointing with LoRA
|
||||
|
||||
```python
|
||||
from peft import prepare_model_for_kbit_training
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model = prepare_model_for_kbit_training(
|
||||
model,
|
||||
use_gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
```
|
||||
|
||||
### CPU offloading for training
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="bf16",
|
||||
gradient_accumulation_steps=8,
|
||||
cpu_offload=True # Offload optimizer states to CPU
|
||||
)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
### Memory-efficient attention with LoRA
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# Combine Flash Attention 2 with LoRA
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(model, lora_config)
|
||||
```
|
||||
|
||||
## Inference Optimization
|
||||
|
||||
### Merge for deployment
|
||||
|
||||
```python
|
||||
# Merge adapter weights into base model
|
||||
merged_model = model.merge_and_unload()
|
||||
|
||||
# Quantize merged model for inference
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"./merged-model",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
```
|
||||
|
||||
### Export to different formats
|
||||
|
||||
```python
|
||||
# Export to GGUF (llama.cpp)
|
||||
# First merge, then convert
|
||||
merged_model.save_pretrained("./merged-model")
|
||||
|
||||
# Use llama.cpp converter
|
||||
# python convert-hf-to-gguf.py ./merged-model --outfile model.gguf
|
||||
|
||||
# Export to ONNX
|
||||
from optimum.onnxruntime import ORTModelForCausalLM
|
||||
|
||||
ort_model = ORTModelForCausalLM.from_pretrained(
|
||||
"./merged-model",
|
||||
export=True
|
||||
)
|
||||
ort_model.save_pretrained("./onnx-model")
|
||||
```
|
||||
|
||||
### Batch adapter inference
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
# Initialize with LoRA support
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B",
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_loras=4 # Max concurrent adapters
|
||||
)
|
||||
|
||||
# Batch with different adapters
|
||||
requests = [
|
||||
("prompt1", LoRARequest("adapter1", 1, "./adapter1")),
|
||||
("prompt2", LoRARequest("adapter2", 2, "./adapter2")),
|
||||
("prompt3", LoRARequest("adapter1", 1, "./adapter1")),
|
||||
]
|
||||
|
||||
outputs = llm.generate(
|
||||
[r[0] for r in requests],
|
||||
lora_request=[r[1] for r in requests]
|
||||
)
|
||||
```
|
||||
|
||||
## Training Recipes
|
||||
|
||||
### Instruction tuning recipe
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
target_modules="all-linear",
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./output",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-4,
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=0.03,
|
||||
bf16=True,
|
||||
logging_steps=10,
|
||||
save_strategy="steps",
|
||||
save_steps=100,
|
||||
eval_strategy="steps",
|
||||
eval_steps=100,
|
||||
)
|
||||
```
|
||||
|
||||
### Code generation recipe
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=32, # Higher rank for code
|
||||
lora_alpha=64,
|
||||
lora_dropout=0.1,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
learning_rate=1e-4, # Lower LR for code
|
||||
num_train_epochs=2,
|
||||
max_seq_length=2048, # Longer sequences
|
||||
)
|
||||
```
|
||||
|
||||
### Conversational/Chat recipe
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=16, # alpha = r for chat
|
||||
lora_dropout=0.05,
|
||||
target_modules="all-linear"
|
||||
)
|
||||
|
||||
# Use chat template
|
||||
def format_chat(example):
|
||||
messages = [
|
||||
{"role": "user", "content": example["instruction"]},
|
||||
{"role": "assistant", "content": example["response"]}
|
||||
]
|
||||
return tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
peft_config=lora_config,
|
||||
train_dataset=dataset.map(format_chat),
|
||||
max_seq_length=1024,
|
||||
)
|
||||
```
|
||||
|
||||
## Debugging and Validation
|
||||
|
||||
### Verify adapter application
|
||||
|
||||
```python
|
||||
# Check which modules have LoRA
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "lora_A"):
|
||||
print(f"LoRA applied to: {name}")
|
||||
|
||||
# Print detailed config
|
||||
print(model.peft_config)
|
||||
|
||||
# Check adapter state
|
||||
print(f"Active adapters: {model.active_adapters}")
|
||||
print(f"Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
```
|
||||
|
||||
### Compare with base model
|
||||
|
||||
```python
|
||||
# Generate with adapter
|
||||
model.set_adapter("default")
|
||||
adapter_output = model.generate(**inputs)
|
||||
|
||||
# Generate without adapter
|
||||
with model.disable_adapter():
|
||||
base_output = model.generate(**inputs)
|
||||
|
||||
print(f"Adapter: {tokenizer.decode(adapter_output[0])}")
|
||||
print(f"Base: {tokenizer.decode(base_output[0])}")
|
||||
```
|
||||
|
||||
### Monitor training metrics
|
||||
|
||||
```python
|
||||
from transformers import TrainerCallback
|
||||
|
||||
class LoRACallback(TrainerCallback):
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if "loss" in logs:
|
||||
# Log adapter-specific metrics
|
||||
model = kwargs["model"]
|
||||
lora_params = sum(p.numel() for n, p in model.named_parameters()
|
||||
if "lora" in n and p.requires_grad)
|
||||
print(f"Step {state.global_step}: loss={logs['loss']:.4f}, lora_params={lora_params}")
|
||||
```
|
||||
@@ -0,0 +1,480 @@
|
||||
# PEFT Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### bitsandbytes CUDA Error
|
||||
|
||||
**Error**: `CUDA Setup failed despite GPU being available`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
|
||||
# Install matching bitsandbytes
|
||||
pip uninstall bitsandbytes
|
||||
pip install bitsandbytes --no-cache-dir
|
||||
|
||||
# Or compile from source for specific CUDA
|
||||
git clone https://github.com/TimDettmers/bitsandbytes.git
|
||||
cd bitsandbytes
|
||||
CUDA_VERSION=118 make cuda11x # Adjust for your CUDA
|
||||
pip install .
|
||||
```
|
||||
|
||||
### Triton Import Error
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'triton'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install triton (Linux only)
|
||||
pip install triton
|
||||
|
||||
# Windows: Triton not supported, use CUDA backend
|
||||
# Set environment variable to disable triton
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
```
|
||||
|
||||
### PEFT Version Conflicts
|
||||
|
||||
**Error**: `AttributeError: 'LoraConfig' object has no attribute 'use_dora'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Upgrade to latest PEFT
|
||||
pip install peft>=0.13.0 --upgrade
|
||||
|
||||
# Check version
|
||||
python -c "import peft; print(peft.__version__)"
|
||||
```
|
||||
|
||||
## Training Issues
|
||||
|
||||
### CUDA Out of Memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable gradient checkpointing**:
|
||||
```python
|
||||
from peft import prepare_model_for_kbit_training
|
||||
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
||||
```
|
||||
|
||||
2. **Reduce batch size**:
|
||||
```python
|
||||
TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=16 # Maintain effective batch size
|
||||
)
|
||||
```
|
||||
|
||||
3. **Use QLoRA**:
|
||||
```python
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
|
||||
```
|
||||
|
||||
4. **Lower LoRA rank**:
|
||||
```python
|
||||
LoraConfig(r=8) # Instead of r=16 or higher
|
||||
```
|
||||
|
||||
5. **Target fewer modules**:
|
||||
```python
|
||||
target_modules=["q_proj", "v_proj"] # Instead of all-linear
|
||||
```
|
||||
|
||||
### Loss Not Decreasing
|
||||
|
||||
**Problem**: Training loss stays flat or increases.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check learning rate**:
|
||||
```python
|
||||
# Start lower
|
||||
TrainingArguments(learning_rate=1e-4) # Not 2e-4 or higher
|
||||
```
|
||||
|
||||
2. **Verify adapter is active**:
|
||||
```python
|
||||
model.print_trainable_parameters()
|
||||
# Should show >0 trainable params
|
||||
|
||||
# Check adapter applied
|
||||
print(model.peft_config)
|
||||
```
|
||||
|
||||
3. **Check data formatting**:
|
||||
```python
|
||||
# Verify tokenization
|
||||
sample = dataset[0]
|
||||
decoded = tokenizer.decode(sample["input_ids"])
|
||||
print(decoded) # Should look correct
|
||||
```
|
||||
|
||||
4. **Increase rank**:
|
||||
```python
|
||||
LoraConfig(r=32, lora_alpha=64) # More capacity
|
||||
```
|
||||
|
||||
### NaN Loss
|
||||
|
||||
**Error**: `Loss is NaN`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Use bf16 instead of fp16
|
||||
TrainingArguments(bf16=True, fp16=False)
|
||||
|
||||
# Or enable loss scaling
|
||||
TrainingArguments(fp16=True, fp16_full_eval=True)
|
||||
|
||||
# Lower learning rate
|
||||
TrainingArguments(learning_rate=5e-5)
|
||||
|
||||
# Check for data issues
|
||||
for batch in dataloader:
|
||||
if torch.isnan(batch["input_ids"].float()).any():
|
||||
print("NaN in input!")
|
||||
```
|
||||
|
||||
### Adapter Not Training
|
||||
|
||||
**Problem**: `trainable params: 0` or model not updating.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Verify LoRA applied to correct modules
|
||||
for name, module in model.named_modules():
|
||||
if "lora" in name.lower():
|
||||
print(f"Found LoRA: {name}")
|
||||
|
||||
# Check target_modules match model architecture
|
||||
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
||||
print(TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.get(model.config.model_type))
|
||||
|
||||
# Ensure model in training mode
|
||||
model.train()
|
||||
|
||||
# Check requires_grad
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
print(f"Trainable: {name}")
|
||||
```
|
||||
|
||||
## Loading Issues
|
||||
|
||||
### Adapter Loading Fails
|
||||
|
||||
**Error**: `ValueError: Can't find adapter weights`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check adapter files exist
|
||||
import os
|
||||
print(os.listdir("./adapter-path"))
|
||||
# Should contain: adapter_config.json, adapter_model.safetensors
|
||||
|
||||
# Load with correct structure
|
||||
from peft import PeftModel, PeftConfig
|
||||
|
||||
# Check config
|
||||
config = PeftConfig.from_pretrained("./adapter-path")
|
||||
print(config)
|
||||
|
||||
# Load base model first
|
||||
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
|
||||
model = PeftModel.from_pretrained(base_model, "./adapter-path")
|
||||
```
|
||||
|
||||
### Base Model Mismatch
|
||||
|
||||
**Error**: `RuntimeError: size mismatch`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure base model matches adapter
|
||||
from peft import PeftConfig
|
||||
|
||||
config = PeftConfig.from_pretrained("./adapter-path")
|
||||
print(f"Base model: {config.base_model_name_or_path}")
|
||||
|
||||
# Load exact same base model
|
||||
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
|
||||
```
|
||||
|
||||
### Safetensors vs PyTorch Format
|
||||
|
||||
**Error**: `ValueError: We couldn't connect to 'https://huggingface.co'`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Force local loading
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
"./adapter-path",
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# Or specify format
|
||||
model.save_pretrained("./adapter", safe_serialization=True) # safetensors
|
||||
model.save_pretrained("./adapter", safe_serialization=False) # pytorch
|
||||
```
|
||||
|
||||
## Inference Issues
|
||||
|
||||
### Slow Generation
|
||||
|
||||
**Problem**: Inference much slower than expected.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Merge adapter for deployment**:
|
||||
```python
|
||||
merged_model = model.merge_and_unload()
|
||||
# No adapter overhead during inference
|
||||
```
|
||||
|
||||
2. **Use optimized inference engine**:
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM(model="./merged-model", dtype="half")
|
||||
```
|
||||
|
||||
3. **Enable Flash Attention**:
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
attn_implementation="flash_attention_2"
|
||||
)
|
||||
```
|
||||
|
||||
### Output Quality Issues
|
||||
|
||||
**Problem**: Fine-tuned model produces worse outputs.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check evaluation without adapter**:
|
||||
```python
|
||||
with model.disable_adapter():
|
||||
base_output = model.generate(**inputs)
|
||||
# Compare with adapter output
|
||||
```
|
||||
|
||||
2. **Lower temperature during eval**:
|
||||
```python
|
||||
model.generate(**inputs, temperature=0.1, do_sample=False)
|
||||
```
|
||||
|
||||
3. **Retrain with more data**:
|
||||
```python
|
||||
# Increase training samples
|
||||
# Use higher quality data
|
||||
# Train for more epochs
|
||||
```
|
||||
|
||||
### Wrong Adapter Active
|
||||
|
||||
**Problem**: Model using wrong adapter or no adapter.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check active adapters
|
||||
print(model.active_adapters)
|
||||
|
||||
# Explicitly set adapter
|
||||
model.set_adapter("your-adapter-name")
|
||||
|
||||
# List all adapters
|
||||
print(model.peft_config.keys())
|
||||
```
|
||||
|
||||
## QLoRA Specific Issues
|
||||
|
||||
### Quantization Errors
|
||||
|
||||
**Error**: `RuntimeError: mat1 and mat2 shapes cannot be multiplied`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure compute dtype matches
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16, # Match model dtype
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
|
||||
# Load with correct dtype
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
quantization_config=bnb_config,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
### QLoRA OOM
|
||||
|
||||
**Error**: OOM even with 4-bit quantization.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Enable double quantization
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True # Further memory reduction
|
||||
)
|
||||
|
||||
# Use offloading
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
max_memory={0: "20GB", "cpu": "100GB"}
|
||||
)
|
||||
```
|
||||
|
||||
### QLoRA Merge Fails
|
||||
|
||||
**Error**: `RuntimeError: expected scalar type BFloat16 but found Float`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Dequantize before merging
|
||||
from peft import PeftModel
|
||||
|
||||
# Load in higher precision for merging
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model_name,
|
||||
torch_dtype=torch.float16, # Not quantized
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
# Load adapter
|
||||
model = PeftModel.from_pretrained(base_model, "./qlora-adapter")
|
||||
|
||||
# Now merge
|
||||
merged = model.merge_and_unload()
|
||||
```
|
||||
|
||||
## Multi-Adapter Issues
|
||||
|
||||
### Adapter Conflict
|
||||
|
||||
**Error**: `ValueError: Adapter with name 'default' already exists`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Use unique names
|
||||
model.load_adapter("./adapter1", adapter_name="task1")
|
||||
model.load_adapter("./adapter2", adapter_name="task2")
|
||||
|
||||
# Or delete existing
|
||||
model.delete_adapter("default")
|
||||
```
|
||||
|
||||
### Mixed Precision Adapters
|
||||
|
||||
**Error**: Adapters trained with different dtypes.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Convert adapter precision
|
||||
model = PeftModel.from_pretrained(base_model, "./adapter")
|
||||
model = model.to(torch.bfloat16)
|
||||
|
||||
# Or load with specific dtype
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
"./adapter",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Memory Profiling
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
def print_memory():
|
||||
if torch.cuda.is_available():
|
||||
allocated = torch.cuda.memory_allocated() / 1e9
|
||||
reserved = torch.cuda.memory_reserved() / 1e9
|
||||
print(f"Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
|
||||
|
||||
# Profile during training
|
||||
print_memory() # Before
|
||||
model.train()
|
||||
loss = model(**batch).loss
|
||||
loss.backward()
|
||||
print_memory() # After
|
||||
```
|
||||
|
||||
### Speed Profiling
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
|
||||
def benchmark_generation(model, tokenizer, prompt, n_runs=5):
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# Warmup
|
||||
model.generate(**inputs, max_new_tokens=10)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
start = time.perf_counter()
|
||||
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
torch.cuda.synchronize()
|
||||
times.append(time.perf_counter() - start)
|
||||
|
||||
tokens = outputs.shape[1] - inputs.input_ids.shape[1]
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f"Speed: {tokens/avg_time:.2f} tokens/sec")
|
||||
|
||||
# Compare adapter vs merged
|
||||
benchmark_generation(adapter_model, tokenizer, "Hello")
|
||||
benchmark_generation(merged_model, tokenizer, "Hello")
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **Check PEFT GitHub Issues**: https://github.com/huggingface/peft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co/
|
||||
3. **PEFT Documentation**: https://huggingface.co/docs/peft
|
||||
|
||||
### Debugging Template
|
||||
|
||||
When reporting issues, include:
|
||||
|
||||
```python
|
||||
# System info
|
||||
import peft
|
||||
import transformers
|
||||
import torch
|
||||
|
||||
print(f"PEFT: {peft.__version__}")
|
||||
print(f"Transformers: {transformers.__version__}")
|
||||
print(f"PyTorch: {torch.__version__}")
|
||||
print(f"CUDA: {torch.version.cuda}")
|
||||
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
|
||||
|
||||
# Config
|
||||
print(model.peft_config)
|
||||
model.print_trainable_parameters()
|
||||
```
|
||||
129
protected/skills-backup/mlops/training/pytorch-fsdp/SKILL.md
Normal file
129
protected/skills-backup/mlops/training/pytorch-fsdp/SKILL.md
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,7 @@
|
||||
# Pytorch-Fsdp Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Other
|
||||
**File:** `other.md`
|
||||
**Pages:** 15
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,349 @@
|
||||
---
|
||||
name: pytorch-lightning
|
||||
description: High-level PyTorch framework with Trainer class, automatic distributed training (DDP/FSDP/DeepSpeed), callbacks system, and minimal boilerplate. Scales from laptop to supercomputer with same code. Use when you want clean training loops with built-in best practices.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [lightning, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [PyTorch Lightning, Training Framework, Distributed Training, DDP, FSDP, DeepSpeed, High-Level API, Callbacks, Best Practices, Scalable]
|
||||
|
||||
---
|
||||
|
||||
# PyTorch Lightning - High-Level Training Framework
|
||||
|
||||
## Quick start
|
||||
|
||||
PyTorch Lightning organizes PyTorch code to eliminate boilerplate while maintaining flexibility.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install lightning
|
||||
```
|
||||
|
||||
**Convert PyTorch to Lightning** (3 steps):
|
||||
|
||||
```python
|
||||
import lightning as L
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# Step 1: Define LightningModule (organize your PyTorch code)
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self, hidden_size=128):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(28 * 28, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, 10)
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = nn.functional.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss) # Auto-logged to TensorBoard
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
# Step 2: Create data
|
||||
train_loader = DataLoader(train_dataset, batch_size=32)
|
||||
|
||||
# Step 3: Train with Trainer (handles everything else!)
|
||||
trainer = L.Trainer(max_epochs=10, accelerator='gpu', devices=2)
|
||||
model = LitModel()
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**That's it!** Trainer handles:
|
||||
- GPU/TPU/CPU switching
|
||||
- Distributed training (DDP, FSDP, DeepSpeed)
|
||||
- Mixed precision (FP16, BF16)
|
||||
- Gradient accumulation
|
||||
- Checkpointing
|
||||
- Logging
|
||||
- Progress bars
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: From PyTorch to Lightning
|
||||
|
||||
**Original PyTorch code**:
|
||||
```python
|
||||
model = MyModel()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
model.to('cuda')
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
for batch in train_loader:
|
||||
batch = batch.to('cuda')
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Lightning version**:
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MyModel()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.model(batch) # No .to('cuda') needed!
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters())
|
||||
|
||||
# Train
|
||||
trainer = L.Trainer(max_epochs=10, accelerator='gpu')
|
||||
trainer.fit(LitModel(), train_loader)
|
||||
```
|
||||
|
||||
**Benefits**: 40+ lines → 15 lines, no device management, automatic distributed
|
||||
|
||||
### Workflow 2: Validation and testing
|
||||
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MyModel()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = nn.functional.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
val_loss = nn.functional.cross_entropy(y_hat, y)
|
||||
acc = (y_hat.argmax(dim=1) == y).float().mean()
|
||||
self.log('val_loss', val_loss)
|
||||
self.log('val_acc', acc)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
test_loss = nn.functional.cross_entropy(y_hat, y)
|
||||
self.log('test_loss', test_loss)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
# Train with validation
|
||||
trainer = L.Trainer(max_epochs=10)
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
# Test
|
||||
trainer.test(model, test_loader)
|
||||
```
|
||||
|
||||
**Automatic features**:
|
||||
- Validation runs every epoch by default
|
||||
- Metrics logged to TensorBoard
|
||||
- Best model checkpointing based on val_loss
|
||||
|
||||
### Workflow 3: Distributed training (DDP)
|
||||
|
||||
```python
|
||||
# Same code as single GPU!
|
||||
model = LitModel()
|
||||
|
||||
# 8 GPUs with DDP (automatic!)
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
strategy='ddp' # Or 'fsdp', 'deepspeed'
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
# Single command, Lightning handles the rest
|
||||
python train.py
|
||||
```
|
||||
|
||||
**No changes needed**:
|
||||
- Automatic data distribution
|
||||
- Gradient synchronization
|
||||
- Multi-node support (just set `num_nodes=2`)
|
||||
|
||||
### Workflow 4: Callbacks for monitoring
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
|
||||
|
||||
# Create callbacks
|
||||
checkpoint = ModelCheckpoint(
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_top_k=3,
|
||||
filename='model-{epoch:02d}-{val_loss:.2f}'
|
||||
)
|
||||
|
||||
early_stop = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=5,
|
||||
mode='min'
|
||||
)
|
||||
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
# Add to Trainer
|
||||
trainer = L.Trainer(
|
||||
max_epochs=100,
|
||||
callbacks=[checkpoint, early_stop, lr_monitor]
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
```
|
||||
|
||||
**Result**:
|
||||
- Auto-saves best 3 models
|
||||
- Stops early if no improvement for 5 epochs
|
||||
- Logs learning rate to TensorBoard
|
||||
|
||||
### Workflow 5: Learning rate scheduling
|
||||
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
# ... (training_step, etc.)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
# Cosine annealing
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer,
|
||||
T_max=100,
|
||||
eta_min=1e-5
|
||||
)
|
||||
|
||||
return {
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler': {
|
||||
'scheduler': scheduler,
|
||||
'interval': 'epoch', # Update per epoch
|
||||
'frequency': 1
|
||||
}
|
||||
}
|
||||
|
||||
# Learning rate auto-logged!
|
||||
trainer = L.Trainer(max_epochs=100)
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use PyTorch Lightning when**:
|
||||
- Want clean, organized code
|
||||
- Need production-ready training loops
|
||||
- Switching between single GPU, multi-GPU, TPU
|
||||
- Want built-in callbacks and logging
|
||||
- Team collaboration (standardized structure)
|
||||
|
||||
**Key advantages**:
|
||||
- **Organized**: Separates research code from engineering
|
||||
- **Automatic**: DDP, FSDP, DeepSpeed with 1 line
|
||||
- **Callbacks**: Modular training extensions
|
||||
- **Reproducible**: Less boilerplate = fewer bugs
|
||||
- **Tested**: 1M+ downloads/month, battle-tested
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Accelerate**: Minimal changes to existing code, more flexibility
|
||||
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
|
||||
- **Raw PyTorch**: Maximum control, learning purposes
|
||||
- **Keras**: TensorFlow ecosystem
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Loss not decreasing**
|
||||
|
||||
Check data and model setup:
|
||||
```python
|
||||
# Add to training_step
|
||||
def training_step(self, batch, batch_idx):
|
||||
if batch_idx == 0:
|
||||
print(f"Batch shape: {batch[0].shape}")
|
||||
print(f"Labels: {batch[1]}")
|
||||
loss = ...
|
||||
return loss
|
||||
```
|
||||
|
||||
**Issue: Out of memory**
|
||||
|
||||
Reduce batch size or use gradient accumulation:
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
accumulate_grad_batches=4, # Effective batch = batch_size × 4
|
||||
precision='bf16' # Or 'fp16', reduces memory 50%
|
||||
)
|
||||
```
|
||||
|
||||
**Issue: Validation not running**
|
||||
|
||||
Ensure you pass val_loader:
|
||||
```python
|
||||
# WRONG
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# CORRECT
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
```
|
||||
|
||||
**Issue: DDP spawns multiple processes unexpectedly**
|
||||
|
||||
Lightning auto-detects GPUs. Explicitly set devices:
|
||||
```python
|
||||
# Test on CPU first
|
||||
trainer = L.Trainer(accelerator='cpu', devices=1)
|
||||
|
||||
# Then GPU
|
||||
trainer = L.Trainer(accelerator='gpu', devices=1)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Callbacks**: See [references/callbacks.md](references/callbacks.md) for EarlyStopping, ModelCheckpoint, custom callbacks, and callback hooks.
|
||||
|
||||
**Distributed strategies**: See [references/distributed.md](references/distributed.md) for DDP, FSDP, DeepSpeed ZeRO integration, multi-node setup.
|
||||
|
||||
**Hyperparameter tuning**: See [references/hyperparameter-tuning.md](references/hyperparameter-tuning.md) for integration with Optuna, Ray Tune, and WandB sweeps.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **CPU**: Works (good for debugging)
|
||||
- **Single GPU**: Works
|
||||
- **Multi-GPU**: DDP (default), FSDP, or DeepSpeed
|
||||
- **Multi-node**: DDP, FSDP, DeepSpeed
|
||||
- **TPU**: Supported (8 cores)
|
||||
- **Apple MPS**: Supported
|
||||
|
||||
**Precision options**:
|
||||
- FP32 (default)
|
||||
- FP16 (V100, older GPUs)
|
||||
- BF16 (A100/H100, recommended)
|
||||
- FP8 (H100)
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://lightning.ai/docs/pytorch/stable/
|
||||
- GitHub: https://github.com/Lightning-AI/pytorch-lightning ⭐ 29,000+
|
||||
- Version: 2.5.5+
|
||||
- Examples: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples
|
||||
- Discord: https://discord.gg/lightning-ai
|
||||
- Used by: Kaggle winners, research labs, production teams
|
||||
|
||||
|
||||
@@ -0,0 +1,436 @@
|
||||
# PyTorch Lightning Callbacks
|
||||
|
||||
## Overview
|
||||
|
||||
Callbacks add functionality to training without modifying the LightningModule. They capture **non-essential logic** like checkpointing, early stopping, and logging.
|
||||
|
||||
## Built-In Callbacks
|
||||
|
||||
### 1. ModelCheckpoint
|
||||
|
||||
**Saves best models during training**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
||||
# Save top 3 models based on validation loss
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath='checkpoints/',
|
||||
filename='model-{epoch:02d}-{val_loss:.2f}',
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_top_k=3,
|
||||
save_last=True, # Also save last epoch
|
||||
verbose=True
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[checkpoint])
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
```
|
||||
|
||||
**Configuration options**:
|
||||
```python
|
||||
checkpoint = ModelCheckpoint(
|
||||
monitor='val_acc', # Metric to monitor
|
||||
mode='max', # 'max' for accuracy, 'min' for loss
|
||||
save_top_k=5, # Keep best 5 models
|
||||
save_last=True, # Save last epoch separately
|
||||
every_n_epochs=1, # Save every N epochs
|
||||
save_on_train_epoch_end=False, # Save on validation end instead
|
||||
filename='best-{epoch}-{val_acc:.3f}', # Naming pattern
|
||||
auto_insert_metric_name=False # Don't auto-add metric to filename
|
||||
)
|
||||
```
|
||||
|
||||
**Load checkpoint**:
|
||||
```python
|
||||
# Load best model
|
||||
best_model_path = checkpoint.best_model_path
|
||||
model = LitModel.load_from_checkpoint(best_model_path)
|
||||
|
||||
# Resume training
|
||||
trainer = L.Trainer(callbacks=[checkpoint])
|
||||
trainer.fit(model, train_loader, val_loader, ckpt_path='checkpoints/last.ckpt')
|
||||
```
|
||||
|
||||
### 2. EarlyStopping
|
||||
|
||||
**Stops training when metric stops improving**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import EarlyStopping
|
||||
|
||||
early_stop = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=5, # Wait 5 epochs
|
||||
mode='min',
|
||||
min_delta=0.001, # Minimum change to qualify as improvement
|
||||
verbose=True,
|
||||
strict=True, # Crash if monitored metric not found
|
||||
check_on_train_epoch_end=False # Check on validation end
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[early_stop])
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
# Stops automatically if no improvement for 5 epochs
|
||||
```
|
||||
|
||||
**Advanced usage**:
|
||||
```python
|
||||
early_stop = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
min_delta=0.0,
|
||||
verbose=True,
|
||||
mode='min',
|
||||
stopping_threshold=0.1, # Stop if val_loss < 0.1
|
||||
divergence_threshold=5.0, # Stop if val_loss > 5.0
|
||||
check_finite=True # Stop on NaN/Inf
|
||||
)
|
||||
```
|
||||
|
||||
### 3. LearningRateMonitor
|
||||
|
||||
**Logs learning rate**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import LearningRateMonitor
|
||||
|
||||
lr_monitor = LearningRateMonitor(
|
||||
logging_interval='epoch', # Or 'step'
|
||||
log_momentum=True # Also log momentum
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[lr_monitor])
|
||||
# Learning rate automatically logged to TensorBoard/WandB
|
||||
```
|
||||
|
||||
### 4. TQDMProgressBar
|
||||
|
||||
**Customizes progress bar**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||
|
||||
progress_bar = TQDMProgressBar(
|
||||
refresh_rate=10, # Update every 10 batches
|
||||
process_position=0
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[progress_bar])
|
||||
```
|
||||
|
||||
### 5. GradientAccumulationScheduler
|
||||
|
||||
**Dynamic gradient accumulation**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import GradientAccumulationScheduler
|
||||
|
||||
# Accumulate more gradients as training progresses
|
||||
accumulator = GradientAccumulationScheduler(
|
||||
scheduling={
|
||||
0: 8, # Epochs 0-4: accumulate 8 batches
|
||||
5: 4, # Epochs 5-9: accumulate 4 batches
|
||||
10: 2 # Epochs 10+: accumulate 2 batches
|
||||
}
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[accumulator])
|
||||
```
|
||||
|
||||
### 6. StochasticWeightAveraging (SWA)
|
||||
|
||||
**Averages weights for better generalization**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import StochasticWeightAveraging
|
||||
|
||||
swa = StochasticWeightAveraging(
|
||||
swa_lrs=1e-2, # SWA learning rate
|
||||
swa_epoch_start=0.8, # Start at 80% of training
|
||||
annealing_epochs=10, # Annealing period
|
||||
annealing_strategy='cos' # 'cos' or 'linear'
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[swa])
|
||||
```
|
||||
|
||||
## Custom Callbacks
|
||||
|
||||
### Basic Custom Callback
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
|
||||
class PrintingCallback(Callback):
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
print("Training is starting!")
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
print("Training is done!")
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
print(f"Epoch {trainer.current_epoch} ended")
|
||||
|
||||
# Use it
|
||||
trainer = L.Trainer(callbacks=[PrintingCallback()])
|
||||
```
|
||||
|
||||
### Advanced Custom Callback
|
||||
|
||||
```python
|
||||
class MetricsCallback(Callback):
|
||||
"""Logs custom metrics every N batches."""
|
||||
|
||||
def __init__(self, log_every_n_batches=100):
|
||||
self.log_every_n_batches = log_every_n_batches
|
||||
self.metrics = []
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if batch_idx % self.log_every_n_batches == 0:
|
||||
# Compute custom metric
|
||||
metric = self.compute_metric(outputs)
|
||||
self.metrics.append(metric)
|
||||
|
||||
# Log to Lightning
|
||||
pl_module.log('custom_metric', metric)
|
||||
|
||||
def compute_metric(self, outputs):
|
||||
# Your custom logic
|
||||
return outputs['loss'].item()
|
||||
|
||||
def state_dict(self):
|
||||
"""Save callback state in checkpoint."""
|
||||
return {'metrics': self.metrics}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Restore callback state from checkpoint."""
|
||||
self.metrics = state_dict['metrics']
|
||||
```
|
||||
|
||||
### Gradient Monitoring Callback
|
||||
|
||||
```python
|
||||
class GradientMonitorCallback(Callback):
|
||||
"""Monitor gradient norms."""
|
||||
|
||||
def on_after_backward(self, trainer, pl_module):
|
||||
# Compute gradient norm
|
||||
total_norm = 0.0
|
||||
for p in pl_module.parameters():
|
||||
if p.grad is not None:
|
||||
param_norm = p.grad.data.norm(2)
|
||||
total_norm += param_norm.item() ** 2
|
||||
total_norm = total_norm ** 0.5
|
||||
|
||||
# Log
|
||||
pl_module.log('grad_norm', total_norm)
|
||||
|
||||
# Warn if exploding
|
||||
if total_norm > 100:
|
||||
print(f"Warning: Large gradient norm: {total_norm:.2f}")
|
||||
```
|
||||
|
||||
### Model Inspection Callback
|
||||
|
||||
```python
|
||||
class ModelInspectionCallback(Callback):
|
||||
"""Inspect model activations during training."""
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||
if batch_idx == 0: # First batch of epoch
|
||||
# Register hooks
|
||||
self.activations = {}
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
self.activations[name] = output.detach()
|
||||
return hook
|
||||
|
||||
# Attach to specific layers
|
||||
pl_module.model.layer1.register_forward_hook(get_activation('layer1'))
|
||||
pl_module.model.layer2.register_forward_hook(get_activation('layer2'))
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if batch_idx == 0:
|
||||
# Log activation statistics
|
||||
for name, activation in self.activations.items():
|
||||
mean = activation.mean().item()
|
||||
std = activation.std().item()
|
||||
pl_module.log(f'{name}_mean', mean)
|
||||
pl_module.log(f'{name}_std', std)
|
||||
```
|
||||
|
||||
## Callback Hooks
|
||||
|
||||
**All available hooks**:
|
||||
|
||||
```python
|
||||
class MyCallback(Callback):
|
||||
# Setup/Teardown
|
||||
def setup(self, trainer, pl_module, stage):
|
||||
"""Called at beginning of fit/test/predict."""
|
||||
pass
|
||||
|
||||
def teardown(self, trainer, pl_module, stage):
|
||||
"""Called at end of fit/test/predict."""
|
||||
pass
|
||||
|
||||
# Training
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||
pass
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
pass
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
# Validation
|
||||
def on_validation_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_validation_epoch_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||
pass
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
pass
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
# Test (same structure as validation)
|
||||
def on_test_start(self, trainer, pl_module):
|
||||
pass
|
||||
# ... (test_epoch_start, test_batch_start, etc.)
|
||||
|
||||
# Predict
|
||||
def on_predict_start(self, trainer, pl_module):
|
||||
pass
|
||||
# ... (predict_epoch_start, predict_batch_start, etc.)
|
||||
|
||||
# Backward
|
||||
def on_before_backward(self, trainer, pl_module, loss):
|
||||
pass
|
||||
|
||||
def on_after_backward(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
# Optimizer
|
||||
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
|
||||
pass
|
||||
|
||||
# Checkpointing
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
"""Add data to checkpoint."""
|
||||
pass
|
||||
|
||||
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
"""Restore data from checkpoint."""
|
||||
pass
|
||||
```
|
||||
|
||||
## Combining Multiple Callbacks
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
|
||||
|
||||
# Create all callbacks
|
||||
checkpoint = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=3)
|
||||
early_stop = EarlyStopping(monitor='val_loss', patience=5)
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
custom_callback = MyCustomCallback()
|
||||
|
||||
# Add all to Trainer
|
||||
trainer = L.Trainer(
|
||||
callbacks=[checkpoint, early_stop, lr_monitor, custom_callback]
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
```
|
||||
|
||||
**Execution order**: Callbacks execute in the order they're added
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Keep Callbacks Independent
|
||||
|
||||
**Bad** (dependent on other callback):
|
||||
```python
|
||||
class BadCallback(Callback):
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
# Assumes ModelCheckpoint is present
|
||||
best_path = trainer.checkpoint_callback.best_model_path # Fragile!
|
||||
```
|
||||
|
||||
**Good** (self-contained):
|
||||
```python
|
||||
class GoodCallback(Callback):
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
# Find checkpoint callback if present
|
||||
for callback in trainer.callbacks:
|
||||
if isinstance(callback, ModelCheckpoint):
|
||||
best_path = callback.best_model_path
|
||||
break
|
||||
```
|
||||
|
||||
### 2. Use State Dict for Persistence
|
||||
|
||||
```python
|
||||
class StatefulCallback(Callback):
|
||||
def __init__(self):
|
||||
self.counter = 0
|
||||
self.history = []
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
self.counter += 1
|
||||
self.history.append(outputs['loss'].item())
|
||||
|
||||
def state_dict(self):
|
||||
"""Save state."""
|
||||
return {
|
||||
'counter': self.counter,
|
||||
'history': self.history
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Restore state."""
|
||||
self.counter = state_dict['counter']
|
||||
self.history = state_dict['history']
|
||||
```
|
||||
|
||||
### 3. Handle Distributed Training
|
||||
|
||||
```python
|
||||
class DistributedCallback(Callback):
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
# Only run on main process
|
||||
if trainer.is_global_zero:
|
||||
print("This only prints once in distributed training")
|
||||
|
||||
# Run on all processes
|
||||
loss = outputs['loss']
|
||||
# ... do something with loss on each GPU
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Callback API: https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html
|
||||
- Built-in callbacks: https://lightning.ai/docs/pytorch/stable/api_references.html#callbacks
|
||||
- Examples: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/callbacks
|
||||
@@ -0,0 +1,490 @@
|
||||
# PyTorch Lightning Distributed Training
|
||||
|
||||
## Distributed Strategies
|
||||
|
||||
Lightning supports multiple distributed strategies with a single parameter change.
|
||||
|
||||
### 1. DDP (DistributedDataParallel)
|
||||
|
||||
**Default strategy for multi-GPU**:
|
||||
|
||||
```python
|
||||
# Automatic DDP on all available GPUs
|
||||
trainer = L.Trainer(accelerator='gpu', devices=4, strategy='ddp')
|
||||
|
||||
# Or auto-detect
|
||||
trainer = L.Trainer(accelerator='gpu', devices='auto')
|
||||
```
|
||||
|
||||
**How DDP works**:
|
||||
- Replicates model on each GPU
|
||||
- Each GPU processes different batch
|
||||
- Gradients all-reduced across GPUs
|
||||
- Model weights synchronized
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
# Lightning handles spawning processes automatically
|
||||
python train.py
|
||||
```
|
||||
|
||||
**DDP Configuration**:
|
||||
```python
|
||||
from lightning.pytorch.strategies import DDPStrategy
|
||||
|
||||
strategy = DDPStrategy(
|
||||
find_unused_parameters=False, # Set True if model has unused params
|
||||
gradient_as_bucket_view=True, # Memory optimization
|
||||
static_graph=False, # Set True if graph doesn't change
|
||||
)
|
||||
|
||||
trainer = L.Trainer(strategy=strategy)
|
||||
```
|
||||
|
||||
### 2. FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
**For large models (7B+ parameters)**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.strategies import FSDPStrategy
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
|
||||
activation_checkpointing=None, # Or specify layer types
|
||||
cpu_offload=False, # CPU offload for memory
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
strategy=strategy,
|
||||
precision='bf16' # Recommended with FSDP
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**FSDP Sharding Strategies**:
|
||||
```python
|
||||
# FULL_SHARD (most memory efficient, equivalent to ZeRO-3)
|
||||
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
|
||||
|
||||
# SHARD_GRAD_OP (less memory efficient, equivalent to ZeRO-2)
|
||||
strategy = FSDPStrategy(sharding_strategy="SHARD_GRAD_OP")
|
||||
|
||||
# NO_SHARD (no sharding, like DDP)
|
||||
strategy = FSDPStrategy(sharding_strategy="NO_SHARD")
|
||||
```
|
||||
|
||||
**Auto-wrap policy** (wrap transformer blocks):
|
||||
```python
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
||||
import functools
|
||||
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={GPT2Block}
|
||||
)
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
activation_checkpointing_policy={GPT2Block} # Checkpoint these blocks
|
||||
)
|
||||
```
|
||||
|
||||
### 3. DeepSpeed
|
||||
|
||||
**For massive models (70B+ parameters)**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.strategies import DeepSpeedStrategy
|
||||
|
||||
# DeepSpeed ZeRO-3 with CPU offload
|
||||
strategy = DeepSpeedStrategy(
|
||||
stage=3, # ZeRO-3
|
||||
offload_optimizer=True, # CPU offload optimizer
|
||||
offload_parameters=True, # CPU offload parameters
|
||||
cpu_checkpointing=True, # Checkpoint to CPU
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
strategy=strategy,
|
||||
precision='bf16'
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**DeepSpeed configuration file**:
|
||||
```json
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Use config file**:
|
||||
```python
|
||||
strategy = DeepSpeedStrategy(config='deepspeed_config.json')
|
||||
trainer = L.Trainer(strategy=strategy)
|
||||
```
|
||||
|
||||
### 4. DDP Spawn
|
||||
|
||||
**Windows-compatible DDP**:
|
||||
|
||||
```python
|
||||
# Use when DDP doesn't work (e.g., Windows, Jupyter)
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=2,
|
||||
strategy='ddp_spawn' # Spawns new processes
|
||||
)
|
||||
```
|
||||
|
||||
**Note**: Slower than DDP due to process spawning overhead
|
||||
|
||||
## Multi-Node Training
|
||||
|
||||
### Setup Multi-Node Cluster
|
||||
|
||||
**Node 0 (master)**:
|
||||
```bash
|
||||
export MASTER_ADDR=192.168.1.100
|
||||
export MASTER_PORT=12355
|
||||
export WORLD_SIZE=16 # 2 nodes × 8 GPUs
|
||||
export NODE_RANK=0
|
||||
|
||||
python train.py
|
||||
```
|
||||
|
||||
**Node 1 (worker)**:
|
||||
```bash
|
||||
export MASTER_ADDR=192.168.1.100
|
||||
export MASTER_PORT=12355
|
||||
export WORLD_SIZE=16
|
||||
export NODE_RANK=1
|
||||
|
||||
python train.py
|
||||
```
|
||||
|
||||
**Training script**:
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8, # GPUs per node
|
||||
num_nodes=2, # Total nodes
|
||||
strategy='ddp'
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
### SLURM Integration
|
||||
|
||||
**SLURM job script**:
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=4
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --time=24:00:00
|
||||
|
||||
# Lightning auto-detects SLURM environment
|
||||
srun python train.py
|
||||
```
|
||||
|
||||
**Training script** (no changes needed):
|
||||
```python
|
||||
# Lightning automatically reads SLURM environment variables
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
num_nodes=4, # From SBATCH --nodes
|
||||
strategy='ddp'
|
||||
)
|
||||
```
|
||||
|
||||
### Kubernetes (KubeFlow)
|
||||
|
||||
**Training script**:
|
||||
```python
|
||||
import os
|
||||
|
||||
# Lightning auto-detects Kubernetes
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=int(os.getenv('WORLD_SIZE', 1)),
|
||||
strategy='ddp'
|
||||
)
|
||||
```
|
||||
|
||||
## Mixed Precision Training
|
||||
|
||||
### BF16 (A100/H100)
|
||||
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
precision='bf16', # Or 'bf16-mixed'
|
||||
accelerator='gpu'
|
||||
)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- No gradient scaler needed
|
||||
- Same dynamic range as FP32
|
||||
- 2× speedup, 50% memory reduction
|
||||
|
||||
### FP16 (V100, older GPUs)
|
||||
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
precision='16-mixed', # Or just '16'
|
||||
accelerator='gpu'
|
||||
)
|
||||
```
|
||||
|
||||
**Automatic gradient scaling** handled by Lightning
|
||||
|
||||
### FP8 (H100)
|
||||
|
||||
```python
|
||||
# Requires transformer_engine
|
||||
# pip install transformer-engine[pytorch]
|
||||
|
||||
trainer = L.Trainer(
|
||||
precision='transformer-engine',
|
||||
accelerator='gpu'
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**: 2× faster than BF16 on H100
|
||||
|
||||
## Gradient Accumulation
|
||||
|
||||
**Simulate larger batch size**:
|
||||
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
accumulate_grad_batches=4, # Accumulate 4 batches
|
||||
precision='bf16'
|
||||
)
|
||||
|
||||
# Effective batch = batch_size × accumulate_grad_batches × num_gpus
|
||||
# Example: 32 × 4 × 8 = 1024
|
||||
```
|
||||
|
||||
**Dynamic accumulation**:
|
||||
```python
|
||||
# Accumulate more early in training
|
||||
trainer = L.Trainer(
|
||||
accumulate_grad_batches={
|
||||
0: 8, # Epochs 0-4: accumulate 8
|
||||
5: 4, # Epochs 5-9: accumulate 4
|
||||
10: 2 # Epochs 10+: accumulate 2
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Checkpointing in Distributed
|
||||
|
||||
### Save Checkpoint
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
||||
# Only rank 0 saves by default
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath='checkpoints/',
|
||||
filename='model-{epoch:02d}',
|
||||
save_top_k=3
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[checkpoint], strategy='ddp')
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**Manual save**:
|
||||
```python
|
||||
class MyModel(L.LightningModule):
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Training...
|
||||
loss = ...
|
||||
|
||||
# Save every 1000 steps (only rank 0)
|
||||
if batch_idx % 1000 == 0 and self.trainer.is_global_zero:
|
||||
self.trainer.save_checkpoint(f'checkpoint_step_{batch_idx}.ckpt')
|
||||
|
||||
return loss
|
||||
```
|
||||
|
||||
### Load Checkpoint
|
||||
|
||||
```python
|
||||
# Resume training
|
||||
trainer = L.Trainer(strategy='ddp')
|
||||
trainer.fit(model, train_loader, ckpt_path='checkpoints/last.ckpt')
|
||||
|
||||
# Load for inference
|
||||
model = MyModel.load_from_checkpoint('checkpoints/best.ckpt')
|
||||
model.eval()
|
||||
```
|
||||
|
||||
## Strategy Comparison
|
||||
|
||||
| Strategy | Memory Efficiency | Speed | Use Case |
|
||||
|----------|------------------|-------|----------|
|
||||
| DDP | Low | Fast | Small models (<7B), single node |
|
||||
| FSDP | High | Medium | Large models (7-70B) |
|
||||
| DeepSpeed ZeRO-2 | Medium | Fast | Medium models (1-13B) |
|
||||
| DeepSpeed ZeRO-3 | Very High | Slower | Massive models (70B+) |
|
||||
| DDP Spawn | Low | Slow | Windows, debugging |
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Choose Right Strategy
|
||||
|
||||
```python
|
||||
# Model size guide
|
||||
if model_params < 1e9: # <1B
|
||||
strategy = 'ddp'
|
||||
elif model_params < 7e9: # 1-7B
|
||||
strategy = 'ddp' or DeepSpeedStrategy(stage=2)
|
||||
elif model_params < 70e9: # 7-70B
|
||||
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
|
||||
else: # 70B+
|
||||
strategy = DeepSpeedStrategy(stage=3, offload_optimizer=True)
|
||||
|
||||
trainer = L.Trainer(strategy=strategy)
|
||||
```
|
||||
|
||||
### 2. Avoid Sync Issues
|
||||
|
||||
```python
|
||||
class MyModel(L.LightningModule):
|
||||
def training_step(self, batch, batch_idx):
|
||||
# WRONG: This runs on all GPUs independently
|
||||
if batch_idx % 100 == 0:
|
||||
self.log_something() # Logged 8 times on 8 GPUs!
|
||||
|
||||
# CORRECT: Use is_global_zero
|
||||
if batch_idx % 100 == 0 and self.trainer.is_global_zero:
|
||||
self.log_something() # Logged once
|
||||
|
||||
loss = ...
|
||||
return loss
|
||||
```
|
||||
|
||||
### 3. Efficient Data Loading
|
||||
|
||||
```python
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
# Lightning handles DistributedSampler automatically
|
||||
train_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # 4 workers per GPU
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
)
|
||||
|
||||
# Lightning automatically wraps with DistributedSampler in DDP
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
### 4. Reduce Communication Overhead
|
||||
|
||||
```python
|
||||
from lightning.pytorch.strategies import DDPStrategy
|
||||
|
||||
strategy = DDPStrategy(
|
||||
gradient_as_bucket_view=True, # Reduce memory copies
|
||||
static_graph=True, # If model graph doesn't change (faster)
|
||||
)
|
||||
|
||||
trainer = L.Trainer(strategy=strategy)
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: NCCL Timeout
|
||||
|
||||
**Symptom**: Training hangs with `NCCL timeout` error
|
||||
|
||||
**Solution 1**: Increase timeout
|
||||
```bash
|
||||
export NCCL_TIMEOUT=3600 # 1 hour
|
||||
python train.py
|
||||
```
|
||||
|
||||
**Solution 2**: Check network
|
||||
```bash
|
||||
# Test inter-node communication
|
||||
nvidia-smi nvlink -s
|
||||
|
||||
# Verify all nodes can ping each other
|
||||
ping <node-2-ip>
|
||||
```
|
||||
|
||||
### Issue: OOM with FSDP
|
||||
|
||||
**Solution**: Enable CPU offload
|
||||
```python
|
||||
strategy = FSDPStrategy(
|
||||
sharding_strategy="FULL_SHARD",
|
||||
cpu_offload=True # Offload to CPU
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Different Results with DDP
|
||||
|
||||
**Cause**: Different random seeds per GPU
|
||||
|
||||
**Solution**: Set seed in LightningModule
|
||||
```python
|
||||
class MyModel(L.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
L.seed_everything(42, workers=True) # Same seed everywhere
|
||||
```
|
||||
|
||||
### Issue: DeepSpeed Config Errors
|
||||
|
||||
**Solution**: Use Lightning's auto config
|
||||
```python
|
||||
strategy = DeepSpeedStrategy(
|
||||
stage=3,
|
||||
# Don't specify config file, Lightning generates automatically
|
||||
)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Distributed strategies: https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html
|
||||
- FSDP guide: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html
|
||||
- DeepSpeed: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/deepspeed.html
|
||||
- Multi-node: https://lightning.ai/docs/pytorch/stable/clouds/cluster.html
|
||||
@@ -0,0 +1,556 @@
|
||||
# Hyperparameter Tuning with PyTorch Lightning
|
||||
|
||||
## Integration with Tuning Frameworks
|
||||
|
||||
Lightning integrates seamlessly with popular hyperparameter tuning libraries.
|
||||
|
||||
### 1. Ray Tune Integration
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install ray[tune]
|
||||
pip install lightning
|
||||
```
|
||||
|
||||
**Basic Ray Tune example**:
|
||||
|
||||
```python
|
||||
import lightning as L
|
||||
from ray import tune
|
||||
from ray.tune.integration.pytorch_lightning import TuneReportCallback
|
||||
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self, lr, batch_size):
|
||||
super().__init__()
|
||||
self.lr = lr
|
||||
self.batch_size = batch_size
|
||||
self.model = nn.Sequential(nn.Linear(10, 128), nn.ReLU(), nn.Linear(128, 1))
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.model(batch).mean()
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
val_loss = self.model(batch).mean()
|
||||
self.log('val_loss', val_loss)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
|
||||
def train_fn(config):
|
||||
"""Training function for Ray Tune."""
|
||||
model = LitModel(lr=config["lr"], batch_size=config["batch_size"])
|
||||
|
||||
# Add callback to report metrics to Tune
|
||||
trainer = L.Trainer(
|
||||
max_epochs=10,
|
||||
callbacks=[TuneReportCallback({"loss": "val_loss"}, on="validation_end")]
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
# Define search space
|
||||
config = {
|
||||
"lr": tune.loguniform(1e-5, 1e-1),
|
||||
"batch_size": tune.choice([16, 32, 64, 128])
|
||||
}
|
||||
|
||||
# Run hyperparameter search
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=20, # 20 trials
|
||||
resources_per_trial={"gpu": 1}
|
||||
)
|
||||
|
||||
# Best hyperparameters
|
||||
best_config = analysis.get_best_config(metric="loss", mode="min")
|
||||
print(f"Best config: {best_config}")
|
||||
```
|
||||
|
||||
**Advanced: Population-Based Training (PBT)**:
|
||||
|
||||
```python
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
# PBT scheduler
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr='training_iteration',
|
||||
metric='val_loss',
|
||||
mode='min',
|
||||
perturbation_interval=5, # Perturb every 5 epochs
|
||||
hyperparam_mutations={
|
||||
"lr": tune.loguniform(1e-5, 1e-1),
|
||||
"batch_size": [16, 32, 64, 128]
|
||||
}
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=8, # Population size
|
||||
scheduler=scheduler,
|
||||
resources_per_trial={"gpu": 1}
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Optuna Integration
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install optuna
|
||||
pip install optuna-integration
|
||||
```
|
||||
|
||||
**Optuna example**:
|
||||
|
||||
```python
|
||||
import optuna
|
||||
from optuna.integration import PyTorchLightningPruningCallback
|
||||
|
||||
def objective(trial):
|
||||
# Suggest hyperparameters
|
||||
lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
|
||||
batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
|
||||
n_layers = trial.suggest_int('n_layers', 1, 3)
|
||||
hidden_size = trial.suggest_int('hidden_size', 64, 512, step=64)
|
||||
|
||||
# Create model
|
||||
model = LitModel(lr=lr, n_layers=n_layers, hidden_size=hidden_size)
|
||||
|
||||
# Pruning callback (early stopping for bad trials)
|
||||
pruning_callback = PyTorchLightningPruningCallback(trial, monitor="val_loss")
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=20,
|
||||
callbacks=[pruning_callback],
|
||||
enable_progress_bar=False,
|
||||
logger=False
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
return trainer.callback_metrics["val_loss"].item()
|
||||
|
||||
# Create study
|
||||
study = optuna.create_study(
|
||||
direction='minimize',
|
||||
pruner=optuna.pruners.MedianPruner() # Prune bad trials early
|
||||
)
|
||||
|
||||
# Optimize
|
||||
study.optimize(objective, n_trials=50, timeout=3600)
|
||||
|
||||
# Best params
|
||||
print(f"Best trial: {study.best_trial.params}")
|
||||
print(f"Best value: {study.best_value}")
|
||||
|
||||
# Visualization
|
||||
optuna.visualization.plot_optimization_history(study).show()
|
||||
optuna.visualization.plot_param_importances(study).show()
|
||||
```
|
||||
|
||||
**Optuna with distributed training**:
|
||||
|
||||
```python
|
||||
import optuna
|
||||
|
||||
# Shared database for distributed optimization
|
||||
storage = optuna.storages.RDBStorage(
|
||||
url='postgresql://user:pass@localhost/optuna'
|
||||
)
|
||||
|
||||
study = optuna.create_study(
|
||||
study_name='distributed_study',
|
||||
storage=storage,
|
||||
load_if_exists=True,
|
||||
direction='minimize'
|
||||
)
|
||||
|
||||
# Run on multiple machines
|
||||
study.optimize(objective, n_trials=50)
|
||||
```
|
||||
|
||||
### 3. Weights & Biases (WandB) Sweeps
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install wandb
|
||||
```
|
||||
|
||||
**WandB sweep config** (`sweep.yaml`):
|
||||
```yaml
|
||||
program: train.py
|
||||
method: bayes
|
||||
metric:
|
||||
name: val_loss
|
||||
goal: minimize
|
||||
parameters:
|
||||
lr:
|
||||
distribution: log_uniform_values
|
||||
min: 0.00001
|
||||
max: 0.1
|
||||
batch_size:
|
||||
values: [16, 32, 64, 128]
|
||||
optimizer:
|
||||
values: ['adam', 'sgd', 'adamw']
|
||||
dropout:
|
||||
distribution: uniform
|
||||
min: 0.0
|
||||
max: 0.5
|
||||
```
|
||||
|
||||
**Training script** (`train.py`):
|
||||
```python
|
||||
import wandb
|
||||
import lightning as L
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
def train():
|
||||
# Initialize wandb
|
||||
wandb.init()
|
||||
config = wandb.config
|
||||
|
||||
# Create model with sweep params
|
||||
model = LitModel(
|
||||
lr=config.lr,
|
||||
batch_size=config.batch_size,
|
||||
optimizer=config.optimizer,
|
||||
dropout=config.dropout
|
||||
)
|
||||
|
||||
# WandB logger
|
||||
wandb_logger = WandbLogger(project='hyperparameter-sweep')
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=20,
|
||||
logger=wandb_logger
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
```
|
||||
|
||||
**Launch sweep**:
|
||||
```bash
|
||||
# Initialize sweep
|
||||
wandb sweep sweep.yaml
|
||||
# Output: wandb: Created sweep with ID: abc123
|
||||
|
||||
# Run agent (can run on multiple machines)
|
||||
wandb agent your-entity/your-project/abc123
|
||||
```
|
||||
|
||||
### 4. Hyperopt Integration
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install hyperopt
|
||||
```
|
||||
|
||||
**Hyperopt example**:
|
||||
|
||||
```python
|
||||
from hyperopt import hp, fmin, tpe, Trials
|
||||
|
||||
def objective(params):
|
||||
model = LitModel(
|
||||
lr=params['lr'],
|
||||
batch_size=int(params['batch_size']),
|
||||
hidden_size=int(params['hidden_size'])
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=10,
|
||||
enable_progress_bar=False,
|
||||
logger=False
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
# Return loss (minimize)
|
||||
return trainer.callback_metrics["val_loss"].item()
|
||||
|
||||
# Define search space
|
||||
space = {
|
||||
'lr': hp.loguniform('lr', np.log(1e-5), np.log(1e-1)),
|
||||
'batch_size': hp.quniform('batch_size', 16, 128, 16),
|
||||
'hidden_size': hp.quniform('hidden_size', 64, 512, 64)
|
||||
}
|
||||
|
||||
# Optimize
|
||||
trials = Trials()
|
||||
best = fmin(
|
||||
fn=objective,
|
||||
space=space,
|
||||
algo=tpe.suggest, # Tree-structured Parzen Estimator
|
||||
max_evals=50,
|
||||
trials=trials
|
||||
)
|
||||
|
||||
print(f"Best hyperparameters: {best}")
|
||||
```
|
||||
|
||||
## Built-In Lightning Tuning
|
||||
|
||||
### Auto Learning Rate Finder
|
||||
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self, lr=1e-3):
|
||||
super().__init__()
|
||||
self.lr = lr
|
||||
self.model = nn.Linear(10, 1)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.model(batch).mean()
|
||||
return loss
|
||||
|
||||
# Find optimal learning rate
|
||||
model = LitModel()
|
||||
trainer = L.Trainer(auto_lr_find=True)
|
||||
|
||||
# This runs LR finder before training
|
||||
trainer.tune(model, train_loader)
|
||||
|
||||
# Or manually
|
||||
from lightning.pytorch.tuner import Tuner
|
||||
tuner = Tuner(trainer)
|
||||
lr_finder = tuner.lr_find(model, train_loader)
|
||||
|
||||
# Plot results
|
||||
fig = lr_finder.plot(suggest=True)
|
||||
fig.show()
|
||||
|
||||
# Get suggested LR
|
||||
suggested_lr = lr_finder.suggestion()
|
||||
print(f"Suggested LR: {suggested_lr}")
|
||||
|
||||
# Update model
|
||||
model.lr = suggested_lr
|
||||
|
||||
# Train with optimal LR
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
### Auto Batch Size Finder
|
||||
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self, batch_size=32):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.model = nn.Linear(10, 1)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset, batch_size=self.batch_size)
|
||||
|
||||
model = LitModel()
|
||||
trainer = L.Trainer(auto_scale_batch_size='binsearch')
|
||||
|
||||
# Find optimal batch size
|
||||
trainer.tune(model)
|
||||
|
||||
print(f"Optimal batch size: {model.batch_size}")
|
||||
|
||||
# Train with optimal batch size
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
## Advanced Tuning Strategies
|
||||
|
||||
### 1. Multi-Fidelity Optimization (Successive Halving)
|
||||
|
||||
```python
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
|
||||
# ASHA: Asynchronous Successive Halving Algorithm
|
||||
scheduler = ASHAScheduler(
|
||||
max_t=100, # Max epochs
|
||||
grace_period=10, # Min epochs before stopping
|
||||
reduction_factor=2 # Halve resources each round
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=64,
|
||||
scheduler=scheduler,
|
||||
resources_per_trial={"gpu": 1}
|
||||
)
|
||||
```
|
||||
|
||||
**How it works**:
|
||||
- Start 64 trials
|
||||
- After 10 epochs, stop bottom 50% (32 trials remain)
|
||||
- After 20 epochs, stop bottom 50% (16 trials remain)
|
||||
- After 40 epochs, stop bottom 50% (8 trials remain)
|
||||
- After 80 epochs, stop bottom 50% (4 trials remain)
|
||||
- Run remaining 4 trials to completion (100 epochs)
|
||||
|
||||
### 2. Bayesian Optimization
|
||||
|
||||
```python
|
||||
from ray.tune.search.bayesopt import BayesOptSearch
|
||||
|
||||
search = BayesOptSearch(
|
||||
metric="val_loss",
|
||||
mode="min"
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=50,
|
||||
search_alg=search,
|
||||
resources_per_trial={"gpu": 1}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Grid Search
|
||||
|
||||
```python
|
||||
from ray import tune
|
||||
|
||||
# Exhaustive grid search
|
||||
config = {
|
||||
"lr": tune.grid_search([1e-5, 1e-4, 1e-3, 1e-2]),
|
||||
"batch_size": tune.grid_search([16, 32, 64, 128]),
|
||||
"optimizer": tune.grid_search(['adam', 'sgd', 'adamw'])
|
||||
}
|
||||
|
||||
# Total trials: 4 × 4 × 3 = 48
|
||||
analysis = tune.run(train_fn, config=config)
|
||||
```
|
||||
|
||||
### 4. Random Search
|
||||
|
||||
```python
|
||||
config = {
|
||||
"lr": tune.loguniform(1e-5, 1e-1),
|
||||
"batch_size": tune.choice([16, 32, 64, 128]),
|
||||
"dropout": tune.uniform(0.0, 0.5),
|
||||
"hidden_size": tune.randint(64, 512)
|
||||
}
|
||||
|
||||
# Random sampling
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=100 # 100 random samples
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Simple
|
||||
|
||||
```python
|
||||
# Phase 1: Coarse search (fast)
|
||||
coarse_config = {
|
||||
"lr": tune.loguniform(1e-5, 1e-1),
|
||||
"batch_size": tune.choice([32, 64])
|
||||
}
|
||||
coarse_analysis = tune.run(train_fn, config=coarse_config, num_samples=10, max_epochs=5)
|
||||
|
||||
# Phase 2: Fine-tune around best (slow)
|
||||
best_lr = coarse_analysis.best_config["lr"]
|
||||
fine_config = {
|
||||
"lr": tune.uniform(best_lr * 0.5, best_lr * 2),
|
||||
"batch_size": tune.choice([16, 32, 64, 128])
|
||||
}
|
||||
fine_analysis = tune.run(train_fn, config=fine_config, num_samples=20, max_epochs=20)
|
||||
```
|
||||
|
||||
### 2. Use Checkpointing
|
||||
|
||||
```python
|
||||
def train_fn(config, checkpoint_dir=None):
|
||||
model = LitModel(lr=config["lr"])
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=100,
|
||||
callbacks=[
|
||||
TuneReportCheckpointCallback(
|
||||
metrics={"loss": "val_loss"},
|
||||
filename="checkpoint",
|
||||
on="validation_end"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Resume from checkpoint if exists
|
||||
ckpt_path = None
|
||||
if checkpoint_dir:
|
||||
ckpt_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
|
||||
trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
|
||||
```
|
||||
|
||||
### 3. Monitor Resource Usage
|
||||
|
||||
```python
|
||||
import GPUtil
|
||||
|
||||
def train_fn(config):
|
||||
# Before training
|
||||
GPUs = GPUtil.getGPUs()
|
||||
print(f"GPU memory before: {GPUs[0].memoryUsed} MB")
|
||||
|
||||
# Train
|
||||
model = LitModel(lr=config["lr"], batch_size=config["batch_size"])
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# After training
|
||||
GPUs = GPUtil.getGPUs()
|
||||
print(f"GPU memory after: {GPUs[0].memoryUsed} MB")
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: Trials Running Out of Memory
|
||||
|
||||
**Solution**: Reduce concurrent trials or batch size
|
||||
```python
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
resources_per_trial={"gpu": 0.5}, # 2 trials per GPU
|
||||
max_concurrent_trials=2 # Limit concurrent trials
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slow Hyperparameter Search
|
||||
|
||||
**Solution**: Use early stopping scheduler
|
||||
```python
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
|
||||
scheduler = ASHAScheduler(
|
||||
max_t=100,
|
||||
grace_period=5, # Stop bad trials after 5 epochs
|
||||
reduction_factor=3
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Can't Reproduce Best Trial
|
||||
|
||||
**Solution**: Set seeds in training function
|
||||
```python
|
||||
def train_fn(config):
|
||||
L.seed_everything(42, workers=True)
|
||||
# Rest of training...
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Ray Tune + Lightning: https://docs.ray.io/en/latest/tune/examples/tune-pytorch-lightning.html
|
||||
- Optuna: https://optuna.readthedocs.io/
|
||||
- WandB Sweeps: https://docs.wandb.ai/guides/sweeps
|
||||
- Lightning Tuner: https://lightning.ai/docs/pytorch/stable/tuning.html
|
||||
222
protected/skills-backup/mlops/training/simpo/SKILL.md
Normal file
222
protected/skills-backup/mlops/training/simpo/SKILL.md
Normal file
@@ -0,0 +1,222 @@
|
||||
---
|
||||
name: simpo-training
|
||||
description: Simple Preference Optimization for LLM alignment. Reference-free alternative to DPO with better performance (+6.4 points on AlpacaEval 2.0). No reference model needed, more efficient than DPO. Use for preference alignment when want simpler, faster training than DPO/PPO.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [torch, transformers, datasets, trl, accelerate]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, SimPO, Preference Optimization, Alignment, DPO Alternative, Reference-Free, LLM Alignment, Efficient Training]
|
||||
|
||||
---
|
||||
|
||||
# SimPO - Simple Preference Optimization
|
||||
|
||||
## Quick start
|
||||
|
||||
SimPO is a reference-free preference optimization method that outperforms DPO without needing a reference model.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
# Create environment
|
||||
conda create -n simpo python=3.10 && conda activate simpo
|
||||
|
||||
# Install PyTorch 2.2.2
|
||||
# Visit: https://pytorch.org/get-started/locally/
|
||||
|
||||
# Install alignment-handbook
|
||||
git clone https://github.com/huggingface/alignment-handbook.git
|
||||
cd alignment-handbook
|
||||
python -m pip install .
|
||||
|
||||
# Install Flash Attention 2
|
||||
python -m pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
**Training** (Mistral 7B):
|
||||
```bash
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch \
|
||||
--config_file accelerate_configs/deepspeed_zero3.yaml \
|
||||
scripts/run_simpo.py \
|
||||
training_configs/mistral-7b-base-simpo.yaml
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Train from base model (Mistral 7B)
|
||||
|
||||
**Config** (`mistral-7b-base-simpo.yaml`):
|
||||
```yaml
|
||||
# Model
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
torch_dtype: bfloat16
|
||||
|
||||
# Dataset
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
dataset_splits:
|
||||
- train_prefs
|
||||
- test_prefs
|
||||
|
||||
# SimPO hyperparameters
|
||||
beta: 2.0 # Reward scaling (2.0-10.0)
|
||||
gamma_beta_ratio: 0.5 # Target margin (0-1)
|
||||
loss_type: sigmoid # sigmoid or hinge
|
||||
sft_weight: 0.0 # Optional SFT regularization
|
||||
|
||||
# Training
|
||||
learning_rate: 5e-7 # Critical: 3e-7 to 1e-6
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
|
||||
# Output
|
||||
output_dir: ./outputs/mistral-7b-simpo
|
||||
```
|
||||
|
||||
**Launch training**:
|
||||
```bash
|
||||
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \
|
||||
scripts/run_simpo.py training_configs/mistral-7b-base-simpo.yaml
|
||||
```
|
||||
|
||||
### Workflow 2: Fine-tune instruct model (Llama 3 8B)
|
||||
|
||||
**Config** (`llama3-8b-instruct-simpo.yaml`):
|
||||
```yaml
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
dataset_mixer:
|
||||
argilla/ultrafeedback-binarized-preferences-cleaned: 1.0
|
||||
|
||||
beta: 2.5
|
||||
gamma_beta_ratio: 0.5
|
||||
learning_rate: 5e-7
|
||||
sft_weight: 0.1 # Add SFT loss to preserve capabilities
|
||||
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
output_dir: ./outputs/llama3-8b-simpo
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \
|
||||
scripts/run_simpo.py training_configs/llama3-8b-instruct-simpo.yaml
|
||||
```
|
||||
|
||||
### Workflow 3: Reasoning-intensive tasks (lower LR)
|
||||
|
||||
**For math/code tasks**:
|
||||
```yaml
|
||||
model_name_or_path: deepseek-ai/deepseek-math-7b-base
|
||||
|
||||
dataset_mixer:
|
||||
argilla/distilabel-math-preference-dpo: 1.0
|
||||
|
||||
beta: 5.0 # Higher for stronger signal
|
||||
gamma_beta_ratio: 0.7 # Larger margin
|
||||
learning_rate: 3e-7 # Lower LR for reasoning
|
||||
sft_weight: 0.0
|
||||
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use SimPO when**:
|
||||
- Want simpler training than DPO (no reference model)
|
||||
- Have preference data (chosen/rejected pairs)
|
||||
- Need better performance than DPO
|
||||
- Limited compute resources
|
||||
- Single-node training sufficient
|
||||
|
||||
**Algorithm selection**:
|
||||
- **SimPO**: Simplest, best performance, no reference model
|
||||
- **DPO**: Need reference model baseline, more conservative
|
||||
- **PPO**: Maximum control, need reward model, complex setup
|
||||
- **GRPO**: Memory-efficient RL, no critic
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **OpenRLHF**: Multi-node distributed training, PPO/GRPO
|
||||
- **TRL**: Need multiple methods in one framework
|
||||
- **DPO**: Established baseline comparison
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Loss divergence**
|
||||
|
||||
Reduce learning rate:
|
||||
```yaml
|
||||
learning_rate: 3e-7 # Reduce from 5e-7
|
||||
```
|
||||
|
||||
Reduce beta:
|
||||
```yaml
|
||||
beta: 1.0 # Reduce from 2.0
|
||||
```
|
||||
|
||||
**Issue: Model forgets capabilities**
|
||||
|
||||
Add SFT regularization:
|
||||
```yaml
|
||||
sft_weight: 0.1 # Add SFT loss component
|
||||
```
|
||||
|
||||
**Issue: Poor preference separation**
|
||||
|
||||
Increase beta and margin:
|
||||
```yaml
|
||||
beta: 5.0 # Increase from 2.0
|
||||
gamma_beta_ratio: 0.8 # Increase from 0.5
|
||||
```
|
||||
|
||||
**Issue: OOM during training**
|
||||
|
||||
Reduce batch size:
|
||||
```yaml
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 16 # Maintain effective batch
|
||||
```
|
||||
|
||||
Enable gradient checkpointing:
|
||||
```yaml
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Loss functions**: See [references/loss-functions.md](references/loss-functions.md) for sigmoid vs hinge loss, mathematical formulations, and when to use each.
|
||||
|
||||
**Hyperparameter tuning**: See [references/hyperparameters.md](references/hyperparameters.md) for beta, gamma, learning rate selection guide, and model-size-specific recommendations.
|
||||
|
||||
**Dataset preparation**: See [references/datasets.md](references/datasets.md) for preference data formats, quality filtering, and custom dataset creation.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA A100/H100 recommended
|
||||
- **VRAM**:
|
||||
- 7B model: 1× A100 40GB (DeepSpeed ZeRO-3)
|
||||
- 8B model: 2× A100 40GB
|
||||
- 70B model: 8× A100 80GB
|
||||
- **Single-node**: DeepSpeed ZeRO-3 sufficient
|
||||
- **Mixed precision**: BF16 recommended
|
||||
|
||||
**Memory optimization**:
|
||||
- DeepSpeed ZeRO-3 (default config)
|
||||
- Gradient checkpointing
|
||||
- Flash Attention 2
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper: https://arxiv.org/abs/2405.14734 (NeurIPS 2024)
|
||||
- GitHub: https://github.com/princeton-nlp/SimPO
|
||||
- Models: https://huggingface.co/princeton-nlp
|
||||
- Alignment Handbook: https://github.com/huggingface/alignment-handbook
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,478 @@
|
||||
# Datasets
|
||||
|
||||
Complete guide to preference datasets for SimPO training.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
### Required Fields
|
||||
|
||||
Preference datasets must contain:
|
||||
```json
|
||||
{
|
||||
"prompt": "User question or instruction",
|
||||
"chosen": "Better/preferred response",
|
||||
"rejected": "Worse/rejected response"
|
||||
}
|
||||
```
|
||||
|
||||
**Alternative field names** (auto-detected):
|
||||
- `prompt` → `question`, `instruction`, `input`
|
||||
- `chosen` → `response_chosen`, `winner`, `preferred`
|
||||
- `rejected` → `response_rejected`, `loser`
|
||||
|
||||
### Example Entry
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "Explain quantum computing in simple terms.",
|
||||
"chosen": "Quantum computing uses quantum bits (qubits) that can exist in multiple states simultaneously through superposition. This allows quantum computers to process many possibilities at once, making them potentially much faster than classical computers for specific tasks like cryptography and optimization.",
|
||||
"rejected": "It's like regular computing but quantum."
|
||||
}
|
||||
```
|
||||
|
||||
## Popular Datasets
|
||||
|
||||
### 1. UltraFeedback (Recommended)
|
||||
|
||||
**HuggingFaceH4/ultrafeedback_binarized**:
|
||||
- **Size**: 60K preference pairs
|
||||
- **Quality**: High (GPT-4 annotations)
|
||||
- **Domain**: General instruction following
|
||||
- **Format**: Clean, ready-to-use
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
dataset_splits:
|
||||
- train_prefs
|
||||
- test_prefs
|
||||
```
|
||||
|
||||
### 2. Argilla UltraFeedback (Cleaned)
|
||||
|
||||
**argilla/ultrafeedback-binarized-preferences-cleaned**:
|
||||
- **Size**: 50K pairs (filtered)
|
||||
- **Quality**: Very high (deduped, cleaned)
|
||||
- **Domain**: General
|
||||
- **Format**: Clean
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
argilla/ultrafeedback-binarized-preferences-cleaned: 1.0
|
||||
```
|
||||
|
||||
### 3. Distilabel Math
|
||||
|
||||
**argilla/distilabel-math-preference-dpo**:
|
||||
- **Size**: 30K pairs
|
||||
- **Quality**: High (GSM8K, MATH)
|
||||
- **Domain**: Math reasoning
|
||||
- **Format**: Math-specific
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
argilla/distilabel-math-preference-dpo: 1.0
|
||||
```
|
||||
|
||||
### 4. HelpSteer
|
||||
|
||||
**nvidia/HelpSteer**:
|
||||
- **Size**: 38K samples
|
||||
- **Quality**: High (human ratings)
|
||||
- **Domain**: Helpfulness alignment
|
||||
- **Format**: Multi-attribute ratings
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
nvidia/HelpSteer: 1.0
|
||||
```
|
||||
|
||||
### 5. Anthropic HH-RLHF
|
||||
|
||||
**Anthropic/hh-rlhf**:
|
||||
- **Size**: 161K samples
|
||||
- **Quality**: High (human preferences)
|
||||
- **Domain**: Harmless + helpful
|
||||
- **Format**: Conversational
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
Anthropic/hh-rlhf: 1.0
|
||||
```
|
||||
|
||||
## Dataset Mixing
|
||||
|
||||
### Multiple Datasets
|
||||
|
||||
**Equal mix**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 0.5
|
||||
Anthropic/hh-rlhf: 0.5
|
||||
```
|
||||
|
||||
**Weighted mix**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 0.7
|
||||
argilla/distilabel-math-preference-dpo: 0.2
|
||||
nvidia/HelpSteer: 0.1
|
||||
```
|
||||
|
||||
**Domain-specific emphasis**:
|
||||
```yaml
|
||||
# 80% general + 20% math
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 0.8
|
||||
argilla/distilabel-math-preference-dpo: 0.2
|
||||
```
|
||||
|
||||
## Data Quality
|
||||
|
||||
### Quality Indicators
|
||||
|
||||
**Good preference data**:
|
||||
- ✅ Clear quality difference between chosen/rejected
|
||||
- ✅ Diverse prompts
|
||||
- ✅ Minimal noise/annotation errors
|
||||
- ✅ Appropriate difficulty level
|
||||
|
||||
**Poor preference data**:
|
||||
- ❌ Ambiguous preferences
|
||||
- ❌ Repetitive prompts
|
||||
- ❌ Annotation noise
|
||||
- ❌ Too easy/hard prompts
|
||||
|
||||
### Quality Filtering
|
||||
|
||||
**Filter by length difference**:
|
||||
```python
|
||||
def filter_by_length(example):
|
||||
chosen_len = len(example['chosen'].split())
|
||||
rejected_len = len(example['rejected'].split())
|
||||
# Reject if chosen is much shorter (potential low-effort)
|
||||
return chosen_len >= rejected_len * 0.5
|
||||
|
||||
dataset = dataset.filter(filter_by_length)
|
||||
```
|
||||
|
||||
**Filter by diversity**:
|
||||
```python
|
||||
seen_prompts = set()
|
||||
|
||||
def filter_duplicates(example):
|
||||
prompt = example['prompt']
|
||||
if prompt in seen_prompts:
|
||||
return False
|
||||
seen_prompts.add(prompt)
|
||||
return True
|
||||
|
||||
dataset = dataset.filter(filter_duplicates)
|
||||
```
|
||||
|
||||
## Custom Dataset Creation
|
||||
|
||||
### Format 1: JSON Lines
|
||||
|
||||
**File** (`preferences.jsonl`):
|
||||
```jsonl
|
||||
{"prompt": "What is Python?", "chosen": "Python is a high-level programming language...", "rejected": "It's a snake."}
|
||||
{"prompt": "Explain AI.", "chosen": "AI refers to systems that can...", "rejected": "It's computers that think."}
|
||||
```
|
||||
|
||||
**Load**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
json:
|
||||
data_files: preferences.jsonl
|
||||
```
|
||||
|
||||
### Format 2: HuggingFace Dataset
|
||||
|
||||
**Create from dict**:
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
data = {
|
||||
"prompt": ["What is Python?", "Explain AI."],
|
||||
"chosen": ["Python is...", "AI refers to..."],
|
||||
"rejected": ["It's a snake.", "It's computers..."]
|
||||
}
|
||||
|
||||
dataset = Dataset.from_dict(data)
|
||||
dataset.push_to_hub("username/my-preferences")
|
||||
```
|
||||
|
||||
**Use in config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
username/my-preferences: 1.0
|
||||
```
|
||||
|
||||
### Format 3: ChatML
|
||||
|
||||
**For conversational data**:
|
||||
```json
|
||||
{
|
||||
"prompt": [
|
||||
{"role": "user", "content": "What is quantum computing?"}
|
||||
],
|
||||
"chosen": [
|
||||
{"role": "assistant", "content": "Quantum computing uses qubits..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "assistant", "content": "It's like regular computing but quantum."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Apply chat template**:
|
||||
```yaml
|
||||
dataset_text_field: null # Will apply chat template
|
||||
```
|
||||
|
||||
## Synthetic Data Generation
|
||||
|
||||
### Using GPT-4
|
||||
|
||||
**Prompt template**:
|
||||
```
|
||||
Given the following question:
|
||||
{prompt}
|
||||
|
||||
Generate two responses:
|
||||
1. A high-quality, detailed response (chosen)
|
||||
2. A low-quality, brief response (rejected)
|
||||
|
||||
Format as JSON with "chosen" and "rejected" fields.
|
||||
```
|
||||
|
||||
**Example code**:
|
||||
```python
|
||||
import openai
|
||||
|
||||
def generate_pair(prompt):
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-4",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Given: {prompt}\n\nGenerate chosen/rejected pair in JSON."
|
||||
}]
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
|
||||
# Generate dataset
|
||||
prompts = load_prompts()
|
||||
dataset = [generate_pair(p) for p in prompts]
|
||||
```
|
||||
|
||||
### Using Local Model
|
||||
|
||||
**With vLLM**:
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-70B-Instruct")
|
||||
|
||||
def generate_variations(prompt):
|
||||
# Generate multiple completions
|
||||
outputs = llm.generate(
|
||||
[prompt] * 4,
|
||||
sampling_params={
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": 512
|
||||
}
|
||||
)
|
||||
|
||||
# Select best/worst
|
||||
chosen = max(outputs, key=lambda x: len(x.outputs[0].text))
|
||||
rejected = min(outputs, key=lambda x: len(x.outputs[0].text))
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"chosen": chosen.outputs[0].text,
|
||||
"rejected": rejected.outputs[0].text
|
||||
}
|
||||
```
|
||||
|
||||
## Data Preprocessing
|
||||
|
||||
### Truncation
|
||||
|
||||
**Limit sequence length**:
|
||||
```yaml
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 512
|
||||
max_length: 1024 # Total
|
||||
```
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
def truncate_example(example):
|
||||
tokenizer.truncation_side = "left" # For prompts
|
||||
prompt_tokens = tokenizer(
|
||||
example['prompt'],
|
||||
max_length=512,
|
||||
truncation=True
|
||||
)
|
||||
|
||||
tokenizer.truncation_side = "right" # For completions
|
||||
chosen_tokens = tokenizer(
|
||||
example['chosen'],
|
||||
max_length=512,
|
||||
truncation=True
|
||||
)
|
||||
|
||||
return {
|
||||
"prompt": tokenizer.decode(prompt_tokens['input_ids']),
|
||||
"chosen": tokenizer.decode(chosen_tokens['input_ids'])
|
||||
}
|
||||
|
||||
dataset = dataset.map(truncate_example)
|
||||
```
|
||||
|
||||
### Deduplication
|
||||
|
||||
**Remove exact duplicates**:
|
||||
```python
|
||||
dataset = dataset.unique('prompt')
|
||||
```
|
||||
|
||||
**Remove near-duplicates** (MinHash):
|
||||
```python
|
||||
from datasketch import MinHash, MinHashLSH
|
||||
|
||||
def deduplicate_lsh(dataset, threshold=0.8):
|
||||
lsh = MinHashLSH(threshold=threshold, num_perm=128)
|
||||
seen = []
|
||||
|
||||
for i, example in enumerate(dataset):
|
||||
m = MinHash(num_perm=128)
|
||||
for word in example['prompt'].split():
|
||||
m.update(word.encode('utf8'))
|
||||
|
||||
if not lsh.query(m):
|
||||
lsh.insert(i, m)
|
||||
seen.append(example)
|
||||
|
||||
return Dataset.from_list(seen)
|
||||
|
||||
dataset = deduplicate_lsh(dataset)
|
||||
```
|
||||
|
||||
## Data Augmentation
|
||||
|
||||
### Paraphrasing Prompts
|
||||
|
||||
```python
|
||||
def paraphrase_prompt(example):
|
||||
# Use paraphrasing model
|
||||
paraphrased = paraphrase_model(example['prompt'])
|
||||
|
||||
return [
|
||||
example, # Original
|
||||
{
|
||||
"prompt": paraphrased,
|
||||
"chosen": example['chosen'],
|
||||
"rejected": example['rejected']
|
||||
}
|
||||
]
|
||||
|
||||
dataset = dataset.map(paraphrase_prompt, batched=False, remove_columns=[])
|
||||
```
|
||||
|
||||
### Difficulty Balancing
|
||||
|
||||
**Mix easy/medium/hard**:
|
||||
```python
|
||||
def categorize_difficulty(example):
|
||||
prompt_len = len(example['prompt'].split())
|
||||
if prompt_len < 20:
|
||||
return "easy"
|
||||
elif prompt_len < 50:
|
||||
return "medium"
|
||||
else:
|
||||
return "hard"
|
||||
|
||||
dataset = dataset.map(lambda x: {"difficulty": categorize_difficulty(x)})
|
||||
|
||||
# Sample balanced dataset
|
||||
easy = dataset.filter(lambda x: x['difficulty'] == 'easy').shuffle().select(range(1000))
|
||||
medium = dataset.filter(lambda x: x['difficulty'] == 'medium').shuffle().select(range(1000))
|
||||
hard = dataset.filter(lambda x: x['difficulty'] == 'hard').shuffle().select(range(1000))
|
||||
|
||||
balanced = concatenate_datasets([easy, medium, hard]).shuffle()
|
||||
```
|
||||
|
||||
## Dataset Statistics
|
||||
|
||||
### Compute Stats
|
||||
|
||||
```python
|
||||
def compute_stats(dataset):
|
||||
prompt_lens = [len(x['prompt'].split()) for x in dataset]
|
||||
chosen_lens = [len(x['chosen'].split()) for x in dataset]
|
||||
rejected_lens = [len(x['rejected'].split()) for x in dataset]
|
||||
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
print(f"Avg prompt length: {np.mean(prompt_lens):.1f} words")
|
||||
print(f"Avg chosen length: {np.mean(chosen_lens):.1f} words")
|
||||
print(f"Avg rejected length: {np.mean(rejected_lens):.1f} words")
|
||||
print(f"Chosen > Rejected: {sum(c > r for c, r in zip(chosen_lens, rejected_lens)) / len(dataset):.1%}")
|
||||
|
||||
compute_stats(dataset)
|
||||
```
|
||||
|
||||
**Expected output**:
|
||||
```
|
||||
Dataset size: 50000
|
||||
Avg prompt length: 45.2 words
|
||||
Avg chosen length: 180.5 words
|
||||
Avg rejected length: 120.3 words
|
||||
Chosen > Rejected: 85.2%
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Data Quality Over Quantity
|
||||
|
||||
- **Prefer**: 10K high-quality pairs
|
||||
- **Over**: 100K noisy pairs
|
||||
|
||||
### 2. Clear Preference Signals
|
||||
|
||||
- Chosen should be noticeably better
|
||||
- Avoid marginal differences
|
||||
- Remove ambiguous pairs
|
||||
|
||||
### 3. Domain Matching
|
||||
|
||||
- Match dataset domain to target use case
|
||||
- Mix datasets for broader coverage
|
||||
- Include safety-filtered data
|
||||
|
||||
### 4. Validate Before Training
|
||||
|
||||
```python
|
||||
# Sample 10 random examples
|
||||
samples = dataset.shuffle().select(range(10))
|
||||
|
||||
for ex in samples:
|
||||
print(f"Prompt: {ex['prompt']}")
|
||||
print(f"Chosen: {ex['chosen'][:100]}...")
|
||||
print(f"Rejected: {ex['rejected'][:100]}...")
|
||||
print(f"Preference clear: {'✓' if len(ex['chosen']) > len(ex['rejected']) else '?'}")
|
||||
print()
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- HuggingFace Datasets: https://huggingface.co/datasets
|
||||
- Alignment Handbook: https://github.com/huggingface/alignment-handbook
|
||||
- UltraFeedback: https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized
|
||||
@@ -0,0 +1,452 @@
|
||||
# Hyperparameters
|
||||
|
||||
Complete guide to SimPO hyperparameter selection and tuning.
|
||||
|
||||
## Overview
|
||||
|
||||
Key hyperparameters in SimPO:
|
||||
1. **Learning Rate** - Most critical
|
||||
2. **Beta (β)** - Reward scaling
|
||||
3. **Gamma-Beta Ratio (γ/β)** - Target margin
|
||||
4. **SFT Weight** - Regularization strength
|
||||
|
||||
## Learning Rate
|
||||
|
||||
### Recommended Ranges
|
||||
|
||||
**By model size**:
|
||||
| Model Size | Learning Rate | Notes |
|
||||
|------------|---------------|-------|
|
||||
| 1B-3B | 5e-7 to 1e-6 | Higher end safe |
|
||||
| 7B-8B | 3e-7 to 5e-7 | **Standard** |
|
||||
| 13B-30B | 1e-7 to 3e-7 | Lower for stability |
|
||||
| 70B+ | 5e-8 to 1e-7 | Very conservative |
|
||||
|
||||
**By task type**:
|
||||
| Task | Learning Rate | Reason |
|
||||
|------|---------------|--------|
|
||||
| General chat | 5e-7 | Standard |
|
||||
| Code generation | 3e-7 | **Precise reasoning** |
|
||||
| Math reasoning | 3e-7 | **Careful optimization** |
|
||||
| Creative writing | 1e-6 | More aggressive OK |
|
||||
|
||||
### Why Learning Rate Matters
|
||||
|
||||
**Too high** (> 1e-6 for 7B):
|
||||
- Loss divergence
|
||||
- Catastrophic forgetting
|
||||
- Unstable training
|
||||
|
||||
**Too low** (< 1e-7 for 7B):
|
||||
- Very slow convergence
|
||||
- May not finish in time
|
||||
- Undertraining
|
||||
|
||||
**Optimal** (3e-7 to 5e-7 for 7B):
|
||||
- Stable convergence
|
||||
- Good final performance
|
||||
- Efficient training
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Mistral 7B (general)**:
|
||||
```yaml
|
||||
learning_rate: 5e-7
|
||||
num_train_epochs: 1
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: cosine
|
||||
```
|
||||
|
||||
**Llama 3 8B (reasoning)**:
|
||||
```yaml
|
||||
learning_rate: 3e-7
|
||||
num_train_epochs: 1
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: cosine
|
||||
```
|
||||
|
||||
**Gemma 2 9B (creative)**:
|
||||
```yaml
|
||||
learning_rate: 1e-6
|
||||
num_train_epochs: 1
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: linear
|
||||
```
|
||||
|
||||
## Beta (β)
|
||||
|
||||
### Recommended Values
|
||||
|
||||
**Range**: 2.0 to 10.0 (much higher than DPO's 0.01-0.1)
|
||||
|
||||
**By preference strength**:
|
||||
| Beta | Preference Strength | Use Case |
|
||||
|------|-------------------|----------|
|
||||
| 1.0-2.0 | Weak | Subtle preferences |
|
||||
| 2.0-5.0 | **Standard** | General alignment |
|
||||
| 5.0-10.0 | Strong | Clear preferences |
|
||||
|
||||
**Default**: 2.0 to 2.5
|
||||
|
||||
### Why Beta Matters
|
||||
|
||||
**Low beta** (< 2.0):
|
||||
- Weak reward signal
|
||||
- Slow preference learning
|
||||
- May underfit
|
||||
|
||||
**High beta** (> 10.0):
|
||||
- Very strong reward signal
|
||||
- Risk of overfitting
|
||||
- May ignore weak preferences
|
||||
|
||||
**Optimal** (2.0-5.0):
|
||||
- Balanced reward scaling
|
||||
- Stable training
|
||||
- Good generalization
|
||||
|
||||
### Interaction with Gamma
|
||||
|
||||
**Beta and gamma together**:
|
||||
```
|
||||
Target margin in reward space = gamma
|
||||
Target margin in logit space = gamma / beta
|
||||
```
|
||||
|
||||
**Example**:
|
||||
```yaml
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
# Effective gamma = 2.0 * 0.5 = 1.0
|
||||
```
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Weak preferences**:
|
||||
```yaml
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.3 # Small margin
|
||||
```
|
||||
|
||||
**Standard**:
|
||||
```yaml
|
||||
beta: 2.5
|
||||
gamma_beta_ratio: 0.5 # Default
|
||||
```
|
||||
|
||||
**Strong preferences**:
|
||||
```yaml
|
||||
beta: 5.0
|
||||
gamma_beta_ratio: 0.7 # Larger margin
|
||||
```
|
||||
|
||||
## Gamma-Beta Ratio (γ/β)
|
||||
|
||||
### Recommended Values
|
||||
|
||||
**Range**: 0.0 to 1.0
|
||||
|
||||
**By scenario**:
|
||||
| Ratio | Margin | Use Case |
|
||||
|-------|--------|----------|
|
||||
| 0.0-0.3 | Small | Weak preference data |
|
||||
| 0.4-0.6 | **Standard** | General use |
|
||||
| 0.7-1.0 | Large | Very clear preferences |
|
||||
|
||||
**Default**: 0.5
|
||||
|
||||
### Why Gamma Matters
|
||||
|
||||
**Low gamma** (< 0.3):
|
||||
- Small target margin
|
||||
- Less aggressive alignment
|
||||
- More conservative
|
||||
|
||||
**High gamma** (> 0.7):
|
||||
- Large target margin
|
||||
- Stronger alignment
|
||||
- More aggressive
|
||||
|
||||
**Optimal** (0.4-0.6):
|
||||
- Balanced margin
|
||||
- Stable training
|
||||
- Good alignment
|
||||
|
||||
### Mathematical Meaning
|
||||
|
||||
**In loss function**:
|
||||
```python
|
||||
logits = pi_logratios - gamma_beta_ratio
|
||||
loss = -log(sigmoid(beta * logits))
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- gamma_beta_ratio shifts the decision boundary
|
||||
- Higher ratio = requires larger log prob difference
|
||||
- Controls how "clear" preferences must be
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Noisy preferences**:
|
||||
```yaml
|
||||
gamma_beta_ratio: 0.3 # Smaller margin, more tolerant
|
||||
```
|
||||
|
||||
**Standard**:
|
||||
```yaml
|
||||
gamma_beta_ratio: 0.5 # Default
|
||||
```
|
||||
|
||||
**High-quality preferences**:
|
||||
```yaml
|
||||
gamma_beta_ratio: 0.8 # Larger margin, stricter
|
||||
```
|
||||
|
||||
## SFT Weight
|
||||
|
||||
### Recommended Values
|
||||
|
||||
**Range**: 0.0 to 1.0
|
||||
|
||||
**By model type**:
|
||||
| Model Type | SFT Weight | Reason |
|
||||
|------------|-----------|--------|
|
||||
| Base model | 0.0 | No prior capabilities |
|
||||
| **Instruct model** | 0.05-0.1 | Preserve instruction following |
|
||||
| Chat model | 0.1-0.2 | Preserve conversational skills |
|
||||
|
||||
**Default**: 0.0 (no SFT regularization)
|
||||
|
||||
### Why SFT Weight Matters
|
||||
|
||||
**Zero SFT** (0.0):
|
||||
- Pure preference optimization
|
||||
- May forget capabilities
|
||||
- Standard for base models
|
||||
|
||||
**Low SFT** (0.05-0.1):
|
||||
- Balanced approach
|
||||
- **Recommended for instruct models**
|
||||
- Slight capability preservation
|
||||
|
||||
**High SFT** (> 0.2):
|
||||
- Strong capability preservation
|
||||
- Weaker preference alignment
|
||||
- May reduce alignment gains
|
||||
|
||||
### Trade-off
|
||||
|
||||
```
|
||||
Total Loss = SimPO Loss + (sft_weight * SFT Loss)
|
||||
```
|
||||
|
||||
**Example**:
|
||||
```yaml
|
||||
sft_weight: 0.1
|
||||
# 90% preference optimization + 10% capability preservation
|
||||
```
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Base model (no SFT)**:
|
||||
```yaml
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
sft_weight: 0.0
|
||||
```
|
||||
|
||||
**Instruct model (light SFT)**:
|
||||
```yaml
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
sft_weight: 0.1
|
||||
```
|
||||
|
||||
**Chat model (moderate SFT)**:
|
||||
```yaml
|
||||
model_name_or_path: HuggingFaceH4/zephyr-7b-beta
|
||||
sft_weight: 0.2
|
||||
```
|
||||
|
||||
## Model-Size-Specific Recommendations
|
||||
|
||||
### 7B Models (Mistral, Llama 3)
|
||||
|
||||
**Standard config**:
|
||||
```yaml
|
||||
learning_rate: 5e-7
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
sft_weight: 0.0 # 0.1 if instruct model
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
```
|
||||
|
||||
### 8B-13B Models
|
||||
|
||||
**Standard config**:
|
||||
```yaml
|
||||
learning_rate: 3e-7
|
||||
beta: 2.5
|
||||
gamma_beta_ratio: 0.5
|
||||
sft_weight: 0.1 # If instruct
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
```
|
||||
|
||||
### 70B Models
|
||||
|
||||
**Standard config**:
|
||||
```yaml
|
||||
learning_rate: 1e-7
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
sft_weight: 0.05
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
```
|
||||
|
||||
## Batch Size & Gradient Accumulation
|
||||
|
||||
### Effective Batch Size
|
||||
|
||||
```
|
||||
Effective Batch Size = per_device_batch_size * num_gpus * grad_accum_steps
|
||||
```
|
||||
|
||||
**Recommended effective batch sizes**:
|
||||
- 7B: 128-256
|
||||
- 13B: 64-128
|
||||
- 70B: 32-64
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Single GPU (A100 40GB)**:
|
||||
```yaml
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 128 # Effective batch = 128
|
||||
```
|
||||
|
||||
**4 GPUs (A100 40GB)**:
|
||||
```yaml
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 16 # Effective batch = 2*4*16 = 128
|
||||
```
|
||||
|
||||
**8 GPUs (A100 80GB)**:
|
||||
```yaml
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 8 # Effective batch = 2*8*8 = 128
|
||||
```
|
||||
|
||||
## Loss Type
|
||||
|
||||
### Sigmoid vs Hinge
|
||||
|
||||
**Sigmoid** (default, recommended):
|
||||
```yaml
|
||||
loss_type: sigmoid
|
||||
label_smoothing: 0.0
|
||||
```
|
||||
|
||||
**Hinge** (experimental):
|
||||
```yaml
|
||||
loss_type: hinge
|
||||
# No label smoothing for hinge
|
||||
```
|
||||
|
||||
**When to use hinge**:
|
||||
- Margin-based tasks
|
||||
- SVM-style optimization
|
||||
- Experimental purposes
|
||||
|
||||
**Generally**: Stick with sigmoid
|
||||
|
||||
## Tuning Guide
|
||||
|
||||
### Step 1: Start with Defaults
|
||||
|
||||
```yaml
|
||||
learning_rate: 5e-7 # For 7B
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
sft_weight: 0.0 # 0.1 if instruct
|
||||
loss_type: sigmoid
|
||||
```
|
||||
|
||||
### Step 2: Monitor Training
|
||||
|
||||
**Check every 100 steps**:
|
||||
- Loss curve (should decrease smoothly)
|
||||
- Reward margin (should increase)
|
||||
- Chosen/rejected logps (should separate)
|
||||
|
||||
### Step 3: Adjust if Needed
|
||||
|
||||
**If loss diverges**:
|
||||
```yaml
|
||||
learning_rate: 3e-7 # Reduce from 5e-7
|
||||
beta: 1.0 # Reduce from 2.0
|
||||
```
|
||||
|
||||
**If loss plateaus early**:
|
||||
```yaml
|
||||
learning_rate: 1e-6 # Increase from 5e-7
|
||||
beta: 5.0 # Increase from 2.0
|
||||
```
|
||||
|
||||
**If model forgets**:
|
||||
```yaml
|
||||
sft_weight: 0.2 # Increase from 0.0
|
||||
```
|
||||
|
||||
## Complete Example Configs
|
||||
|
||||
### Mistral 7B Base (Standard)
|
||||
|
||||
```yaml
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
|
||||
learning_rate: 5e-7
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
loss_type: sigmoid
|
||||
sft_weight: 0.0
|
||||
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: cosine
|
||||
|
||||
bf16: true
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
### Llama 3 8B Instruct (Reasoning)
|
||||
|
||||
```yaml
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
dataset_mixer:
|
||||
argilla/distilabel-math-preference-dpo: 1.0
|
||||
|
||||
learning_rate: 3e-7
|
||||
beta: 5.0
|
||||
gamma_beta_ratio: 0.7
|
||||
loss_type: sigmoid
|
||||
sft_weight: 0.1
|
||||
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: cosine
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- SimPO paper: https://arxiv.org/abs/2405.14734
|
||||
- Alignment Handbook: https://github.com/huggingface/alignment-handbook
|
||||
@@ -0,0 +1,350 @@
|
||||
# Loss Functions
|
||||
|
||||
Complete guide to SimPO loss functions and mathematical formulations.
|
||||
|
||||
## Overview
|
||||
|
||||
SimPO supports two loss types:
|
||||
- **Sigmoid** (default) - Smooth, differentiable loss
|
||||
- **Hinge** - Margin-based, sparse loss
|
||||
|
||||
Both are reference-free (no reference model needed).
|
||||
|
||||
## SimPO Loss Formula
|
||||
|
||||
### Core Calculation
|
||||
|
||||
**Step 1: Log probability ratio**:
|
||||
```
|
||||
pi_logratios = log P_θ(y_chosen|x) - log P_θ(y_rejected|x)
|
||||
```
|
||||
|
||||
**Step 2: Apply target margin**:
|
||||
```
|
||||
logits = pi_logratios - γ/β
|
||||
```
|
||||
Where:
|
||||
- γ/β = `gamma_beta_ratio` (target margin)
|
||||
|
||||
**Step 3: Compute loss** (depends on loss type)
|
||||
|
||||
### Sigmoid Loss (Default)
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
L = -log σ(β * logits) * (1 - ε) - log σ(-β * logits) * ε
|
||||
```
|
||||
|
||||
Where:
|
||||
- β = `beta` (reward scaling)
|
||||
- σ = sigmoid function
|
||||
- ε = `label_smoothing` (default 0.0)
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
losses = (
|
||||
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
||||
)
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- Smooth, continuous gradients
|
||||
- Probabilistic interpretation
|
||||
- Standard choice for most tasks
|
||||
- Works well with higher beta values
|
||||
|
||||
### Hinge Loss
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
L = max(0, 1 - β * logits)
|
||||
```
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
losses = torch.relu(1 - self.beta * logits)
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- Non-smooth (has kink at logits = 1/β)
|
||||
- Margin-based (SVM-style)
|
||||
- Can lead to sparser solutions
|
||||
- Less commonly used
|
||||
|
||||
## Comparison to DPO
|
||||
|
||||
### DPO Loss (Reference Model Required)
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
L_DPO = -E[log σ(β * log(π_θ(y_w|x)/π_ref(y_w|x)) - β * log(π_θ(y_l|x)/π_ref(y_l|x)))]
|
||||
```
|
||||
|
||||
**Key features**:
|
||||
- Requires reference model π_ref
|
||||
- Normalizes by reference log probabilities
|
||||
- More conservative (stays close to reference)
|
||||
|
||||
### SimPO Loss (Reference-Free)
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
L_SimPO = -log σ(β * (log π_θ(y_w|x) - log π_θ(y_l|x) - γ/β))
|
||||
```
|
||||
|
||||
**Key features**:
|
||||
- No reference model needed
|
||||
- Direct preference optimization
|
||||
- Target margin γ/β controls preference strength
|
||||
- More efficient (fewer model forward passes)
|
||||
|
||||
**Visual comparison**:
|
||||
```
|
||||
DPO: [Policy] - [Reference] → Loss
|
||||
SimPO: [Policy] → Loss
|
||||
```
|
||||
|
||||
## Average Log Probability Reward
|
||||
|
||||
### Calculation
|
||||
|
||||
**Per-token log probabilities**:
|
||||
```python
|
||||
# Get log probs for each token
|
||||
per_token_logps = log_softmax(logits).gather(dim=-1, index=labels)
|
||||
|
||||
# Create mask to ignore padding
|
||||
loss_mask = (labels != label_pad_token_id)
|
||||
```
|
||||
|
||||
**Average log probability** (if `average_log_prob=True`):
|
||||
```python
|
||||
avg_logp = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
```
|
||||
|
||||
**Sum log probability** (if `average_log_prob=False`):
|
||||
```python
|
||||
sum_logp = (per_token_logps * loss_mask).sum(-1)
|
||||
```
|
||||
|
||||
**Why average?**
|
||||
- Normalizes for sequence length
|
||||
- Prevents bias toward shorter/longer responses
|
||||
- Standard practice in SimPO
|
||||
|
||||
### Reward Metrics
|
||||
|
||||
**Chosen reward**:
|
||||
```python
|
||||
chosen_rewards = beta * policy_chosen_logps.detach()
|
||||
```
|
||||
|
||||
**Rejected reward**:
|
||||
```python
|
||||
rejected_rewards = beta * policy_rejected_logps.detach()
|
||||
```
|
||||
|
||||
**Reward margin**:
|
||||
```python
|
||||
reward_margin = chosen_rewards.mean() - rejected_rewards.mean()
|
||||
```
|
||||
|
||||
## Label Smoothing
|
||||
|
||||
### Formula with Smoothing
|
||||
|
||||
**Sigmoid loss**:
|
||||
```
|
||||
L = -log σ(β * logits) * (1 - ε) - log σ(-β * logits) * ε
|
||||
```
|
||||
|
||||
**Effect**:
|
||||
- ε = 0.0: No smoothing (default)
|
||||
- ε = 0.1: 10% smoothing (soft labels)
|
||||
- ε = 0.5: Maximum smoothing
|
||||
|
||||
**When to use**:
|
||||
- Noisy preference labels
|
||||
- Uncertain preferences
|
||||
- Prevent overconfidence
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
label_smoothing: 0.1 # 10% smoothing
|
||||
```
|
||||
|
||||
## SFT Regularization
|
||||
|
||||
### Combined Loss
|
||||
|
||||
**With SFT component**:
|
||||
```
|
||||
L_total = L_SimPO + λ * L_SFT
|
||||
```
|
||||
|
||||
Where:
|
||||
- L_SFT = cross-entropy loss on chosen responses
|
||||
- λ = `sft_weight` (0.0 to 1.0)
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
if self.sft_weight > 0:
|
||||
sft_loss = -policy_chosen_logps
|
||||
total_loss = simpo_loss + self.sft_weight * sft_loss
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- Preserve model capabilities
|
||||
- Prevent catastrophic forgetting
|
||||
- Fine-tuning instruct models
|
||||
|
||||
**Trade-off**:
|
||||
- Higher sft_weight: Preserve capabilities, less alignment
|
||||
- Lower sft_weight: Stronger alignment, may forget capabilities
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
sft_weight: 0.1 # 10% SFT regularization
|
||||
```
|
||||
|
||||
## Loss Type Selection
|
||||
|
||||
### Sigmoid vs Hinge
|
||||
|
||||
| Aspect | Sigmoid | Hinge |
|
||||
|--------|---------|-------|
|
||||
| Smoothness | Smooth | Non-smooth |
|
||||
| Gradients | Continuous | Discontinuous at margin |
|
||||
| Sparsity | Dense solutions | Sparse solutions |
|
||||
| Interpretability | Probabilistic | Geometric margin |
|
||||
| Use case | **General purpose** | Margin-based tasks |
|
||||
| Recommendation | **Default choice** | Experimental |
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
# Sigmoid (default)
|
||||
loss_type: sigmoid
|
||||
|
||||
# Hinge (alternative)
|
||||
loss_type: hinge
|
||||
```
|
||||
|
||||
## Mathematical Properties
|
||||
|
||||
### Gradient Analysis
|
||||
|
||||
**Sigmoid loss gradient**:
|
||||
```
|
||||
∂L/∂logits = -β * σ(-β * logits) * (1 - ε) + β * σ(β * logits) * ε
|
||||
```
|
||||
|
||||
**Hinge loss gradient**:
|
||||
```
|
||||
∂L/∂logits = -β if logits < 1/β
|
||||
0 otherwise
|
||||
```
|
||||
|
||||
**Implications**:
|
||||
- Sigmoid: Always provides gradient signal
|
||||
- Hinge: No gradient when margin satisfied
|
||||
|
||||
### Convergence Behavior
|
||||
|
||||
**Sigmoid**:
|
||||
- Asymptotically approaches zero loss
|
||||
- Continues optimizing even with large margins
|
||||
- Smoother training curves
|
||||
|
||||
**Hinge**:
|
||||
- Reaches zero loss at margin
|
||||
- Stops optimizing once margin satisfied
|
||||
- May have training plateaus
|
||||
|
||||
## Complete Loss Examples
|
||||
|
||||
### Example 1: Basic SimPO (Sigmoid)
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
loss_type: sigmoid
|
||||
label_smoothing: 0.0
|
||||
sft_weight: 0.0
|
||||
```
|
||||
|
||||
**Loss calculation**:
|
||||
```python
|
||||
# Step 1: Compute log probs
|
||||
chosen_logps = avg_log_prob(policy(chosen)) # e.g., -1.2
|
||||
rejected_logps = avg_log_prob(policy(rejected)) # e.g., -2.5
|
||||
|
||||
# Step 2: Log ratio and margin
|
||||
pi_logratios = -1.2 - (-2.5) = 1.3
|
||||
logits = 1.3 - 0.5 = 0.8
|
||||
|
||||
# Step 3: Sigmoid loss
|
||||
loss = -log(sigmoid(2.0 * 0.8))
|
||||
= -log(sigmoid(1.6))
|
||||
= -log(0.832)
|
||||
= 0.184
|
||||
```
|
||||
|
||||
### Example 2: SimPO with SFT
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
beta: 2.5
|
||||
gamma_beta_ratio: 0.5
|
||||
loss_type: sigmoid
|
||||
sft_weight: 0.1
|
||||
```
|
||||
|
||||
**Loss calculation**:
|
||||
```python
|
||||
# SimPO loss (as above)
|
||||
simpo_loss = 0.184
|
||||
|
||||
# SFT loss
|
||||
sft_loss = -chosen_logps = -(-1.2) = 1.2
|
||||
|
||||
# Total loss
|
||||
total_loss = simpo_loss + 0.1 * sft_loss
|
||||
= 0.184 + 0.12
|
||||
= 0.304
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
### Check Reward Margins
|
||||
|
||||
**Low margin (< 0.5)**:
|
||||
- Preferences not being learned
|
||||
- Increase beta or gamma_beta_ratio
|
||||
|
||||
**High margin (> 5.0)**:
|
||||
- May be overfitting
|
||||
- Reduce beta or learning rate
|
||||
|
||||
**Monitor**:
|
||||
```python
|
||||
reward_margin = chosen_rewards.mean() - rejected_rewards.mean()
|
||||
print(f"Reward margin: {reward_margin:.2f}")
|
||||
```
|
||||
|
||||
### Check Log Probabilities
|
||||
|
||||
**Typical values**:
|
||||
- Chosen: -1.0 to -2.0 (higher is better)
|
||||
- Rejected: -2.0 to -4.0 (lower is worse)
|
||||
|
||||
**Warning signs**:
|
||||
- Both very negative (< -10): Model not learning
|
||||
- Both very positive (> 0): Numerical instability
|
||||
|
||||
## References
|
||||
|
||||
- SimPO paper: https://arxiv.org/abs/2405.14734
|
||||
- DPO paper: https://arxiv.org/abs/2305.18290
|
||||
- Implementation: https://github.com/princeton-nlp/SimPO
|
||||
467
protected/skills-backup/mlops/training/slime/SKILL.md
Normal file
467
protected/skills-backup/mlops/training/slime/SKILL.md
Normal file
@@ -0,0 +1,467 @@
|
||||
---
|
||||
name: slime-rl-training
|
||||
description: Provides guidance for LLM post-training with RL using slime, a Megatron+SGLang framework. Use when training GLM models, implementing custom data generation workflows, or needing tight Megatron-LM integration for RL scaling.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [sglang-router>=0.2.3, ray, torch>=2.0.0, transformers>=4.40.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Reinforcement Learning, Megatron-LM, SGLang, GRPO, Post-Training, GLM]
|
||||
|
||||
---
|
||||
|
||||
# slime: LLM Post-Training Framework for RL Scaling
|
||||
|
||||
slime is an LLM post-training framework from Tsinghua's THUDM team, powering GLM-4.5, GLM-4.6, and GLM-4.7. It connects Megatron-LM for training with SGLang for high-throughput rollout generation.
|
||||
|
||||
## When to Use slime
|
||||
|
||||
**Choose slime when you need:**
|
||||
- Megatron-LM native training with SGLang inference
|
||||
- Custom data generation workflows with flexible data buffers
|
||||
- Training GLM, Qwen3, DeepSeek V3, or Llama 3 models
|
||||
- Research-grade framework with production backing (Z.ai)
|
||||
|
||||
**Consider alternatives when:**
|
||||
- You need enterprise-grade stability features → use **miles**
|
||||
- You want flexible backend swapping → use **verl**
|
||||
- You need PyTorch-native abstractions → use **torchforge**
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Training**: Megatron-LM with full parallelism support (TP, PP, DP, SP)
|
||||
- **Rollout**: SGLang-based high-throughput generation with router
|
||||
- **Data Buffer**: Flexible prompt management and sample storage
|
||||
- **Models**: GLM-4.x, Qwen3, DeepSeek V3/R1, Llama 3
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Data Buffer │
|
||||
│ - Prompt initialization and management │
|
||||
│ - Custom data generation and filtering │
|
||||
│ - Rollout sample storage │
|
||||
└─────────────┬───────────────────────────┬───────────────┘
|
||||
│ │
|
||||
┌─────────────▼───────────┐ ┌─────────────▼───────────────┐
|
||||
│ Training (Megatron-LM) │ │ Rollout (SGLang + Router) │
|
||||
│ - Actor model training │ │ - Response generation │
|
||||
│ - Critic (optional) │ │ - Reward/verifier output │
|
||||
│ - Weight sync to rollout│ │ - Multi-turn support │
|
||||
└─────────────────────────┘ └─────────────────────────────┘
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Recommended: Docker
|
||||
docker pull slimerl/slime:latest
|
||||
docker run --rm --gpus all --ipc=host --shm-size=16g \
|
||||
-it slimerl/slime:latest /bin/bash
|
||||
|
||||
# Inside container
|
||||
cd /root/slime && pip install -e . --no-deps
|
||||
```
|
||||
|
||||
### From Source
|
||||
|
||||
```bash
|
||||
git clone https://github.com/THUDM/slime.git
|
||||
cd slime
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Quick Start: GRPO Training
|
||||
|
||||
```bash
|
||||
# Source model configuration
|
||||
source scripts/models/qwen3-4B.sh
|
||||
|
||||
# Launch training
|
||||
python train.py \
|
||||
--actor-num-nodes 1 \
|
||||
--actor-num-gpus-per-node 4 \
|
||||
--rollout-num-gpus 4 \
|
||||
--advantage-estimator grpo \
|
||||
--use-kl-loss --kl-loss-coef 0.001 \
|
||||
--rollout-batch-size 32 \
|
||||
--n-samples-per-prompt 8 \
|
||||
--global-batch-size 256 \
|
||||
--num-rollout 3000 \
|
||||
--prompt-data /path/to/data.jsonl \
|
||||
${MODEL_ARGS[@]} ${CKPT_ARGS[@]}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Workflow 1: Standard GRPO Training
|
||||
|
||||
Use this workflow for training reasoning models with group-relative advantages.
|
||||
|
||||
### Prerequisites Checklist
|
||||
- [ ] Docker environment or Megatron-LM + SGLang installed
|
||||
- [ ] Model checkpoint (HuggingFace or Megatron format)
|
||||
- [ ] Training data in JSONL format
|
||||
|
||||
### Step 1: Prepare Data
|
||||
|
||||
```python
|
||||
# data.jsonl format
|
||||
{"prompt": "What is 2 + 2?", "label": "4"}
|
||||
{"prompt": "Solve: 3x = 12", "label": "x = 4"}
|
||||
```
|
||||
|
||||
Or with chat format:
|
||||
```python
|
||||
{
|
||||
"prompt": [
|
||||
{"role": "system", "content": "You are a math tutor."},
|
||||
{"role": "user", "content": "What is 15 + 27?"}
|
||||
],
|
||||
"label": "42"
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Configure Model
|
||||
|
||||
Choose a pre-configured model script:
|
||||
|
||||
```bash
|
||||
# List available models
|
||||
ls scripts/models/
|
||||
# glm4-9B.sh, qwen3-4B.sh, qwen3-30B-A3B.sh, deepseek-v3.sh, llama3-8B.sh, ...
|
||||
|
||||
# Source your model
|
||||
source scripts/models/qwen3-4B.sh
|
||||
```
|
||||
|
||||
### Step 3: Launch Training
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--actor-num-nodes 1 \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--advantage-estimator grpo \
|
||||
--use-kl-loss \
|
||||
--kl-loss-coef 0.001 \
|
||||
--prompt-data /path/to/train.jsonl \
|
||||
--input-key prompt \
|
||||
--label-key label \
|
||||
--apply-chat-template \
|
||||
--rollout-batch-size 32 \
|
||||
--n-samples-per-prompt 8 \
|
||||
--global-batch-size 256 \
|
||||
--num-rollout 3000 \
|
||||
--save-interval 100 \
|
||||
--eval-interval 50 \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Step 4: Monitor Training
|
||||
- [ ] Check TensorBoard: `tensorboard --logdir outputs/`
|
||||
- [ ] Verify reward curves are increasing
|
||||
- [ ] Monitor GPU utilization across nodes
|
||||
|
||||
---
|
||||
|
||||
## Workflow 2: Asynchronous Training
|
||||
|
||||
Use async mode for higher throughput by overlapping rollout and training.
|
||||
|
||||
### When to Use Async
|
||||
- Large models with long generation times
|
||||
- High GPU idle time in synchronous mode
|
||||
- Sufficient memory for buffering
|
||||
|
||||
### Launch Async Training
|
||||
|
||||
```bash
|
||||
python train_async.py \
|
||||
--actor-num-nodes 1 \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--advantage-estimator grpo \
|
||||
--async-buffer-size 4 \
|
||||
--prompt-data /path/to/train.jsonl \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Async-Specific Parameters
|
||||
|
||||
```bash
|
||||
--async-buffer-size 4 # Number of rollouts to buffer
|
||||
--update-weights-interval 2 # Sync weights every N rollouts
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Workflow 3: Multi-Turn Agentic Training
|
||||
|
||||
Use this workflow for training agents with tool use or multi-step reasoning.
|
||||
|
||||
### Prerequisites
|
||||
- [ ] Custom generate function for multi-turn logic
|
||||
- [ ] Tool/environment interface
|
||||
|
||||
### Step 1: Define Custom Generate Function
|
||||
|
||||
```python
|
||||
# custom_generate.py
|
||||
async def custom_generate(args, samples, evaluation=False):
|
||||
"""Multi-turn generation with tool calling."""
|
||||
for sample in samples:
|
||||
conversation = sample.prompt
|
||||
|
||||
for turn in range(args.max_turns):
|
||||
# Generate response
|
||||
response = await generate_single(conversation)
|
||||
|
||||
# Check for tool call
|
||||
tool_call = extract_tool_call(response)
|
||||
if tool_call:
|
||||
tool_result = execute_tool(tool_call)
|
||||
conversation.append({"role": "assistant", "content": response})
|
||||
conversation.append({"role": "tool", "content": tool_result})
|
||||
else:
|
||||
break
|
||||
|
||||
sample.response = response
|
||||
sample.reward = compute_reward(sample)
|
||||
|
||||
return samples
|
||||
```
|
||||
|
||||
### Step 2: Launch with Custom Function
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--custom-generate-function-path custom_generate.py \
|
||||
--max-turns 5 \
|
||||
--prompt-data /path/to/agent_data.jsonl \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
See `examples/search-r1/` for a complete multi-turn search example.
|
||||
|
||||
---
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### Three Argument Categories
|
||||
|
||||
slime uses three types of arguments:
|
||||
|
||||
**1. Megatron Arguments** (passed directly):
|
||||
```bash
|
||||
--tensor-model-parallel-size 2
|
||||
--pipeline-model-parallel-size 1
|
||||
--num-layers 32
|
||||
--hidden-size 4096
|
||||
```
|
||||
|
||||
**2. SGLang Arguments** (prefixed with `--sglang-`):
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.8
|
||||
--sglang-context-length 8192
|
||||
--sglang-log-level INFO
|
||||
```
|
||||
|
||||
**3. slime Arguments**:
|
||||
```bash
|
||||
# Resource allocation
|
||||
--actor-num-nodes 1
|
||||
--actor-num-gpus-per-node 8
|
||||
--rollout-num-gpus 8
|
||||
--colocate # Share GPUs between training/inference
|
||||
|
||||
# Data
|
||||
--prompt-data /path/to/data.jsonl
|
||||
--input-key prompt
|
||||
--label-key label
|
||||
|
||||
# Training loop
|
||||
--num-rollout 3000
|
||||
--rollout-batch-size 32
|
||||
--n-samples-per-prompt 8
|
||||
--global-batch-size 256
|
||||
|
||||
# Algorithm
|
||||
--advantage-estimator grpo # or: gspo, ppo, reinforce_plus_plus
|
||||
--use-kl-loss
|
||||
--kl-loss-coef 0.001
|
||||
```
|
||||
|
||||
### Key Constraints
|
||||
|
||||
```
|
||||
rollout_batch_size × n_samples_per_prompt = global_batch_size × num_steps_per_rollout
|
||||
```
|
||||
|
||||
Example: 32 × 8 = 256 × 1
|
||||
|
||||
---
|
||||
|
||||
## Data Buffer System
|
||||
|
||||
slime's data buffer enables flexible data management:
|
||||
|
||||
### Basic Data Source
|
||||
|
||||
```python
|
||||
class RolloutDataSource:
|
||||
def get_samples(self, num_samples):
|
||||
"""Fetch prompts from dataset."""
|
||||
return self.dataset.sample(num_samples)
|
||||
|
||||
def add_samples(self, samples):
|
||||
"""Called after generation (no-op by default)."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Buffered Data Source (Off-Policy)
|
||||
|
||||
```python
|
||||
class RolloutDataSourceWithBuffer(RolloutDataSource):
|
||||
def __init__(self):
|
||||
self.buffer = []
|
||||
|
||||
def add_samples(self, samples):
|
||||
"""Store generated samples for reuse."""
|
||||
self.buffer.extend(samples)
|
||||
|
||||
def buffer_filter(self, args, buffer, num_samples):
|
||||
"""Custom selection logic (prioritized, stratified, etc.)."""
|
||||
return select_best(buffer, num_samples)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Issue: SGLang Engine Crash
|
||||
|
||||
**Symptoms**: Inference engine dies mid-training
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Enable fault tolerance
|
||||
--use-fault-tolerance
|
||||
|
||||
# Increase memory allocation
|
||||
--sglang-mem-fraction-static 0.85
|
||||
|
||||
# Reduce batch size
|
||||
--rollout-batch-size 16
|
||||
```
|
||||
|
||||
### Issue: Weight Sync Timeout
|
||||
|
||||
**Symptoms**: Training hangs after rollout
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Increase sync interval
|
||||
--update-weights-interval 5
|
||||
|
||||
# Use colocated mode (no network transfer)
|
||||
--colocate
|
||||
```
|
||||
|
||||
### Issue: OOM During Training
|
||||
|
||||
**Symptoms**: CUDA OOM in backward pass
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Enable gradient checkpointing
|
||||
--recompute-activations
|
||||
|
||||
# Reduce micro-batch size
|
||||
--micro-batch-size 1
|
||||
|
||||
# Enable sequence parallelism
|
||||
--sequence-parallel
|
||||
```
|
||||
|
||||
### Issue: Slow Data Loading
|
||||
|
||||
**Symptoms**: GPU idle during data fetch
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Increase data workers
|
||||
--num-data-workers 4
|
||||
|
||||
# Use streaming dataset
|
||||
--streaming-data
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model Family | Configurations |
|
||||
|--------------|----------------|
|
||||
| GLM | GLM-4.5, GLM-4.6, GLM-4.7, GLM-Z1-9B |
|
||||
| Qwen | Qwen3 (4B, 8B, 30B-A3B), Qwen3-MoE, Qwen2.5 |
|
||||
| DeepSeek | V3, V3.1, R1 |
|
||||
| Llama | Llama 3 (8B, 70B) |
|
||||
| Others | Kimi K2, Moonlight-16B |
|
||||
|
||||
Each model has pre-configured scripts in `scripts/models/`.
|
||||
|
||||
---
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Co-location Mode
|
||||
|
||||
Share GPUs between training and inference to reduce memory:
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--colocate \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--sglang-mem-fraction-static 0.4 \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Custom Reward Model
|
||||
|
||||
```python
|
||||
# custom_rm.py
|
||||
class CustomRewardModel:
|
||||
def __init__(self, model_path):
|
||||
self.model = load_model(model_path)
|
||||
|
||||
def compute_reward(self, prompts, responses):
|
||||
inputs = self.tokenize(prompts, responses)
|
||||
scores = self.model(inputs)
|
||||
return scores.tolist()
|
||||
```
|
||||
|
||||
```bash
|
||||
--custom-rm-path custom_rm.py
|
||||
```
|
||||
|
||||
### Evaluation Multi-Task
|
||||
|
||||
```bash
|
||||
--eval-prompt-data aime /path/to/aime.jsonl \
|
||||
--eval-prompt-data gsm8k /path/to/gsm8k.jsonl \
|
||||
--n-samples-per-eval-prompt 16
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://thudm.github.io/slime/
|
||||
- **GitHub**: https://github.com/THUDM/slime
|
||||
- **Blog**: https://lmsys.org/blog/2025-07-09-slime/
|
||||
- **Examples**: See `examples/` directory for 14+ worked examples
|
||||
|
||||
@@ -0,0 +1,392 @@
|
||||
# slime API Reference
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
slime operates with a three-module architecture orchestrated by Ray:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Data Buffer │
|
||||
│ - Prompt initialization and management │
|
||||
│ - Custom data generation and filtering │
|
||||
│ - Rollout sample storage │
|
||||
└─────────────┬───────────────────────────┬───────────────┘
|
||||
│ │
|
||||
┌─────────────▼───────────┐ ┌─────────────▼───────────────┐
|
||||
│ Training (Megatron-LM) │ │ Rollout (SGLang + Router) │
|
||||
│ - Actor model training │ │ - Response generation │
|
||||
│ - Critic (optional) │ │ - Reward/verifier output │
|
||||
│ - Weight sync to rollout│ │ - Multi-turn support │
|
||||
└─────────────────────────┘ └─────────────────────────────┘
|
||||
```
|
||||
|
||||
## Core Data Structures
|
||||
|
||||
### Sample Object
|
||||
|
||||
The `Sample` object is the core data structure defined in `slime/utils/types.py`:
|
||||
|
||||
```python
|
||||
from slime.utils.types import Sample
|
||||
|
||||
@dataclass
|
||||
class Sample:
|
||||
# Core fields
|
||||
group_index: Optional[int] # Group index for batching
|
||||
index: Optional[int] # Sample index
|
||||
prompt: str | list[dict] = "" # Input prompt or chat history
|
||||
tokens: list[int] = field(default_factory=list) # Token IDs
|
||||
response: str = "" # Generated response
|
||||
response_length: int = 0 # Response length in tokens
|
||||
label: Optional[str] = None # Ground truth label
|
||||
reward: Optional[float | dict] = None # RL reward signal
|
||||
loss_mask: Optional[list[int]] = None # 1=compute loss, 0=mask
|
||||
status: Status = Status.PENDING # Sample status
|
||||
metadata: dict = field(default_factory=dict) # Custom data
|
||||
|
||||
# Multimodal support
|
||||
multimodal_inputs: Optional[Any] = None # Raw multimodal data (images, videos)
|
||||
multimodal_train_inputs: Optional[Any] = None # Processed multimodal data (pixel_values)
|
||||
|
||||
# Rollout tracking
|
||||
weight_versions: list[str] = field(default_factory=list)
|
||||
rollout_log_probs: Optional[list[float]] = None # Log probs from SGLang
|
||||
rollout_routed_experts: Optional[list[list[int]]] = None # Expert routing (MoE)
|
||||
|
||||
# Control fields
|
||||
remove_sample: bool = False
|
||||
generate_function_path: Optional[str] = None
|
||||
train_metadata: Optional[dict] = None
|
||||
non_generation_time: float = 0.0
|
||||
|
||||
# Speculative decoding info (nested dataclass)
|
||||
@dataclass
|
||||
class SpecInfo:
|
||||
spec_accept_token_num: int = 0
|
||||
spec_draft_token_num: int = 0
|
||||
spec_verify_ct: int = 0
|
||||
completion_token_num: int = 0
|
||||
```
|
||||
|
||||
### Status Enum
|
||||
|
||||
```python
|
||||
class Status(Enum):
|
||||
PENDING = "pending" # Not yet processed
|
||||
COMPLETED = "completed" # Successfully generated
|
||||
TRUNCATED = "truncated" # Hit max length
|
||||
ABORTED = "aborted" # Failed generation
|
||||
FAILED = "failed" # Generation failed
|
||||
```
|
||||
|
||||
## Configuration System
|
||||
|
||||
slime uses three categories of command-line arguments:
|
||||
|
||||
### 1. Megatron Arguments
|
||||
|
||||
All Megatron-LM arguments are supported directly:
|
||||
|
||||
```bash
|
||||
--tensor-model-parallel-size 2
|
||||
--pipeline-model-parallel-size 1
|
||||
--num-layers 32
|
||||
--hidden-size 4096
|
||||
--num-attention-heads 32
|
||||
--seq-length 4096
|
||||
--micro-batch-size 1
|
||||
--global-batch-size 256
|
||||
```
|
||||
|
||||
### 2. SGLang Arguments
|
||||
|
||||
SGLang arguments are prefixed with `--sglang-`:
|
||||
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.8 # GPU memory for KV cache
|
||||
--sglang-context-length 8192 # Maximum context length
|
||||
--sglang-log-level INFO # Logging verbosity
|
||||
--sglang-tp-size 2 # Tensor parallelism
|
||||
--sglang-disable-cuda-graph # Disable CUDA graphs
|
||||
```
|
||||
|
||||
### 3. slime-Specific Arguments
|
||||
|
||||
Defined in `slime/utils/arguments.py`:
|
||||
|
||||
```bash
|
||||
# Resource Allocation
|
||||
--actor-num-nodes 1 # Training nodes
|
||||
--actor-num-gpus-per-node 8 # GPUs per training node
|
||||
--rollout-num-gpus 8 # Total rollout GPUs
|
||||
--rollout-num-gpus-per-engine 2 # GPUs per SGLang engine
|
||||
--colocate # Share GPUs for train/inference
|
||||
|
||||
# Data Configuration
|
||||
--prompt-data /path/to/data.jsonl # Training data path
|
||||
--input-key prompt # Key for prompts in JSON
|
||||
--label-key label # Key for labels in JSON
|
||||
--apply-chat-template # Apply chat formatting
|
||||
|
||||
# Training Loop
|
||||
--num-rollout 3000 # Total rollout iterations
|
||||
--rollout-batch-size 32 # Prompts per rollout
|
||||
--n-samples-per-prompt 8 # Responses per prompt
|
||||
--global-batch-size 256 # Training batch size
|
||||
--num-steps-per-rollout 1 # Training steps per rollout
|
||||
|
||||
# RL Algorithm
|
||||
--advantage-estimator grpo # grpo, gspo, ppo, reinforce_plus_plus
|
||||
--use-kl-loss # Enable KL loss
|
||||
--kl-loss-coef 0.001 # KL coefficient
|
||||
--calculate-per-token-loss # Token-level loss
|
||||
|
||||
# Off-Policy Options
|
||||
--use-tis # Truncated Importance Sampling
|
||||
--tis-threshold 0.9 # TIS threshold
|
||||
--true-on-policy-mode # Force on-policy training
|
||||
```
|
||||
|
||||
## Data Buffer System
|
||||
|
||||
### RolloutDataSource (Base Class)
|
||||
|
||||
```python
|
||||
from slime.data import RolloutDataSource
|
||||
|
||||
class RolloutDataSource:
|
||||
def __init__(self, dataset, args):
|
||||
self.dataset = dataset
|
||||
self.args = args
|
||||
|
||||
def get_samples(self, num_samples: int) -> list[Sample]:
|
||||
"""Fetch prompts from dataset."""
|
||||
return [Sample(prompt=p) for p in self.dataset.sample(num_samples)]
|
||||
|
||||
def add_samples(self, samples: list[Sample]) -> None:
|
||||
"""Called after generation (no-op by default)."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Buffered Data Source (Off-Policy)
|
||||
|
||||
```python
|
||||
from slime.data import RolloutDataSourceWithBuffer
|
||||
|
||||
class RolloutDataSourceWithBuffer(RolloutDataSource):
|
||||
def __init__(self, dataset, args):
|
||||
super().__init__(dataset, args)
|
||||
self.buffer = []
|
||||
|
||||
def add_samples(self, samples: list[Sample]) -> None:
|
||||
"""Store generated samples for reuse."""
|
||||
self.buffer.extend(samples)
|
||||
|
||||
def buffer_filter(self, args, buffer, num_samples) -> list[Sample]:
|
||||
"""Custom selection logic."""
|
||||
# Example: prioritized sampling based on reward
|
||||
sorted_buffer = sorted(buffer, key=lambda s: s.reward, reverse=True)
|
||||
return sorted_buffer[:num_samples]
|
||||
```
|
||||
|
||||
## Custom Functions
|
||||
|
||||
### Custom Generate Function
|
||||
|
||||
For multi-turn or tool-calling scenarios:
|
||||
|
||||
```python
|
||||
# custom_generate.py
|
||||
from slime.data import Sample
|
||||
|
||||
async def custom_generate(args, samples: list[Sample], evaluation: bool = False) -> list[Sample]:
|
||||
"""
|
||||
Custom generation function for multi-turn interactions.
|
||||
|
||||
Args:
|
||||
args: Training arguments
|
||||
samples: List of Sample objects with prompts
|
||||
evaluation: Whether this is an evaluation run
|
||||
|
||||
Returns:
|
||||
List of Sample objects with responses and rewards
|
||||
"""
|
||||
for sample in samples:
|
||||
conversation = sample.prompt if isinstance(sample.prompt, list) else [
|
||||
{"role": "user", "content": sample.prompt}
|
||||
]
|
||||
|
||||
for turn in range(args.max_turns):
|
||||
# Generate response
|
||||
response = await generate_single(conversation)
|
||||
|
||||
# Check for tool call
|
||||
tool_call = extract_tool_call(response)
|
||||
if tool_call:
|
||||
# Execute tool
|
||||
tool_result = await execute_tool(tool_call)
|
||||
conversation.append({"role": "assistant", "content": response})
|
||||
conversation.append({"role": "tool", "content": tool_result})
|
||||
else:
|
||||
# Final response
|
||||
sample.response = response
|
||||
break
|
||||
|
||||
# Compute reward
|
||||
sample.reward = compute_reward(sample)
|
||||
|
||||
# Set loss mask (1 for model tokens, 0 for tool responses)
|
||||
sample.loss_mask = build_loss_mask(sample)
|
||||
|
||||
return samples
|
||||
```
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
python train.py \
|
||||
--custom-generate-function-path custom_generate.py \
|
||||
--max-turns 5
|
||||
```
|
||||
|
||||
### Custom Reward Function
|
||||
|
||||
```python
|
||||
# custom_rm.py
|
||||
from slime.data import Sample
|
||||
|
||||
async def reward_func(args, sample: Sample, **kwargs) -> float:
|
||||
"""
|
||||
Compute reward for a single sample.
|
||||
|
||||
Args:
|
||||
args: Training arguments
|
||||
sample: Sample object with response
|
||||
|
||||
Returns:
|
||||
Reward score (float)
|
||||
"""
|
||||
response = sample.response
|
||||
ground_truth = sample.label or sample.metadata.get("answer", "")
|
||||
|
||||
# Example: exact match reward
|
||||
if response.strip() == ground_truth.strip():
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
# For batched processing (more efficient)
|
||||
async def batched_custom_rm(args, samples: list[Sample]) -> list[float]:
|
||||
"""Batch reward computation."""
|
||||
rewards = []
|
||||
for sample in samples:
|
||||
reward = await reward_func(args, sample)
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
```
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
python train.py \
|
||||
--custom-rm-path custom_rm.py \
|
||||
--group-rm # Enable batched processing
|
||||
```
|
||||
|
||||
## Model Configuration
|
||||
|
||||
### Pre-configured Model Scripts
|
||||
|
||||
Located in `scripts/models/`:
|
||||
|
||||
```bash
|
||||
# List available models
|
||||
ls scripts/models/
|
||||
# glm4-9B.sh, qwen3-4B.sh, qwen3-30B-A3B.sh, deepseek-v3.sh, llama3-8B.sh
|
||||
|
||||
# Source model configuration
|
||||
source scripts/models/qwen3-4B.sh
|
||||
# This sets MODEL_ARGS and CKPT_ARGS arrays
|
||||
```
|
||||
|
||||
### Example Model Script
|
||||
|
||||
```bash
|
||||
# scripts/models/qwen3-4B.sh
|
||||
export MODEL_ARGS=(
|
||||
--num-layers 36
|
||||
--hidden-size 2560
|
||||
--num-attention-heads 20
|
||||
--num-query-groups 4
|
||||
--ffn-hidden-size 6912
|
||||
--max-position-embeddings 32768
|
||||
--rotary-percent 1.0
|
||||
--rotary-base 1000000
|
||||
--swiglu
|
||||
--untie-embeddings-and-output-weights
|
||||
--no-position-embedding
|
||||
--normalization RMSNorm
|
||||
--tokenizer-type HuggingFaceTokenizer
|
||||
--bf16
|
||||
)
|
||||
|
||||
export CKPT_ARGS=(
|
||||
--hf-checkpoint /path/to/qwen3-4b-hf
|
||||
--initial-megatron-checkpoint /path/to/megatron/ckpt
|
||||
)
|
||||
```
|
||||
|
||||
## Async Training
|
||||
|
||||
### Enabling Async Mode
|
||||
|
||||
```bash
|
||||
python train_async.py \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--async-buffer-size 4 \
|
||||
--update-weights-interval 2 \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Async-Specific Parameters
|
||||
|
||||
```bash
|
||||
--async-buffer-size 4 # Number of rollouts to buffer
|
||||
--update-weights-interval 2 # Sync weights every N rollouts
|
||||
```
|
||||
|
||||
**Note**: Colocated mode (`--colocate`) is NOT supported with async training.
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Multi-Task Evaluation
|
||||
|
||||
```bash
|
||||
--eval-prompt-data aime /path/to/aime.jsonl \
|
||||
--eval-prompt-data gsm8k /path/to/gsm8k.jsonl \
|
||||
--n-samples-per-eval-prompt 16 \
|
||||
--eval-interval 50
|
||||
```
|
||||
|
||||
### Evaluation Configuration
|
||||
|
||||
```bash
|
||||
--eval-interval 50 # Evaluate every N rollouts
|
||||
--n-samples-per-eval-prompt 16 # Samples for evaluation
|
||||
--eval-temperature 0.0 # Greedy decoding for eval
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model Family | Configurations |
|
||||
|--------------|----------------|
|
||||
| GLM | GLM-4.5, GLM-4.6, GLM-4.7, GLM-Z1-9B |
|
||||
| Qwen | Qwen3 (4B, 8B, 30B-A3B), Qwen3-MoE, Qwen2.5 |
|
||||
| DeepSeek | V3, V3.1, R1 |
|
||||
| Llama | Llama 3 (8B, 70B) |
|
||||
| Others | Kimi K2, Moonlight-16B |
|
||||
|
||||
## Resources
|
||||
|
||||
- Documentation: https://thudm.github.io/slime/
|
||||
- GitHub: https://github.com/THUDM/slime
|
||||
- Blog: https://lmsys.org/blog/2025-07-09-slime/
|
||||
- Examples: `examples/` directory (14+ worked examples)
|
||||
@@ -0,0 +1,386 @@
|
||||
# slime Troubleshooting Guide
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### SGLang Issues
|
||||
|
||||
#### Issue: SGLang Engine Crash
|
||||
|
||||
**Symptoms**: Inference engine dies mid-training, connection errors
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable fault tolerance**:
|
||||
```bash
|
||||
--use-fault-tolerance
|
||||
```
|
||||
|
||||
2. **Increase memory allocation**:
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.85 # Increase from 0.8
|
||||
```
|
||||
|
||||
3. **Reduce batch size**:
|
||||
```bash
|
||||
--rollout-batch-size 16 # Reduce from 32
|
||||
```
|
||||
|
||||
4. **Disable CUDA graphs** (for debugging):
|
||||
```bash
|
||||
--sglang-disable-cuda-graph
|
||||
```
|
||||
|
||||
#### Issue: SGLang Router Load Imbalance
|
||||
|
||||
**Symptoms**: Some SGLang engines overloaded while others idle
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Adjust routing strategy**:
|
||||
```bash
|
||||
--sglang-router-strategy round_robin
|
||||
```
|
||||
|
||||
2. **Increase number of engines**:
|
||||
```bash
|
||||
--rollout-num-gpus-per-engine 1 # More engines, less GPUs each
|
||||
```
|
||||
|
||||
### Weight Synchronization Issues
|
||||
|
||||
#### Issue: Weight Sync Timeout
|
||||
|
||||
**Symptoms**: Training hangs after rollout, timeout errors
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase sync interval** (async mode):
|
||||
```bash
|
||||
--update-weights-interval 5 # Increase from 2
|
||||
```
|
||||
|
||||
2. **Use colocated mode** (eliminates network transfer):
|
||||
```bash
|
||||
--colocate
|
||||
```
|
||||
|
||||
3. **Check network bandwidth**:
|
||||
```bash
|
||||
# Verify InfiniBand is enabled
|
||||
ibstat
|
||||
```
|
||||
|
||||
#### Issue: Weight Sync Failures in Multi-Node
|
||||
|
||||
**Symptoms**: Nodes fail to receive updated weights
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Set NCCL environment**:
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_SOCKET_IFNAME=eth0
|
||||
export NCCL_IB_DISABLE=0
|
||||
```
|
||||
|
||||
2. **Increase timeout**:
|
||||
```bash
|
||||
export NCCL_TIMEOUT=1800
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
|
||||
#### Issue: OOM During Training
|
||||
|
||||
**Symptoms**: CUDA OOM in backward pass
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable gradient checkpointing**:
|
||||
```bash
|
||||
--recompute-activations
|
||||
```
|
||||
|
||||
2. **Reduce micro-batch size**:
|
||||
```bash
|
||||
--micro-batch-size 1
|
||||
```
|
||||
|
||||
3. **Enable sequence parallelism**:
|
||||
```bash
|
||||
--sequence-parallel
|
||||
```
|
||||
|
||||
4. **Reduce global batch size**:
|
||||
```bash
|
||||
--global-batch-size 128 # Reduce from 256
|
||||
```
|
||||
|
||||
#### Issue: OOM in Colocated Mode
|
||||
|
||||
**Symptoms**: OOM when both training and inference run on same GPUs
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce SGLang memory**:
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.4 # Reduce from 0.8
|
||||
```
|
||||
|
||||
2. **Enable offloading**:
|
||||
```bash
|
||||
--offload-optimizer-states
|
||||
```
|
||||
|
||||
3. **Use smaller sequence length**:
|
||||
```bash
|
||||
--seq-length 2048 # Reduce from 4096
|
||||
```
|
||||
|
||||
### Data Loading Issues
|
||||
|
||||
#### Issue: Slow Data Loading
|
||||
|
||||
**Symptoms**: GPU idle during data fetch, low GPU utilization
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase data workers**:
|
||||
```bash
|
||||
--num-data-workers 4
|
||||
```
|
||||
|
||||
2. **Use streaming dataset**:
|
||||
```bash
|
||||
--streaming-data
|
||||
```
|
||||
|
||||
3. **Pre-tokenize data**:
|
||||
```python
|
||||
# Pre-process data offline
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("model_path")
|
||||
# Save tokenized data
|
||||
```
|
||||
|
||||
#### Issue: Data Format Errors
|
||||
|
||||
**Symptoms**: KeyError, missing fields, parsing failures
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Verify data format**:
|
||||
```python
|
||||
import json
|
||||
with open("data.jsonl") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
assert "prompt" in data, "Missing prompt field"
|
||||
assert "label" in data, "Missing label field"
|
||||
```
|
||||
|
||||
2. **Check key names**:
|
||||
```bash
|
||||
--input-key prompt # Must match your data
|
||||
--label-key label # Must match your data
|
||||
```
|
||||
|
||||
### Training Stability Issues
|
||||
|
||||
#### Issue: Loss Explosion / NaN
|
||||
|
||||
**Symptoms**: Loss becomes NaN or explodes
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce learning rate**:
|
||||
```bash
|
||||
--lr 1e-6 # Reduce from 5e-6
|
||||
```
|
||||
|
||||
2. **Enable gradient clipping**:
|
||||
```bash
|
||||
--clip-grad 1.0
|
||||
```
|
||||
|
||||
3. **Check for data issues**:
|
||||
```python
|
||||
# Verify no empty prompts or responses
|
||||
for sample in dataset:
|
||||
assert len(sample["prompt"]) > 0
|
||||
```
|
||||
|
||||
4. **Use BF16 instead of FP16**:
|
||||
```bash
|
||||
--bf16 # More numerically stable
|
||||
```
|
||||
|
||||
#### Issue: Reward Collapse
|
||||
|
||||
**Symptoms**: Reward drops to zero, model outputs garbage
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase KL penalty**:
|
||||
```bash
|
||||
--kl-loss-coef 0.01 # Increase from 0.001
|
||||
```
|
||||
|
||||
2. **Reduce number of samples**:
|
||||
```bash
|
||||
--n-samples-per-prompt 4 # Reduce from 8
|
||||
```
|
||||
|
||||
3. **Verify reward function**:
|
||||
```python
|
||||
# Test reward function independently
|
||||
from custom_rm import reward_func
|
||||
sample = Sample(prompt="test", response="test response")
|
||||
reward = reward_func(args, sample)
|
||||
print(f"Reward: {reward}") # Should be reasonable
|
||||
```
|
||||
|
||||
### Async Training Issues
|
||||
|
||||
#### Issue: Async Training Not Supported with Colocate
|
||||
|
||||
**Symptoms**: Error when using `--colocate` with `train_async.py`
|
||||
|
||||
**Solution**: Colocated mode is NOT supported for async training. Use separate GPUs:
|
||||
```bash
|
||||
# Remove --colocate flag
|
||||
python train_async.py \
|
||||
--actor-num-gpus-per-node 4 \
|
||||
--rollout-num-gpus 4 \
|
||||
# No --colocate
|
||||
```
|
||||
|
||||
#### Issue: Stale Weights in Async Mode
|
||||
|
||||
**Symptoms**: Policy divergence, inconsistent behavior
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce async buffer size**:
|
||||
```bash
|
||||
--async-buffer-size 2 # Reduce from 4
|
||||
```
|
||||
|
||||
2. **Increase weight update frequency**:
|
||||
```bash
|
||||
--update-weights-interval 1 # Sync every rollout
|
||||
```
|
||||
|
||||
### Multi-Turn Training Issues
|
||||
|
||||
#### Issue: Tool Responses Included in Loss
|
||||
|
||||
**Symptoms**: Model learns to output tool responses verbatim
|
||||
|
||||
**Solution**: Properly set loss mask in custom generate function:
|
||||
```python
|
||||
def build_loss_mask(sample):
|
||||
"""Create loss mask that excludes tool responses."""
|
||||
mask = []
|
||||
for i, token in enumerate(sample.tokens):
|
||||
if is_tool_response(token, sample.metadata):
|
||||
mask.append(0) # Don't compute loss
|
||||
else:
|
||||
mask.append(1) # Compute loss
|
||||
return mask
|
||||
```
|
||||
|
||||
#### Issue: Multi-Turn Context Too Long
|
||||
|
||||
**Symptoms**: OOM or truncation in multi-turn conversations
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Limit conversation history**:
|
||||
```python
|
||||
# In custom generate function
|
||||
conversation = sample.prompt[-10:] # Keep last 10 turns
|
||||
```
|
||||
|
||||
2. **Increase context length**:
|
||||
```bash
|
||||
--sglang-context-length 16384
|
||||
```
|
||||
|
||||
### Checkpoint Issues
|
||||
|
||||
#### Issue: Checkpoint Loading Fails
|
||||
|
||||
**Symptoms**: Cannot load saved checkpoint
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Verify checkpoint path**:
|
||||
```bash
|
||||
ls -la /path/to/checkpoint/
|
||||
```
|
||||
|
||||
2. **Check parallelism matches**:
|
||||
```bash
|
||||
# Checkpoint was saved with TP=2, must load with TP=2
|
||||
--tensor-model-parallel-size 2
|
||||
```
|
||||
|
||||
3. **Convert HuggingFace to Megatron** (if needed):
|
||||
```bash
|
||||
python tools/convert_hf_to_megatron.py \
|
||||
--hf_model_path /path/to/hf/model \
|
||||
--save_path /path/to/megatron/checkpoint
|
||||
```
|
||||
|
||||
### Debugging Tips
|
||||
|
||||
#### Enable Verbose Logging
|
||||
|
||||
```bash
|
||||
--log-level DEBUG
|
||||
export SLIME_DEBUG=1
|
||||
```
|
||||
|
||||
#### Check GPU Utilization
|
||||
|
||||
```bash
|
||||
watch -n 1 nvidia-smi
|
||||
```
|
||||
|
||||
#### Monitor Training
|
||||
|
||||
```bash
|
||||
tensorboard --logdir outputs/
|
||||
```
|
||||
|
||||
#### Test Custom Functions Independently
|
||||
|
||||
```python
|
||||
# Test reward function
|
||||
import asyncio
|
||||
from custom_rm import reward_func
|
||||
|
||||
async def test():
|
||||
sample = Sample(prompt="test", response="test", label="expected")
|
||||
reward = await reward_func(args, sample)
|
||||
print(f"Reward: {reward}")
|
||||
|
||||
asyncio.run(test())
|
||||
```
|
||||
|
||||
## Constraint Reference
|
||||
|
||||
Key constraint to remember:
|
||||
|
||||
```
|
||||
rollout_batch_size × n_samples_per_prompt = global_batch_size × num_steps_per_rollout
|
||||
```
|
||||
|
||||
Example: `32 × 8 = 256 × 1`
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub Issues: https://github.com/THUDM/slime/issues
|
||||
- Documentation: https://thudm.github.io/slime/
|
||||
- Examples: `examples/` directory
|
||||
361
protected/skills-backup/mlops/training/torchtitan/SKILL.md
Normal file
361
protected/skills-backup/mlops/training/torchtitan/SKILL.md
Normal file
@@ -0,0 +1,361 @@
|
||||
---
|
||||
name: distributed-llm-pretraining-torchtitan
|
||||
description: Provides PyTorch-native distributed LLM pretraining using torchtitan with 4D parallelism (FSDP2, TP, PP, CP). Use when pretraining Llama 3.1, DeepSeek V3, or custom models at scale from 8 to 512+ GPUs with Float8, torch.compile, and distributed checkpointing.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [torch>=2.6.0, torchtitan>=0.2.0, torchao>=0.5.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Model Architecture, Distributed Training, TorchTitan, FSDP2, Tensor Parallel, Pipeline Parallel, Context Parallel, Float8, Llama, Pretraining]
|
||||
|
||||
---
|
||||
|
||||
# TorchTitan - PyTorch Native Distributed LLM Pretraining
|
||||
|
||||
## Quick start
|
||||
|
||||
TorchTitan is PyTorch's official platform for large-scale LLM pretraining with composable 4D parallelism (FSDP2, TP, PP, CP), achieving 65%+ speedups over baselines on H100 GPUs.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
# From PyPI (stable)
|
||||
pip install torchtitan
|
||||
|
||||
# From source (latest features, requires PyTorch nightly)
|
||||
git clone https://github.com/pytorch/torchtitan
|
||||
cd torchtitan
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**Download tokenizer**:
|
||||
```bash
|
||||
# Get HF token from https://huggingface.co/settings/tokens
|
||||
python scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets tokenizer --hf_token=...
|
||||
```
|
||||
|
||||
**Start training on 8 GPUs**:
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Pretrain Llama 3.1 8B on single node
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Single Node Pretraining:
|
||||
- [ ] Step 1: Download tokenizer
|
||||
- [ ] Step 2: Configure training
|
||||
- [ ] Step 3: Launch training
|
||||
- [ ] Step 4: Monitor and checkpoint
|
||||
```
|
||||
|
||||
**Step 1: Download tokenizer**
|
||||
|
||||
```bash
|
||||
python scripts/download_hf_assets.py \
|
||||
--repo_id meta-llama/Llama-3.1-8B \
|
||||
--assets tokenizer \
|
||||
--hf_token=YOUR_HF_TOKEN
|
||||
```
|
||||
|
||||
**Step 2: Configure training**
|
||||
|
||||
Edit or create a TOML config file:
|
||||
|
||||
```toml
|
||||
# llama3_8b_custom.toml
|
||||
[job]
|
||||
dump_folder = "./outputs"
|
||||
description = "Llama 3.1 8B training"
|
||||
|
||||
[model]
|
||||
name = "llama3"
|
||||
flavor = "8B"
|
||||
hf_assets_path = "./assets/hf/Llama-3.1-8B"
|
||||
|
||||
[optimizer]
|
||||
name = "AdamW"
|
||||
lr = 3e-4
|
||||
|
||||
[lr_scheduler]
|
||||
warmup_steps = 200
|
||||
|
||||
[training]
|
||||
local_batch_size = 2
|
||||
seq_len = 8192
|
||||
max_norm = 1.0
|
||||
steps = 1000
|
||||
dataset = "c4"
|
||||
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = -1 # Use all GPUs for FSDP
|
||||
|
||||
[activation_checkpoint]
|
||||
mode = "selective"
|
||||
selective_ac_option = "op"
|
||||
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
```
|
||||
|
||||
**Step 3: Launch training**
|
||||
|
||||
```bash
|
||||
# 8 GPUs on single node
|
||||
CONFIG_FILE="./llama3_8b_custom.toml" ./run_train.sh
|
||||
|
||||
# Or explicitly with torchrun
|
||||
torchrun --nproc_per_node=8 \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_8b_custom.toml
|
||||
```
|
||||
|
||||
**Step 4: Monitor and checkpoint**
|
||||
|
||||
TensorBoard logs are saved to `./outputs/tb/`:
|
||||
```bash
|
||||
tensorboard --logdir ./outputs/tb
|
||||
```
|
||||
|
||||
### Workflow 2: Multi-node training with SLURM
|
||||
|
||||
```
|
||||
Multi-Node Training:
|
||||
- [ ] Step 1: Configure parallelism for scale
|
||||
- [ ] Step 2: Set up SLURM script
|
||||
- [ ] Step 3: Submit job
|
||||
- [ ] Step 4: Resume from checkpoint
|
||||
```
|
||||
|
||||
**Step 1: Configure parallelism for scale**
|
||||
|
||||
For 70B model on 256 GPUs (32 nodes):
|
||||
```toml
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = 32 # FSDP across 32 ranks
|
||||
tensor_parallel_degree = 8 # TP within node
|
||||
pipeline_parallel_degree = 1 # No PP for 70B
|
||||
context_parallel_degree = 1 # Increase for long sequences
|
||||
```
|
||||
|
||||
**Step 2: Set up SLURM script**
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=llama70b
|
||||
#SBATCH --nodes=32
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gpus-per-node=8
|
||||
|
||||
srun torchrun \
|
||||
--nnodes=32 \
|
||||
--nproc_per_node=8 \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_70b.toml
|
||||
```
|
||||
|
||||
**Step 3: Submit job**
|
||||
|
||||
```bash
|
||||
sbatch multinode_trainer.slurm
|
||||
```
|
||||
|
||||
**Step 4: Resume from checkpoint**
|
||||
|
||||
Training auto-resumes if checkpoint exists in configured folder.
|
||||
|
||||
### Workflow 3: Enable Float8 training for H100s
|
||||
|
||||
Float8 provides 30-50% speedup on H100 GPUs.
|
||||
|
||||
```
|
||||
Float8 Training:
|
||||
- [ ] Step 1: Install torchao
|
||||
- [ ] Step 2: Configure Float8
|
||||
- [ ] Step 3: Launch with compile
|
||||
```
|
||||
|
||||
**Step 1: Install torchao**
|
||||
|
||||
```bash
|
||||
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
|
||||
```
|
||||
|
||||
**Step 2: Configure Float8**
|
||||
|
||||
Add to your TOML config:
|
||||
```toml
|
||||
[model]
|
||||
converters = ["quantize.linear.float8"]
|
||||
|
||||
[quantize.linear.float8]
|
||||
enable_fsdp_float8_all_gather = true
|
||||
precompute_float8_dynamic_scale_for_fsdp = true
|
||||
filter_fqns = ["output"] # Exclude output layer
|
||||
|
||||
[compile]
|
||||
enable = true
|
||||
components = ["model", "loss"]
|
||||
```
|
||||
|
||||
**Step 3: Launch with compile**
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.enable_fsdp_float8_all_gather \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
### Workflow 4: 4D parallelism for 405B models
|
||||
|
||||
```
|
||||
4D Parallelism (FSDP + TP + PP + CP):
|
||||
- [ ] Step 1: Create seed checkpoint
|
||||
- [ ] Step 2: Configure 4D parallelism
|
||||
- [ ] Step 3: Launch on 512 GPUs
|
||||
```
|
||||
|
||||
**Step 1: Create seed checkpoint**
|
||||
|
||||
Required for consistent initialization across PP stages:
|
||||
```bash
|
||||
NGPU=1 CONFIG_FILE=./llama3_405b.toml ./run_train.sh \
|
||||
--checkpoint.enable \
|
||||
--checkpoint.create_seed_checkpoint \
|
||||
--parallelism.data_parallel_shard_degree 1 \
|
||||
--parallelism.tensor_parallel_degree 1 \
|
||||
--parallelism.pipeline_parallel_degree 1
|
||||
```
|
||||
|
||||
**Step 2: Configure 4D parallelism**
|
||||
|
||||
```toml
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = 8 # FSDP
|
||||
tensor_parallel_degree = 8 # TP within node
|
||||
pipeline_parallel_degree = 8 # PP across nodes
|
||||
context_parallel_degree = 1 # CP for long sequences
|
||||
|
||||
[training]
|
||||
local_batch_size = 32
|
||||
seq_len = 8192
|
||||
```
|
||||
|
||||
**Step 3: Launch on 512 GPUs**
|
||||
|
||||
```bash
|
||||
# 64 nodes x 8 GPUs = 512 GPUs
|
||||
srun torchrun --nnodes=64 --nproc_per_node=8 \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_405b.toml
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use TorchTitan when:**
|
||||
- Pretraining LLMs from scratch (8B to 405B+)
|
||||
- Need PyTorch-native solution without third-party dependencies
|
||||
- Require composable 4D parallelism (FSDP2, TP, PP, CP)
|
||||
- Training on H100s with Float8 support
|
||||
- Want interoperable checkpoints with torchtune/HuggingFace
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Megatron-LM**: Maximum performance for NVIDIA-only deployments
|
||||
- **DeepSpeed**: Broader ZeRO optimization ecosystem, inference support
|
||||
- **Axolotl/TRL**: Fine-tuning rather than pretraining
|
||||
- **LitGPT**: Educational, smaller-scale training
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Out of memory on large models**
|
||||
|
||||
Enable activation checkpointing and reduce batch size:
|
||||
```toml
|
||||
[activation_checkpoint]
|
||||
mode = "full" # Instead of "selective"
|
||||
|
||||
[training]
|
||||
local_batch_size = 1
|
||||
```
|
||||
|
||||
Or use gradient accumulation:
|
||||
```toml
|
||||
[training]
|
||||
local_batch_size = 1
|
||||
global_batch_size = 32 # Accumulates gradients
|
||||
```
|
||||
|
||||
**Issue: TP causes high memory with async collectives**
|
||||
|
||||
Set environment variable:
|
||||
```bash
|
||||
export TORCH_NCCL_AVOID_RECORD_STREAMS=1
|
||||
```
|
||||
|
||||
**Issue: Float8 training not faster**
|
||||
|
||||
Float8 only benefits large GEMMs. Filter small layers:
|
||||
```toml
|
||||
[quantize.linear.float8]
|
||||
filter_fqns = ["attention.wk", "attention.wv", "output", "auto_filter_small_kn"]
|
||||
```
|
||||
|
||||
**Issue: Checkpoint loading fails after parallelism change**
|
||||
|
||||
Use DCP's resharding capability:
|
||||
```bash
|
||||
# Convert sharded checkpoint to single file
|
||||
python -m torch.distributed.checkpoint.format_utils \
|
||||
dcp_to_torch checkpoint/step-1000 checkpoint.pt
|
||||
```
|
||||
|
||||
**Issue: Pipeline parallelism initialization**
|
||||
|
||||
Create seed checkpoint first (see Workflow 4, Step 1).
|
||||
|
||||
## Supported models
|
||||
|
||||
| Model | Sizes | Status |
|
||||
|-------|-------|--------|
|
||||
| Llama 3.1 | 8B, 70B, 405B | Production |
|
||||
| Llama 4 | Various | Experimental |
|
||||
| DeepSeek V3 | 16B, 236B, 671B (MoE) | Experimental |
|
||||
| GPT-OSS | 20B, 120B (MoE) | Experimental |
|
||||
| Qwen 3 | Various | Experimental |
|
||||
| Flux | Diffusion | Experimental |
|
||||
|
||||
## Performance benchmarks (H100)
|
||||
|
||||
| Model | GPUs | Parallelism | TPS/GPU | Techniques |
|
||||
|-------|------|-------------|---------|------------|
|
||||
| Llama 8B | 8 | FSDP | 5,762 | Baseline |
|
||||
| Llama 8B | 8 | FSDP+compile+FP8 | 8,532 | +48% |
|
||||
| Llama 70B | 256 | FSDP+TP+AsyncTP | 876 | 2D parallel |
|
||||
| Llama 405B | 512 | FSDP+TP+PP | 128 | 3D parallel |
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**FSDP2 configuration**: See [references/fsdp.md](references/fsdp.md) for detailed FSDP2 vs FSDP1 comparison and ZeRO equivalents.
|
||||
|
||||
**Float8 training**: See [references/float8.md](references/float8.md) for tensorwise vs rowwise scaling recipes.
|
||||
|
||||
**Checkpointing**: See [references/checkpoint.md](references/checkpoint.md) for HuggingFace conversion and async checkpointing.
|
||||
|
||||
**Adding custom models**: See [references/custom-models.md](references/custom-models.md) for TrainSpec protocol.
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub: https://github.com/pytorch/torchtitan
|
||||
- Paper: https://arxiv.org/abs/2410.06511
|
||||
- ICLR 2025: https://iclr.cc/virtual/2025/poster/29620
|
||||
- PyTorch Forum: https://discuss.pytorch.org/c/distributed/torchtitan/44
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
# Checkpointing in TorchTitan
|
||||
|
||||
TorchTitan uses PyTorch Distributed Checkpoint (DCP) for fault-tolerant, interoperable checkpointing.
|
||||
|
||||
## Basic Configuration
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
```
|
||||
|
||||
## Save Model Only (Smaller Checkpoints)
|
||||
|
||||
Exclude optimizer state and training metadata:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
last_save_model_only = true
|
||||
export_dtype = "bfloat16" # Optional: export in lower precision
|
||||
```
|
||||
|
||||
## Excluding Keys from Loading
|
||||
|
||||
Partial checkpoint loading for modified settings:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
exclude_from_loading = ["data_loader", "lr_scheduler"]
|
||||
```
|
||||
|
||||
CLI equivalent:
|
||||
```bash
|
||||
--checkpoint.exclude_from_loading data_loader,lr_scheduler
|
||||
```
|
||||
|
||||
## Creating Seed Checkpoints
|
||||
|
||||
Required for Pipeline Parallelism to ensure consistent initialization:
|
||||
|
||||
```bash
|
||||
NGPU=1 CONFIG_FILE=<path_to_config> ./run_train.sh \
|
||||
--checkpoint.enable \
|
||||
--checkpoint.create_seed_checkpoint \
|
||||
--parallelism.data_parallel_replicate_degree 1 \
|
||||
--parallelism.data_parallel_shard_degree 1 \
|
||||
--parallelism.tensor_parallel_degree 1 \
|
||||
--parallelism.pipeline_parallel_degree 1 \
|
||||
--parallelism.context_parallel_degree 1 \
|
||||
--parallelism.expert_parallel_degree 1
|
||||
```
|
||||
|
||||
This initializes on single CPU for reproducible initialization across any GPU count.
|
||||
|
||||
## Async Checkpointing
|
||||
|
||||
Reduce checkpoint overhead with async writes:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
async_mode = "async" # Options: "disabled", "async", "async_with_pinned_mem"
|
||||
```
|
||||
|
||||
## HuggingFace Conversion
|
||||
|
||||
### During Training
|
||||
|
||||
Save directly in HuggingFace format:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
last_save_in_hf = true
|
||||
last_save_model_only = true
|
||||
```
|
||||
|
||||
Load from HuggingFace:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
initial_load_in_hf = true
|
||||
|
||||
[model]
|
||||
hf_assets_path = "./path/to/hf/checkpoint"
|
||||
```
|
||||
|
||||
### Offline Conversion
|
||||
|
||||
Convert without running training:
|
||||
|
||||
```bash
|
||||
# HuggingFace -> TorchTitan
|
||||
python ./scripts/checkpoint_conversion/convert_from_hf.py \
|
||||
<input_dir> <output_dir> \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
|
||||
# TorchTitan -> HuggingFace
|
||||
python ./scripts/checkpoint_conversion/convert_to_hf.py \
|
||||
<input_dir> <output_dir> \
|
||||
--hf_assets_path ./assets/hf/Llama3.1-8B \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
```
|
||||
|
||||
### Example
|
||||
|
||||
```bash
|
||||
python ./scripts/convert_from_hf.py \
|
||||
~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ \
|
||||
./initial_load_path/ \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
```
|
||||
|
||||
## Converting to Single .pt File
|
||||
|
||||
Convert DCP sharded checkpoint to single PyTorch file:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.checkpoint.format_utils \
|
||||
dcp_to_torch \
|
||||
torchtitan/outputs/checkpoint/step-1000 \
|
||||
checkpoint.pt
|
||||
```
|
||||
|
||||
## Checkpoint Structure
|
||||
|
||||
DCP saves sharded checkpoints that can be resharded for different parallelism configurations:
|
||||
|
||||
```
|
||||
checkpoint/
|
||||
├── step-500/
|
||||
│ ├── .metadata
|
||||
│ ├── __0_0.distcp
|
||||
│ ├── __0_1.distcp
|
||||
│ └── ...
|
||||
└── step-1000/
|
||||
└── ...
|
||||
```
|
||||
|
||||
## Resume Training
|
||||
|
||||
Training auto-resumes from the latest checkpoint in the configured folder. To resume from a specific step:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
load_step = 500 # Resume from step 500
|
||||
```
|
||||
|
||||
## Interoperability with TorchTune
|
||||
|
||||
Checkpoints saved with `last_save_model_only = true` can be loaded directly into [torchtune](https://github.com/pytorch/torchtune) for fine-tuning.
|
||||
|
||||
## Full Configuration Example
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
load_step = -1 # -1 = latest, or specify step number
|
||||
last_save_model_only = true
|
||||
export_dtype = "bfloat16"
|
||||
async_mode = "async"
|
||||
exclude_from_loading = []
|
||||
last_save_in_hf = false
|
||||
initial_load_in_hf = false
|
||||
create_seed_checkpoint = false
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Large models**: Use `async_mode = "async"` to overlap checkpoint saves with training
|
||||
2. **Fine-tuning export**: Enable `last_save_model_only` and `export_dtype = "bfloat16"` for smaller files
|
||||
3. **Pipeline parallelism**: Always create seed checkpoint first
|
||||
4. **Debugging**: Save frequent checkpoints during development, reduce for production
|
||||
5. **HF interop**: Use conversion scripts for offline conversion, direct save/load for training workflows
|
||||
@@ -0,0 +1,258 @@
|
||||
# Adding Custom Models to TorchTitan
|
||||
|
||||
This guide explains how to add a new model to TorchTitan following the established patterns.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
torchtitan/models/your_model/
|
||||
├── model/
|
||||
│ ├── __init__.py
|
||||
│ ├── args.py # Model arguments
|
||||
│ ├── model.py # Model definition
|
||||
│ └── state_dict_adapter.py # HF conversion (optional)
|
||||
├── infra/
|
||||
│ ├── __init__.py
|
||||
│ ├── parallelize.py # TP, FSDP, compile application
|
||||
│ └── pipeline.py # PP application (optional)
|
||||
├── train_configs/
|
||||
│ ├── debug_model.toml
|
||||
│ └── your_model_XB.toml
|
||||
├── __init__.py # TrainSpec registration
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Step 1: Define Model Arguments
|
||||
|
||||
Inherit from `BaseModelArgs`:
|
||||
|
||||
```python
|
||||
# model/args.py
|
||||
from torchtitan.protocols.model import BaseModelArgs
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class YourModelArgs(BaseModelArgs):
|
||||
dim: int = 4096
|
||||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
vocab_size: int = 128256
|
||||
|
||||
def get_nparams_and_flops(self, seq_len: int) -> tuple[int, int]:
|
||||
"""Return (num_params, flops_per_token) for throughput calculation."""
|
||||
nparams = self.vocab_size * self.dim + ... # Calculate params
|
||||
flops = 6 * nparams # Approximate: 6 * params for forward+backward
|
||||
return nparams, flops
|
||||
|
||||
def update_from_config(self, job_config) -> "YourModelArgs":
|
||||
"""Update args from training config."""
|
||||
# Override specific args from job_config if needed
|
||||
return self
|
||||
```
|
||||
|
||||
## Step 2: Define Model
|
||||
|
||||
Inherit from `ModelProtocol`:
|
||||
|
||||
```python
|
||||
# model/model.py
|
||||
import torch.nn as nn
|
||||
from torchtitan.protocols.model import ModelProtocol
|
||||
from .args import YourModelArgs
|
||||
|
||||
class YourModel(ModelProtocol):
|
||||
def __init__(self, args: YourModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = nn.ModuleDict({
|
||||
str(i): TransformerBlock(args) for i in range(args.n_layers)
|
||||
})
|
||||
self.norm = RMSNorm(args.dim)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
h = self.tok_embeddings(tokens)
|
||||
for layer in self.layers.values():
|
||||
h = layer(h)
|
||||
h = self.norm(h)
|
||||
return self.output(h)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights recursively."""
|
||||
for module in self.modules():
|
||||
if hasattr(module, 'init_weights') and module is not self:
|
||||
module.init_weights()
|
||||
elif isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=0.02)
|
||||
```
|
||||
|
||||
**Important guidelines**:
|
||||
- Write single-device model code (parallelism applied externally)
|
||||
- Use `nn.ModuleDict` for layers (preserves FQNs when deleting for PP)
|
||||
- Make input/output layers optional for PP compatibility
|
||||
- Define `init_weights()` recursively
|
||||
|
||||
## Step 3: Parallelize Function
|
||||
|
||||
```python
|
||||
# infra/parallelize.py
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
|
||||
def parallelize_your_model(
|
||||
model: YourModel,
|
||||
world_mesh: DeviceMesh,
|
||||
parallel_dims: ParallelDims,
|
||||
job_config: JobConfig,
|
||||
):
|
||||
# Apply in this order: TP -> AC -> compile -> FSDP
|
||||
|
||||
# 1. Tensor Parallelism
|
||||
if parallel_dims.tp_enabled:
|
||||
apply_tp(model, world_mesh["tp"], job_config)
|
||||
|
||||
# 2. Activation Checkpointing
|
||||
if job_config.activation_checkpoint.mode == "full":
|
||||
apply_ac(model, job_config)
|
||||
|
||||
# 3. torch.compile
|
||||
if job_config.compile.enable:
|
||||
model = torch.compile(model)
|
||||
|
||||
# 4. FSDP
|
||||
if parallel_dims.dp_enabled:
|
||||
apply_fsdp(model, world_mesh["dp"], job_config)
|
||||
|
||||
return model
|
||||
```
|
||||
|
||||
## Step 4: Create TrainSpec
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
from torchtitan.protocols.train_spec import TrainSpec, register_train_spec
|
||||
from .model.model import YourModel
|
||||
from .model.args import YourModelArgs
|
||||
from .infra.parallelize import parallelize_your_model
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"8B": YourModelArgs(dim=4096, n_layers=32, n_heads=32),
|
||||
"70B": YourModelArgs(dim=8192, n_layers=80, n_heads=64),
|
||||
}
|
||||
|
||||
def get_train_spec(flavor: str) -> TrainSpec:
|
||||
return TrainSpec(
|
||||
model_cls=YourModel,
|
||||
model_args=MODEL_CONFIGS[flavor],
|
||||
parallelize_fn=parallelize_your_model,
|
||||
pipeline_fn=None, # Or your_pipeline_fn for PP
|
||||
build_optimizer_fn=build_optimizer, # Reuse existing
|
||||
build_lr_scheduler_fn=build_lr_scheduler, # Reuse existing
|
||||
build_dataloader_fn=build_dataloader, # Reuse existing
|
||||
build_tokenizer_fn=build_tokenizer, # Reuse existing
|
||||
build_loss_fn=build_loss, # Reuse existing
|
||||
state_dict_adapter=None, # Or YourStateDictAdapter
|
||||
)
|
||||
|
||||
# Register so train.py can find it
|
||||
register_train_spec("your_model", get_train_spec)
|
||||
```
|
||||
|
||||
## Step 5: State Dict Adapter (Optional)
|
||||
|
||||
For HuggingFace checkpoint conversion:
|
||||
|
||||
```python
|
||||
# model/state_dict_adapter.py
|
||||
from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter
|
||||
|
||||
class YourStateDictAdapter(BaseStateDictAdapter):
|
||||
def to_hf(self, state_dict: dict) -> dict:
|
||||
"""Convert torchtitan state dict to HF format."""
|
||||
hf_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
hf_key = self._convert_key_to_hf(key)
|
||||
hf_state_dict[hf_key] = value
|
||||
return hf_state_dict
|
||||
|
||||
def from_hf(self, state_dict: dict) -> dict:
|
||||
"""Convert HF state dict to torchtitan format."""
|
||||
tt_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
tt_key = self._convert_key_from_hf(key)
|
||||
tt_state_dict[tt_key] = value
|
||||
return tt_state_dict
|
||||
```
|
||||
|
||||
## Step 6: Training Config
|
||||
|
||||
```toml
|
||||
# train_configs/your_model_8b.toml
|
||||
[job]
|
||||
dump_folder = "./outputs"
|
||||
description = "Your Model 8B training"
|
||||
|
||||
[model]
|
||||
name = "your_model"
|
||||
flavor = "8B"
|
||||
|
||||
[optimizer]
|
||||
name = "AdamW"
|
||||
lr = 3e-4
|
||||
|
||||
[training]
|
||||
local_batch_size = 2
|
||||
seq_len = 8192
|
||||
steps = 1000
|
||||
dataset = "c4"
|
||||
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = -1
|
||||
tensor_parallel_degree = 1
|
||||
```
|
||||
|
||||
## Step 7: Register Model
|
||||
|
||||
Add to `torchtitan/models/__init__.py`:
|
||||
|
||||
```python
|
||||
from .your_model import get_train_spec as get_your_model_train_spec
|
||||
|
||||
MODEL_REGISTRY["your_model"] = get_your_model_train_spec
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
### Numerics Test
|
||||
|
||||
Compare output with HuggingFace implementation:
|
||||
|
||||
```python
|
||||
def test_numerics():
|
||||
# Load same checkpoint into both implementations
|
||||
tt_model = YourModel(args).load_checkpoint(...)
|
||||
hf_model = HFYourModel.from_pretrained(...)
|
||||
|
||||
# Compare outputs
|
||||
input_ids = torch.randint(0, vocab_size, (1, 128))
|
||||
tt_output = tt_model(input_ids)
|
||||
hf_output = hf_model(input_ids).logits
|
||||
|
||||
torch.testing.assert_close(tt_output, hf_output, atol=1e-4, rtol=1e-4)
|
||||
```
|
||||
|
||||
### Loss Convergence
|
||||
|
||||
Compare loss curves with verified baseline (see `docs/converging.md`).
|
||||
|
||||
### Performance Benchmark
|
||||
|
||||
Add benchmark config to `benchmarks/` folder.
|
||||
|
||||
## Guiding Principles
|
||||
|
||||
1. **Readability over flexibility**: Don't over-abstract
|
||||
2. **Minimal model changes**: Parallelism applied externally
|
||||
3. **Clean, minimal codebase**: Reuse existing components where possible
|
||||
4. **Single-device semantics**: Model code should work on single GPU
|
||||
@@ -0,0 +1,133 @@
|
||||
# Float8 Training in TorchTitan
|
||||
|
||||
Float8 training provides substantial speedups for models where GEMMs are large enough that the FP8 tensorcore speedup outweighs dynamic quantization overhead.
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
- NVIDIA H100 or newer GPUs (FP8 Tensor Cores)
|
||||
- Blackwell GPUs for MXFP8 training
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
|
||||
```
|
||||
|
||||
## Usage: Tensorwise Scaling
|
||||
|
||||
Standard Float8 with tensorwise dynamic scaling:
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.enable_fsdp_float8_all_gather \
|
||||
--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
### Key Arguments
|
||||
|
||||
| Argument | Description |
|
||||
|----------|-------------|
|
||||
| `--model.converters="quantize.linear.float8"` | Swap `nn.Linear` with `Float8Linear` |
|
||||
| `--quantize.linear.float8.enable_fsdp_float8_all_gather` | Communicate in float8 to save bandwidth |
|
||||
| `--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp` | Single all-reduce for all AMAX/scales |
|
||||
| `--compile.enable` | Required - fuses float8 scaling/casting kernels |
|
||||
|
||||
## Usage: Rowwise Scaling
|
||||
|
||||
Higher accuracy than tensorwise scaling:
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.recipe_name rowwise \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
## Filtering Layers
|
||||
|
||||
Not all layers benefit from Float8. Filter small layers:
|
||||
|
||||
```bash
|
||||
--quantize.linear.float8.filter_fqns="attention.wk,attention.wv,output"
|
||||
```
|
||||
|
||||
### Auto-filtering
|
||||
|
||||
Automatically skip layers too small to benefit:
|
||||
|
||||
```bash
|
||||
--quantize.linear.float8.filter_fqns="auto_filter_small_kn"
|
||||
```
|
||||
|
||||
Thresholds based on H100 microbenchmarks where speedup > overhead.
|
||||
|
||||
## TOML Configuration
|
||||
|
||||
```toml
|
||||
[model]
|
||||
converters = ["quantize.linear.float8"]
|
||||
|
||||
[quantize.linear.float8]
|
||||
enable_fsdp_float8_all_gather = true
|
||||
precompute_float8_dynamic_scale_for_fsdp = true
|
||||
filter_fqns = ["output", "auto_filter_small_kn"]
|
||||
|
||||
[compile]
|
||||
enable = true
|
||||
components = ["model", "loss"]
|
||||
```
|
||||
|
||||
## How Float8 Works with Distributed Training
|
||||
|
||||
### Single Device
|
||||
|
||||
Cast input and weight to float8 inside forward before calling `torch._scaled_mm`:
|
||||
|
||||
```python
|
||||
# Float8 matmul requires scales
|
||||
torch._scaled_mm(input_fp8, weight_fp8, scale_a=scale_input, scale_b=scale_weight)
|
||||
```
|
||||
|
||||
### FSDP + Float8
|
||||
|
||||
1. Cast sharded high-precision weights (1/N per rank) to float8
|
||||
2. Perform float8 all-gather (saves bandwidth vs bf16/fp32)
|
||||
3. Communicate `max(abs)` across ranks for scale computation
|
||||
4. At forward start, have unsharded float8 weights ready
|
||||
|
||||
**Net benefit**: Float8 all-gather + amax communication can beat bf16/fp32 all-gather, depending on world size and message size.
|
||||
|
||||
### TP + Float8
|
||||
|
||||
- **Input**: Cast sharded input to float8, all-gather in float8
|
||||
- **Weights**: Communicate `max(abs)` for sharded weights
|
||||
- **Matmul**: Float8 input (unsharded) x float8 weight (sharded) with global scales
|
||||
|
||||
## Scaling Strategies
|
||||
|
||||
| Strategy | Status | Description |
|
||||
|----------|--------|-------------|
|
||||
| Tensorwise dynamic | Stable | Single scale per tensor |
|
||||
| Rowwise dynamic | Alpha | Scale per row, higher accuracy |
|
||||
|
||||
## Performance Gains
|
||||
|
||||
From benchmarks on H100:
|
||||
|
||||
| Configuration | TPS/GPU | vs Baseline |
|
||||
|---------------|---------|-------------|
|
||||
| FSDP only | 5,762 | - |
|
||||
| FSDP + compile | 6,667 | +16% |
|
||||
| FSDP + compile + Float8 | 8,532 | +48% |
|
||||
|
||||
## Determining Float8 Benefit
|
||||
|
||||
Check [torchao microbenchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) for forward+backward pass speedups on "layer norm => linear => sigmoid" for different M,N,K sizes.
|
||||
|
||||
Rule of thumb: GEMMs with K,N > 4096 typically benefit from Float8.
|
||||
|
||||
## MXFP8 Training (Blackwell)
|
||||
|
||||
For NVIDIA Blackwell GPUs, TorchTitan supports MXFP8 (Microscaling FP8) for both dense and MoE models. See [docs/mxfp8.md](https://github.com/pytorch/torchtitan/blob/main/docs/mxfp8.md) for details.
|
||||
@@ -0,0 +1,126 @@
|
||||
# FSDP2 in TorchTitan
|
||||
|
||||
## Why FSDP2?
|
||||
|
||||
FSDP2 is a rewrite of PyTorch's Fully Sharded Data Parallel (FSDP) API, removing the `FlatParameter` abstraction for better composability and simpler implementation.
|
||||
|
||||
### Key improvements over FSDP1
|
||||
|
||||
- **DTensor-based sharding**: Sharded parameters are `DTensor`s on dim-0, enabling easy manipulation and communication-free sharded state dicts
|
||||
- **Better memory management**: Deterministic and lower GPU memory (7% reduction) by avoiding `recordStream`
|
||||
- **Simplified API**: Fewer arguments, no wrapper class
|
||||
|
||||
### Performance
|
||||
|
||||
On Llama-7B with 8x H100s, FSDP2 achieves higher MFU with 7% lower peak memory than FSDP1, matching the same loss curve.
|
||||
|
||||
## API Reference
|
||||
|
||||
```python
|
||||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, OffloadPolicy
|
||||
|
||||
@contract(state_cls=FSDPState)
|
||||
def fully_shard(
|
||||
module: nn.Module,
|
||||
*,
|
||||
mesh: Optional[DeviceMesh] = None,
|
||||
reshard_after_forward: Union[bool, int] = True,
|
||||
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
|
||||
offload_policy: OffloadPolicy = OffloadPolicy(),
|
||||
) -> nn.Module:
|
||||
```
|
||||
|
||||
## Sharding Strategies (ZeRO Equivalents)
|
||||
|
||||
| FSDP2 Configuration | FSDP1 Equivalent | DeepSpeed |
|
||||
|---------------------|------------------|-----------|
|
||||
| 1D mesh + `reshard_after_forward=True` | FULL_SHARD | ZeRO-3 |
|
||||
| 1D mesh + `reshard_after_forward=False` | SHARD_GRAD_OP | ZeRO-2 |
|
||||
| 2D mesh + `reshard_after_forward=True` | HYBRID_SHARD | MiCS |
|
||||
| 1D/2D mesh + `reshard_after_forward=8` (int) | - | ZeRO++ hpZ |
|
||||
|
||||
## Meta-Device Initialization
|
||||
|
||||
FSDP2 supports materializing tensors onto GPU _after_ sharding:
|
||||
|
||||
```python
|
||||
# Initialize on meta device (no memory)
|
||||
with torch.device("meta"):
|
||||
model = Transformer()
|
||||
|
||||
# Apply FSDP2 sharding
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
fully_shard(module)
|
||||
fully_shard(model)
|
||||
|
||||
# Parameters still on meta device
|
||||
for tensor in itertools.chain(model.parameters(), model.buffers()):
|
||||
assert tensor.device == torch.device("meta")
|
||||
|
||||
# Allocate sharded parameters on GPU
|
||||
model.to_empty(device="cuda")
|
||||
|
||||
# Initialize weights
|
||||
model.init_weights()
|
||||
```
|
||||
|
||||
## State Dict Differences
|
||||
|
||||
| Operation | FSDP1 | FSDP2 |
|
||||
|-----------|-------|-------|
|
||||
| `model.state_dict()` | Full state dict | Sharded state dict (no communication) |
|
||||
| `optim.state_dict()` | Local state dict | Sharded state dict (no communication) |
|
||||
| `summon_full_params()` | Supported | Use `DTensor` APIs like `full_tensor()` |
|
||||
| Gradient clipping | `FSDP.clip_grad_norm_()` | `nn.utils.clip_grad_norm_()` |
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
```python
|
||||
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.float32,
|
||||
output_dtype=torch.bfloat16,
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
|
||||
fully_shard(model, mp_policy=mp_policy)
|
||||
```
|
||||
|
||||
## HSDP (Hybrid Sharded Data Parallel)
|
||||
|
||||
For 2D parallelism with replication + sharding:
|
||||
|
||||
```python
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
# Replicate across 4 groups, shard within 8 GPUs each
|
||||
mesh = init_device_mesh("cuda", (4, 8), mesh_dim_names=("replicate", "shard"))
|
||||
|
||||
fully_shard(model, mesh=mesh)
|
||||
```
|
||||
|
||||
## Configuration in TorchTitan
|
||||
|
||||
```toml
|
||||
[parallelism]
|
||||
# FSDP sharding degree (-1 = auto, use all available GPUs)
|
||||
data_parallel_shard_degree = -1
|
||||
|
||||
# HSDP replication degree (1 = pure FSDP, >1 = HSDP)
|
||||
data_parallel_replicate_degree = 1
|
||||
```
|
||||
|
||||
## Removed Arguments from FSDP1
|
||||
|
||||
These FSDP1 arguments are no longer needed:
|
||||
|
||||
- `auto_wrap_policy`: Apply `fully_shard` directly to modules
|
||||
- `backward_prefetch`: Always uses BACKWARD_PRE
|
||||
- `param_init_fn`: Use meta-device initialization
|
||||
- `device_id`: Uses mesh's device automatically
|
||||
- `sync_module_states`: Not needed with DTensor
|
||||
- `limit_all_gathers`: New memory management doesn't need it
|
||||
- `use_orig_params`: Always true (no FlatParameter)
|
||||
458
protected/skills-backup/mlops/training/trl-fine-tuning/SKILL.md
Normal file
458
protected/skills-backup/mlops/training/trl-fine-tuning/SKILL.md
Normal file
@@ -0,0 +1,458 @@
|
||||
---
|
||||
name: fine-tuning-with-trl
|
||||
description: Fine-tune LLMs using reinforcement learning with TRL - SFT for instruction tuning, DPO for preference alignment, PPO/GRPO for reward optimization, and reward model training. Use when need RLHF, align model with preferences, or train from human feedback. Works with HuggingFace Transformers.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [trl, transformers, datasets, peft, accelerate, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, TRL, Reinforcement Learning, Fine-Tuning, SFT, DPO, PPO, GRPO, RLHF, Preference Alignment, HuggingFace]
|
||||
|
||||
---
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
## Quick start
|
||||
|
||||
TRL provides post-training methods for aligning language models with human preferences.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install trl transformers datasets peft accelerate
|
||||
```
|
||||
|
||||
**Supervised Fine-Tuning** (instruction tuning):
|
||||
```python
|
||||
from trl import SFTTrainer
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset, # Prompt-completion pairs
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**DPO** (align with preferences):
|
||||
```python
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
|
||||
config = DPOConfig(output_dir="model-dpo", beta=0.1)
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=preference_dataset, # chosen/rejected pairs
|
||||
processing_class=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Full RLHF pipeline (SFT → Reward Model → PPO)
|
||||
|
||||
Complete pipeline from base model to human-aligned model.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
RLHF Training:
|
||||
- [ ] Step 1: Supervised fine-tuning (SFT)
|
||||
- [ ] Step 2: Train reward model
|
||||
- [ ] Step 3: PPO reinforcement learning
|
||||
- [ ] Step 4: Evaluate aligned model
|
||||
```
|
||||
|
||||
**Step 1: Supervised fine-tuning**
|
||||
|
||||
Train base model on instruction-following data:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
|
||||
# Load instruction dataset
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# Configure training
|
||||
training_args = SFTConfig(
|
||||
output_dir="Qwen2.5-0.5B-SFT",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=2e-5,
|
||||
logging_steps=10,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**Step 2: Train reward model**
|
||||
|
||||
Train model to predict human preferences:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
|
||||
# Load SFT model as base
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen2.5-0.5B-SFT",
|
||||
num_labels=1 # Single reward score
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen2.5-0.5B-SFT")
|
||||
|
||||
# Load preference data (chosen/rejected pairs)
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
# Configure training
|
||||
training_args = RewardConfig(
|
||||
output_dir="Qwen2.5-0.5B-Reward",
|
||||
per_device_train_batch_size=2,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5
|
||||
)
|
||||
|
||||
# Train reward model
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**Step 3: PPO reinforcement learning**
|
||||
|
||||
Optimize policy using reward model:
|
||||
|
||||
```bash
|
||||
python -m trl.scripts.ppo \
|
||||
--model_name_or_path Qwen2.5-0.5B-SFT \
|
||||
--reward_model_path Qwen2.5-0.5B-Reward \
|
||||
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
||||
--output_dir Qwen2.5-0.5B-PPO \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--total_episodes 10000
|
||||
```
|
||||
|
||||
**Step 4: Evaluate**
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Load aligned model
|
||||
generator = pipeline("text-generation", model="Qwen2.5-0.5B-PPO")
|
||||
|
||||
# Test
|
||||
prompt = "Explain quantum computing to a 10-year-old"
|
||||
output = generator(prompt, max_length=200)[0]["generated_text"]
|
||||
print(output)
|
||||
```
|
||||
|
||||
### Workflow 2: Simple preference alignment with DPO
|
||||
|
||||
Align model with preferences without reward model.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
DPO Training:
|
||||
- [ ] Step 1: Prepare preference dataset
|
||||
- [ ] Step 2: Configure DPO
|
||||
- [ ] Step 3: Train with DPOTrainer
|
||||
- [ ] Step 4: Evaluate alignment
|
||||
```
|
||||
|
||||
**Step 1: Prepare preference dataset**
|
||||
|
||||
Dataset format:
|
||||
```json
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"chosen": "The capital of France is Paris.",
|
||||
"rejected": "I don't know."
|
||||
}
|
||||
```
|
||||
|
||||
Load dataset:
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
# Or load your own
|
||||
# dataset = load_dataset("json", data_files="preferences.json")
|
||||
```
|
||||
|
||||
**Step 2: Configure DPO**
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
config = DPOConfig(
|
||||
output_dir="Qwen2.5-0.5B-DPO",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=5e-7,
|
||||
beta=0.1, # KL penalty strength
|
||||
max_prompt_length=512,
|
||||
max_length=1024,
|
||||
logging_steps=10
|
||||
)
|
||||
```
|
||||
|
||||
**Step 3: Train with DPOTrainer**
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**CLI alternative**:
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name argilla/Capybara-Preferences \
|
||||
--output_dir Qwen2.5-0.5B-DPO \
|
||||
--per_device_train_batch_size 4 \
|
||||
--learning_rate 5e-7 \
|
||||
--beta 0.1
|
||||
```
|
||||
|
||||
### Workflow 3: Memory-efficient online RL with GRPO
|
||||
|
||||
Train with reinforcement learning using minimal memory.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
GRPO Training:
|
||||
- [ ] Step 1: Define reward function
|
||||
- [ ] Step 2: Configure GRPO
|
||||
- [ ] Step 3: Train with GRPOTrainer
|
||||
```
|
||||
|
||||
**Step 1: Define reward function**
|
||||
|
||||
```python
|
||||
def reward_function(completions, **kwargs):
|
||||
"""
|
||||
Compute rewards for completions.
|
||||
|
||||
Args:
|
||||
completions: List of generated texts
|
||||
|
||||
Returns:
|
||||
List of reward scores (floats)
|
||||
"""
|
||||
rewards = []
|
||||
for completion in completions:
|
||||
# Example: reward based on length and unique words
|
||||
score = len(completion.split()) # Favor longer responses
|
||||
score += len(set(completion.lower().split())) # Reward unique words
|
||||
rewards.append(score)
|
||||
return rewards
|
||||
```
|
||||
|
||||
Or use a reward model:
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
reward_model = pipeline("text-classification", model="reward-model-path")
|
||||
|
||||
def reward_from_model(completions, prompts, **kwargs):
|
||||
# Combine prompt + completion
|
||||
full_texts = [p + c for p, c in zip(prompts, completions)]
|
||||
# Get reward scores
|
||||
results = reward_model(full_texts)
|
||||
return [r["score"] for r in results]
|
||||
```
|
||||
|
||||
**Step 2: Configure GRPO**
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
config = GRPOConfig(
|
||||
output_dir="Qwen2-GRPO",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5,
|
||||
num_generations=4, # Generate 4 completions per prompt
|
||||
max_new_tokens=128
|
||||
)
|
||||
```
|
||||
|
||||
**Step 3: Train with GRPOTrainer**
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load prompt-only dataset
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_function, # Your reward function
|
||||
args=config,
|
||||
train_dataset=dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**CLI**:
|
||||
```bash
|
||||
trl grpo \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--output_dir Qwen2-GRPO \
|
||||
--num_generations 4
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use TRL when:**
|
||||
- Need to align model with human preferences
|
||||
- Have preference data (chosen/rejected pairs)
|
||||
- Want to use reinforcement learning (PPO, GRPO)
|
||||
- Need reward model training
|
||||
- Doing RLHF (full pipeline)
|
||||
|
||||
**Method selection**:
|
||||
- **SFT**: Have prompt-completion pairs, want basic instruction following
|
||||
- **DPO**: Have preferences, want simple alignment (no reward model needed)
|
||||
- **PPO**: Have reward model, need maximum control over RL
|
||||
- **GRPO**: Memory-constrained, want online RL
|
||||
- **Reward Model**: Building RLHF pipeline, need to score generations
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **HuggingFace Trainer**: Basic fine-tuning without RL
|
||||
- **Axolotl**: YAML-based training configuration
|
||||
- **LitGPT**: Educational, minimal fine-tuning
|
||||
- **Unsloth**: Fast LoRA training
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: OOM during DPO training**
|
||||
|
||||
Reduce batch size and sequence length:
|
||||
```python
|
||||
config = DPOConfig(
|
||||
per_device_train_batch_size=1, # Reduce from 4
|
||||
max_length=512, # Reduce from 1024
|
||||
gradient_accumulation_steps=8 # Maintain effective batch
|
||||
)
|
||||
```
|
||||
|
||||
Or use gradient checkpointing:
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
|
||||
**Issue: Poor alignment quality**
|
||||
|
||||
Tune beta parameter:
|
||||
```python
|
||||
# Higher beta = more conservative (stays closer to reference)
|
||||
config = DPOConfig(beta=0.5) # Default 0.1
|
||||
|
||||
# Lower beta = more aggressive alignment
|
||||
config = DPOConfig(beta=0.01)
|
||||
```
|
||||
|
||||
**Issue: Reward model not learning**
|
||||
|
||||
Check loss type and learning rate:
|
||||
```python
|
||||
config = RewardConfig(
|
||||
learning_rate=1e-5, # Try different LR
|
||||
num_train_epochs=3 # Train longer
|
||||
)
|
||||
```
|
||||
|
||||
Ensure preference dataset has clear winners:
|
||||
```python
|
||||
# Verify dataset
|
||||
print(dataset[0])
|
||||
# Should have clear chosen > rejected
|
||||
```
|
||||
|
||||
**Issue: PPO training unstable**
|
||||
|
||||
Adjust KL coefficient:
|
||||
```python
|
||||
config = PPOConfig(
|
||||
kl_coef=0.1, # Increase from 0.05
|
||||
cliprange=0.1 # Reduce from 0.2
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**SFT training guide**: See [references/sft-training.md](references/sft-training.md) for dataset formats, chat templates, packing strategies, and multi-GPU training.
|
||||
|
||||
**DPO variants**: See [references/dpo-variants.md](references/dpo-variants.md) for IPO, cDPO, RPO, and other DPO loss functions with recommended hyperparameters.
|
||||
|
||||
**Reward modeling**: See [references/reward-modeling.md](references/reward-modeling.md) for outcome vs process rewards, Bradley-Terry loss, and reward model evaluation.
|
||||
|
||||
**Online RL methods**: See [references/online-rl.md](references/online-rl.md) for PPO, GRPO, RLOO, and OnlineDPO with detailed configurations.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA (CUDA required)
|
||||
- **VRAM**: Depends on model and method
|
||||
- SFT 7B: 16GB (with LoRA)
|
||||
- DPO 7B: 24GB (stores reference model)
|
||||
- PPO 7B: 40GB (policy + reward model)
|
||||
- GRPO 7B: 24GB (more memory efficient)
|
||||
- **Multi-GPU**: Supported via `accelerate`
|
||||
- **Mixed precision**: BF16 recommended (A100/H100)
|
||||
|
||||
**Memory optimization**:
|
||||
- Use LoRA/QLoRA for all methods
|
||||
- Enable gradient checkpointing
|
||||
- Use smaller batch sizes with gradient accumulation
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://huggingface.co/docs/trl/
|
||||
- GitHub: https://github.com/huggingface/trl
|
||||
- Papers:
|
||||
- "Training language models to follow instructions with human feedback" (InstructGPT, 2022)
|
||||
- "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" (DPO, 2023)
|
||||
- "Group Relative Policy Optimization" (GRPO, 2024)
|
||||
- Examples: https://github.com/huggingface/trl/tree/main/examples/scripts
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
# DPO Variants
|
||||
|
||||
Complete guide to Direct Preference Optimization loss variants in TRL.
|
||||
|
||||
## Overview
|
||||
|
||||
DPO optimizes models using preference data (chosen/rejected pairs). TRL supports 10+ loss variants for different scenarios.
|
||||
|
||||
## Loss Types
|
||||
|
||||
### 1. Sigmoid (Standard DPO)
|
||||
|
||||
**Formula**: `-log(sigmoid(β * logits))`
|
||||
|
||||
**When to use**: Default choice, general preference alignment
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="sigmoid",
|
||||
beta=0.1, # KL penalty
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=1e-6
|
||||
)
|
||||
```
|
||||
|
||||
### 2. IPO (Identity Policy Optimization)
|
||||
|
||||
**Formula**: `(logits - 1/(2β))²`
|
||||
|
||||
**When to use**: Better theoretical foundation, reduce overfitting
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="ipo",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=90,
|
||||
learning_rate=1e-2
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Hinge (SLiC)
|
||||
|
||||
**Formula**: `ReLU(1 - β * logits)`
|
||||
|
||||
**When to use**: Margin-based objective
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="hinge",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=512,
|
||||
learning_rate=1e-4
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Robust DPO
|
||||
|
||||
**Formula**: Sigmoid with label smoothing for noise robustness
|
||||
|
||||
**When to use**: Noisy preference labels
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="robust",
|
||||
beta=0.01,
|
||||
label_smoothing=0.1, # Noise probability
|
||||
per_device_train_batch_size=16,
|
||||
learning_rate=1e-3,
|
||||
max_prompt_length=128,
|
||||
max_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 5. BCO Pair (Binary Classification)
|
||||
|
||||
**Formula**: Train binary classifier (chosen=1, rejected=0)
|
||||
|
||||
**When to use**: Pairwise preference data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="bco_pair",
|
||||
beta=0.01,
|
||||
per_device_train_batch_size=128,
|
||||
learning_rate=5e-7,
|
||||
max_prompt_length=1536,
|
||||
max_completion_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 6. SPPO Hard
|
||||
|
||||
**Formula**: Push chosen→0.5, rejected→-0.5
|
||||
|
||||
**When to use**: Nash equilibrium, sparse data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="sppo_hard",
|
||||
beta=0.1
|
||||
)
|
||||
```
|
||||
|
||||
### 7. DiscoPOP
|
||||
|
||||
**Formula**: Log-Ratio Modulated Loss
|
||||
|
||||
**When to use**: Automated loss discovery
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="discopop",
|
||||
beta=0.05,
|
||||
discopop_tau=0.05,
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=5e-7
|
||||
)
|
||||
```
|
||||
|
||||
### 8. APO Zero
|
||||
|
||||
**Formula**: Increase chosen, decrease rejected likelihood
|
||||
|
||||
**When to use**: Model worse than winning outputs
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="apo_zero",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=2e-7,
|
||||
max_prompt_length=512,
|
||||
max_completion_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 9. APO Down
|
||||
|
||||
**Formula**: Decrease both, emphasize rejected reduction
|
||||
|
||||
**When to use**: Model better than winning outputs
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="apo_down",
|
||||
beta=0.1,
|
||||
# Same hyperparameters as apo_zero
|
||||
)
|
||||
```
|
||||
|
||||
### 10. AOT & AOT Pair
|
||||
|
||||
**Formula**: Distributional alignment via stochastic dominance
|
||||
|
||||
**When to use**:
|
||||
- `aot_pair`: Paired preference data
|
||||
- `aot`: Unpaired data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="aot_pair", # or "aot"
|
||||
beta=0.1,
|
||||
label_smoothing=0.0
|
||||
)
|
||||
```
|
||||
|
||||
## Multi-Loss Training
|
||||
|
||||
Combine multiple losses:
|
||||
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type=["sigmoid", "ipo"],
|
||||
loss_weights=[0.7, 0.3], # Weighted combination
|
||||
beta=0.1
|
||||
)
|
||||
```
|
||||
|
||||
## Key Parameters
|
||||
|
||||
### Beta (β)
|
||||
|
||||
Controls deviation from reference model:
|
||||
- **Higher** (0.5): More conservative, stays close to reference
|
||||
- **Lower** (0.01): More aggressive alignment
|
||||
- **Default**: 0.1
|
||||
|
||||
### Label Smoothing
|
||||
|
||||
For robust DPO:
|
||||
- **0.0**: No smoothing (default)
|
||||
- **0.1-0.3**: Moderate noise robustness
|
||||
- **0.5**: Maximum noise tolerance
|
||||
|
||||
### Max Lengths
|
||||
|
||||
- `max_prompt_length`: 128-1536
|
||||
- `max_completion_length`: 128-512
|
||||
- `max_length`: Total sequence (1024-2048)
|
||||
|
||||
## Comparison Table
|
||||
|
||||
| Loss | Speed | Stability | Best For |
|
||||
|------|-------|-----------|----------|
|
||||
| Sigmoid | Fast | Good | **General use** |
|
||||
| IPO | Fast | Better | Overfitting issues |
|
||||
| Hinge | Fast | Good | Margin objectives |
|
||||
| Robust | Fast | Best | Noisy data |
|
||||
| BCO | Medium | Good | Binary classification |
|
||||
| DiscoPOP | Fast | Good | New architectures |
|
||||
| APO | Fast | Good | Model quality matching |
|
||||
|
||||
## References
|
||||
|
||||
- DPO paper: https://arxiv.org/abs/2305.18290
|
||||
- IPO paper: https://arxiv.org/abs/2310.12036
|
||||
- TRL docs: https://huggingface.co/docs/trl/dpo_trainer
|
||||
@@ -0,0 +1,82 @@
|
||||
# Online RL Methods
|
||||
|
||||
Guide to online reinforcement learning with PPO, GRPO, RLOO, and OnlineDPO.
|
||||
|
||||
## Overview
|
||||
|
||||
Online RL generates completions during training and optimizes based on rewards.
|
||||
|
||||
## PPO (Proximal Policy Optimization)
|
||||
|
||||
Classic RL algorithm for LLM alignment.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
python -m trl.scripts.ppo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--reward_model_path reward-model \
|
||||
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
||||
--output_dir model-ppo \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--total_episodes 10000 \
|
||||
--num_ppo_epochs 4 \
|
||||
--kl_coef 0.05
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `kl_coef`: KL penalty (0.05-0.2)
|
||||
- `num_ppo_epochs`: Epochs per batch (2-4)
|
||||
- `cliprange`: PPO clip (0.1-0.3)
|
||||
- `vf_coef`: Value function coef (0.1)
|
||||
|
||||
## GRPO (Group Relative Policy Optimization)
|
||||
|
||||
Memory-efficient online RL.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Define reward function
|
||||
def reward_func(completions, **kwargs):
|
||||
return [len(set(c.split())) for c in completions]
|
||||
|
||||
config = GRPOConfig(
|
||||
output_dir="model-grpo",
|
||||
num_generations=4, # Completions per prompt
|
||||
max_new_tokens=128
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_func,
|
||||
args=config,
|
||||
train_dataset=load_dataset("trl-lib/tldr", split="train")
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `num_generations`: 2-8 completions
|
||||
- `max_new_tokens`: 64-256
|
||||
- Learning rate: 1e-5 to 1e-4
|
||||
|
||||
## Memory Comparison
|
||||
|
||||
| Method | Memory (7B) | Speed | Use Case |
|
||||
|--------|-------------|-------|----------|
|
||||
| PPO | 40GB | Medium | Maximum control |
|
||||
| GRPO | 24GB | Fast | **Memory-constrained** |
|
||||
| OnlineDPO | 28GB | Fast | No reward model |
|
||||
|
||||
## References
|
||||
|
||||
- PPO paper: https://arxiv.org/abs/1707.06347
|
||||
- GRPO paper: https://arxiv.org/abs/2402.03300
|
||||
- TRL docs: https://huggingface.co/docs/trl/
|
||||
@@ -0,0 +1,122 @@
|
||||
# Reward Modeling
|
||||
|
||||
Guide to training reward models with TRL for RLHF pipelines.
|
||||
|
||||
## Overview
|
||||
|
||||
Reward models score completions based on human preferences. Used in:
|
||||
- PPO training (RL feedback)
|
||||
- GRPO online RL
|
||||
- Completion ranking
|
||||
|
||||
## Basic Training
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model (num_labels=1 for single reward score)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
num_labels=1
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
|
||||
# Load preference dataset (chosen/rejected pairs)
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
# Configure
|
||||
config = RewardConfig(
|
||||
output_dir="Qwen2.5-Reward",
|
||||
per_device_train_batch_size=2,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
Required fields:
|
||||
```json
|
||||
{
|
||||
"prompt": "Question or instruction",
|
||||
"chosen": "Better response",
|
||||
"rejected": "Worse response"
|
||||
}
|
||||
```
|
||||
|
||||
## Bradley-Terry Loss
|
||||
|
||||
Default loss function:
|
||||
```
|
||||
loss = -log(sigmoid(reward_chosen - reward_rejected))
|
||||
```
|
||||
|
||||
Learns to score chosen > rejected.
|
||||
|
||||
## Using Reward Models
|
||||
|
||||
### Inference
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Load trained reward model
|
||||
reward_pipe = pipeline("text-classification", model="Qwen2.5-Reward")
|
||||
|
||||
# Score completions
|
||||
texts = ["Good answer", "Bad answer"]
|
||||
scores = reward_pipe(texts)
|
||||
print(scores) # Higher score = better
|
||||
```
|
||||
|
||||
### In PPO
|
||||
|
||||
```python
|
||||
from trl import PPOTrainer, PPOConfig
|
||||
|
||||
config = PPOConfig(
|
||||
reward_model_path="Qwen2.5-Reward" # Use trained reward model
|
||||
)
|
||||
|
||||
trainer = PPOTrainer(
|
||||
model=policy_model,
|
||||
config=config,
|
||||
# Reward model loaded automatically
|
||||
)
|
||||
```
|
||||
|
||||
## Hyperparameters
|
||||
|
||||
| Model Size | Learning Rate | Batch Size | Epochs |
|
||||
|------------|---------------|------------|--------|
|
||||
| <1B | 2e-5 | 4-8 | 1-2 |
|
||||
| 1-7B | 1e-5 | 2-4 | 1 |
|
||||
| 7-13B | 5e-6 | 1-2 | 1 |
|
||||
|
||||
## Evaluation
|
||||
|
||||
Check reward separation:
|
||||
```python
|
||||
# Chosen should score higher than rejected
|
||||
chosen_rewards = model(**chosen_inputs).logits
|
||||
rejected_rewards = model(**rejected_inputs).logits
|
||||
|
||||
accuracy = (chosen_rewards > rejected_rewards).float().mean()
|
||||
print(f"Accuracy: {accuracy:.2%}") # Target: >80%
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- InstructGPT paper: https://arxiv.org/abs/2203.02155
|
||||
- TRL docs: https://huggingface.co/docs/trl/reward_trainer
|
||||
@@ -0,0 +1,168 @@
|
||||
# SFT Training Guide
|
||||
|
||||
Complete guide to Supervised Fine-Tuning (SFT) with TRL for instruction tuning and task-specific fine-tuning.
|
||||
|
||||
## Overview
|
||||
|
||||
SFT trains models on input-output pairs to minimize cross-entropy loss. Use for:
|
||||
- Instruction following
|
||||
- Task-specific fine-tuning
|
||||
- Chatbot training
|
||||
- Domain adaptation
|
||||
|
||||
## Dataset Formats
|
||||
|
||||
### Format 1: Prompt-Completion
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"completion": "The capital of France is Paris."
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Format 2: Conversational (ChatML)
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is Python?"},
|
||||
{"role": "assistant", "content": "Python is a programming language."}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Format 3: Text-only
|
||||
|
||||
```json
|
||||
[
|
||||
{"text": "User: Hello\nAssistant: Hi! How can I help?"}
|
||||
]
|
||||
```
|
||||
|
||||
## Basic Training
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# Configure
|
||||
config = SFTConfig(
|
||||
output_dir="Qwen2.5-SFT",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=2e-5,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Chat Templates
|
||||
|
||||
Apply chat templates automatically:
|
||||
|
||||
```python
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset, # Messages format
|
||||
tokenizer=tokenizer
|
||||
# Chat template applied automatically
|
||||
)
|
||||
```
|
||||
|
||||
Or manually:
|
||||
```python
|
||||
def format_chat(example):
|
||||
messages = example["messages"]
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
return {"text": text}
|
||||
|
||||
dataset = dataset.map(format_chat)
|
||||
```
|
||||
|
||||
## Packing for Efficiency
|
||||
|
||||
Pack multiple sequences into one to maximize GPU utilization:
|
||||
|
||||
```python
|
||||
config = SFTConfig(
|
||||
packing=True, # Enable packing
|
||||
max_seq_length=2048,
|
||||
dataset_text_field="text"
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**: 2-3× faster training
|
||||
**Trade-off**: Slightly more complex batching
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
```bash
|
||||
accelerate launch --num_processes 4 train_sft.py
|
||||
```
|
||||
|
||||
Or with config:
|
||||
```python
|
||||
config = SFTConfig(
|
||||
output_dir="model-sft",
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=1
|
||||
)
|
||||
```
|
||||
|
||||
## LoRA Fine-Tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
lora_dropout=0.05,
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config # Add LoRA
|
||||
)
|
||||
```
|
||||
|
||||
## Hyperparameters
|
||||
|
||||
| Model Size | Learning Rate | Batch Size | Epochs |
|
||||
|------------|---------------|------------|--------|
|
||||
| <1B | 5e-5 | 8-16 | 1-3 |
|
||||
| 1-7B | 2e-5 | 4-8 | 1-2 |
|
||||
| 7-13B | 1e-5 | 2-4 | 1 |
|
||||
| 13B+ | 5e-6 | 1-2 | 1 |
|
||||
|
||||
## References
|
||||
|
||||
- TRL docs: https://huggingface.co/docs/trl/sft_trainer
|
||||
- Examples: https://github.com/huggingface/trl/tree/main/examples/scripts
|
||||
83
protected/skills-backup/mlops/training/unsloth/SKILL.md
Normal file
83
protected/skills-backup/mlops/training/unsloth/SKILL.md
Normal file
@@ -0,0 +1,83 @@
|
||||
---
|
||||
name: unsloth
|
||||
description: Expert guidance for fast fine-tuning with Unsloth - 2-5x faster training, 50-80% less memory, LoRA/QLoRA optimization
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [unsloth, torch, transformers, trl, datasets, peft]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Fine-Tuning, Unsloth, Fast Training, LoRA, QLoRA, Memory-Efficient, Optimization, Llama, Mistral, Gemma, Qwen]
|
||||
|
||||
---
|
||||
|
||||
# Unsloth Skill
|
||||
|
||||
Comprehensive assistance with unsloth development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be triggered when:
|
||||
- Working with unsloth
|
||||
- Asking about unsloth features or APIs
|
||||
- Implementing unsloth solutions
|
||||
- Debugging unsloth code
|
||||
- Learning unsloth best practices
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Patterns
|
||||
|
||||
*Quick reference patterns will be added as you use the skill.*
|
||||
|
||||
## Reference Files
|
||||
|
||||
This skill includes comprehensive documentation in `references/`:
|
||||
|
||||
- **llms-txt.md** - Llms-Txt documentation
|
||||
|
||||
Use `view` to read specific reference files when detailed information is needed.
|
||||
|
||||
## Working with This Skill
|
||||
|
||||
### For Beginners
|
||||
Start with the getting_started or tutorials reference files for foundational concepts.
|
||||
|
||||
### For Specific Features
|
||||
Use the appropriate category reference file (api, guides, etc.) for detailed information.
|
||||
|
||||
### For Code Examples
|
||||
The quick reference section above contains common patterns extracted from the official docs.
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
Organized documentation extracted from official sources. These files contain:
|
||||
- Detailed explanations
|
||||
- Code examples with language annotations
|
||||
- Links to original documentation
|
||||
- Table of contents for quick navigation
|
||||
|
||||
### scripts/
|
||||
Add helper scripts here for common automation tasks.
|
||||
|
||||
### assets/
|
||||
Add templates, boilerplate, or example projects here.
|
||||
|
||||
## Notes
|
||||
|
||||
- This skill was automatically generated from official documentation
|
||||
- Reference files preserve the structure and examples from source docs
|
||||
- Code examples include language detection for better syntax highlighting
|
||||
- Quick reference patterns are extracted from common usage examples in the docs
|
||||
|
||||
## Updating
|
||||
|
||||
To refresh this skill with updated documentation:
|
||||
1. Re-run the scraper with the same configuration
|
||||
2. The skill will be rebuilt with the latest information
|
||||
|
||||
<!-- Trigger re-upload 1763621536 -->
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
# Unsloth Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Llms-Txt
|
||||
**File:** `llms-txt.md`
|
||||
**Pages:** 136
|
||||
16799
protected/skills-backup/mlops/training/unsloth/references/llms-full.md
Normal file
16799
protected/skills-backup/mlops/training/unsloth/references/llms-full.md
Normal file
File diff suppressed because one or more lines are too long
12044
protected/skills-backup/mlops/training/unsloth/references/llms-txt.md
Normal file
12044
protected/skills-backup/mlops/training/unsloth/references/llms-txt.md
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,82 @@
|
||||
# Unsloth Documentation
|
||||
|
||||
## Unsloth Documentation
|
||||
|
||||
- [Unsloth Docs](/get-started/unsloth-docs.md): Train your own model with Unsloth, an open-source framework for LLM fine-tuning and reinforcement learning.
|
||||
- [Beginner? Start here!](/get-started/beginner-start-here.md)
|
||||
- [Unsloth Requirements](/get-started/beginner-start-here/unsloth-requirements.md): Here are Unsloth's requirements including system and GPU VRAM requirements.
|
||||
- [FAQ + Is Fine-tuning Right For Me?](/get-started/beginner-start-here/faq-+-is-fine-tuning-right-for-me.md): If you're stuck on if fine-tuning is right for you, see here! Learn about fine-tuning misconceptions, how it compared to RAG and more:
|
||||
- [Unsloth Notebooks](/get-started/unsloth-notebooks.md): Explore our catalog of Unsloth notebooks:
|
||||
- [All Our Models](/get-started/all-our-models.md)
|
||||
- [Install & Update](/get-started/install-and-update.md): Learn to install Unsloth locally or online.
|
||||
- [Updating](/get-started/install-and-update/updating.md): To update or use an old version of Unsloth, follow the steps below:
|
||||
- [Pip Install](/get-started/install-and-update/pip-install.md): To install Unsloth locally via Pip, follow the steps below:
|
||||
- [Docker](/get-started/install-and-update/docker.md): Install Unsloth using our official Docker container
|
||||
- [Windows Installation](/get-started/install-and-update/windows-installation.md): See how to install Unsloth on Windows with or without WSL.
|
||||
- [AMD](/get-started/install-and-update/amd.md): Fine-tune with Unsloth on AMD GPUs.
|
||||
- [Conda Install](/get-started/install-and-update/conda-install.md): To install Unsloth locally on Conda, follow the steps below:
|
||||
- [Google Colab](/get-started/install-and-update/google-colab.md): To install and run Unsloth on Google Colab, follow the steps below:
|
||||
- [Fine-tuning LLMs Guide](/get-started/fine-tuning-llms-guide.md): Learn all the basics and best practices of fine-tuning. Beginner-friendly.
|
||||
- [What Model Should I Use?](/get-started/fine-tuning-llms-guide/what-model-should-i-use.md)
|
||||
- [Datasets Guide](/get-started/fine-tuning-llms-guide/datasets-guide.md): Learn how to create & prepare a dataset for fine-tuning.
|
||||
- [LoRA Hyperparameters Guide](/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide.md): Optimal lora rank. alpha, number of epochs, batch size & gradient accumulation, QLoRA vs LoRA, target modules and more!
|
||||
- [Tutorial: How to Finetune Llama-3 and Use In Ollama](/get-started/fine-tuning-llms-guide/tutorial-how-to-finetune-llama-3-and-use-in-ollama.md): Beginner's Guide for creating a customized personal assistant (like ChatGPT) to run locally on Ollama
|
||||
- [Reinforcement Learning (RL) Guide](/get-started/reinforcement-learning-rl-guide.md): Learn all about Reinforcement Learning (RL) and how to train your own DeepSeek-R1 reasoning model with Unsloth using GRPO. A complete guide from beginner to advanced.
|
||||
- [Tutorial: Train your own Reasoning model with GRPO](/get-started/reinforcement-learning-rl-guide/tutorial-train-your-own-reasoning-model-with-grpo.md): Beginner's Guide to transforming a model like Llama 3.1 (8B) into a reasoning model by using Unsloth and GRPO.
|
||||
- [Advanced RL Documentation](/get-started/reinforcement-learning-rl-guide/advanced-rl-documentation.md): Advanced documentation settings when using Unsloth with GRPO.
|
||||
- [Memory Efficient RL](/get-started/reinforcement-learning-rl-guide/memory-efficient-rl.md)
|
||||
- [RL Reward Hacking](/get-started/reinforcement-learning-rl-guide/rl-reward-hacking.md): Learn what is Reward Hacking in Reinforcement Learning and how to counter it.
|
||||
- [GSPO Reinforcement Learning](/get-started/reinforcement-learning-rl-guide/gspo-reinforcement-learning.md): Train with GSPO (Group Sequence Policy Optimization) RL in Unsloth.
|
||||
- [Reinforcement Learning - DPO, ORPO & KTO](/get-started/reinforcement-learning-rl-guide/reinforcement-learning-dpo-orpo-and-kto.md): To use the reward modelling functions for DPO, GRPO, ORPO or KTO with Unsloth, follow the steps below:
|
||||
- [DeepSeek-OCR: How to Run & Fine-tune](/new/deepseek-ocr-how-to-run-and-fine-tune.md): Guide on how to run and fine-tune DeepSeek-OCR locally.
|
||||
- [How to Fine-tune LLMs with Unsloth & Docker](/new/how-to-fine-tune-llms-with-unsloth-and-docker.md): Learn how to fine-tune LLMs or do Reinforcement Learning (RL) with Unsloth's Docker image.
|
||||
- [Vision Reinforcement Learning (VLM RL)](/new/vision-reinforcement-learning-vlm-rl.md): Train Vision/multimodal models via GRPO and RL with Unsloth!
|
||||
- [gpt-oss Reinforcement Learning](/new/gpt-oss-reinforcement-learning.md)
|
||||
- [Tutorial: How to Train gpt-oss with RL](/new/gpt-oss-reinforcement-learning/tutorial-how-to-train-gpt-oss-with-rl.md): Learn to train OpenAI gpt-oss with GRPO to autonomously beat 2048 locally or on Colab.
|
||||
- [Unsloth Dynamic GGUFs on Aider Polyglot](/new/unsloth-dynamic-ggufs-on-aider-polyglot.md): Performance of Unsloth Dynamic GGUFs on Aider Polyglot Benchmarks
|
||||
- [Qwen3-VL: How to Run & Fine-tune](/models/qwen3-vl-how-to-run-and-fine-tune.md): Learn to fine-tune and run Qwen3-VL locally with Unsloth.
|
||||
- [gpt-oss: How to Run & Fine-tune](/models/gpt-oss-how-to-run-and-fine-tune.md): Run & fine-tune OpenAI's new open-source models!
|
||||
- [Tutorial: How to Fine-tune gpt-oss](/models/gpt-oss-how-to-run-and-fine-tune/tutorial-how-to-fine-tune-gpt-oss.md): Learn step-by-step how to train OpenAI gpt-oss locally with Unsloth.
|
||||
- [Long Context gpt-oss Training](/models/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training.md)
|
||||
- [GLM-4.6: How to Run Locally](/models/glm-4.6-how-to-run-locally.md): A guide on how to run Z.ai's new GLM-4.6 model on your own local device!
|
||||
- [IBM Granite 4.0](/models/ibm-granite-4.0.md): How to run IBM Granite-4.0 with Unsloth GGUFs on llama.cpp, Ollama and how to fine-tune!
|
||||
- [DeepSeek-V3.1: How to Run Locally](/models/deepseek-v3.1-how-to-run-locally.md): A guide on how to run DeepSeek-V3.1 and Terminus on your own local device!
|
||||
- [Qwen3-Coder: How to Run Locally](/models/qwen3-coder-how-to-run-locally.md): Run Qwen3-Coder-30B-A3B-Instruct and 480B-A35B locally with Unsloth Dynamic quants.
|
||||
- [Gemma 3: How to Run & Fine-tune](/models/gemma-3-how-to-run-and-fine-tune.md): How to run Gemma 3 effectively with our GGUFs on llama.cpp, Ollama, Open WebUI and how to fine-tune with Unsloth!
|
||||
- [Gemma 3n: How to Run & Fine-tune](/models/gemma-3-how-to-run-and-fine-tune/gemma-3n-how-to-run-and-fine-tune.md): Run Google's new Gemma 3n locally with Dynamic GGUFs on llama.cpp, Ollama, Open WebUI and fine-tune with Unsloth!
|
||||
- [Qwen3: How to Run & Fine-tune](/models/qwen3-how-to-run-and-fine-tune.md): Learn to run & fine-tune Qwen3 locally with Unsloth + our Dynamic 2.0 quants
|
||||
- [Qwen3-2507](/models/qwen3-how-to-run-and-fine-tune/qwen3-2507.md): Run Qwen3-30B-A3B-2507 and 235B-A22B Thinking and Instruct versions locally on your device!
|
||||
- [Tutorials: How To Fine-tune & Run LLMs](/models/tutorials-how-to-fine-tune-and-run-llms.md): Learn how to run and fine-tune models for optimal performance 100% locally with Unsloth.
|
||||
- [DeepSeek-R1-0528: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-0528-how-to-run-locally.md): A guide on how to run DeepSeek-R1-0528 including Qwen3 on your own local device!
|
||||
- [Magistral: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/magistral-how-to-run-and-fine-tune.md): Meet Magistral - Mistral's new reasoning models.
|
||||
- [Llama 4: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/llama-4-how-to-run-and-fine-tune.md): How to run Llama 4 locally using our dynamic GGUFs which recovers accuracy compared to standard quantization.
|
||||
- [Kimi K2: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/kimi-k2-how-to-run-locally.md): Guide on running Kimi K2 and Kimi-K2-Instruct-0905 on your own local device!
|
||||
- [Grok 2](/models/tutorials-how-to-fine-tune-and-run-llms/grok-2.md): Run xAI's Grok 2 model locally!
|
||||
- [Devstral: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/devstral-how-to-run-and-fine-tune.md): Run and fine-tune Mistral Devstral 1.1, including Small-2507 and 2505.
|
||||
- [DeepSeek-V3-0324: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-v3-0324-how-to-run-locally.md): How to run DeepSeek-V3-0324 locally using our dynamic quants which recovers accuracy
|
||||
- [DeepSeek-R1: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-how-to-run-locally.md): A guide on how you can run our 1.58-bit Dynamic Quants for DeepSeek-R1 using llama.cpp.
|
||||
- [DeepSeek-R1 Dynamic 1.58-bit](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-how-to-run-locally/deepseek-r1-dynamic-1.58-bit.md): See performance comparison tables for Unsloth's Dynamic GGUF Quants vs Standard IMatrix Quants.
|
||||
- [QwQ-32B: How to Run effectively](/models/tutorials-how-to-fine-tune-and-run-llms/qwq-32b-how-to-run-effectively.md): How to run QwQ-32B effectively with our bug fixes and without endless generations + GGUFs.
|
||||
- [Phi-4 Reasoning: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/phi-4-reasoning-how-to-run-and-fine-tune.md): Learn to run & fine-tune Phi-4 reasoning models locally with Unsloth + our Dynamic 2.0 quants
|
||||
- [Running & Saving Models](/basics/running-and-saving-models.md): Learn how to save your finetuned model so you can run it in your favorite inference engine.
|
||||
- [Saving to GGUF](/basics/running-and-saving-models/saving-to-gguf.md): Saving models to 16bit for GGUF so you can use it for Ollama, Jan AI, Open WebUI and more!
|
||||
- [Saving to Ollama](/basics/running-and-saving-models/saving-to-ollama.md)
|
||||
- [Saving to vLLM for deployment](/basics/running-and-saving-models/saving-to-vllm-for-deployment.md): Saving models to 16bit for vLLM deployment and serving
|
||||
- [Saving to SGLang for deployment](/basics/running-and-saving-models/saving-to-sglang-for-deployment.md): Saving models to 16bit for SGLang for deployment and serving
|
||||
- [Unsloth Inference](/basics/running-and-saving-models/unsloth-inference.md): Learn how to run your finetuned model with Unsloth's faster inference.
|
||||
- [Troubleshooting Inference](/basics/running-and-saving-models/troubleshooting-inference.md): If you're experiencing issues when running or saving your model.
|
||||
- [vLLM Engine Arguments](/basics/running-and-saving-models/vllm-engine-arguments.md)
|
||||
- [LoRA Hot Swapping Guide](/basics/running-and-saving-models/lora-hot-swapping-guide.md)
|
||||
- [Text-to-Speech (TTS) Fine-tuning](/basics/text-to-speech-tts-fine-tuning.md): Learn how to fine-tune TTS & STT voice models with Unsloth.
|
||||
- [Unsloth Dynamic 2.0 GGUFs](/basics/unsloth-dynamic-2.0-ggufs.md): A big new upgrade to our Dynamic Quants!
|
||||
- [Vision Fine-tuning](/basics/vision-fine-tuning.md): Learn how to fine-tune vision/multimodal LLMs with Unsloth
|
||||
- [Fine-tuning LLMs with NVIDIA DGX Spark and Unsloth](/basics/fine-tuning-llms-with-nvidia-dgx-spark-and-unsloth.md): Tutorial on how to fine-tune and do reinforcement learning (RL) with OpenAI gpt-oss on NVIDIA DGX Spark.
|
||||
- [Fine-tuning LLMs with Blackwell, RTX 50 series & Unsloth](/basics/fine-tuning-llms-with-blackwell-rtx-50-series-and-unsloth.md): Learn how to fine-tune LLMs on NVIDIA's Blackwell RTX 50 series and B200 GPUs with our step-by-step guide.
|
||||
- [Multi-GPU Training with Unsloth](/basics/multi-gpu-training-with-unsloth.md): Learn how to fine-tune LLMs on multiple GPUs and parallelism with Unsloth.
|
||||
- [Finetuning from Last Checkpoint](/basics/finetuning-from-last-checkpoint.md): Checkpointing allows you to save your finetuning progress so you can pause it and then continue.
|
||||
- [Troubleshooting & FAQs](/basics/troubleshooting-and-faqs.md): Tips to solve issues, and frequently asked questions.
|
||||
- [Chat Templates](/basics/chat-templates.md): Learn the fundamentals and customization options of chat templates, including Conversational, ChatML, ShareGPT, Alpaca formats, and more!
|
||||
- [Quantization-Aware Training (QAT)](/basics/quantization-aware-training-qat.md): Quantize models to 4-bit with Unsloth and PyTorch to recover accuracy.
|
||||
- [Unsloth Environment Flags](/basics/unsloth-environment-flags.md): Advanced flags which might be useful if you see breaking finetunes, or you want to turn stuff off.
|
||||
- [Continued Pretraining](/basics/continued-pretraining.md): AKA as Continued Finetuning. Unsloth allows you to continually pretrain so a model can learn a new language.
|
||||
- [Unsloth Benchmarks](/basics/unsloth-benchmarks.md): Unsloth recorded benchmarks on NVIDIA GPUs.
|
||||
Reference in New Issue
Block a user