diff --git a/skills/mlops/accelerate/SKILL.md b/skills/mlops/accelerate/SKILL.md deleted file mode 100644 index ad2d6fdd7..000000000 --- a/skills/mlops/accelerate/SKILL.md +++ /dev/null @@ -1,335 +0,0 @@ ---- -name: huggingface-accelerate -description: Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard. -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [accelerate, torch, transformers] -metadata: - hermes: - tags: [Distributed Training, HuggingFace, Accelerate, DeepSpeed, FSDP, Mixed Precision, PyTorch, DDP, Unified API, Simple] - ---- - -# HuggingFace Accelerate - Unified Distributed Training - -## Quick start - -Accelerate simplifies distributed training to 4 lines of code. - -**Installation**: -```bash -pip install accelerate -``` - -**Convert PyTorch script** (4 lines): -```python -import torch -+ from accelerate import Accelerator - -+ accelerator = Accelerator() - - model = torch.nn.Transformer() - optimizer = torch.optim.Adam(model.parameters()) - dataloader = torch.utils.data.DataLoader(dataset) - -+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) - - for batch in dataloader: - optimizer.zero_grad() - loss = model(batch) -- loss.backward() -+ accelerator.backward(loss) - optimizer.step() -``` - -**Run** (single command): -```bash -accelerate launch train.py -``` - -## Common workflows - -### Workflow 1: From single GPU to multi-GPU - -**Original script**: -```python -# train.py -import torch - -model = torch.nn.Linear(10, 2).to('cuda') -optimizer = torch.optim.Adam(model.parameters()) -dataloader = torch.utils.data.DataLoader(dataset, batch_size=32) - -for epoch in range(10): - for batch in dataloader: - batch = batch.to('cuda') - optimizer.zero_grad() - loss = model(batch).mean() - loss.backward() - optimizer.step() -``` - -**With Accelerate** (4 lines added): -```python -# train.py -import torch -from accelerate import Accelerator # +1 - -accelerator = Accelerator() # +2 - -model = torch.nn.Linear(10, 2) -optimizer = torch.optim.Adam(model.parameters()) -dataloader = torch.utils.data.DataLoader(dataset, batch_size=32) - -model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # +3 - -for epoch in range(10): - for batch in dataloader: - # No .to('cuda') needed - automatic! - optimizer.zero_grad() - loss = model(batch).mean() - accelerator.backward(loss) # +4 - optimizer.step() -``` - -**Configure** (interactive): -```bash -accelerate config -``` - -**Questions**: -- Which machine? (single/multi GPU/TPU/CPU) -- How many machines? (1) -- Mixed precision? (no/fp16/bf16/fp8) -- DeepSpeed? (no/yes) - -**Launch** (works on any setup): -```bash -# Single GPU -accelerate launch train.py - -# Multi-GPU (8 GPUs) -accelerate launch --multi_gpu --num_processes 8 train.py - -# Multi-node -accelerate launch --multi_gpu --num_processes 16 \ - --num_machines 2 --machine_rank 0 \ - --main_process_ip $MASTER_ADDR \ - train.py -``` - -### Workflow 2: Mixed precision training - -**Enable FP16/BF16**: -```python -from accelerate import Accelerator - -# FP16 (with gradient scaling) -accelerator = Accelerator(mixed_precision='fp16') - -# BF16 (no scaling, more stable) -accelerator = Accelerator(mixed_precision='bf16') - -# FP8 (H100+) -accelerator = Accelerator(mixed_precision='fp8') - -model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) - -# Everything else is automatic! -for batch in dataloader: - with accelerator.autocast(): # Optional, done automatically - loss = model(batch) - accelerator.backward(loss) -``` - -### Workflow 3: DeepSpeed ZeRO integration - -**Enable DeepSpeed ZeRO-2**: -```python -from accelerate import Accelerator - -accelerator = Accelerator( - mixed_precision='bf16', - deepspeed_plugin={ - "zero_stage": 2, # ZeRO-2 - "offload_optimizer": False, - "gradient_accumulation_steps": 4 - } -) - -# Same code as before! -model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) -``` - -**Or via config**: -```bash -accelerate config -# Select: DeepSpeed → ZeRO-2 -``` - -**deepspeed_config.json**: -```json -{ - "fp16": {"enabled": false}, - "bf16": {"enabled": true}, - "zero_optimization": { - "stage": 2, - "offload_optimizer": {"device": "cpu"}, - "allgather_bucket_size": 5e8, - "reduce_bucket_size": 5e8 - } -} -``` - -**Launch**: -```bash -accelerate launch --config_file deepspeed_config.json train.py -``` - -### Workflow 4: FSDP (Fully Sharded Data Parallel) - -**Enable FSDP**: -```python -from accelerate import Accelerator, FullyShardedDataParallelPlugin - -fsdp_plugin = FullyShardedDataParallelPlugin( - sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent - auto_wrap_policy="TRANSFORMER_AUTO_WRAP", - cpu_offload=False -) - -accelerator = Accelerator( - mixed_precision='bf16', - fsdp_plugin=fsdp_plugin -) - -model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) -``` - -**Or via config**: -```bash -accelerate config -# Select: FSDP → Full Shard → No CPU Offload -``` - -### Workflow 5: Gradient accumulation - -**Accumulate gradients**: -```python -from accelerate import Accelerator - -accelerator = Accelerator(gradient_accumulation_steps=4) - -model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) - -for batch in dataloader: - with accelerator.accumulate(model): # Handles accumulation - optimizer.zero_grad() - loss = model(batch) - accelerator.backward(loss) - optimizer.step() -``` - -**Effective batch size**: `batch_size * num_gpus * gradient_accumulation_steps` - -## When to use vs alternatives - -**Use Accelerate when**: -- Want simplest distributed training -- Need single script for any hardware -- Use HuggingFace ecosystem -- Want flexibility (DDP/DeepSpeed/FSDP/Megatron) -- Need quick prototyping - -**Key advantages**: -- **4 lines**: Minimal code changes -- **Unified API**: Same code for DDP, DeepSpeed, FSDP, Megatron -- **Automatic**: Device placement, mixed precision, sharding -- **Interactive config**: No manual launcher setup -- **Single launch**: Works everywhere - -**Use alternatives instead**: -- **PyTorch Lightning**: Need callbacks, high-level abstractions -- **Ray Train**: Multi-node orchestration, hyperparameter tuning -- **DeepSpeed**: Direct API control, advanced features -- **Raw DDP**: Maximum control, minimal abstraction - -## Common issues - -**Issue: Wrong device placement** - -Don't manually move to device: -```python -# WRONG -batch = batch.to('cuda') - -# CORRECT -# Accelerate handles it automatically after prepare() -``` - -**Issue: Gradient accumulation not working** - -Use context manager: -```python -# CORRECT -with accelerator.accumulate(model): - optimizer.zero_grad() - accelerator.backward(loss) - optimizer.step() -``` - -**Issue: Checkpointing in distributed** - -Use accelerator methods: -```python -# Save only on main process -if accelerator.is_main_process: - accelerator.save_state('checkpoint/') - -# Load on all processes -accelerator.load_state('checkpoint/') -``` - -**Issue: Different results with FSDP** - -Ensure same random seed: -```python -from accelerate.utils import set_seed -set_seed(42) -``` - -## Advanced topics - -**Megatron integration**: See [references/megatron-integration.md](references/megatron-integration.md) for tensor parallelism, pipeline parallelism, and sequence parallelism setup. - -**Custom plugins**: See [references/custom-plugins.md](references/custom-plugins.md) for creating custom distributed plugins and advanced configuration. - -**Performance tuning**: See [references/performance.md](references/performance.md) for profiling, memory optimization, and best practices. - -## Hardware requirements - -- **CPU**: Works (slow) -- **Single GPU**: Works -- **Multi-GPU**: DDP (default), DeepSpeed, or FSDP -- **Multi-node**: DDP, DeepSpeed, FSDP, Megatron -- **TPU**: Supported -- **Apple MPS**: Supported - -**Launcher requirements**: -- **DDP**: `torch.distributed.run` (built-in) -- **DeepSpeed**: `deepspeed` (pip install deepspeed) -- **FSDP**: PyTorch 1.12+ (built-in) -- **Megatron**: Custom setup - -## Resources - -- Docs: https://huggingface.co/docs/accelerate -- GitHub: https://github.com/huggingface/accelerate -- Version: 1.11.0+ -- Tutorial: "Accelerate your scripts" -- Examples: https://github.com/huggingface/accelerate/tree/main/examples -- Used by: HuggingFace Transformers, TRL, PEFT, all HF libraries - - - diff --git a/skills/mlops/accelerate/references/custom-plugins.md b/skills/mlops/accelerate/references/custom-plugins.md deleted file mode 100644 index d8207ee85..000000000 --- a/skills/mlops/accelerate/references/custom-plugins.md +++ /dev/null @@ -1,453 +0,0 @@ -# Custom Plugins for Accelerate - -## Overview - -Accelerate allows creating **custom plugins** to extend distributed training strategies beyond built-in options (DDP, FSDP, DeepSpeed). - -## Plugin Architecture - -### Base Plugin Structure - -```python -from accelerate.utils import DistributedDataParallelKwargs -from dataclasses import dataclass - -@dataclass -class CustomPlugin: - """Custom training plugin.""" - - # Plugin configuration - param1: int = 1 - param2: str = "default" - - def __post_init__(self): - # Validation logic - if self.param1 < 1: - raise ValueError("param1 must be >= 1") -``` - -### Using Custom Plugin - -```python -from accelerate import Accelerator - -# Create plugin -custom_plugin = CustomPlugin(param1=4, param2="value") - -# Pass to Accelerator -accelerator = Accelerator( - custom_plugin=custom_plugin # Not a real parameter, example only -) -``` - -## Built-In Plugin Examples - -### 1. GradScalerKwargs (FP16 Configuration) - -```python -from accelerate.utils import GradScalerKwargs - -# Configure gradient scaler for FP16 -scaler_kwargs = GradScalerKwargs( - init_scale=2.**16, # Initial loss scale - growth_factor=2.0, # Scale growth rate - backoff_factor=0.5, # Scale backoff rate - growth_interval=2000, # Steps between scale increases - enabled=True # Enable scaler -) - -accelerator = Accelerator( - mixed_precision='fp16', - kwargs_handlers=[scaler_kwargs] # Pass as kwargs handler -) -``` - -**Use case**: Fine-tune FP16 gradient scaling behavior - -### 2. DistributedDataParallelKwargs - -```python -from accelerate.utils import DistributedDataParallelKwargs - -# Configure DDP behavior -ddp_kwargs = DistributedDataParallelKwargs( - bucket_cap_mb=25, # Gradient bucketing size - find_unused_parameters=False, # Find unused params (slower) - check_reduction=False, # Check gradient reduction - gradient_as_bucket_view=True, # Memory optimization - static_graph=False # Static computation graph -) - -accelerator = Accelerator( - kwargs_handlers=[ddp_kwargs] -) -``` - -**Use case**: Optimize DDP performance for specific models - -### 3. FP8RecipeKwargs (H100 FP8) - -```python -from accelerate.utils import FP8RecipeKwargs - -# Configure FP8 training (H100) -fp8_recipe = FP8RecipeKwargs( - backend="te", # TransformerEngine backend - margin=0, # Scaling margin - interval=1, # Scaling interval - fp8_format="HYBRID", # E4M3 + E5M2 hybrid - amax_history_len=1024, # AMAX history length - amax_compute_algo="max" # AMAX computation algorithm -) - -accelerator = Accelerator( - mixed_precision='fp8', - kwargs_handlers=[fp8_recipe] -) -``` - -**Use case**: Ultra-fast training on H100 GPUs - -## Custom DeepSpeed Configuration - -### ZeRO-3 with CPU Offload - -```python -from accelerate import Accelerator -from accelerate.utils import DeepSpeedPlugin - -# Custom DeepSpeed config -ds_plugin = DeepSpeedPlugin( - zero_stage=3, # ZeRO-3 - offload_optimizer_device="cpu", # CPU offload optimizer - offload_param_device="cpu", # CPU offload parameters - zero3_init_flag=True, # ZeRO-3 initialization - zero3_save_16bit_model=True, # Save FP16 weights -) - -accelerator = Accelerator( - deepspeed_plugin=ds_plugin, - mixed_precision='bf16' -) -``` - -### ZeRO-2 with NVMe Offload - -```python -ds_plugin = DeepSpeedPlugin( - zero_stage=2, - offload_optimizer_device="nvme", # NVMe offload - offload_param_device="nvme", - nvme_path="/local_nvme", # NVMe mount path -) -``` - -### Custom JSON Config - -```python -import json - -# Load custom DeepSpeed config -with open('deepspeed_config.json', 'r') as f: - ds_config = json.load(f) - -ds_plugin = DeepSpeedPlugin(hf_ds_config=ds_config) - -accelerator = Accelerator(deepspeed_plugin=ds_plugin) -``` - -**Example config** (`deepspeed_config.json`): -```json -{ - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "gradient_accumulation_steps": "auto", - "gradient_clipping": 1.0, - "zero_optimization": { - "stage": 3, - "offload_optimizer": { - "device": "cpu", - "pin_memory": true - }, - "offload_param": { - "device": "cpu", - "pin_memory": true - }, - "overlap_comm": true, - "contiguous_gradients": true, - "sub_group_size": 1e9, - "reduce_bucket_size": 5e8, - "stage3_prefetch_bucket_size": 5e8, - "stage3_param_persistence_threshold": 1e6, - "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9, - "stage3_gather_16bit_weights_on_model_save": true - }, - "bf16": { - "enabled": true - }, - "steps_per_print": 100, - "wall_clock_breakdown": false -} -``` - -## Custom FSDP Configuration - -### FSDP with Custom Auto-Wrap Policy - -```python -from accelerate.utils import FullyShardedDataParallelPlugin -from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy -from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy -import functools - -# Custom wrap policy (size-based) -wrap_policy = functools.partial( - size_based_auto_wrap_policy, - min_num_params=1e6 # Wrap layers with 1M+ params -) - -fsdp_plugin = FullyShardedDataParallelPlugin( - sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent - backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch strategy - mixed_precision_policy=None, # Use Accelerator's mixed precision - auto_wrap_policy=wrap_policy, # Custom wrapping - cpu_offload=False, - ignored_modules=None, # Modules to not wrap - state_dict_type="FULL_STATE_DICT", # Save format - optim_state_dict_config=None, - limit_all_gathers=False, - use_orig_params=True, # Use original param shapes -) - -accelerator = Accelerator( - fsdp_plugin=fsdp_plugin, - mixed_precision='bf16' -) -``` - -### FSDP with Transformer Auto-Wrap - -```python -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from transformers.models.gpt2.modeling_gpt2 import GPT2Block - -# Wrap at transformer block level -wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={GPT2Block} # Wrap GPT2Block layers -) - -fsdp_plugin = FullyShardedDataParallelPlugin( - auto_wrap_policy=wrap_policy -) -``` - -## Creating Custom Training Strategy - -### Example: Custom Gradient Accumulation - -```python -from accelerate import Accelerator - -class CustomGradientAccumulation: - def __init__(self, steps=4, adaptive=False): - self.steps = steps - self.adaptive = adaptive - self.current_step = 0 - - def should_sync(self, loss): - """Decide whether to sync gradients.""" - self.current_step += 1 - - # Adaptive: sync on high loss - if self.adaptive and loss > threshold: - self.current_step = 0 - return True - - # Regular: sync every N steps - if self.current_step >= self.steps: - self.current_step = 0 - return True - - return False - -# Usage -custom_accum = CustomGradientAccumulation(steps=8, adaptive=True) -accelerator = Accelerator() - -for batch in dataloader: - outputs = model(**batch) - loss = outputs.loss - - # Scale loss - loss = loss / custom_accum.steps - accelerator.backward(loss) - - # Conditional sync - if custom_accum.should_sync(loss.item()): - optimizer.step() - optimizer.zero_grad() -``` - -### Example: Custom Mixed Precision - -```python -import torch - -class CustomMixedPrecision: - """Custom mixed precision with dynamic loss scaling.""" - - def __init__(self, init_scale=2**16, scale_window=2000): - self.scaler = torch.cuda.amp.GradScaler( - init_scale=init_scale, - growth_interval=scale_window - ) - self.scale_history = [] - - def scale_loss(self, loss): - """Scale loss for backward.""" - return self.scaler.scale(loss) - - def unscale_and_clip(self, optimizer, max_norm=1.0): - """Unscale gradients and clip.""" - self.scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_( - optimizer.param_groups[0]['params'], - max_norm - ) - - def step(self, optimizer): - """Optimizer step with scaler update.""" - scale_before = self.scaler.get_scale() - self.scaler.step(optimizer) - self.scaler.update() - scale_after = self.scaler.get_scale() - - # Track scale changes - if scale_before != scale_after: - self.scale_history.append(scale_after) - -# Usage -custom_mp = CustomMixedPrecision() - -for batch in dataloader: - with torch.cuda.amp.autocast(dtype=torch.float16): - loss = model(**batch).loss - - scaled_loss = custom_mp.scale_loss(loss) - scaled_loss.backward() - - custom_mp.unscale_and_clip(optimizer, max_norm=1.0) - custom_mp.step(optimizer) - optimizer.zero_grad() -``` - -## Advanced: Custom Distributed Backend - -### Custom AllReduce Strategy - -```python -import torch.distributed as dist - -class CustomAllReduce: - """Custom all-reduce with compression.""" - - def __init__(self, compression_ratio=0.1): - self.compression_ratio = compression_ratio - - def compress_gradients(self, tensor): - """Top-k gradient compression.""" - k = int(tensor.numel() * self.compression_ratio) - values, indices = torch.topk(tensor.abs().view(-1), k) - return values, indices - - def all_reduce_compressed(self, tensor): - """All-reduce with gradient compression.""" - # Compress - values, indices = self.compress_gradients(tensor) - - # All-reduce compressed gradients - dist.all_reduce(values, op=dist.ReduceOp.SUM) - - # Decompress - tensor_compressed = torch.zeros_like(tensor).view(-1) - tensor_compressed[indices] = values / dist.get_world_size() - - return tensor_compressed.view_as(tensor) - -# Usage in training loop -custom_ar = CustomAllReduce(compression_ratio=0.1) - -for batch in dataloader: - loss = model(**batch).loss - loss.backward() - - # Custom all-reduce - for param in model.parameters(): - if param.grad is not None: - param.grad.data = custom_ar.all_reduce_compressed(param.grad.data) - - optimizer.step() - optimizer.zero_grad() -``` - -## Plugin Best Practices - -### 1. Validation in `__post_init__` - -```python -@dataclass -class CustomPlugin: - learning_rate: float = 1e-3 - warmup_steps: int = 1000 - - def __post_init__(self): - # Validate parameters - if self.learning_rate <= 0: - raise ValueError("learning_rate must be positive") - if self.warmup_steps < 0: - raise ValueError("warmup_steps must be non-negative") - - # Compute derived values - self.min_lr = self.learning_rate * 0.1 -``` - -### 2. Compatibility Checks - -```python -@dataclass -class CustomPlugin: - feature_enabled: bool = True - - def is_compatible(self, accelerator): - """Check if plugin is compatible with accelerator config.""" - if self.feature_enabled and accelerator.mixed_precision == 'fp8': - raise ValueError("Custom plugin not compatible with FP8") - return True -``` - -### 3. State Management - -```python -@dataclass -class CustomPlugin: - counter: int = 0 - history: list = None - - def __post_init__(self): - if self.history is None: - self.history = [] - - def update_state(self, value): - """Update plugin state during training.""" - self.counter += 1 - self.history.append(value) -``` - -## Resources - -- Accelerate Plugins: https://huggingface.co/docs/accelerate/package_reference/kwargs -- DeepSpeed Config: https://www.deepspeed.ai/docs/config-json/ -- FSDP Guide: https://pytorch.org/docs/stable/fsdp.html -- Custom Training Loops: https://huggingface.co/docs/accelerate/usage_guides/training_tpu diff --git a/skills/mlops/accelerate/references/megatron-integration.md b/skills/mlops/accelerate/references/megatron-integration.md deleted file mode 100644 index 61b025b5e..000000000 --- a/skills/mlops/accelerate/references/megatron-integration.md +++ /dev/null @@ -1,489 +0,0 @@ -# Megatron Integration with Accelerate - -## Overview - -Accelerate supports Megatron-LM for massive model training with tensor parallelism and pipeline parallelism. - -**Megatron capabilities**: -- **Tensor Parallelism (TP)**: Split layers across GPUs -- **Pipeline Parallelism (PP)**: Split model depth across GPUs -- **Data Parallelism (DP)**: Replicate model across GPU groups -- **Sequence Parallelism**: Split sequences for long contexts - -## Setup - -### Install Megatron-LM - -```bash -# Clone Megatron-LM repository -git clone https://github.com/NVIDIA/Megatron-LM.git -cd Megatron-LM -pip install -e . - -# Install Apex (NVIDIA optimizations) -git clone https://github.com/NVIDIA/apex -cd apex -pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \ - --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ -``` - -### Accelerate Configuration - -```bash -accelerate config -``` - -**Questions**: -``` -In which compute environment are you running? -> This machine - -Which type of machine are you using? -> Multi-GPU - -How many different machines will you use? -> 1 - -Do you want to use DeepSpeed/FSDP? -> No - -Do you want to use Megatron-LM? -> Yes - -What is the Tensor Parallelism degree? [1-8] -> 2 - -Do you want to enable Sequence Parallelism? -> No - -What is the Pipeline Parallelism degree? [1-8] -> 2 - -What is the Data Parallelism degree? [1-8] -> 2 - -Where to perform activation checkpointing? ['SELECTIVE', 'FULL', 'NONE'] -> SELECTIVE - -Where to perform activation partitioning? ['SEQUENTIAL', 'UNIFORM'] -> SEQUENTIAL -``` - -**Generated config** (`~/.cache/huggingface/accelerate/default_config.yaml`): -```yaml -compute_environment: LOCAL_MACHINE -distributed_type: MEGATRON_LM -downcast_bf16: 'no' -machine_rank: 0 -main_training_function: main -megatron_lm_config: - megatron_lm_gradient_clipping: 1.0 - megatron_lm_learning_rate_decay_iters: 320000 - megatron_lm_num_micro_batches: 1 - megatron_lm_pp_degree: 2 - megatron_lm_recompute_activations: true - megatron_lm_sequence_parallelism: false - megatron_lm_tp_degree: 2 -mixed_precision: bf16 -num_machines: 1 -num_processes: 8 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false -``` - -## Parallelism Strategies - -### Tensor Parallelism (TP) - -**Splits each transformer layer across GPUs**: - -```python -# Layer split across 2 GPUs -# GPU 0: First half of attention heads -# GPU 1: Second half of attention heads - -# Each GPU computes partial outputs -# All-reduce combines results -``` - -**TP degree recommendations**: -- **TP=1**: No tensor parallelism (single GPU per layer) -- **TP=2**: 2 GPUs per layer (good for 7-13B models) -- **TP=4**: 4 GPUs per layer (good for 20-40B models) -- **TP=8**: 8 GPUs per layer (good for 70B+ models) - -**Benefits**: -- Reduces memory per GPU -- All-reduce communication (fast) - -**Drawbacks**: -- Requires fast inter-GPU bandwidth (NVLink) -- Communication overhead per layer - -### Pipeline Parallelism (PP) - -**Splits model depth across GPUs**: - -```python -# 12-layer model, PP=4 -# GPU 0: Layers 0-2 -# GPU 1: Layers 3-5 -# GPU 2: Layers 6-8 -# GPU 3: Layers 9-11 -``` - -**PP degree recommendations**: -- **PP=1**: No pipeline parallelism -- **PP=2**: 2 pipeline stages (good for 20-40B models) -- **PP=4**: 4 pipeline stages (good for 70B+ models) -- **PP=8**: 8 pipeline stages (good for 175B+ models) - -**Benefits**: -- Linear memory reduction (4× PP = 4× less memory) -- Works across nodes (slower interconnect OK) - -**Drawbacks**: -- Pipeline bubbles (idle time) -- Requires micro-batching - -### Data Parallelism (DP) - -**Replicates model across GPU groups**: - -```python -# 8 GPUs, TP=2, PP=2, DP=2 -# Group 0 (GPUs 0-3): Full model replica -# Group 1 (GPUs 4-7): Full model replica -``` - -**DP degree**: -- `DP = total_gpus / (TP × PP)` -- Example: 8 GPUs, TP=2, PP=2 → DP=2 - -**Benefits**: -- Increases throughput -- Scales batch size - -### Sequence Parallelism - -**Splits long sequences across GPUs** (extends TP): - -```python -# 8K sequence, TP=2, Sequence Parallel=True -# GPU 0: Tokens 0-4095 -# GPU 1: Tokens 4096-8191 -``` - -**Benefits**: -- Enables very long sequences (100K+ tokens) -- Reduces activation memory - -**Requirements**: -- Must use with TP > 1 -- RoPE/ALiBi position encodings work best - -## Accelerate Code Example - -### Basic Setup - -```python -from accelerate import Accelerator -from accelerate.utils import MegatronLMPlugin - -# Configure Megatron -megatron_plugin = MegatronLMPlugin( - tp_degree=2, # Tensor parallelism degree - pp_degree=2, # Pipeline parallelism degree - num_micro_batches=4, # Micro-batches for pipeline - gradient_clipping=1.0, # Gradient clipping value - sequence_parallelism=False, # Enable sequence parallelism - recompute_activations=True, # Activation checkpointing - use_distributed_optimizer=True, # Distributed optimizer - custom_prepare_model_function=None, # Custom model prep -) - -# Initialize accelerator -accelerator = Accelerator( - mixed_precision='bf16', - megatron_lm_plugin=megatron_plugin -) - -# Prepare model and optimizer -model, optimizer, train_dataloader = accelerator.prepare( - model, optimizer, train_dataloader -) - -# Training loop (same as DDP!) -for batch in train_dataloader: - optimizer.zero_grad() - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() -``` - -### Full Training Script - -```python -import torch -from accelerate import Accelerator -from accelerate.utils import MegatronLMPlugin -from transformers import GPT2Config, GPT2LMHeadModel - -def main(): - # Megatron configuration - megatron_plugin = MegatronLMPlugin( - tp_degree=2, - pp_degree=2, - num_micro_batches=4, - gradient_clipping=1.0, - ) - - accelerator = Accelerator( - mixed_precision='bf16', - gradient_accumulation_steps=8, - megatron_lm_plugin=megatron_plugin - ) - - # Model - config = GPT2Config( - n_layer=24, - n_head=16, - n_embd=1024, - ) - model = GPT2LMHeadModel(config) - - # Optimizer - optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4) - - # Prepare - model, optimizer, train_loader = accelerator.prepare( - model, optimizer, train_loader - ) - - # Training loop - for epoch in range(num_epochs): - for batch in train_loader: - with accelerator.accumulate(model): - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - - # Save checkpoint - accelerator.wait_for_everyone() - accelerator.save_state(f'checkpoint-epoch-{epoch}') - -if __name__ == '__main__': - main() -``` - -### Launch Command - -```bash -# 8 GPUs, TP=2, PP=2, DP=2 -accelerate launch --multi_gpu --num_processes 8 train.py - -# Multi-node (2 nodes, 8 GPUs each) -# Node 0 -accelerate launch --multi_gpu --num_processes 16 \ - --num_machines 2 --machine_rank 0 \ - --main_process_ip $MASTER_ADDR \ - --main_process_port 29500 \ - train.py - -# Node 1 -accelerate launch --multi_gpu --num_processes 16 \ - --num_machines 2 --machine_rank 1 \ - --main_process_ip $MASTER_ADDR \ - --main_process_port 29500 \ - train.py -``` - -## Activation Checkpointing - -**Reduces memory by recomputing activations**: - -```python -megatron_plugin = MegatronLMPlugin( - recompute_activations=True, # Enable checkpointing - checkpoint_num_layers=1, # Checkpoint every N layers - distribute_checkpointed_activations=True, # Distribute across TP - partition_activations=True, # Partition in PP - check_for_nan_in_loss_and_grad=True, # Stability check -) -``` - -**Strategies**: -- `SELECTIVE`: Checkpoint transformer blocks only -- `FULL`: Checkpoint all layers -- `NONE`: No checkpointing - -**Memory savings**: 30-50% with 10-15% slowdown - -## Distributed Optimizer - -**Shards optimizer state across DP ranks**: - -```python -megatron_plugin = MegatronLMPlugin( - use_distributed_optimizer=True, # Enable sharded optimizer -) -``` - -**Benefits**: -- Reduces optimizer memory by DP degree -- Example: DP=4 → 4× less optimizer memory per GPU - -**Compatible with**: -- AdamW, Adam, SGD -- Mixed precision training - -## Performance Tuning - -### Micro-Batch Size - -```python -# Pipeline parallelism requires micro-batching -megatron_plugin = MegatronLMPlugin( - pp_degree=4, - num_micro_batches=16, # 16 micro-batches per pipeline -) - -# Effective batch = num_micro_batches × micro_batch_size × DP -# Example: 16 × 2 × 4 = 128 -``` - -**Recommendations**: -- More micro-batches → less pipeline bubble -- Typical: 4-16 micro-batches - -### Sequence Length - -```python -# For long sequences, enable sequence parallelism -megatron_plugin = MegatronLMPlugin( - tp_degree=4, - sequence_parallelism=True, # Required: TP > 1 -) - -# Enables sequences up to TP × normal limit -# Example: TP=4, 8K normal → 32K with sequence parallel -``` - -### GPU Topology - -**NVLink required for TP**: -```bash -# Check NVLink topology -nvidia-smi topo -m - -# Good topology (NVLink between all GPUs) -# GPU0 - GPU1: NV12 (fast) -# GPU0 - GPU2: NV12 (fast) - -# Bad topology (PCIe only) -# GPU0 - GPU4: PHB (slow, avoid TP across these) -``` - -**Recommendations**: -- **TP**: Within same node (NVLink) -- **PP**: Across nodes (slower interconnect OK) -- **DP**: Any topology - -## Model Size Guidelines - -| Model Size | GPUs | TP | PP | DP | Micro-Batches | -|------------|------|----|----|----|--------------| -| 7B | 8 | 1 | 1 | 8 | 1 | -| 13B | 8 | 2 | 1 | 4 | 1 | -| 20B | 16 | 4 | 1 | 4 | 1 | -| 40B | 32 | 4 | 2 | 4 | 4 | -| 70B | 64 | 8 | 2 | 4 | 8 | -| 175B | 128 | 8 | 4 | 4 | 16 | - -**Assumptions**: BF16, 2K sequence length, A100 80GB - -## Checkpointing - -### Save Checkpoint - -```python -# Save full model state -accelerator.save_state('checkpoint-1000') - -# Megatron saves separate files per rank -# checkpoint-1000/ -# pytorch_model_tp_0_pp_0.bin -# pytorch_model_tp_0_pp_1.bin -# pytorch_model_tp_1_pp_0.bin -# pytorch_model_tp_1_pp_1.bin -# optimizer_tp_0_pp_0.bin -# ... -``` - -### Load Checkpoint - -```python -# Resume training -accelerator.load_state('checkpoint-1000') - -# Automatically loads correct shard per rank -``` - -### Convert to Standard PyTorch - -```bash -# Merge Megatron checkpoint to single file -python merge_megatron_checkpoint.py \ - --checkpoint-dir checkpoint-1000 \ - --output pytorch_model.bin -``` - -## Common Issues - -### Issue: OOM with Pipeline Parallelism - -**Solution**: Increase micro-batches -```python -megatron_plugin = MegatronLMPlugin( - pp_degree=4, - num_micro_batches=16, # Increase from 4 -) -``` - -### Issue: Slow Training - -**Check 1**: Pipeline bubbles (PP too high) -```python -# Reduce PP, increase TP -tp_degree=4 # Increase -pp_degree=2 # Decrease -``` - -**Check 2**: Micro-batch size too small -```python -num_micro_batches=8 # Increase -``` - -### Issue: NVLink Not Detected - -```bash -# Verify NVLink -nvidia-smi nvlink -s - -# If no NVLink, avoid TP > 1 -# Use PP or DP instead -``` - -## Resources - -- Megatron-LM: https://github.com/NVIDIA/Megatron-LM -- Accelerate Megatron docs: https://huggingface.co/docs/accelerate/usage_guides/megatron_lm -- Paper: "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism" -- NVIDIA Apex: https://github.com/NVIDIA/apex diff --git a/skills/mlops/accelerate/references/performance.md b/skills/mlops/accelerate/references/performance.md deleted file mode 100644 index 62560d2bf..000000000 --- a/skills/mlops/accelerate/references/performance.md +++ /dev/null @@ -1,525 +0,0 @@ -# Accelerate Performance Tuning - -## Profiling - -### Basic Profiling - -```python -from accelerate import Accelerator -import time - -accelerator = Accelerator() - -# Warmup -for _ in range(10): - batch = next(iter(dataloader)) - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - -# Profile training loop -start = time.time() -total_batches = 100 - -for i, batch in enumerate(dataloader): - if i >= total_batches: - break - - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - -accelerator.wait_for_everyone() # Sync all processes -elapsed = time.time() - start - -# Metrics -batches_per_sec = total_batches / elapsed -samples_per_sec = (total_batches * batch_size * accelerator.num_processes) / elapsed - -print(f"Throughput: {samples_per_sec:.2f} samples/sec") -print(f"Batches/sec: {batches_per_sec:.2f}") -``` - -### PyTorch Profiler Integration - -```python -from torch.profiler import profile, ProfilerActivity - -with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - record_shapes=True, - profile_memory=True, - with_stack=True -) as prof: - for i, batch in enumerate(dataloader): - if i >= 10: # Profile first 10 batches - break - - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - -# Print profiling results -print(prof.key_averages().table( - sort_by="cuda_time_total", row_limit=20 -)) - -# Export to Chrome tracing -prof.export_chrome_trace("trace.json") -# View at chrome://tracing -``` - -## Memory Optimization - -### 1. Gradient Accumulation - -**Problem**: Large batch size causes OOM - -**Solution**: Accumulate gradients across micro-batches - -```python -accelerator = Accelerator(gradient_accumulation_steps=8) - -# Effective batch = batch_size × accumulation_steps × num_gpus -# Example: 4 × 8 × 8 = 256 - -for batch in dataloader: - with accelerator.accumulate(model): # Handles accumulation logic - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() -``` - -**Memory savings**: 8× less activation memory (with 8 accumulation steps) - -### 2. Gradient Checkpointing - -**Enable in model**: - -```python -from transformers import AutoModelForCausalLM - -model = AutoModelForCausalLM.from_pretrained( - "gpt2", - use_cache=False # Required for gradient checkpointing -) - -# Enable checkpointing -model.gradient_checkpointing_enable() - -# Prepare with Accelerate -model = accelerator.prepare(model) -``` - -**Memory savings**: 30-50% with 10-15% slowdown - -### 3. Mixed Precision - -**BF16 (A100/H100)**: -```python -accelerator = Accelerator(mixed_precision='bf16') - -# Automatic mixed precision -for batch in dataloader: - outputs = model(**batch) # Forward in BF16 - loss = outputs.loss - accelerator.backward(loss) # Backward in FP32 - optimizer.step() -``` - -**FP16 (V100, older GPUs)**: -```python -from accelerate.utils import GradScalerKwargs - -scaler_kwargs = GradScalerKwargs( - init_scale=2.**16, - growth_interval=2000 -) - -accelerator = Accelerator( - mixed_precision='fp16', - kwargs_handlers=[scaler_kwargs] -) -``` - -**Memory savings**: 50% compared to FP32 - -### 4. CPU Offloading (DeepSpeed) - -```python -from accelerate.utils import DeepSpeedPlugin - -ds_plugin = DeepSpeedPlugin( - zero_stage=3, - offload_optimizer_device="cpu", # Offload optimizer to CPU - offload_param_device="cpu", # Offload parameters to CPU -) - -accelerator = Accelerator( - deepspeed_plugin=ds_plugin, - mixed_precision='bf16' -) -``` - -**Memory savings**: 10-20× for optimizer state, 5-10× for parameters - -**Trade-off**: 20-30% slower due to CPU-GPU transfers - -### 5. Flash Attention - -```python -# Install flash-attn -# pip install flash-attn - -from transformers import AutoModelForCausalLM - -model = AutoModelForCausalLM.from_pretrained( - "gpt2", - attn_implementation="flash_attention_2" # Enable Flash Attention 2 -) - -model = accelerator.prepare(model) -``` - -**Memory savings**: 50% for attention, 2× faster - -**Requirements**: A100/H100, sequence length must be multiple of 128 - -## Communication Optimization - -### 1. Gradient Bucketing (DDP) - -```python -from accelerate.utils import DistributedDataParallelKwargs - -ddp_kwargs = DistributedDataParallelKwargs( - bucket_cap_mb=25, # Bucket size for gradient reduction - gradient_as_bucket_view=True, # Reduce memory copies - static_graph=False # Set True if model doesn't change -) - -accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) -``` - -**Recommended bucket sizes**: -- Small models (<1B): 25 MB -- Medium models (1-10B): 50-100 MB -- Large models (>10B): 100-200 MB - -### 2. Find Unused Parameters - -```python -# Only enable if model has unused parameters (slower!) -ddp_kwargs = DistributedDataParallelKwargs( - find_unused_parameters=True -) -``` - -**Use case**: Models with conditional branches (e.g., mixture of experts) - -**Cost**: 10-20% slower - -### 3. NCCL Tuning - -```bash -# Set environment variables before launch -export NCCL_DEBUG=INFO # Debug info -export NCCL_IB_DISABLE=0 # Enable InfiniBand -export NCCL_SOCKET_IFNAME=eth0 # Network interface -export NCCL_P2P_LEVEL=NVL # Use NVLink - -accelerate launch train.py -``` - -**NCCL_P2P_LEVEL options**: -- `NVL`: NVLink (fastest, within node) -- `PIX`: PCIe (fast, within node) -- `PHB`: PCIe host bridge (slow, cross-node) - -## Data Loading Optimization - -### 1. DataLoader Workers - -```python -from torch.utils.data import DataLoader - -train_loader = DataLoader( - dataset, - batch_size=32, - num_workers=4, # Parallel data loading - pin_memory=True, # Pin memory for faster GPU transfer - prefetch_factor=2, # Prefetch batches per worker - persistent_workers=True # Keep workers alive between epochs -) - -train_loader = accelerator.prepare(train_loader) -``` - -**Recommendations**: -- `num_workers`: 2-4 per GPU (8 GPUs → 16-32 workers) -- `pin_memory`: Always True for GPU training -- `prefetch_factor`: 2-4 (higher for slow data loading) - -### 2. Data Preprocessing - -```python -from datasets import load_dataset - -# Bad: Preprocess during training (slow) -dataset = load_dataset("openwebtext") - -for batch in dataset: - tokens = tokenizer(batch['text']) # Slow! - ... - -# Good: Preprocess once, save -dataset = load_dataset("openwebtext") -tokenized = dataset.map( - lambda x: tokenizer(x['text']), - batched=True, - num_proc=8, # Parallel preprocessing - remove_columns=['text'] -) -tokenized.save_to_disk("preprocessed_data") - -# Load preprocessed -dataset = load_from_disk("preprocessed_data") -``` - -### 3. Faster Tokenization - -```python -import os - -# Enable Rust-based tokenizers (10× faster) -os.environ["TOKENIZERS_PARALLELISM"] = "true" - -from transformers import AutoTokenizer - -tokenizer = AutoTokenizer.from_pretrained( - "gpt2", - use_fast=True # Use fast Rust tokenizer -) -``` - -## Compilation (PyTorch 2.0+) - -### Compile Model - -```python -import torch - -# Compile model for faster execution -model = torch.compile( - model, - mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune - fullgraph=False, # Compile entire graph (stricter) - dynamic=True # Support dynamic shapes -) - -model = accelerator.prepare(model) -``` - -**Speedup**: 10-50% depending on model - -**Compilation modes**: -- `default`: Balanced (best for most cases) -- `reduce-overhead`: Min overhead (best for small batches) -- `max-autotune`: Max performance (slow compile, best for production) - -### Compilation Best Practices - -```python -# Bad: Compile after prepare (won't work) -model = accelerator.prepare(model) -model = torch.compile(model) # Error! - -# Good: Compile before prepare -model = torch.compile(model) -model = accelerator.prepare(model) - -# Training loop -for batch in dataloader: - # First iteration: slow (compilation) - # Subsequent iterations: fast (compiled) - outputs = model(**batch) - ... -``` - -## Benchmarking Different Strategies - -### Script Template - -```python -import time -import torch -from accelerate import Accelerator - -def benchmark_strategy(strategy_name, accelerator_kwargs): - """Benchmark a specific training strategy.""" - accelerator = Accelerator(**accelerator_kwargs) - - # Setup - model = create_model() - optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) - dataloader = create_dataloader() - - model, optimizer, dataloader = accelerator.prepare( - model, optimizer, dataloader - ) - - # Warmup - for i, batch in enumerate(dataloader): - if i >= 10: - break - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - - # Benchmark - accelerator.wait_for_everyone() - torch.cuda.synchronize() - start = time.time() - - num_batches = 100 - for i, batch in enumerate(dataloader): - if i >= num_batches: - break - - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - - accelerator.wait_for_everyone() - torch.cuda.synchronize() - elapsed = time.time() - start - - # Metrics - throughput = (num_batches * batch_size * accelerator.num_processes) / elapsed - memory_used = torch.cuda.max_memory_allocated() / 1e9 # GB - - if accelerator.is_main_process: - print(f"\n{strategy_name}:") - print(f" Throughput: {throughput:.2f} samples/sec") - print(f" Memory: {memory_used:.2f} GB") - print(f" Time: {elapsed:.2f} sec") - - torch.cuda.reset_peak_memory_stats() - -# Benchmark different strategies -strategies = [ - ("DDP + FP32", {}), - ("DDP + BF16", {"mixed_precision": "bf16"}), - ("DDP + BF16 + GradAccum", {"mixed_precision": "bf16", "gradient_accumulation_steps": 4}), - ("FSDP", {"fsdp_plugin": fsdp_plugin}), - ("DeepSpeed ZeRO-2", {"deepspeed_plugin": ds_plugin_stage2}), - ("DeepSpeed ZeRO-3", {"deepspeed_plugin": ds_plugin_stage3}), -] - -for name, kwargs in strategies: - benchmark_strategy(name, kwargs) -``` - -## Performance Checklist - -**Before training**: -- [ ] Use BF16/FP16 mixed precision -- [ ] Enable gradient checkpointing (if OOM) -- [ ] Set appropriate `num_workers` (2-4 per GPU) -- [ ] Enable `pin_memory=True` -- [ ] Preprocess data once, not during training -- [ ] Compile model with `torch.compile` (PyTorch 2.0+) - -**For large models**: -- [ ] Use FSDP or DeepSpeed ZeRO-3 -- [ ] Enable CPU offloading (if still OOM) -- [ ] Use Flash Attention -- [ ] Increase gradient accumulation - -**For multi-node**: -- [ ] Check network topology (InfiniBand > Ethernet) -- [ ] Tune NCCL settings -- [ ] Use larger bucket sizes for DDP -- [ ] Verify NVLink for tensor parallelism - -**Profiling**: -- [ ] Profile first 10-100 batches -- [ ] Check GPU utilization (`nvidia-smi dmon`) -- [ ] Check data loading time (should be <5% of iteration) -- [ ] Identify communication bottlenecks - -## Common Performance Issues - -### Issue: Low GPU Utilization (<80%) - -**Cause 1**: Data loading bottleneck -```python -# Solution: Increase workers and prefetch -num_workers=8 -prefetch_factor=4 -``` - -**Cause 2**: Small batch size -```python -# Solution: Increase batch size or use gradient accumulation -batch_size=32 # Increase -gradient_accumulation_steps=4 # Or accumulate -``` - -### Issue: High Memory Usage - -**Solution 1**: Gradient checkpointing -```python -model.gradient_checkpointing_enable() -``` - -**Solution 2**: Reduce batch size, increase accumulation -```python -batch_size=8 # Reduce from 32 -gradient_accumulation_steps=16 # Maintain effective batch -``` - -**Solution 3**: Use FSDP or DeepSpeed ZeRO-3 -```python -accelerator = Accelerator(fsdp_plugin=fsdp_plugin) -``` - -### Issue: Slow Multi-GPU Training - -**Cause**: Communication bottleneck - -**Check 1**: Gradient bucket size -```python -ddp_kwargs = DistributedDataParallelKwargs(bucket_cap_mb=100) -``` - -**Check 2**: NCCL settings -```bash -export NCCL_DEBUG=INFO -# Check for "Using NVLS" (good) vs "Using PHB" (bad) -``` - -**Check 3**: Network bandwidth -```bash -# Test inter-GPU bandwidth -nvidia-smi nvlink -s -``` - -## Resources - -- Accelerate Performance: https://huggingface.co/docs/accelerate/usage_guides/performance -- PyTorch Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html -- NCCL Tuning: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html -- Flash Attention: https://github.com/Dao-AILab/flash-attention diff --git a/skills/mlops/audiocraft/SKILL.md b/skills/mlops/audiocraft/SKILL.md deleted file mode 100644 index 3d3bf7158..000000000 --- a/skills/mlops/audiocraft/SKILL.md +++ /dev/null @@ -1,567 +0,0 @@ ---- -name: audiocraft-audio-generation -description: PyTorch library for audio generation including text-to-music (MusicGen) and text-to-sound (AudioGen). Use when you need to generate music from text descriptions, create sound effects, or perform melody-conditioned music generation. -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0] -metadata: - hermes: - tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen] - ---- - -# AudioCraft: Audio Generation - -Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec. - -## When to use AudioCraft - -**Use AudioCraft when:** -- Need to generate music from text descriptions -- Creating sound effects and environmental audio -- Building music generation applications -- Need melody-conditioned music generation -- Want stereo audio output -- Require controllable music generation with style transfer - -**Key features:** -- **MusicGen**: Text-to-music generation with melody conditioning -- **AudioGen**: Text-to-sound effects generation -- **EnCodec**: High-fidelity neural audio codec -- **Multiple model sizes**: Small (300M) to Large (3.3B) -- **Stereo support**: Full stereo audio generation -- **Style conditioning**: MusicGen-Style for reference-based generation - -**Use alternatives instead:** -- **Stable Audio**: For longer commercial music generation -- **Bark**: For text-to-speech with music/sound effects -- **Riffusion**: For spectogram-based music generation -- **OpenAI Jukebox**: For raw audio generation with lyrics - -## Quick start - -### Installation - -```bash -# From PyPI -pip install audiocraft - -# From GitHub (latest) -pip install git+https://github.com/facebookresearch/audiocraft.git - -# Or use HuggingFace Transformers -pip install transformers torch torchaudio -``` - -### Basic text-to-music (AudioCraft) - -```python -import torchaudio -from audiocraft.models import MusicGen - -# Load model -model = MusicGen.get_pretrained('facebook/musicgen-small') - -# Set generation parameters -model.set_generation_params( - duration=8, # seconds - top_k=250, - temperature=1.0 -) - -# Generate from text -descriptions = ["happy upbeat electronic dance music with synths"] -wav = model.generate(descriptions) - -# Save audio -torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000) -``` - -### Using HuggingFace Transformers - -```python -from transformers import AutoProcessor, MusicgenForConditionalGeneration -import scipy - -# Load model and processor -processor = AutoProcessor.from_pretrained("facebook/musicgen-small") -model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") -model.to("cuda") - -# Generate music -inputs = processor( - text=["80s pop track with bassy drums and synth"], - padding=True, - return_tensors="pt" -).to("cuda") - -audio_values = model.generate( - **inputs, - do_sample=True, - guidance_scale=3, - max_new_tokens=256 -) - -# Save -sampling_rate = model.config.audio_encoder.sampling_rate -scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy()) -``` - -### Text-to-sound with AudioGen - -```python -from audiocraft.models import AudioGen - -# Load AudioGen -model = AudioGen.get_pretrained('facebook/audiogen-medium') - -model.set_generation_params(duration=5) - -# Generate sound effects -descriptions = ["dog barking in a park with birds chirping"] -wav = model.generate(descriptions) - -torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000) -``` - -## Core concepts - -### Architecture overview - -``` -AudioCraft Architecture: -┌──────────────────────────────────────────────────────────────┐ -│ Text Encoder (T5) │ -│ │ │ -│ Text Embeddings │ -└────────────────────────┬─────────────────────────────────────┘ - │ -┌────────────────────────▼─────────────────────────────────────┐ -│ Transformer Decoder (LM) │ -│ Auto-regressively generates audio tokens │ -│ Using efficient token interleaving patterns │ -└────────────────────────┬─────────────────────────────────────┘ - │ -┌────────────────────────▼─────────────────────────────────────┐ -│ EnCodec Audio Decoder │ -│ Converts tokens back to audio waveform │ -└──────────────────────────────────────────────────────────────┘ -``` - -### Model variants - -| Model | Size | Description | Use Case | -|-------|------|-------------|----------| -| `musicgen-small` | 300M | Text-to-music | Quick generation | -| `musicgen-medium` | 1.5B | Text-to-music | Balanced | -| `musicgen-large` | 3.3B | Text-to-music | Best quality | -| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning | -| `musicgen-melody-large` | 3.3B | Text + melody | Best melody | -| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation | -| `musicgen-style` | 1.5B | Style transfer | Reference-based | -| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects | - -### Generation parameters - -| Parameter | Default | Description | -|-----------|---------|-------------| -| `duration` | 8.0 | Length in seconds (1-120) | -| `top_k` | 250 | Top-k sampling | -| `top_p` | 0.0 | Nucleus sampling (0 = disabled) | -| `temperature` | 1.0 | Sampling temperature | -| `cfg_coef` | 3.0 | Classifier-free guidance | - -## MusicGen usage - -### Text-to-music generation - -```python -from audiocraft.models import MusicGen -import torchaudio - -model = MusicGen.get_pretrained('facebook/musicgen-medium') - -# Configure generation -model.set_generation_params( - duration=30, # Up to 30 seconds - top_k=250, # Sampling diversity - top_p=0.0, # 0 = use top_k only - temperature=1.0, # Creativity (higher = more varied) - cfg_coef=3.0 # Text adherence (higher = stricter) -) - -# Generate multiple samples -descriptions = [ - "epic orchestral soundtrack with strings and brass", - "chill lo-fi hip hop beat with jazzy piano", - "energetic rock song with electric guitar" -] - -# Generate (returns [batch, channels, samples]) -wav = model.generate(descriptions) - -# Save each -for i, audio in enumerate(wav): - torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000) -``` - -### Melody-conditioned generation - -```python -from audiocraft.models import MusicGen -import torchaudio - -# Load melody model -model = MusicGen.get_pretrained('facebook/musicgen-melody') -model.set_generation_params(duration=30) - -# Load melody audio -melody, sr = torchaudio.load("melody.wav") - -# Generate with melody conditioning -descriptions = ["acoustic guitar folk song"] -wav = model.generate_with_chroma(descriptions, melody, sr) - -torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000) -``` - -### Stereo generation - -```python -from audiocraft.models import MusicGen - -# Load stereo model -model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium') -model.set_generation_params(duration=15) - -descriptions = ["ambient electronic music with wide stereo panning"] -wav = model.generate(descriptions) - -# wav shape: [batch, 2, samples] for stereo -print(f"Stereo shape: {wav.shape}") # [1, 2, 480000] -torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000) -``` - -### Audio continuation - -```python -from transformers import AutoProcessor, MusicgenForConditionalGeneration - -processor = AutoProcessor.from_pretrained("facebook/musicgen-medium") -model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium") - -# Load audio to continue -import torchaudio -audio, sr = torchaudio.load("intro.wav") - -# Process with text and audio -inputs = processor( - audio=audio.squeeze().numpy(), - sampling_rate=sr, - text=["continue with a epic chorus"], - padding=True, - return_tensors="pt" -) - -# Generate continuation -audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512) -``` - -## MusicGen-Style usage - -### Style-conditioned generation - -```python -from audiocraft.models import MusicGen - -# Load style model -model = MusicGen.get_pretrained('facebook/musicgen-style') - -# Configure generation with style -model.set_generation_params( - duration=30, - cfg_coef=3.0, - cfg_coef_beta=5.0 # Style influence -) - -# Configure style conditioner -model.set_style_conditioner_params( - eval_q=3, # RVQ quantizers (1-6) - excerpt_length=3.0 # Style excerpt length -) - -# Load style reference -style_audio, sr = torchaudio.load("reference_style.wav") - -# Generate with text + style -descriptions = ["upbeat dance track"] -wav = model.generate_with_style(descriptions, style_audio, sr) -``` - -### Style-only generation (no text) - -```python -# Generate matching style without text prompt -model.set_generation_params( - duration=30, - cfg_coef=3.0, - cfg_coef_beta=None # Disable double CFG for style-only -) - -wav = model.generate_with_style([None], style_audio, sr) -``` - -## AudioGen usage - -### Sound effect generation - -```python -from audiocraft.models import AudioGen -import torchaudio - -model = AudioGen.get_pretrained('facebook/audiogen-medium') -model.set_generation_params(duration=10) - -# Generate various sounds -descriptions = [ - "thunderstorm with heavy rain and lightning", - "busy city traffic with car horns", - "ocean waves crashing on rocks", - "crackling campfire in forest" -] - -wav = model.generate(descriptions) - -for i, audio in enumerate(wav): - torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000) -``` - -## EnCodec usage - -### Audio compression - -```python -from audiocraft.models import CompressionModel -import torch -import torchaudio - -# Load EnCodec -model = CompressionModel.get_pretrained('facebook/encodec_32khz') - -# Load audio -wav, sr = torchaudio.load("audio.wav") - -# Ensure correct sample rate -if sr != 32000: - resampler = torchaudio.transforms.Resample(sr, 32000) - wav = resampler(wav) - -# Encode to tokens -with torch.no_grad(): - encoded = model.encode(wav.unsqueeze(0)) - codes = encoded[0] # Audio codes - -# Decode back to audio -with torch.no_grad(): - decoded = model.decode(codes) - -torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000) -``` - -## Common workflows - -### Workflow 1: Music generation pipeline - -```python -import torch -import torchaudio -from audiocraft.models import MusicGen - -class MusicGenerator: - def __init__(self, model_name="facebook/musicgen-medium"): - self.model = MusicGen.get_pretrained(model_name) - self.sample_rate = 32000 - - def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0): - self.model.set_generation_params( - duration=duration, - top_k=250, - temperature=temperature, - cfg_coef=cfg - ) - - with torch.no_grad(): - wav = self.model.generate([prompt]) - - return wav[0].cpu() - - def generate_batch(self, prompts, duration=30): - self.model.set_generation_params(duration=duration) - - with torch.no_grad(): - wav = self.model.generate(prompts) - - return wav.cpu() - - def save(self, audio, path): - torchaudio.save(path, audio, sample_rate=self.sample_rate) - -# Usage -generator = MusicGenerator() -audio = generator.generate( - "epic cinematic orchestral music", - duration=30, - temperature=1.0 -) -generator.save(audio, "epic_music.wav") -``` - -### Workflow 2: Sound design batch processing - -```python -import json -from pathlib import Path -from audiocraft.models import AudioGen -import torchaudio - -def batch_generate_sounds(sound_specs, output_dir): - """ - Generate multiple sounds from specifications. - - Args: - sound_specs: list of {"name": str, "description": str, "duration": float} - output_dir: output directory path - """ - model = AudioGen.get_pretrained('facebook/audiogen-medium') - output_dir = Path(output_dir) - output_dir.mkdir(exist_ok=True) - - results = [] - - for spec in sound_specs: - model.set_generation_params(duration=spec.get("duration", 5)) - - wav = model.generate([spec["description"]]) - - output_path = output_dir / f"{spec['name']}.wav" - torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000) - - results.append({ - "name": spec["name"], - "path": str(output_path), - "description": spec["description"] - }) - - return results - -# Usage -sounds = [ - {"name": "explosion", "description": "massive explosion with debris", "duration": 3}, - {"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5}, - {"name": "door", "description": "wooden door creaking and closing", "duration": 2} -] - -results = batch_generate_sounds(sounds, "sound_effects/") -``` - -### Workflow 3: Gradio demo - -```python -import gradio as gr -import torch -import torchaudio -from audiocraft.models import MusicGen - -model = MusicGen.get_pretrained('facebook/musicgen-small') - -def generate_music(prompt, duration, temperature, cfg_coef): - model.set_generation_params( - duration=duration, - temperature=temperature, - cfg_coef=cfg_coef - ) - - with torch.no_grad(): - wav = model.generate([prompt]) - - # Save to temp file - path = "temp_output.wav" - torchaudio.save(path, wav[0].cpu(), sample_rate=32000) - return path - -demo = gr.Interface( - fn=generate_music, - inputs=[ - gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"), - gr.Slider(1, 30, value=8, label="Duration (seconds)"), - gr.Slider(0.5, 2.0, value=1.0, label="Temperature"), - gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient") - ], - outputs=gr.Audio(label="Generated Music"), - title="MusicGen Demo" -) - -demo.launch() -``` - -## Performance optimization - -### Memory optimization - -```python -# Use smaller model -model = MusicGen.get_pretrained('facebook/musicgen-small') - -# Clear cache between generations -torch.cuda.empty_cache() - -# Generate shorter durations -model.set_generation_params(duration=10) # Instead of 30 - -# Use half precision -model = model.half() -``` - -### Batch processing efficiency - -```python -# Process multiple prompts at once (more efficient) -descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"] -wav = model.generate(descriptions) # Single batch - -# Instead of -for desc in descriptions: - wav = model.generate([desc]) # Multiple batches (slower) -``` - -### GPU memory requirements - -| Model | FP32 VRAM | FP16 VRAM | -|-------|-----------|-----------| -| musicgen-small | ~4GB | ~2GB | -| musicgen-medium | ~8GB | ~4GB | -| musicgen-large | ~16GB | ~8GB | - -## Common issues - -| Issue | Solution | -|-------|----------| -| CUDA OOM | Use smaller model, reduce duration | -| Poor quality | Increase cfg_coef, better prompts | -| Generation too short | Check max duration setting | -| Audio artifacts | Try different temperature | -| Stereo not working | Use stereo model variant | - -## References - -- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment -- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions - -## Resources - -- **GitHub**: https://github.com/facebookresearch/audiocraft -- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284 -- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352 -- **HuggingFace**: https://huggingface.co/facebook/musicgen-small -- **Demo**: https://huggingface.co/spaces/facebook/MusicGen diff --git a/skills/mlops/audiocraft/references/advanced-usage.md b/skills/mlops/audiocraft/references/advanced-usage.md deleted file mode 100644 index 953be2b4a..000000000 --- a/skills/mlops/audiocraft/references/advanced-usage.md +++ /dev/null @@ -1,666 +0,0 @@ -# AudioCraft Advanced Usage Guide - -## Fine-tuning MusicGen - -### Custom dataset preparation - -```python -import os -import json -from pathlib import Path -import torchaudio - -def prepare_dataset(audio_dir, output_dir, metadata_file): - """ - Prepare dataset for MusicGen fine-tuning. - - Directory structure: - output_dir/ - ├── audio/ - │ ├── 0001.wav - │ ├── 0002.wav - │ └── ... - └── metadata.json - """ - output_dir = Path(output_dir) - audio_output = output_dir / "audio" - audio_output.mkdir(parents=True, exist_ok=True) - - # Load metadata (format: {"path": "...", "description": "..."}) - with open(metadata_file) as f: - metadata = json.load(f) - - processed = [] - - for idx, item in enumerate(metadata): - audio_path = Path(audio_dir) / item["path"] - - # Load and resample to 32kHz - wav, sr = torchaudio.load(str(audio_path)) - if sr != 32000: - resampler = torchaudio.transforms.Resample(sr, 32000) - wav = resampler(wav) - - # Convert to mono if stereo - if wav.shape[0] > 1: - wav = wav.mean(dim=0, keepdim=True) - - # Save processed audio - output_path = audio_output / f"{idx:04d}.wav" - torchaudio.save(str(output_path), wav, sample_rate=32000) - - processed.append({ - "path": str(output_path.relative_to(output_dir)), - "description": item["description"], - "duration": wav.shape[1] / 32000 - }) - - # Save processed metadata - with open(output_dir / "metadata.json", "w") as f: - json.dump(processed, f, indent=2) - - print(f"Processed {len(processed)} samples") - return processed -``` - -### Fine-tuning with dora - -```bash -# AudioCraft uses dora for experiment management -# Install dora -pip install dora-search - -# Clone AudioCraft -git clone https://github.com/facebookresearch/audiocraft.git -cd audiocraft - -# Create config for fine-tuning -cat > config/solver/musicgen/finetune.yaml << 'EOF' -defaults: - - musicgen/musicgen_base - - /model: lm/musicgen_lm - - /conditioner: cond_base - -solver: musicgen -autocast: true -autocast_dtype: float16 - -optim: - epochs: 100 - batch_size: 4 - lr: 1e-4 - ema: 0.999 - optimizer: adamw - -dataset: - batch_size: 4 - num_workers: 4 - train: - - dset: your_dataset - root: /path/to/dataset - valid: - - dset: your_dataset - root: /path/to/dataset - -checkpoint: - save_every: 10 - keep_every_states: null -EOF - -# Run fine-tuning -dora run solver=musicgen/finetune -``` - -### LoRA fine-tuning - -```python -from peft import LoraConfig, get_peft_model -from audiocraft.models import MusicGen -import torch - -# Load base model -model = MusicGen.get_pretrained('facebook/musicgen-small') - -# Get the language model component -lm = model.lm - -# Configure LoRA -lora_config = LoraConfig( - r=8, - lora_alpha=16, - target_modules=["q_proj", "v_proj", "k_proj", "out_proj"], - lora_dropout=0.05, - bias="none" -) - -# Apply LoRA -lm = get_peft_model(lm, lora_config) -lm.print_trainable_parameters() -``` - -## Multi-GPU Training - -### DataParallel - -```python -import torch -import torch.nn as nn -from audiocraft.models import MusicGen - -model = MusicGen.get_pretrained('facebook/musicgen-small') - -# Wrap LM with DataParallel -if torch.cuda.device_count() > 1: - model.lm = nn.DataParallel(model.lm) - -model.to("cuda") -``` - -### DistributedDataParallel - -```python -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP - -def setup(rank, world_size): - dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(rank) - -def train(rank, world_size): - setup(rank, world_size) - - model = MusicGen.get_pretrained('facebook/musicgen-small') - model.lm = model.lm.to(rank) - model.lm = DDP(model.lm, device_ids=[rank]) - - # Training loop - # ... - - dist.destroy_process_group() -``` - -## Custom Conditioning - -### Adding new conditioners - -```python -from audiocraft.modules.conditioners import BaseConditioner -import torch - -class CustomConditioner(BaseConditioner): - """Custom conditioner for additional control signals.""" - - def __init__(self, dim, output_dim): - super().__init__(dim, output_dim) - self.embed = torch.nn.Linear(dim, output_dim) - - def forward(self, x): - return self.embed(x) - - def tokenize(self, x): - # Tokenize input for conditioning - return x - -# Use with MusicGen -from audiocraft.models.builders import get_lm_model - -# Modify model config to include custom conditioner -# This requires editing the model configuration -``` - -### Melody conditioning internals - -```python -from audiocraft.models import MusicGen -from audiocraft.modules.codebooks_patterns import DelayedPatternProvider -import torch - -model = MusicGen.get_pretrained('facebook/musicgen-melody') - -# Access chroma extractor -chroma_extractor = model.lm.condition_provider.conditioners.get('chroma') - -# Manual chroma extraction -def extract_chroma(audio, sr): - """Extract chroma features from audio.""" - import librosa - - # Compute chroma - chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr) - - return torch.from_numpy(chroma).float() - -# Use extracted chroma for conditioning -chroma = extract_chroma(melody_audio, sample_rate) -``` - -## EnCodec Deep Dive - -### Custom compression settings - -```python -from audiocraft.models import CompressionModel -import torch - -# Load EnCodec -encodec = CompressionModel.get_pretrained('facebook/encodec_32khz') - -# Access codec parameters -print(f"Sample rate: {encodec.sample_rate}") -print(f"Channels: {encodec.channels}") -print(f"Cardinality: {encodec.cardinality}") # Codebook size -print(f"Num codebooks: {encodec.num_codebooks}") -print(f"Frame rate: {encodec.frame_rate}") - -# Encode with specific bandwidth -# Lower bandwidth = more compression, lower quality -encodec.set_target_bandwidth(6.0) # 6 kbps - -audio = torch.randn(1, 1, 32000) # 1 second -encoded = encodec.encode(audio) -decoded = encodec.decode(encoded[0]) -``` - -### Streaming encoding - -```python -import torch -from audiocraft.models import CompressionModel - -encodec = CompressionModel.get_pretrained('facebook/encodec_32khz') - -def encode_streaming(audio_stream, chunk_size=32000): - """Encode audio in streaming fashion.""" - all_codes = [] - - for chunk in audio_stream: - # Ensure chunk is right shape - if chunk.dim() == 1: - chunk = chunk.unsqueeze(0).unsqueeze(0) - - with torch.no_grad(): - codes = encodec.encode(chunk)[0] - all_codes.append(codes) - - return torch.cat(all_codes, dim=-1) - -def decode_streaming(codes_stream, output_stream): - """Decode codes in streaming fashion.""" - for codes in codes_stream: - with torch.no_grad(): - audio = encodec.decode(codes) - output_stream.write(audio.cpu().numpy()) -``` - -## MultiBand Diffusion - -### Using MBD for enhanced quality - -```python -from audiocraft.models import MusicGen, MultiBandDiffusion - -# Load MusicGen -model = MusicGen.get_pretrained('facebook/musicgen-medium') - -# Load MultiBand Diffusion -mbd = MultiBandDiffusion.get_mbd_musicgen() - -model.set_generation_params(duration=10) - -# Generate with standard decoder -descriptions = ["epic orchestral music"] -wav_standard = model.generate(descriptions) - -# Generate tokens and use MBD decoder -with torch.no_grad(): - # Get tokens - gen_tokens = model.generate_tokens(descriptions) - - # Decode with MBD - wav_mbd = mbd.tokens_to_wav(gen_tokens) - -# Compare quality -print(f"Standard shape: {wav_standard.shape}") -print(f"MBD shape: {wav_mbd.shape}") -``` - -## API Server Deployment - -### FastAPI server - -```python -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -import torch -import torchaudio -from audiocraft.models import MusicGen -import io -import base64 - -app = FastAPI() - -# Load model at startup -model = None - -@app.on_event("startup") -async def load_model(): - global model - model = MusicGen.get_pretrained('facebook/musicgen-small') - model.set_generation_params(duration=10) - -class GenerateRequest(BaseModel): - prompt: str - duration: float = 10.0 - temperature: float = 1.0 - cfg_coef: float = 3.0 - -class GenerateResponse(BaseModel): - audio_base64: str - sample_rate: int - duration: float - -@app.post("/generate", response_model=GenerateResponse) -async def generate(request: GenerateRequest): - if model is None: - raise HTTPException(status_code=500, detail="Model not loaded") - - try: - model.set_generation_params( - duration=min(request.duration, 30), - temperature=request.temperature, - cfg_coef=request.cfg_coef - ) - - with torch.no_grad(): - wav = model.generate([request.prompt]) - - # Convert to bytes - buffer = io.BytesIO() - torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav") - buffer.seek(0) - - audio_base64 = base64.b64encode(buffer.read()).decode() - - return GenerateResponse( - audio_base64=audio_base64, - sample_rate=32000, - duration=wav.shape[-1] / 32000 - ) - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/health") -async def health(): - return {"status": "ok", "model_loaded": model is not None} - -# Run: uvicorn server:app --host 0.0.0.0 --port 8000 -``` - -### Batch processing service - -```python -import asyncio -from concurrent.futures import ThreadPoolExecutor -import torch -from audiocraft.models import MusicGen - -class MusicGenService: - def __init__(self, model_name='facebook/musicgen-small', max_workers=2): - self.model = MusicGen.get_pretrained(model_name) - self.executor = ThreadPoolExecutor(max_workers=max_workers) - self.lock = asyncio.Lock() - - async def generate_async(self, prompt, duration=10): - """Async generation with thread pool.""" - loop = asyncio.get_event_loop() - - def _generate(): - with torch.no_grad(): - self.model.set_generation_params(duration=duration) - return self.model.generate([prompt]) - - # Run in thread pool - wav = await loop.run_in_executor(self.executor, _generate) - return wav[0].cpu() - - async def generate_batch_async(self, prompts, duration=10): - """Process multiple prompts concurrently.""" - tasks = [self.generate_async(p, duration) for p in prompts] - return await asyncio.gather(*tasks) - -# Usage -service = MusicGenService() - -async def main(): - prompts = ["jazz piano", "rock guitar", "electronic beats"] - results = await service.generate_batch_async(prompts) - return results -``` - -## Integration Patterns - -### LangChain tool - -```python -from langchain.tools import BaseTool -import torch -import torchaudio -from audiocraft.models import MusicGen -import tempfile - -class MusicGeneratorTool(BaseTool): - name = "music_generator" - description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments." - - def __init__(self): - super().__init__() - self.model = MusicGen.get_pretrained('facebook/musicgen-small') - self.model.set_generation_params(duration=15) - - def _run(self, description: str) -> str: - with torch.no_grad(): - wav = self.model.generate([description]) - - # Save to temp file - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: - torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000) - return f"Generated music saved to: {f.name}" - - async def _arun(self, description: str) -> str: - return self._run(description) -``` - -### Gradio with advanced controls - -```python -import gradio as gr -import torch -import torchaudio -from audiocraft.models import MusicGen - -models = {} - -def load_model(model_size): - if model_size not in models: - model_name = f"facebook/musicgen-{model_size}" - models[model_size] = MusicGen.get_pretrained(model_name) - return models[model_size] - -def generate(prompt, duration, temperature, cfg_coef, top_k, model_size): - model = load_model(model_size) - - model.set_generation_params( - duration=duration, - temperature=temperature, - cfg_coef=cfg_coef, - top_k=top_k - ) - - with torch.no_grad(): - wav = model.generate([prompt]) - - # Save - path = "output.wav" - torchaudio.save(path, wav[0].cpu(), sample_rate=32000) - return path - -demo = gr.Interface( - fn=generate, - inputs=[ - gr.Textbox(label="Prompt", lines=3), - gr.Slider(1, 30, value=10, label="Duration (s)"), - gr.Slider(0.1, 2.0, value=1.0, label="Temperature"), - gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"), - gr.Slider(50, 500, value=250, step=50, label="Top-K"), - gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size") - ], - outputs=gr.Audio(label="Generated Music"), - title="MusicGen Advanced", - allow_flagging="never" -) - -demo.launch(share=True) -``` - -## Audio Processing Pipeline - -### Post-processing chain - -```python -import torch -import torchaudio -import torchaudio.transforms as T -import numpy as np - -class AudioPostProcessor: - def __init__(self, sample_rate=32000): - self.sample_rate = sample_rate - - def normalize(self, audio, target_db=-14.0): - """Normalize audio to target loudness.""" - rms = torch.sqrt(torch.mean(audio ** 2)) - target_rms = 10 ** (target_db / 20) - gain = target_rms / (rms + 1e-8) - return audio * gain - - def fade_in_out(self, audio, fade_duration=0.1): - """Apply fade in/out.""" - fade_samples = int(fade_duration * self.sample_rate) - - # Create fade curves - fade_in = torch.linspace(0, 1, fade_samples) - fade_out = torch.linspace(1, 0, fade_samples) - - # Apply fades - audio[..., :fade_samples] *= fade_in - audio[..., -fade_samples:] *= fade_out - - return audio - - def apply_reverb(self, audio, decay=0.5): - """Apply simple reverb effect.""" - impulse = torch.zeros(int(self.sample_rate * 0.5)) - impulse[0] = 1.0 - impulse[int(self.sample_rate * 0.1)] = decay * 0.5 - impulse[int(self.sample_rate * 0.2)] = decay * 0.25 - - # Convolve - audio = torch.nn.functional.conv1d( - audio.unsqueeze(0), - impulse.unsqueeze(0).unsqueeze(0), - padding=len(impulse) // 2 - ).squeeze(0) - - return audio - - def process(self, audio): - """Full processing pipeline.""" - audio = self.normalize(audio) - audio = self.fade_in_out(audio) - return audio - -# Usage with MusicGen -from audiocraft.models import MusicGen - -model = MusicGen.get_pretrained('facebook/musicgen-small') -model.set_generation_params(duration=10) - -wav = model.generate(["chill ambient music"]) -processor = AudioPostProcessor() -wav_processed = processor.process(wav[0].cpu()) - -torchaudio.save("processed.wav", wav_processed, sample_rate=32000) -``` - -## Evaluation - -### Audio quality metrics - -```python -import torch -from audiocraft.metrics import CLAPTextConsistencyMetric -from audiocraft.data.audio import audio_read - -def evaluate_generation(audio_path, text_prompt): - """Evaluate generated audio quality.""" - # Load audio - wav, sr = audio_read(audio_path) - - # CLAP consistency (text-audio alignment) - clap_metric = CLAPTextConsistencyMetric() - clap_score = clap_metric.compute(wav, [text_prompt]) - - return { - "clap_score": clap_score, - "duration": wav.shape[-1] / sr - } - -# Batch evaluation -def evaluate_batch(generations): - """Evaluate multiple generations.""" - results = [] - for gen in generations: - result = evaluate_generation(gen["path"], gen["prompt"]) - result["prompt"] = gen["prompt"] - results.append(result) - - # Aggregate - avg_clap = sum(r["clap_score"] for r in results) / len(results) - return { - "individual": results, - "average_clap": avg_clap - } -``` - -## Model Comparison - -### MusicGen variants benchmark - -| Model | CLAP Score | Generation Time (10s) | VRAM | -|-------|------------|----------------------|------| -| musicgen-small | 0.35 | ~5s | 2GB | -| musicgen-medium | 0.42 | ~15s | 4GB | -| musicgen-large | 0.48 | ~30s | 8GB | -| musicgen-melody | 0.45 | ~15s | 4GB | -| musicgen-stereo-medium | 0.41 | ~18s | 5GB | - -### Prompt engineering tips - -```python -# Good prompts - specific and descriptive -good_prompts = [ - "upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm", - "melancholic piano ballad with strings, slow tempo, emotional and cinematic", - "funky disco groove with slap bass, brass section, and rhythmic guitar" -] - -# Bad prompts - too vague -bad_prompts = [ - "nice music", - "song", - "good beat" -] - -# Structure: [mood] [genre] with [instruments] at [tempo/style] -``` diff --git a/skills/mlops/audiocraft/references/troubleshooting.md b/skills/mlops/audiocraft/references/troubleshooting.md deleted file mode 100644 index 7b83e863d..000000000 --- a/skills/mlops/audiocraft/references/troubleshooting.md +++ /dev/null @@ -1,504 +0,0 @@ -# AudioCraft Troubleshooting Guide - -## Installation Issues - -### Import errors - -**Error**: `ModuleNotFoundError: No module named 'audiocraft'` - -**Solutions**: -```bash -# Install from PyPI -pip install audiocraft - -# Or from GitHub -pip install git+https://github.com/facebookresearch/audiocraft.git - -# Verify installation -python -c "from audiocraft.models import MusicGen; print('OK')" -``` - -### FFmpeg not found - -**Error**: `RuntimeError: ffmpeg not found` - -**Solutions**: -```bash -# Ubuntu/Debian -sudo apt-get install ffmpeg - -# macOS -brew install ffmpeg - -# Windows (using conda) -conda install -c conda-forge ffmpeg - -# Verify -ffmpeg -version -``` - -### PyTorch CUDA mismatch - -**Error**: `RuntimeError: CUDA error: no kernel image is available` - -**Solutions**: -```bash -# Check CUDA version -nvcc --version -python -c "import torch; print(torch.version.cuda)" - -# Install matching PyTorch -pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121 - -# For CUDA 11.8 -pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118 -``` - -### xformers issues - -**Error**: `ImportError: xformers` related errors - -**Solutions**: -```bash -# Install xformers for memory efficiency -pip install xformers - -# Or disable xformers -export AUDIOCRAFT_USE_XFORMERS=0 - -# In Python -import os -os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0" -from audiocraft.models import MusicGen -``` - -## Model Loading Issues - -### Out of memory during load - -**Error**: `torch.cuda.OutOfMemoryError` during model loading - -**Solutions**: -```python -# Use smaller model -model = MusicGen.get_pretrained('facebook/musicgen-small') - -# Force CPU loading first -import torch -device = "cpu" -model = MusicGen.get_pretrained('facebook/musicgen-small', device=device) -model = model.to("cuda") - -# Use HuggingFace with device_map -from transformers import MusicgenForConditionalGeneration -model = MusicgenForConditionalGeneration.from_pretrained( - "facebook/musicgen-small", - device_map="auto" -) -``` - -### Download failures - -**Error**: Connection errors or incomplete downloads - -**Solutions**: -```python -# Set cache directory -import os -os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache" - -# Or for HuggingFace -os.environ["HF_HOME"] = "/path/to/hf_cache" - -# Resume download -from huggingface_hub import snapshot_download -snapshot_download("facebook/musicgen-small", resume_download=True) - -# Use local files -model = MusicGen.get_pretrained('/local/path/to/model') -``` - -### Wrong model type - -**Error**: Loading wrong model for task - -**Solutions**: -```python -# For text-to-music: use MusicGen -from audiocraft.models import MusicGen -model = MusicGen.get_pretrained('facebook/musicgen-medium') - -# For text-to-sound: use AudioGen -from audiocraft.models import AudioGen -model = AudioGen.get_pretrained('facebook/audiogen-medium') - -# For melody conditioning: use melody variant -model = MusicGen.get_pretrained('facebook/musicgen-melody') - -# For stereo: use stereo variant -model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium') -``` - -## Generation Issues - -### Empty or silent output - -**Problem**: Generated audio is silent or very quiet - -**Solutions**: -```python -import torch - -# Check output -wav = model.generate(["upbeat music"]) -print(f"Shape: {wav.shape}") -print(f"Max amplitude: {wav.abs().max().item()}") -print(f"Mean amplitude: {wav.abs().mean().item()}") - -# If too quiet, normalize -def normalize_audio(audio, target_db=-14.0): - rms = torch.sqrt(torch.mean(audio ** 2)) - target_rms = 10 ** (target_db / 20) - gain = target_rms / (rms + 1e-8) - return audio * gain - -wav_normalized = normalize_audio(wav) -``` - -### Poor quality output - -**Problem**: Generated music sounds bad or noisy - -**Solutions**: -```python -# Use larger model -model = MusicGen.get_pretrained('facebook/musicgen-large') - -# Adjust generation parameters -model.set_generation_params( - duration=15, - top_k=250, # Increase for more diversity - temperature=0.8, # Lower for more focused output - cfg_coef=4.0 # Increase for better text adherence -) - -# Use better prompts -# Bad: "music" -# Good: "upbeat electronic dance music with synthesizers and punchy drums" - -# Try MultiBand Diffusion -from audiocraft.models import MultiBandDiffusion -mbd = MultiBandDiffusion.get_mbd_musicgen() -tokens = model.generate_tokens(["prompt"]) -wav = mbd.tokens_to_wav(tokens) -``` - -### Generation too short - -**Problem**: Audio shorter than expected - -**Solutions**: -```python -# Check duration setting -model.set_generation_params(duration=30) # Set before generate - -# Verify in generation -print(f"Duration setting: {model.generation_params}") - -# Check output shape -wav = model.generate(["prompt"]) -actual_duration = wav.shape[-1] / 32000 -print(f"Actual duration: {actual_duration}s") - -# Note: max duration is typically 30s -``` - -### Melody conditioning fails - -**Error**: Issues with melody-conditioned generation - -**Solutions**: -```python -import torchaudio -from audiocraft.models import MusicGen - -# Load melody model (not base model) -model = MusicGen.get_pretrained('facebook/musicgen-melody') - -# Load and prepare melody -melody, sr = torchaudio.load("melody.wav") - -# Resample to model sample rate if needed -if sr != 32000: - resampler = torchaudio.transforms.Resample(sr, 32000) - melody = resampler(melody) - -# Ensure correct shape [batch, channels, samples] -if melody.dim() == 1: - melody = melody.unsqueeze(0).unsqueeze(0) -elif melody.dim() == 2: - melody = melody.unsqueeze(0) - -# Convert stereo to mono -if melody.shape[1] > 1: - melody = melody.mean(dim=1, keepdim=True) - -# Generate with melody -model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30)) -wav = model.generate_with_chroma(["piano cover"], melody, 32000) -``` - -## Memory Issues - -### CUDA out of memory - -**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory` - -**Solutions**: -```python -import torch - -# Clear cache before generation -torch.cuda.empty_cache() - -# Use smaller model -model = MusicGen.get_pretrained('facebook/musicgen-small') - -# Reduce duration -model.set_generation_params(duration=10) # Instead of 30 - -# Generate one at a time -for prompt in prompts: - wav = model.generate([prompt]) - save_audio(wav) - torch.cuda.empty_cache() - -# Use CPU for very large generations -model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu") -``` - -### Memory leak during batch processing - -**Problem**: Memory grows over time - -**Solutions**: -```python -import gc -import torch - -def generate_with_cleanup(model, prompts): - results = [] - - for prompt in prompts: - with torch.no_grad(): - wav = model.generate([prompt]) - results.append(wav.cpu()) - - # Cleanup - del wav - gc.collect() - torch.cuda.empty_cache() - - return results - -# Use context manager -with torch.inference_mode(): - wav = model.generate(["prompt"]) -``` - -## Audio Format Issues - -### Wrong sample rate - -**Problem**: Audio plays at wrong speed - -**Solutions**: -```python -import torchaudio - -# MusicGen outputs at 32kHz -sample_rate = 32000 - -# AudioGen outputs at 16kHz -sample_rate = 16000 - -# Always use correct rate when saving -torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate) - -# Resample if needed -resampler = torchaudio.transforms.Resample(32000, 44100) -wav_resampled = resampler(wav) -``` - -### Stereo/mono mismatch - -**Problem**: Wrong number of channels - -**Solutions**: -```python -# Check model type -print(f"Audio channels: {wav.shape}") -# Mono: [batch, 1, samples] -# Stereo: [batch, 2, samples] - -# Convert mono to stereo -if wav.shape[1] == 1: - wav_stereo = wav.repeat(1, 2, 1) - -# Convert stereo to mono -if wav.shape[1] == 2: - wav_mono = wav.mean(dim=1, keepdim=True) - -# Use stereo model for stereo output -model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium') -``` - -### Clipping and distortion - -**Problem**: Audio has clipping or distortion - -**Solutions**: -```python -import torch - -# Check for clipping -max_val = wav.abs().max().item() -print(f"Max amplitude: {max_val}") - -# Normalize to prevent clipping -if max_val > 1.0: - wav = wav / max_val - -# Apply soft clipping -def soft_clip(x, threshold=0.9): - return torch.tanh(x / threshold) * threshold - -wav_clipped = soft_clip(wav) - -# Lower temperature during generation -model.set_generation_params(temperature=0.7) # More controlled -``` - -## HuggingFace Transformers Issues - -### Processor errors - -**Error**: Issues with MusicgenProcessor - -**Solutions**: -```python -from transformers import AutoProcessor, MusicgenForConditionalGeneration - -# Load matching processor and model -processor = AutoProcessor.from_pretrained("facebook/musicgen-small") -model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") - -# Ensure inputs are on same device -inputs = processor( - text=["prompt"], - padding=True, - return_tensors="pt" -).to("cuda") - -# Check processor configuration -print(processor.tokenizer) -print(processor.feature_extractor) -``` - -### Generation parameter errors - -**Error**: Invalid generation parameters - -**Solutions**: -```python -# HuggingFace uses different parameter names -audio_values = model.generate( - **inputs, - do_sample=True, # Enable sampling - guidance_scale=3.0, # CFG (not cfg_coef) - max_new_tokens=256, # Token limit (not duration) - temperature=1.0 -) - -# Calculate tokens from duration -# ~50 tokens per second -duration_seconds = 10 -max_tokens = duration_seconds * 50 -audio_values = model.generate(**inputs, max_new_tokens=max_tokens) -``` - -## Performance Issues - -### Slow generation - -**Problem**: Generation takes too long - -**Solutions**: -```python -# Use smaller model -model = MusicGen.get_pretrained('facebook/musicgen-small') - -# Reduce duration -model.set_generation_params(duration=10) - -# Use GPU -model.to("cuda") - -# Enable flash attention if available -# (requires compatible hardware) - -# Batch multiple prompts -prompts = ["prompt1", "prompt2", "prompt3"] -wav = model.generate(prompts) # Single batch is faster than loop - -# Use compile (PyTorch 2.0+) -model.lm = torch.compile(model.lm) -``` - -### CPU fallback - -**Problem**: Generation running on CPU instead of GPU - -**Solutions**: -```python -import torch - -# Check CUDA availability -print(f"CUDA available: {torch.cuda.is_available()}") -print(f"CUDA device: {torch.cuda.get_device_name(0)}") - -# Explicitly move to GPU -model = MusicGen.get_pretrained('facebook/musicgen-small') -model.to("cuda") - -# Verify model device -print(f"Model device: {next(model.lm.parameters()).device}") -``` - -## Common Error Messages - -| Error | Cause | Solution | -|-------|-------|----------| -| `CUDA out of memory` | Model too large | Use smaller model, reduce duration | -| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg | -| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` | -| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions | -| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody | -| `Sample rate mismatch` | Wrong audio format | Resample to model rate | - -## Getting Help - -1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues -2. **HuggingFace Forums**: https://discuss.huggingface.co -3. **Paper**: https://arxiv.org/abs/2306.05284 - -### Reporting Issues - -Include: -- Python version -- PyTorch version -- CUDA version -- AudioCraft version: `pip show audiocraft` -- Full error traceback -- Minimal reproducible code -- Hardware (GPU model, VRAM) diff --git a/skills/mlops/code-review/SKILL.md b/skills/mlops/code-review/SKILL.md deleted file mode 100644 index 08efacda0..000000000 --- a/skills/mlops/code-review/SKILL.md +++ /dev/null @@ -1,81 +0,0 @@ ---- -name: code-review -description: Guidelines for performing thorough code reviews with security and quality focus ---- - -# Code Review Skill - -Use this skill when reviewing code changes, pull requests, or auditing existing code. - -## Review Checklist - -### 1. Security First -- [ ] No hardcoded secrets, API keys, or credentials -- [ ] Input validation on all user-provided data -- [ ] SQL queries use parameterized statements (no string concatenation) -- [ ] File operations validate paths (no path traversal) -- [ ] Authentication/authorization checks present where needed - -### 2. Error Handling -- [ ] All external calls (API, DB, file) have try/catch -- [ ] Errors are logged with context (but no sensitive data) -- [ ] User-facing errors are helpful but don't leak internals -- [ ] Resources are cleaned up in finally blocks or context managers - -### 3. Code Quality -- [ ] Functions do one thing and are reasonably sized (<50 lines ideal) -- [ ] Variable names are descriptive (no single letters except loops) -- [ ] No commented-out code left behind -- [ ] Complex logic has explanatory comments -- [ ] No duplicate code (DRY principle) - -### 4. Testing Considerations -- [ ] Edge cases handled (empty inputs, nulls, boundaries) -- [ ] Happy path and error paths both work -- [ ] New code has corresponding tests (if test suite exists) - -## Review Response Format - -When providing review feedback, structure it as: - -``` -## Summary -[1-2 sentence overall assessment] - -## Critical Issues (Must Fix) -- Issue 1: [description + suggested fix] -- Issue 2: ... - -## Suggestions (Nice to Have) -- Suggestion 1: [description] - -## Questions -- [Any clarifying questions about intent] -``` - -## Common Patterns to Flag - -### Python -```python -# Bad: SQL injection risk -cursor.execute(f"SELECT * FROM users WHERE id = {user_id}") - -# Good: Parameterized query -cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,)) -``` - -### JavaScript -```javascript -// Bad: XSS risk -element.innerHTML = userInput; - -// Good: Safe text content -element.textContent = userInput; -``` - -## Tone Guidelines - -- Be constructive, not critical -- Explain *why* something is an issue, not just *what* -- Offer solutions, not just problems -- Acknowledge good patterns you see diff --git a/skills/mlops/faiss/SKILL.md b/skills/mlops/faiss/SKILL.md deleted file mode 100644 index 2e33007b3..000000000 --- a/skills/mlops/faiss/SKILL.md +++ /dev/null @@ -1,224 +0,0 @@ ---- -name: faiss -description: Facebook's library for efficient similarity search and clustering of dense vectors. Supports billions of vectors, GPU acceleration, and various index types (Flat, IVF, HNSW). Use for fast k-NN search, large-scale vector retrieval, or when you need pure similarity search without metadata. Best for high-performance applications. -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [faiss-cpu, faiss-gpu, numpy] -metadata: - hermes: - tags: [RAG, FAISS, Similarity Search, Vector Search, Facebook AI, GPU Acceleration, Billion-Scale, K-NN, HNSW, High Performance, Large Scale] - ---- - -# FAISS - Efficient Similarity Search - -Facebook AI's library for billion-scale vector similarity search. - -## When to use FAISS - -**Use FAISS when:** -- Need fast similarity search on large vector datasets (millions/billions) -- GPU acceleration required -- Pure vector similarity (no metadata filtering needed) -- High throughput, low latency critical -- Offline/batch processing of embeddings - -**Metrics**: -- **31,700+ GitHub stars** -- Meta/Facebook AI Research -- **Handles billions of vectors** -- **C++** with Python bindings - -**Use alternatives instead**: -- **Chroma/Pinecone**: Need metadata filtering -- **Weaviate**: Need full database features -- **Annoy**: Simpler, fewer features - -## Quick start - -### Installation - -```bash -# CPU only -pip install faiss-cpu - -# GPU support -pip install faiss-gpu -``` - -### Basic usage - -```python -import faiss -import numpy as np - -# Create sample data (1000 vectors, 128 dimensions) -d = 128 -nb = 1000 -vectors = np.random.random((nb, d)).astype('float32') - -# Create index -index = faiss.IndexFlatL2(d) # L2 distance -index.add(vectors) # Add vectors - -# Search -k = 5 # Find 5 nearest neighbors -query = np.random.random((1, d)).astype('float32') -distances, indices = index.search(query, k) - -print(f"Nearest neighbors: {indices}") -print(f"Distances: {distances}") -``` - -## Index types - -### 1. Flat (exact search) - -```python -# L2 (Euclidean) distance -index = faiss.IndexFlatL2(d) - -# Inner product (cosine similarity if normalized) -index = faiss.IndexFlatIP(d) - -# Slowest, most accurate -``` - -### 2. IVF (inverted file) - Fast approximate - -```python -# Create quantizer -quantizer = faiss.IndexFlatL2(d) - -# IVF index with 100 clusters -nlist = 100 -index = faiss.IndexIVFFlat(quantizer, d, nlist) - -# Train on data -index.train(vectors) - -# Add vectors -index.add(vectors) - -# Search (nprobe = clusters to search) -index.nprobe = 10 -distances, indices = index.search(query, k) -``` - -### 3. HNSW (Hierarchical NSW) - Best quality/speed - -```python -# HNSW index -M = 32 # Number of connections per layer -index = faiss.IndexHNSWFlat(d, M) - -# No training needed -index.add(vectors) - -# Search -distances, indices = index.search(query, k) -``` - -### 4. Product Quantization - Memory efficient - -```python -# PQ reduces memory by 16-32× -m = 8 # Number of subquantizers -nbits = 8 -index = faiss.IndexPQ(d, m, nbits) - -# Train and add -index.train(vectors) -index.add(vectors) -``` - -## Save and load - -```python -# Save index -faiss.write_index(index, "large.index") - -# Load index -index = faiss.read_index("large.index") - -# Continue using -distances, indices = index.search(query, k) -``` - -## GPU acceleration - -```python -# Single GPU -res = faiss.StandardGpuResources() -index_cpu = faiss.IndexFlatL2(d) -index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0 - -# Multi-GPU -index_gpu = faiss.index_cpu_to_all_gpus(index_cpu) - -# 10-100× faster than CPU -``` - -## LangChain integration - -```python -from langchain_community.vectorstores import FAISS -from langchain_openai import OpenAIEmbeddings - -# Create FAISS vector store -vectorstore = FAISS.from_documents(docs, OpenAIEmbeddings()) - -# Save -vectorstore.save_local("faiss_index") - -# Load -vectorstore = FAISS.load_local( - "faiss_index", - OpenAIEmbeddings(), - allow_dangerous_deserialization=True -) - -# Search -results = vectorstore.similarity_search("query", k=5) -``` - -## LlamaIndex integration - -```python -from llama_index.vector_stores.faiss import FaissVectorStore -import faiss - -# Create FAISS index -d = 1536 -faiss_index = faiss.IndexFlatL2(d) - -vector_store = FaissVectorStore(faiss_index=faiss_index) -``` - -## Best practices - -1. **Choose right index type** - Flat for <10K, IVF for 10K-1M, HNSW for quality -2. **Normalize for cosine** - Use IndexFlatIP with normalized vectors -3. **Use GPU for large datasets** - 10-100× faster -4. **Save trained indices** - Training is expensive -5. **Tune nprobe/ef_search** - Balance speed/accuracy -6. **Monitor memory** - PQ for large datasets -7. **Batch queries** - Better GPU utilization - -## Performance - -| Index Type | Build Time | Search Time | Memory | Accuracy | -|------------|------------|-------------|--------|----------| -| Flat | Fast | Slow | High | 100% | -| IVF | Medium | Fast | Medium | 95-99% | -| HNSW | Slow | Fastest | High | 99% | -| PQ | Medium | Fast | Low | 90-95% | - -## Resources - -- **GitHub**: https://github.com/facebookresearch/faiss ⭐ 31,700+ -- **Wiki**: https://github.com/facebookresearch/faiss/wiki -- **License**: MIT - - diff --git a/skills/mlops/faiss/references/index_types.md b/skills/mlops/faiss/references/index_types.md deleted file mode 100644 index f75bd3e9e..000000000 --- a/skills/mlops/faiss/references/index_types.md +++ /dev/null @@ -1,280 +0,0 @@ -# FAISS Index Types Guide - -Complete guide to choosing and using FAISS index types. - -## Index selection guide - -| Dataset Size | Index Type | Training | Accuracy | Speed | -|--------------|------------|----------|----------|-------| -| < 10K | Flat | No | 100% | Slow | -| 10K-1M | IVF | Yes | 95-99% | Fast | -| 1M-10M | HNSW | No | 99% | Fastest | -| > 10M | IVF+PQ | Yes | 90-95% | Fast, low memory | - -## Flat indices (exact search) - -### IndexFlatL2 - L2 (Euclidean) distance - -```python -import faiss -import numpy as np - -d = 128 # Dimension -index = faiss.IndexFlatL2(d) - -# Add vectors -vectors = np.random.random((1000, d)).astype('float32') -index.add(vectors) - -# Search -k = 5 -query = np.random.random((1, d)).astype('float32') -distances, indices = index.search(query, k) -``` - -**Use when:** -- Dataset < 10,000 vectors -- Need 100% accuracy -- Serving as baseline - -### IndexFlatIP - Inner product (cosine similarity) - -```python -# For cosine similarity, normalize vectors first -import faiss - -d = 128 -index = faiss.IndexFlatIP(d) - -# Normalize vectors (required for cosine similarity) -faiss.normalize_L2(vectors) -index.add(vectors) - -# Search -faiss.normalize_L2(query) -distances, indices = index.search(query, k) -``` - -**Use when:** -- Need cosine similarity -- Recommendation systems -- Text embeddings - -## IVF indices (inverted file) - -### IndexIVFFlat - Cluster-based search - -```python -# Create quantizer -quantizer = faiss.IndexFlatL2(d) - -# Create IVF index with 100 clusters -nlist = 100 # Number of clusters -index = faiss.IndexIVFFlat(quantizer, d, nlist) - -# Train on data (required!) -index.train(vectors) - -# Add vectors -index.add(vectors) - -# Search (nprobe = clusters to search) -index.nprobe = 10 # Search 10 closest clusters -distances, indices = index.search(query, k) -``` - -**Parameters:** -- `nlist`: Number of clusters (√N to 4√N recommended) -- `nprobe`: Clusters to search (1-nlist, higher = more accurate) - -**Use when:** -- Dataset 10K-1M vectors -- Need fast approximate search -- Can afford training time - -### Tuning nprobe - -```python -# Test different nprobe values -for nprobe in [1, 5, 10, 20, 50]: - index.nprobe = nprobe - distances, indices = index.search(query, k) - # Measure recall/speed trade-off -``` - -**Guidelines:** -- `nprobe=1`: Fastest, ~50% recall -- `nprobe=10`: Good balance, ~95% recall -- `nprobe=nlist`: Exact search (same as Flat) - -## HNSW indices (graph-based) - -### IndexHNSWFlat - Hierarchical NSW - -```python -# HNSW index -M = 32 # Number of connections per layer (16-64) -index = faiss.IndexHNSWFlat(d, M) - -# Optional: Set ef_construction (build time parameter) -index.hnsw.efConstruction = 40 # Higher = better quality, slower build - -# Add vectors (no training needed!) -index.add(vectors) - -# Search -index.hnsw.efSearch = 16 # Search time parameter -distances, indices = index.search(query, k) -``` - -**Parameters:** -- `M`: Connections per layer (16-64, default 32) -- `efConstruction`: Build quality (40-200, higher = better) -- `efSearch`: Search quality (16-512, higher = more accurate) - -**Use when:** -- Need best quality approximate search -- Can afford higher memory (more connections) -- Dataset 1M-10M vectors - -## PQ indices (product quantization) - -### IndexPQ - Memory-efficient - -```python -# PQ reduces memory by 16-32× -m = 8 # Number of subquantizers (divides d) -nbits = 8 # Bits per subquantizer - -index = faiss.IndexPQ(d, m, nbits) - -# Train (required!) -index.train(vectors) - -# Add vectors -index.add(vectors) - -# Search -distances, indices = index.search(query, k) -``` - -**Parameters:** -- `m`: Subquantizers (d must be divisible by m) -- `nbits`: Bits per code (8 or 16) - -**Memory savings:** -- Original: d × 4 bytes (float32) -- PQ: m bytes -- Compression ratio: 4d/m - -**Use when:** -- Limited memory -- Large datasets (> 10M vectors) -- Can accept ~90-95% accuracy - -### IndexIVFPQ - IVF + PQ combined - -```python -# Best for very large datasets -nlist = 4096 -m = 8 -nbits = 8 - -quantizer = faiss.IndexFlatL2(d) -index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits) - -# Train -index.train(vectors) -index.add(vectors) - -# Search -index.nprobe = 32 -distances, indices = index.search(query, k) -``` - -**Use when:** -- Dataset > 10M vectors -- Need fast search + low memory -- Can accept 90-95% accuracy - -## GPU indices - -### Single GPU - -```python -import faiss - -# Create CPU index -index_cpu = faiss.IndexFlatL2(d) - -# Move to GPU -res = faiss.StandardGpuResources() # GPU resources -index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0 - -# Use normally -index_gpu.add(vectors) -distances, indices = index_gpu.search(query, k) -``` - -### Multi-GPU - -```python -# Use all available GPUs -index_gpu = faiss.index_cpu_to_all_gpus(index_cpu) - -# Or specific GPUs -gpus = [0, 1, 2, 3] # Use GPUs 0-3 -index_gpu = faiss.index_cpu_to_gpus_list(index_cpu, gpus) -``` - -**Speedup:** -- Single GPU: 10-50× faster than CPU -- Multi-GPU: Near-linear scaling - -## Index factory - -```python -# Easy index creation with string descriptors -index = faiss.index_factory(d, "IVF100,Flat") -index = faiss.index_factory(d, "HNSW32") -index = faiss.index_factory(d, "IVF4096,PQ8") - -# Train and use -index.train(vectors) -index.add(vectors) -``` - -**Common descriptors:** -- `"Flat"`: Exact search -- `"IVF100,Flat"`: IVF with 100 clusters -- `"HNSW32"`: HNSW with M=32 -- `"IVF4096,PQ8"`: IVF + PQ compression - -## Performance comparison - -### Search speed (1M vectors, k=10) - -| Index | Build Time | Search Time | Memory | Recall | -|-------|------------|-------------|--------|--------| -| Flat | 0s | 50ms | 512 MB | 100% | -| IVF100 | 5s | 2ms | 512 MB | 95% | -| HNSW32 | 60s | 1ms | 1GB | 99% | -| IVF4096+PQ8 | 30s | 3ms | 32 MB | 90% | - -*CPU (16 cores), 128-dim vectors* - -## Best practices - -1. **Start with Flat** - Baseline for comparison -2. **Use IVF for medium datasets** - Good balance -3. **Use HNSW for best quality** - If memory allows -4. **Add PQ for memory savings** - Large datasets -5. **GPU for > 100K vectors** - 10-50× speedup -6. **Tune nprobe/efSearch** - Trade-off speed/accuracy -7. **Train on representative data** - Better clustering -8. **Save trained indices** - Avoid retraining - -## Resources - -- **Wiki**: https://github.com/facebookresearch/faiss/wiki -- **Paper**: https://arxiv.org/abs/1702.08734 diff --git a/skills/mlops/flash-attention/SKILL.md b/skills/mlops/flash-attention/SKILL.md deleted file mode 100644 index 6a3839bf7..000000000 --- a/skills/mlops/flash-attention/SKILL.md +++ /dev/null @@ -1,370 +0,0 @@ ---- -name: optimizing-attention-flash -description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention. -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [flash-attn, torch, transformers] -metadata: - hermes: - tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers] - ---- - -# Flash Attention - Fast Memory-Efficient Attention - -## Quick start - -Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation. - -**PyTorch native (easiest, PyTorch 2.2+)**: -```python -import torch -import torch.nn.functional as F - -q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim] -k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) -v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) - -# Automatically uses Flash Attention if available -out = F.scaled_dot_product_attention(q, k, v) -``` - -**flash-attn library (more features)**: -```bash -pip install flash-attn --no-build-isolation -``` - -```python -from flash_attn import flash_attn_func - -# q, k, v: [batch, seqlen, nheads, headdim] -out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True) -``` - -## Common workflows - -### Workflow 1: Enable in existing PyTorch model - -Copy this checklist: - -``` -Flash Attention Integration: -- [ ] Step 1: Check PyTorch version (≥2.2) -- [ ] Step 2: Enable Flash Attention backend -- [ ] Step 3: Verify speedup with profiling -- [ ] Step 4: Test accuracy matches baseline -``` - -**Step 1: Check PyTorch version** - -```bash -python -c "import torch; print(torch.__version__)" -# Should be ≥2.2.0 -``` - -If <2.2, upgrade: -```bash -pip install --upgrade torch -``` - -**Step 2: Enable Flash Attention backend** - -Replace standard attention: -```python -# Before (standard attention) -attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1) -out = attn_weights @ v - -# After (Flash Attention) -import torch.nn.functional as F -out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) -``` - -Force Flash Attention backend: -```python -with torch.backends.cuda.sdp_kernel( - enable_flash=True, - enable_math=False, - enable_mem_efficient=False -): - out = F.scaled_dot_product_attention(q, k, v) -``` - -**Step 3: Verify speedup with profiling** - -```python -import torch.utils.benchmark as benchmark - -def test_attention(use_flash): - q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)] - - if use_flash: - with torch.backends.cuda.sdp_kernel(enable_flash=True): - return F.scaled_dot_product_attention(q, k, v) - else: - attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1) - return attn @ v - -# Benchmark -t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals()) -t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals()) - -print(f"Flash: {t_flash.timeit(100).mean:.3f}s") -print(f"Standard: {t_standard.timeit(100).mean:.3f}s") -``` - -Expected: 2-4x speedup for sequences >512 tokens. - -**Step 4: Test accuracy matches baseline** - -```python -# Compare outputs -q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)] - -# Flash Attention -out_flash = F.scaled_dot_product_attention(q, k, v) - -# Standard attention -attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1) -out_standard = attn_weights @ v - -# Check difference -diff = (out_flash - out_standard).abs().max() -print(f"Max difference: {diff:.6f}") -# Should be <1e-3 for float16 -``` - -### Workflow 2: Use flash-attn library for advanced features - -For multi-query attention, sliding window, or H100 FP8. - -Copy this checklist: - -``` -flash-attn Library Setup: -- [ ] Step 1: Install flash-attn library -- [ ] Step 2: Modify attention code -- [ ] Step 3: Enable advanced features -- [ ] Step 4: Benchmark performance -``` - -**Step 1: Install flash-attn library** - -```bash -# NVIDIA GPUs (CUDA 12.0+) -pip install flash-attn --no-build-isolation - -# Verify installation -python -c "from flash_attn import flash_attn_func; print('Success')" -``` - -**Step 2: Modify attention code** - -```python -from flash_attn import flash_attn_func - -# Input: [batch_size, seq_len, num_heads, head_dim] -# Transpose from [batch, heads, seq, dim] if needed -q = q.transpose(1, 2) # [batch, seq, heads, dim] -k = k.transpose(1, 2) -v = v.transpose(1, 2) - -out = flash_attn_func( - q, k, v, - dropout_p=0.1, - causal=True, # For autoregressive models - window_size=(-1, -1), # No sliding window - softmax_scale=None # Auto-scale -) - -out = out.transpose(1, 2) # Back to [batch, heads, seq, dim] -``` - -**Step 3: Enable advanced features** - -Multi-query attention (shared K/V across heads): -```python -from flash_attn import flash_attn_func - -# q: [batch, seq, num_q_heads, dim] -# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads -out = flash_attn_func(q, k, v) # Automatically handles MQA -``` - -Sliding window attention (local attention): -```python -# Only attend to window of 256 tokens before/after -out = flash_attn_func( - q, k, v, - window_size=(256, 256), # (left, right) window - causal=True -) -``` - -**Step 4: Benchmark performance** - -```python -import torch -from flash_attn import flash_attn_func -import time - -q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)] - -# Warmup -for _ in range(10): - _ = flash_attn_func(q, k, v) - -# Benchmark -torch.cuda.synchronize() -start = time.time() -for _ in range(100): - out = flash_attn_func(q, k, v) - torch.cuda.synchronize() -end = time.time() - -print(f"Time per iteration: {(end-start)/100*1000:.2f}ms") -print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB") -``` - -### Workflow 3: H100 FP8 optimization (FlashAttention-3) - -For maximum performance on H100 GPUs. - -``` -FP8 Setup: -- [ ] Step 1: Verify H100 GPU available -- [ ] Step 2: Install flash-attn with FP8 support -- [ ] Step 3: Convert inputs to FP8 -- [ ] Step 4: Run with FP8 attention -``` - -**Step 1: Verify H100 GPU** - -```bash -nvidia-smi --query-gpu=name --format=csv -# Should show "H100" or "H800" -``` - -**Step 2: Install flash-attn with FP8 support** - -```bash -pip install flash-attn --no-build-isolation -# FP8 support included for H100 -``` - -**Step 3: Convert inputs to FP8** - -```python -import torch - -q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16) -k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16) -v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16) - -# Convert to float8_e4m3 (FP8) -q_fp8 = q.to(torch.float8_e4m3fn) -k_fp8 = k.to(torch.float8_e4m3fn) -v_fp8 = v.to(torch.float8_e4m3fn) -``` - -**Step 4: Run with FP8 attention** - -```python -from flash_attn import flash_attn_func - -# FlashAttention-3 automatically uses FP8 kernels on H100 -out = flash_attn_func(q_fp8, k_fp8, v_fp8) -# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16 -``` - -## When to use vs alternatives - -**Use Flash Attention when:** -- Training transformers with sequences >512 tokens -- Running inference with long context (>2K tokens) -- GPU memory constrained (OOM with standard attention) -- Need 2-4x speedup without accuracy loss -- Using PyTorch 2.2+ or can install flash-attn - -**Use alternatives instead:** -- **Standard attention**: Sequences <256 tokens (overhead not worth it) -- **xFormers**: Need more attention variants (not just speed) -- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU) - -## Common issues - -**Issue: ImportError: cannot import flash_attn** - -Install with no-build-isolation flag: -```bash -pip install flash-attn --no-build-isolation -``` - -Or install CUDA toolkit first: -```bash -conda install cuda -c nvidia -pip install flash-attn --no-build-isolation -``` - -**Issue: Slower than expected (no speedup)** - -Flash Attention benefits increase with sequence length: -- <512 tokens: Minimal speedup (10-20%) -- 512-2K tokens: 2-3x speedup -- >2K tokens: 3-4x speedup - -Check sequence length is sufficient. - -**Issue: RuntimeError: CUDA error** - -Verify GPU supports Flash Attention: -```python -import torch -print(torch.cuda.get_device_capability()) -# Should be ≥(7, 5) for Turing+ -``` - -Flash Attention requires: -- Ampere (A100, A10): ✅ Full support -- Turing (T4): ✅ Supported -- Volta (V100): ❌ Not supported - -**Issue: Accuracy degradation** - -Check dtype is float16 or bfloat16 (not float32): -```python -q = q.to(torch.float16) # Or torch.bfloat16 -``` - -Flash Attention uses float16/bfloat16 for speed. Float32 not supported. - -## Advanced topics - -**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models. - -**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths. - -**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis. - -**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks. - -## Hardware requirements - -- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+ -- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory) -- **CUDA**: 12.0+ (11.8 minimum) -- **PyTorch**: 2.2+ for native support - -**Not supported**: V100 (Volta), CPU inference - -## Resources - -- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022) -- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024) -- Blog: https://tridao.me/blog/2024/flash3/ -- GitHub: https://github.com/Dao-AILab/flash-attention -- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - - - diff --git a/skills/mlops/flash-attention/references/benchmarks.md b/skills/mlops/flash-attention/references/benchmarks.md deleted file mode 100644 index f798a6dda..000000000 --- a/skills/mlops/flash-attention/references/benchmarks.md +++ /dev/null @@ -1,215 +0,0 @@ -# Performance Benchmarks - -## Contents -- Speed comparisons across GPUs -- Memory usage analysis -- Scaling with sequence length -- Training vs inference performance -- Flash Attention versions comparison - -## Speed comparisons across GPUs - -### A100 80GB (Ampere) - -**Forward pass time** (milliseconds, batch=8, heads=32, dim=64): - -| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) | -|------------|----------|--------------|--------------|---------------| -| 512 | 1.2 | 0.9 | N/A | 1.3x | -| 1024 | 3.8 | 1.4 | N/A | 2.7x | -| 2048 | 14.2 | 4.8 | N/A | 3.0x | -| 4096 | 55.1 | 17.3 | N/A | 3.2x | -| 8192 | 218.5 | 66.2 | N/A | 3.3x | - -### H100 80GB (Hopper) - -**Forward pass time** (milliseconds, same config): - -| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup | -|------------|----------|--------------|---------------------|--------------------|--------------| -| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x | -| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x | -| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x | -| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x | -| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x | - -**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max). - -### A10G 24GB (Ampere) - -**Forward pass time** (milliseconds, batch=4): - -| Seq Length | Standard | Flash Attn 2 | Speedup | -|------------|----------|--------------|---------| -| 512 | 2.1 | 1.6 | 1.3x | -| 1024 | 6.8 | 2.8 | 2.4x | -| 2048 | 25.9 | 9.4 | 2.8x | -| 4096 | 102.1 | 35.2 | 2.9x | - -## Memory usage analysis - -### GPU memory consumption (batch=8, heads=32, dim=64) - -**Standard attention memory**: - -| Seq Length | Attention Matrix | KV Cache | Total | Notes | -|------------|------------------|----------|-------|-------| -| 512 | 8 MB | 32 MB | 40 MB | Manageable | -| 2048 | 128 MB | 128 MB | 256 MB | Growing | -| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large | -| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs | - -**Flash Attention 2 memory**: - -| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction | -|------------|---------------------|----------|-------|-----------| -| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% | -| 2048 | 0 MB | 128 MB | 128 MB | 50% | -| 8192 | 0 MB | 512 MB | 512 MB | 80% | -| 32768 | 0 MB | 2048 MB | 2 GB | 94% | - -**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory. - -### Memory scaling comparison - -**Llama 2 7B model memory** (float16, batch=1): - -| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? | -|----------------|-------------------|-------------------|-------------------| -| 2K | 3.2 GB | 2.1 GB | Both: Yes | -| 4K | 5.8 GB | 2.8 GB | Both: Yes | -| 8K | 12.1 GB | 4.2 GB | Both: Yes | -| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes | -| 32K | OOM | 14.2 GB | Only Flash: Yes | - -### Training memory (Llama 2 7B, batch=4) - -| Context | Standard (GB) | Flash Attn (GB) | Reduction | -|---------|---------------|-----------------|-----------| -| 2K | 18.2 | 12.4 | 32% | -| 4K | 34.8 | 16.8 | 52% | -| 8K | OOM (>40GB) | 26.2 | Fits! | - -## Scaling with sequence length - -### Computational complexity - -**Standard attention**: -- Time: O(N² × d) -- Memory: O(N² + N × d) - -**Flash Attention**: -- Time: O(N² × d) (same, but with better constants) -- Memory: O(N × d) (linear!) - -### Empirical scaling (A100, batch=1, heads=32, dim=64) - -**Time per token (milliseconds)**: - -| Sequence | 512 | 1K | 2K | 4K | 8K | 16K | -|----------|-----|-----|-----|-----|-----|------| -| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 | -| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 | -| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x | - -**Observation**: Speedup increases quadratically with sequence length! - -### Memory per token (MB) - -| Sequence | 512 | 1K | 2K | 4K | 8K | 16K | -|----------|-----|-----|-----|-----|-----|------| -| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 | -| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | - -**Observation**: Flash Attention memory per token is constant! - -## Training vs inference performance - -### Training (forward + backward, Llama 2 7B, A100) - -| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup | -|-------------|------------------------|--------------------------|---------| -| 4 × 2K | 1.2 | 3.1 | 2.6x | -| 8 × 2K | 2.1 | 5.8 | 2.8x | -| 4 × 4K | 0.4 | 1.3 | 3.3x | -| 8 × 4K | OOM | 2.4 | Enabled | -| 2 × 8K | 0.1 | 0.4 | 4.0x | - -### Inference (generation, Llama 2 7B, A100) - -| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup | -|----------------|----------------------|-------------------------|---------| -| 512 | 48 | 52 | 1.1x | -| 2K | 42 | 62 | 1.5x | -| 4K | 31 | 58 | 1.9x | -| 8K | 18 | 51 | 2.8x | -| 16K | OOM | 42 | Enabled | - -**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses). - -## Flash Attention versions comparison - -### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8) - -| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) | -|--------|-----|-----|------------|-----------| -| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 | -| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 | -| TFLOPS | 180 | 420 | 740 | 1150 | -| GPU util % | 35% | 55% | 75% | 82% | - -**Key improvements**: -- FA2: 2.3x faster than FA1 (better parallelism) -- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations) -- FA3 (FP8): 2.6x faster than FA2 (low precision) - -### Features by version - -| Feature | FA1 | FA2 | FA3 | -|---------|-----|-----|-----| -| Basic attention | ✅ | ✅ | ✅ | -| Causal masking | ✅ | ✅ | ✅ | -| Multi-query attention | ❌ | ✅ | ✅ | -| Sliding window | ❌ | ✅ | ✅ | -| Paged KV cache | ❌ | ✅ | ✅ | -| FP8 support | ❌ | ❌ | ✅ (H100 only) | -| Work partitioning | Basic | Advanced | Optimal | - -## Real-world model benchmarks - -### Llama 2 models (A100 80GB, batch=4, seq=2048) - -| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup | -|-------|--------|------------------------|--------------------------|---------| -| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x | -| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x | -| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x | - -### GPT-style models (seq=1024) - -| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup | -|-------|----------------------|-------------------------|---------| -| GPT-2 (124M) | 520 | 680 | 1.3x | -| GPT-J (6B) | 42 | 98 | 2.3x | -| GPT-NeoX (20B) | 8 | 22 | 2.75x | - -## Recommendations by use case - -**Training large models (>7B parameters)**: -- Use Flash Attention 2 on A100 -- Use Flash Attention 3 FP8 on H100 for maximum speed -- Expected: 2.5-3x speedup - -**Long context inference (>4K tokens)**: -- Flash Attention essential (enables contexts standard attention can't handle) -- Expected: 2-4x speedup, 5-10x memory reduction - -**Short sequences (<512 tokens)**: -- Flash Attention provides 1.2-1.5x speedup -- Minimal memory benefit -- Still worth enabling (no downside) - -**Multi-user serving**: -- Flash Attention reduces per-request memory -- Allows higher concurrent batch sizes -- Can serve 2-3x more users on same hardware diff --git a/skills/mlops/flash-attention/references/transformers-integration.md b/skills/mlops/flash-attention/references/transformers-integration.md deleted file mode 100644 index 48736755d..000000000 --- a/skills/mlops/flash-attention/references/transformers-integration.md +++ /dev/null @@ -1,293 +0,0 @@ -# HuggingFace Transformers Integration - -## Contents -- Enabling Flash Attention in Transformers -- Supported model architectures -- Configuration examples -- Performance comparisons -- Troubleshooting model-specific issues - -## Enabling Flash Attention in Transformers - -HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively. - -**Simple enable for any supported model**: -```python -from transformers import AutoModel - -model = AutoModel.from_pretrained( - "meta-llama/Llama-2-7b-hf", - attn_implementation="flash_attention_2", - torch_dtype=torch.float16, - device_map="auto" -) -``` - -**Install requirements**: -```bash -pip install transformers>=4.36 -pip install flash-attn --no-build-isolation -``` - -## Supported model architectures - -As of Transformers 4.40: - -**Fully supported**: -- Llama / Llama 2 / Llama 3 -- Mistral / Mixtral -- Falcon -- GPT-NeoX -- Phi / Phi-2 / Phi-3 -- Qwen / Qwen2 -- Gemma -- Starcoder2 -- GPT-J -- OPT -- BLOOM - -**Partially supported** (encoder-decoder): -- BART -- T5 / Flan-T5 -- Whisper - -**Check support**: -```python -from transformers import AutoConfig - -config = AutoConfig.from_pretrained("model-name") -print(config._attn_implementation_internal) -# 'flash_attention_2' if supported -``` - -## Configuration examples - -### Llama 2 with Flash Attention - -```python -from transformers import AutoModelForCausalLM, AutoTokenizer -import torch - -model_id = "meta-llama/Llama-2-7b-hf" - -model = AutoModelForCausalLM.from_pretrained( - model_id, - attn_implementation="flash_attention_2", - torch_dtype=torch.float16, - device_map="auto" -) - -tokenizer = AutoTokenizer.from_pretrained(model_id) - -# Generate -inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda") -outputs = model.generate(**inputs, max_length=100) -print(tokenizer.decode(outputs[0])) -``` - -### Mistral with Flash Attention for long context - -```python -from transformers import AutoModelForCausalLM -import torch - -model = AutoModelForCausalLM.from_pretrained( - "mistralai/Mistral-7B-v0.1", - attn_implementation="flash_attention_2", - torch_dtype=torch.bfloat16, # Better for long context - device_map="auto", - max_position_embeddings=32768 # Extended context -) - -# Process long document (32K tokens) -long_text = "..." * 10000 -inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda") -outputs = model.generate(**inputs, max_new_tokens=512) -``` - -### Fine-tuning with Flash Attention - -```python -from transformers import Trainer, TrainingArguments -from transformers import AutoModelForCausalLM - -model = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", - attn_implementation="flash_attention_2", - torch_dtype=torch.float16 -) - -training_args = TrainingArguments( - output_dir="./results", - per_device_train_batch_size=4, - gradient_accumulation_steps=4, - num_train_epochs=3, - fp16=True, # Must match model dtype - optim="adamw_torch_fused" # Fast optimizer -) - -trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_dataset -) - -trainer.train() -``` - -### Multi-GPU training - -```python -from transformers import AutoModelForCausalLM -import torch - -# Model parallelism with Flash Attention -model = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-2-13b-hf", - attn_implementation="flash_attention_2", - torch_dtype=torch.float16, - device_map="auto", # Automatic multi-GPU placement - max_memory={0: "20GB", 1: "20GB"} # Limit per GPU -) -``` - -## Performance comparisons - -### Memory usage (Llama 2 7B, batch=1) - -| Sequence Length | Standard Attention | Flash Attention 2 | Reduction | -|-----------------|-------------------|-------------------|-----------| -| 512 | 1.2 GB | 0.9 GB | 25% | -| 2048 | 3.8 GB | 1.4 GB | 63% | -| 8192 | 14.2 GB | 3.2 GB | 77% | -| 32768 | OOM (>24GB) | 10.8 GB | Fits! | - -### Speed (tokens/sec, A100 80GB) - -| Model | Standard | Flash Attn 2 | Speedup | -|-------|----------|--------------|---------| -| Llama 2 7B (seq=2048) | 42 | 118 | 2.8x | -| Llama 2 13B (seq=4096) | 18 | 52 | 2.9x | -| Llama 2 70B (seq=2048) | 4 | 11 | 2.75x | - -### Training throughput (samples/sec) - -| Model | Batch Size | Standard | Flash Attn 2 | Speedup | -|-------|------------|----------|--------------|---------| -| Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x | -| Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x | -| Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x | - -## Troubleshooting model-specific issues - -### Issue: Model doesn't support Flash Attention - -Check support list above. If not supported, use PyTorch SDPA as fallback: - -```python -model = AutoModelForCausalLM.from_pretrained( - "model-name", - attn_implementation="sdpa", # PyTorch native (still faster) - torch_dtype=torch.float16 -) -``` - -### Issue: CUDA out of memory during loading - -Reduce memory footprint: - -```python -model = AutoModelForCausalLM.from_pretrained( - "model-name", - attn_implementation="flash_attention_2", - torch_dtype=torch.float16, - device_map="auto", - max_memory={0: "18GB"}, # Reserve memory for KV cache - low_cpu_mem_usage=True -) -``` - -### Issue: Slower inference than expected - -Ensure dtype matches: - -```python -# Model and inputs must both be float16/bfloat16 -model = model.to(torch.float16) -inputs = tokenizer(..., return_tensors="pt").to("cuda") -inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v - for k, v in inputs.items()} -``` - -### Issue: Different outputs vs standard attention - -Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal: - -```python -# Compare outputs -model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16) -model_flash = AutoModelForCausalLM.from_pretrained( - "model-name", - attn_implementation="flash_attention_2", - torch_dtype=torch.float16 -) - -inputs = tokenizer("Test", return_tensors="pt").to("cuda") - -with torch.no_grad(): - out_standard = model_standard(**inputs).logits - out_flash = model_flash(**inputs).logits - -diff = (out_standard - out_flash).abs().max() -print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4 -``` - -### Issue: ImportError during model loading - -Install flash-attn: -```bash -pip install flash-attn --no-build-isolation -``` - -Or disable Flash Attention: -```python -model = AutoModelForCausalLM.from_pretrained( - "model-name", - attn_implementation="eager", # Standard PyTorch - torch_dtype=torch.float16 -) -``` - -## Best practices - -1. **Always use float16/bfloat16** with Flash Attention (not float32) -2. **Set device_map="auto"** for automatic memory management -3. **Use bfloat16 for long context** (better numerical stability) -4. **Enable gradient checkpointing** for training large models -5. **Monitor memory** with `torch.cuda.max_memory_allocated()` - -**Example with all best practices**: -```python -from transformers import AutoModelForCausalLM, TrainingArguments - -model = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", - attn_implementation="flash_attention_2", - torch_dtype=torch.bfloat16, # Better for training - device_map="auto", - low_cpu_mem_usage=True -) - -# Enable gradient checkpointing for memory -model.gradient_checkpointing_enable() - -# Training with optimizations -training_args = TrainingArguments( - output_dir="./results", - per_device_train_batch_size=8, - gradient_accumulation_steps=2, - bf16=True, # Match model dtype - optim="adamw_torch_fused", - gradient_checkpointing=True -) -``` diff --git a/skills/mlops/gguf/SKILL.md b/skills/mlops/gguf/SKILL.md deleted file mode 100644 index 21bb176c8..000000000 --- a/skills/mlops/gguf/SKILL.md +++ /dev/null @@ -1,430 +0,0 @@ ---- -name: gguf-quantization -description: GGUF format and llama.cpp quantization for efficient CPU/GPU inference. Use when deploying models on consumer hardware, Apple Silicon, or when needing flexible quantization from 2-8 bit without GPU requirements. -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [llama-cpp-python>=0.2.0] -metadata: - hermes: - tags: [GGUF, Quantization, llama.cpp, CPU Inference, Apple Silicon, Model Compression, Optimization] - ---- - -# GGUF - Quantization Format for llama.cpp - -The GGUF (GPT-Generated Unified Format) is the standard file format for llama.cpp, enabling efficient inference on CPUs, Apple Silicon, and GPUs with flexible quantization options. - -## When to use GGUF - -**Use GGUF when:** -- Deploying on consumer hardware (laptops, desktops) -- Running on Apple Silicon (M1/M2/M3) with Metal acceleration -- Need CPU inference without GPU requirements -- Want flexible quantization (Q2_K to Q8_0) -- Using local AI tools (LM Studio, Ollama, text-generation-webui) - -**Key advantages:** -- **Universal hardware**: CPU, Apple Silicon, NVIDIA, AMD support -- **No Python runtime**: Pure C/C++ inference -- **Flexible quantization**: 2-8 bit with various methods (K-quants) -- **Ecosystem support**: LM Studio, Ollama, koboldcpp, and more -- **imatrix**: Importance matrix for better low-bit quality - -**Use alternatives instead:** -- **AWQ/GPTQ**: Maximum accuracy with calibration on NVIDIA GPUs -- **HQQ**: Fast calibration-free quantization for HuggingFace -- **bitsandbytes**: Simple integration with transformers library -- **TensorRT-LLM**: Production NVIDIA deployment with maximum speed - -## Quick start - -### Installation - -```bash -# Clone llama.cpp -git clone https://github.com/ggml-org/llama.cpp -cd llama.cpp - -# Build (CPU) -make - -# Build with CUDA (NVIDIA) -make GGML_CUDA=1 - -# Build with Metal (Apple Silicon) -make GGML_METAL=1 - -# Install Python bindings (optional) -pip install llama-cpp-python -``` - -### Convert model to GGUF - -```bash -# Install requirements -pip install -r requirements.txt - -# Convert HuggingFace model to GGUF (FP16) -python convert_hf_to_gguf.py ./path/to/model --outfile model-f16.gguf - -# Or specify output type -python convert_hf_to_gguf.py ./path/to/model \ - --outfile model-f16.gguf \ - --outtype f16 -``` - -### Quantize model - -```bash -# Basic quantization to Q4_K_M -./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M - -# Quantize with importance matrix (better quality) -./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix -./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M -``` - -### Run inference - -```bash -# CLI inference -./llama-cli -m model-q4_k_m.gguf -p "Hello, how are you?" - -# Interactive mode -./llama-cli -m model-q4_k_m.gguf --interactive - -# With GPU offload -./llama-cli -m model-q4_k_m.gguf -ngl 35 -p "Hello!" -``` - -## Quantization types - -### K-quant methods (recommended) - -| Type | Bits | Size (7B) | Quality | Use Case | -|------|------|-----------|---------|----------| -| Q2_K | 2.5 | ~2.8 GB | Low | Extreme compression | -| Q3_K_S | 3.0 | ~3.0 GB | Low-Med | Memory constrained | -| Q3_K_M | 3.3 | ~3.3 GB | Medium | Balance | -| Q4_K_S | 4.0 | ~3.8 GB | Med-High | Good balance | -| Q4_K_M | 4.5 | ~4.1 GB | High | **Recommended default** | -| Q5_K_S | 5.0 | ~4.6 GB | High | Quality focused | -| Q5_K_M | 5.5 | ~4.8 GB | Very High | High quality | -| Q6_K | 6.0 | ~5.5 GB | Excellent | Near-original | -| Q8_0 | 8.0 | ~7.2 GB | Best | Maximum quality | - -### Legacy methods - -| Type | Description | -|------|-------------| -| Q4_0 | 4-bit, basic | -| Q4_1 | 4-bit with delta | -| Q5_0 | 5-bit, basic | -| Q5_1 | 5-bit with delta | - -**Recommendation**: Use K-quant methods (Q4_K_M, Q5_K_M) for best quality/size ratio. - -## Conversion workflows - -### Workflow 1: HuggingFace to GGUF - -```bash -# 1. Download model -huggingface-cli download meta-llama/Llama-3.1-8B --local-dir ./llama-3.1-8b - -# 2. Convert to GGUF (FP16) -python convert_hf_to_gguf.py ./llama-3.1-8b \ - --outfile llama-3.1-8b-f16.gguf \ - --outtype f16 - -# 3. Quantize -./llama-quantize llama-3.1-8b-f16.gguf llama-3.1-8b-q4_k_m.gguf Q4_K_M - -# 4. Test -./llama-cli -m llama-3.1-8b-q4_k_m.gguf -p "Hello!" -n 50 -``` - -### Workflow 2: With importance matrix (better quality) - -```bash -# 1. Convert to GGUF -python convert_hf_to_gguf.py ./model --outfile model-f16.gguf - -# 2. Create calibration text (diverse samples) -cat > calibration.txt << 'EOF' -The quick brown fox jumps over the lazy dog. -Machine learning is a subset of artificial intelligence. -Python is a popular programming language. -# Add more diverse text samples... -EOF - -# 3. Generate importance matrix -./llama-imatrix -m model-f16.gguf \ - -f calibration.txt \ - --chunk 512 \ - -o model.imatrix \ - -ngl 35 # GPU layers if available - -# 4. Quantize with imatrix -./llama-quantize --imatrix model.imatrix \ - model-f16.gguf \ - model-q4_k_m.gguf \ - Q4_K_M -``` - -### Workflow 3: Multiple quantizations - -```bash -#!/bin/bash -MODEL="llama-3.1-8b-f16.gguf" -IMATRIX="llama-3.1-8b.imatrix" - -# Generate imatrix once -./llama-imatrix -m $MODEL -f wiki.txt -o $IMATRIX -ngl 35 - -# Create multiple quantizations -for QUANT in Q4_K_M Q5_K_M Q6_K Q8_0; do - OUTPUT="llama-3.1-8b-${QUANT,,}.gguf" - ./llama-quantize --imatrix $IMATRIX $MODEL $OUTPUT $QUANT - echo "Created: $OUTPUT ($(du -h $OUTPUT | cut -f1))" -done -``` - -## Python usage - -### llama-cpp-python - -```python -from llama_cpp import Llama - -# Load model -llm = Llama( - model_path="./model-q4_k_m.gguf", - n_ctx=4096, # Context window - n_gpu_layers=35, # GPU offload (0 for CPU only) - n_threads=8 # CPU threads -) - -# Generate -output = llm( - "What is machine learning?", - max_tokens=256, - temperature=0.7, - stop=["", "\n\n"] -) -print(output["choices"][0]["text"]) -``` - -### Chat completion - -```python -from llama_cpp import Llama - -llm = Llama( - model_path="./model-q4_k_m.gguf", - n_ctx=4096, - n_gpu_layers=35, - chat_format="llama-3" # Or "chatml", "mistral", etc. -) - -messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is Python?"} -] - -response = llm.create_chat_completion( - messages=messages, - max_tokens=256, - temperature=0.7 -) -print(response["choices"][0]["message"]["content"]) -``` - -### Streaming - -```python -from llama_cpp import Llama - -llm = Llama(model_path="./model-q4_k_m.gguf", n_gpu_layers=35) - -# Stream tokens -for chunk in llm( - "Explain quantum computing:", - max_tokens=256, - stream=True -): - print(chunk["choices"][0]["text"], end="", flush=True) -``` - -## Server mode - -### Start OpenAI-compatible server - -```bash -# Start server -./llama-server -m model-q4_k_m.gguf \ - --host 0.0.0.0 \ - --port 8080 \ - -ngl 35 \ - -c 4096 - -# Or with Python bindings -python -m llama_cpp.server \ - --model model-q4_k_m.gguf \ - --n_gpu_layers 35 \ - --host 0.0.0.0 \ - --port 8080 -``` - -### Use with OpenAI client - -```python -from openai import OpenAI - -client = OpenAI( - base_url="http://localhost:8080/v1", - api_key="not-needed" -) - -response = client.chat.completions.create( - model="local-model", - messages=[{"role": "user", "content": "Hello!"}], - max_tokens=256 -) -print(response.choices[0].message.content) -``` - -## Hardware optimization - -### Apple Silicon (Metal) - -```bash -# Build with Metal -make clean && make GGML_METAL=1 - -# Run with Metal acceleration -./llama-cli -m model.gguf -ngl 99 -p "Hello" - -# Python with Metal -llm = Llama( - model_path="model.gguf", - n_gpu_layers=99, # Offload all layers - n_threads=1 # Metal handles parallelism -) -``` - -### NVIDIA CUDA - -```bash -# Build with CUDA -make clean && make GGML_CUDA=1 - -# Run with CUDA -./llama-cli -m model.gguf -ngl 35 -p "Hello" - -# Specify GPU -CUDA_VISIBLE_DEVICES=0 ./llama-cli -m model.gguf -ngl 35 -``` - -### CPU optimization - -```bash -# Build with AVX2/AVX512 -make clean && make - -# Run with optimal threads -./llama-cli -m model.gguf -t 8 -p "Hello" - -# Python CPU config -llm = Llama( - model_path="model.gguf", - n_gpu_layers=0, # CPU only - n_threads=8, # Match physical cores - n_batch=512 # Batch size for prompt processing -) -``` - -## Integration with tools - -### Ollama - -```bash -# Create Modelfile -cat > Modelfile << 'EOF' -FROM ./model-q4_k_m.gguf -TEMPLATE """{{ .System }} -{{ .Prompt }}""" -PARAMETER temperature 0.7 -PARAMETER num_ctx 4096 -EOF - -# Create Ollama model -ollama create mymodel -f Modelfile - -# Run -ollama run mymodel "Hello!" -``` - -### LM Studio - -1. Place GGUF file in `~/.cache/lm-studio/models/` -2. Open LM Studio and select the model -3. Configure context length and GPU offload -4. Start inference - -### text-generation-webui - -```bash -# Place in models folder -cp model-q4_k_m.gguf text-generation-webui/models/ - -# Start with llama.cpp loader -python server.py --model model-q4_k_m.gguf --loader llama.cpp --n-gpu-layers 35 -``` - -## Best practices - -1. **Use K-quants**: Q4_K_M offers best quality/size balance -2. **Use imatrix**: Always use importance matrix for Q4 and below -3. **GPU offload**: Offload as many layers as VRAM allows -4. **Context length**: Start with 4096, increase if needed -5. **Thread count**: Match physical CPU cores, not logical -6. **Batch size**: Increase n_batch for faster prompt processing - -## Common issues - -**Model loads slowly:** -```bash -# Use mmap for faster loading -./llama-cli -m model.gguf --mmap -``` - -**Out of memory:** -```bash -# Reduce GPU layers -./llama-cli -m model.gguf -ngl 20 # Reduce from 35 - -# Or use smaller quantization -./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M -``` - -**Poor quality at low bits:** -```bash -# Always use imatrix for Q4 and below -./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix -./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M -``` - -## References - -- **[Advanced Usage](references/advanced-usage.md)** - Batching, speculative decoding, custom builds -- **[Troubleshooting](references/troubleshooting.md)** - Common issues, debugging, benchmarks - -## Resources - -- **Repository**: https://github.com/ggml-org/llama.cpp -- **Python Bindings**: https://github.com/abetlen/llama-cpp-python -- **Pre-quantized Models**: https://huggingface.co/TheBloke -- **GGUF Converter**: https://huggingface.co/spaces/ggml-org/gguf-my-repo -- **License**: MIT diff --git a/skills/mlops/gguf/references/advanced-usage.md b/skills/mlops/gguf/references/advanced-usage.md deleted file mode 100644 index de01fda24..000000000 --- a/skills/mlops/gguf/references/advanced-usage.md +++ /dev/null @@ -1,504 +0,0 @@ -# GGUF Advanced Usage Guide - -## Speculative Decoding - -### Draft Model Approach - -```bash -# Use smaller model as draft for faster generation -./llama-speculative \ - -m large-model-q4_k_m.gguf \ - -md draft-model-q4_k_m.gguf \ - -p "Write a story about AI" \ - -n 500 \ - --draft 8 # Draft tokens before verification -``` - -### Self-Speculative Decoding - -```bash -# Use same model with different context for speculation -./llama-cli -m model-q4_k_m.gguf \ - --lookup-cache-static lookup.bin \ - --lookup-cache-dynamic lookup-dynamic.bin \ - -p "Hello world" -``` - -## Batched Inference - -### Process Multiple Prompts - -```python -from llama_cpp import Llama - -llm = Llama( - model_path="model-q4_k_m.gguf", - n_ctx=4096, - n_gpu_layers=35, - n_batch=512 # Larger batch for parallel processing -) - -prompts = [ - "What is Python?", - "Explain machine learning.", - "Describe neural networks." -] - -# Process in batch (each prompt gets separate context) -for prompt in prompts: - output = llm(prompt, max_tokens=100) - print(f"Q: {prompt}") - print(f"A: {output['choices'][0]['text']}\n") -``` - -### Server Batching - -```bash -# Start server with batching -./llama-server -m model-q4_k_m.gguf \ - --host 0.0.0.0 \ - --port 8080 \ - -ngl 35 \ - -c 4096 \ - --parallel 4 # Concurrent requests - --cont-batching # Continuous batching -``` - -## Custom Model Conversion - -### Convert with Vocabulary Modifications - -```python -# custom_convert.py -import sys -sys.path.insert(0, './llama.cpp') - -from convert_hf_to_gguf import main -from gguf import GGUFWriter - -# Custom conversion with modified vocab -def convert_with_custom_vocab(model_path, output_path): - # Load and modify tokenizer - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path) - - # Add special tokens if needed - special_tokens = {"additional_special_tokens": ["<|custom|>"]} - tokenizer.add_special_tokens(special_tokens) - tokenizer.save_pretrained(model_path) - - # Then run standard conversion - main([model_path, "--outfile", output_path]) -``` - -### Convert Specific Architecture - -```bash -# For Mistral-style models -python convert_hf_to_gguf.py ./mistral-model \ - --outfile mistral-f16.gguf \ - --outtype f16 - -# For Qwen models -python convert_hf_to_gguf.py ./qwen-model \ - --outfile qwen-f16.gguf \ - --outtype f16 - -# For Phi models -python convert_hf_to_gguf.py ./phi-model \ - --outfile phi-f16.gguf \ - --outtype f16 -``` - -## Advanced Quantization - -### Mixed Quantization - -```bash -# Quantize different layer types differently -./llama-quantize model-f16.gguf model-mixed.gguf Q4_K_M \ - --allow-requantize \ - --leave-output-tensor -``` - -### Quantization with Token Embeddings - -```bash -# Keep embeddings at higher precision -./llama-quantize model-f16.gguf model-q4.gguf Q4_K_M \ - --token-embedding-type f16 -``` - -### IQ Quantization (Importance-aware) - -```bash -# Ultra-low bit quantization with importance -./llama-quantize --imatrix model.imatrix \ - model-f16.gguf model-iq2_xxs.gguf IQ2_XXS - -# Available IQ types: IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ3_XS, IQ3_S, IQ4_XS -``` - -## Memory Optimization - -### Memory Mapping - -```python -from llama_cpp import Llama - -# Use memory mapping for large models -llm = Llama( - model_path="model-q4_k_m.gguf", - use_mmap=True, # Memory map the model - use_mlock=False, # Don't lock in RAM - n_gpu_layers=35 -) -``` - -### Partial GPU Offload - -```python -# Calculate layers to offload based on VRAM -import subprocess - -def get_free_vram_gb(): - result = subprocess.run( - ['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'], - capture_output=True, text=True - ) - return int(result.stdout.strip()) / 1024 - -# Estimate layers based on VRAM (rough: 0.5GB per layer for 7B Q4) -free_vram = get_free_vram_gb() -layers_to_offload = int(free_vram / 0.5) - -llm = Llama( - model_path="model-q4_k_m.gguf", - n_gpu_layers=min(layers_to_offload, 35) # Cap at total layers -) -``` - -### KV Cache Optimization - -```python -from llama_cpp import Llama - -# Optimize KV cache for long contexts -llm = Llama( - model_path="model-q4_k_m.gguf", - n_ctx=8192, # Large context - n_gpu_layers=35, - type_k=1, # Q8_0 for K cache (1) - type_v=1, # Q8_0 for V cache (1) - # Or use Q4_0 (2) for more compression -) -``` - -## Context Management - -### Context Shifting - -```python -from llama_cpp import Llama - -llm = Llama( - model_path="model-q4_k_m.gguf", - n_ctx=4096, - n_gpu_layers=35 -) - -# Handle long conversations with context shifting -conversation = [] -max_history = 10 - -def chat(user_message): - conversation.append({"role": "user", "content": user_message}) - - # Keep only recent history - if len(conversation) > max_history * 2: - conversation = conversation[-max_history * 2:] - - response = llm.create_chat_completion( - messages=conversation, - max_tokens=256 - ) - - assistant_message = response["choices"][0]["message"]["content"] - conversation.append({"role": "assistant", "content": assistant_message}) - return assistant_message -``` - -### Save and Load State - -```bash -# Save state to file -./llama-cli -m model.gguf \ - -p "Once upon a time" \ - --save-session session.bin \ - -n 100 - -# Load and continue -./llama-cli -m model.gguf \ - --load-session session.bin \ - -p " and they lived" \ - -n 100 -``` - -## Grammar Constrained Generation - -### JSON Output - -```python -from llama_cpp import Llama, LlamaGrammar - -# Define JSON grammar -json_grammar = LlamaGrammar.from_string(''' -root ::= object -object ::= "{" ws pair ("," ws pair)* "}" ws -pair ::= string ":" ws value -value ::= string | number | object | array | "true" | "false" | "null" -array ::= "[" ws value ("," ws value)* "]" ws -string ::= "\\"" [^"\\\\]* "\\"" -number ::= [0-9]+ -ws ::= [ \\t\\n]* -''') - -llm = Llama(model_path="model-q4_k_m.gguf", n_gpu_layers=35) - -output = llm( - "Output a JSON object with name and age:", - grammar=json_grammar, - max_tokens=100 -) -print(output["choices"][0]["text"]) -``` - -### Custom Grammar - -```python -# Grammar for specific format -answer_grammar = LlamaGrammar.from_string(''' -root ::= "Answer: " letter "\\n" "Explanation: " explanation -letter ::= [A-D] -explanation ::= [a-zA-Z0-9 .,!?]+ -''') - -output = llm( - "Q: What is 2+2? A) 3 B) 4 C) 5 D) 6", - grammar=answer_grammar, - max_tokens=100 -) -``` - -## LoRA Integration - -### Load LoRA Adapter - -```bash -# Apply LoRA at runtime -./llama-cli -m base-model-q4_k_m.gguf \ - --lora lora-adapter.gguf \ - --lora-scale 1.0 \ - -p "Hello!" -``` - -### Multiple LoRA Adapters - -```bash -# Stack multiple adapters -./llama-cli -m base-model.gguf \ - --lora adapter1.gguf --lora-scale 0.5 \ - --lora adapter2.gguf --lora-scale 0.5 \ - -p "Hello!" -``` - -### Python LoRA Usage - -```python -from llama_cpp import Llama - -llm = Llama( - model_path="base-model-q4_k_m.gguf", - lora_path="lora-adapter.gguf", - lora_scale=1.0, - n_gpu_layers=35 -) -``` - -## Embedding Generation - -### Extract Embeddings - -```python -from llama_cpp import Llama - -llm = Llama( - model_path="model-q4_k_m.gguf", - embedding=True, # Enable embedding mode - n_gpu_layers=35 -) - -# Get embeddings -embeddings = llm.embed("This is a test sentence.") -print(f"Embedding dimension: {len(embeddings)}") -``` - -### Batch Embeddings - -```python -texts = [ - "Machine learning is fascinating.", - "Deep learning uses neural networks.", - "Python is a programming language." -] - -embeddings = [llm.embed(text) for text in texts] - -# Calculate similarity -import numpy as np - -def cosine_similarity(a, b): - return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) - -sim = cosine_similarity(embeddings[0], embeddings[1]) -print(f"Similarity: {sim:.4f}") -``` - -## Performance Tuning - -### Benchmark Script - -```python -import time -from llama_cpp import Llama - -def benchmark(model_path, prompt, n_tokens=100, n_runs=5): - llm = Llama( - model_path=model_path, - n_gpu_layers=35, - n_ctx=2048, - verbose=False - ) - - # Warmup - llm(prompt, max_tokens=10) - - # Benchmark - times = [] - for _ in range(n_runs): - start = time.time() - output = llm(prompt, max_tokens=n_tokens) - elapsed = time.time() - start - times.append(elapsed) - - avg_time = sum(times) / len(times) - tokens_per_sec = n_tokens / avg_time - - print(f"Model: {model_path}") - print(f"Avg time: {avg_time:.2f}s") - print(f"Tokens/sec: {tokens_per_sec:.1f}") - - return tokens_per_sec - -# Compare quantizations -for quant in ["q4_k_m", "q5_k_m", "q8_0"]: - benchmark(f"model-{quant}.gguf", "Explain quantum computing:", 100) -``` - -### Optimal Configuration Finder - -```python -def find_optimal_config(model_path, target_vram_gb=8): - """Find optimal n_gpu_layers and n_batch for target VRAM.""" - from llama_cpp import Llama - import gc - - best_config = None - best_speed = 0 - - for n_gpu_layers in range(0, 50, 5): - for n_batch in [128, 256, 512, 1024]: - try: - gc.collect() - llm = Llama( - model_path=model_path, - n_gpu_layers=n_gpu_layers, - n_batch=n_batch, - n_ctx=2048, - verbose=False - ) - - # Quick benchmark - start = time.time() - llm("Hello", max_tokens=50) - speed = 50 / (time.time() - start) - - if speed > best_speed: - best_speed = speed - best_config = { - "n_gpu_layers": n_gpu_layers, - "n_batch": n_batch, - "speed": speed - } - - del llm - gc.collect() - - except Exception as e: - print(f"OOM at layers={n_gpu_layers}, batch={n_batch}") - break - - return best_config -``` - -## Multi-GPU Setup - -### Distribute Across GPUs - -```bash -# Split model across multiple GPUs -./llama-cli -m large-model.gguf \ - --tensor-split 0.5,0.5 \ - -ngl 60 \ - -p "Hello!" -``` - -### Python Multi-GPU - -```python -import os -os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" - -from llama_cpp import Llama - -llm = Llama( - model_path="large-model-q4_k_m.gguf", - n_gpu_layers=60, - tensor_split=[0.5, 0.5] # Split evenly across 2 GPUs -) -``` - -## Custom Builds - -### Build with All Optimizations - -```bash -# Clean build with all CPU optimizations -make clean -LLAMA_OPENBLAS=1 LLAMA_BLAS_VENDOR=OpenBLAS make -j - -# With CUDA and cuBLAS -make clean -GGML_CUDA=1 LLAMA_CUBLAS=1 make -j - -# With specific CUDA architecture -GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_86 make -j -``` - -### CMake Build - -```bash -mkdir build && cd build -cmake .. -DGGML_CUDA=ON -DCMAKE_BUILD_TYPE=Release -cmake --build . --config Release -j -``` diff --git a/skills/mlops/gguf/references/troubleshooting.md b/skills/mlops/gguf/references/troubleshooting.md deleted file mode 100644 index 3d5c579cb..000000000 --- a/skills/mlops/gguf/references/troubleshooting.md +++ /dev/null @@ -1,442 +0,0 @@ -# GGUF Troubleshooting Guide - -## Installation Issues - -### Build Fails - -**Error**: `make: *** No targets specified and no makefile found` - -**Fix**: -```bash -# Ensure you're in llama.cpp directory -cd llama.cpp -make -``` - -**Error**: `fatal error: cuda_runtime.h: No such file or directory` - -**Fix**: -```bash -# Install CUDA toolkit -# Ubuntu -sudo apt install nvidia-cuda-toolkit - -# Or set CUDA path -export CUDA_PATH=/usr/local/cuda -export PATH=$CUDA_PATH/bin:$PATH -make GGML_CUDA=1 -``` - -### Python Bindings Issues - -**Error**: `ERROR: Failed building wheel for llama-cpp-python` - -**Fix**: -```bash -# Install build dependencies -pip install cmake scikit-build-core - -# For CUDA support -CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --force-reinstall --no-cache-dir - -# For Metal (macOS) -CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall --no-cache-dir -``` - -**Error**: `ImportError: libcudart.so.XX: cannot open shared object file` - -**Fix**: -```bash -# Add CUDA libraries to path -export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH - -# Or reinstall with correct CUDA version -pip uninstall llama-cpp-python -CUDACXX=/usr/local/cuda/bin/nvcc CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python -``` - -## Conversion Issues - -### Model Not Supported - -**Error**: `KeyError: 'model.embed_tokens.weight'` - -**Fix**: -```bash -# Check model architecture -python -c "from transformers import AutoConfig; print(AutoConfig.from_pretrained('./model').architectures)" - -# Use appropriate conversion script -# For most models: -python convert_hf_to_gguf.py ./model --outfile model.gguf - -# For older models, check if legacy script needed -``` - -### Vocabulary Mismatch - -**Error**: `RuntimeError: Vocabulary size mismatch` - -**Fix**: -```python -# Ensure tokenizer matches model -from transformers import AutoTokenizer, AutoModelForCausalLM - -tokenizer = AutoTokenizer.from_pretrained("./model") -model = AutoModelForCausalLM.from_pretrained("./model") - -print(f"Tokenizer vocab size: {len(tokenizer)}") -print(f"Model vocab size: {model.config.vocab_size}") - -# If mismatch, resize embeddings before conversion -model.resize_token_embeddings(len(tokenizer)) -model.save_pretrained("./model-fixed") -``` - -### Out of Memory During Conversion - -**Error**: `torch.cuda.OutOfMemoryError` during conversion - -**Fix**: -```bash -# Use CPU for conversion -CUDA_VISIBLE_DEVICES="" python convert_hf_to_gguf.py ./model --outfile model.gguf - -# Or use low memory mode -python convert_hf_to_gguf.py ./model --outfile model.gguf --outtype f16 -``` - -## Quantization Issues - -### Wrong Output File Size - -**Problem**: Quantized file is larger than expected - -**Check**: -```bash -# Verify quantization type -./llama-cli -m model.gguf --verbose - -# Expected sizes for 7B model: -# Q4_K_M: ~4.1 GB -# Q5_K_M: ~4.8 GB -# Q8_0: ~7.2 GB -# F16: ~13.5 GB -``` - -### Quantization Crashes - -**Error**: `Segmentation fault` during quantization - -**Fix**: -```bash -# Increase stack size -ulimit -s unlimited - -# Or use less threads -./llama-quantize -t 4 model-f16.gguf model-q4.gguf Q4_K_M -``` - -### Poor Quality After Quantization - -**Problem**: Model outputs gibberish after quantization - -**Solutions**: - -1. **Use importance matrix**: -```bash -# Generate imatrix with good calibration data -./llama-imatrix -m model-f16.gguf \ - -f wiki_sample.txt \ - --chunk 512 \ - -o model.imatrix - -# Quantize with imatrix -./llama-quantize --imatrix model.imatrix \ - model-f16.gguf model-q4_k_m.gguf Q4_K_M -``` - -2. **Try higher precision**: -```bash -# Use Q5_K_M or Q6_K instead of Q4 -./llama-quantize model-f16.gguf model-q5_k_m.gguf Q5_K_M -``` - -3. **Check original model**: -```bash -# Test FP16 version first -./llama-cli -m model-f16.gguf -p "Hello, how are you?" -n 50 -``` - -## Inference Issues - -### Slow Generation - -**Problem**: Generation is slower than expected - -**Solutions**: - -1. **Enable GPU offload**: -```bash -./llama-cli -m model.gguf -ngl 35 -p "Hello" -``` - -2. **Optimize batch size**: -```python -llm = Llama( - model_path="model.gguf", - n_batch=512, # Increase for faster prompt processing - n_gpu_layers=35 -) -``` - -3. **Use appropriate threads**: -```bash -# Match physical cores, not logical -./llama-cli -m model.gguf -t 8 -p "Hello" -``` - -4. **Enable Flash Attention** (if supported): -```bash -./llama-cli -m model.gguf -ngl 35 --flash-attn -p "Hello" -``` - -### Out of Memory - -**Error**: `CUDA out of memory` or system freeze - -**Solutions**: - -1. **Reduce GPU layers**: -```python -# Start low and increase -llm = Llama(model_path="model.gguf", n_gpu_layers=10) -``` - -2. **Use smaller quantization**: -```bash -./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M -``` - -3. **Reduce context length**: -```python -llm = Llama( - model_path="model.gguf", - n_ctx=2048, # Reduce from 4096 - n_gpu_layers=35 -) -``` - -4. **Quantize KV cache**: -```python -llm = Llama( - model_path="model.gguf", - type_k=2, # Q4_0 for K cache - type_v=2, # Q4_0 for V cache - n_gpu_layers=35 -) -``` - -### Garbage Output - -**Problem**: Model outputs random characters or nonsense - -**Diagnose**: -```python -# Check model loading -llm = Llama(model_path="model.gguf", verbose=True) - -# Test with simple prompt -output = llm("1+1=", max_tokens=5, temperature=0) -print(output) -``` - -**Solutions**: - -1. **Check model integrity**: -```bash -# Verify GGUF file -./llama-cli -m model.gguf --verbose 2>&1 | head -50 -``` - -2. **Use correct chat format**: -```python -llm = Llama( - model_path="model.gguf", - chat_format="llama-3" # Match your model: chatml, mistral, etc. -) -``` - -3. **Check temperature**: -```python -# Use lower temperature for deterministic output -output = llm("Hello", max_tokens=50, temperature=0.1) -``` - -### Token Issues - -**Error**: `RuntimeError: unknown token` or encoding errors - -**Fix**: -```python -# Ensure UTF-8 encoding -prompt = "Hello, world!".encode('utf-8').decode('utf-8') -output = llm(prompt, max_tokens=50) -``` - -## Server Issues - -### Connection Refused - -**Error**: `Connection refused` when accessing server - -**Fix**: -```bash -# Bind to all interfaces -./llama-server -m model.gguf --host 0.0.0.0 --port 8080 - -# Check if port is in use -lsof -i :8080 -``` - -### Server Crashes Under Load - -**Problem**: Server crashes with multiple concurrent requests - -**Solutions**: - -1. **Limit parallelism**: -```bash -./llama-server -m model.gguf \ - --parallel 2 \ - -c 4096 \ - --cont-batching -``` - -2. **Add request timeout**: -```bash -./llama-server -m model.gguf --timeout 300 -``` - -3. **Monitor memory**: -```bash -watch -n 1 nvidia-smi # For GPU -watch -n 1 free -h # For RAM -``` - -### API Compatibility Issues - -**Problem**: OpenAI client not working with server - -**Fix**: -```python -from openai import OpenAI - -# Use correct base URL format -client = OpenAI( - base_url="http://localhost:8080/v1", # Include /v1 - api_key="not-needed" -) - -# Use correct model name -response = client.chat.completions.create( - model="local", # Or the actual model name - messages=[{"role": "user", "content": "Hello"}] -) -``` - -## Apple Silicon Issues - -### Metal Not Working - -**Problem**: Metal acceleration not enabled - -**Check**: -```bash -# Verify Metal support -./llama-cli -m model.gguf --verbose 2>&1 | grep -i metal -``` - -**Fix**: -```bash -# Rebuild with Metal -make clean -make GGML_METAL=1 - -# Python bindings -CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall -``` - -### Incorrect Memory Usage on M1/M2 - -**Problem**: Model uses too much unified memory - -**Fix**: -```python -# Offload all layers for Metal -llm = Llama( - model_path="model.gguf", - n_gpu_layers=99, # Offload everything - n_threads=1 # Metal handles parallelism -) -``` - -## Debugging - -### Enable Verbose Output - -```bash -# CLI verbose mode -./llama-cli -m model.gguf --verbose -p "Hello" -n 50 - -# Python verbose -llm = Llama(model_path="model.gguf", verbose=True) -``` - -### Check Model Metadata - -```bash -# View GGUF metadata -./llama-cli -m model.gguf --verbose 2>&1 | head -100 -``` - -### Validate GGUF File - -```python -import struct - -def validate_gguf(filepath): - with open(filepath, 'rb') as f: - magic = f.read(4) - if magic != b'GGUF': - print(f"Invalid magic: {magic}") - return False - - version = struct.unpack(' 0.1 (avoid mode collapse) -- **Start with num_generations=4-8** - Scale up if GPU allows - -## 🔗 External Resources - -- [TRL Documentation](https://huggingface.co/docs/trl) -- [DeepSeek R1 Paper](https://arxiv.org/abs/2501.12948) -- [Open R1 Implementation](https://github.com/huggingface/open-r1) -- [Unsloth (2-3x faster)](https://docs.unsloth.ai/) - -## 📝 Version - -**v1.0.0** - Initial release (January 2025) - -## 👨‍💻 Maintained By - -Orchestra Research -For questions or improvements, see https://orchestra.com - ---- - -**License:** MIT -**Last Updated:** January 2025 diff --git a/skills/mlops/grpo-rl-training/SKILL.md b/skills/mlops/grpo-rl-training/SKILL.md deleted file mode 100644 index 1d7629ab6..000000000 --- a/skills/mlops/grpo-rl-training/SKILL.md +++ /dev/null @@ -1,575 +0,0 @@ ---- -name: grpo-rl-training -description: Expert guidance for GRPO/RL fine-tuning with TRL for reasoning and task-specific model training -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [transformers>=4.47.0, trl>=0.14.0, datasets>=3.2.0, peft>=0.14.0, torch] -metadata: - hermes: - tags: [Post-Training, Reinforcement Learning, GRPO, TRL, RLHF, Reward Modeling, Reasoning, DPO, PPO, Structured Output] - ---- - -# GRPO/RL Training with TRL - -Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions. - -## When to Use This Skill - -Use GRPO training when you need to: -- **Enforce specific output formats** (e.g., XML tags, JSON, structured reasoning) -- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking) -- **Improve reasoning capabilities** by rewarding chain-of-thought patterns -- **Align models to domain-specific behaviors** without labeled preference data -- **Optimize for multiple objectives** simultaneously (format + correctness + style) - -**Do NOT use GRPO for:** -- Simple supervised fine-tuning tasks (use SFT instead) -- Tasks without clear reward signals -- When you already have high-quality preference pairs (use DPO/PPO instead) - ---- - -## Core Concepts - -### 1. GRPO Algorithm Fundamentals - -**Key Mechanism:** -- Generates **multiple completions** for each prompt (group size: 4-16) -- Compares completions within each group using reward functions -- Updates policy to favor higher-rewarded responses relative to the group - -**Critical Difference from PPO:** -- No separate reward model needed -- More sample-efficient (learns from within-group comparisons) -- Simpler to implement and debug - -**Mathematical Intuition:** -``` -For each prompt p: - 1. Generate N completions: {c₁, c₂, ..., cₙ} - 2. Compute rewards: {r₁, r₂, ..., rₙ} - 3. Learn to increase probability of high-reward completions - relative to low-reward ones in the same group -``` - -### 2. Reward Function Design Philosophy - -**Golden Rules:** -1. **Compose multiple reward functions** - Each handles one aspect (format, correctness, style) -2. **Scale rewards appropriately** - Higher weight = stronger signal -3. **Use incremental rewards** - Partial credit for partial compliance -4. **Test rewards independently** - Debug each reward function in isolation - -**Reward Function Types:** - -| Type | Use Case | Example Weight | -|------|----------|----------------| -| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) | -| **Format** | Strict structure enforcement | 0.5-1.0 | -| **Length** | Encourage verbosity/conciseness | 0.1-0.5 | -| **Style** | Penalize unwanted patterns | -0.5 to 0.5 | - ---- - -## Implementation Workflow - -### Step 1: Dataset Preparation - -**Critical Requirements:** -- Prompts in chat format (list of dicts with 'role' and 'content') -- Include system prompts to set expectations -- For verifiable tasks, include ground truth answers as additional columns - -**Example Structure:** -```python -from datasets import load_dataset, Dataset - -SYSTEM_PROMPT = """ -Respond in the following format: - -[Your step-by-step thinking] - - -[Final answer] - -""" - -def prepare_dataset(raw_data): - """ - Transform raw data into GRPO-compatible format. - - Returns: Dataset with columns: - - 'prompt': List[Dict] with role/content (system + user messages) - - 'answer': str (ground truth, optional but recommended) - """ - return raw_data.map(lambda x: { - 'prompt': [ - {'role': 'system', 'content': SYSTEM_PROMPT}, - {'role': 'user', 'content': x['question']} - ], - 'answer': extract_answer(x['raw_answer']) - }) -``` - -**Pro Tips:** -- Use one-shot or few-shot examples in system prompt for complex formats -- Keep prompts concise (max_prompt_length: 256-512 tokens) -- Validate data quality before training (garbage in = garbage out) - -### Step 2: Reward Function Implementation - -**Template Structure:** -```python -def reward_function_name( - prompts, # List[List[Dict]]: Original prompts - completions, # List[List[Dict]]: Model generations - answer=None, # Optional: Ground truth from dataset - **kwargs # Additional dataset columns -) -> list[float]: - """ - Evaluate completions and return rewards. - - Returns: List of floats (one per completion) - """ - # Extract completion text - responses = [comp[0]['content'] for comp in completions] - - # Compute rewards - rewards = [] - for response in responses: - score = compute_score(response) - rewards.append(score) - - return rewards -``` - -**Example 1: Correctness Reward (Math/Coding)** -```python -def correctness_reward(prompts, completions, answer, **kwargs): - """Reward correct answers with high score.""" - responses = [comp[0]['content'] for comp in completions] - extracted = [extract_final_answer(r) for r in responses] - return [2.0 if ans == gt else 0.0 - for ans, gt in zip(extracted, answer)] -``` - -**Example 2: Format Reward (Structured Output)** -```python -import re - -def format_reward(completions, **kwargs): - """Reward XML-like structured format.""" - pattern = r'.*?\s*.*?' - responses = [comp[0]['content'] for comp in completions] - return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0 - for r in responses] -``` - -**Example 3: Incremental Format Reward (Partial Credit)** -```python -def incremental_format_reward(completions, **kwargs): - """Award partial credit for format compliance.""" - responses = [comp[0]['content'] for comp in completions] - rewards = [] - - for r in responses: - score = 0.0 - if '' in r: - score += 0.25 - if '' in r: - score += 0.25 - if '' in r: - score += 0.25 - if '' in r: - score += 0.25 - # Penalize extra text after closing tag - if r.count('') == 1: - extra_text = r.split('')[-1].strip() - score -= len(extra_text) * 0.001 - rewards.append(score) - - return rewards -``` - -**Critical Insight:** -Combine 3-5 reward functions for robust training. Order matters less than diversity of signals. - -### Step 3: Training Configuration - -**Memory-Optimized Config (Small GPU)** -```python -from trl import GRPOConfig - -training_args = GRPOConfig( - output_dir="outputs/grpo-model", - - # Learning rate - learning_rate=5e-6, # Lower = more stable - adam_beta1=0.9, - adam_beta2=0.99, - weight_decay=0.1, - warmup_ratio=0.1, - lr_scheduler_type='cosine', - - # Batch settings - per_device_train_batch_size=1, - gradient_accumulation_steps=4, # Effective batch = 4 - - # GRPO-specific - num_generations=8, # Group size: 8-16 recommended - max_prompt_length=256, - max_completion_length=512, - - # Training duration - num_train_epochs=1, - max_steps=None, # Or set fixed steps (e.g., 500) - - # Optimization - bf16=True, # Faster on A100/H100 - optim="adamw_8bit", # Memory-efficient optimizer - max_grad_norm=0.1, - - # Logging - logging_steps=1, - save_steps=100, - report_to="wandb", # Or "none" for no logging -) -``` - -**High-Performance Config (Large GPU)** -```python -training_args = GRPOConfig( - output_dir="outputs/grpo-model", - learning_rate=1e-5, - per_device_train_batch_size=4, - gradient_accumulation_steps=2, - num_generations=16, # Larger groups = better signal - max_prompt_length=512, - max_completion_length=1024, - num_train_epochs=1, - bf16=True, - use_vllm=True, # Fast generation with vLLM - logging_steps=10, -) -``` - -**Critical Hyperparameters:** - -| Parameter | Impact | Tuning Advice | -|-----------|--------|---------------| -| `num_generations` | Group size for comparison | Start with 8, increase to 16 if GPU allows | -| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) | -| `max_completion_length` | Output verbosity | Match your task (512 for reasoning, 256 for short answers) | -| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited | - -### Step 4: Model Setup and Training - -**Standard Setup (Transformers)** -```python -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -from peft import LoraConfig -from trl import GRPOTrainer - -# Load model -model_name = "Qwen/Qwen2.5-1.5B-Instruct" -model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", # 2-3x faster - device_map="auto" -) - -tokenizer = AutoTokenizer.from_pretrained(model_name) -tokenizer.pad_token = tokenizer.eos_token - -# Optional: LoRA for parameter-efficient training -peft_config = LoraConfig( - r=16, # Rank (higher = more capacity) - lora_alpha=32, # Scaling factor (typically 2*r) - target_modules=[ - "q_proj", "k_proj", "v_proj", "o_proj", - "gate_proj", "up_proj", "down_proj" - ], - task_type="CAUSAL_LM", - lora_dropout=0.05, -) - -# Initialize trainer -trainer = GRPOTrainer( - model=model, - processing_class=tokenizer, - reward_funcs=[ - incremental_format_reward, - format_reward, - correctness_reward, - ], - args=training_args, - train_dataset=dataset, - peft_config=peft_config, # Remove for full fine-tuning -) - -# Train -trainer.train() - -# Save -trainer.save_model("final_model") -``` - -**Unsloth Setup (2-3x Faster)** -```python -from unsloth import FastLanguageModel - -model, tokenizer = FastLanguageModel.from_pretrained( - model_name="google/gemma-3-1b-it", - max_seq_length=1024, - load_in_4bit=True, - fast_inference=True, - max_lora_rank=32, -) - -model = FastLanguageModel.get_peft_model( - model, - r=32, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj", - "gate_proj", "up_proj", "down_proj"], - lora_alpha=32, - use_gradient_checkpointing="unsloth", -) - -# Rest is identical to standard setup -trainer = GRPOTrainer(model=model, ...) -trainer.train() -``` - ---- - -## Critical Training Insights - -### 1. Loss Behavior (EXPECTED PATTERN) -- **Loss starts near 0 and INCREASES during training** -- This is CORRECT - loss measures KL divergence from initial policy -- Model is learning (diverging from original behavior to optimize rewards) -- Monitor reward metrics instead of loss for progress - -### 2. Reward Tracking -Key metrics to watch: -- `reward`: Average across all completions -- `reward_std`: Diversity within groups (should remain > 0) -- `kl`: KL divergence from reference (should grow moderately) - -**Healthy Training Pattern:** -``` -Step Reward Reward_Std KL -100 0.5 0.3 0.02 -200 0.8 0.25 0.05 -300 1.2 0.2 0.08 ← Good progression -400 1.5 0.15 0.12 -``` - -**Warning Signs:** -- Reward std → 0 (model collapsing to single response) -- KL exploding (> 0.5) (diverging too much, reduce LR) -- Reward stuck (reward functions too harsh or model capacity issue) - -### 3. Common Pitfalls and Solutions - -| Problem | Symptom | Solution | -|---------|---------|----------| -| **Mode collapse** | All completions identical | Increase `num_generations`, add diversity penalty | -| **No learning** | Flat rewards | Check reward function logic, increase LR | -| **OOM errors** | GPU memory exceeded | Reduce `num_generations`, enable gradient checkpointing | -| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length | -| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards | - ---- - -## Advanced Patterns - -### 1. Multi-Stage Training -For complex tasks, train in stages: - -```python -# Stage 1: Format compliance (epochs=1) -trainer_stage1 = GRPOTrainer( - model=model, - reward_funcs=[incremental_format_reward, format_reward], - ... -) -trainer_stage1.train() - -# Stage 2: Correctness (epochs=1) -trainer_stage2 = GRPOTrainer( - model=model, - reward_funcs=[format_reward, correctness_reward], - ... -) -trainer_stage2.train() -``` - -### 2. Adaptive Reward Scaling -```python -class AdaptiveReward: - def __init__(self, base_reward_func, initial_weight=1.0): - self.func = base_reward_func - self.weight = initial_weight - - def __call__(self, *args, **kwargs): - rewards = self.func(*args, **kwargs) - return [r * self.weight for r in rewards] - - def adjust_weight(self, success_rate): - """Increase weight if model struggling, decrease if succeeding.""" - if success_rate < 0.3: - self.weight *= 1.2 - elif success_rate > 0.8: - self.weight *= 0.9 -``` - -### 3. Custom Dataset Integration -```python -def load_custom_knowledge_base(csv_path): - """Example: School communication platform docs.""" - import pandas as pd - df = pd.read_csv(csv_path) - - dataset = Dataset.from_pandas(df).map(lambda x: { - 'prompt': [ - {'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT}, - {'role': 'user', 'content': x['question']} - ], - 'answer': x['expert_answer'] - }) - return dataset -``` - ---- - -## Deployment and Inference - -### Save and Merge LoRA -```python -# Merge LoRA adapters into base model -if hasattr(trainer.model, 'merge_and_unload'): - merged_model = trainer.model.merge_and_unload() - merged_model.save_pretrained("production_model") - tokenizer.save_pretrained("production_model") -``` - -### Inference Example -```python -from transformers import pipeline - -generator = pipeline( - "text-generation", - model="production_model", - tokenizer=tokenizer -) - -result = generator( - [ - {'role': 'system', 'content': SYSTEM_PROMPT}, - {'role': 'user', 'content': "What is 15 + 27?"} - ], - max_new_tokens=256, - do_sample=True, - temperature=0.7, - top_p=0.9 -) -print(result[0]['generated_text']) -``` - ---- - -## Best Practices Checklist - -**Before Training:** -- [ ] Validate dataset format (prompts as List[Dict]) -- [ ] Test reward functions on sample data -- [ ] Calculate expected max_prompt_length from data -- [ ] Choose appropriate num_generations based on GPU memory -- [ ] Set up logging (wandb recommended) - -**During Training:** -- [ ] Monitor reward progression (should increase) -- [ ] Check reward_std (should stay > 0.1) -- [ ] Watch for OOM errors (reduce batch size if needed) -- [ ] Sample generations every 50-100 steps -- [ ] Validate format compliance on holdout set - -**After Training:** -- [ ] Merge LoRA weights if using PEFT -- [ ] Test on diverse prompts -- [ ] Compare to baseline model -- [ ] Document reward weights and hyperparameters -- [ ] Save reproducibility config - ---- - -## Troubleshooting Guide - -### Debugging Workflow -1. **Isolate reward functions** - Test each independently -2. **Check data distribution** - Ensure diversity in prompts -3. **Reduce complexity** - Start with single reward, add gradually -4. **Monitor generations** - Print samples every N steps -5. **Validate extraction logic** - Ensure answer parsing works - -### Quick Fixes -```python -# Debug reward function -def debug_reward(completions, **kwargs): - responses = [comp[0]['content'] for comp in completions] - for i, r in enumerate(responses[:2]): # Print first 2 - print(f"Response {i}: {r[:200]}...") - return [1.0] * len(responses) # Dummy rewards - -# Test without training -trainer = GRPOTrainer(..., reward_funcs=[debug_reward]) -trainer.generate_completions(dataset[:1]) # Generate without updating -``` - ---- - -## References and Resources - -**Official Documentation:** -- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer -- DeepSeek R1 Paper: https://arxiv.org/abs/2501.12948 -- Unsloth Docs: https://docs.unsloth.ai/ - -**Example Repositories:** -- Open R1 Implementation: https://github.com/huggingface/open-r1 -- TRL Examples: https://github.com/huggingface/trl/tree/main/examples - -**Recommended Reading:** -- Progressive Disclosure Pattern for agent instructions -- Reward shaping in RL (Ng et al.) -- LoRA paper (Hu et al., 2021) - ---- - -## Usage Instructions for Agents - -When this skill is loaded: - -1. **Read this entire file** before implementing GRPO training -2. **Start with the simplest reward function** (e.g., length-based) to validate setup -3. **Use the templates** in `templates/` directory as starting points -4. **Reference examples** in `examples/` for task-specific implementations -5. **Follow the workflow** sequentially (don't skip steps) -6. **Debug incrementally** - add one reward function at a time - -**Critical Reminders:** -- Always use multiple reward functions (3-5 is optimal) -- Monitor reward metrics, not loss -- Test reward functions before training -- Start small (num_generations=4), scale up gradually -- Save checkpoints frequently (every 100 steps) - -This skill is designed for **expert-level implementation**. Beginners should start with supervised fine-tuning before attempting GRPO. - - - diff --git a/skills/mlops/grpo-rl-training/templates/basic_grpo_training.py b/skills/mlops/grpo-rl-training/templates/basic_grpo_training.py deleted file mode 100644 index 228a93e7c..000000000 --- a/skills/mlops/grpo-rl-training/templates/basic_grpo_training.py +++ /dev/null @@ -1,228 +0,0 @@ -""" -Basic GRPO Training Template -============================= - -A minimal, production-ready template for GRPO training with TRL. -Adapt this for your specific task by modifying: -1. Dataset loading (get_dataset function) -2. Reward functions (reward_*_func) -3. System prompt (SYSTEM_PROMPT) -4. Hyperparameters (GRPOConfig) -""" - -import torch -import re -from datasets import load_dataset, Dataset -from transformers import AutoModelForCausalLM, AutoTokenizer -from peft import LoraConfig -from trl import GRPOTrainer, GRPOConfig - -# ==================== CONFIGURATION ==================== - -MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" -OUTPUT_DIR = "outputs/grpo-model" -MAX_PROMPT_LENGTH = 256 -MAX_COMPLETION_LENGTH = 512 - -SYSTEM_PROMPT = """ -Respond in the following format: - -[Your step-by-step thinking] - - -[Final answer] - -""" - -# ==================== DATASET ==================== - -def get_dataset(split="train"): - """ - Load and prepare your dataset. - - Returns: Dataset with columns: - - 'prompt': List[Dict] with role/content - - 'answer': str (ground truth, optional) - """ - # Example: GSM8K math dataset - data = load_dataset('openai/gsm8k', 'main')[split] - - def process_example(x): - # Extract ground truth answer - answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None - - return { - 'prompt': [ - {'role': 'system', 'content': SYSTEM_PROMPT}, - {'role': 'user', 'content': x['question']} - ], - 'answer': answer - } - - return data.map(process_example) - -# ==================== HELPER FUNCTIONS ==================== - -def extract_xml_tag(text: str, tag: str) -> str: - """Extract content between XML tags.""" - pattern = f'<{tag}>(.*?)' - match = re.search(pattern, text, re.DOTALL) - return match.group(1).strip() if match else "" - -def extract_answer(text: str) -> str: - """Extract the final answer from structured output.""" - return extract_xml_tag(text, 'answer') - -# ==================== REWARD FUNCTIONS ==================== - -def correctness_reward_func(prompts, completions, answer, **kwargs): - """ - Reward correct answers. - Weight: 2.0 (highest priority) - """ - responses = [comp[0]['content'] for comp in completions] - extracted = [extract_answer(r) for r in responses] - return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)] - -def format_reward_func(completions, **kwargs): - """ - Reward proper XML format. - Weight: 0.5 - """ - pattern = r'.*?\s*.*?' - responses = [comp[0]['content'] for comp in completions] - return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses] - -def incremental_format_reward_func(completions, **kwargs): - """ - Incremental reward for partial format compliance. - Weight: up to 0.5 - """ - responses = [comp[0]['content'] for comp in completions] - rewards = [] - - for r in responses: - score = 0.0 - if '' in r: - score += 0.125 - if '' in r: - score += 0.125 - if '' in r: - score += 0.125 - if '' in r: - score += 0.125 - - # Penalize extra content after closing tag - if '' in r: - extra = r.split('')[-1].strip() - score -= len(extra) * 0.001 - - rewards.append(score) - - return rewards - -# ==================== MODEL SETUP ==================== - -def setup_model_and_tokenizer(): - """Load model and tokenizer with optimizations.""" - model = AutoModelForCausalLM.from_pretrained( - MODEL_NAME, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - device_map="auto" - ) - - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - tokenizer.pad_token = tokenizer.eos_token - - return model, tokenizer - -def get_peft_config(): - """LoRA configuration for parameter-efficient training.""" - return LoraConfig( - r=16, - lora_alpha=32, - target_modules=[ - "q_proj", "k_proj", "v_proj", "o_proj", - "gate_proj", "up_proj", "down_proj" - ], - task_type="CAUSAL_LM", - lora_dropout=0.05, - ) - -# ==================== TRAINING ==================== - -def main(): - """Main training function.""" - - # Load data - print("Loading dataset...") - dataset = get_dataset() - print(f"Dataset size: {len(dataset)}") - - # Setup model - print("Loading model...") - model, tokenizer = setup_model_and_tokenizer() - - # Training configuration - training_args = GRPOConfig( - output_dir=OUTPUT_DIR, - run_name="grpo-training", - - # Learning rate - learning_rate=5e-6, - adam_beta1=0.9, - adam_beta2=0.99, - weight_decay=0.1, - warmup_ratio=0.1, - lr_scheduler_type='cosine', - - # Batch settings - per_device_train_batch_size=1, - gradient_accumulation_steps=4, - - # GRPO specific - num_generations=8, - max_prompt_length=MAX_PROMPT_LENGTH, - max_completion_length=MAX_COMPLETION_LENGTH, - - # Training duration - num_train_epochs=1, - - # Optimization - bf16=True, - optim="adamw_8bit", - max_grad_norm=0.1, - - # Logging - logging_steps=1, - save_steps=100, - report_to="wandb", # Change to "none" to disable logging - ) - - # Initialize trainer - trainer = GRPOTrainer( - model=model, - processing_class=tokenizer, - reward_funcs=[ - incremental_format_reward_func, - format_reward_func, - correctness_reward_func, - ], - args=training_args, - train_dataset=dataset, - peft_config=get_peft_config(), - ) - - # Train - print("Starting training...") - trainer.train() - - # Save final model - print(f"Saving model to {OUTPUT_DIR}/final") - trainer.save_model(f"{OUTPUT_DIR}/final") - - print("Training complete!") - -if __name__ == "__main__": - main() diff --git a/skills/mlops/guidance/SKILL.md b/skills/mlops/guidance/SKILL.md deleted file mode 100644 index 12f5139ff..000000000 --- a/skills/mlops/guidance/SKILL.md +++ /dev/null @@ -1,575 +0,0 @@ ---- -name: guidance -description: Control LLM output with regex and grammars, guarantee valid JSON/XML/code generation, enforce structured formats, and build multi-step workflows with Guidance - Microsoft Research's constrained generation framework -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [guidance, transformers] -metadata: - hermes: - tags: [Prompt Engineering, Guidance, Constrained Generation, Structured Output, JSON Validation, Grammar, Microsoft Research, Format Enforcement, Multi-Step Workflows] - ---- - -# Guidance: Constrained LLM Generation - -## When to Use This Skill - -Use Guidance when you need to: -- **Control LLM output syntax** with regex or grammars -- **Guarantee valid JSON/XML/code** generation -- **Reduce latency** vs traditional prompting approaches -- **Enforce structured formats** (dates, emails, IDs, etc.) -- **Build multi-step workflows** with Pythonic control flow -- **Prevent invalid outputs** through grammatical constraints - -**GitHub Stars**: 18,000+ | **From**: Microsoft Research - -## Installation - -```bash -# Base installation -pip install guidance - -# With specific backends -pip install guidance[transformers] # Hugging Face models -pip install guidance[llama_cpp] # llama.cpp models -``` - -## Quick Start - -### Basic Example: Structured Generation - -```python -from guidance import models, gen - -# Load model (supports OpenAI, Transformers, llama.cpp) -lm = models.OpenAI("gpt-4") - -# Generate with constraints -result = lm + "The capital of France is " + gen("capital", max_tokens=5) - -print(result["capital"]) # "Paris" -``` - -### With Anthropic Claude - -```python -from guidance import models, gen, system, user, assistant - -# Configure Claude -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -# Use context managers for chat format -with system(): - lm += "You are a helpful assistant." - -with user(): - lm += "What is the capital of France?" - -with assistant(): - lm += gen(max_tokens=20) -``` - -## Core Concepts - -### 1. Context Managers - -Guidance uses Pythonic context managers for chat-style interactions. - -```python -from guidance import system, user, assistant, gen - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -# System message -with system(): - lm += "You are a JSON generation expert." - -# User message -with user(): - lm += "Generate a person object with name and age." - -# Assistant response -with assistant(): - lm += gen("response", max_tokens=100) - -print(lm["response"]) -``` - -**Benefits:** -- Natural chat flow -- Clear role separation -- Easy to read and maintain - -### 2. Constrained Generation - -Guidance ensures outputs match specified patterns using regex or grammars. - -#### Regex Constraints - -```python -from guidance import models, gen - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -# Constrain to valid email format -lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}") - -# Constrain to date format (YYYY-MM-DD) -lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}") - -# Constrain to phone number -lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}") - -print(lm["email"]) # Guaranteed valid email -print(lm["date"]) # Guaranteed YYYY-MM-DD format -``` - -**How it works:** -- Regex converted to grammar at token level -- Invalid tokens filtered during generation -- Model can only produce matching outputs - -#### Selection Constraints - -```python -from guidance import models, gen, select - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -# Constrain to specific choices -lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment") - -# Multiple-choice selection -lm += "Best answer: " + select( - ["A) Paris", "B) London", "C) Berlin", "D) Madrid"], - name="answer" -) - -print(lm["sentiment"]) # One of: positive, negative, neutral -print(lm["answer"]) # One of: A, B, C, or D -``` - -### 3. Token Healing - -Guidance automatically "heals" token boundaries between prompt and generation. - -**Problem:** Tokenization creates unnatural boundaries. - -```python -# Without token healing -prompt = "The capital of France is " -# Last token: " is " -# First generated token might be " Par" (with leading space) -# Result: "The capital of France is Paris" (double space!) -``` - -**Solution:** Guidance backs up one token and regenerates. - -```python -from guidance import models, gen - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -# Token healing enabled by default -lm += "The capital of France is " + gen("capital", max_tokens=5) -# Result: "The capital of France is Paris" (correct spacing) -``` - -**Benefits:** -- Natural text boundaries -- No awkward spacing issues -- Better model performance (sees natural token sequences) - -### 4. Grammar-Based Generation - -Define complex structures using context-free grammars. - -```python -from guidance import models, gen - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -# JSON grammar (simplified) -json_grammar = """ -{ - "name": , - "age": , - "email": -} -""" - -# Generate valid JSON -lm += gen("person", grammar=json_grammar) - -print(lm["person"]) # Guaranteed valid JSON structure -``` - -**Use cases:** -- Complex structured outputs -- Nested data structures -- Programming language syntax -- Domain-specific languages - -### 5. Guidance Functions - -Create reusable generation patterns with the `@guidance` decorator. - -```python -from guidance import guidance, gen, models - -@guidance -def generate_person(lm): - """Generate a person with name and age.""" - lm += "Name: " + gen("name", max_tokens=20, stop="\n") - lm += "\nAge: " + gen("age", regex=r"[0-9]+", max_tokens=3) - return lm - -# Use the function -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = generate_person(lm) - -print(lm["name"]) -print(lm["age"]) -``` - -**Stateful Functions:** - -```python -@guidance(stateless=False) -def react_agent(lm, question, tools, max_rounds=5): - """ReAct agent with tool use.""" - lm += f"Question: {question}\n\n" - - for i in range(max_rounds): - # Thought - lm += f"Thought {i+1}: " + gen("thought", stop="\n") - - # Action - lm += "\nAction: " + select(list(tools.keys()), name="action") - - # Execute tool - tool_result = tools[lm["action"]]() - lm += f"\nObservation: {tool_result}\n\n" - - # Check if done - lm += "Done? " + select(["Yes", "No"], name="done") - if lm["done"] == "Yes": - break - - # Final answer - lm += "\nFinal Answer: " + gen("answer", max_tokens=100) - return lm -``` - -## Backend Configuration - -### Anthropic Claude - -```python -from guidance import models - -lm = models.Anthropic( - model="claude-sonnet-4-5-20250929", - api_key="your-api-key" # Or set ANTHROPIC_API_KEY env var -) -``` - -### OpenAI - -```python -lm = models.OpenAI( - model="gpt-4o-mini", - api_key="your-api-key" # Or set OPENAI_API_KEY env var -) -``` - -### Local Models (Transformers) - -```python -from guidance.models import Transformers - -lm = Transformers( - "microsoft/Phi-4-mini-instruct", - device="cuda" # Or "cpu" -) -``` - -### Local Models (llama.cpp) - -```python -from guidance.models import LlamaCpp - -lm = LlamaCpp( - model_path="/path/to/model.gguf", - n_ctx=4096, - n_gpu_layers=35 -) -``` - -## Common Patterns - -### Pattern 1: JSON Generation - -```python -from guidance import models, gen, system, user, assistant - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -with system(): - lm += "You generate valid JSON." - -with user(): - lm += "Generate a user profile with name, age, and email." - -with assistant(): - lm += """{ - "name": """ + gen("name", regex=r'"[A-Za-z ]+"', max_tokens=30) + """, - "age": """ + gen("age", regex=r"[0-9]+", max_tokens=3) + """, - "email": """ + gen("email", regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"', max_tokens=50) + """ -}""" - -print(lm) # Valid JSON guaranteed -``` - -### Pattern 2: Classification - -```python -from guidance import models, gen, select - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -text = "This product is amazing! I love it." - -lm += f"Text: {text}\n" -lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment") -lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]+", max_tokens=3) + "%" - -print(f"Sentiment: {lm['sentiment']}") -print(f"Confidence: {lm['confidence']}%") -``` - -### Pattern 3: Multi-Step Reasoning - -```python -from guidance import models, gen, guidance - -@guidance -def chain_of_thought(lm, question): - """Generate answer with step-by-step reasoning.""" - lm += f"Question: {question}\n\n" - - # Generate multiple reasoning steps - for i in range(3): - lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n" - - # Final answer - lm += "\nTherefore, the answer is: " + gen("answer", max_tokens=50) - - return lm - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = chain_of_thought(lm, "What is 15% of 200?") - -print(lm["answer"]) -``` - -### Pattern 4: ReAct Agent - -```python -from guidance import models, gen, select, guidance - -@guidance(stateless=False) -def react_agent(lm, question): - """ReAct agent with tool use.""" - tools = { - "calculator": lambda expr: eval(expr), - "search": lambda query: f"Search results for: {query}", - } - - lm += f"Question: {question}\n\n" - - for round in range(5): - # Thought - lm += f"Thought: " + gen("thought", stop="\n") + "\n" - - # Action selection - lm += "Action: " + select(["calculator", "search", "answer"], name="action") - - if lm["action"] == "answer": - lm += "\nFinal Answer: " + gen("answer", max_tokens=100) - break - - # Action input - lm += "\nAction Input: " + gen("action_input", stop="\n") + "\n" - - # Execute tool - if lm["action"] in tools: - result = tools[lm["action"]](lm["action_input"]) - lm += f"Observation: {result}\n\n" - - return lm - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = react_agent(lm, "What is 25 * 4 + 10?") -print(lm["answer"]) -``` - -### Pattern 5: Data Extraction - -```python -from guidance import models, gen, guidance - -@guidance -def extract_entities(lm, text): - """Extract structured entities from text.""" - lm += f"Text: {text}\n\n" - - # Extract person - lm += "Person: " + gen("person", stop="\n", max_tokens=30) + "\n" - - # Extract organization - lm += "Organization: " + gen("organization", stop="\n", max_tokens=30) + "\n" - - # Extract date - lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}", max_tokens=10) + "\n" - - # Extract location - lm += "Location: " + gen("location", stop="\n", max_tokens=30) + "\n" - - return lm - -text = "Tim Cook announced at Apple Park on 2024-09-15 in Cupertino." - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = extract_entities(lm, text) - -print(f"Person: {lm['person']}") -print(f"Organization: {lm['organization']}") -print(f"Date: {lm['date']}") -print(f"Location: {lm['location']}") -``` - -## Best Practices - -### 1. Use Regex for Format Validation - -```python -# ✅ Good: Regex ensures valid format -lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}") - -# ❌ Bad: Free generation may produce invalid emails -lm += "Email: " + gen("email", max_tokens=50) -``` - -### 2. Use select() for Fixed Categories - -```python -# ✅ Good: Guaranteed valid category -lm += "Status: " + select(["pending", "approved", "rejected"], name="status") - -# ❌ Bad: May generate typos or invalid values -lm += "Status: " + gen("status", max_tokens=20) -``` - -### 3. Leverage Token Healing - -```python -# Token healing is enabled by default -# No special action needed - just concatenate naturally -lm += "The capital is " + gen("capital") # Automatic healing -``` - -### 4. Use stop Sequences - -```python -# ✅ Good: Stop at newline for single-line outputs -lm += "Name: " + gen("name", stop="\n") - -# ❌ Bad: May generate multiple lines -lm += "Name: " + gen("name", max_tokens=50) -``` - -### 5. Create Reusable Functions - -```python -# ✅ Good: Reusable pattern -@guidance -def generate_person(lm): - lm += "Name: " + gen("name", stop="\n") - lm += "\nAge: " + gen("age", regex=r"[0-9]+") - return lm - -# Use multiple times -lm = generate_person(lm) -lm += "\n\n" -lm = generate_person(lm) -``` - -### 6. Balance Constraints - -```python -# ✅ Good: Reasonable constraints -lm += gen("name", regex=r"[A-Za-z ]+", max_tokens=30) - -# ❌ Too strict: May fail or be very slow -lm += gen("name", regex=r"^(John|Jane)$", max_tokens=10) -``` - -## Comparison to Alternatives - -| Feature | Guidance | Instructor | Outlines | LMQL | -|---------|----------|------------|----------|------| -| Regex Constraints | ✅ Yes | ❌ No | ✅ Yes | ✅ Yes | -| Grammar Support | ✅ CFG | ❌ No | ✅ CFG | ✅ CFG | -| Pydantic Validation | ❌ No | ✅ Yes | ✅ Yes | ❌ No | -| Token Healing | ✅ Yes | ❌ No | ✅ Yes | ❌ No | -| Local Models | ✅ Yes | ⚠️ Limited | ✅ Yes | ✅ Yes | -| API Models | ✅ Yes | ✅ Yes | ⚠️ Limited | ✅ Yes | -| Pythonic Syntax | ✅ Yes | ✅ Yes | ✅ Yes | ❌ SQL-like | -| Learning Curve | Low | Low | Medium | High | - -**When to choose Guidance:** -- Need regex/grammar constraints -- Want token healing -- Building complex workflows with control flow -- Using local models (Transformers, llama.cpp) -- Prefer Pythonic syntax - -**When to choose alternatives:** -- Instructor: Need Pydantic validation with automatic retrying -- Outlines: Need JSON schema validation -- LMQL: Prefer declarative query syntax - -## Performance Characteristics - -**Latency Reduction:** -- 30-50% faster than traditional prompting for constrained outputs -- Token healing reduces unnecessary regeneration -- Grammar constraints prevent invalid token generation - -**Memory Usage:** -- Minimal overhead vs unconstrained generation -- Grammar compilation cached after first use -- Efficient token filtering at inference time - -**Token Efficiency:** -- Prevents wasted tokens on invalid outputs -- No need for retry loops -- Direct path to valid outputs - -## Resources - -- **Documentation**: https://guidance.readthedocs.io -- **GitHub**: https://github.com/guidance-ai/guidance (18k+ stars) -- **Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks -- **Discord**: Community support available - -## See Also - -- `references/constraints.md` - Comprehensive regex and grammar patterns -- `references/backends.md` - Backend-specific configuration -- `references/examples.md` - Production-ready examples - - diff --git a/skills/mlops/guidance/references/backends.md b/skills/mlops/guidance/references/backends.md deleted file mode 100644 index e1e9c5e44..000000000 --- a/skills/mlops/guidance/references/backends.md +++ /dev/null @@ -1,554 +0,0 @@ -# Backend Configuration Guide - -Complete guide to configuring Guidance with different LLM backends. - -## Table of Contents -- API-Based Models (Anthropic, OpenAI) -- Local Models (Transformers, llama.cpp) -- Backend Comparison -- Performance Tuning -- Advanced Configuration - -## API-Based Models - -### Anthropic Claude - -#### Basic Setup - -```python -from guidance import models - -# Using environment variable -lm = models.Anthropic("claude-sonnet-4-5-20250929") -# Reads ANTHROPIC_API_KEY from environment - -# Explicit API key -lm = models.Anthropic( - model="claude-sonnet-4-5-20250929", - api_key="your-api-key-here" -) -``` - -#### Available Models - -```python -# Claude 3.5 Sonnet (Latest, recommended) -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -# Claude 3.7 Sonnet (Fast, cost-effective) -lm = models.Anthropic("claude-sonnet-3.7-20250219") - -# Claude 3 Opus (Most capable) -lm = models.Anthropic("claude-3-opus-20240229") - -# Claude 3.5 Haiku (Fastest, cheapest) -lm = models.Anthropic("claude-3-5-haiku-20241022") -``` - -#### Configuration Options - -```python -lm = models.Anthropic( - model="claude-sonnet-4-5-20250929", - api_key="your-api-key", - max_tokens=4096, # Max tokens to generate - temperature=0.7, # Sampling temperature (0-1) - top_p=0.9, # Nucleus sampling - timeout=30, # Request timeout (seconds) - max_retries=3 # Retry failed requests -) -``` - -#### With Context Managers - -```python -from guidance import models, system, user, assistant, gen - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -with system(): - lm += "You are a helpful assistant." - -with user(): - lm += "What is the capital of France?" - -with assistant(): - lm += gen(max_tokens=50) - -print(lm) -``` - -### OpenAI - -#### Basic Setup - -```python -from guidance import models - -# Using environment variable -lm = models.OpenAI("gpt-4o") -# Reads OPENAI_API_KEY from environment - -# Explicit API key -lm = models.OpenAI( - model="gpt-4o", - api_key="your-api-key-here" -) -``` - -#### Available Models - -```python -# GPT-4o (Latest, multimodal) -lm = models.OpenAI("gpt-4o") - -# GPT-4o Mini (Fast, cost-effective) -lm = models.OpenAI("gpt-4o-mini") - -# GPT-4 Turbo -lm = models.OpenAI("gpt-4-turbo") - -# GPT-3.5 Turbo (Cheapest) -lm = models.OpenAI("gpt-3.5-turbo") -``` - -#### Configuration Options - -```python -lm = models.OpenAI( - model="gpt-4o-mini", - api_key="your-api-key", - max_tokens=2048, - temperature=0.7, - top_p=1.0, - frequency_penalty=0.0, - presence_penalty=0.0, - timeout=30 -) -``` - -#### Chat Format - -```python -from guidance import models, gen - -lm = models.OpenAI("gpt-4o-mini") - -# OpenAI uses chat format -lm += [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is 2+2?"} -] - -# Generate response -lm += gen(max_tokens=50) -``` - -### Azure OpenAI - -```python -from guidance import models - -lm = models.AzureOpenAI( - model="gpt-4o", - azure_endpoint="https://your-resource.openai.azure.com/", - api_key="your-azure-api-key", - api_version="2024-02-15-preview", - deployment_name="your-deployment-name" -) -``` - -## Local Models - -### Transformers (Hugging Face) - -#### Basic Setup - -```python -from guidance.models import Transformers - -# Load model from Hugging Face -lm = Transformers("microsoft/Phi-4-mini-instruct") -``` - -#### GPU Configuration - -```python -# Use GPU -lm = Transformers( - "microsoft/Phi-4-mini-instruct", - device="cuda" -) - -# Use specific GPU -lm = Transformers( - "microsoft/Phi-4-mini-instruct", - device="cuda:0" # GPU 0 -) - -# Use CPU -lm = Transformers( - "microsoft/Phi-4-mini-instruct", - device="cpu" -) -``` - -#### Advanced Configuration - -```python -lm = Transformers( - "microsoft/Phi-4-mini-instruct", - device="cuda", - torch_dtype="float16", # Use FP16 (faster, less memory) - load_in_8bit=True, # 8-bit quantization - max_memory={0: "20GB"}, # GPU memory limit - offload_folder="./offload" # Offload to disk if needed -) -``` - -#### Popular Models - -```python -# Phi-4 (Microsoft) -lm = Transformers("microsoft/Phi-4-mini-instruct") -lm = Transformers("microsoft/Phi-3-medium-4k-instruct") - -# Llama 3 (Meta) -lm = Transformers("meta-llama/Llama-3.1-8B-Instruct") -lm = Transformers("meta-llama/Llama-3.1-70B-Instruct") - -# Mistral (Mistral AI) -lm = Transformers("mistralai/Mistral-7B-Instruct-v0.3") -lm = Transformers("mistralai/Mixtral-8x7B-Instruct-v0.1") - -# Qwen (Alibaba) -lm = Transformers("Qwen/Qwen2.5-7B-Instruct") - -# Gemma (Google) -lm = Transformers("google/gemma-2-9b-it") -``` - -#### Generation Configuration - -```python -lm = Transformers( - "microsoft/Phi-4-mini-instruct", - device="cuda" -) - -# Configure generation -from guidance import gen - -result = lm + gen( - max_tokens=100, - temperature=0.7, - top_p=0.9, - top_k=50, - repetition_penalty=1.1 -) -``` - -### llama.cpp - -#### Basic Setup - -```python -from guidance.models import LlamaCpp - -# Load GGUF model -lm = LlamaCpp( - model_path="/path/to/model.gguf", - n_ctx=4096 # Context window -) -``` - -#### GPU Configuration - -```python -# Use GPU acceleration -lm = LlamaCpp( - model_path="/path/to/model.gguf", - n_ctx=4096, - n_gpu_layers=35, # Offload 35 layers to GPU - n_threads=8 # CPU threads for remaining layers -) - -# Full GPU offload -lm = LlamaCpp( - model_path="/path/to/model.gguf", - n_ctx=4096, - n_gpu_layers=-1 # Offload all layers -) -``` - -#### Advanced Configuration - -```python -lm = LlamaCpp( - model_path="/path/to/llama-3.1-8b-instruct.Q4_K_M.gguf", - n_ctx=8192, # Context window (tokens) - n_gpu_layers=35, # GPU layers - n_threads=8, # CPU threads - n_batch=512, # Batch size for prompt processing - use_mmap=True, # Memory-map the model file - use_mlock=False, # Lock model in RAM - seed=42, # Random seed - verbose=False # Suppress verbose output -) -``` - -#### Quantized Models - -```python -# Q4_K_M (4-bit, recommended for most cases) -lm = LlamaCpp("/path/to/model.Q4_K_M.gguf") - -# Q5_K_M (5-bit, better quality) -lm = LlamaCpp("/path/to/model.Q5_K_M.gguf") - -# Q8_0 (8-bit, high quality) -lm = LlamaCpp("/path/to/model.Q8_0.gguf") - -# F16 (16-bit float, highest quality) -lm = LlamaCpp("/path/to/model.F16.gguf") -``` - -#### Popular GGUF Models - -```python -# Llama 3.1 -lm = LlamaCpp("llama-3.1-8b-instruct.Q4_K_M.gguf") - -# Mistral -lm = LlamaCpp("mistral-7b-instruct-v0.3.Q4_K_M.gguf") - -# Phi-4 -lm = LlamaCpp("phi-4-mini-instruct.Q4_K_M.gguf") -``` - -## Backend Comparison - -### Feature Matrix - -| Feature | Anthropic | OpenAI | Transformers | llama.cpp | -|---------|-----------|--------|--------------|-----------| -| Constrained Generation | ✅ Full | ✅ Full | ✅ Full | ✅ Full | -| Token Healing | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes | -| Streaming | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes | -| GPU Support | N/A | N/A | ✅ Yes | ✅ Yes | -| Quantization | N/A | N/A | ✅ Yes | ✅ Yes | -| Cost | $$$ | $$$ | Free | Free | -| Latency | Low | Low | Medium | Low | -| Setup Difficulty | Easy | Easy | Medium | Medium | - -### Performance Characteristics - -**Anthropic Claude:** -- **Latency**: 200-500ms (API call) -- **Throughput**: Limited by API rate limits -- **Cost**: $3-15 per 1M input tokens -- **Best for**: Production systems, high-quality outputs - -**OpenAI:** -- **Latency**: 200-400ms (API call) -- **Throughput**: Limited by API rate limits -- **Cost**: $0.15-30 per 1M input tokens -- **Best for**: Cost-sensitive production, gpt-4o-mini - -**Transformers:** -- **Latency**: 50-200ms (local inference) -- **Throughput**: GPU-dependent (10-100 tokens/sec) -- **Cost**: Hardware cost only -- **Best for**: Privacy-sensitive, high-volume, experimentation - -**llama.cpp:** -- **Latency**: 30-150ms (local inference) -- **Throughput**: Hardware-dependent (20-150 tokens/sec) -- **Cost**: Hardware cost only -- **Best for**: Edge deployment, Apple Silicon, CPU inference - -### Memory Requirements - -**Transformers (FP16):** -- 7B model: ~14GB GPU VRAM -- 13B model: ~26GB GPU VRAM -- 70B model: ~140GB GPU VRAM (multi-GPU) - -**llama.cpp (Q4_K_M):** -- 7B model: ~4.5GB RAM -- 13B model: ~8GB RAM -- 70B model: ~40GB RAM - -**Optimization Tips:** -- Use quantized models (Q4_K_M) for lower memory -- Use GPU offloading for faster inference -- Use CPU inference for smaller models (<7B) - -## Performance Tuning - -### API Models (Anthropic, OpenAI) - -#### Reduce Latency - -```python -from guidance import models, gen - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -# Use lower max_tokens (faster response) -lm += gen(max_tokens=100) # Instead of 1000 - -# Use streaming (perceived latency reduction) -for chunk in lm.stream(gen(max_tokens=500)): - print(chunk, end="", flush=True) -``` - -#### Reduce Cost - -```python -# Use cheaper models -lm = models.Anthropic("claude-3-5-haiku-20241022") # vs Sonnet -lm = models.OpenAI("gpt-4o-mini") # vs gpt-4o - -# Reduce context size -# - Keep prompts concise -# - Avoid large few-shot examples -# - Use max_tokens limits -``` - -### Local Models (Transformers, llama.cpp) - -#### Optimize GPU Usage - -```python -from guidance.models import Transformers - -# Use FP16 for 2x speedup -lm = Transformers( - "meta-llama/Llama-3.1-8B-Instruct", - device="cuda", - torch_dtype="float16" -) - -# Use 8-bit quantization for 4x memory reduction -lm = Transformers( - "meta-llama/Llama-3.1-8B-Instruct", - device="cuda", - load_in_8bit=True -) - -# Use flash attention (requires flash-attn package) -lm = Transformers( - "meta-llama/Llama-3.1-8B-Instruct", - device="cuda", - use_flash_attention_2=True -) -``` - -#### Optimize llama.cpp - -```python -from guidance.models import LlamaCpp - -# Maximize GPU layers -lm = LlamaCpp( - model_path="/path/to/model.Q4_K_M.gguf", - n_gpu_layers=-1 # All layers on GPU -) - -# Optimize batch size -lm = LlamaCpp( - model_path="/path/to/model.Q4_K_M.gguf", - n_batch=512, # Larger batch = faster prompt processing - n_gpu_layers=-1 -) - -# Use Metal (Apple Silicon) -lm = LlamaCpp( - model_path="/path/to/model.Q4_K_M.gguf", - n_gpu_layers=-1, # Use Metal GPU acceleration - use_mmap=True -) -``` - -#### Batch Processing - -```python -# Process multiple requests efficiently -requests = [ - "What is 2+2?", - "What is the capital of France?", - "What is photosynthesis?" -] - -# Bad: Sequential processing -for req in requests: - lm = Transformers("microsoft/Phi-4-mini-instruct") - lm += req + gen(max_tokens=50) - -# Good: Reuse loaded model -lm = Transformers("microsoft/Phi-4-mini-instruct") -for req in requests: - lm += req + gen(max_tokens=50) -``` - -## Advanced Configuration - -### Custom Model Configurations - -```python -from transformers import AutoTokenizer, AutoModelForCausalLM -from guidance.models import Transformers - -# Load custom model -tokenizer = AutoTokenizer.from_pretrained("your-model") -model = AutoModelForCausalLM.from_pretrained( - "your-model", - device_map="auto", - torch_dtype="float16" -) - -# Use with Guidance -lm = Transformers(model=model, tokenizer=tokenizer) -``` - -### Environment Variables - -```bash -# API keys -export ANTHROPIC_API_KEY="sk-ant-..." -export OPENAI_API_KEY="sk-..." - -# Transformers cache -export HF_HOME="/path/to/cache" -export TRANSFORMERS_CACHE="/path/to/cache" - -# GPU selection -export CUDA_VISIBLE_DEVICES=0,1 # Use GPU 0 and 1 -``` - -### Debugging - -```python -# Enable verbose logging -import logging -logging.basicConfig(level=logging.DEBUG) - -# Check backend info -lm = models.Anthropic("claude-sonnet-4-5-20250929") -print(f"Model: {lm.model_name}") -print(f"Backend: {lm.backend}") - -# Check GPU usage (Transformers) -lm = Transformers("microsoft/Phi-4-mini-instruct", device="cuda") -print(f"Device: {lm.device}") -print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB") -``` - -## Resources - -- **Anthropic Docs**: https://docs.anthropic.com -- **OpenAI Docs**: https://platform.openai.com/docs -- **Hugging Face Models**: https://huggingface.co/models -- **llama.cpp**: https://github.com/ggerganov/llama.cpp -- **GGUF Models**: https://huggingface.co/models?library=gguf diff --git a/skills/mlops/guidance/references/constraints.md b/skills/mlops/guidance/references/constraints.md deleted file mode 100644 index 99c81890c..000000000 --- a/skills/mlops/guidance/references/constraints.md +++ /dev/null @@ -1,674 +0,0 @@ -# Comprehensive Constraint Patterns - -Guide to regex constraints, grammar-based generation, and token healing in Guidance. - -## Table of Contents -- Regex Constraints -- Grammar-Based Generation -- Token Healing -- Selection Constraints -- Complex Patterns -- Performance Optimization - -## Regex Constraints - -### Basic Patterns - -#### Numeric Constraints - -```python -from guidance import models, gen - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -# Integer (positive) -lm += "Age: " + gen("age", regex=r"[0-9]+") - -# Integer (with negatives) -lm += "Temperature: " + gen("temp", regex=r"-?[0-9]+") - -# Float (positive) -lm += "Price: $" + gen("price", regex=r"[0-9]+\.[0-9]{2}") - -# Float (with negatives and optional decimals) -lm += "Value: " + gen("value", regex=r"-?[0-9]+(\.[0-9]+)?") - -# Percentage (0-100) -lm += "Progress: " + gen("progress", regex=r"(100|[0-9]{1,2})") - -# Range (1-5 stars) -lm += "Rating: " + gen("rating", regex=r"[1-5]") + " stars" -``` - -#### Text Constraints - -```python -# Alphabetic only -lm += "Name: " + gen("name", regex=r"[A-Za-z]+") - -# Alphabetic with spaces -lm += "Full Name: " + gen("full_name", regex=r"[A-Za-z ]+") - -# Alphanumeric -lm += "Username: " + gen("username", regex=r"[A-Za-z0-9_]+") - -# Capitalized words -lm += "Title: " + gen("title", regex=r"[A-Z][a-z]+( [A-Z][a-z]+)*") - -# Lowercase only -lm += "Code: " + gen("code", regex=r"[a-z0-9-]+") - -# Specific length -lm += "ID: " + gen("id", regex=r"[A-Z]{3}-[0-9]{6}") # e.g., "ABC-123456" -``` - -#### Date and Time Constraints - -```python -# Date (YYYY-MM-DD) -lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}") - -# Date (MM/DD/YYYY) -lm += "Date: " + gen("date_us", regex=r"\d{2}/\d{2}/\d{4}") - -# Time (HH:MM) -lm += "Time: " + gen("time", regex=r"\d{2}:\d{2}") - -# Time (HH:MM:SS) -lm += "Time: " + gen("time_full", regex=r"\d{2}:\d{2}:\d{2}") - -# ISO 8601 datetime -lm += "Timestamp: " + gen( - "timestamp", - regex=r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z" -) - -# Year (YYYY) -lm += "Year: " + gen("year", regex=r"(19|20)\d{2}") - -# Month name -lm += "Month: " + gen( - "month", - regex=r"(January|February|March|April|May|June|July|August|September|October|November|December)" -) -``` - -#### Contact Information - -```python -# Email -lm += "Email: " + gen( - "email", - regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}" -) - -# Phone (US format) -lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}") - -# Phone (international format) -lm += "Phone: " + gen("phone_intl", regex=r"\+[0-9]{1,3}-[0-9]{1,14}") - -# ZIP code (US) -lm += "ZIP: " + gen("zip", regex=r"\d{5}(-\d{4})?") - -# Postal code (Canada) -lm += "Postal: " + gen("postal", regex=r"[A-Z]\d[A-Z] \d[A-Z]\d") - -# URL -lm += "URL: " + gen( - "url", - regex=r"https?://[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(/[a-zA-Z0-9._~:/?#\[\]@!$&'()*+,;=-]*)?" -) -``` - -### Advanced Patterns - -#### JSON Field Constraints - -```python -from guidance import models, gen - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -# String field with quotes -lm += '"name": ' + gen("name", regex=r'"[A-Za-z ]+"') - -# Numeric field (no quotes) -lm += '"age": ' + gen("age", regex=r"[0-9]+") - -# Boolean field -lm += '"active": ' + gen("active", regex=r"(true|false)") - -# Null field -lm += '"optional": ' + gen("optional", regex=r"(null|[0-9]+)") - -# Array of strings -lm += '"tags": [' + gen( - "tags", - regex=r'"[a-z]+"(, "[a-z]+")*' -) + ']' - -# Complete JSON object -lm += """{ - "name": """ + gen("name", regex=r'"[A-Za-z ]+"') + """, - "age": """ + gen("age", regex=r"[0-9]+") + """, - "email": """ + gen( - "email", - regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"' - ) + """ -}""" -``` - -#### Code Patterns - -```python -# Python variable name -lm += "Variable: " + gen("var", regex=r"[a-z_][a-z0-9_]*") - -# Python function name -lm += "Function: " + gen("func", regex=r"[a-z_][a-z0-9_]*") - -# Hex color code -lm += "Color: #" + gen("color", regex=r"[0-9A-Fa-f]{6}") - -# UUID -lm += "UUID: " + gen( - "uuid", - regex=r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" -) - -# Git commit hash (short) -lm += "Commit: " + gen("commit", regex=r"[0-9a-f]{7}") - -# Semantic version -lm += "Version: " + gen("version", regex=r"[0-9]+\.[0-9]+\.[0-9]+") - -# IP address (IPv4) -lm += "IP: " + gen( - "ip", - regex=r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)" -) -``` - -#### Domain-Specific Patterns - -```python -# Credit card number -lm += "Card: " + gen("card", regex=r"\d{4}-\d{4}-\d{4}-\d{4}") - -# Social Security Number (US) -lm += "SSN: " + gen("ssn", regex=r"\d{3}-\d{2}-\d{4}") - -# ISBN-13 -lm += "ISBN: " + gen("isbn", regex=r"978-\d{1,5}-\d{1,7}-\d{1,7}-\d") - -# License plate (US) -lm += "Plate: " + gen("plate", regex=r"[A-Z]{3}-\d{4}") - -# Currency amount -lm += "Amount: $" + gen("amount", regex=r"[0-9]{1,3}(,[0-9]{3})*\.[0-9]{2}") - -# Percentage with decimal -lm += "Rate: " + gen("rate", regex=r"[0-9]+\.[0-9]{1,2}%") -``` - -## Grammar-Based Generation - -### JSON Grammar - -```python -from guidance import models, gen, guidance - -@guidance -def json_object(lm): - """Generate valid JSON object.""" - lm += "{\n" - - # Name field (required) - lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n" - - # Age field (required) - lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n" - - # Email field (required) - lm += ' "email": ' + gen( - "email", - regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"' - ) + ",\n" - - # Active field (required, boolean) - lm += ' "active": ' + gen("active", regex=r"(true|false)") + "\n" - - lm += "}" - return lm - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = json_object(lm) -print(lm) # Valid JSON guaranteed -``` - -### Nested JSON Grammar - -```python -@guidance -def nested_json(lm): - """Generate nested JSON structure.""" - lm += "{\n" - - # User object - lm += ' "user": {\n' - lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n" - lm += ' "age": ' + gen("age", regex=r"[0-9]+") + "\n" - lm += " },\n" - - # Address object - lm += ' "address": {\n' - lm += ' "street": ' + gen("street", regex=r'"[A-Za-z0-9 ]+"') + ",\n" - lm += ' "city": ' + gen("city", regex=r'"[A-Za-z ]+"') + ",\n" - lm += ' "zip": ' + gen("zip", regex=r'"\d{5}"') + "\n" - lm += " }\n" - - lm += "}" - return lm -``` - -### Array Grammar - -```python -@guidance -def json_array(lm, count=3): - """Generate JSON array with fixed count.""" - lm += "[\n" - - for i in range(count): - lm += " {\n" - lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n" - lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + "\n" - lm += " }" - if i < count - 1: - lm += "," - lm += "\n" - - lm += "]" - return lm -``` - -### XML Grammar - -```python -@guidance -def xml_document(lm): - """Generate valid XML document.""" - lm += '\n' - lm += "\n" - - # Name element - lm += " " + gen("name", regex=r"[A-Za-z ]+") + "\n" - - # Age element - lm += " " + gen("age", regex=r"[0-9]+") + "\n" - - # Email element - lm += " " + gen( - "email", - regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}" - ) + "\n" - - lm += "" - return lm -``` - -### CSV Grammar - -```python -@guidance -def csv_row(lm): - """Generate CSV row.""" - lm += gen("name", regex=r"[A-Za-z ]+") + "," - lm += gen("age", regex=r"[0-9]+") + "," - lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}") - return lm - -@guidance -def csv_document(lm, rows=5): - """Generate complete CSV.""" - # Header - lm += "Name,Age,Email\n" - - # Rows - for i in range(rows): - lm = csv_row(lm) - if i < rows - 1: - lm += "\n" - - return lm -``` - -## Token Healing - -### How Token Healing Works - -**Problem:** Tokenization creates unnatural boundaries. - -```python -# Example without token healing -prompt = "The capital of France is " -# Tokenization: ["The", " capital", " of", " France", " is", " "] -# Model sees last token: " " -# First generated token might include leading space: " Paris" -# Result: "The capital of France is Paris" (double space) -``` - -**Solution:** Guidance backs up and regenerates the last token. - -```python -from guidance import models, gen - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -# Token healing enabled by default -lm += "The capital of France is " + gen("capital", max_tokens=5) - -# Process: -# 1. Back up to token before " is " -# 2. Regenerate " is" + "capital" together -# 3. Result: "The capital of France is Paris" (correct) -``` - -### Token Healing Examples - -#### Natural Continuations - -```python -# Before token healing -lm += "The function name is get" + gen("rest") -# Might generate: "The function name is get User" (space before User) - -# With token healing -lm += "The function name is get" + gen("rest") -# Generates: "The function name is getUser" (correct camelCase) -``` - -#### Code Generation - -```python -# Function name completion -lm += "def calculate_" + gen("rest", stop="(") -# Token healing ensures smooth connection: "calculate_total" - -# Variable name completion -lm += "my_" + gen("var_name", regex=r"[a-z_]+") -# Token healing ensures: "my_variable_name" (not "my_ variable_name") -``` - -#### Domain-Specific Terms - -```python -# Medical terms -lm += "The patient has hyper" + gen("condition") -# Token healing helps: "hypertension" (not "hyper tension") - -# Technical terms -lm += "Using micro" + gen("tech") -# Token healing helps: "microservices" (not "micro services") -``` - -### Disabling Token Healing - -```python -# Disable token healing if needed (rare) -lm += gen("text", token_healing=False) -``` - -## Selection Constraints - -### Basic Selection - -```python -from guidance import models, select - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -# Simple selection -lm += "Status: " + select(["active", "inactive", "pending"], name="status") - -# Boolean selection -lm += "Approved: " + select(["Yes", "No"], name="approved") - -# Multiple choice -lm += "Answer: " + select( - ["A) Paris", "B) London", "C) Berlin", "D) Madrid"], - name="answer" -) -``` - -### Conditional Selection - -```python -from guidance import models, select, gen, guidance - -@guidance -def conditional_fields(lm): - """Generate fields conditionally based on type.""" - lm += "Type: " + select(["person", "company"], name="type") - - if lm["type"] == "person": - lm += "\nName: " + gen("name", regex=r"[A-Za-z ]+") - lm += "\nAge: " + gen("age", regex=r"[0-9]+") - else: - lm += "\nCompany Name: " + gen("company", regex=r"[A-Za-z ]+") - lm += "\nEmployees: " + gen("employees", regex=r"[0-9]+") - - return lm -``` - -### Repeated Selection - -```python -@guidance -def multiple_selections(lm): - """Select multiple items.""" - lm += "Select 3 colors:\n" - - colors = ["red", "blue", "green", "yellow", "purple"] - - for i in range(3): - lm += f"{i+1}. " + select(colors, name=f"color_{i}") + "\n" - - return lm -``` - -## Complex Patterns - -### Pattern 1: Structured Forms - -```python -@guidance -def user_form(lm): - """Generate structured user form.""" - lm += "=== User Registration ===\n\n" - - # Name (alphabetic only) - lm += "Full Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n" - - # Age (numeric) - lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n" - - # Email (validated format) - lm += "Email: " + gen( - "email", - regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", - stop="\n" - ) + "\n" - - # Phone (US format) - lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}") + "\n" - - # Account type (selection) - lm += "Account Type: " + select( - ["Standard", "Premium", "Enterprise"], - name="account_type" - ) + "\n" - - # Active status (boolean) - lm += "Active: " + select(["Yes", "No"], name="active") + "\n" - - return lm -``` - -### Pattern 2: Multi-Entity Extraction - -```python -@guidance -def extract_entities(lm, text): - """Extract multiple entities with constraints.""" - lm += f"Text: {text}\n\n" - - # Person name (alphabetic) - lm += "Person: " + gen("person", regex=r"[A-Za-z ]+", stop="\n") + "\n" - - # Organization (alphanumeric with spaces) - lm += "Organization: " + gen( - "organization", - regex=r"[A-Za-z0-9 ]+", - stop="\n" - ) + "\n" - - # Date (YYYY-MM-DD format) - lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}") + "\n" - - # Location (alphabetic with spaces) - lm += "Location: " + gen("location", regex=r"[A-Za-z ]+", stop="\n") + "\n" - - # Amount (currency) - lm += "Amount: $" + gen("amount", regex=r"[0-9,]+\.[0-9]{2}") + "\n" - - return lm -``` - -### Pattern 3: Code Generation - -```python -@guidance -def generate_python_function(lm): - """Generate Python function with constraints.""" - # Function name (valid Python identifier) - lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "(" - - # Parameter name - lm += gen("param", regex=r"[a-z_][a-z0-9_]*") + "):\n" - - # Docstring - lm += ' """' + gen("docstring", stop='"""', max_tokens=50) + '"""\n' - - # Function body (constrained to valid Python) - lm += " return " + gen("return_value", stop="\n") + "\n" - - return lm -``` - -### Pattern 4: Hierarchical Data - -```python -@guidance -def org_chart(lm): - """Generate organizational chart.""" - lm += "Company: " + gen("company", regex=r"[A-Za-z ]+") + "\n\n" - - # CEO - lm += "CEO: " + gen("ceo", regex=r"[A-Za-z ]+") + "\n" - - # Departments - for dept in ["Engineering", "Sales", "Marketing"]: - lm += f"\n{dept} Department:\n" - lm += " Head: " + gen(f"{dept.lower()}_head", regex=r"[A-Za-z ]+") + "\n" - lm += " Size: " + gen(f"{dept.lower()}_size", regex=r"[0-9]+") + " employees\n" - - return lm -``` - -## Performance Optimization - -### Best Practices - -#### 1. Use Specific Patterns - -```python -# ✅ Good: Specific pattern -lm += gen("age", regex=r"[0-9]{1,3}") # Fast - -# ❌ Bad: Overly broad pattern -lm += gen("age", regex=r"[0-9]+") # Slower -``` - -#### 2. Limit Max Tokens - -```python -# ✅ Good: Reasonable limit -lm += gen("name", max_tokens=30) - -# ❌ Bad: No limit -lm += gen("name") # May generate forever -``` - -#### 3. Use stop Sequences - -```python -# ✅ Good: Stop at newline -lm += gen("line", stop="\n") - -# ❌ Bad: Rely on max_tokens -lm += gen("line", max_tokens=100) -``` - -#### 4. Cache Compiled Grammars - -```python -# Grammars are cached automatically after first use -# No manual caching needed -@guidance -def reusable_pattern(lm): - """This grammar is compiled once and cached.""" - lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}") - return lm - -# First call: compiles grammar -lm = reusable_pattern(lm) - -# Subsequent calls: uses cached grammar (fast) -lm = reusable_pattern(lm) -``` - -#### 5. Avoid Overlapping Constraints - -```python -# ✅ Good: Clear constraints -lm += gen("age", regex=r"[0-9]+", max_tokens=3) - -# ❌ Bad: Conflicting constraints -lm += gen("age", regex=r"[0-9]{2}", max_tokens=10) # max_tokens unnecessary -``` - -### Performance Benchmarks - -**Regex vs Free Generation:** -- Simple regex (digits): ~1.2x slower than free gen -- Complex regex (email): ~1.5x slower than free gen -- Grammar-based: ~2x slower than free gen - -**But:** -- 100% valid outputs (vs ~70% with free gen + validation) -- No retry loops needed -- Overall faster end-to-end for structured outputs - -**Optimization Tips:** -- Use regex for critical fields only -- Use `select()` for small fixed sets (fastest) -- Use `stop` sequences when possible (faster than max_tokens) -- Cache compiled grammars by reusing functions - -## Resources - -- **Token Healing Paper**: https://arxiv.org/abs/2306.17648 -- **Guidance Docs**: https://guidance.readthedocs.io -- **GitHub**: https://github.com/guidance-ai/guidance diff --git a/skills/mlops/guidance/references/examples.md b/skills/mlops/guidance/references/examples.md deleted file mode 100644 index 315388748..000000000 --- a/skills/mlops/guidance/references/examples.md +++ /dev/null @@ -1,767 +0,0 @@ -# Production-Ready Examples - -Real-world examples of using Guidance for structured generation, agents, and workflows. - -## Table of Contents -- JSON Generation -- Data Extraction -- Classification Systems -- Agent Systems -- Multi-Step Workflows -- Code Generation -- Production Tips - -## JSON Generation - -### Basic JSON - -```python -from guidance import models, gen, guidance - -@guidance -def generate_user(lm): - """Generate valid user JSON.""" - lm += "{\n" - lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n" - lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n" - lm += ' "email": ' + gen( - "email", - regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"' - ) + "\n" - lm += "}" - return lm - -# Use it -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm += "Generate a user profile:\n" -lm = generate_user(lm) - -print(lm) -# Output: Valid JSON guaranteed -``` - -### Nested JSON - -```python -@guidance -def generate_order(lm): - """Generate nested order JSON.""" - lm += "{\n" - - # Customer info - lm += ' "customer": {\n' - lm += ' "name": ' + gen("customer_name", regex=r'"[A-Za-z ]+"') + ",\n" - lm += ' "email": ' + gen( - "customer_email", - regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"' - ) + "\n" - lm += " },\n" - - # Order details - lm += ' "order": {\n' - lm += ' "id": ' + gen("order_id", regex=r'"ORD-[0-9]{6}"') + ",\n" - lm += ' "date": ' + gen("order_date", regex=r'"\d{4}-\d{2}-\d{2}"') + ",\n" - lm += ' "total": ' + gen("order_total", regex=r"[0-9]+\.[0-9]{2}") + "\n" - lm += " },\n" - - # Status - lm += ' "status": ' + gen( - "status", - regex=r'"(pending|processing|shipped|delivered)"' - ) + "\n" - - lm += "}" - return lm - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = generate_order(lm) -``` - -### JSON Array - -```python -@guidance -def generate_user_list(lm, count=3): - """Generate JSON array of users.""" - lm += "[\n" - - for i in range(count): - lm += " {\n" - lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n" - lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + ",\n" - lm += ' "active": ' + gen(f"active_{i}", regex=r"(true|false)") + "\n" - lm += " }" - if i < count - 1: - lm += "," - lm += "\n" - - lm += "]" - return lm - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = generate_user_list(lm, count=5) -``` - -### Dynamic JSON Schema - -```python -import json -from guidance import models, gen, guidance - -@guidance -def json_from_schema(lm, schema): - """Generate JSON matching a schema.""" - lm += "{\n" - - fields = list(schema["properties"].items()) - for i, (field_name, field_schema) in enumerate(fields): - lm += f' "{field_name}": ' - - # Handle different types - if field_schema["type"] == "string": - if "pattern" in field_schema: - lm += gen(field_name, regex=f'"{field_schema["pattern"]}"') - else: - lm += gen(field_name, regex=r'"[^"]+"') - elif field_schema["type"] == "number": - lm += gen(field_name, regex=r"[0-9]+(\.[0-9]+)?") - elif field_schema["type"] == "integer": - lm += gen(field_name, regex=r"[0-9]+") - elif field_schema["type"] == "boolean": - lm += gen(field_name, regex=r"(true|false)") - - if i < len(fields) - 1: - lm += "," - lm += "\n" - - lm += "}" - return lm - -# Define schema -schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"}, - "score": {"type": "number"}, - "active": {"type": "boolean"} - } -} - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = json_from_schema(lm, schema) -``` - -## Data Extraction - -### Extract from Text - -```python -from guidance import models, gen, guidance, system, user, assistant - -@guidance -def extract_person_info(lm, text): - """Extract structured info from text.""" - lm += f"Text: {text}\n\n" - - with assistant(): - lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n" - lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n" - lm += "Occupation: " + gen("occupation", regex=r"[A-Za-z ]+", stop="\n") + "\n" - lm += "Email: " + gen( - "email", - regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", - stop="\n" - ) + "\n" - - return lm - -text = "John Smith is a 35-year-old software engineer. Contact: john@example.com" - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -with system(): - lm += "You extract structured information from text." - -with user(): - lm = extract_person_info(lm, text) - -print(f"Name: {lm['name']}") -print(f"Age: {lm['age']}") -print(f"Occupation: {lm['occupation']}") -print(f"Email: {lm['email']}") -``` - -### Multi-Entity Extraction - -```python -@guidance -def extract_entities(lm, text): - """Extract multiple entity types.""" - lm += f"Analyze: {text}\n\n" - - # Person entities - lm += "People:\n" - for i in range(3): # Up to 3 people - lm += f"- " + gen(f"person_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n" - - # Organization entities - lm += "\nOrganizations:\n" - for i in range(2): # Up to 2 orgs - lm += f"- " + gen(f"org_{i}", regex=r"[A-Za-z0-9 ]+", stop="\n") + "\n" - - # Dates - lm += "\nDates:\n" - for i in range(2): # Up to 2 dates - lm += f"- " + gen(f"date_{i}", regex=r"\d{4}-\d{2}-\d{2}", stop="\n") + "\n" - - # Locations - lm += "\nLocations:\n" - for i in range(2): # Up to 2 locations - lm += f"- " + gen(f"location_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n" - - return lm - -text = """ -Tim Cook and Satya Nadella met at Microsoft headquarters in Redmond on 2024-09-15 -to discuss the collaboration between Apple and Microsoft. The meeting continued -in Cupertino on 2024-09-20. -""" - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = extract_entities(lm, text) -``` - -### Batch Extraction - -```python -@guidance -def batch_extract(lm, texts): - """Extract from multiple texts.""" - lm += "Batch Extraction Results:\n\n" - - for i, text in enumerate(texts): - lm += f"=== Item {i+1} ===\n" - lm += f"Text: {text}\n" - lm += "Name: " + gen(f"name_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n" - lm += "Sentiment: " + gen( - f"sentiment_{i}", - regex=r"(positive|negative|neutral)", - stop="\n" - ) + "\n\n" - - return lm - -texts = [ - "Alice is happy with the product", - "Bob is disappointed with the service", - "Carol has no strong feelings either way" -] - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = batch_extract(lm, texts) -``` - -## Classification Systems - -### Sentiment Analysis - -```python -from guidance import models, select, gen - -lm = models.Anthropic("claude-sonnet-4-5-20250929") - -text = "This product is absolutely amazing! Best purchase ever." - -lm += f"Text: {text}\n\n" -lm += "Sentiment: " + select( - ["positive", "negative", "neutral"], - name="sentiment" -) -lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]{1,3}") + "%\n" -lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=50) - -print(f"Sentiment: {lm['sentiment']}") -print(f"Confidence: {lm['confidence']}%") -print(f"Reasoning: {lm['reasoning']}") -``` - -### Multi-Label Classification - -```python -@guidance -def classify_article(lm, text): - """Classify article with multiple labels.""" - lm += f"Article: {text}\n\n" - - # Primary category - lm += "Primary Category: " + select( - ["Technology", "Business", "Science", "Politics", "Entertainment"], - name="primary_category" - ) + "\n" - - # Secondary categories (up to 3) - lm += "\nSecondary Categories:\n" - categories = ["Technology", "Business", "Science", "Politics", "Entertainment"] - for i in range(3): - lm += f"{i+1}. " + select(categories, name=f"secondary_{i}") + "\n" - - # Tags - lm += "\nTags: " + gen("tags", stop="\n", max_tokens=50) + "\n" - - # Target audience - lm += "Target Audience: " + select( - ["General", "Expert", "Beginner"], - name="audience" - ) - - return lm - -article = """ -Apple announced new AI features in iOS 18, leveraging machine learning to improve -battery life and performance. The company's stock rose 5% following the announcement. -""" - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = classify_article(lm, article) -``` - -### Intent Classification - -```python -@guidance -def classify_intent(lm, message): - """Classify user intent.""" - lm += f"User Message: {message}\n\n" - - # Intent - lm += "Intent: " + select( - ["question", "complaint", "request", "feedback", "other"], - name="intent" - ) + "\n" - - # Urgency - lm += "Urgency: " + select( - ["low", "medium", "high", "critical"], - name="urgency" - ) + "\n" - - # Department - lm += "Route To: " + select( - ["support", "sales", "billing", "technical"], - name="department" - ) + "\n" - - # Sentiment - lm += "Sentiment: " + select( - ["positive", "neutral", "negative"], - name="sentiment" - ) - - return lm - -message = "My account was charged twice for the same order. Need help ASAP!" - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = classify_intent(lm, message) - -print(f"Intent: {lm['intent']}") -print(f"Urgency: {lm['urgency']}") -print(f"Department: {lm['department']}") -``` - -## Agent Systems - -### ReAct Agent - -```python -from guidance import models, gen, select, guidance - -@guidance(stateless=False) -def react_agent(lm, question, tools, max_rounds=5): - """ReAct agent with tool use.""" - lm += f"Question: {question}\n\n" - - for round in range(max_rounds): - # Thought - lm += f"Thought {round+1}: " + gen("thought", stop="\n", max_tokens=100) + "\n" - - # Action selection - lm += "Action: " + select( - list(tools.keys()) + ["answer"], - name="action" - ) - - if lm["action"] == "answer": - lm += "\n\nFinal Answer: " + gen("answer", max_tokens=200) - break - - # Action input - lm += "\nAction Input: " + gen("action_input", stop="\n", max_tokens=100) + "\n" - - # Execute tool - if lm["action"] in tools: - try: - result = tools[lm["action"]](lm["action_input"]) - lm += f"Observation: {result}\n\n" - except Exception as e: - lm += f"Observation: Error - {str(e)}\n\n" - - return lm - -# Define tools -tools = { - "calculator": lambda expr: eval(expr), - "search": lambda query: f"Search results for '{query}': [Mock results]", - "weather": lambda city: f"Weather in {city}: Sunny, 72°F" -} - -# Use agent -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = react_agent(lm, "What is (25 * 4) + 10?", tools) - -print(lm["answer"]) -``` - -### Multi-Agent System - -```python -@guidance -def coordinator_agent(lm, task): - """Coordinator that delegates to specialists.""" - lm += f"Task: {task}\n\n" - - # Determine which specialist to use - lm += "Specialist: " + select( - ["researcher", "writer", "coder", "analyst"], - name="specialist" - ) + "\n" - - lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=100) + "\n" - - return lm - -@guidance -def researcher_agent(lm, query): - """Research specialist.""" - lm += f"Research Query: {query}\n\n" - lm += "Findings:\n" - for i in range(3): - lm += f"{i+1}. " + gen(f"finding_{i}", stop="\n", max_tokens=100) + "\n" - return lm - -@guidance -def writer_agent(lm, topic): - """Writing specialist.""" - lm += f"Topic: {topic}\n\n" - lm += "Title: " + gen("title", stop="\n", max_tokens=50) + "\n" - lm += "Content:\n" + gen("content", max_tokens=500) - return lm - -# Coordination workflow -task = "Write an article about AI safety" - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = coordinator_agent(lm, task) - -specialist = lm["specialist"] -if specialist == "researcher": - lm = researcher_agent(lm, task) -elif specialist == "writer": - lm = writer_agent(lm, task) -``` - -### Tool Use with Validation - -```python -@guidance(stateless=False) -def validated_tool_agent(lm, question): - """Agent with validated tool calls.""" - tools = { - "add": lambda a, b: float(a) + float(b), - "multiply": lambda a, b: float(a) * float(b), - "divide": lambda a, b: float(a) / float(b) if float(b) != 0 else "Error: Division by zero" - } - - lm += f"Question: {question}\n\n" - - for i in range(5): - # Select tool - lm += "Tool: " + select(list(tools.keys()) + ["done"], name="tool") - - if lm["tool"] == "done": - lm += "\nAnswer: " + gen("answer", max_tokens=100) - break - - # Get validated numeric arguments - lm += "\nArg1: " + gen("arg1", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n" - lm += "Arg2: " + gen("arg2", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n" - - # Execute - result = tools[lm["tool"]](lm["arg1"], lm["arg2"]) - lm += f"Result: {result}\n\n" - - return lm - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = validated_tool_agent(lm, "What is (10 + 5) * 3?") -``` - -## Multi-Step Workflows - -### Chain of Thought - -```python -@guidance -def chain_of_thought(lm, question): - """Multi-step reasoning with CoT.""" - lm += f"Question: {question}\n\n" - - # Generate reasoning steps - lm += "Let me think step by step:\n\n" - for i in range(4): - lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n" - - # Final answer - lm += "\nTherefore, the answer is: " + gen("answer", stop="\n", max_tokens=50) - - return lm - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = chain_of_thought(lm, "If a train travels 60 mph for 2.5 hours, how far does it go?") - -print(lm["answer"]) -``` - -### Self-Consistency - -```python -@guidance -def self_consistency(lm, question, num_samples=3): - """Generate multiple reasoning paths and aggregate.""" - lm += f"Question: {question}\n\n" - - answers = [] - for i in range(num_samples): - lm += f"=== Attempt {i+1} ===\n" - lm += "Reasoning: " + gen(f"reasoning_{i}", stop="\n", max_tokens=100) + "\n" - lm += "Answer: " + gen(f"answer_{i}", stop="\n", max_tokens=50) + "\n\n" - answers.append(lm[f"answer_{i}"]) - - # Aggregate (simple majority vote) - from collections import Counter - most_common = Counter(answers).most_common(1)[0][0] - - lm += f"Final Answer (by majority): {most_common}\n" - return lm - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = self_consistency(lm, "What is 15% of 200?") -``` - -### Planning and Execution - -```python -@guidance -def plan_and_execute(lm, goal): - """Plan tasks then execute them.""" - lm += f"Goal: {goal}\n\n" - - # Planning phase - lm += "Plan:\n" - num_steps = 4 - for i in range(num_steps): - lm += f"{i+1}. " + gen(f"plan_step_{i}", stop="\n", max_tokens=100) + "\n" - - # Execution phase - lm += "\nExecution:\n\n" - for i in range(num_steps): - lm += f"Step {i+1}: {lm[f'plan_step_{i}']}\n" - lm += "Status: " + select(["completed", "in-progress", "blocked"], name=f"status_{i}") + "\n" - lm += "Result: " + gen(f"result_{i}", stop="\n", max_tokens=150) + "\n\n" - - # Summary - lm += "Summary: " + gen("summary", max_tokens=200) - - return lm - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = plan_and_execute(lm, "Build a REST API for a blog platform") -``` - -## Code Generation - -### Python Function - -```python -@guidance -def generate_python_function(lm, description): - """Generate Python function from description.""" - lm += f"Description: {description}\n\n" - - # Function signature - lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "(" - lm += gen("params", regex=r"[a-z_][a-z0-9_]*(, [a-z_][a-z0-9_]*)*") + "):\n" - - # Docstring - lm += ' """' + gen("docstring", stop='"""', max_tokens=100) + '"""\n' - - # Function body - lm += " " + gen("body", stop="\n", max_tokens=200) + "\n" - - return lm - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = generate_python_function(lm, "Check if a number is prime") - -print(lm) -``` - -### SQL Query - -```python -@guidance -def generate_sql(lm, description): - """Generate SQL query from description.""" - lm += f"Description: {description}\n\n" - lm += "SQL Query:\n" - - # SELECT clause - lm += "SELECT " + gen("select_clause", stop=" FROM", max_tokens=100) - - # FROM clause - lm += " FROM " + gen("from_clause", stop=" WHERE", max_tokens=50) - - # WHERE clause (optional) - lm += " WHERE " + gen("where_clause", stop=";", max_tokens=100) + ";" - - return lm - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = generate_sql(lm, "Get all users who signed up in the last 30 days") -``` - -### API Endpoint - -```python -@guidance -def generate_api_endpoint(lm, description): - """Generate REST API endpoint.""" - lm += f"Description: {description}\n\n" - - # HTTP method - lm += "Method: " + select(["GET", "POST", "PUT", "DELETE"], name="method") + "\n" - - # Path - lm += "Path: /" + gen("path", regex=r"[a-z0-9/-]+", stop="\n") + "\n" - - # Request body (if POST/PUT) - if lm["method"] in ["POST", "PUT"]: - lm += "\nRequest Body:\n" - lm += "{\n" - lm += ' "field1": ' + gen("field1", regex=r'"[a-z_]+"') + ",\n" - lm += ' "field2": ' + gen("field2", regex=r'"[a-z_]+"') + "\n" - lm += "}\n" - - # Response - lm += "\nResponse (200 OK):\n" - lm += "{\n" - lm += ' "status": "success",\n' - lm += ' "data": ' + gen("response_data", max_tokens=100) + "\n" - lm += "}\n" - - return lm - -lm = models.Anthropic("claude-sonnet-4-5-20250929") -lm = generate_api_endpoint(lm, "Create a new blog post") -``` - -## Production Tips - -### Error Handling - -```python -@guidance -def safe_extraction(lm, text): - """Extract with fallback handling.""" - try: - lm += f"Text: {text}\n" - lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n", max_tokens=30) - return lm - except Exception as e: - # Fallback to less strict extraction - lm += f"Text: {text}\n" - lm += "Name: " + gen("name", stop="\n", max_tokens=30) - return lm -``` - -### Caching - -```python -from functools import lru_cache - -@lru_cache(maxsize=100) -def cached_generation(text): - """Cache LLM generations.""" - lm = models.Anthropic("claude-sonnet-4-5-20250929") - lm += f"Analyze: {text}\n" - lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment") - return lm["sentiment"] - -# First call: hits LLM -result1 = cached_generation("This is great!") - -# Second call: returns cached result -result2 = cached_generation("This is great!") # Instant! -``` - -### Monitoring - -```python -import time - -@guidance -def monitored_generation(lm, text): - """Track generation metrics.""" - start_time = time.time() - - lm += f"Text: {text}\n" - lm += "Analysis: " + gen("analysis", max_tokens=100) - - elapsed = time.time() - start_time - - # Log metrics - print(f"Generation time: {elapsed:.2f}s") - print(f"Output length: {len(lm['analysis'])} chars") - - return lm -``` - -### Batch Processing - -```python -def batch_process(texts, batch_size=10): - """Process texts in batches.""" - lm = models.Anthropic("claude-sonnet-4-5-20250929") - results = [] - - for i in range(0, len(texts), batch_size): - batch = texts[i:i+batch_size] - - for text in batch: - lm += f"Text: {text}\n" - lm += "Sentiment: " + select( - ["positive", "negative", "neutral"], - name=f"sentiment_{i}" - ) + "\n\n" - - results.extend([lm[f"sentiment_{i}"] for i in range(len(batch))]) - - return results -``` - -## Resources - -- **Guidance Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks -- **Guidance Docs**: https://guidance.readthedocs.io -- **Community Examples**: https://github.com/guidance-ai/guidance/discussions diff --git a/skills/mlops/llava/SKILL.md b/skills/mlops/llava/SKILL.md deleted file mode 100644 index 5fe0b7298..000000000 --- a/skills/mlops/llava/SKILL.md +++ /dev/null @@ -1,307 +0,0 @@ ---- -name: llava -description: Large Language and Vision Assistant. Enables visual instruction tuning and image-based conversations. Combines CLIP vision encoder with Vicuna/LLaMA language models. Supports multi-turn image chat, visual question answering, and instruction following. Use for vision-language chatbots or image understanding tasks. Best for conversational image analysis. -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [transformers, torch, pillow] -metadata: - hermes: - tags: [LLaVA, Vision-Language, Multimodal, Visual Question Answering, Image Chat, CLIP, Vicuna, Conversational AI, Instruction Tuning, VQA] - ---- - -# LLaVA - Large Language and Vision Assistant - -Open-source vision-language model for conversational image understanding. - -## When to use LLaVA - -**Use when:** -- Building vision-language chatbots -- Visual question answering (VQA) -- Image description and captioning -- Multi-turn image conversations -- Visual instruction following -- Document understanding with images - -**Metrics**: -- **23,000+ GitHub stars** -- GPT-4V level capabilities (targeted) -- Apache 2.0 License -- Multiple model sizes (7B-34B params) - -**Use alternatives instead**: -- **GPT-4V**: Highest quality, API-based -- **CLIP**: Simple zero-shot classification -- **BLIP-2**: Better for captioning only -- **Flamingo**: Research, not open-source - -## Quick start - -### Installation - -```bash -# Clone repository -git clone https://github.com/haotian-liu/LLaVA -cd LLaVA - -# Install -pip install -e . -``` - -### Basic usage - -```python -from llava.model.builder import load_pretrained_model -from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token -from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN -from llava.conversation import conv_templates -from PIL import Image -import torch - -# Load model -model_path = "liuhaotian/llava-v1.5-7b" -tokenizer, model, image_processor, context_len = load_pretrained_model( - model_path=model_path, - model_base=None, - model_name=get_model_name_from_path(model_path) -) - -# Load image -image = Image.open("image.jpg") -image_tensor = process_images([image], image_processor, model.config) -image_tensor = image_tensor.to(model.device, dtype=torch.float16) - -# Create conversation -conv = conv_templates["llava_v1"].copy() -conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?") -conv.append_message(conv.roles[1], None) -prompt = conv.get_prompt() - -# Generate response -input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) - -with torch.inference_mode(): - output_ids = model.generate( - input_ids, - images=image_tensor, - do_sample=True, - temperature=0.2, - max_new_tokens=512 - ) - -response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() -print(response) -``` - -## Available models - -| Model | Parameters | VRAM | Quality | -|-------|------------|------|---------| -| LLaVA-v1.5-7B | 7B | ~14 GB | Good | -| LLaVA-v1.5-13B | 13B | ~28 GB | Better | -| LLaVA-v1.6-34B | 34B | ~70 GB | Best | - -```python -# Load different models -model_7b = "liuhaotian/llava-v1.5-7b" -model_13b = "liuhaotian/llava-v1.5-13b" -model_34b = "liuhaotian/llava-v1.6-34b" - -# 4-bit quantization for lower VRAM -load_4bit = True # Reduces VRAM by ~4× -``` - -## CLI usage - -```bash -# Single image query -python -m llava.serve.cli \ - --model-path liuhaotian/llava-v1.5-7b \ - --image-file image.jpg \ - --query "What is in this image?" - -# Multi-turn conversation -python -m llava.serve.cli \ - --model-path liuhaotian/llava-v1.5-7b \ - --image-file image.jpg -# Then type questions interactively -``` - -## Web UI (Gradio) - -```bash -# Launch Gradio interface -python -m llava.serve.gradio_web_server \ - --model-path liuhaotian/llava-v1.5-7b \ - --load-4bit # Optional: reduce VRAM - -# Access at http://localhost:7860 -``` - -## Multi-turn conversations - -```python -# Initialize conversation -conv = conv_templates["llava_v1"].copy() - -# Turn 1 -conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?") -conv.append_message(conv.roles[1], None) -response1 = generate(conv, model, image) # "A dog playing in a park" - -# Turn 2 -conv.messages[-1][1] = response1 # Add previous response -conv.append_message(conv.roles[0], "What breed is the dog?") -conv.append_message(conv.roles[1], None) -response2 = generate(conv, model, image) # "Golden Retriever" - -# Turn 3 -conv.messages[-1][1] = response2 -conv.append_message(conv.roles[0], "What time of day is it?") -conv.append_message(conv.roles[1], None) -response3 = generate(conv, model, image) -``` - -## Common tasks - -### Image captioning - -```python -question = "Describe this image in detail." -response = ask(model, image, question) -``` - -### Visual question answering - -```python -question = "How many people are in the image?" -response = ask(model, image, question) -``` - -### Object detection (textual) - -```python -question = "List all the objects you can see in this image." -response = ask(model, image, question) -``` - -### Scene understanding - -```python -question = "What is happening in this scene?" -response = ask(model, image, question) -``` - -### Document understanding - -```python -question = "What is the main topic of this document?" -response = ask(model, document_image, question) -``` - -## Training custom model - -```bash -# Stage 1: Feature alignment (558K image-caption pairs) -bash scripts/v1_5/pretrain.sh - -# Stage 2: Visual instruction tuning (150K instruction data) -bash scripts/v1_5/finetune.sh -``` - -## Quantization (reduce VRAM) - -```python -# 4-bit quantization -tokenizer, model, image_processor, context_len = load_pretrained_model( - model_path="liuhaotian/llava-v1.5-13b", - model_base=None, - model_name=get_model_name_from_path("liuhaotian/llava-v1.5-13b"), - load_4bit=True # Reduces VRAM ~4× -) - -# 8-bit quantization -load_8bit=True # Reduces VRAM ~2× -``` - -## Best practices - -1. **Start with 7B model** - Good quality, manageable VRAM -2. **Use 4-bit quantization** - Reduces VRAM significantly -3. **GPU required** - CPU inference extremely slow -4. **Clear prompts** - Specific questions get better answers -5. **Multi-turn conversations** - Maintain conversation context -6. **Temperature 0.2-0.7** - Balance creativity/consistency -7. **max_new_tokens 512-1024** - For detailed responses -8. **Batch processing** - Process multiple images sequentially - -## Performance - -| Model | VRAM (FP16) | VRAM (4-bit) | Speed (tokens/s) | -|-------|-------------|--------------|------------------| -| 7B | ~14 GB | ~4 GB | ~20 | -| 13B | ~28 GB | ~8 GB | ~12 | -| 34B | ~70 GB | ~18 GB | ~5 | - -*On A100 GPU* - -## Benchmarks - -LLaVA achieves competitive scores on: -- **VQAv2**: 78.5% -- **GQA**: 62.0% -- **MM-Vet**: 35.4% -- **MMBench**: 64.3% - -## Limitations - -1. **Hallucinations** - May describe things not in image -2. **Spatial reasoning** - Struggles with precise locations -3. **Small text** - Difficulty reading fine print -4. **Object counting** - Imprecise for many objects -5. **VRAM requirements** - Need powerful GPU -6. **Inference speed** - Slower than CLIP - -## Integration with frameworks - -### LangChain - -```python -from langchain.llms.base import LLM - -class LLaVALLM(LLM): - def _call(self, prompt, stop=None): - # Custom LLaVA inference - return response - -llm = LLaVALLM() -``` - -### Gradio App - -```python -import gradio as gr - -def chat(image, text, history): - response = ask_llava(model, image, text) - return response - -demo = gr.ChatInterface( - chat, - additional_inputs=[gr.Image(type="pil")], - title="LLaVA Chat" -) -demo.launch() -``` - -## Resources - -- **GitHub**: https://github.com/haotian-liu/LLaVA ⭐ 23,000+ -- **Paper**: https://arxiv.org/abs/2304.08485 -- **Demo**: https://llava.hliu.cc -- **Models**: https://huggingface.co/liuhaotian -- **License**: Apache 2.0 - - diff --git a/skills/mlops/llava/references/training.md b/skills/mlops/llava/references/training.md deleted file mode 100644 index 9ab89c96f..000000000 --- a/skills/mlops/llava/references/training.md +++ /dev/null @@ -1,197 +0,0 @@ -# LLaVA Training Guide - -Guide to training and fine-tuning LLaVA models. - -## Training stages - -### Stage 1: Feature alignment (Pretraining) - -**Purpose**: Align vision encoder with language model - -**Data**: 558K image-caption pairs (CC3M subset) - -```bash -# Download pretrained projector or train from scratch -bash scripts/v1_5/pretrain.sh -``` - -**Configuration:** -- Base model: Vicuna-7B or LLaMA-2-7B -- Vision encoder: CLIP ViT-L/14 -- Training time: ~20 hours on 8× A100 - -### Stage 2: Visual instruction tuning - -**Purpose**: Teach model to follow visual instructions - -**Data**: 150K GPT-generated multimodal instruction data - -```bash -# Fine-tune with instruction data -bash scripts/v1_5/finetune.sh -``` - -**Configuration:** -- Epochs: 1 -- Batch size: 128 (across 8 GPUs) -- Learning rate: 2e-5 -- Training time: ~24 hours on 8× A100 - -## Data format - -### Instruction data format - -```json -[ - { - "id": "001", - "image": "path/to/image.jpg", - "conversations": [ - { - "from": "human", - "value": "\nWhat is in this image?" - }, - { - "from": "gpt", - "value": "The image shows a dog playing in a park." - }, - { - "from": "human", - "value": "What breed is the dog?" - }, - { - "from": "gpt", - "value": "It appears to be a Golden Retriever." - } - ] - } -] -``` - -## Fine-tuning on custom data - -### Prepare your data - -```python -import json - -# Create instruction data -data = [] -for image_path, qa_pairs in your_dataset: - conversations = [] - for q, a in qa_pairs: - conversations.append({"from": "human", "value": f"\n{q}"}) - conversations.append({"from": "gpt", "value": a}) - - data.append({ - "id": str(len(data)), - "image": image_path, - "conversations": conversations - }) - -# Save -with open("custom_data.json", "w") as f: - json.dump(data, f, indent=2) -``` - -### Fine-tune script - -```bash -#!/bin/bash - -# Set paths -DATA_PATH="custom_data.json" -IMAGE_FOLDER="path/to/images" -MODEL_PATH="liuhaotian/llava-v1.5-7b" -OUTPUT_DIR="./checkpoints/llava-custom" - -# Fine-tune -deepspeed llava/train/train_mem.py \ - --deepspeed ./scripts/zero2.json \ - --model_name_or_path $MODEL_PATH \ - --version v1 \ - --data_path $DATA_PATH \ - --image_folder $IMAGE_FOLDER \ - --vision_tower openai/clip-vit-large-patch14-336 \ - --mm_projector_type mlp2x_gelu \ - --mm_vision_select_layer -2 \ - --mm_use_im_start_end False \ - --mm_use_im_patch_token False \ - --image_aspect_ratio pad \ - --group_by_modality_length True \ - --bf16 True \ - --output_dir $OUTPUT_DIR \ - --num_train_epochs 1 \ - --per_device_train_batch_size 16 \ - --per_device_eval_batch_size 4 \ - --gradient_accumulation_steps 1 \ - --evaluation_strategy "no" \ - --save_strategy "steps" \ - --save_steps 50000 \ - --save_total_limit 1 \ - --learning_rate 2e-5 \ - --weight_decay 0. \ - --warmup_ratio 0.03 \ - --lr_scheduler_type "cosine" \ - --logging_steps 1 \ - --tf32 True \ - --model_max_length 2048 \ - --gradient_checkpointing True \ - --dataloader_num_workers 4 \ - --lazy_preprocess True \ - --report_to wandb -``` - -## LoRA fine-tuning (memory efficient) - -```python -from peft import LoraConfig, get_peft_model - -# LoRA config -lora_config = LoraConfig( - r=8, # LoRA rank - lora_alpha=16, - target_modules=["q_proj", "v_proj"], - lora_dropout=0.05, - bias="none", - task_type="CAUSAL_LM" -) - -# Apply LoRA -model = get_peft_model(base_model, lora_config) - -# Train with much lower memory -``` - -## Hardware requirements - -### Full fine-tuning - -- **7B model**: 8× A100 (40GB) -- **13B model**: 8× A100 (80GB) -- **Training time**: 20-48 hours - -### LoRA fine-tuning - -- **7B model**: 1× A100 (40GB) -- **13B model**: 2× A100 (40GB) -- **Training time**: 10-24 hours - -## Best practices - -1. **Start with pretrained** - Don't train from scratch -2. **Use LoRA for efficiency** - 10× less memory -3. **Quality over quantity** - 1K high-quality > 10K low-quality -4. **Multi-turn conversations** - More engaging than single Q&A -5. **Diverse images** - Cover different scenarios -6. **Clear instructions** - Specific questions get better answers -7. **Monitor loss** - Should decrease smoothly -8. **Save checkpoints** - Training can fail -9. **Test regularly** - Validate on held-out set -10. **Use DeepSpeed** - For multi-GPU training - -## Resources - -- **Training script**: https://github.com/haotian-liu/LLaVA/tree/main/scripts -- **Data format**: https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md -- **Paper**: https://arxiv.org/abs/2304.08485 diff --git a/skills/mlops/nemo-curator/SKILL.md b/skills/mlops/nemo-curator/SKILL.md deleted file mode 100644 index c9262f11a..000000000 --- a/skills/mlops/nemo-curator/SKILL.md +++ /dev/null @@ -1,386 +0,0 @@ ---- -name: nemo-curator -description: GPU-accelerated data curation for LLM training. Supports text/image/video/audio. Features fuzzy deduplication (16× faster), quality filtering (30+ heuristics), semantic deduplication, PII redaction, NSFW detection. Scales across GPUs with RAPIDS. Use for preparing high-quality training datasets, cleaning web data, or deduplicating large corpora. -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [nemo-curator, cudf, dask, rapids] -metadata: - hermes: - tags: [Data Processing, NeMo Curator, Data Curation, GPU Acceleration, Deduplication, Quality Filtering, NVIDIA, RAPIDS, PII Redaction, Multimodal, LLM Training Data] - ---- - -# NeMo Curator - GPU-Accelerated Data Curation - -NVIDIA's toolkit for preparing high-quality training data for LLMs. - -## When to use NeMo Curator - -**Use NeMo Curator when:** -- Preparing LLM training data from web scrapes (Common Crawl) -- Need fast deduplication (16× faster than CPU) -- Curating multi-modal datasets (text, images, video, audio) -- Filtering low-quality or toxic content -- Scaling data processing across GPU cluster - -**Performance**: -- **16× faster** fuzzy deduplication (8TB RedPajama v2) -- **40% lower TCO** vs CPU alternatives -- **Near-linear scaling** across GPU nodes - -**Use alternatives instead**: -- **datatrove**: CPU-based, open-source data processing -- **dolma**: Allen AI's data toolkit -- **Ray Data**: General ML data processing (no curation focus) - -## Quick start - -### Installation - -```bash -# Text curation (CUDA 12) -uv pip install "nemo-curator[text_cuda12]" - -# All modalities -uv pip install "nemo-curator[all_cuda12]" - -# CPU-only (slower) -uv pip install "nemo-curator[cpu]" -``` - -### Basic text curation pipeline - -```python -from nemo_curator import ScoreFilter, Modify -from nemo_curator.datasets import DocumentDataset -import pandas as pd - -# Load data -df = pd.DataFrame({"text": ["Good document", "Bad doc", "Excellent text"]}) -dataset = DocumentDataset(df) - -# Quality filtering -def quality_score(doc): - return len(doc["text"].split()) > 5 # Filter short docs - -filtered = ScoreFilter(quality_score)(dataset) - -# Deduplication -from nemo_curator.modules import ExactDuplicates -deduped = ExactDuplicates()(filtered) - -# Save -deduped.to_parquet("curated_data/") -``` - -## Data curation pipeline - -### Stage 1: Quality filtering - -```python -from nemo_curator.filters import ( - WordCountFilter, - RepeatedLinesFilter, - UrlRatioFilter, - NonAlphaNumericFilter -) - -# Apply 30+ heuristic filters -from nemo_curator import ScoreFilter - -# Word count filter -dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000)) - -# Remove repetitive content -dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3)) - -# URL ratio filter -dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2)) -``` - -### Stage 2: Deduplication - -**Exact deduplication**: -```python -from nemo_curator.modules import ExactDuplicates - -# Remove exact duplicates -deduped = ExactDuplicates(id_field="id", text_field="text")(dataset) -``` - -**Fuzzy deduplication** (16× faster on GPU): -```python -from nemo_curator.modules import FuzzyDuplicates - -# MinHash + LSH deduplication -fuzzy_dedup = FuzzyDuplicates( - id_field="id", - text_field="text", - num_hashes=260, # MinHash parameters - num_buckets=20, - hash_method="md5" -) - -deduped = fuzzy_dedup(dataset) -``` - -**Semantic deduplication**: -```python -from nemo_curator.modules import SemanticDuplicates - -# Embedding-based deduplication -semantic_dedup = SemanticDuplicates( - id_field="id", - text_field="text", - embedding_model="sentence-transformers/all-MiniLM-L6-v2", - threshold=0.8 # Cosine similarity threshold -) - -deduped = semantic_dedup(dataset) -``` - -### Stage 3: PII redaction - -```python -from nemo_curator.modules import Modify -from nemo_curator.modifiers import PIIRedactor - -# Redact personally identifiable information -pii_redactor = PIIRedactor( - supported_entities=["EMAIL_ADDRESS", "PHONE_NUMBER", "PERSON", "LOCATION"], - anonymize_action="replace" # or "redact" -) - -redacted = Modify(pii_redactor)(dataset) -``` - -### Stage 4: Classifier filtering - -```python -from nemo_curator.classifiers import QualityClassifier - -# Quality classification -quality_clf = QualityClassifier( - model_path="nvidia/quality-classifier-deberta", - batch_size=256, - device="cuda" -) - -# Filter low-quality documents -high_quality = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5) -``` - -## GPU acceleration - -### GPU vs CPU performance - -| Operation | CPU (16 cores) | GPU (A100) | Speedup | -|-----------|----------------|------------|---------| -| Fuzzy dedup (8TB) | 120 hours | 7.5 hours | 16× | -| Exact dedup (1TB) | 8 hours | 0.5 hours | 16× | -| Quality filtering | 2 hours | 0.2 hours | 10× | - -### Multi-GPU scaling - -```python -from nemo_curator import get_client -import dask_cuda - -# Initialize GPU cluster -client = get_client(cluster_type="gpu", n_workers=8) - -# Process with 8 GPUs -deduped = FuzzyDuplicates(...)(dataset) -``` - -## Multi-modal curation - -### Image curation - -```python -from nemo_curator.image import ( - AestheticFilter, - NSFWFilter, - CLIPEmbedder -) - -# Aesthetic scoring -aesthetic_filter = AestheticFilter(threshold=5.0) -filtered_images = aesthetic_filter(image_dataset) - -# NSFW detection -nsfw_filter = NSFWFilter(threshold=0.9) -safe_images = nsfw_filter(filtered_images) - -# Generate CLIP embeddings -clip_embedder = CLIPEmbedder(model="openai/clip-vit-base-patch32") -image_embeddings = clip_embedder(safe_images) -``` - -### Video curation - -```python -from nemo_curator.video import ( - SceneDetector, - ClipExtractor, - InternVideo2Embedder -) - -# Detect scenes -scene_detector = SceneDetector(threshold=27.0) -scenes = scene_detector(video_dataset) - -# Extract clips -clip_extractor = ClipExtractor(min_duration=2.0, max_duration=10.0) -clips = clip_extractor(scenes) - -# Generate embeddings -video_embedder = InternVideo2Embedder() -video_embeddings = video_embedder(clips) -``` - -### Audio curation - -```python -from nemo_curator.audio import ( - ASRInference, - WERFilter, - DurationFilter -) - -# ASR transcription -asr = ASRInference(model="nvidia/stt_en_fastconformer_hybrid_large_pc") -transcribed = asr(audio_dataset) - -# Filter by WER (word error rate) -wer_filter = WERFilter(max_wer=0.3) -high_quality_audio = wer_filter(transcribed) - -# Duration filtering -duration_filter = DurationFilter(min_duration=1.0, max_duration=30.0) -filtered_audio = duration_filter(high_quality_audio) -``` - -## Common patterns - -### Web scrape curation (Common Crawl) - -```python -from nemo_curator import ScoreFilter, Modify -from nemo_curator.filters import * -from nemo_curator.modules import * -from nemo_curator.datasets import DocumentDataset - -# Load Common Crawl data -dataset = DocumentDataset.read_parquet("common_crawl/*.parquet") - -# Pipeline -pipeline = [ - # 1. Quality filtering - WordCountFilter(min_words=100, max_words=50000), - RepeatedLinesFilter(max_repeated_line_fraction=0.2), - SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3), - UrlRatioFilter(max_url_ratio=0.3), - - # 2. Language filtering - LanguageIdentificationFilter(target_languages=["en"]), - - # 3. Deduplication - ExactDuplicates(id_field="id", text_field="text"), - FuzzyDuplicates(id_field="id", text_field="text", num_hashes=260), - - # 4. PII redaction - PIIRedactor(), - - # 5. NSFW filtering - NSFWClassifier(threshold=0.8) -] - -# Execute -for stage in pipeline: - dataset = stage(dataset) - -# Save -dataset.to_parquet("curated_common_crawl/") -``` - -### Distributed processing - -```python -from nemo_curator import get_client -from dask_cuda import LocalCUDACluster - -# Multi-GPU cluster -cluster = LocalCUDACluster(n_workers=8) -client = get_client(cluster=cluster) - -# Process large dataset -dataset = DocumentDataset.read_parquet("s3://large_dataset/*.parquet") -deduped = FuzzyDuplicates(...)(dataset) - -# Cleanup -client.close() -cluster.close() -``` - -## Performance benchmarks - -### Fuzzy deduplication (8TB RedPajama v2) - -- **CPU (256 cores)**: 120 hours -- **GPU (8× A100)**: 7.5 hours -- **Speedup**: 16× - -### Exact deduplication (1TB) - -- **CPU (64 cores)**: 8 hours -- **GPU (4× A100)**: 0.5 hours -- **Speedup**: 16× - -### Quality filtering (100GB) - -- **CPU (32 cores)**: 2 hours -- **GPU (2× A100)**: 0.2 hours -- **Speedup**: 10× - -## Cost comparison - -**CPU-based curation** (AWS c5.18xlarge × 10): -- Cost: $3.60/hour × 10 = $36/hour -- Time for 8TB: 120 hours -- **Total**: $4,320 - -**GPU-based curation** (AWS p4d.24xlarge × 2): -- Cost: $32.77/hour × 2 = $65.54/hour -- Time for 8TB: 7.5 hours -- **Total**: $491.55 - -**Savings**: 89% reduction ($3,828 saved) - -## Supported data formats - -- **Input**: Parquet, JSONL, CSV -- **Output**: Parquet (recommended), JSONL -- **WebDataset**: TAR archives for multi-modal - -## Use cases - -**Production deployments**: -- NVIDIA used NeMo Curator to prepare Nemotron-4 training data -- Open-source datasets curated: RedPajama v2, The Pile - -## References - -- **[Filtering Guide](references/filtering.md)** - 30+ quality filters, heuristics -- **[Deduplication Guide](references/deduplication.md)** - Exact, fuzzy, semantic methods - -## Resources - -- **GitHub**: https://github.com/NVIDIA/NeMo-Curator ⭐ 500+ -- **Docs**: https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/ -- **Version**: 0.4.0+ -- **License**: Apache 2.0 - - - diff --git a/skills/mlops/nemo-curator/references/deduplication.md b/skills/mlops/nemo-curator/references/deduplication.md deleted file mode 100644 index b3336c1c7..000000000 --- a/skills/mlops/nemo-curator/references/deduplication.md +++ /dev/null @@ -1,87 +0,0 @@ -# Deduplication Guide - -Complete guide to exact, fuzzy, and semantic deduplication. - -## Exact deduplication - -Remove documents with identical content. - -```python -from nemo_curator.modules import ExactDuplicates - -# Exact deduplication -exact_dedup = ExactDuplicates( - id_field="id", - text_field="text", - hash_method="md5" # or "sha256" -) - -deduped = exact_dedup(dataset) -``` - -**Performance**: ~16× faster on GPU vs CPU - -## Fuzzy deduplication - -Remove near-duplicate documents using MinHash + LSH. - -```python -from nemo_curator.modules import FuzzyDuplicates - -fuzzy_dedup = FuzzyDuplicates( - id_field="id", - text_field="text", - num_hashes=260, # MinHash permutations (more = accurate) - num_buckets=20, # LSH buckets (more = faster, less recall) - hash_method="md5", - jaccard_threshold=0.8 # Similarity threshold -) - -deduped = fuzzy_dedup(dataset) -``` - -**Parameters**: -- `num_hashes`: 128-512 (default 260) -- `num_buckets`: 10-50 (default 20) -- `jaccard_threshold`: 0.7-0.9 (default 0.8) - -**Performance**: 16× faster on 8TB dataset (120h → 7.5h) - -## Semantic deduplication - -Remove semantically similar documents using embeddings. - -```python -from nemo_curator.modules import SemanticDuplicates - -semantic_dedup = SemanticDuplicates( - id_field="id", - text_field="text", - embedding_model="sentence-transformers/all-MiniLM-L6-v2", - embedding_batch_size=256, - threshold=0.85, # Cosine similarity threshold - device="cuda" -) - -deduped = semantic_dedup(dataset) -``` - -**Models**: -- `all-MiniLM-L6-v2`: Fast, 384 dims -- `all-mpnet-base-v2`: Better quality, 768 dims -- Custom models supported - -## Comparison - -| Method | Speed | Recall | Use Case | -|--------|-------|--------|----------| -| Exact | Fastest | 100% | Exact matches only | -| Fuzzy | Fast | ~95% | Near-duplicates (recommended) | -| Semantic | Slow | ~90% | Paraphrases, rewrites | - -## Best practices - -1. **Start with exact dedup** - Remove obvious duplicates -2. **Use fuzzy for large datasets** - Best speed/quality trade-off -3. **Semantic for high-value data** - Expensive but thorough -4. **GPU acceleration required** - 10-16× speedup diff --git a/skills/mlops/nemo-curator/references/filtering.md b/skills/mlops/nemo-curator/references/filtering.md deleted file mode 100644 index 565160685..000000000 --- a/skills/mlops/nemo-curator/references/filtering.md +++ /dev/null @@ -1,102 +0,0 @@ -# Quality Filtering Guide - -Complete guide to NeMo Curator's 30+ quality filters. - -## Text-based filters - -### Word count - -```python -from nemo_curator.filters import WordCountFilter - -# Filter by word count -dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000)) -``` - -### Repeated content - -```python -from nemo_curator.filters import RepeatedLinesFilter - -# Remove documents with >30% repeated lines -dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3)) -``` - -### Symbol ratio - -```python -from nemo_curator.filters import SymbolToWordRatioFilter - -# Remove documents with too many symbols -dataset = dataset.filter(SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3)) -``` - -### URL ratio - -```python -from nemo_curator.filters import UrlRatioFilter - -# Remove documents with many URLs -dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2)) -``` - -## Language filtering - -```python -from nemo_curator.filters import LanguageIdentificationFilter - -# Keep only English documents -dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en"])) - -# Multiple languages -dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en", "es", "fr"])) -``` - -## Classifier-based filtering - -### Quality classifier - -```python -from nemo_curator.classifiers import QualityClassifier - -quality_clf = QualityClassifier( - model_path="nvidia/quality-classifier-deberta", - batch_size=256, - device="cuda" -) - -# Filter low-quality (threshold > 0.5 = high quality) -dataset = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5) -``` - -### NSFW classifier - -```python -from nemo_curator.classifiers import NSFWClassifier - -nsfw_clf = NSFWClassifier(threshold=0.9, device="cuda") - -# Remove NSFW content -dataset = dataset.filter(lambda doc: nsfw_clf(doc["text"]) < 0.9) -``` - -## Heuristic filters - -Full list of 30+ filters: -- WordCountFilter -- RepeatedLinesFilter -- UrlRatioFilter -- SymbolToWordRatioFilter -- NonAlphaNumericFilter -- BulletsFilter -- WhiteSpaceFilter -- ParenthesesFilter -- LongWordFilter -- And 20+ more... - -## Best practices - -1. **Apply cheap filters first** - Word count before GPU classifiers -2. **Tune thresholds on sample** - Test on 10k docs before full run -3. **Use GPU classifiers sparingly** - Expensive but effective -4. **Chain filters efficiently** - Order by cost (cheap → expensive) diff --git a/skills/mlops/pytorch-fsdp/SKILL.md b/skills/mlops/pytorch-fsdp/SKILL.md deleted file mode 100644 index 9e16f446f..000000000 --- a/skills/mlops/pytorch-fsdp/SKILL.md +++ /dev/null @@ -1,129 +0,0 @@ ---- -name: pytorch-fsdp -description: Expert guidance for Fully Sharded Data Parallel training with PyTorch FSDP - parameter sharding, mixed precision, CPU offloading, FSDP2 -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [torch>=2.0, transformers] -metadata: - hermes: - tags: [Distributed Training, PyTorch, FSDP, Data Parallel, Sharding, Mixed Precision, CPU Offloading, FSDP2, Large-Scale Training] - ---- - -# Pytorch-Fsdp Skill - -Comprehensive assistance with pytorch-fsdp development, generated from official documentation. - -## When to Use This Skill - -This skill should be triggered when: -- Working with pytorch-fsdp -- Asking about pytorch-fsdp features or APIs -- Implementing pytorch-fsdp solutions -- Debugging pytorch-fsdp code -- Learning pytorch-fsdp best practices - -## Quick Reference - -### Common Patterns - -**Pattern 1:** Generic Join Context Manager# Created On: Jun 06, 2025 | Last Updated On: Jun 06, 2025 The generic join context manager facilitates distributed training on uneven inputs. This page outlines the API of the relevant classes: Join, Joinable, and JoinHook. For a tutorial, see Distributed Training with Uneven Inputs Using the Join Context Manager. class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[source]# This class defines the generic join context manager, which allows custom hooks to be called after a process joins. These hooks should shadow the collective communications of non-joined processes to prevent hanging and erroring and to ensure algorithmic correctness. Refer to JoinHook for details about the hook definition. Warning The context manager requires each participating Joinable to call the method notify_join_context() before its own per- iteration collective communications to ensure correctness. Warning The context manager requires that all process_group attributes in the JoinHook objects are the same. If there are multiple JoinHook objects, then the device of the first is used. The process group and device information is used for checking for non- joined processes and for notifying processes to throw an exception if throw_on_early_termination is enabled, both of which using an all- reduce. Parameters joinables (List[Joinable]) – a list of the participating Joinable s; their hooks are iterated over in the given order. enable (bool) – a flag enabling uneven input detection; setting to False disables the context manager’s functionality and should only be set when the user knows the inputs will not be uneven (default: True). throw_on_early_termination (bool) – a flag controlling whether to throw an exception upon detecting uneven inputs (default: False). Example: >>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # Rank 1 gets one more input than rank 0 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # All ranks reach here without hanging/erroring static notify_join_context(joinable)[source]# Notifies the join context manager that the calling process has not yet joined. Then, if throw_on_early_termination=True, checks if uneven inputs have been detected (i.e. if one process has already joined) and throws an exception if so. This method should be called from a Joinable object before its per-iteration collective communications. For example, this should be called at the beginning of the forward pass in DistributedDataParallel. Only the first Joinable object passed into the context manager performs the collective communications in this method, and for the others, this method is vacuous. Parameters joinable (Joinable) – the Joinable object calling this method. Returns An async work handle for the all-reduce meant to notify the context manager that the process has not yet joined if joinable is the first one passed into the context manager; None otherwise. class torch.distributed.algorithms.Joinable[source]# This defines an abstract base class for joinable classes. A joinable class (inheriting from Joinable) should implement join_hook(), which returns a JoinHook instance, in addition to join_device() and join_process_group() that return device and process group information, respectively. abstract property join_device: device# Return the device from which to perform collective communications needed by the join context manager. abstract join_hook(**kwargs)[source]# Return a JoinHook instance for the given Joinable. Parameters kwargs (dict) – a dict containing any keyword arguments to modify the behavior of the join hook at run time; all Joinable instances sharing the same join context manager are forwarded the same value for kwargs. Return type JoinHook abstract property join_process_group: Any# Returns the process group for the collective communications needed by the join context manager itself. class torch.distributed.algorithms.JoinHook[source]# This defines a join hook, which provides two entry points in the join context manager. Entry points : a main hook, which is called repeatedly while there exists a non-joined process, and a post-hook, which is called once all processes have joined. To implement a join hook for the generic join context manager, define a class that inherits from JoinHook and override main_hook() and post_hook() as appropriate. main_hook()[source]# Call this hook while there exists a non-joined process to shadow collective communications in a training iteration. Training iteration i.e., in one forward pass, backward pass, and optimizer step. post_hook(is_last_joiner)[source]# Call hook after all processes have joined. It is passed an additional bool argument is_last_joiner, which indicates if the rank is one of the last to join. Parameters is_last_joiner (bool) – True if the rank is one of the last to join; False otherwise. - -``` -Join -``` - -**Pattern 2:** Distributed communication package - torch.distributed# Created On: Jul 12, 2017 | Last Updated On: Sep 04, 2025 Note Please refer to PyTorch Distributed Overview for a brief introduction to all features related to distributed training. Backends# torch.distributed supports four built-in backends, each with different capabilities. The table below shows which functions are available for use with a CPU or GPU for each backend. For NCCL, GPU refers to CUDA GPU while for XCCL to XPU GPU. MPI supports CUDA only if the implementation used to build PyTorch supports it. Backend gloo mpi nccl xccl Device CPU GPU CPU GPU CPU GPU CPU GPU send ✓ ✘ ✓ ? ✘ ✓ ✘ ✓ recv ✓ ✘ ✓ ? ✘ ✓ ✘ ✓ broadcast ✓ ✓ ✓ ? ✘ ✓ ✘ ✓ all_reduce ✓ ✓ ✓ ? ✘ ✓ ✘ ✓ reduce ✓ ✓ ✓ ? ✘ ✓ ✘ ✓ all_gather ✓ ✓ ✓ ? ✘ ✓ ✘ ✓ gather ✓ ✓ ✓ ? ✘ ✓ ✘ ✓ scatter ✓ ✓ ✓ ? ✘ ✓ ✘ ✓ reduce_scatter ✓ ✓ ✘ ✘ ✘ ✓ ✘ ✓ all_to_all ✓ ✓ ✓ ? ✘ ✓ ✘ ✓ barrier ✓ ✘ ✓ ? ✘ ✓ ✘ ✓ Backends that come with PyTorch# PyTorch distributed package supports Linux (stable), MacOS (stable), and Windows (prototype). By default for Linux, the Gloo and NCCL backends are built and included in PyTorch distributed (NCCL only when building with CUDA). MPI is an optional backend that can only be included if you build PyTorch from source. (e.g. building PyTorch on a host that has MPI installed.) Note As of PyTorch v1.8, Windows supports all collective communications backend but NCCL, If the init_method argument of init_process_group() points to a file it must adhere to the following schema: Local file system, init_method="file:///d:/tmp/some_file" Shared file system, init_method="file://////{machine_name}/{share_folder_name}/some_file" Same as on Linux platform, you can enable TcpStore by setting environment variables, MASTER_ADDR and MASTER_PORT. Which backend to use?# In the past, we were often asked: “which backend should I use?”. Rule of thumb Use the NCCL backend for distributed training with CUDA GPU. Use the XCCL backend for distributed training with XPU GPU. Use the Gloo backend for distributed training with CPU. GPU hosts with InfiniBand interconnect Use NCCL, since it’s the only backend that currently supports InfiniBand and GPUDirect. GPU hosts with Ethernet interconnect Use NCCL, since it currently provides the best distributed GPU training performance, especially for multiprocess single-node or multi-node distributed training. If you encounter any problem with NCCL, use Gloo as the fallback option. (Note that Gloo currently runs slower than NCCL for GPUs.) CPU hosts with InfiniBand interconnect If your InfiniBand has enabled IP over IB, use Gloo, otherwise, use MPI instead. We are planning on adding InfiniBand support for Gloo in the upcoming releases. CPU hosts with Ethernet interconnect Use Gloo, unless you have specific reasons to use MPI. Common environment variables# Choosing the network interface to use# By default, both the NCCL and Gloo backends will try to find the right network interface to use. If the automatically detected interface is not correct, you can override it using the following environment variables (applicable to the respective backend): NCCL_SOCKET_IFNAME, for example export NCCL_SOCKET_IFNAME=eth0 GLOO_SOCKET_IFNAME, for example export GLOO_SOCKET_IFNAME=eth0 If you’re using the Gloo backend, you can specify multiple interfaces by separating them by a comma, like this: export GLOO_SOCKET_IFNAME=eth0,eth1,eth2,eth3. The backend will dispatch operations in a round-robin fashion across these interfaces. It is imperative that all processes specify the same number of interfaces in this variable. Other NCCL environment variables# Debugging - in case of NCCL failure, you can set NCCL_DEBUG=INFO to print an explicit warning message as well as basic NCCL initialization information. You may also use NCCL_DEBUG_SUBSYS to get more details about a specific aspect of NCCL. For example, NCCL_DEBUG_SUBSYS=COLL would print logs of collective calls, which may be helpful when debugging hangs, especially those caused by collective type or message size mismatch. In case of topology detection failure, it would be helpful to set NCCL_DEBUG_SUBSYS=GRAPH to inspect the detailed detection result and save as reference if further help from NCCL team is needed. Performance tuning - NCCL performs automatic tuning based on its topology detection to save users’ tuning effort. On some socket-based systems, users may still try tuning NCCL_SOCKET_NTHREADS and NCCL_NSOCKS_PERTHREAD to increase socket network bandwidth. These two environment variables have been pre-tuned by NCCL for some cloud providers, such as AWS or GCP. For a full list of NCCL environment variables, please refer to NVIDIA NCCL’s official documentation You can tune NCCL communicators even further using torch.distributed.ProcessGroupNCCL.NCCLConfig and torch.distributed.ProcessGroupNCCL.Options. Learn more about them using help (e.g. help(torch.distributed.ProcessGroupNCCL.NCCLConfig)) in the interpreter. Basics# The torch.distributed package provides PyTorch support and communication primitives for multiprocess parallelism across several computation nodes running on one or more machines. The class torch.nn.parallel.DistributedDataParallel() builds on this functionality to provide synchronous distributed training as a wrapper around any PyTorch model. This differs from the kinds of parallelism provided by Multiprocessing package - torch.multiprocessing and torch.nn.DataParallel() in that it supports multiple network-connected machines and in that the user must explicitly launch a separate copy of the main training script for each process. In the single-machine synchronous case, torch.distributed or the torch.nn.parallel.DistributedDataParallel() wrapper may still have advantages over other approaches to data-parallelism, including torch.nn.DataParallel(): Each process maintains its own optimizer and performs a complete optimization step with each iteration. While this may appear redundant, since the gradients have already been gathered together and averaged across processes and are thus the same for every process, this means that no parameter broadcast step is needed, reducing time spent transferring tensors between nodes. Each process contains an independent Python interpreter, eliminating the extra interpreter overhead and “GIL-thrashing” that comes from driving several execution threads, model replicas, or GPUs from a single Python process. This is especially important for models that make heavy use of the Python runtime, including models with recurrent layers or many small components. Initialization# The package needs to be initialized using the torch.distributed.init_process_group() or torch.distributed.device_mesh.init_device_mesh() function before calling any other methods. Both block until all processes have joined. Warning Initialization is not thread-safe. Process group creation should be performed from a single thread, to prevent inconsistent ‘UUID’ assignment across ranks, and to prevent races during initialization that can lead to hangs. torch.distributed.is_available()[source]# Return True if the distributed package is available. Otherwise, torch.distributed does not expose any other APIs. Currently, torch.distributed is available on Linux, MacOS and Windows. Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source. Currently, the default value is USE_DISTRIBUTED=1 for Linux and Windows, USE_DISTRIBUTED=0 for MacOS. Return type bool torch.distributed.init_process_group(backend=None, init_method=None, timeout=None, world_size=-1, rank=-1, store=None, group_name='', pg_options=None, device_id=None)[source]# Initialize the default distributed process group. This will also initialize the distributed package. There are 2 main ways to initialize a process group: Specify store, rank, and world_size explicitly. Specify init_method (a URL string) which indicates where/how to discover peers. Optionally specify rank and world_size, or encode all required parameters in the URL and omit them. If neither is specified, init_method is assumed to be “env://”. Parameters backend (str or Backend, optional) – The backend to use. Depending on build-time configurations, valid values include mpi, gloo, nccl, ucc, xccl or one that is registered by a third-party plugin. Since 2.6, if backend is not provided, c10d will use a backend registered for the device type indicated by the device_id kwarg (if provided). The known default registrations today are: nccl for cuda, gloo for cpu, xccl for xpu. If neither backend nor device_id is provided, c10d will detect the accelerator on the run-time machine and use a backend registered for that detected accelerator (or cpu). This field can be given as a lowercase string (e.g., "gloo"), which can also be accessed via Backend attributes (e.g., Backend.GLOO). If using multiple processes per machine with nccl backend, each process must have exclusive access to every GPU it uses, as sharing GPUs between processes can result in deadlock or NCCL invalid usage. ucc backend is experimental. Default backend for the device can be queried with get_default_backend_for_device(). init_method (str, optional) – URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified. Mutually exclusive with store. world_size (int, optional) – Number of processes participating in the job. Required if store is specified. rank (int, optional) – Rank of the current process (it should be a number between 0 and world_size-1). Required if store is specified. store (Store, optional) – Key/value store accessible to all workers, used to exchange connection/address information. Mutually exclusive with init_method. timeout (timedelta, optional) – Timeout for operations executed against the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends. This is the duration after which collectives will be aborted asynchronously and the process will crash. This is done since CUDA execution is async and it is no longer safe to continue executing user code since failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. group_name (str, optional, deprecated) – Group name. This argument is ignored pg_options (ProcessGroupOptions, optional) – process group options specifying what additional options need to be passed in during the construction of specific process groups. As of now, the only options we support is ProcessGroupNCCL.Options for the nccl backend, is_high_priority_stream can be specified so that the nccl backend can pick up high priority cuda streams when there’re compute kernels waiting. For other available options to config nccl, See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t device_id (torch.device | int, optional) – a single, specific device this process will work on, allowing for backend-specific optimizations. Currently this has two effects, only under NCCL: the communicator is immediately formed (calling ncclCommInit* immediately rather than the normal lazy call) and sub-groups will use ncclCommSplit when possible to avoid unnecessary overhead of group creation. If you want to know NCCL initialization error early, you can also use this field. If an int is provided, the API assumes that the accelerator type at compile time will be used. Note To enable backend == Backend.MPI, PyTorch needs to be built from source on a system that supports MPI. Note Support for multiple backends is experimental. Currently when no backend is specified, both gloo and nccl backends will be created. The gloo backend will be used for collectives with CPU tensors and the nccl backend will be used for collectives with CUDA tensors. A custom backend can be specified by passing in a string with format “:,:”, e.g. “cpu:gloo,cuda:custom_backend”. torch.distributed.device_mesh.init_device_mesh(device_type, mesh_shape, *, mesh_dim_names=None, backend_override=None)[source]# Initializes a DeviceMesh based on device_type, mesh_shape, and mesh_dim_names parameters. This creates a DeviceMesh with an n-dimensional array layout, where n is the length of mesh_shape. If mesh_dim_names is provided, each dimension is labeled as mesh_dim_names[i]. Note init_device_mesh follows SPMD programming model, meaning the same PyTorch Python program runs on all processes/ranks in the cluster. Ensure mesh_shape (the dimensions of the nD array describing device layout) is identical across all ranks. Inconsistent mesh_shape may lead to hanging. Note If no process group is found, init_device_mesh will initialize distributed process group/groups required for distributed communications behind the scene. Parameters device_type (str) – The device type of the mesh. Currently supports: “cpu”, “cuda/cuda-like”, “xpu”. Passing in a device type with a GPU index, such as “cuda:0”, is not allowed. mesh_shape (Tuple[int]) – A tuple defining the dimensions of the multi-dimensional array describing the layout of devices. mesh_dim_names (Tuple[str], optional) – A tuple of mesh dimension names to assign to each dimension of the multi-dimensional array describing the layout of devices. Its length must match the length of mesh_shape. Each string in mesh_dim_names must be unique. backend_override (Dict[int | str, tuple[str, Options] | str | Options], optional) – Overrides for some or all of the ProcessGroups that will be created for each mesh dimension. Each key can be either the index of a dimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the name of the backend and its options, or just one of these two components (in which case the other will be set to its default value). Returns A DeviceMesh object representing the device layout. Return type DeviceMesh Example: >>> from torch.distributed.device_mesh import init_device_mesh >>> >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) torch.distributed.is_initialized()[source]# Check if the default process group has been initialized. Return type bool torch.distributed.is_mpi_available()[source]# Check if the MPI backend is available. Return type bool torch.distributed.is_nccl_available()[source]# Check if the NCCL backend is available. Return type bool torch.distributed.is_gloo_available()[source]# Check if the Gloo backend is available. Return type bool torch.distributed.distributed_c10d.is_xccl_available()[source]# Check if the XCCL backend is available. Return type bool torch.distributed.is_torchelastic_launched()[source]# Check whether this process was launched with torch.distributed.elastic (aka torchelastic). The existence of TORCHELASTIC_RUN_ID environment variable is used as a proxy to determine whether the current process was launched with torchelastic. This is a reasonable proxy since TORCHELASTIC_RUN_ID maps to the rendezvous id which is always a non-null value indicating the job id for peer discovery purposes.. Return type bool torch.distributed.get_default_backend_for_device(device)[source]# Return the default backend for the given device. Parameters device (Union[str, torch.device]) – The device to get the default backend for. Returns The default backend for the given device as a lower case string. Return type str Currently three initialization methods are supported: TCP initialization# There are two ways to initialize using TCP, both requiring a network address reachable from all processes and a desired world_size. The first way requires specifying an address that belongs to the rank 0 process. This initialization method requires that all processes have manually specified ranks. Note that multicast address is not supported anymore in the latest distributed package. group_name is deprecated as well. import torch.distributed as dist # Use address of one of the machines dist.init_process_group(backend, init_method='tcp://10.1.1.20:23456', rank=args.rank, world_size=4) Shared file-system initialization# Another initialization method makes use of a file system that is shared and visible from all machines in a group, along with a desired world_size. The URL should start with file:// and contain a path to a non-existent file (in an existing directory) on a shared file system. File-system initialization will automatically create that file if it doesn’t exist, but will not delete the file. Therefore, it is your responsibility to make sure that the file is cleaned up before the next init_process_group() call on the same file path/name. Note that automatic rank assignment is not supported anymore in the latest distributed package and group_name is deprecated as well. Warning This method assumes that the file system supports locking using fcntl - most local systems and NFS support it. Warning This method will always create the file and try its best to clean up and remove the file at the end of the program. In other words, each initialization with the file init method will need a brand new empty file in order for the initialization to succeed. If the same file used by the previous initialization (which happens not to get cleaned up) is used again, this is unexpected behavior and can often cause deadlocks and failures. Therefore, even though this method will try its best to clean up the file, if the auto-delete happens to be unsuccessful, it is your responsibility to ensure that the file is removed at the end of the training to prevent the same file to be reused again during the next time. This is especially important if you plan to call init_process_group() multiple times on the same file name. In other words, if the file is not removed/cleaned up and you call init_process_group() again on that file, failures are expected. The rule of thumb here is that, make sure that the file is non-existent or empty every time init_process_group() is called. import torch.distributed as dist # rank should always be specified dist.init_process_group(backend, init_method='file:///mnt/nfs/sharedfile', world_size=4, rank=args.rank) Environment variable initialization# This method will read the configuration from environment variables, allowing one to fully customize how the information is obtained. The variables to be set are: MASTER_PORT - required; has to be a free port on machine with rank 0 MASTER_ADDR - required (except for rank 0); address of rank 0 node WORLD_SIZE - required; can be set either here, or in a call to init function RANK - required; can be set either here, or in a call to init function The machine with rank 0 will be used to set up all connections. This is the default method, meaning that init_method does not have to be specified (or can be env://). Improving initialization time# TORCH_GLOO_LAZY_INIT - establishes connections on demand rather than using a full mesh which can greatly improve initialization time for non all2all operations. Post-Initialization# Once torch.distributed.init_process_group() was run, the following functions can be used. To check whether the process group has already been initialized use torch.distributed.is_initialized(). class torch.distributed.Backend(name)[source]# An enum-like class for backends. Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends. The values of this class are lowercase strings, e.g., "gloo". They can be accessed as attributes, e.g., Backend.NCCL. This class can be directly called to parse the string, e.g., Backend(backend_str) will check if backend_str is valid, and return the parsed lowercase string if so. It also accepts uppercase strings, e.g., Backend("GLOO") returns "gloo". Note The entry Backend.UNDEFINED is present but only used as initial value of some fields. Users should neither use it directly nor assume its existence. classmethod register_backend(name, func, extended_api=False, devices=None)[source]# Register a new backend with the given name and instantiating function. This class method is used by 3rd party ProcessGroup extension to register new backends. Parameters name (str) – Backend name of the ProcessGroup extension. It should match the one in init_process_group(). func (function) – Function handler that instantiates the backend. The function should be implemented in the backend extension and takes four arguments, including store, rank, world_size, and timeout. extended_api (bool, optional) – Whether the backend supports extended argument structure. Default: False. If set to True, the backend will get an instance of c10d::DistributedBackendOptions, and a process group options object as defined by the backend implementation. device (str or list of str, optional) – device type this backend supports, e.g. “cpu”, “cuda”, etc. If None, assuming both “cpu” and “cuda” Note This support of 3rd party backend is experimental and subject to change. torch.distributed.get_backend(group=None)[source]# Return the backend of the given process group. Parameters group (ProcessGroup, optional) – The process group to work on. The default is the general main process group. If another specific group is specified, the calling process must be part of group. Returns The backend of the given process group as a lower case string. Return type Backend torch.distributed.get_rank(group=None)[source]# Return the rank of the current process in the provided group, default otherwise. Rank is a unique identifier assigned to each process within a distributed process group. They are always consecutive integers ranging from 0 to world_size. Parameters group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. Returns The rank of the process group -1, if not part of the group Return type int torch.distributed.get_world_size(group=None)[source]# Return the number of processes in the current process group. Parameters group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. Returns The world size of the process group -1, if not part of the group Return type int Shutdown# It is important to clean up resources on exit by calling destroy_process_group(). The simplest pattern to follow is to destroy every process group and backend by calling destroy_process_group() with the default value of None for the group argument, at a point in the training script where communications are no longer needed, usually near the end of main(). The call should be made once per trainer-process, not at the outer process-launcher level. if destroy_process_group() is not called by all ranks in a pg within the timeout duration, especially when there are multiple process-groups in the application e.g. for N-D parallelism, hangs on exit are possible. This is because the destructor for ProcessGroupNCCL calls ncclCommAbort, which must be called collectively, but the order of calling ProcessGroupNCCL’s destructor if called by python’s GC is not deterministic. Calling destroy_process_group() helps by ensuring ncclCommAbort is called in a consistent order across ranks, and avoids calling ncclCommAbort during ProcessGroupNCCL’s destructor. Reinitialization# destroy_process_group can also be used to destroy individual process groups. One use case could be fault tolerant training, where a process group may be destroyed and then a new one initialized during runtime. In this case, it’s critical to synchronize the trainer processes using some means other than torch.distributed primitives _after_ calling destroy and before subsequently initializing. This behavior is currently unsupported/untested, due to the difficulty of achieving this synchronization, and is considered a known issue. Please file a github issue or RFC if this is a use case that’s blocking you. Groups# By default collectives operate on the default group (also called the world) and require all processes to enter the distributed function call. However, some workloads can benefit from more fine-grained communication. This is where distributed groups come into play. new_group() function can be used to create new groups, with arbitrary subsets of all processes. It returns an opaque group handle that can be given as a group argument to all collectives (collectives are distributed functions to exchange information in certain well-known programming patterns). torch.distributed.new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None, device_id=None)[source]# Create a new distributed group. This function requires that all processes in the main group (i.e. all processes that are part of the distributed job) enter this function, even if they are not going to be members of the group. Additionally, groups should be created in the same order in all processes. Warning Safe concurrent usage: When using multiple process groups with the NCCL backend, the user must ensure a globally consistent execution order of collectives across ranks. If multiple threads within a process issue collectives, explicit synchronization is necessary to ensure consistent ordering. When using async variants of torch.distributed communication APIs, a work object is returned and the communication kernel is enqueued on a separate CUDA stream, allowing overlap of communication and computation. Once one or more async ops have been issued on one process group, they must be synchronized with other cuda streams by calling work.wait() before using another process group. See Using multiple NCCL communicators concurrently for more details. Parameters ranks (list[int]) – List of ranks of group members. If None, will be set to all ranks. Default is None. timeout (timedelta, optional) – see init_process_group for details and default value. backend (str or Backend, optional) – The backend to use. Depending on build-time configurations, valid values are gloo and nccl. By default uses the same backend as the global group. This field should be given as a lowercase string (e.g., "gloo"), which can also be accessed via Backend attributes (e.g., Backend.GLOO). If None is passed in, the backend corresponding to the default process group will be used. Default is None. pg_options (ProcessGroupOptions, optional) – process group options specifying what additional options need to be passed in during the construction of specific process groups. i.e. for the nccl backend, is_high_priority_stream can be specified so that process group can pick up high priority cuda streams. For other available options to config nccl, See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-tuse_local_synchronization (bool, optional): perform a group-local barrier at the end of the process group creation. This is different in that non-member ranks don’t need to call into API and don’t join the barrier. group_desc (str, optional) – a string to describe the process group. device_id (torch.device, optional) – a single, specific device to “bind” this process to, The new_group call will try to initialize a communication backend immediately for the device if this field is given. Returns A handle of distributed group that can be given to collective calls or GroupMember.NON_GROUP_MEMBER if the rank is not part of ranks. N.B. use_local_synchronization doesn’t work with MPI. N.B. While use_local_synchronization=True can be significantly faster with larger clusters and small process groups, care must be taken since it changes cluster behavior as non-member ranks don’t join the group barrier(). N.B. use_local_synchronization=True can lead to deadlocks when each rank creates multiple overlapping process groups. To avoid that, make sure all ranks follow the same global creation order. torch.distributed.get_group_rank(group, global_rank)[source]# Translate a global rank into a group rank. global_rank must be part of group otherwise this raises RuntimeError. Parameters group (ProcessGroup) – ProcessGroup to find the relative rank. global_rank (int) – Global rank to query. Returns Group rank of global_rank relative to group Return type int N.B. calling this function on the default process group returns identity torch.distributed.get_global_rank(group, group_rank)[source]# Translate a group rank into a global rank. group_rank must be part of group otherwise this raises RuntimeError. Parameters group (ProcessGroup) – ProcessGroup to find the global rank from. group_rank (int) – Group rank to query. Returns Global rank of group_rank relative to group Return type int N.B. calling this function on the default process group returns identity torch.distributed.get_process_group_ranks(group)[source]# Get all ranks associated with group. Parameters group (Optional[ProcessGroup]) – ProcessGroup to get all ranks from. If None, the default process group will be used. Returns List of global ranks ordered by group rank. Return type list[int] DeviceMesh# DeviceMesh is a higher level abstraction that manages process groups (or NCCL communicators). It allows user to easily create inter node and intra node process groups without worrying about how to set up the ranks correctly for different sub process groups, and it helps manage those distributed process group easily. init_device_mesh() function can be used to create new DeviceMesh, with a mesh shape describing the device topology. class torch.distributed.device_mesh.DeviceMesh(device_type, mesh, *, mesh_dim_names=None, backend_override=None, _init_backend=True)[source]# DeviceMesh represents a mesh of devices, where layout of devices could be represented as a n-d dimension array, and each value of the n-d dimensional array is the global id of the default process group ranks. DeviceMesh could be used to setup the N dimensional device connections across the cluster, and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects already (i.e. if user call torch.cuda.set_device before the DeviceMesh initialization), and will select/set the device for the current process if user does not set the device beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization. DeviceMesh can also be used as a context manager when using together with DTensor APIs. Note DeviceMesh follows SPMD programming model, which means the same PyTorch Python program is running on all processes/ranks in the cluster. Therefore, users need to make sure the mesh array (which describes the layout of devices) should be identical across all ranks. Inconsistent mesh will lead to silent hang. Parameters device_type (str) – The device type of the mesh. Currently supports: “cpu”, “cuda/cuda-like”. mesh (ndarray) – A multi-dimensional array or an integer tensor describing the layout of devices, where the IDs are global IDs of the default process group. Returns A DeviceMesh object representing the device layout. Return type DeviceMesh The following program runs on each process/rank in an SPMD manner. In this example, we have 2 hosts with 4 GPUs each. A reduction over the first dimension of mesh will reduce across columns (0, 4), .. and (3, 7), a reduction over the second dimension of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7). Example: >>> from torch.distributed.device_mesh import DeviceMesh >>> >>> # Initialize device mesh as (2, 4) to represent the topology >>> # of cross-host(dim 0), and within-host (dim 1). >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) static from_group(group, device_type, mesh=None, *, mesh_dim_names=None)[source]# Constructs a DeviceMesh with device_type from an existing ProcessGroup or a list of existing ProcessGroup. The constructed device mesh has number of dimensions equal to the number of groups passed. For example, if a single process group is passed in, the resulted DeviceMesh is a 1D mesh. If a list of 2 process groups is passed in, the resulted DeviceMesh is a 2D mesh. If more than one group is passed, then the mesh and mesh_dim_names arguments are required. The order of the process groups passed in determines the topology of the mesh. For example, the first process group will be the 0th dimension of the DeviceMesh. The mesh tensor passed in must have the same number of dimensions as the number of process groups passed in, and the order of the dimensions in the mesh tensor must match the order in the process groups passed in. Parameters group (ProcessGroup or list[ProcessGroup]) – the existing ProcessGroup or a list of existing ProcessGroups. device_type (str) – The device type of the mesh. Currently supports: “cpu”, “cuda/cuda-like”. Passing in a device type with a GPU index, such as “cuda:0”, is not allowed. mesh (torch.Tensor or ArrayLike, optional) – A multi-dimensional array or an integer tensor describing the layout of devices, where the IDs are global IDs of the default process group. Default is None. mesh_dim_names (tuple[str], optional) – A tuple of mesh dimension names to assign to each dimension of the multi-dimensional array describing the layout of devices. Its length must match the length of mesh_shape. Each string in mesh_dim_names must be unique. Default is None. Returns A DeviceMesh object representing the device layout. Return type DeviceMesh get_all_groups()[source]# Returns a list of ProcessGroups for all mesh dimensions. Returns A list of ProcessGroup object. Return type list[torch.distributed.distributed_c10d.ProcessGroup] get_coordinate()[source]# Return the relative indices of this rank relative to all dimensions of the mesh. If this rank is not part of the mesh, return None. Return type Optional[list[int]] get_group(mesh_dim=None)[source]# Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh. Parameters mesh_dim (str/python:int, optional) – it can be the name of the mesh dimension or the index None. (of the mesh dimension. Default is) – Returns A ProcessGroup object. Return type ProcessGroup get_local_rank(mesh_dim=None)[source]# Returns the local rank of the given mesh_dim of the DeviceMesh. Parameters mesh_dim (str/python:int, optional) – it can be the name of the mesh dimension or the index None. (of the mesh dimension. Default is) – Returns An integer denotes the local rank. Return type int The following program runs on each process/rank in an SPMD manner. In this example, we have 2 hosts with 4 GPUs each. Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3. Example: >>> from torch.distributed.device_mesh import DeviceMesh >>> >>> # Initialize device mesh as (2, 4) to represent the topology >>> # of cross-host(dim 0), and within-host (dim 1). >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) get_rank()[source]# Returns the current global rank. Return type int Point-to-point communication# torch.distributed.send(tensor, dst=None, group=None, tag=0, group_dst=None)[source]# Send a tensor synchronously. Warning tag is not supported with the NCCL backend. Parameters tensor (Tensor) – Tensor to send. dst (int) – Destination rank on global process group (regardless of group argument). Destination rank should not be the same as the rank of the current process. group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. tag (int, optional) – Tag to match send with remote recv group_dst (int, optional) – Destination rank on group. Invalid to specify both dst and group_dst. torch.distributed.recv(tensor, src=None, group=None, tag=0, group_src=None)[source]# Receives a tensor synchronously. Warning tag is not supported with the NCCL backend. Parameters tensor (Tensor) – Tensor to fill with received data. src (int, optional) – Source rank on global process group (regardless of group argument). Will receive from any process if unspecified. group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. tag (int, optional) – Tag to match recv with remote send group_src (int, optional) – Destination rank on group. Invalid to specify both src and group_src. Returns Sender rank -1, if not part of the group Return type int isend() and irecv() return distributed request objects when used. In general, the type of this object is unspecified as they should never be created manually, but they are guaranteed to support two methods: is_completed() - returns True if the operation has finished wait() - will block the process until the operation is finished. is_completed() is guaranteed to return True once it returns. torch.distributed.isend(tensor, dst=None, group=None, tag=0, group_dst=None)[source]# Send a tensor asynchronously. Warning Modifying tensor before the request completes causes undefined behavior. Warning tag is not supported with the NCCL backend. Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self. Parameters tensor (Tensor) – Tensor to send. dst (int) – Destination rank on global process group (regardless of group argument) group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. tag (int, optional) – Tag to match send with remote recv group_dst (int, optional) – Destination rank on group. Invalid to specify both dst and group_dst Returns A distributed request object. None, if not part of the group Return type Optional[Work] torch.distributed.irecv(tensor, src=None, group=None, tag=0, group_src=None)[source]# Receives a tensor asynchronously. Warning tag is not supported with the NCCL backend. Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self. Parameters tensor (Tensor) – Tensor to fill with received data. src (int, optional) – Source rank on global process group (regardless of group argument). Will receive from any process if unspecified. group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. tag (int, optional) – Tag to match recv with remote send group_src (int, optional) – Destination rank on group. Invalid to specify both src and group_src. Returns A distributed request object. None, if not part of the group Return type Optional[Work] torch.distributed.send_object_list(object_list, dst=None, group=None, device=None, group_dst=None, use_batch=False)[source]# Sends picklable objects in object_list synchronously. Similar to send(), but Python objects can be passed in. Note that all objects in object_list must be picklable in order to be sent. Parameters object_list (List[Any]) – List of input objects to sent. Each object must be picklable. Receiver must provide lists of equal sizes. dst (int) – Destination rank to send object_list to. Destination rank is based on global process group (regardless of group argument) group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None. device (torch.device, optional) – If not None, the objects are serialized and converted to tensors which are moved to the device before sending. Default is None. group_dst (int, optional) – Destination rank on group. Must specify one of dst and group_dst but not both use_batch (bool, optional) – If True, use batch p2p operations instead of regular send operations. This avoids initializing 2-rank communicators and uses existing entire group communicators. See batch_isend_irecv for usage and assumptions. Default is False. Returns None. Note For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). Warning Object collectives have a number of serious performance and scalability limitations. See Object collectives for details. Warning send_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. Warning Calling send_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using send() instead. Example::>>> # Note: Process group initialization omitted on each rank. >>> import torch.distributed as dist >>> # Assumes backend is not NCCL >>> device = torch.device("cpu") >>> if dist.get_rank() == 0: >>> # Assumes world_size of 2. >>> objects = ["foo", 12, {1: 2}] # any picklable object >>> dist.send_object_list(objects, dst=1, device=device) >>> else: >>> objects = [None, None, None] >>> dist.recv_object_list(objects, src=0, device=device) >>> objects ['foo', 12, {1: 2}] torch.distributed.recv_object_list(object_list, src=None, group=None, device=None, group_src=None, use_batch=False)[source]# Receives picklable objects in object_list synchronously. Similar to recv(), but can receive Python objects. Parameters object_list (List[Any]) – List of objects to receive into. Must provide a list of sizes equal to the size of the list being sent. src (int, optional) – Source rank from which to recv object_list. Source rank is based on global process group (regardless of group argument) Will receive from any rank if set to None. Default is None. group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None. device (torch.device, optional) – If not None, receives on this device. Default is None. group_src (int, optional) – Destination rank on group. Invalid to specify both src and group_src. use_batch (bool, optional) – If True, use batch p2p operations instead of regular send operations. This avoids initializing 2-rank communicators and uses existing entire group communicators. See batch_isend_irecv for usage and assumptions. Default is False. Returns Sender rank. -1 if rank is not part of the group. If rank is part of the group, object_list will contain the sent objects from src rank. Note For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). Warning Object collectives have a number of serious performance and scalability limitations. See Object collectives for details. Warning recv_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. Warning Calling recv_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using recv() instead. Example::>>> # Note: Process group initialization omitted on each rank. >>> import torch.distributed as dist >>> # Assumes backend is not NCCL >>> device = torch.device("cpu") >>> if dist.get_rank() == 0: >>> # Assumes world_size of 2. >>> objects = ["foo", 12, {1: 2}] # any picklable object >>> dist.send_object_list(objects, dst=1, device=device) >>> else: >>> objects = [None, None, None] >>> dist.recv_object_list(objects, src=0, device=device) >>> objects ['foo', 12, {1: 2}] torch.distributed.batch_isend_irecv(p2p_op_list)[source]# Send or Receive a batch of tensors asynchronously and return a list of requests. Process each of the operations in p2p_op_list and return the corresponding requests. NCCL, Gloo, and UCC backend are currently supported. Parameters p2p_op_list (list[torch.distributed.distributed_c10d.P2POp]) – A list of point-to-point operations(type of each operator is torch.distributed.P2POp). The order of the isend/irecv in the list matters and it needs to match with corresponding isend/irecv on the remote end. Returns A list of distributed request objects returned by calling the corresponding op in the op_list. Return type list[torch.distributed.distributed_c10d.Work] Examples >>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank >>> recv_tensor = torch.randn(2, dtype=torch.float32) >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size) >>> recv_op = dist.P2POp( ... dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size ... ) >>> reqs = batch_isend_irecv([send_op, recv_op]) >>> for req in reqs: >>> req.wait() >>> recv_tensor tensor([2, 3]) # Rank 0 tensor([0, 1]) # Rank 1 Note Note that when this API is used with the NCCL PG backend, users must set the current GPU device with torch.cuda.set_device, otherwise it will lead to unexpected hang issues. In addition, if this API is the first collective call in the group passed to dist.P2POp, all ranks of the group must participate in this API call; otherwise, the behavior is undefined. If this API call is not the first collective call in the group, batched P2P operations involving only a subset of ranks of the group are allowed. class torch.distributed.P2POp(op, tensor, peer=None, group=None, tag=0, group_peer=None)[source]# A class to build point-to-point operations for batch_isend_irecv. This class builds the type of P2P operation, communication buffer, peer rank, Process Group, and tag. Instances of this class will be passed to batch_isend_irecv for point-to-point communications. Parameters op (Callable) – A function to send data to or receive data from a peer process. The type of op is either torch.distributed.isend or torch.distributed.irecv. tensor (Tensor) – Tensor to send or receive. peer (int, optional) – Destination or source rank. group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. tag (int, optional) – Tag to match send with recv. group_peer (int, optional) – Destination or source rank. Synchronous and asynchronous collective operations# Every collective operation function supports the following two kinds of operations, depending on the setting of the async_op flag passed into the collective: Synchronous operation - the default mode, when async_op is set to False. When the function returns, it is guaranteed that the collective operation is performed. In the case of CUDA operations, it is not guaranteed that the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, any further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream synchronization, see CUDA Semantics. See the below script to see examples of differences in these semantics for CPU and CUDA operations. Asynchronous operation - when async_op is set to True. The collective operation function returns a distributed request object. In general, you don’t need to create it manually and it is guaranteed to support two methods: is_completed() - in the case of CPU collectives, returns True if completed. In the case of CUDA operations, returns True if the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on the default stream without further synchronization. wait() - in the case of CPU collectives, will block the process until the operation is completed. In the case of CUDA collectives, will block the currently active CUDA stream until the operation is completed (but will not block the CPU). get_future() - returns torch._C.Future object. Supported for NCCL, also supported for most operations on GLOO and MPI, except for peer to peer operations. Note: as we continue adopting Futures and merging APIs, get_future() call might become redundant. Example The following code can serve as a reference regarding semantics for CUDA operations when using distributed collectives. It shows the explicit need to synchronize when using collective outputs on different CUDA streams: # Code runs on each rank. dist.init_process_group("nccl", rank=rank, world_size=2) output = torch.tensor([rank]).cuda(rank) s = torch.cuda.Stream() handle = dist.all_reduce(output, async_op=True) # Wait ensures the operation is enqueued, but not necessarily complete. handle.wait() # Using result on non-default stream. with torch.cuda.stream(s): s.wait_stream(torch.cuda.default_stream()) output.add_(100) if rank == 0: # if the explicit call to wait_stream was omitted, the output below will be # non-deterministically 1 or 101, depending on whether the allreduce overwrote # the value after the add completed. print(output) Collective functions# torch.distributed.broadcast(tensor, src=None, group=None, async_op=False, group_src=None)[source]# Broadcasts the tensor to the whole group. tensor must have the same number of elements in all processes participating in the collective. Parameters tensor (Tensor) – Data to be sent if src is the rank of current process, and tensor to be used to save received data otherwise. src (int) – Source rank on global process group (regardless of group argument). group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. async_op (bool, optional) – Whether this op should be an async op group_src (int) – Source rank on group. Must specify one of group_src and src but not both. Returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group torch.distributed.broadcast_object_list(object_list, src=None, group=None, device=None, group_src=None)[source]# Broadcasts picklable objects in object_list to the whole group. Similar to broadcast(), but Python objects can be passed in. Note that all objects in object_list must be picklable in order to be broadcasted. Parameters object_list (List[Any]) – List of input objects to broadcast. Each object must be picklable. Only objects on the src rank will be broadcast, but each rank must provide lists of equal sizes. src (int) – Source rank from which to broadcast object_list. Source rank is based on global process group (regardless of group argument) group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None. device (torch.device, optional) – If not None, the objects are serialized and converted to tensors which are moved to the device before broadcasting. Default is None. group_src (int) – Source rank on group. Must not specify one of group_src and src but not both. Returns None. If rank is part of the group, object_list will contain the broadcasted objects from src rank. Note For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). Note Note that this API differs slightly from the broadcast() collective since it does not provide an async_op handle and thus will be a blocking call. Warning Object collectives have a number of serious performance and scalability limitations. See Object collectives for details. Warning broadcast_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. Warning Calling broadcast_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using broadcast() instead. Example::>>> # Note: Process group initialization omitted on each rank. >>> import torch.distributed as dist >>> if dist.get_rank() == 0: >>> # Assumes world_size of 3. >>> objects = ["foo", 12, {1: 2}] # any picklable object >>> else: >>> objects = [None, None, None] >>> # Assumes backend is not NCCL >>> device = torch.device("cpu") >>> dist.broadcast_object_list(objects, src=0, device=device) >>> objects ['foo', 12, {1: 2}] torch.distributed.all_reduce(tensor, op=, group=None, async_op=False)[source]# Reduces the tensor data across all machines in a way that all get the final result. After the call tensor is going to be bitwise identical in all processes. Complex tensors are supported. Parameters tensor (Tensor) – Input and output of the collective. The function operates in-place. op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions. group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. async_op (bool, optional) – Whether this op should be an async op Returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group Examples >>> # All tensors below are of torch.int64 type. >>> # We have 2 process groups, 2 ranks. >>> device = torch.device(f"cuda:{rank}") >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank >>> tensor tensor([1, 2], device='cuda:0') # Rank 0 tensor([3, 4], device='cuda:1') # Rank 1 >>> dist.all_reduce(tensor, op=ReduceOp.SUM) >>> tensor tensor([4, 6], device='cuda:0') # Rank 0 tensor([4, 6], device='cuda:1') # Rank 1 >>> # All tensors below are of torch.cfloat type. >>> # We have 2 process groups, 2 ranks. >>> tensor = torch.tensor( ... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device ... ) + 2 * rank * (1 + 1j) >>> tensor tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 >>> dist.all_reduce(tensor, op=ReduceOp.SUM) >>> tensor tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0 tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1 torch.distributed.reduce(tensor, dst=None, op=, group=None, async_op=False, group_dst=None)[source]# Reduces the tensor data across all machines. Only the process with rank dst is going to receive the final result. Parameters tensor (Tensor) – Input and output of the collective. The function operates in-place. dst (int) – Destination rank on global process group (regardless of group argument) op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions. group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. async_op (bool, optional) – Whether this op should be an async op group_dst (int) – Destination rank on group. Must specify one of group_dst and dst but not both. Returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group torch.distributed.all_gather(tensor_list, tensor, group=None, async_op=False)[source]# Gathers tensors from the whole group in a list. Complex and uneven sized tensors are supported. Parameters tensor_list (list[Tensor]) – Output list. It should contain correctly-sized tensors to be used for output of the collective. Uneven sized tensors are supported. tensor (Tensor) – Tensor to be broadcast from current process. group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. async_op (bool, optional) – Whether this op should be an async op Returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group Examples >>> # All tensors below are of torch.int64 dtype. >>> # We have 2 process groups, 2 ranks. >>> device = torch.device(f"cuda:{rank}") >>> tensor_list = [ ... torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2) ... ] >>> tensor_list [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0 [tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1 >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank >>> tensor tensor([1, 2], device='cuda:0') # Rank 0 tensor([3, 4], device='cuda:1') # Rank 1 >>> dist.all_gather(tensor_list, tensor) >>> tensor_list [tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0 [tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1 >>> # All tensors below are of torch.cfloat dtype. >>> # We have 2 process groups, 2 ranks. >>> tensor_list = [ ... torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2) ... ] >>> tensor_list [tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0 [tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1 >>> tensor = torch.tensor( ... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device ... ) + 2 * rank * (1 + 1j) >>> tensor tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 >>> dist.all_gather(tensor_list, tensor) >>> tensor_list [tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0 [tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1 torch.distributed.all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False)[source]# Gather tensors from all ranks and put them in a single output tensor. This function requires all tensors to be the same size on each process. Parameters output_tensor (Tensor) – Output tensor to accommodate tensor elements from all ranks. It must be correctly sized to have one of the following forms: (i) a concatenation of all the input tensors along the primary dimension; for definition of “concatenation”, see torch.cat(); (ii) a stack of all the input tensors along the primary dimension; for definition of “stack”, see torch.stack(). Examples below may better explain the supported output forms. input_tensor (Tensor) – Tensor to be gathered from current rank. Different from the all_gather API, the input tensors in this API must have the same size across all ranks. group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. async_op (bool, optional) – Whether this op should be an async op Returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group Examples >>> # All tensors below are of torch.int64 dtype and on CUDA devices. >>> # We have two ranks. >>> device = torch.device(f"cuda:{rank}") >>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank >>> tensor_in tensor([1, 2], device='cuda:0') # Rank 0 tensor([3, 4], device='cuda:1') # Rank 1 >>> # Output in concatenation form >>> tensor_out = torch.zeros(world_size * 2, dtype=torch.int64, device=device) >>> dist.all_gather_into_tensor(tensor_out, tensor_in) >>> tensor_out tensor([1, 2, 3, 4], device='cuda:0') # Rank 0 tensor([1, 2, 3, 4], device='cuda:1') # Rank 1 >>> # Output in stack form >>> tensor_out2 = torch.zeros(world_size, 2, dtype=torch.int64, device=device) >>> dist.all_gather_into_tensor(tensor_out2, tensor_in) >>> tensor_out2 tensor([[1, 2], [3, 4]], device='cuda:0') # Rank 0 tensor([[1, 2], [3, 4]], device='cuda:1') # Rank 1 torch.distributed.all_gather_object(object_list, obj, group=None)[source]# Gathers picklable objects from the whole group into a list. Similar to all_gather(), but Python objects can be passed in. Note that the object must be picklable in order to be gathered. Parameters object_list (list[Any]) – Output list. It should be correctly sized as the size of the group for this collective and will contain the output. obj (Any) – Pickable Python object to be broadcast from current process. group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. Default is None. Returns None. If the calling rank is part of this group, the output of the collective will be populated into the input object_list. If the calling rank is not part of the group, the passed in object_list will be unmodified. Note Note that this API differs slightly from the all_gather() collective since it does not provide an async_op handle and thus will be a blocking call. Note For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). Warning Object collectives have a number of serious performance and scalability limitations. See Object collectives for details. Warning all_gather_object() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. Warning Calling all_gather_object() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using all_gather() instead. Example::>>> # Note: Process group initialization omitted on each rank. >>> import torch.distributed as dist >>> # Assumes world_size of 3. >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object >>> output = [None for _ in gather_objects] >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) >>> output ['foo', 12, {1: 2}] torch.distributed.gather(tensor, gather_list=None, dst=None, group=None, async_op=False, group_dst=None)[source]# Gathers a list of tensors in a single process. This function requires all tensors to be the same size on each process. Parameters tensor (Tensor) – Input tensor. gather_list (list[Tensor], optional) – List of appropriately, same-sized tensors to use for gathered data (default is None, must be specified on the destination rank) dst (int, optional) – Destination rank on global process group (regardless of group argument). (If both dst and group_dst are None, default is global rank 0) group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. async_op (bool, optional) – Whether this op should be an async op group_dst (int, optional) – Destination rank on group. Invalid to specify both dst and group_dst Returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group Note Note that all Tensors in gather_list must have the same size. Example::>>> # We have 2 process groups, 2 ranks. >>> tensor_size = 2 >>> device = torch.device(f'cuda:{rank}') >>> tensor = torch.ones(tensor_size, device=device) + rank >>> if dist.get_rank() == 0: >>> gather_list = [torch.zeros_like(tensor, device=device) for i in range(2)] >>> else: >>> gather_list = None >>> dist.gather(tensor, gather_list, dst=0) >>> # Rank 0 gets gathered data. >>> gather_list [tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0 None # Rank 1 torch.distributed.gather_object(obj, object_gather_list=None, dst=None, group=None, group_dst=None)[source]# Gathers picklable objects from the whole group in a single process. Similar to gather(), but Python objects can be passed in. Note that the object must be picklable in order to be gathered. Parameters obj (Any) – Input object. Must be picklable. object_gather_list (list[Any]) – Output list. On the dst rank, it should be correctly sized as the size of the group for this collective and will contain the output. Must be None on non-dst ranks. (default is None) dst (int, optional) – Destination rank on global process group (regardless of group argument). (If both dst and group_dst are None, default is global rank 0) group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None. group_dst (int, optional) – Destination rank on group. Invalid to specify both dst and group_dst Returns None. On the dst rank, object_gather_list will contain the output of the collective. Note Note that this API differs slightly from the gather collective since it does not provide an async_op handle and thus will be a blocking call. Note For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). Warning Object collectives have a number of serious performance and scalability limitations. See Object collectives for details. Warning gather_object() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. Warning Calling gather_object() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using gather() instead. Example::>>> # Note: Process group initialization omitted on each rank. >>> import torch.distributed as dist >>> # Assumes world_size of 3. >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object >>> output = [None for _ in gather_objects] >>> dist.gather_object( ... gather_objects[dist.get_rank()], ... output if dist.get_rank() == 0 else None, ... dst=0 ... ) >>> # On rank 0 >>> output ['foo', 12, {1: 2}] torch.distributed.scatter(tensor, scatter_list=None, src=None, group=None, async_op=False, group_src=None)[source]# Scatters a list of tensors to all processes in a group. Each process will receive exactly one tensor and store its data in the tensor argument. Complex tensors are supported. Parameters tensor (Tensor) – Output tensor. scatter_list (list[Tensor]) – List of tensors to scatter (default is None, must be specified on the source rank) src (int) – Source rank on global process group (regardless of group argument). (If both src and group_src are None, default is global rank 0) group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. async_op (bool, optional) – Whether this op should be an async op group_src (int, optional) – Source rank on group. Invalid to specify both src and group_src Returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group Note Note that all Tensors in scatter_list must have the same size. Example::>>> # Note: Process group initialization omitted on each rank. >>> import torch.distributed as dist >>> tensor_size = 2 >>> device = torch.device(f'cuda:{rank}') >>> output_tensor = torch.zeros(tensor_size, device=device) >>> if dist.get_rank() == 0: >>> # Assumes world_size of 2. >>> # Only tensors, all of which must be the same size. >>> t_ones = torch.ones(tensor_size, device=device) >>> t_fives = torch.ones(tensor_size, device=device) * 5 >>> scatter_list = [t_ones, t_fives] >>> else: >>> scatter_list = None >>> dist.scatter(output_tensor, scatter_list, src=0) >>> # Rank i gets scatter_list[i]. >>> output_tensor tensor([1., 1.], device='cuda:0') # Rank 0 tensor([5., 5.], device='cuda:1') # Rank 1 torch.distributed.scatter_object_list(scatter_object_output_list, scatter_object_input_list=None, src=None, group=None, group_src=None)[source]# Scatters picklable objects in scatter_object_input_list to the whole group. Similar to scatter(), but Python objects can be passed in. On each rank, the scattered object will be stored as the first element of scatter_object_output_list. Note that all objects in scatter_object_input_list must be picklable in order to be scattered. Parameters scatter_object_output_list (List[Any]) – Non-empty list whose first element will store the object scattered to this rank. scatter_object_input_list (List[Any], optional) – List of input objects to scatter. Each object must be picklable. Only objects on the src rank will be scattered, and the argument can be None for non-src ranks. src (int) – Source rank from which to scatter scatter_object_input_list. Source rank is based on global process group (regardless of group argument). (If both src and group_src are None, default is global rank 0) group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None. group_src (int, optional) – Source rank on group. Invalid to specify both src and group_src Returns None. If rank is part of the group, scatter_object_output_list will have its first element set to the scattered object for this rank. Note Note that this API differs slightly from the scatter collective since it does not provide an async_op handle and thus will be a blocking call. Warning Object collectives have a number of serious performance and scalability limitations. See Object collectives for details. Warning scatter_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. Warning Calling scatter_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using scatter() instead. Example::>>> # Note: Process group initialization omitted on each rank. >>> import torch.distributed as dist >>> if dist.get_rank() == 0: >>> # Assumes world_size of 3. >>> objects = ["foo", 12, {1: 2}] # any picklable object >>> else: >>> # Can be any list on non-src ranks, elements are not used. >>> objects = [None, None, None] >>> output_list = [None] >>> dist.scatter_object_list(output_list, objects, src=0) >>> # Rank i gets objects[i]. For example, on rank 2: >>> output_list [{1: 2}] torch.distributed.reduce_scatter(output, input_list, op=, group=None, async_op=False)[source]# Reduces, then scatters a list of tensors to all processes in a group. Parameters output (Tensor) – Output tensor. input_list (list[Tensor]) – List of tensors to reduce and scatter. op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions. group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. async_op (bool, optional) – Whether this op should be an async op. Returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group. torch.distributed.reduce_scatter_tensor(output, input, op=, group=None, async_op=False)[source]# Reduces, then scatters a tensor to all ranks in a group. Parameters output (Tensor) – Output tensor. It should have the same size across all ranks. input (Tensor) – Input tensor to be reduced and scattered. Its size should be output tensor size times the world size. The input tensor can have one of the following shapes: (i) a concatenation of the output tensors along the primary dimension, or (ii) a stack of the output tensors along the primary dimension. For definition of “concatenation”, see torch.cat(). For definition of “stack”, see torch.stack(). group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. async_op (bool, optional) – Whether this op should be an async op. Returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group. Examples >>> # All tensors below are of torch.int64 dtype and on CUDA devices. >>> # We have two ranks. >>> device = torch.device(f"cuda:{rank}") >>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device) >>> # Input in concatenation form >>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device) >>> tensor_in tensor([0, 1, 2, 3], device='cuda:0') # Rank 0 tensor([0, 1, 2, 3], device='cuda:1') # Rank 1 >>> dist.reduce_scatter_tensor(tensor_out, tensor_in) >>> tensor_out tensor([0, 2], device='cuda:0') # Rank 0 tensor([4, 6], device='cuda:1') # Rank 1 >>> # Input in stack form >>> tensor_in = torch.reshape(tensor_in, (world_size, 2)) >>> tensor_in tensor([[0, 1], [2, 3]], device='cuda:0') # Rank 0 tensor([[0, 1], [2, 3]], device='cuda:1') # Rank 1 >>> dist.reduce_scatter_tensor(tensor_out, tensor_in) >>> tensor_out tensor([0, 2], device='cuda:0') # Rank 0 tensor([4, 6], device='cuda:1') # Rank 1 torch.distributed.all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None, group=None, async_op=False)[source]# Split input tensor and then scatter the split list to all processes in a group. Later the received tensors are concatenated from all the processes in the group and returned as a single output tensor. Complex tensors are supported. Parameters output (Tensor) – Gathered concatenated output tensor. input (Tensor) – Input tensor to scatter. output_split_sizes – (list[Int], optional): Output split sizes for dim 0 if specified None or empty, dim 0 of output tensor must divide equally by world_size. input_split_sizes – (list[Int], optional): Input split sizes for dim 0 if specified None or empty, dim 0 of input tensor must divide equally by world_size. group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. async_op (bool, optional) – Whether this op should be an async op. Returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group. Warning all_to_all_single is experimental and subject to change. Examples >>> input = torch.arange(4) + rank * 4 >>> input tensor([0, 1, 2, 3]) # Rank 0 tensor([4, 5, 6, 7]) # Rank 1 tensor([8, 9, 10, 11]) # Rank 2 tensor([12, 13, 14, 15]) # Rank 3 >>> output = torch.empty([4], dtype=torch.int64) >>> dist.all_to_all_single(output, input) >>> output tensor([0, 4, 8, 12]) # Rank 0 tensor([1, 5, 9, 13]) # Rank 1 tensor([2, 6, 10, 14]) # Rank 2 tensor([3, 7, 11, 15]) # Rank 3 >>> # Essentially, it is similar to following operation: >>> scatter_list = list(input.chunk(world_size)) >>> gather_list = list(output.chunk(world_size)) >>> for i in range(world_size): >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i) >>> # Another example with uneven split >>> input tensor([0, 1, 2, 3, 4, 5]) # Rank 0 tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 tensor([20, 21, 22, 23, 24]) # Rank 2 tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 >>> input_splits [2, 2, 1, 1] # Rank 0 [3, 2, 2, 2] # Rank 1 [2, 1, 1, 1] # Rank 2 [2, 2, 2, 1] # Rank 3 >>> output_splits [2, 3, 2, 2] # Rank 0 [2, 2, 1, 2] # Rank 1 [1, 2, 1, 2] # Rank 2 [1, 2, 1, 1] # Rank 3 >>> output = ... >>> dist.all_to_all_single(output, input, output_splits, input_splits) >>> output tensor([ 0, 1, 10, 11, 12, 20, 21, 30, 31]) # Rank 0 tensor([ 2, 3, 13, 14, 22, 32, 33]) # Rank 1 tensor([ 4, 15, 16, 23, 34, 35]) # Rank 2 tensor([ 5, 17, 18, 24, 36]) # Rank 3 >>> # Another example with tensors of torch.cfloat type. >>> input = torch.tensor( ... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat ... ) + 4 * rank * (1 + 1j) >>> input tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0 tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1 tensor([9+9j, 10+10j, 11+11j, 12+12j]) # Rank 2 tensor([13+13j, 14+14j, 15+15j, 16+16j]) # Rank 3 >>> output = torch.empty([4], dtype=torch.int64) >>> dist.all_to_all_single(output, input) >>> output tensor([1+1j, 5+5j, 9+9j, 13+13j]) # Rank 0 tensor([2+2j, 6+6j, 10+10j, 14+14j]) # Rank 1 tensor([3+3j, 7+7j, 11+11j, 15+15j]) # Rank 2 tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3 torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False)[source]# Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. Complex tensors are supported. Parameters output_tensor_list (list[Tensor]) – List of tensors to be gathered one per rank. input_tensor_list (list[Tensor]) – List of tensors to scatter one per rank. group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. async_op (bool, optional) – Whether this op should be an async op. Returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group. Warning all_to_all is experimental and subject to change. Examples >>> input = torch.arange(4) + rank * 4 >>> input = list(input.chunk(4)) >>> input [tensor([0]), tensor([1]), tensor([2]), tensor([3])] # Rank 0 [tensor([4]), tensor([5]), tensor([6]), tensor([7])] # Rank 1 [tensor([8]), tensor([9]), tensor([10]), tensor([11])] # Rank 2 [tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3 >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4)) >>> dist.all_to_all(output, input) >>> output [tensor([0]), tensor([4]), tensor([8]), tensor([12])] # Rank 0 [tensor([1]), tensor([5]), tensor([9]), tensor([13])] # Rank 1 [tensor([2]), tensor([6]), tensor([10]), tensor([14])] # Rank 2 [tensor([3]), tensor([7]), tensor([11]), tensor([15])] # Rank 3 >>> # Essentially, it is similar to following operation: >>> scatter_list = input >>> gather_list = output >>> for i in range(world_size): >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i) >>> input tensor([0, 1, 2, 3, 4, 5]) # Rank 0 tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 tensor([20, 21, 22, 23, 24]) # Rank 2 tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 >>> input_splits [2, 2, 1, 1] # Rank 0 [3, 2, 2, 2] # Rank 1 [2, 1, 1, 1] # Rank 2 [2, 2, 2, 1] # Rank 3 >>> output_splits [2, 3, 2, 2] # Rank 0 [2, 2, 1, 2] # Rank 1 [1, 2, 1, 2] # Rank 2 [1, 2, 1, 1] # Rank 3 >>> input = list(input.split(input_splits)) >>> input [tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])] # Rank 0 [tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1 [tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])] # Rank 2 [tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])] # Rank 3 >>> output = ... >>> dist.all_to_all(output, input) >>> output [tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])] # Rank 0 [tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])] # Rank 1 [tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])] # Rank 2 [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3 >>> # Another example with tensors of torch.cfloat type. >>> input = torch.tensor( ... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat ... ) + 4 * rank * (1 + 1j) >>> input = list(input.chunk(4)) >>> input [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0 [tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])] # Rank 1 [tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])] # Rank 2 [tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])] # Rank 3 >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4)) >>> dist.all_to_all(output, input) >>> output [tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])] # Rank 0 [tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])] # Rank 1 [tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])] # Rank 2 [tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])] # Rank 3 torch.distributed.barrier(group=None, async_op=False, device_ids=None)[source]# Synchronize all processes. This collective blocks processes until the whole group enters this function, if async_op is False, or if async work handle is called on wait(). Parameters group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. async_op (bool, optional) – Whether this op should be an async op device_ids ([int], optional) – List of device/GPU ids. Only one id is expected. Returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group Note ProcessGroupNCCL now blocks the cpu thread till the completion of the barrier collective. Note ProcessGroupNCCL implements barrier as an all_reduce of a 1-element tensor. A device must be chosen for allocating this tensor. The device choice is made by checking in this order (1) the first device passed to device_ids arg of barrier if not None, (2) the device passed to init_process_group if not None, (3) the device that was first used with this process group, if another collective with tensor inputs has been performed, (4) the device index indicated by the global rank mod local device count. torch.distributed.monitored_barrier(group=None, timeout=None, wait_all_ranks=False)[source]# Synchronize processes similar to torch.distributed.barrier, but consider a configurable timeout. It is able to report ranks that did not pass this barrier within the provided timeout. Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0. Rank 0 will block until all send /recv from other ranks are processed, and will report failures for ranks that failed to respond in time. Note that if one rank does not reach the monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier. This collective will block all processes/ranks in the group, until the whole group exits the function successfully, making it useful for debugging and synchronizing. However, it can have a performance impact and should only be used for debugging or scenarios that require full synchronization points on the host-side. For debugging purposes, this barrier can be inserted before the application’s collective calls to check if any ranks are desynchronized. Note Note that this collective is only supported with the GLOO backend. Parameters group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. timeout (datetime.timedelta, optional) – Timeout for monitored_barrier. If None, the default process group timeout will be used. wait_all_ranks (bool, optional) – Whether to collect all failed ranks or not. By default, this is False and monitored_barrier on rank 0 will throw on the first failed rank it encounters in order to fail fast. By setting wait_all_ranks=True monitored_barrier will collect all failed ranks and throw an error containing information about all failed ranks. Returns None. Example::>>> # Note: Process group initialization omitted on each rank. >>> import torch.distributed as dist >>> if dist.get_rank() != 1: >>> dist.monitored_barrier() # Raises exception indicating that >>> # rank 1 did not call into monitored_barrier. >>> # Example with wait_all_ranks=True >>> if dist.get_rank() == 0: >>> dist.monitored_barrier(wait_all_ranks=True) # Raises exception >>> # indicating that ranks 1, 2, ... world_size - 1 did not call into >>> # monitored_barrier. class torch.distributed.Work# A Work object represents the handle to a pending asynchronous operation in PyTorch’s distributed package. It is returned by non-blocking collective operations, such as dist.all_reduce(tensor, async_op=True). block_current_stream(self: torch._C._distributed_c10d.Work) → None# Blocks the currently active GPU stream on the operation to complete. For GPU based collectives this is equivalent to synchronize. For CPU initiated collectives such as with Gloo this will block the CUDA stream until the operation is complete. This returns immediately in all cases. To check whether an operation was successful you should check the Work object result asynchronously. boxed(self: torch._C._distributed_c10d.Work) → object# exception(self: torch._C._distributed_c10d.Work) → std::__exception_ptr::exception_ptr# get_future(self: torch._C._distributed_c10d.Work) → torch.Future# Returns A torch.futures.Future object which is associated with the completion of the Work. As an example, a future object can be retrieved by fut = process_group.allreduce(tensors).get_future(). Example::Below is an example of a simple allreduce DDP communication hook that uses get_future API to retrieve a Future associated with the completion of allreduce. >>> def allreduce(process_group: dist.ProcessGroup, bucket: dist.GradBucket): -> torch.futures.Future >>> group_to_use = process_group if process_group is not None else torch.distributed.group.WORLD >>> tensor = bucket.buffer().div_(group_to_use.size()) >>> return torch.distributed.all_reduce(tensor, group=group_to_use, async_op=True).get_future() >>> ddp_model.register_comm_hook(state=None, hook=allreduce) Warning get_future API supports NCCL, and partially GLOO and MPI backends (no support for peer-to-peer operations like send/recv) and will return a torch.futures.Future. In the example above, allreduce work will be done on GPU using NCCL backend, fut.wait() will return after synchronizing the appropriate NCCL streams with PyTorch’s current device streams to ensure we can have asynchronous CUDA execution and it does not wait for the entire operation to complete on GPU. Note that CUDAFuture does not support TORCH_NCCL_BLOCKING_WAIT flag or NCCL’s barrier(). In addition, if a callback function was added by fut.then(), it will wait until WorkNCCL’s NCCL streams synchronize with ProcessGroupNCCL’s dedicated callback stream and invoke the callback inline after running the callback on the callback stream. fut.then() will return another CUDAFuture that holds the return value of the callback and a CUDAEvent that recorded the callback stream. For CPU work, fut.done() returns true when work has been completed and value() tensors are ready. For GPU work, fut.done() returns true only whether the operation has been enqueued. For mixed CPU-GPU work (e.g. sending GPU tensors with GLOO), fut.done() returns true when tensors have arrived on respective nodes, but not yet necessarily synched on respective GPUs (similarly to GPU work). get_future_result(self: torch._C._distributed_c10d.Work) → torch.Future# Returns A torch.futures.Future object of int type which maps to the enum type of WorkResult As an example, a future object can be retrieved by fut = process_group.allreduce(tensor).get_future_result(). Example::users can use fut.wait() to blocking wait for the completion of the work and get the WorkResult by fut.value(). Also, users can use fut.then(call_back_func) to register a callback function to be called when the work is completed, without blocking the current thread. Warning get_future_result API supports NCCL is_completed(self: torch._C._distributed_c10d.Work) → bool# is_success(self: torch._C._distributed_c10d.Work) → bool# result(self: torch._C._distributed_c10d.Work) → list[torch.Tensor]# source_rank(self: torch._C._distributed_c10d.Work) → int# synchronize(self: torch._C._distributed_c10d.Work) → None# static unbox(arg0: object) → torch._C._distributed_c10d.Work# wait(self: torch._C._distributed_c10d.Work, timeout: datetime.timedelta = datetime.timedelta(0)) → bool# Returns true/false. Example:: try:work.wait(timeout) except:# some handling Warning In normal cases, users do not need to set the timeout. calling wait() is the same as calling synchronize(): Letting the current stream block on the completion of the NCCL work. However, if timeout is set, it will block the CPU thread until the NCCL work is completed or timed out. If timeout, exception will be thrown. class torch.distributed.ReduceOp# An enum-like class for available reduction operations: SUM, PRODUCT, MIN, MAX, BAND, BOR, BXOR, and PREMUL_SUM. BAND, BOR, and BXOR reductions are not available when using the NCCL backend. AVG divides values by the world size before summing across ranks. AVG is only available with the NCCL backend, and only for NCCL versions 2.10 or later. PREMUL_SUM multiplies inputs by a given scalar locally before reduction. PREMUL_SUM is only available with the NCCL backend, and only available for NCCL versions 2.11 or later. Users are supposed to use torch.distributed._make_nccl_premul_sum. Additionally, MAX, MIN and PRODUCT are not supported for complex tensors. The values of this class can be accessed as attributes, e.g., ReduceOp.SUM. They are used in specifying strategies for reduction collectives, e.g., reduce(). This class does not support __members__ property. class torch.distributed.reduce_op# Deprecated enum-like class for reduction operations: SUM, PRODUCT, MIN, and MAX. ReduceOp is recommended to use instead. Distributed Key-Value Store# The distributed package comes with a distributed key-value store, which can be used to share information between processes in the group as well as to initialize the distributed package in torch.distributed.init_process_group() (by explicitly creating the store as an alternative to specifying init_method.) There are 3 choices for Key-Value Stores: TCPStore, FileStore, and HashStore. class torch.distributed.Store# Base class for all store implementations, such as the 3 provided by PyTorch distributed: (TCPStore, FileStore, and HashStore). __init__(self: torch._C._distributed_c10d.Store) → None# add(self: torch._C._distributed_c10d.Store, arg0: str, arg1: SupportsInt) → int# The first call to add for a given key creates a counter associated with key in the store, initialized to amount. Subsequent calls to add with the same key increment the counter by the specified amount. Calling add() with a key that has already been set in the store by set() will result in an exception. Parameters key (str) – The key in the store whose counter will be incremented. amount (int) – The quantity by which the counter will be incremented. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> # Using TCPStore as an example, other store types can also be used >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.add("first_key", 1) >>> store.add("first_key", 6) >>> # Should return 7 >>> store.get("first_key") append(self: torch._C._distributed_c10d.Store, arg0: str, arg1: str) → None# Append the key-value pair into the store based on the supplied key and value. If key does not exists in the store, it will be created. Parameters key (str) – The key to be appended to the store. value (str) – The value associated with key to be added to the store. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.append("first_key", "po") >>> store.append("first_key", "tato") >>> # Should return "potato" >>> store.get("first_key") check(self: torch._C._distributed_c10d.Store, arg0: collections.abc.Sequence[str]) → bool# The call to check whether a given list of keys have value stored in the store. This call immediately returns in normal cases but still suffers from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed. Calling check() with a list of keys that one wants to check whether stored in the store or not. Parameters keys (list[str]) – The keys to query whether stored in the store. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> # Using TCPStore as an example, other store types can also be used >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.add("first_key", 1) >>> # Should return 7 >>> store.check(["first_key"]) clone(self: torch._C._distributed_c10d.Store) → torch._C._distributed_c10d.Store# Clones the store and returns a new object that points to the same underlying store. The returned store can be used concurrently with the original object. This is intended to provide a safe way to use a store from multiple threads by cloning one store per thread. compare_set(self: torch._C._distributed_c10d.Store, arg0: str, arg1: str, arg2: str) → bytes# Inserts the key-value pair into the store based on the supplied key and performs comparison between expected_value and desired_value before inserting. desired_value will only be set if expected_value for the key already exists in the store or if expected_value is an empty string. Parameters key (str) – The key to be checked in the store. expected_value (str) – The value associated with key to be checked before insertion. desired_value (str) – The value associated with key to be added to the store. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.set("key", "first_value") >>> store.compare_set("key", "first_value", "second_value") >>> # Should return "second_value" >>> store.get("key") delete_key(self: torch._C._distributed_c10d.Store, arg0: str) → bool# Deletes the key-value pair associated with key from the store. Returns true if the key was successfully deleted, and false if it was not. Warning The delete_key API is only supported by the TCPStore and HashStore. Using this API with the FileStore will result in an exception. Parameters key (str) – The key to be deleted from the store Returns True if key was deleted, otherwise False. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> # Using TCPStore as an example, HashStore can also be used >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.set("first_key") >>> # This should return true >>> store.delete_key("first_key") >>> # This should return false >>> store.delete_key("bad_key") get(self: torch._C._distributed_c10d.Store, arg0: str) → bytes# Retrieves the value associated with the given key in the store. If key is not present in the store, the function will wait for timeout, which is defined when initializing the store, before throwing an exception. Parameters key (str) – The function will return the value associated with this key. Returns Value associated with key if key is in the store. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.set("first_key", "first_value") >>> # Should return "first_value" >>> store.get("first_key") has_extended_api(self: torch._C._distributed_c10d.Store) → bool# Returns true if the store supports extended operations. multi_get(self: torch._C._distributed_c10d.Store, arg0: collections.abc.Sequence[str]) → list[bytes]# Retrieve all values in keys. If any key in keys is not present in the store, the function will wait for timeout Parameters keys (List[str]) – The keys to be retrieved from the store. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.set("first_key", "po") >>> store.set("second_key", "tato") >>> # Should return [b"po", b"tato"] >>> store.multi_get(["first_key", "second_key"]) multi_set(self: torch._C._distributed_c10d.Store, arg0: collections.abc.Sequence[str], arg1: collections.abc.Sequence[str]) → None# Inserts a list key-value pair into the store based on the supplied keys and values Parameters keys (List[str]) – The keys to insert. values (List[str]) – The values to insert. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.multi_set(["first_key", "second_key"], ["po", "tato"]) >>> # Should return b"po" >>> store.get("first_key") num_keys(self: torch._C._distributed_c10d.Store) → int# Returns the number of keys set in the store. Note that this number will typically be one greater than the number of keys added by set() and add() since one key is used to coordinate all the workers using the store. Warning When used with the TCPStore, num_keys returns the number of keys written to the underlying file. If the store is destructed and another store is created with the same file, the original keys will be retained. Returns The number of keys present in the store. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> # Using TCPStore as an example, other store types can also be used >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.set("first_key", "first_value") >>> # This should return 2 >>> store.num_keys() queue_len(self: torch._C._distributed_c10d.Store, arg0: str) → int# Returns the length of the specified queue. If the queue doesn’t exist it returns 0. See queue_push for more details. Parameters key (str) – The key of the queue to get the length. queue_pop(self: torch._C._distributed_c10d.Store, key: str, block: bool = True) → bytes# Pops a value from the specified queue or waits until timeout if the queue is empty. See queue_push for more details. If block is False, a dist.QueueEmptyError will be raised if the queue is empty. Parameters key (str) – The key of the queue to pop from. block (bool) – Whether to block waiting for the key or immediately return. queue_push(self: torch._C._distributed_c10d.Store, arg0: str, arg1: str) → None# Pushes a value into the specified queue. Using the same key for queues and set/get operations may result in unexpected behavior. wait/check operations are supported for queues. wait with queues will only wake one waiting worker rather than all. Parameters key (str) – The key of the queue to push to. value (str) – The value to push into the queue. set(self: torch._C._distributed_c10d.Store, arg0: str, arg1: str) → None# Inserts the key-value pair into the store based on the supplied key and value. If key already exists in the store, it will overwrite the old value with the new supplied value. Parameters key (str) – The key to be added to the store. value (str) – The value associated with key to be added to the store. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.set("first_key", "first_value") >>> # Should return "first_value" >>> store.get("first_key") set_timeout(self: torch._C._distributed_c10d.Store, arg0: datetime.timedelta) → None# Sets the store’s default timeout. This timeout is used during initialization and in wait() and get(). Parameters timeout (timedelta) – timeout to be set in the store. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> # Using TCPStore as an example, other store types can also be used >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.set_timeout(timedelta(seconds=10)) >>> # This will throw an exception after 10 seconds >>> store.wait(["bad_key"]) property timeout# Gets the timeout of the store. wait(*args, **kwargs)# Overloaded function. wait(self: torch._C._distributed_c10d.Store, arg0: collections.abc.Sequence[str]) -> None Waits for each key in keys to be added to the store. If not all keys are set before the timeout (set during store initialization), then wait will throw an exception. Parameters keys (list) – List of keys on which to wait until they are set in the store. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> # Using TCPStore as an example, other store types can also be used >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> # This will throw an exception after 30 seconds >>> store.wait(["bad_key"]) wait(self: torch._C._distributed_c10d.Store, arg0: collections.abc.Sequence[str], arg1: datetime.timedelta) -> None Waits for each key in keys to be added to the store, and throws an exception if the keys have not been set by the supplied timeout. Parameters keys (list) – List of keys on which to wait until they are set in the store. timeout (timedelta) – Time to wait for the keys to be added before throwing an exception. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> # Using TCPStore as an example, other store types can also be used >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> # This will throw an exception after 10 seconds >>> store.wait(["bad_key"], timedelta(seconds=10)) class torch.distributed.TCPStore# A TCP-based distributed key-value store implementation. The server store holds the data, while the client stores can connect to the server store over TCP and perform actions such as set() to insert a key-value pair, get() to retrieve a key-value pair, etc. There should always be one server store initialized because the client store(s) will wait for the server to establish a connection. Parameters host_name (str) – The hostname or IP Address the server store should run on. port (int) – The port on which the server store should listen for incoming requests. world_size (int, optional) – The total number of store users (number of clients + 1 for the server). Default is None (None indicates a non-fixed number of store users). is_master (bool, optional) – True when initializing the server store and False for client stores. Default is False. timeout (timedelta, optional) – Timeout used by the store during initialization and for methods such as get() and wait(). Default is timedelta(seconds=300) wait_for_workers (bool, optional) – Whether to wait for all the workers to connect with the server store. This is only applicable when world_size is a fixed value. Default is True. multi_tenant (bool, optional) – If True, all TCPStore instances in the current process with the same host/port will use the same underlying TCPServer. Default is False. master_listen_fd (int, optional) – If specified, the underlying TCPServer will listen on this file descriptor, which must be a socket already bound to port. To bind an ephemeral port we recommend setting the port to 0 and reading .port. Default is None (meaning the server creates a new socket and attempts to bind it to port). use_libuv (bool, optional) – If True, use libuv for TCPServer backend. Default is True. Example::>>> import torch.distributed as dist >>> from datetime import timedelta >>> # Run on process 1 (server) >>> server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30)) >>> # Run on process 2 (client) >>> client_store = dist.TCPStore("127.0.0.1", 1234, 2, False) >>> # Use any of the store methods from either the client or server after initialization >>> server_store.set("first_key", "first_value") >>> client_store.get("first_key") __init__(self: torch._C._distributed_c10d.TCPStore, host_name: str, port: SupportsInt, world_size: SupportsInt | None = None, is_master: bool = False, timeout: datetime.timedelta = datetime.timedelta(seconds=300), wait_for_workers: bool = True, multi_tenant: bool = False, master_listen_fd: SupportsInt | None = None, use_libuv: bool = True) → None# Creates a new TCPStore. property host# Gets the hostname on which the store listens for requests. property libuvBackend# Returns True if it’s using the libuv backend. property port# Gets the port number on which the store listens for requests. class torch.distributed.HashStore# A thread-safe store implementation based on an underlying hashmap. This store can be used within the same process (for example, by other threads), but cannot be used across processes. Example::>>> import torch.distributed as dist >>> store = dist.HashStore() >>> # store can be used from other threads >>> # Use any of the store methods after initialization >>> store.set("first_key", "first_value") __init__(self: torch._C._distributed_c10d.HashStore) → None# Creates a new HashStore. class torch.distributed.FileStore# A store implementation that uses a file to store the underlying key-value pairs. Parameters file_name (str) – path of the file in which to store the key-value pairs world_size (int, optional) – The total number of processes using the store. Default is -1 (a negative value indicates a non-fixed number of store users). Example::>>> import torch.distributed as dist >>> store1 = dist.FileStore("/tmp/filestore", 2) >>> store2 = dist.FileStore("/tmp/filestore", 2) >>> # Use any of the store methods from either the client or server after initialization >>> store1.set("first_key", "first_value") >>> store2.get("first_key") __init__(self: torch._C._distributed_c10d.FileStore, file_name: str, world_size: SupportsInt = -1) → None# Creates a new FileStore. property path# Gets the path of the file used by FileStore to store key-value pairs. class torch.distributed.PrefixStore# A wrapper around any of the 3 key-value stores (TCPStore, FileStore, and HashStore) that adds a prefix to each key inserted to the store. Parameters prefix (str) – The prefix string that is prepended to each key before being inserted into the store. store (torch.distributed.store) – A store object that forms the underlying key-value store. __init__(self: torch._C._distributed_c10d.PrefixStore, prefix: str, store: torch._C._distributed_c10d.Store) → None# Creates a new PrefixStore. property underlying_store# Gets the underlying store object that PrefixStore wraps around. Profiling Collective Communication# Note that you can use torch.profiler (recommended, only available after 1.8.1) or torch.autograd.profiler to profile collective communication and point-to-point communication APIs mentioned here. All out-of-the-box backends (gloo, nccl, mpi) are supported and collective communication usage will be rendered as expected in profiling output/traces. Profiling your code is the same as any regular torch operator: import torch import torch.distributed as dist with torch.profiler(): tensor = torch.randn(20, 10) dist.all_reduce(tensor) Please refer to the profiler documentation for a full overview of profiler features. Multi-GPU collective functions# Warning The multi-GPU functions (which stand for multiple GPUs per CPU thread) are deprecated. As of today, PyTorch Distributed’s preferred programming model is one device per thread, as exemplified by the APIs in this document. If you are a backend developer and want to support multiple devices per thread, please contact PyTorch Distributed’s maintainers. Object collectives# Warning Object collectives have a number of serious limitations. Read further to determine if they are safe to use for your use case. Object collectives are a set of collective-like operations that work on arbitrary Python objects, as long as they can be pickled. There are various collective patterns implemented (e.g. broadcast, all_gather, …) but they each roughly follow this pattern: convert the input object into a pickle (raw bytes), then shove it into a byte tensor communicate the size of this byte tensor to peers (first collective operation) allocate appropriately sized tensor to perform the real collective communicate the object data (second collective operation) convert raw data back into Python (unpickle) Object collectives sometimes have surprising performance or memory characteristics that lead to long runtimes or OOMs, and thus they should be used with caution. Here are some common issues. Asymmetric pickle/unpickle time - Pickling objects can be slow, depending on the number, type and size of the objects. When the collective has a fan-in (e.g. gather_object), the receiving rank(s) must unpickle N times more objects than the sending rank(s) had to pickle, which can cause other ranks to time out on their next collective. Inefficient tensor communication - Tensors should be sent via regular collective APIs, not object collective APIs. It is possible to send Tensors via object collective APIs, but they will be serialized and deserialized (including a CPU-sync and device-to-host copy in the case of non-CPU tensors), and in almost every case other than debugging or troubleshooting code, it would be worth the trouble to refactor the code to use non-object collectives instead. Unexpected tensor devices - If you still want to send tensors via object collectives, there is another aspect specific to cuda (and possibly other accelerators) tensors. If you pickle a tensor that is currently on cuda:3, and then unpickle it, you will get another tensor on cuda:3 regardless of which process you are on, or which CUDA device is the ‘default’ device for that process. With regular tensor collective APIs, ‘output tensors’ will always be on the same, local device, which is generally what you’d expect. Unpickling a tensor will implicitly activate a CUDA context if it is the first time a GPU is used by the process, which can waste significant amounts of GPU memory. This issue can be avoided by moving tensors to CPU before passing them as inputs to an object collective. Third-party backends# Besides the builtin GLOO/MPI/NCCL backends, PyTorch distributed supports third-party backends through a run-time register mechanism. For references on how to develop a third-party backend through C++ Extension, please refer to Tutorials - Custom C++ and CUDA Extensions and test/cpp_extensions/cpp_c10d_extension.cpp. The capability of third-party backends are decided by their own implementations. The new backend derives from c10d::ProcessGroup and registers the backend name and the instantiating interface through torch.distributed.Backend.register_backend() when imported. When manually importing this backend and invoking torch.distributed.init_process_group() with the corresponding backend name, the torch.distributed package runs on the new backend. Warning The support of third-party backend is experimental and subject to change. Launch utility# The torch.distributed package also provides a launch utility in torch.distributed.launch. This helper utility can be used to launch multiple processes per node for distributed training. Module torch.distributed.launch. torch.distributed.launch is a module that spawns up multiple distributed training processes on each of the training nodes. Warning This module is going to be deprecated in favor of torchrun. The utility can be used for single-node distributed training, in which one or more processes per node will be spawned. The utility can be used for either CPU training or GPU training. If the utility is used for GPU training, each distributed process will be operating on a single GPU. This can achieve well-improved single-node training performance. It can also be used in multi-node distributed training, by spawning up multiple processes on each node for well-improved multi-node distributed training performance as well. This will especially be beneficial for systems with multiple Infiniband interfaces that have direct-GPU support, since all of them can be utilized for aggregated communication bandwidth. In both cases of single-node distributed training or multi-node distributed training, this utility will launch the given number of processes per node (--nproc-per-node). If used for GPU training, this number needs to be less or equal to the number of GPUs on the current system (nproc_per_node), and each process will be operating on a single GPU from GPU 0 to GPU (nproc_per_node - 1). How to use this module: Single-Node multi-process distributed training python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other arguments of your training script) Multi-Node multi-process distributed training: (e.g. two nodes) Node 1: (IP: 192.168.1.1, and has a free port: 1234) python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE --nnodes=2 --node-rank=0 --master-addr="192.168.1.1" --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other arguments of your training script) Node 2: python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE --nnodes=2 --node-rank=1 --master-addr="192.168.1.1" --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other arguments of your training script) To look up what optional arguments this module offers: python -m torch.distributed.launch --help Important Notices: 1. This utility and multi-process distributed (single-node or multi-node) GPU training currently only achieves the best performance using the NCCL distributed backend. Thus NCCL backend is the recommended backend to use for GPU training. 2. In your training program, you must parse the command-line argument: --local-rank=LOCAL_PROCESS_RANK, which will be provided by this module. If your training program uses GPUs, you should ensure that your code only runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: Parsing the local_rank argument >>> import argparse >>> parser = argparse.ArgumentParser() >>> parser.add_argument("--local-rank", "--local_rank", type=int) >>> args = parser.parse_args() Set your device to local rank using either >>> torch.cuda.set_device(args.local_rank) # before your code runs or >>> with torch.cuda.device(args.local_rank): >>> # your code to run >>> ... Changed in version 2.0.0: The launcher will passes the --local-rank= argument to your script. From PyTorch 2.0.0 onwards, the dashed --local-rank is preferred over the previously used underscored --local_rank. For backward compatibility, it may be necessary for users to handle both cases in their argument parsing code. This means including both "--local-rank" and "--local_rank" in the argument parser. If only "--local_rank" is provided, the launcher will trigger an error: “error: unrecognized arguments: –local-rank=”. For training code that only supports PyTorch 2.0.0+, including "--local-rank" should be sufficient. 3. In your training program, you are supposed to call the following function at the beginning to start the distributed backend. It is strongly recommended that init_method=env://. Other init methods (e.g. tcp://) may work, but env:// is the one that is officially supported by this module. >>> torch.distributed.init_process_group(backend='YOUR BACKEND', >>> init_method='env://') 4. In your training program, you can either use regular distributed functions or use torch.nn.parallel.DistributedDataParallel() module. If your training program uses GPUs for training and you would like to use torch.nn.parallel.DistributedDataParallel() module, here is how to configure it. >>> model = torch.nn.parallel.DistributedDataParallel(model, >>> device_ids=[args.local_rank], >>> output_device=args.local_rank) Please ensure that device_ids argument is set to be the only GPU device id that your code will be operating on. This is generally the local rank of the process. In other words, the device_ids needs to be [args.local_rank], and output_device needs to be args.local_rank in order to use this utility 5. Another way to pass local_rank to the subprocesses via environment variable LOCAL_RANK. This behavior is enabled when you launch the script with --use-env=True. You must adjust the subprocess example above to replace args.local_rank with os.environ['LOCAL_RANK']; the launcher will not pass --local-rank when you specify this flag. Warning local_rank is NOT globally unique: it is only unique per process on a machine. Thus, don’t use it to decide if you should, e.g., write to a networked filesystem. See pytorch/pytorch#12042 for an example of how things can go wrong if you don’t do this correctly. Spawn utility# The Multiprocessing package - torch.multiprocessing package also provides a spawn function in torch.multiprocessing.spawn(). This helper function can be used to spawn multiple processes. It works by passing in the function that you want to run and spawns N processes to run it. This can be used for multiprocess distributed training as well. For references on how to use it, please refer to PyTorch example - ImageNet implementation Note that this function requires Python 3.4 or higher. Debugging torch.distributed applications# Debugging distributed applications can be challenging due to hard to understand hangs, crashes, or inconsistent behavior across ranks. torch.distributed provides a suite of tools to help debug training applications in a self-serve fashion: Python Breakpoint# It is extremely convenient to use python’s debugger in a distributed environment, but because it does not work out of the box many people do not use it at all. PyTorch offers a customized wrapper around pdb that streamlines the process. torch.distributed.breakpoint makes this process easy. Internally, it customizes pdb’s breakpoint behavior in two ways but otherwise behaves as normal pdb. Attaches the debugger only on one rank (specified by the user). Ensures all other ranks stop, by using a torch.distributed.barrier() that will release once the debugged rank issues a continue Reroutes stdin from the child process such that it connects to your terminal. To use it, simply issue torch.distributed.breakpoint(rank) on all ranks, using the same value for rank in each case. Monitored Barrier# As of v1.10, torch.distributed.monitored_barrier() exists as an alternative to torch.distributed.barrier() which fails with helpful information about which rank may be faulty when crashing, i.e. not all ranks calling into torch.distributed.monitored_barrier() within the provided timeout. torch.distributed.monitored_barrier() implements a host-side barrier using send/recv communication primitives in a process similar to acknowledgements, allowing rank 0 to report which rank(s) failed to acknowledge the barrier in time. As an example, consider the following function where rank 1 fails to call into torch.distributed.monitored_barrier() (in practice this could be due to an application bug or hang in a previous collective): import os from datetime import timedelta import torch import torch.distributed as dist import torch.multiprocessing as mp def worker(rank): dist.init_process_group("nccl", rank=rank, world_size=2) # monitored barrier requires gloo process group to perform host-side sync. group_gloo = dist.new_group(backend="gloo") if rank not in [1]: dist.monitored_barrier(group=group_gloo, timeout=timedelta(seconds=2)) if __name__ == "__main__": os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29501" mp.spawn(worker, nprocs=2, args=()) The following error message is produced on rank 0, allowing the user to determine which rank(s) may be faulty and investigate further: RuntimeError: Rank 1 failed to pass monitoredBarrier in 2000 ms Original exception: [gloo/transport/tcp/pair.cc:598] Connection closed by peer [2401:db00:eef0:1100:3560:0:1c05:25d]:8594 TORCH_DISTRIBUTED_DEBUG# With TORCH_CPP_LOG_LEVEL=INFO, the environment variable TORCH_DISTRIBUTED_DEBUG can be used to trigger additional useful logging and collective synchronization checks to ensure all ranks are synchronized appropriately. TORCH_DISTRIBUTED_DEBUG can be set to either OFF (default), INFO, or DETAIL depending on the debugging level required. Please note that the most verbose option, DETAIL may impact the application performance and thus should only be used when debugging issues. Setting TORCH_DISTRIBUTED_DEBUG=INFO will result in additional debug logging when models trained with torch.nn.parallel.DistributedDataParallel() are initialized, and TORCH_DISTRIBUTED_DEBUG=DETAIL will additionally log runtime performance statistics a select number of iterations. These runtime statistics include data such as forward time, backward time, gradient communication time, etc. As an example, given the following application: import os import torch import torch.distributed as dist import torch.multiprocessing as mp class TwoLinLayerNet(torch.nn.Module): def __init__(self): super().__init__() self.a = torch.nn.Linear(10, 10, bias=False) self.b = torch.nn.Linear(10, 1, bias=False) def forward(self, x): a = self.a(x) b = self.b(x) return (a, b) def worker(rank): dist.init_process_group("nccl", rank=rank, world_size=2) torch.cuda.set_device(rank) print("init model") model = TwoLinLayerNet().cuda() print("init ddp") ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) inp = torch.randn(10, 10).cuda() print("train") for _ in range(20): output = ddp_model(inp) loss = output[0] + output[1] loss.sum().backward() if __name__ == "__main__": os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29501" os.environ["TORCH_CPP_LOG_LEVEL"]="INFO" os.environ[ "TORCH_DISTRIBUTED_DEBUG" ] = "DETAIL" # set to DETAIL for runtime logging. mp.spawn(worker, nprocs=2, args=()) The following logs are rendered at initialization time: I0607 16:10:35.739390 515217 logger.cpp:173] [Rank 0]: DDP Initialized with: broadcast_buffers: 1 bucket_cap_bytes: 26214400 find_unused_parameters: 0 gradient_as_bucket_view: 0 is_multi_device_module: 0 iteration: 0 num_parameter_tensors: 2 output_device: 0 rank: 0 total_parameter_size_bytes: 440 world_size: 2 backend_name: nccl bucket_sizes: 440 cuda_visible_devices: N/A device_ids: 0 dtypes: float master_addr: localhost master_port: 29501 module_name: TwoLinLayerNet nccl_async_error_handling: N/A nccl_blocking_wait: N/A nccl_debug: WARN nccl_ib_timeout: N/A nccl_nthreads: N/A nccl_socket_ifname: N/A torch_distributed_debug: INFO The following logs are rendered during runtime (when TORCH_DISTRIBUTED_DEBUG=DETAIL is set): I0607 16:18:58.085681 544067 logger.cpp:344] [Rank 1 / 2] Training TwoLinLayerNet unused_parameter_size=0 Avg forward compute time: 40838608 Avg backward compute time: 5983335 Avg backward comm. time: 4326421 Avg backward comm/comp overlap time: 4207652 I0607 16:18:58.085693 544066 logger.cpp:344] [Rank 0 / 2] Training TwoLinLayerNet unused_parameter_size=0 Avg forward compute time: 42850427 Avg backward compute time: 3885553 Avg backward comm. time: 2357981 Avg backward comm/comp overlap time: 2234674 In addition, TORCH_DISTRIBUTED_DEBUG=INFO enhances crash logging in torch.nn.parallel.DistributedDataParallel() due to unused parameters in the model. Currently, find_unused_parameters=True must be passed into torch.nn.parallel.DistributedDataParallel() initialization if there are parameters that may be unused in the forward pass, and as of v1.10, all model outputs are required to be used in loss computation as torch.nn.parallel.DistributedDataParallel() does not support unused parameters in the backwards pass. These constraints are challenging especially for larger models, thus when crashing with an error, torch.nn.parallel.DistributedDataParallel() will log the fully qualified name of all parameters that went unused. For example, in the above application, if we modify loss to be instead computed as loss = output[1], then TwoLinLayerNet.a does not receive a gradient in the backwards pass, and thus results in DDP failing. On a crash, the user is passed information about parameters which went unused, which may be challenging to manually find for large models: RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by making sure all `forward` function outputs participate in calculating loss. If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return va lue of `forward` of your module when reporting this issue (e.g. list, dict, iterable). Parameters which did not receive grad for rank 0: a.weight Parameter indices which did not receive grad for rank 0: 0 Setting TORCH_DISTRIBUTED_DEBUG=DETAIL will trigger additional consistency and synchronization checks on every collective call issued by the user either directly or indirectly (such as DDP allreduce). This is done by creating a wrapper process group that wraps all process groups returned by torch.distributed.init_process_group() and torch.distributed.new_group() APIs. As a result, these APIs will return a wrapper process group that can be used exactly like a regular process group, but performs consistency checks before dispatching the collective to an underlying process group. Currently, these checks include a torch.distributed.monitored_barrier(), which ensures all ranks complete their outstanding collective calls and reports ranks which are stuck. Next, the collective itself is checked for consistency by ensuring all collective functions match and are called with consistent tensor shapes. If this is not the case, a detailed error report is included when the application crashes, rather than a hang or uninformative error message. As an example, consider the following function which has mismatched input shapes into torch.distributed.all_reduce(): import torch import torch.distributed as dist import torch.multiprocessing as mp def worker(rank): dist.init_process_group("nccl", rank=rank, world_size=2) torch.cuda.set_device(rank) tensor = torch.randn(10 if rank == 0 else 20).cuda() dist.all_reduce(tensor) torch.cuda.synchronize(device=rank) if __name__ == "__main__": os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29501" os.environ["TORCH_CPP_LOG_LEVEL"]="INFO" os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" mp.spawn(worker, nprocs=2, args=()) With the NCCL backend, such an application would likely result in a hang which can be challenging to root-cause in nontrivial scenarios. If the user enables TORCH_DISTRIBUTED_DEBUG=DETAIL and reruns the application, the following error message reveals the root cause: work = default_pg.allreduce([tensor], opts) RuntimeError: Error when verifying shape tensors for collective ALLREDUCE on rank 0. This likely indicates that input shapes into the collective are mismatched across ranks. Got shapes: 10 [ torch.LongTensor{1} ] Note For fine-grained control of the debug level during runtime the functions torch.distributed.set_debug_level(), torch.distributed.set_debug_level_from_env(), and torch.distributed.get_debug_level() can also be used. In addition, TORCH_DISTRIBUTED_DEBUG=DETAIL can be used in conjunction with TORCH_SHOW_CPP_STACKTRACES=1 to log the entire callstack when a collective desynchronization is detected. These collective desynchronization checks will work for all applications that use c10d collective calls backed by process groups created with the torch.distributed.init_process_group() and torch.distributed.new_group() APIs. Logging# In addition to explicit debugging support via torch.distributed.monitored_barrier() and TORCH_DISTRIBUTED_DEBUG, the underlying C++ library of torch.distributed also outputs log messages at various levels. These messages can be helpful to understand the execution state of a distributed training job and to troubleshoot problems such as network connection failures. The following matrix shows how the log level can be adjusted via the combination of TORCH_CPP_LOG_LEVEL and TORCH_DISTRIBUTED_DEBUG environment variables. TORCH_CPP_LOG_LEVEL TORCH_DISTRIBUTED_DEBUG Effective Log Level ERROR ignored Error WARNING ignored Warning INFO ignored Info INFO INFO Debug INFO DETAIL Trace (a.k.a. All) Distributed components raise custom Exception types derived from RuntimeError: torch.distributed.DistError: This is the base type of all distributed exceptions. torch.distributed.DistBackendError: This exception is thrown when a backend-specific error occurs. For example, if the NCCL backend is used and the user attempts to use a GPU that is not available to the NCCL library. torch.distributed.DistNetworkError: This exception is thrown when networking libraries encounter errors (ex: Connection reset by peer) torch.distributed.DistStoreError: This exception is thrown when the Store encounters an error (ex: TCPStore timeout) class torch.distributed.DistError# Exception raised when an error occurs in the distributed library class torch.distributed.DistBackendError# Exception raised when a backend error occurs in distributed class torch.distributed.DistNetworkError# Exception raised when a network error occurs in distributed class torch.distributed.DistStoreError# Exception raised when an error occurs in the distributed store If you are running single node training, it may be convenient to interactively breakpoint your script. We offer a way to conveniently breakpoint a single rank: torch.distributed.breakpoint(rank=0, skip=0, timeout_s=3600)[source]# Set a breakpoint, but only on a single rank. All other ranks will wait for you to be done with the breakpoint before continuing. Parameters rank (int) – Which rank to break on. Default: 0 skip (int) – Skip the first skip calls to this breakpoint. Default: 0. - -``` -torch.distributed -``` - -**Pattern 3:** Initialization# The package needs to be initialized using the torch.distributed.init_process_group() or torch.distributed.device_mesh.init_device_mesh() function before calling any other methods. Both block until all processes have joined. Warning Initialization is not thread-safe. Process group creation should be performed from a single thread, to prevent inconsistent ‘UUID’ assignment across ranks, and to prevent races during initialization that can lead to hangs. torch.distributed.is_available()[source]# Return True if the distributed package is available. Otherwise, torch.distributed does not expose any other APIs. Currently, torch.distributed is available on Linux, MacOS and Windows. Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source. Currently, the default value is USE_DISTRIBUTED=1 for Linux and Windows, USE_DISTRIBUTED=0 for MacOS. Return type bool torch.distributed.init_process_group(backend=None, init_method=None, timeout=None, world_size=-1, rank=-1, store=None, group_name='', pg_options=None, device_id=None)[source]# Initialize the default distributed process group. This will also initialize the distributed package. There are 2 main ways to initialize a process group: Specify store, rank, and world_size explicitly. Specify init_method (a URL string) which indicates where/how to discover peers. Optionally specify rank and world_size, or encode all required parameters in the URL and omit them. If neither is specified, init_method is assumed to be “env://”. Parameters backend (str or Backend, optional) – The backend to use. Depending on build-time configurations, valid values include mpi, gloo, nccl, ucc, xccl or one that is registered by a third-party plugin. Since 2.6, if backend is not provided, c10d will use a backend registered for the device type indicated by the device_id kwarg (if provided). The known default registrations today are: nccl for cuda, gloo for cpu, xccl for xpu. If neither backend nor device_id is provided, c10d will detect the accelerator on the run-time machine and use a backend registered for that detected accelerator (or cpu). This field can be given as a lowercase string (e.g., "gloo"), which can also be accessed via Backend attributes (e.g., Backend.GLOO). If using multiple processes per machine with nccl backend, each process must have exclusive access to every GPU it uses, as sharing GPUs between processes can result in deadlock or NCCL invalid usage. ucc backend is experimental. Default backend for the device can be queried with get_default_backend_for_device(). init_method (str, optional) – URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified. Mutually exclusive with store. world_size (int, optional) – Number of processes participating in the job. Required if store is specified. rank (int, optional) – Rank of the current process (it should be a number between 0 and world_size-1). Required if store is specified. store (Store, optional) – Key/value store accessible to all workers, used to exchange connection/address information. Mutually exclusive with init_method. timeout (timedelta, optional) – Timeout for operations executed against the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends. This is the duration after which collectives will be aborted asynchronously and the process will crash. This is done since CUDA execution is async and it is no longer safe to continue executing user code since failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. group_name (str, optional, deprecated) – Group name. This argument is ignored pg_options (ProcessGroupOptions, optional) – process group options specifying what additional options need to be passed in during the construction of specific process groups. As of now, the only options we support is ProcessGroupNCCL.Options for the nccl backend, is_high_priority_stream can be specified so that the nccl backend can pick up high priority cuda streams when there’re compute kernels waiting. For other available options to config nccl, See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t device_id (torch.device | int, optional) – a single, specific device this process will work on, allowing for backend-specific optimizations. Currently this has two effects, only under NCCL: the communicator is immediately formed (calling ncclCommInit* immediately rather than the normal lazy call) and sub-groups will use ncclCommSplit when possible to avoid unnecessary overhead of group creation. If you want to know NCCL initialization error early, you can also use this field. If an int is provided, the API assumes that the accelerator type at compile time will be used. Note To enable backend == Backend.MPI, PyTorch needs to be built from source on a system that supports MPI. Note Support for multiple backends is experimental. Currently when no backend is specified, both gloo and nccl backends will be created. The gloo backend will be used for collectives with CPU tensors and the nccl backend will be used for collectives with CUDA tensors. A custom backend can be specified by passing in a string with format “:,:”, e.g. “cpu:gloo,cuda:custom_backend”. torch.distributed.device_mesh.init_device_mesh(device_type, mesh_shape, *, mesh_dim_names=None, backend_override=None)[source]# Initializes a DeviceMesh based on device_type, mesh_shape, and mesh_dim_names parameters. This creates a DeviceMesh with an n-dimensional array layout, where n is the length of mesh_shape. If mesh_dim_names is provided, each dimension is labeled as mesh_dim_names[i]. Note init_device_mesh follows SPMD programming model, meaning the same PyTorch Python program runs on all processes/ranks in the cluster. Ensure mesh_shape (the dimensions of the nD array describing device layout) is identical across all ranks. Inconsistent mesh_shape may lead to hanging. Note If no process group is found, init_device_mesh will initialize distributed process group/groups required for distributed communications behind the scene. Parameters device_type (str) – The device type of the mesh. Currently supports: “cpu”, “cuda/cuda-like”, “xpu”. Passing in a device type with a GPU index, such as “cuda:0”, is not allowed. mesh_shape (Tuple[int]) – A tuple defining the dimensions of the multi-dimensional array describing the layout of devices. mesh_dim_names (Tuple[str], optional) – A tuple of mesh dimension names to assign to each dimension of the multi-dimensional array describing the layout of devices. Its length must match the length of mesh_shape. Each string in mesh_dim_names must be unique. backend_override (Dict[int | str, tuple[str, Options] | str | Options], optional) – Overrides for some or all of the ProcessGroups that will be created for each mesh dimension. Each key can be either the index of a dimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the name of the backend and its options, or just one of these two components (in which case the other will be set to its default value). Returns A DeviceMesh object representing the device layout. Return type DeviceMesh Example: >>> from torch.distributed.device_mesh import init_device_mesh >>> >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) torch.distributed.is_initialized()[source]# Check if the default process group has been initialized. Return type bool torch.distributed.is_mpi_available()[source]# Check if the MPI backend is available. Return type bool torch.distributed.is_nccl_available()[source]# Check if the NCCL backend is available. Return type bool torch.distributed.is_gloo_available()[source]# Check if the Gloo backend is available. Return type bool torch.distributed.distributed_c10d.is_xccl_available()[source]# Check if the XCCL backend is available. Return type bool torch.distributed.is_torchelastic_launched()[source]# Check whether this process was launched with torch.distributed.elastic (aka torchelastic). The existence of TORCHELASTIC_RUN_ID environment variable is used as a proxy to determine whether the current process was launched with torchelastic. This is a reasonable proxy since TORCHELASTIC_RUN_ID maps to the rendezvous id which is always a non-null value indicating the job id for peer discovery purposes.. Return type bool torch.distributed.get_default_backend_for_device(device)[source]# Return the default backend for the given device. Parameters device (Union[str, torch.device]) – The device to get the default backend for. Returns The default backend for the given device as a lower case string. Return type str Currently three initialization methods are supported: TCP initialization# There are two ways to initialize using TCP, both requiring a network address reachable from all processes and a desired world_size. The first way requires specifying an address that belongs to the rank 0 process. This initialization method requires that all processes have manually specified ranks. Note that multicast address is not supported anymore in the latest distributed package. group_name is deprecated as well. import torch.distributed as dist # Use address of one of the machines dist.init_process_group(backend, init_method='tcp://10.1.1.20:23456', rank=args.rank, world_size=4) Shared file-system initialization# Another initialization method makes use of a file system that is shared and visible from all machines in a group, along with a desired world_size. The URL should start with file:// and contain a path to a non-existent file (in an existing directory) on a shared file system. File-system initialization will automatically create that file if it doesn’t exist, but will not delete the file. Therefore, it is your responsibility to make sure that the file is cleaned up before the next init_process_group() call on the same file path/name. Note that automatic rank assignment is not supported anymore in the latest distributed package and group_name is deprecated as well. Warning This method assumes that the file system supports locking using fcntl - most local systems and NFS support it. Warning This method will always create the file and try its best to clean up and remove the file at the end of the program. In other words, each initialization with the file init method will need a brand new empty file in order for the initialization to succeed. If the same file used by the previous initialization (which happens not to get cleaned up) is used again, this is unexpected behavior and can often cause deadlocks and failures. Therefore, even though this method will try its best to clean up the file, if the auto-delete happens to be unsuccessful, it is your responsibility to ensure that the file is removed at the end of the training to prevent the same file to be reused again during the next time. This is especially important if you plan to call init_process_group() multiple times on the same file name. In other words, if the file is not removed/cleaned up and you call init_process_group() again on that file, failures are expected. The rule of thumb here is that, make sure that the file is non-existent or empty every time init_process_group() is called. import torch.distributed as dist # rank should always be specified dist.init_process_group(backend, init_method='file:///mnt/nfs/sharedfile', world_size=4, rank=args.rank) Environment variable initialization# This method will read the configuration from environment variables, allowing one to fully customize how the information is obtained. The variables to be set are: MASTER_PORT - required; has to be a free port on machine with rank 0 MASTER_ADDR - required (except for rank 0); address of rank 0 node WORLD_SIZE - required; can be set either here, or in a call to init function RANK - required; can be set either here, or in a call to init function The machine with rank 0 will be used to set up all connections. This is the default method, meaning that init_method does not have to be specified (or can be env://). Improving initialization time# TORCH_GLOO_LAZY_INIT - establishes connections on demand rather than using a full mesh which can greatly improve initialization time for non all2all operations. - -``` -torch.distributed.init_process_group() -``` - -**Pattern 4:** Example: - -``` ->>> from torch.distributed.device_mesh import init_device_mesh ->>> ->>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) ->>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) -``` - -**Pattern 5:** Groups# By default collectives operate on the default group (also called the world) and require all processes to enter the distributed function call. However, some workloads can benefit from more fine-grained communication. This is where distributed groups come into play. new_group() function can be used to create new groups, with arbitrary subsets of all processes. It returns an opaque group handle that can be given as a group argument to all collectives (collectives are distributed functions to exchange information in certain well-known programming patterns). torch.distributed.new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None, device_id=None)[source]# Create a new distributed group. This function requires that all processes in the main group (i.e. all processes that are part of the distributed job) enter this function, even if they are not going to be members of the group. Additionally, groups should be created in the same order in all processes. Warning Safe concurrent usage: When using multiple process groups with the NCCL backend, the user must ensure a globally consistent execution order of collectives across ranks. If multiple threads within a process issue collectives, explicit synchronization is necessary to ensure consistent ordering. When using async variants of torch.distributed communication APIs, a work object is returned and the communication kernel is enqueued on a separate CUDA stream, allowing overlap of communication and computation. Once one or more async ops have been issued on one process group, they must be synchronized with other cuda streams by calling work.wait() before using another process group. See Using multiple NCCL communicators concurrently for more details. Parameters ranks (list[int]) – List of ranks of group members. If None, will be set to all ranks. Default is None. timeout (timedelta, optional) – see init_process_group for details and default value. backend (str or Backend, optional) – The backend to use. Depending on build-time configurations, valid values are gloo and nccl. By default uses the same backend as the global group. This field should be given as a lowercase string (e.g., "gloo"), which can also be accessed via Backend attributes (e.g., Backend.GLOO). If None is passed in, the backend corresponding to the default process group will be used. Default is None. pg_options (ProcessGroupOptions, optional) – process group options specifying what additional options need to be passed in during the construction of specific process groups. i.e. for the nccl backend, is_high_priority_stream can be specified so that process group can pick up high priority cuda streams. For other available options to config nccl, See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-tuse_local_synchronization (bool, optional): perform a group-local barrier at the end of the process group creation. This is different in that non-member ranks don’t need to call into API and don’t join the barrier. group_desc (str, optional) – a string to describe the process group. device_id (torch.device, optional) – a single, specific device to “bind” this process to, The new_group call will try to initialize a communication backend immediately for the device if this field is given. Returns A handle of distributed group that can be given to collective calls or GroupMember.NON_GROUP_MEMBER if the rank is not part of ranks. N.B. use_local_synchronization doesn’t work with MPI. N.B. While use_local_synchronization=True can be significantly faster with larger clusters and small process groups, care must be taken since it changes cluster behavior as non-member ranks don’t join the group barrier(). N.B. use_local_synchronization=True can lead to deadlocks when each rank creates multiple overlapping process groups. To avoid that, make sure all ranks follow the same global creation order. torch.distributed.get_group_rank(group, global_rank)[source]# Translate a global rank into a group rank. global_rank must be part of group otherwise this raises RuntimeError. Parameters group (ProcessGroup) – ProcessGroup to find the relative rank. global_rank (int) – Global rank to query. Returns Group rank of global_rank relative to group Return type int N.B. calling this function on the default process group returns identity torch.distributed.get_global_rank(group, group_rank)[source]# Translate a group rank into a global rank. group_rank must be part of group otherwise this raises RuntimeError. Parameters group (ProcessGroup) – ProcessGroup to find the global rank from. group_rank (int) – Group rank to query. Returns Global rank of group_rank relative to group Return type int N.B. calling this function on the default process group returns identity torch.distributed.get_process_group_ranks(group)[source]# Get all ranks associated with group. Parameters group (Optional[ProcessGroup]) – ProcessGroup to get all ranks from. If None, the default process group will be used. Returns List of global ranks ordered by group rank. Return type list[int] - -``` -new_group() -``` - -**Pattern 6:** Warning Safe concurrent usage: When using multiple process groups with the NCCL backend, the user must ensure a globally consistent execution order of collectives across ranks. If multiple threads within a process issue collectives, explicit synchronization is necessary to ensure consistent ordering. When using async variants of torch.distributed communication APIs, a work object is returned and the communication kernel is enqueued on a separate CUDA stream, allowing overlap of communication and computation. Once one or more async ops have been issued on one process group, they must be synchronized with other cuda streams by calling work.wait() before using another process group. See Using multiple NCCL communicators concurrently for more details. - -``` -NCCL -``` - -**Pattern 7:** Note If you are using DistributedDataParallel in conjunction with the Distributed RPC Framework, you should always use torch.distributed.autograd.backward() to compute gradients and torch.distributed.optim.DistributedOptimizer for optimizing parameters. Example: >>> import torch.distributed.autograd as dist_autograd >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> import torch >>> from torch import optim >>> from torch.distributed.optim import DistributedOptimizer >>> import torch.distributed.rpc as rpc >>> from torch.distributed.rpc import RRef >>> >>> t1 = torch.rand((3, 3), requires_grad=True) >>> t2 = torch.rand((3, 3), requires_grad=True) >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2)) >>> ddp_model = DDP(my_model) >>> >>> # Setup optimizer >>> optimizer_params = [rref] >>> for param in ddp_model.parameters(): >>> optimizer_params.append(RRef(param)) >>> >>> dist_optim = DistributedOptimizer( >>> optim.SGD, >>> optimizer_params, >>> lr=0.05, >>> ) >>> >>> with dist_autograd.context() as context_id: >>> pred = ddp_model(rref.to_here()) >>> loss = loss_func(pred, target) >>> dist_autograd.backward(context_id, [loss]) >>> dist_optim.step(context_id) - -``` -torch.distributed.autograd.backward() -``` - -**Pattern 8:** static_graph (bool) – When set to True, DDP knows the trained graph is static. Static graph means 1) The set of used and unused parameters will not change during the whole training loop; in this case, it does not matter whether users set find_unused_parameters = True or not. 2) How the graph is trained will not change during the whole training loop (meaning there is no control flow depending on iterations). When static_graph is set to be True, DDP will support cases that can not be supported in the past: 1) Reentrant backwards. 2) Activation checkpointing multiple times. 3) Activation checkpointing when model has unused parameters. 4) There are model parameters that are outside of forward function. 5) Potentially improve performance when there are unused parameters, as DDP will not search graph in each iteration to detect unused parameters when static_graph is set to be True. To check whether you can set static_graph to be True, one way is to check ddp logging data at the end of your previous model training, if ddp_logging_data.get("can_set_static_graph") == True, mostly you can set static_graph = True as well. Example::>>> model_DDP = torch.nn.parallel.DistributedDataParallel(model) >>> # Training loop >>> ... >>> ddp_logging_data = model_DDP._get_ddp_logging_data() >>> static_graph = ddp_logging_data.get("can_set_static_graph") - -``` -True -``` - -## Reference Files - -This skill includes comprehensive documentation in `references/`: - -- **other.md** - Other documentation - -Use `view` to read specific reference files when detailed information is needed. - -## Working with This Skill - -### For Beginners -Start with the getting_started or tutorials reference files for foundational concepts. - -### For Specific Features -Use the appropriate category reference file (api, guides, etc.) for detailed information. - -### For Code Examples -The quick reference section above contains common patterns extracted from the official docs. - -## Resources - -### references/ -Organized documentation extracted from official sources. These files contain: -- Detailed explanations -- Code examples with language annotations -- Links to original documentation -- Table of contents for quick navigation - -### scripts/ -Add helper scripts here for common automation tasks. - -### assets/ -Add templates, boilerplate, or example projects here. - -## Notes - -- This skill was automatically generated from official documentation -- Reference files preserve the structure and examples from source docs -- Code examples include language detection for better syntax highlighting -- Quick reference patterns are extracted from common usage examples in the docs - -## Updating - -To refresh this skill with updated documentation: -1. Re-run the scraper with the same configuration -2. The skill will be rebuilt with the latest information - - diff --git a/skills/mlops/pytorch-fsdp/references/index.md b/skills/mlops/pytorch-fsdp/references/index.md deleted file mode 100644 index 0eefba993..000000000 --- a/skills/mlops/pytorch-fsdp/references/index.md +++ /dev/null @@ -1,7 +0,0 @@ -# Pytorch-Fsdp Documentation Index - -## Categories - -### Other -**File:** `other.md` -**Pages:** 15 diff --git a/skills/mlops/pytorch-fsdp/references/other.md b/skills/mlops/pytorch-fsdp/references/other.md deleted file mode 100644 index d5b6cae6f..000000000 --- a/skills/mlops/pytorch-fsdp/references/other.md +++ /dev/null @@ -1,4249 +0,0 @@ -# Pytorch-Fsdp - Other - -**Pages:** 15 - ---- - -## Distributed Data Parallel# - -**URL:** https://pytorch.org/docs/stable/notes/ddp.html - -**Contents:** -- Distributed Data Parallel# -- Example# -- Internal Design# -- Implementation# - - ProcessGroup# - - DistributedDataParallel# - - TorchDynamo DDPOptimizer# - -Created On: Jan 15, 2020 | Last Updated On: Jan 25, 2024 - -The implementation of torch.nn.parallel.DistributedDataParallel evolves over time. This design note is written based on the state as of v1.4. - -torch.nn.parallel.DistributedDataParallel (DDP) transparently performs distributed data parallel training. This page describes how it works and reveals implementation details. - -Let us start with a simple torch.nn.parallel.DistributedDataParallel example. This example uses a torch.nn.Linear as the local model, wraps it with DDP, and then runs one forward pass, one backward pass, and an optimizer step on the DDP model. After that, parameters on the local model will be updated, and all models on different processes should be exactly the same. - -DDP works with TorchDynamo. When used with TorchDynamo, apply the DDP model wrapper before compiling the model, such that torchdynamo can apply DDPOptimizer (graph-break optimizations) based on DDP bucket sizes. (See TorchDynamo DDPOptimizer for more information.) - -This section reveals how it works under the hood of torch.nn.parallel.DistributedDataParallel by diving into details of every step in one iteration. - -Prerequisite: DDP relies on c10d ProcessGroup for communications. Hence, applications must create ProcessGroup instances before constructing DDP. - -Construction: The DDP constructor takes a reference to the local module, and broadcasts state_dict() from the process with rank 0 to all other processes in the group to make sure that all model replicas start from the exact same state. Then, each DDP process creates a local Reducer, which later will take care of the gradients synchronization during the backward pass. To improve communication efficiency, the Reducer organizes parameter gradients into buckets, and reduces one bucket at a time. Bucket size can be configured by setting the bucket_cap_mb argument in DDP constructor. The mapping from parameter gradients to buckets is determined at the construction time, based on the bucket size limit and parameter sizes. Model parameters are allocated into buckets in (roughly) the reverse order of Model.parameters() from the given model. The reason for using the reverse order is because DDP expects gradients to become ready during the backward pass in approximately that order. The figure below shows an example. Note that, the grad0 and grad1 are in bucket1, and the other two gradients are in bucket0. Of course, this assumption might not always be true, and when that happens it could hurt DDP backward speed as the Reducer cannot kick off the communication at the earliest possible time. Besides bucketing, the Reducer also registers autograd hooks during construction, one hook per parameter. These hooks will be triggered during the backward pass when the gradient becomes ready. - -Forward Pass: The DDP takes the input and passes it to the local model, and then analyzes the output from the local model if find_unused_parameters is set to True. This mode allows running backward on a subgraph of the model, and DDP finds out which parameters are involved in the backward pass by traversing the autograd graph from the model output and marking all unused parameters as ready for reduction. During the backward pass, the Reducer would only wait for unready parameters, but it would still reduce all buckets. Marking a parameter gradient as ready does not help DDP skip buckets as for now, but it will prevent DDP from waiting for absent gradients forever during the backward pass. Note that traversing the autograd graph introduces extra overheads, so applications should only set find_unused_parameters to True when necessary. - -Backward Pass: The backward() function is directly invoked on the loss Tensor, which is out of DDP’s control, and DDP uses autograd hooks registered at construction time to trigger gradients synchronizations. When one gradient becomes ready, its corresponding DDP hook on that grad accumulator will fire, and DDP will then mark that parameter gradient as ready for reduction. When gradients in one bucket are all ready, the Reducer kicks off an asynchronous allreduce on that bucket to calculate mean of gradients across all processes. When all buckets are ready, the Reducer will block waiting for all allreduce operations to finish. When this is done, averaged gradients are written to the param.grad field of all parameters. So after the backward pass, the grad field on the same corresponding parameter across different DDP processes should be the same. - -Optimizer Step: From the optimizer’s perspective, it is optimizing a local model. Model replicas on all DDP processes can keep in sync because they all start from the same state and they have the same averaged gradients in every iteration. - -DDP requires Reducer instances on all processes to invoke allreduce in exactly the same order, which is done by always running allreduce in the bucket index order instead of actual bucket ready order. Mismatched allreduce order across processes can lead to wrong results or DDP backward hang. - -Below are pointers to the DDP implementation components. The stacked graph shows the structure of the code. - -ProcessGroup.hpp: contains the abstract API of all process group implementations. The c10d library provides 3 implementations out of the box, namely, ProcessGroupGloo, ProcessGroupNCCL, and ProcessGroupMPI. DistributedDataParallel uses ProcessGroup::broadcast() to send model states from the process with rank 0 to others during initialization and ProcessGroup::allreduce() to sum gradients. - -Store.hpp: assists the rendezvous service for process group instances to find each other. - -distributed.py: is the Python entry point for DDP. It implements the initialization steps and the forward function for the nn.parallel.DistributedDataParallel module which call into C++ libraries. Its _sync_param function performs intra-process parameter synchronization when one DDP process works on multiple devices, and it also broadcasts model buffers from the process with rank 0 to all other processes. The inter-process parameter synchronization happens in Reducer.cpp. - -comm.h: implements the coalesced broadcast helper function which is invoked to broadcast model states during initialization and synchronize model buffers before the forward pass. - -reducer.h: provides the core implementation for gradient synchronization in the backward pass. It has three entry point functions: - -Reducer: The constructor is called in distributed.py which registers Reducer::autograd_hook() to gradient accumulators. - -autograd_hook() function will be invoked by the autograd engine when a gradient becomes ready. - -prepare_for_backward() is called at the end of DDP forward pass in distributed.py. It traverses the autograd graph to find unused parameters when find_unused_parameters is set to True in DDP constructor. - -DDP’s performance advantage comes from overlapping allreduce collectives with computations during backwards. AotAutograd prevents this overlap when used with TorchDynamo for compiling a whole forward and whole backward graph, because allreduce ops are launched by autograd hooks _after_ the whole optimized backwards computation finishes. - -TorchDynamo’s DDPOptimizer helps by breaking the forward graph at the logical boundaries of DDP’s allreduce buckets during backwards. Note: the goal is to break the graph during backwards, and the simplest implementation is to break the forward graphs and then call AotAutograd and compilation on each section. This allows DDP’s allreduce hooks to fire in-between sections of backwards, and schedule communications to overlap with compute. - -See this blog post for a more in-depth explanation and experimental results, or read the docs and code at torch/_dynamo/optimizations/distributed.py - -To Debug DDPOptimizer, set TORCH_LOGS=’ddp_graphs’ for full graph dumps. For logs without graphs, add any of ‘dynamo’, ‘distributed’, or ‘dist_ddp’ to TORCH_LOGS (for basic info about bucket boundaries). To disable DDPOptimizer, set torch._dynamo.config.optimize_ddp=False. DDP and TorchDynamo should still work correctly without DDPOptimizer, but with performance degradation. - ---- - -## PyTorch documentation# - -**URL:** https://pytorch.org/docs/stable/ - -**Contents:** -- PyTorch documentation# -- Indices and tables# - -PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. - -Features described in this documentation are classified by release status: - -Stable (API-Stable): These features will be maintained long-term and there should generally be no major performance limitations or gaps in documentation. We also expect to maintain backwards compatibility (although breaking changes can happen and notice will be given one release ahead of time). - -Unstable (API-Unstable): Encompasses all features that are under active development where APIs may change based on user feedback, requisite performance improvements or because coverage across operators is not yet complete. The APIs and performance characteristics of these features may change. - ---- - -## Generic Join Context Manager# - -**URL:** https://pytorch.org/docs/stable/distributed.algorithms.join.html - -**Contents:** -- Generic Join Context Manager# - -Created On: Jun 06, 2025 | Last Updated On: Jun 06, 2025 - -The generic join context manager facilitates distributed training on uneven inputs. This page outlines the API of the relevant classes: Join, Joinable, and JoinHook. For a tutorial, see Distributed Training with Uneven Inputs Using the Join Context Manager. - -This class defines the generic join context manager, which allows custom hooks to be called after a process joins. - -These hooks should shadow the collective communications of non-joined processes to prevent hanging and erroring and to ensure algorithmic correctness. Refer to JoinHook for details about the hook definition. - -The context manager requires each participating Joinable to call the method notify_join_context() before its own per- iteration collective communications to ensure correctness. - -The context manager requires that all process_group attributes in the JoinHook objects are the same. If there are multiple JoinHook objects, then the device of the first is used. The process group and device information is used for checking for non- joined processes and for notifying processes to throw an exception if throw_on_early_termination is enabled, both of which using an all- reduce. - -joinables (List[Joinable]) – a list of the participating Joinable s; their hooks are iterated over in the given order. - -enable (bool) – a flag enabling uneven input detection; setting to False disables the context manager’s functionality and should only be set when the user knows the inputs will not be uneven (default: True). - -throw_on_early_termination (bool) – a flag controlling whether to throw an exception upon detecting uneven inputs (default: False). - -Notifies the join context manager that the calling process has not yet joined. - -Then, if throw_on_early_termination=True, checks if uneven inputs have been detected (i.e. if one process has already joined) and throws an exception if so. - -This method should be called from a Joinable object before its per-iteration collective communications. For example, this should be called at the beginning of the forward pass in DistributedDataParallel. - -Only the first Joinable object passed into the context manager performs the collective communications in this method, and for the others, this method is vacuous. - -joinable (Joinable) – the Joinable object calling this method. - -An async work handle for the all-reduce meant to notify the context manager that the process has not yet joined if joinable is the first one passed into the context manager; None otherwise. - -This defines an abstract base class for joinable classes. - -A joinable class (inheriting from Joinable) should implement join_hook(), which returns a JoinHook instance, in addition to join_device() and join_process_group() that return device and process group information, respectively. - -Return the device from which to perform collective communications needed by the join context manager. - -Return a JoinHook instance for the given Joinable. - -kwargs (dict) – a dict containing any keyword arguments to modify the behavior of the join hook at run time; all Joinable instances sharing the same join context manager are forwarded the same value for kwargs. - -Returns the process group for the collective communications needed by the join context manager itself. - -This defines a join hook, which provides two entry points in the join context manager. - -Entry points : a main hook, which is called repeatedly while there exists a non-joined process, and a post-hook, which is called once all processes have joined. - -To implement a join hook for the generic join context manager, define a class that inherits from JoinHook and override main_hook() and post_hook() as appropriate. - -Call this hook while there exists a non-joined process to shadow collective communications in a training iteration. - -Training iteration i.e., in one forward pass, backward pass, and optimizer step. - -Call hook after all processes have joined. - -It is passed an additional bool argument is_last_joiner, which indicates if the rank is one of the last to join. - -is_last_joiner (bool) – True if the rank is one of the last to join; False otherwise. - ---- - -## Experimental Object Oriented Distributed API# - -**URL:** https://pytorch.org/docs/stable/distributed._dist2.html - -**Contents:** -- Experimental Object Oriented Distributed API# - -Created On: Jul 09, 2025 | Last Updated On: Jul 30, 2025 - -This is an experimental new API for PyTorch Distributed. This is actively in development and subject to change or deletion entirely. - -This is intended as a proving ground for more flexible and object oriented distributed APIs. - -Bases: pybind11_object - -A ProcessGroup is a communication primitive that allows for collective operations across a group of processes. - -This is a base class that provides the interface for all ProcessGroups. It is not meant to be used directly, but rather extended by subclasses. - -Bases: pybind11_object - -The type of the backend used for the process group. - -abort all operations and connections if supported by the backend - -allgather(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: collections.abc.Sequence[collections.abc.Sequence[torch.Tensor]], input_tensors: collections.abc.Sequence[torch.Tensor], opts: torch._C._distributed_c10d.AllgatherOptions = ) -> c10d::Work - -Allgathers the input tensors from all processes across the process group. - -See torch.distributed.all_gather() for more details. - -allgather(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: collections.abc.Sequence[torch.Tensor], input_tensor: torch.Tensor, timeout: datetime.timedelta | None = None) -> c10d::Work - -Allgathers the input tensors from all processes across the process group. - -See torch.distributed.all_gather() for more details. - -Allgathers the input tensors from all processes across the process group. - -See torch.distributed.all_gather() for more details. - -Allgathers the input tensors from all processes across the process group. - -See torch.distributed.all_gather() for more details. - -allreduce(self: torch._C._distributed_c10d.ProcessGroup, tensors: collections.abc.Sequence[torch.Tensor], opts: torch._C._distributed_c10d.AllreduceOptions = ) -> c10d::Work - -Allreduces the provided tensors across all processes in the process group. - -See torch.distributed.all_reduce() for more details. - -allreduce(self: torch._C._distributed_c10d.ProcessGroup, tensors: collections.abc.Sequence[torch.Tensor], op: torch._C._distributed_c10d.ReduceOp = , timeout: datetime.timedelta | None = None) -> c10d::Work - -Allreduces the provided tensors across all processes in the process group. - -See torch.distributed.all_reduce() for more details. - -allreduce(self: torch._C._distributed_c10d.ProcessGroup, tensor: torch.Tensor, op: torch._C._distributed_c10d.ReduceOp = , timeout: datetime.timedelta | None = None) -> c10d::Work - -Allreduces the provided tensors across all processes in the process group. - -See torch.distributed.all_reduce() for more details. - -Allreduces the provided tensors across all processes in the process group. - -See torch.distributed.all_reduce() for more details. - -Alltoalls the input tensors from all processes across the process group. - -See torch.distributed.all_to_all() for more details. - -alltoall_base(self: torch._C._distributed_c10d.ProcessGroup, output: torch.Tensor, input: torch.Tensor, output_split_sizes: collections.abc.Sequence[typing.SupportsInt], input_split_sizes: collections.abc.Sequence[typing.SupportsInt], opts: torch._C._distributed_c10d.AllToAllOptions = ) -> c10d::Work - -Alltoalls the input tensors from all processes across the process group. - -See torch.distributed.all_to_all() for more details. - -alltoall_base(self: torch._C._distributed_c10d.ProcessGroup, output: torch.Tensor, input: torch.Tensor, output_split_sizes: collections.abc.Sequence[typing.SupportsInt], input_split_sizes: collections.abc.Sequence[typing.SupportsInt], timeout: datetime.timedelta | None = None) -> c10d::Work - -Alltoalls the input tensors from all processes across the process group. - -See torch.distributed.all_to_all() for more details. - -barrier(self: torch._C._distributed_c10d.ProcessGroup, opts: torch._C._distributed_c10d.BarrierOptions = ) -> c10d::Work - -then all leave the call together. - -See torch.distributed.barrier() for more details. - -barrier(self: torch._C._distributed_c10d.ProcessGroup, timeout: datetime.timedelta | None = None) -> c10d::Work - -then all leave the call together. - -See torch.distributed.barrier() for more details. - -broadcast(self: torch._C._distributed_c10d.ProcessGroup, tensors: collections.abc.Sequence[torch.Tensor], opts: torch._C._distributed_c10d.BroadcastOptions = ) -> c10d::Work - -Broadcasts the tensor to all processes in the process group. - -See torch.distributed.broadcast() for more details. - -broadcast(self: torch._C._distributed_c10d.ProcessGroup, tensor: torch.Tensor, root: typing.SupportsInt, timeout: datetime.timedelta | None = None) -> c10d::Work - -Broadcasts the tensor to all processes in the process group. - -See torch.distributed.broadcast() for more details. - -gather(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: collections.abc.Sequence[collections.abc.Sequence[torch.Tensor]], input_tensors: collections.abc.Sequence[torch.Tensor], opts: torch._C._distributed_c10d.GatherOptions = ) -> c10d::Work - -Gathers the input tensors from all processes across the process group. - -See torch.distributed.gather() for more details. - -gather(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: collections.abc.Sequence[torch.Tensor], input_tensor: torch.Tensor, root: typing.SupportsInt, timeout: datetime.timedelta | None = None) -> c10d::Work - -Gathers the input tensors from all processes across the process group. - -See torch.distributed.gather() for more details. - -Get the store of this process group. - -Gets this process group description - -(Gets this process group name. It’s cluster unique) - -then all leave the call together. - -See torch.distributed.monitored_barrier() for more details. - -Get the name of this process group. - -Get the rank of this process group. - -Receives the tensor from the specified rank. - -See torch.distributed.recv() for more details. - -Receives the tensor from any source. - -See torch.distributed.recv() for more details. - -reduce(self: torch._C._distributed_c10d.ProcessGroup, tensors: collections.abc.Sequence[torch.Tensor], opts: torch._C._distributed_c10d.ReduceOptions = ) -> c10d::Work - -Reduces the provided tensors across all processes in the process group. - -See torch.distributed.reduce() for more details. - -reduce(self: torch._C._distributed_c10d.ProcessGroup, tensor: torch.Tensor, root: typing.SupportsInt, op: torch._C._distributed_c10d.ReduceOp = , timeout: datetime.timedelta | None = None) -> c10d::Work - -Reduces the provided tensors across all processes in the process group. - -See torch.distributed.reduce() for more details. - -reduce_scatter(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: collections.abc.Sequence[torch.Tensor], input_tensors: collections.abc.Sequence[collections.abc.Sequence[torch.Tensor]], opts: torch._C._distributed_c10d.ReduceScatterOptions = ) -> c10d::Work - -Reduces and scatters the input tensors from all processes across the process group. - -See torch.distributed.reduce_scatter() for more details. - -reduce_scatter(self: torch._C._distributed_c10d.ProcessGroup, output: torch.Tensor, input: collections.abc.Sequence[torch.Tensor], op: torch._C._distributed_c10d.ReduceOp = , timeout: datetime.timedelta | None = None) -> c10d::Work - -Reduces and scatters the input tensors from all processes across the process group. - -See torch.distributed.reduce_scatter() for more details. - -Reduces and scatters the input tensors from all processes across the process group. - -See torch.distributed.reduce_scatter() for more details. - -scatter(self: torch._C._distributed_c10d.ProcessGroup, output_tensors: collections.abc.Sequence[torch.Tensor], input_tensors: collections.abc.Sequence[collections.abc.Sequence[torch.Tensor]], opts: torch._C._distributed_c10d.ScatterOptions = ) -> c10d::Work - -Scatters the input tensors from all processes across the process group. - -See torch.distributed.scatter() for more details. - -scatter(self: torch._C._distributed_c10d.ProcessGroup, output_tensor: torch.Tensor, input_tensors: collections.abc.Sequence[torch.Tensor], root: typing.SupportsInt, timeout: datetime.timedelta | None = None) -> c10d::Work - -Scatters the input tensors from all processes across the process group. - -See torch.distributed.scatter() for more details. - -Sends the tensor to the specified rank. - -See torch.distributed.send() for more details. - -Sets the default timeout for all future operations. - -shutdown the process group - -Get the size of this process group. - -Protocol for process group factories. - -Get the current process group. Thread local method. - -The current process group. - -Create a new process group with the given backend and options. This group is independent and will not be globally registered and thus not usable via the standard torch.distributed.* APIs. - -backend (str) – The backend to use for the process group. - -timeout (timedelta) – The timeout for collective operations. - -device (Union[str, device]) – The device to use for the process group. - -**kwargs (object) – All remaining arguments are passed to the backend constructor. See the backend specific documentation for details. - -Context manager for process groups. Thread local method. - -pg (ProcessGroup) – The process group to use. - -Generator[None, None, None] - -Register a new process group backend. - -name (str) – The name of the backend. - -func (ProcessGroupFactory) – The function to create the process group. - ---- - -## torch.distributed.fsdp.fully_shard# - -**URL:** https://pytorch.org/docs/stable/distributed.fsdp.fully_shard.html - -**Contents:** -- torch.distributed.fsdp.fully_shard# -- PyTorch FSDP2 (fully_shard)# - -Created On: Dec 04, 2024 | Last Updated On: Jun 16, 2025 - -PyTorch FSDP2 (RFC) provides a fully sharded data parallelism (FSDP) implementation targeting performant eager-mode while using per-parameter sharding for improved usability - -See the Getting Started with FSDP2 tutorial for more information. - -If you are currently using FSDP1, consider migrating to FSDP2 using our migration guide. - -The user contract for fully_shard(model) is as follows - -For model initialization, fully_shard converts model.parameters() from plain torch.Tensor to DTensor in-place. The parameters are moved to the appropriate device according to the device mesh. - -Before forward and backward passes, pre-forward/backward hooks are responsible for all-gathering the parameters and converting model.parameters() from DTensor to plain torch.Tensor. - -After forward and backward passes, post-forward/backward hooks free the unsharded parameters (no communication needed) and convert model.parameters() from plain torch.Tensor back to DTensor. - -For the optimizer, it must be initialized with the DTensor model.parameters(), and the optimizer step should be performed on DTensor parameters. - -Call model(input) instead of model.forward(input) to trigger pre-forward hooks to all-gather parameters. To make model.forward(input) work, users must either call model.unshard() explicitly or use register_fsdp_forward_method(model, "forward") to register the forward method for hooking. - -fully_shard groups parameters together for a single all-gather. User should apply fully_shard in a bottom-up manner. For example, in a Transformer model, fully_shard should be applied to each layer before applying it to the root model. When applied to the root model, fully_shard excludes model.parameters() from each layer and groups the remaining parameters (e.g., embeddings, output projection) into a single all-gather group. - -type(model) is “unioned” with FSDPModule in-place. For example, if model is originally of type nn.Linear, then fully_shard changes type(model) from nn.Linear to FSDPLinear in-place. FSDPLinear is an instance of both nn.Linear and FSDPModule. It retains all methods of nn.Linear while also exposing FSDP2-specific APIs under FSDPModule, such as reshard() and unshard(). - -Fully Qualified Names (FQNs) for parameters remain unchanged. If we call model.state_dict(), the FQNs are the same before and after applying fully_shard. This is because fully_shard does not wrap the module but only registers hooks to the original module. - -Compared to PyTorch FSDP1 (FullyShardedDataParallel): - -FSDP2 uses DTensor-based dim-0 per-parameter sharding for a simpler sharding representation compared to FSDP1’s flat-parameter sharding, while preserving similar throughput performance. More specifically, FSDP2 chunks each parameter on dim-0 across the data parallel workers (using torch.chunk(dim=0)), whereas FSDP1 flattens, concatenates, and chunks a group of tensors together, making reasoning about what data is present on each worker and resharding to different parallelisms complex. Per-parameter sharding provides a more intuitive user experience, relaxes constraints around frozen parameters, and allows for communication-free (sharded) state dicts, which otherwise require all-gathers in FSDP1. - -FSDP2 implements a different memory management approach to handle the multi-stream usages that avoids torch.Tensor.record_stream. This ensures deterministic and expected memory usage and does not require blocking the CPU like in FSDP1’s limit_all_gathers=True. - -FSDP2 exposes APIs for manual control over prefetching and collective scheduling, allowing power users more customization. See the methods on FSDPModule below for details. - -FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly support full state dicts. Instead, users can reshard the sharded state dicts containing DTensor s to full state dicts themselves using DTensor APIs like DTensor.full_tensor() or by using higher-level APIs like PyTorch Distributed Checkpoint ‘s distributed state dict APIs. Also, some other args have been removed; see here for details. - -The frontend API is fully_shard that can be called on a module: - -Apply fully sharded data parallelism (FSDP) to module, where FSDP shards module parameters, gradients, and optimizer states across data parallel workers to save memory at the cost of communication. - -At initialization, FSDP shards the module’s parameters across the data parallel workers given by mesh. Before forward, FSDP all-gathers the sharded parameters across the data-parallel workers to get the unsharded parameters for forward computation. If reshard_after_forward is True, then FSDP frees the unsharded parameters after forward and re-all-gathers them in backward before gradient computation. After gradient computation, FSDP frees the unsharded parameters and reduce-scatters the unsharded gradients across data-parallel workers. - -This implementation represents the sharded parameters as DTensor s sharded on dim-0, while the unsharded parameters will be like the original parameters on module (e.g. torch.Tensor if originally torch.Tensor). A module forward pre-hook on module all-gathers the parameters, and a module forward hook on module frees them (if needed). Similar backward hooks all-gather parameters and later free parameters and reduce-scatter gradients. - -Since grouping multiple tensors together for one collective is critical for communication efficiency, this implementation makes this grouping first class. Calling fully_shard() on module constructs one group that includes the parameters in module.parameters() except those already assigned to a group from an earlier call on a submodule. This means that fully_shard() should be called bottom-up on your model. Each group’s parameters are all-gathered in one collective, and its gradients are reduce-scattered in one collective. Partitioning the model into multiple groups (“layer by layer”) allows for peak memory savings and communication/computation overlap. Users generally should not call fully_shard() only on the topmost root module. - -module (Union[nn.Module, List[nn.Module]) – The module or modules to shard with FSDP and group together for communication. - -mesh (Optional[DeviceMesh]) – This data parallel mesh defines the sharding and device. If 1D, then parameters are fully sharded across the 1D mesh (FSDP) with (Shard(0),) placement. If 2D, then parameters are sharded across the 1st dim and replicated across the 0th dim (HSDP) with (Replicate(), Shard(0)) placement. The mesh’s device type gives the device type used for communication; if a CUDA or CUDA-like device type, then we use the current device. - -reshard_after_forward (Optional[Union[bool, int]]) – This controls the parameter behavior after forward and can trade off memory and communication: If True, then this reshards parameters after forward and re-all-gathers in backward. If False, then this keeps the unsharded parameters in memory after forward and avoids the all-gather in backward. For best performance, we usually set False for the root module, because the root module is typically required immediately when the backward pass begins. If None, it is set to True for non-root modules and False for root modules. If an int, then this represents the world size to reshard to after forward. It should be a non-trivial divisor of the mesh shard dim size (i.e. excluding 1 and the dim size itself). A choice may be the intra-node size (e.g. torch.cuda.device_count()). This allows the all-gather in backward to be over a smaller world size at the cost of higher memory usage than setting to True. After forward, the parameters registered to the module depend on to this: The registered parameters are the sharded parameters if True; unsharded parameters if False; and the parameters resharded to the smaller mesh otherwise. To modify the parameters between forward and backward, the registered parameters must be the sharded parameters. For False or an int, this can be done by manually resharding via reshard(). - -This controls the parameter behavior after forward and can trade off memory and communication: - -If True, then this reshards parameters after forward and re-all-gathers in backward. - -If False, then this keeps the unsharded parameters in memory after forward and avoids the all-gather in backward. For best performance, we usually set False for the root module, because the root module is typically required immediately when the backward pass begins. - -If None, it is set to True for non-root modules and False for root modules. - -If an int, then this represents the world size to reshard to after forward. It should be a non-trivial divisor of the mesh shard dim size (i.e. excluding 1 and the dim size itself). A choice may be the intra-node size (e.g. torch.cuda.device_count()). This allows the all-gather in backward to be over a smaller world size at the cost of higher memory usage than setting to True. - -After forward, the parameters registered to the module depend on to this: The registered parameters are the sharded parameters if True; unsharded parameters if False; and the parameters resharded to the smaller mesh otherwise. To modify the parameters between forward and backward, the registered parameters must be the sharded parameters. For False or an int, this can be done by manually resharding via reshard(). - -shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]) – This callable can be used to override the sharding placement for a parameter to shard a parameter on a dimension other than dim-0. If this callable returns a Shard placement (not None), then FSDP will shard according to that placement (e.g. Shard(1)). If sharding on a nonzero dim, we currently require even sharding, i.e. the tensor dim size on that dim must be divisible by the FSDP shard mesh size. - -mp_policy (MixedPrecisionPolicy) – This controls the mixed precision policy, which offers parameter/reduction mixed precision for this module. See MixedPrecisionPolicy for details. - -offload_policy (OffloadPolicy) – This controls the offloading policy, which offers parameter/gradient/optimizer state offloading. See OffloadPolicy and its subclasses for details. - -ignored_params (Optional[set[nn.Parameter]]) – Optional(Set[nn.Parameter]): The set of parameters to be ignored by FSDP. They will not be sharded, nor moved to the device during init, nor have their gradients reduced in backward. - -The module with FSDP applied (in-place). - -Reshards the module’s parameters, freeing the unsharded parameters if they are allocated and registering the sharded parameters to the module. This method is not recursive. - -hook (Callable[[torch.Tensor], None]) – User-defined all-reduce hook with expected signature hook(reduce_output: torch.Tensor) -> None where reduce_output is the reduce-scatter output if only using FSDP or the all-reduce output if using native HSDP. - -stream (Optional[torch.cuda.Stream]) – Stream to run the all-reduce hook in. This should only be set if not using native HSDP. If using native HSDP, the hook will run in the internally defined all-reduce stream used by the native HSDP all-reduce. - -Sets whether the temporary staging buffers used to send and receive data over collective communications should be allocated using the custom optimized allocator provided by the ProcessGroup itself (if any). This might allow the ProcessGroup to be more efficient. For example, when using NCCL, this enables it to leverage zero-copy transfers over SHARP (for NVLink and/or InfiniBand). - -This cannot be used together with set_custom_all_gather() or set_custom_reduce_scatter() as those APIs allow for finer-grained control over each communication, and this method cannot determine their staging buffer allocation strategy. - -enable (bool) – Whether to turn on ProcessGroup allocation. - -Overrides the default all_gather communication behavior, to have better control over the communication and memory usage. See Comm and ReduceScatter for details. - -comm (AllGather) – Custom all-gather communication. - -Overrides the default reduce_scatter communication behavior, to have better control over the communication and memory usage. See Comm and ReduceScatter for details. - -comm (ReduceScatter) – Custom reduce_scatter communication. - -Sets whether to require the low-level collective communication primitives to exclusively use “sum”-type reductions, even if it comes at the cost of separate additional pre- or post-scaling operations. This is needed for example because NCCL currently supports zero-copy transfers only for this kind of collectives. - -NB: for MTIA devices, this is always implicitly enabled. - -NB: if set_all_reduce_hook is used under FSDP setup, the caller needs to ensure the custom all-reduce across FSDP units follow this strategy as well, as FSDP can no longer automatically handle that. - -enable (bool) – Whether to only ever use ReduceOp.SUM for comms. - -Sets a custom divide factor for the gradient reduction. This might use a custom reduce op using NCCL’s PreMulSum, which allows multiplying by the factor before reduction. - -factor (float) – Custom divide factor. - -Sets whether the next backward is the last one. On the last backward, FSDP waits on pending gradient reduction and clears internal data data structures for backward prefetching. This can be useful for microbatching. - -Sets the FSDP modules for which this FSDP module should explicitly prefetch all-gathers in backward. This overrides the default backward pretching implementation that prefetches the next FSDP module based on the reverse post-forward order. - -Passing a singleton list containing the previous FSDP module gives the same all-gather overlap behavior as the default overlap behavior. Passing a list with at least length two is required for more aggressive overlap and will use more reserved memory. - -modules (List[FSDPModule]) – FSDP modules to prefetch. - -Sets the FSDP modules for which this FSDP module should explicitly prefetch all-gathers in forward. The prefetching runs after this module’s all-gather copy-out. - -Passing a singleton list containing the next FSDP module gives the same all-gather overlap behavior as the default overlap behavior, except the prefetched all-gather is issued earlier from the CPU. Passing a list with at least length two is required for more aggressive overlap and will use more reserved memory. - -modules (List[FSDPModule]) – FSDP modules to prefetch. - -Sets a post-optimizer-step event for the root FSDP module to wait the all-gather streams on. - -By default, the root FSDP module waits the all-gather streams on the current stream to ensure that the optimizer step has finished before all-gathering. However, this may introduce false dependencies if there is unrelated computation after the optimizer step. This API allows the user to provide their own event to wait on. After the root waits on the event, the event is discarded, so this API should be called with a new event each iteration. - -event (torch.Event) – Event recorded after the optimizer step to wait all-gather streams on. - -Use set_gradient_divide_factor() instead - -Sets if the module should all-reduce gradients. This can be used to implement gradient accumulation with only reduce-scatter but not all-reduce for HSDP. - -Sets if the module should sync gradients. This can be used to implement gradient accumulation without communication. For HSDP, this controls both reduce-scatter and all-reduce together. This is the equivalence of no_sync in FSDP1. - -requires_gradient_sync (bool) – Whether to reduce gradients for the module’s parameters. - -recurse (bool) – Whether to set for all FSDP submodules or just the passed-in module. - -Sets if the module should reshard parameters after backward. This can be used during gradient accumulation to trade off higher memory for reduced communication since the unsharded parameters do not need to be re-all-gathered before the next forward. - -reshard_after_backward (bool) – Whether to reshard parameters after backward. - -recurse (bool) – Whether to set for all FSDP submodules or just the passed-in module. - -Sets if the module should reshard parameters after forward. This can be used to change the reshard_after_forward FSDP arg at runtime. For example, this can be used to set the FSDP root module’s value to True (since it is otherwise specially set to False), or it can set an FSDP module’s value to False for running evals and set back to True for training. - -reshard_after_forward (bool) – Whether to reshard parameters after forward. - -recurse (bool) – Whether to set for all FSDP submodules or just the passed-in module. - -Sets whether the FSDP module’s parameters need to be unsharded in backward. This can be used in expert cases when the user knows that all parameters in this FSDP module’s parameter group are not needed for backward computation (e.g. embedding). - -Unshards the module’s parameters by allocating memory and all-gathering the parameters. This method is not recursive. The unshard follows the MixedPrecisionPolicy, so it will all-gather following param_dtype if set. - -async_op (bool) – If True, then returns a UnshardHandle that has a wait() method to wait on the unshard op. If False, then returns None and waits on the handle inside this function. - -Optional[UnshardHandle] - -If async_op=True, then FSDP will wait on the pending unshard in the module’s pre-forward for the user. The user only needs to call wait() explicitly if the wait should happen before pre-forward. - -A handle to wait on a FSDPModule.unshard() op. - -Waits on the unshard op. This ensures that the current stream can use the unsharded parameters, which are now registered to the module. - -Registers a method on module to be considered a forward method for FSDP. - -FSDP all-gathers parameters pre-forward and optionally frees parameters post-forward (depending on reshard_after_forward). FSDP only knows to do this for nn.Module.forward() by default. This function patches a user-specified method to run the pre/post-forward hooks before/after the method, respectively. If module is not an FSDPModule, then this is a no-op. - -module (nn.Module) – Module to register the forward method on. - -method_name (str) – Name of the forward method. - -This configures FSDP’s mixed precision. Unlike autocast, this applies mixed precision at the module level, not op level, which means low-precision activations are saved for backward and high-to-low-precision casts are incurred only at module boundaries. - -FSDP works well with module-level mixed precision since it keeps the high-precision sharded parameters in memory anyway. In other words, FSDP does not require any extra memory to keep a high-precision copy of the parameters for the optimizer step. - -param_dtype (Optional[torch.dtype]) – This specifies the dtype for the unsharded parameter and hence the dtype for forward/backward computation and the parameter all-gather. If this is None, then the unsharded parameter uses the original dtype. The optimizer step uses the sharded parameter in the original dtype. (Default: None) - -reduce_dtype (Optional[torch.dtype]) – This specifies the dtype for gradient reduction (i.e. reduce-scatter or all-reduce). If this is None but param_dtype is not None, then the reduction uses the compute dtype. This can be used to run gradient reduction in full precision while using low precision for compute. If also gradient reduction is disabled via set_requires_gradient_sync(), then FSDP will accumulate gradients using reduce_dtype. (Default: None) - -output_dtype (Optional[torch.dtype]) – This specifies the dtype for casting floating-point forward outputs. This can be used to help implement cases where different modules have different mixed precision policies. (Default: None) - -cast_forward_inputs (bool) – This specifies whether FSDP should cast the forward’s floating-point input tensors to param_dtype or not. - -This base class represents the policy of no offloading and is only used as the default value for the offload_policy arg. - -This offload policy offloads parameters, gradients, and optimizer states to CPU. Sharded parameters are copied host-to-device before all-gather. The all-gathered parameters are freed according to reshard_after_forward. Sharded gradients are copied device-to-host in backward, and the optimizer step runs on CPU with CPU optimizer states. - -pin_memory (bool) – Whether to pin sharded parameter and gradient memory. Pinning memory allows both more efficient H2D/D2H copies and for the copies to overlap with compute. However, the pinned memory cannot be used by other processes. Set this to False if you have insufficient CPU memory. (Default: True) - ---- - -## Distributed communication package - torch.distributed# - -**URL:** https://pytorch.org/docs/stable/distributed.html - -**Contents:** -- Distributed communication package - torch.distributed# -- Backends# - - Backends that come with PyTorch# - - Which backend to use?# - - Common environment variables# - - Choosing the network interface to use# - - Other NCCL environment variables# -- Basics# -- Initialization# - - TCP initialization# - -Created On: Jul 12, 2017 | Last Updated On: Sep 04, 2025 - -Please refer to PyTorch Distributed Overview for a brief introduction to all features related to distributed training. - -torch.distributed supports four built-in backends, each with different capabilities. The table below shows which functions are available for use with a CPU or GPU for each backend. For NCCL, GPU refers to CUDA GPU while for XCCL to XPU GPU. - -MPI supports CUDA only if the implementation used to build PyTorch supports it. - -PyTorch distributed package supports Linux (stable), MacOS (stable), and Windows (prototype). By default for Linux, the Gloo and NCCL backends are built and included in PyTorch distributed (NCCL only when building with CUDA). MPI is an optional backend that can only be included if you build PyTorch from source. (e.g. building PyTorch on a host that has MPI installed.) - -As of PyTorch v1.8, Windows supports all collective communications backend but NCCL, If the init_method argument of init_process_group() points to a file it must adhere to the following schema: - -Local file system, init_method="file:///d:/tmp/some_file" - -Shared file system, init_method="file://////{machine_name}/{share_folder_name}/some_file" - -Same as on Linux platform, you can enable TcpStore by setting environment variables, MASTER_ADDR and MASTER_PORT. - -In the past, we were often asked: “which backend should I use?”. - -Use the NCCL backend for distributed training with CUDA GPU. - -Use the XCCL backend for distributed training with XPU GPU. - -Use the Gloo backend for distributed training with CPU. - -GPU hosts with InfiniBand interconnect - -Use NCCL, since it’s the only backend that currently supports InfiniBand and GPUDirect. - -GPU hosts with Ethernet interconnect - -Use NCCL, since it currently provides the best distributed GPU training performance, especially for multiprocess single-node or multi-node distributed training. If you encounter any problem with NCCL, use Gloo as the fallback option. (Note that Gloo currently runs slower than NCCL for GPUs.) - -CPU hosts with InfiniBand interconnect - -If your InfiniBand has enabled IP over IB, use Gloo, otherwise, use MPI instead. We are planning on adding InfiniBand support for Gloo in the upcoming releases. - -CPU hosts with Ethernet interconnect - -Use Gloo, unless you have specific reasons to use MPI. - -By default, both the NCCL and Gloo backends will try to find the right network interface to use. If the automatically detected interface is not correct, you can override it using the following environment variables (applicable to the respective backend): - -NCCL_SOCKET_IFNAME, for example export NCCL_SOCKET_IFNAME=eth0 - -GLOO_SOCKET_IFNAME, for example export GLOO_SOCKET_IFNAME=eth0 - -If you’re using the Gloo backend, you can specify multiple interfaces by separating them by a comma, like this: export GLOO_SOCKET_IFNAME=eth0,eth1,eth2,eth3. The backend will dispatch operations in a round-robin fashion across these interfaces. It is imperative that all processes specify the same number of interfaces in this variable. - -Debugging - in case of NCCL failure, you can set NCCL_DEBUG=INFO to print an explicit warning message as well as basic NCCL initialization information. - -You may also use NCCL_DEBUG_SUBSYS to get more details about a specific aspect of NCCL. For example, NCCL_DEBUG_SUBSYS=COLL would print logs of collective calls, which may be helpful when debugging hangs, especially those caused by collective type or message size mismatch. In case of topology detection failure, it would be helpful to set NCCL_DEBUG_SUBSYS=GRAPH to inspect the detailed detection result and save as reference if further help from NCCL team is needed. - -Performance tuning - NCCL performs automatic tuning based on its topology detection to save users’ tuning effort. On some socket-based systems, users may still try tuning NCCL_SOCKET_NTHREADS and NCCL_NSOCKS_PERTHREAD to increase socket network bandwidth. These two environment variables have been pre-tuned by NCCL for some cloud providers, such as AWS or GCP. - -For a full list of NCCL environment variables, please refer to NVIDIA NCCL’s official documentation - -You can tune NCCL communicators even further using torch.distributed.ProcessGroupNCCL.NCCLConfig and torch.distributed.ProcessGroupNCCL.Options. Learn more about them using help (e.g. help(torch.distributed.ProcessGroupNCCL.NCCLConfig)) in the interpreter. - -The torch.distributed package provides PyTorch support and communication primitives for multiprocess parallelism across several computation nodes running on one or more machines. The class torch.nn.parallel.DistributedDataParallel() builds on this functionality to provide synchronous distributed training as a wrapper around any PyTorch model. This differs from the kinds of parallelism provided by Multiprocessing package - torch.multiprocessing and torch.nn.DataParallel() in that it supports multiple network-connected machines and in that the user must explicitly launch a separate copy of the main training script for each process. - -In the single-machine synchronous case, torch.distributed or the torch.nn.parallel.DistributedDataParallel() wrapper may still have advantages over other approaches to data-parallelism, including torch.nn.DataParallel(): - -Each process maintains its own optimizer and performs a complete optimization step with each iteration. While this may appear redundant, since the gradients have already been gathered together and averaged across processes and are thus the same for every process, this means that no parameter broadcast step is needed, reducing time spent transferring tensors between nodes. - -Each process contains an independent Python interpreter, eliminating the extra interpreter overhead and “GIL-thrashing” that comes from driving several execution threads, model replicas, or GPUs from a single Python process. This is especially important for models that make heavy use of the Python runtime, including models with recurrent layers or many small components. - -The package needs to be initialized using the torch.distributed.init_process_group() or torch.distributed.device_mesh.init_device_mesh() function before calling any other methods. Both block until all processes have joined. - -Initialization is not thread-safe. Process group creation should be performed from a single thread, to prevent inconsistent ‘UUID’ assignment across ranks, and to prevent races during initialization that can lead to hangs. - -Return True if the distributed package is available. - -Otherwise, torch.distributed does not expose any other APIs. Currently, torch.distributed is available on Linux, MacOS and Windows. Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source. Currently, the default value is USE_DISTRIBUTED=1 for Linux and Windows, USE_DISTRIBUTED=0 for MacOS. - -Initialize the default distributed process group. - -This will also initialize the distributed package. - -Specify store, rank, and world_size explicitly. - -Specify init_method (a URL string) which indicates where/how to discover peers. Optionally specify rank and world_size, or encode all required parameters in the URL and omit them. - -If neither is specified, init_method is assumed to be “env://”. - -backend (str or Backend, optional) – The backend to use. Depending on build-time configurations, valid values include mpi, gloo, nccl, ucc, xccl or one that is registered by a third-party plugin. Since 2.6, if backend is not provided, c10d will use a backend registered for the device type indicated by the device_id kwarg (if provided). The known default registrations today are: nccl for cuda, gloo for cpu, xccl for xpu. If neither backend nor device_id is provided, c10d will detect the accelerator on the run-time machine and use a backend registered for that detected accelerator (or cpu). This field can be given as a lowercase string (e.g., "gloo"), which can also be accessed via Backend attributes (e.g., Backend.GLOO). If using multiple processes per machine with nccl backend, each process must have exclusive access to every GPU it uses, as sharing GPUs between processes can result in deadlock or NCCL invalid usage. ucc backend is experimental. Default backend for the device can be queried with get_default_backend_for_device(). - -init_method (str, optional) – URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified. Mutually exclusive with store. - -world_size (int, optional) – Number of processes participating in the job. Required if store is specified. - -rank (int, optional) – Rank of the current process (it should be a number between 0 and world_size-1). Required if store is specified. - -store (Store, optional) – Key/value store accessible to all workers, used to exchange connection/address information. Mutually exclusive with init_method. - -timeout (timedelta, optional) – Timeout for operations executed against the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends. This is the duration after which collectives will be aborted asynchronously and the process will crash. This is done since CUDA execution is async and it is no longer safe to continue executing user code since failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. - -group_name (str, optional, deprecated) – Group name. This argument is ignored - -pg_options (ProcessGroupOptions, optional) – process group options specifying what additional options need to be passed in during the construction of specific process groups. As of now, the only options we support is ProcessGroupNCCL.Options for the nccl backend, is_high_priority_stream can be specified so that the nccl backend can pick up high priority cuda streams when there’re compute kernels waiting. For other available options to config nccl, See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t - -device_id (torch.device | int, optional) – a single, specific device this process will work on, allowing for backend-specific optimizations. Currently this has two effects, only under NCCL: the communicator is immediately formed (calling ncclCommInit* immediately rather than the normal lazy call) and sub-groups will use ncclCommSplit when possible to avoid unnecessary overhead of group creation. If you want to know NCCL initialization error early, you can also use this field. If an int is provided, the API assumes that the accelerator type at compile time will be used. - -To enable backend == Backend.MPI, PyTorch needs to be built from source on a system that supports MPI. - -Support for multiple backends is experimental. Currently when no backend is specified, both gloo and nccl backends will be created. The gloo backend will be used for collectives with CPU tensors and the nccl backend will be used for collectives with CUDA tensors. A custom backend can be specified by passing in a string with format “:,:”, e.g. “cpu:gloo,cuda:custom_backend”. - -Initializes a DeviceMesh based on device_type, mesh_shape, and mesh_dim_names parameters. - -This creates a DeviceMesh with an n-dimensional array layout, where n is the length of mesh_shape. If mesh_dim_names is provided, each dimension is labeled as mesh_dim_names[i]. - -init_device_mesh follows SPMD programming model, meaning the same PyTorch Python program runs on all processes/ranks in the cluster. Ensure mesh_shape (the dimensions of the nD array describing device layout) is identical across all ranks. Inconsistent mesh_shape may lead to hanging. - -If no process group is found, init_device_mesh will initialize distributed process group/groups required for distributed communications behind the scene. - -device_type (str) – The device type of the mesh. Currently supports: “cpu”, “cuda/cuda-like”, “xpu”. Passing in a device type with a GPU index, such as “cuda:0”, is not allowed. - -mesh_shape (Tuple[int]) – A tuple defining the dimensions of the multi-dimensional array describing the layout of devices. - -mesh_dim_names (Tuple[str], optional) – A tuple of mesh dimension names to assign to each dimension of the multi-dimensional array describing the layout of devices. Its length must match the length of mesh_shape. Each string in mesh_dim_names must be unique. - -backend_override (Dict[int | str, tuple[str, Options] | str | Options], optional) – Overrides for some or all of the ProcessGroups that will be created for each mesh dimension. Each key can be either the index of a dimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the name of the backend and its options, or just one of these two components (in which case the other will be set to its default value). - -A DeviceMesh object representing the device layout. - -Check if the default process group has been initialized. - -Check if the MPI backend is available. - -Check if the NCCL backend is available. - -Check if the Gloo backend is available. - -Check if the XCCL backend is available. - -Check whether this process was launched with torch.distributed.elastic (aka torchelastic). - -The existence of TORCHELASTIC_RUN_ID environment variable is used as a proxy to determine whether the current process was launched with torchelastic. This is a reasonable proxy since TORCHELASTIC_RUN_ID maps to the rendezvous id which is always a non-null value indicating the job id for peer discovery purposes.. - -Return the default backend for the given device. - -device (Union[str, torch.device]) – The device to get the default backend for. - -The default backend for the given device as a lower case string. - -Currently three initialization methods are supported: - -There are two ways to initialize using TCP, both requiring a network address reachable from all processes and a desired world_size. The first way requires specifying an address that belongs to the rank 0 process. This initialization method requires that all processes have manually specified ranks. - -Note that multicast address is not supported anymore in the latest distributed package. group_name is deprecated as well. - -Another initialization method makes use of a file system that is shared and visible from all machines in a group, along with a desired world_size. The URL should start with file:// and contain a path to a non-existent file (in an existing directory) on a shared file system. File-system initialization will automatically create that file if it doesn’t exist, but will not delete the file. Therefore, it is your responsibility to make sure that the file is cleaned up before the next init_process_group() call on the same file path/name. - -Note that automatic rank assignment is not supported anymore in the latest distributed package and group_name is deprecated as well. - -This method assumes that the file system supports locking using fcntl - most local systems and NFS support it. - -This method will always create the file and try its best to clean up and remove the file at the end of the program. In other words, each initialization with the file init method will need a brand new empty file in order for the initialization to succeed. If the same file used by the previous initialization (which happens not to get cleaned up) is used again, this is unexpected behavior and can often cause deadlocks and failures. Therefore, even though this method will try its best to clean up the file, if the auto-delete happens to be unsuccessful, it is your responsibility to ensure that the file is removed at the end of the training to prevent the same file to be reused again during the next time. This is especially important if you plan to call init_process_group() multiple times on the same file name. In other words, if the file is not removed/cleaned up and you call init_process_group() again on that file, failures are expected. The rule of thumb here is that, make sure that the file is non-existent or empty every time init_process_group() is called. - -This method will read the configuration from environment variables, allowing one to fully customize how the information is obtained. The variables to be set are: - -MASTER_PORT - required; has to be a free port on machine with rank 0 - -MASTER_ADDR - required (except for rank 0); address of rank 0 node - -WORLD_SIZE - required; can be set either here, or in a call to init function - -RANK - required; can be set either here, or in a call to init function - -The machine with rank 0 will be used to set up all connections. - -This is the default method, meaning that init_method does not have to be specified (or can be env://). - -TORCH_GLOO_LAZY_INIT - establishes connections on demand rather than using a full mesh which can greatly improve initialization time for non all2all operations. - -Once torch.distributed.init_process_group() was run, the following functions can be used. To check whether the process group has already been initialized use torch.distributed.is_initialized(). - -An enum-like class for backends. - -Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends. - -The values of this class are lowercase strings, e.g., "gloo". They can be accessed as attributes, e.g., Backend.NCCL. - -This class can be directly called to parse the string, e.g., Backend(backend_str) will check if backend_str is valid, and return the parsed lowercase string if so. It also accepts uppercase strings, e.g., Backend("GLOO") returns "gloo". - -The entry Backend.UNDEFINED is present but only used as initial value of some fields. Users should neither use it directly nor assume its existence. - -Register a new backend with the given name and instantiating function. - -This class method is used by 3rd party ProcessGroup extension to register new backends. - -name (str) – Backend name of the ProcessGroup extension. It should match the one in init_process_group(). - -func (function) – Function handler that instantiates the backend. The function should be implemented in the backend extension and takes four arguments, including store, rank, world_size, and timeout. - -extended_api (bool, optional) – Whether the backend supports extended argument structure. Default: False. If set to True, the backend will get an instance of c10d::DistributedBackendOptions, and a process group options object as defined by the backend implementation. - -device (str or list of str, optional) – device type this backend supports, e.g. “cpu”, “cuda”, etc. If None, assuming both “cpu” and “cuda” - -This support of 3rd party backend is experimental and subject to change. - -Return the backend of the given process group. - -group (ProcessGroup, optional) – The process group to work on. The default is the general main process group. If another specific group is specified, the calling process must be part of group. - -The backend of the given process group as a lower case string. - -Return the rank of the current process in the provided group, default otherwise. - -Rank is a unique identifier assigned to each process within a distributed process group. They are always consecutive integers ranging from 0 to world_size. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -The rank of the process group -1, if not part of the group - -Return the number of processes in the current process group. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -The world size of the process group -1, if not part of the group - -It is important to clean up resources on exit by calling destroy_process_group(). - -The simplest pattern to follow is to destroy every process group and backend by calling destroy_process_group() with the default value of None for the group argument, at a point in the training script where communications are no longer needed, usually near the end of main(). The call should be made once per trainer-process, not at the outer process-launcher level. - -if destroy_process_group() is not called by all ranks in a pg within the timeout duration, especially when there are multiple process-groups in the application e.g. for N-D parallelism, hangs on exit are possible. This is because the destructor for ProcessGroupNCCL calls ncclCommAbort, which must be called collectively, but the order of calling ProcessGroupNCCL’s destructor if called by python’s GC is not deterministic. Calling destroy_process_group() helps by ensuring ncclCommAbort is called in a consistent order across ranks, and avoids calling ncclCommAbort during ProcessGroupNCCL’s destructor. - -destroy_process_group can also be used to destroy individual process groups. One use case could be fault tolerant training, where a process group may be destroyed and then a new one initialized during runtime. In this case, it’s critical to synchronize the trainer processes using some means other than torch.distributed primitives _after_ calling destroy and before subsequently initializing. This behavior is currently unsupported/untested, due to the difficulty of achieving this synchronization, and is considered a known issue. Please file a github issue or RFC if this is a use case that’s blocking you. - -By default collectives operate on the default group (also called the world) and require all processes to enter the distributed function call. However, some workloads can benefit from more fine-grained communication. This is where distributed groups come into play. new_group() function can be used to create new groups, with arbitrary subsets of all processes. It returns an opaque group handle that can be given as a group argument to all collectives (collectives are distributed functions to exchange information in certain well-known programming patterns). - -Create a new distributed group. - -This function requires that all processes in the main group (i.e. all processes that are part of the distributed job) enter this function, even if they are not going to be members of the group. Additionally, groups should be created in the same order in all processes. - -Safe concurrent usage: When using multiple process groups with the NCCL backend, the user must ensure a globally consistent execution order of collectives across ranks. - -If multiple threads within a process issue collectives, explicit synchronization is necessary to ensure consistent ordering. - -When using async variants of torch.distributed communication APIs, a work object is returned and the communication kernel is enqueued on a separate CUDA stream, allowing overlap of communication and computation. Once one or more async ops have been issued on one process group, they must be synchronized with other cuda streams by calling work.wait() before using another process group. - -See Using multiple NCCL communicators concurrently for more details. - -ranks (list[int]) – List of ranks of group members. If None, will be set to all ranks. Default is None. - -timeout (timedelta, optional) – see init_process_group for details and default value. - -backend (str or Backend, optional) – The backend to use. Depending on build-time configurations, valid values are gloo and nccl. By default uses the same backend as the global group. This field should be given as a lowercase string (e.g., "gloo"), which can also be accessed via Backend attributes (e.g., Backend.GLOO). If None is passed in, the backend corresponding to the default process group will be used. Default is None. - -pg_options (ProcessGroupOptions, optional) – process group options specifying what additional options need to be passed in during the construction of specific process groups. i.e. for the nccl backend, is_high_priority_stream can be specified so that process group can pick up high priority cuda streams. For other available options to config nccl, See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-tuse_local_synchronization (bool, optional): perform a group-local barrier at the end of the process group creation. This is different in that non-member ranks don’t need to call into API and don’t join the barrier. - -group_desc (str, optional) – a string to describe the process group. - -device_id (torch.device, optional) – a single, specific device to “bind” this process to, The new_group call will try to initialize a communication backend immediately for the device if this field is given. - -A handle of distributed group that can be given to collective calls or GroupMember.NON_GROUP_MEMBER if the rank is not part of ranks. - -N.B. use_local_synchronization doesn’t work with MPI. - -N.B. While use_local_synchronization=True can be significantly faster with larger clusters and small process groups, care must be taken since it changes cluster behavior as non-member ranks don’t join the group barrier(). - -N.B. use_local_synchronization=True can lead to deadlocks when each rank creates multiple overlapping process groups. To avoid that, make sure all ranks follow the same global creation order. - -Translate a global rank into a group rank. - -global_rank must be part of group otherwise this raises RuntimeError. - -group (ProcessGroup) – ProcessGroup to find the relative rank. - -global_rank (int) – Global rank to query. - -Group rank of global_rank relative to group - -N.B. calling this function on the default process group returns identity - -Translate a group rank into a global rank. - -group_rank must be part of group otherwise this raises RuntimeError. - -group (ProcessGroup) – ProcessGroup to find the global rank from. - -group_rank (int) – Group rank to query. - -Global rank of group_rank relative to group - -N.B. calling this function on the default process group returns identity - -Get all ranks associated with group. - -group (Optional[ProcessGroup]) – ProcessGroup to get all ranks from. If None, the default process group will be used. - -List of global ranks ordered by group rank. - -DeviceMesh is a higher level abstraction that manages process groups (or NCCL communicators). It allows user to easily create inter node and intra node process groups without worrying about how to set up the ranks correctly for different sub process groups, and it helps manage those distributed process group easily. init_device_mesh() function can be used to create new DeviceMesh, with a mesh shape describing the device topology. - -DeviceMesh represents a mesh of devices, where layout of devices could be represented as a n-d dimension array, and each value of the n-d dimensional array is the global id of the default process group ranks. - -DeviceMesh could be used to setup the N dimensional device connections across the cluster, and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects already (i.e. if user call torch.cuda.set_device before the DeviceMesh initialization), and will select/set the device for the current process if user does not set the device beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization. - -DeviceMesh can also be used as a context manager when using together with DTensor APIs. - -DeviceMesh follows SPMD programming model, which means the same PyTorch Python program is running on all processes/ranks in the cluster. Therefore, users need to make sure the mesh array (which describes the layout of devices) should be identical across all ranks. Inconsistent mesh will lead to silent hang. - -device_type (str) – The device type of the mesh. Currently supports: “cpu”, “cuda/cuda-like”. - -mesh (ndarray) – A multi-dimensional array or an integer tensor describing the layout of devices, where the IDs are global IDs of the default process group. - -A DeviceMesh object representing the device layout. - -The following program runs on each process/rank in an SPMD manner. In this example, we have 2 hosts with 4 GPUs each. A reduction over the first dimension of mesh will reduce across columns (0, 4), .. and (3, 7), a reduction over the second dimension of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7). - -Constructs a DeviceMesh with device_type from an existing ProcessGroup or a list of existing ProcessGroup. - -The constructed device mesh has number of dimensions equal to the number of groups passed. For example, if a single process group is passed in, the resulted DeviceMesh is a 1D mesh. If a list of 2 process groups is passed in, the resulted DeviceMesh is a 2D mesh. - -If more than one group is passed, then the mesh and mesh_dim_names arguments are required. The order of the process groups passed in determines the topology of the mesh. For example, the first process group will be the 0th dimension of the DeviceMesh. The mesh tensor passed in must have the same number of dimensions as the number of process groups passed in, and the order of the dimensions in the mesh tensor must match the order in the process groups passed in. - -group (ProcessGroup or list[ProcessGroup]) – the existing ProcessGroup or a list of existing ProcessGroups. - -device_type (str) – The device type of the mesh. Currently supports: “cpu”, “cuda/cuda-like”. Passing in a device type with a GPU index, such as “cuda:0”, is not allowed. - -mesh (torch.Tensor or ArrayLike, optional) – A multi-dimensional array or an integer tensor describing the layout of devices, where the IDs are global IDs of the default process group. Default is None. - -mesh_dim_names (tuple[str], optional) – A tuple of mesh dimension names to assign to each dimension of the multi-dimensional array describing the layout of devices. Its length must match the length of mesh_shape. Each string in mesh_dim_names must be unique. Default is None. - -A DeviceMesh object representing the device layout. - -Returns a list of ProcessGroups for all mesh dimensions. - -A list of ProcessGroup object. - -list[torch.distributed.distributed_c10d.ProcessGroup] - -Return the relative indices of this rank relative to all dimensions of the mesh. If this rank is not part of the mesh, return None. - -Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh. - -mesh_dim (str/python:int, optional) – it can be the name of the mesh dimension or the index - -None. (of the mesh dimension. Default is) – - -A ProcessGroup object. - -Returns the local rank of the given mesh_dim of the DeviceMesh. - -mesh_dim (str/python:int, optional) – it can be the name of the mesh dimension or the index - -None. (of the mesh dimension. Default is) – - -An integer denotes the local rank. - -The following program runs on each process/rank in an SPMD manner. In this example, we have 2 hosts with 4 GPUs each. Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3. - -Returns the current global rank. - -Send a tensor synchronously. - -tag is not supported with the NCCL backend. - -tensor (Tensor) – Tensor to send. - -dst (int) – Destination rank on global process group (regardless of group argument). Destination rank should not be the same as the rank of the current process. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -tag (int, optional) – Tag to match send with remote recv - -group_dst (int, optional) – Destination rank on group. Invalid to specify both dst and group_dst. - -Receives a tensor synchronously. - -tag is not supported with the NCCL backend. - -tensor (Tensor) – Tensor to fill with received data. - -src (int, optional) – Source rank on global process group (regardless of group argument). Will receive from any process if unspecified. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -tag (int, optional) – Tag to match recv with remote send - -group_src (int, optional) – Destination rank on group. Invalid to specify both src and group_src. - -Sender rank -1, if not part of the group - -isend() and irecv() return distributed request objects when used. In general, the type of this object is unspecified as they should never be created manually, but they are guaranteed to support two methods: - -is_completed() - returns True if the operation has finished - -wait() - will block the process until the operation is finished. is_completed() is guaranteed to return True once it returns. - -Send a tensor asynchronously. - -Modifying tensor before the request completes causes undefined behavior. - -tag is not supported with the NCCL backend. - -Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self. - -tensor (Tensor) – Tensor to send. - -dst (int) – Destination rank on global process group (regardless of group argument) - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -tag (int, optional) – Tag to match send with remote recv - -group_dst (int, optional) – Destination rank on group. Invalid to specify both dst and group_dst - -A distributed request object. None, if not part of the group - -Receives a tensor asynchronously. - -tag is not supported with the NCCL backend. - -Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self. - -tensor (Tensor) – Tensor to fill with received data. - -src (int, optional) – Source rank on global process group (regardless of group argument). Will receive from any process if unspecified. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -tag (int, optional) – Tag to match recv with remote send - -group_src (int, optional) – Destination rank on group. Invalid to specify both src and group_src. - -A distributed request object. None, if not part of the group - -Sends picklable objects in object_list synchronously. - -Similar to send(), but Python objects can be passed in. Note that all objects in object_list must be picklable in order to be sent. - -object_list (List[Any]) – List of input objects to sent. Each object must be picklable. Receiver must provide lists of equal sizes. - -dst (int) – Destination rank to send object_list to. Destination rank is based on global process group (regardless of group argument) - -group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None. - -device (torch.device, optional) – If not None, the objects are serialized and converted to tensors which are moved to the device before sending. Default is None. - -group_dst (int, optional) – Destination rank on group. Must specify one of dst and group_dst but not both - -use_batch (bool, optional) – If True, use batch p2p operations instead of regular send operations. This avoids initializing 2-rank communicators and uses existing entire group communicators. See batch_isend_irecv for usage and assumptions. Default is False. - -For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). - -Object collectives have a number of serious performance and scalability limitations. See Object collectives for details. - -send_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. - -Calling send_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using send() instead. - -Receives picklable objects in object_list synchronously. - -Similar to recv(), but can receive Python objects. - -object_list (List[Any]) – List of objects to receive into. Must provide a list of sizes equal to the size of the list being sent. - -src (int, optional) – Source rank from which to recv object_list. Source rank is based on global process group (regardless of group argument) Will receive from any rank if set to None. Default is None. - -group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None. - -device (torch.device, optional) – If not None, receives on this device. Default is None. - -group_src (int, optional) – Destination rank on group. Invalid to specify both src and group_src. - -use_batch (bool, optional) – If True, use batch p2p operations instead of regular send operations. This avoids initializing 2-rank communicators and uses existing entire group communicators. See batch_isend_irecv for usage and assumptions. Default is False. - -Sender rank. -1 if rank is not part of the group. If rank is part of the group, object_list will contain the sent objects from src rank. - -For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). - -Object collectives have a number of serious performance and scalability limitations. See Object collectives for details. - -recv_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. - -Calling recv_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using recv() instead. - -Send or Receive a batch of tensors asynchronously and return a list of requests. - -Process each of the operations in p2p_op_list and return the corresponding requests. NCCL, Gloo, and UCC backend are currently supported. - -p2p_op_list (list[torch.distributed.distributed_c10d.P2POp]) – A list of point-to-point operations(type of each operator is torch.distributed.P2POp). The order of the isend/irecv in the list matters and it needs to match with corresponding isend/irecv on the remote end. - -A list of distributed request objects returned by calling the corresponding op in the op_list. - -list[torch.distributed.distributed_c10d.Work] - -Note that when this API is used with the NCCL PG backend, users must set the current GPU device with torch.cuda.set_device, otherwise it will lead to unexpected hang issues. - -In addition, if this API is the first collective call in the group passed to dist.P2POp, all ranks of the group must participate in this API call; otherwise, the behavior is undefined. If this API call is not the first collective call in the group, batched P2P operations involving only a subset of ranks of the group are allowed. - -A class to build point-to-point operations for batch_isend_irecv. - -This class builds the type of P2P operation, communication buffer, peer rank, Process Group, and tag. Instances of this class will be passed to batch_isend_irecv for point-to-point communications. - -op (Callable) – A function to send data to or receive data from a peer process. The type of op is either torch.distributed.isend or torch.distributed.irecv. - -tensor (Tensor) – Tensor to send or receive. - -peer (int, optional) – Destination or source rank. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -tag (int, optional) – Tag to match send with recv. - -group_peer (int, optional) – Destination or source rank. - -Every collective operation function supports the following two kinds of operations, depending on the setting of the async_op flag passed into the collective: - -Synchronous operation - the default mode, when async_op is set to False. When the function returns, it is guaranteed that the collective operation is performed. In the case of CUDA operations, it is not guaranteed that the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, any further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream synchronization, see CUDA Semantics. See the below script to see examples of differences in these semantics for CPU and CUDA operations. - -Asynchronous operation - when async_op is set to True. The collective operation function returns a distributed request object. In general, you don’t need to create it manually and it is guaranteed to support two methods: - -is_completed() - in the case of CPU collectives, returns True if completed. In the case of CUDA operations, returns True if the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on the default stream without further synchronization. - -wait() - in the case of CPU collectives, will block the process until the operation is completed. In the case of CUDA collectives, will block the currently active CUDA stream until the operation is completed (but will not block the CPU). - -get_future() - returns torch._C.Future object. Supported for NCCL, also supported for most operations on GLOO and MPI, except for peer to peer operations. Note: as we continue adopting Futures and merging APIs, get_future() call might become redundant. - -The following code can serve as a reference regarding semantics for CUDA operations when using distributed collectives. It shows the explicit need to synchronize when using collective outputs on different CUDA streams: - -Broadcasts the tensor to the whole group. - -tensor must have the same number of elements in all processes participating in the collective. - -tensor (Tensor) – Data to be sent if src is the rank of current process, and tensor to be used to save received data otherwise. - -src (int) – Source rank on global process group (regardless of group argument). - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -async_op (bool, optional) – Whether this op should be an async op - -group_src (int) – Source rank on group. Must specify one of group_src and src but not both. - -Async work handle, if async_op is set to True. None, if not async_op or if not part of the group - -Broadcasts picklable objects in object_list to the whole group. - -Similar to broadcast(), but Python objects can be passed in. Note that all objects in object_list must be picklable in order to be broadcasted. - -object_list (List[Any]) – List of input objects to broadcast. Each object must be picklable. Only objects on the src rank will be broadcast, but each rank must provide lists of equal sizes. - -src (int) – Source rank from which to broadcast object_list. Source rank is based on global process group (regardless of group argument) - -group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None. - -device (torch.device, optional) – If not None, the objects are serialized and converted to tensors which are moved to the device before broadcasting. Default is None. - -group_src (int) – Source rank on group. Must not specify one of group_src and src but not both. - -None. If rank is part of the group, object_list will contain the broadcasted objects from src rank. - -For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). - -Note that this API differs slightly from the broadcast() collective since it does not provide an async_op handle and thus will be a blocking call. - -Object collectives have a number of serious performance and scalability limitations. See Object collectives for details. - -broadcast_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. - -Calling broadcast_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using broadcast() instead. - -Reduces the tensor data across all machines in a way that all get the final result. - -After the call tensor is going to be bitwise identical in all processes. - -Complex tensors are supported. - -tensor (Tensor) – Input and output of the collective. The function operates in-place. - -op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -async_op (bool, optional) – Whether this op should be an async op - -Async work handle, if async_op is set to True. None, if not async_op or if not part of the group - -Reduces the tensor data across all machines. - -Only the process with rank dst is going to receive the final result. - -tensor (Tensor) – Input and output of the collective. The function operates in-place. - -dst (int) – Destination rank on global process group (regardless of group argument) - -op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -async_op (bool, optional) – Whether this op should be an async op - -group_dst (int) – Destination rank on group. Must specify one of group_dst and dst but not both. - -Async work handle, if async_op is set to True. None, if not async_op or if not part of the group - -Gathers tensors from the whole group in a list. - -Complex and uneven sized tensors are supported. - -tensor_list (list[Tensor]) – Output list. It should contain correctly-sized tensors to be used for output of the collective. Uneven sized tensors are supported. - -tensor (Tensor) – Tensor to be broadcast from current process. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -async_op (bool, optional) – Whether this op should be an async op - -Async work handle, if async_op is set to True. None, if not async_op or if not part of the group - -Gather tensors from all ranks and put them in a single output tensor. - -This function requires all tensors to be the same size on each process. - -output_tensor (Tensor) – Output tensor to accommodate tensor elements from all ranks. It must be correctly sized to have one of the following forms: (i) a concatenation of all the input tensors along the primary dimension; for definition of “concatenation”, see torch.cat(); (ii) a stack of all the input tensors along the primary dimension; for definition of “stack”, see torch.stack(). Examples below may better explain the supported output forms. - -input_tensor (Tensor) – Tensor to be gathered from current rank. Different from the all_gather API, the input tensors in this API must have the same size across all ranks. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -async_op (bool, optional) – Whether this op should be an async op - -Async work handle, if async_op is set to True. None, if not async_op or if not part of the group - -Gathers picklable objects from the whole group into a list. - -Similar to all_gather(), but Python objects can be passed in. Note that the object must be picklable in order to be gathered. - -object_list (list[Any]) – Output list. It should be correctly sized as the size of the group for this collective and will contain the output. - -obj (Any) – Pickable Python object to be broadcast from current process. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. Default is None. - -None. If the calling rank is part of this group, the output of the collective will be populated into the input object_list. If the calling rank is not part of the group, the passed in object_list will be unmodified. - -Note that this API differs slightly from the all_gather() collective since it does not provide an async_op handle and thus will be a blocking call. - -For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). - -Object collectives have a number of serious performance and scalability limitations. See Object collectives for details. - -all_gather_object() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. - -Calling all_gather_object() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using all_gather() instead. - -Gathers a list of tensors in a single process. - -This function requires all tensors to be the same size on each process. - -tensor (Tensor) – Input tensor. - -gather_list (list[Tensor], optional) – List of appropriately, same-sized tensors to use for gathered data (default is None, must be specified on the destination rank) - -dst (int, optional) – Destination rank on global process group (regardless of group argument). (If both dst and group_dst are None, default is global rank 0) - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -async_op (bool, optional) – Whether this op should be an async op - -group_dst (int, optional) – Destination rank on group. Invalid to specify both dst and group_dst - -Async work handle, if async_op is set to True. None, if not async_op or if not part of the group - -Note that all Tensors in gather_list must have the same size. - -Gathers picklable objects from the whole group in a single process. - -Similar to gather(), but Python objects can be passed in. Note that the object must be picklable in order to be gathered. - -obj (Any) – Input object. Must be picklable. - -object_gather_list (list[Any]) – Output list. On the dst rank, it should be correctly sized as the size of the group for this collective and will contain the output. Must be None on non-dst ranks. (default is None) - -dst (int, optional) – Destination rank on global process group (regardless of group argument). (If both dst and group_dst are None, default is global rank 0) - -group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None. - -group_dst (int, optional) – Destination rank on group. Invalid to specify both dst and group_dst - -None. On the dst rank, object_gather_list will contain the output of the collective. - -Note that this API differs slightly from the gather collective since it does not provide an async_op handle and thus will be a blocking call. - -For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). - -Object collectives have a number of serious performance and scalability limitations. See Object collectives for details. - -gather_object() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. - -Calling gather_object() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using gather() instead. - -Scatters a list of tensors to all processes in a group. - -Each process will receive exactly one tensor and store its data in the tensor argument. - -Complex tensors are supported. - -tensor (Tensor) – Output tensor. - -scatter_list (list[Tensor]) – List of tensors to scatter (default is None, must be specified on the source rank) - -src (int) – Source rank on global process group (regardless of group argument). (If both src and group_src are None, default is global rank 0) - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -async_op (bool, optional) – Whether this op should be an async op - -group_src (int, optional) – Source rank on group. Invalid to specify both src and group_src - -Async work handle, if async_op is set to True. None, if not async_op or if not part of the group - -Note that all Tensors in scatter_list must have the same size. - -Scatters picklable objects in scatter_object_input_list to the whole group. - -Similar to scatter(), but Python objects can be passed in. On each rank, the scattered object will be stored as the first element of scatter_object_output_list. Note that all objects in scatter_object_input_list must be picklable in order to be scattered. - -scatter_object_output_list (List[Any]) – Non-empty list whose first element will store the object scattered to this rank. - -scatter_object_input_list (List[Any], optional) – List of input objects to scatter. Each object must be picklable. Only objects on the src rank will be scattered, and the argument can be None for non-src ranks. - -src (int) – Source rank from which to scatter scatter_object_input_list. Source rank is based on global process group (regardless of group argument). (If both src and group_src are None, default is global rank 0) - -group (Optional[ProcessGroup]) – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None. - -group_src (int, optional) – Source rank on group. Invalid to specify both src and group_src - -None. If rank is part of the group, scatter_object_output_list will have its first element set to the scattered object for this rank. - -Note that this API differs slightly from the scatter collective since it does not provide an async_op handle and thus will be a blocking call. - -Object collectives have a number of serious performance and scalability limitations. See Object collectives for details. - -scatter_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. - -Calling scatter_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using scatter() instead. - -Reduces, then scatters a list of tensors to all processes in a group. - -output (Tensor) – Output tensor. - -input_list (list[Tensor]) – List of tensors to reduce and scatter. - -op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -async_op (bool, optional) – Whether this op should be an async op. - -Async work handle, if async_op is set to True. None, if not async_op or if not part of the group. - -Reduces, then scatters a tensor to all ranks in a group. - -output (Tensor) – Output tensor. It should have the same size across all ranks. - -input (Tensor) – Input tensor to be reduced and scattered. Its size should be output tensor size times the world size. The input tensor can have one of the following shapes: (i) a concatenation of the output tensors along the primary dimension, or (ii) a stack of the output tensors along the primary dimension. For definition of “concatenation”, see torch.cat(). For definition of “stack”, see torch.stack(). - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -async_op (bool, optional) – Whether this op should be an async op. - -Async work handle, if async_op is set to True. None, if not async_op or if not part of the group. - -Split input tensor and then scatter the split list to all processes in a group. - -Later the received tensors are concatenated from all the processes in the group and returned as a single output tensor. - -Complex tensors are supported. - -output (Tensor) – Gathered concatenated output tensor. - -input (Tensor) – Input tensor to scatter. - -output_split_sizes – (list[Int], optional): Output split sizes for dim 0 if specified None or empty, dim 0 of output tensor must divide equally by world_size. - -input_split_sizes – (list[Int], optional): Input split sizes for dim 0 if specified None or empty, dim 0 of input tensor must divide equally by world_size. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -async_op (bool, optional) – Whether this op should be an async op. - -Async work handle, if async_op is set to True. None, if not async_op or if not part of the group. - -all_to_all_single is experimental and subject to change. - -Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. - -Complex tensors are supported. - -output_tensor_list (list[Tensor]) – List of tensors to be gathered one per rank. - -input_tensor_list (list[Tensor]) – List of tensors to scatter one per rank. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -async_op (bool, optional) – Whether this op should be an async op. - -Async work handle, if async_op is set to True. None, if not async_op or if not part of the group. - -all_to_all is experimental and subject to change. - -Synchronize all processes. - -This collective blocks processes until the whole group enters this function, if async_op is False, or if async work handle is called on wait(). - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -async_op (bool, optional) – Whether this op should be an async op - -device_ids ([int], optional) – List of device/GPU ids. Only one id is expected. - -Async work handle, if async_op is set to True. None, if not async_op or if not part of the group - -ProcessGroupNCCL now blocks the cpu thread till the completion of the barrier collective. - -ProcessGroupNCCL implements barrier as an all_reduce of a 1-element tensor. A device must be chosen for allocating this tensor. The device choice is made by checking in this order (1) the first device passed to device_ids arg of barrier if not None, (2) the device passed to init_process_group if not None, (3) the device that was first used with this process group, if another collective with tensor inputs has been performed, (4) the device index indicated by the global rank mod local device count. - -Synchronize processes similar to torch.distributed.barrier, but consider a configurable timeout. - -It is able to report ranks that did not pass this barrier within the provided timeout. Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0. Rank 0 will block until all send /recv from other ranks are processed, and will report failures for ranks that failed to respond in time. Note that if one rank does not reach the monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier. - -This collective will block all processes/ranks in the group, until the whole group exits the function successfully, making it useful for debugging and synchronizing. However, it can have a performance impact and should only be used for debugging or scenarios that require full synchronization points on the host-side. For debugging purposes, this barrier can be inserted before the application’s collective calls to check if any ranks are desynchronized. - -Note that this collective is only supported with the GLOO backend. - -group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. - -timeout (datetime.timedelta, optional) – Timeout for monitored_barrier. If None, the default process group timeout will be used. - -wait_all_ranks (bool, optional) – Whether to collect all failed ranks or not. By default, this is False and monitored_barrier on rank 0 will throw on the first failed rank it encounters in order to fail fast. By setting wait_all_ranks=True monitored_barrier will collect all failed ranks and throw an error containing information about all failed ranks. - -A Work object represents the handle to a pending asynchronous operation in PyTorch’s distributed package. It is returned by non-blocking collective operations, such as dist.all_reduce(tensor, async_op=True). - -Blocks the currently active GPU stream on the operation to complete. For GPU based collectives this is equivalent to synchronize. For CPU initiated collectives such as with Gloo this will block the CUDA stream until the operation is complete. - -This returns immediately in all cases. - -To check whether an operation was successful you should check the Work object result asynchronously. - -A torch.futures.Future object which is associated with the completion of the Work. As an example, a future object can be retrieved by fut = process_group.allreduce(tensors).get_future(). - -Below is an example of a simple allreduce DDP communication hook that uses get_future API to retrieve a Future associated with the completion of allreduce. - -get_future API supports NCCL, and partially GLOO and MPI backends (no support for peer-to-peer operations like send/recv) and will return a torch.futures.Future. - -In the example above, allreduce work will be done on GPU using NCCL backend, fut.wait() will return after synchronizing the appropriate NCCL streams with PyTorch’s current device streams to ensure we can have asynchronous CUDA execution and it does not wait for the entire operation to complete on GPU. Note that CUDAFuture does not support TORCH_NCCL_BLOCKING_WAIT flag or NCCL’s barrier(). In addition, if a callback function was added by fut.then(), it will wait until WorkNCCL’s NCCL streams synchronize with ProcessGroupNCCL’s dedicated callback stream and invoke the callback inline after running the callback on the callback stream. fut.then() will return another CUDAFuture that holds the return value of the callback and a CUDAEvent that recorded the callback stream. - -For CPU work, fut.done() returns true when work has been completed and value() tensors are ready. - -For GPU work, fut.done() returns true only whether the operation has been enqueued. - -For mixed CPU-GPU work (e.g. sending GPU tensors with GLOO), fut.done() returns true when tensors have arrived on respective nodes, but not yet necessarily synched on respective GPUs (similarly to GPU work). - -A torch.futures.Future object of int type which maps to the enum type of WorkResult As an example, a future object can be retrieved by fut = process_group.allreduce(tensor).get_future_result(). - -users can use fut.wait() to blocking wait for the completion of the work and get the WorkResult by fut.value(). Also, users can use fut.then(call_back_func) to register a callback function to be called when the work is completed, without blocking the current thread. - -get_future_result API supports NCCL - -In normal cases, users do not need to set the timeout. calling wait() is the same as calling synchronize(): Letting the current stream block on the completion of the NCCL work. However, if timeout is set, it will block the CPU thread until the NCCL work is completed or timed out. If timeout, exception will be thrown. - -An enum-like class for available reduction operations: SUM, PRODUCT, MIN, MAX, BAND, BOR, BXOR, and PREMUL_SUM. - -BAND, BOR, and BXOR reductions are not available when using the NCCL backend. - -AVG divides values by the world size before summing across ranks. AVG is only available with the NCCL backend, and only for NCCL versions 2.10 or later. - -PREMUL_SUM multiplies inputs by a given scalar locally before reduction. PREMUL_SUM is only available with the NCCL backend, and only available for NCCL versions 2.11 or later. Users are supposed to use torch.distributed._make_nccl_premul_sum. - -Additionally, MAX, MIN and PRODUCT are not supported for complex tensors. - -The values of this class can be accessed as attributes, e.g., ReduceOp.SUM. They are used in specifying strategies for reduction collectives, e.g., reduce(). - -This class does not support __members__ property. - -Deprecated enum-like class for reduction operations: SUM, PRODUCT, MIN, and MAX. - -ReduceOp is recommended to use instead. - -The distributed package comes with a distributed key-value store, which can be used to share information between processes in the group as well as to initialize the distributed package in torch.distributed.init_process_group() (by explicitly creating the store as an alternative to specifying init_method.) There are 3 choices for Key-Value Stores: TCPStore, FileStore, and HashStore. - -Base class for all store implementations, such as the 3 provided by PyTorch distributed: (TCPStore, FileStore, and HashStore). - -The first call to add for a given key creates a counter associated with key in the store, initialized to amount. Subsequent calls to add with the same key increment the counter by the specified amount. Calling add() with a key that has already been set in the store by set() will result in an exception. - -key (str) – The key in the store whose counter will be incremented. - -amount (int) – The quantity by which the counter will be incremented. - -Append the key-value pair into the store based on the supplied key and value. If key does not exists in the store, it will be created. - -key (str) – The key to be appended to the store. - -value (str) – The value associated with key to be added to the store. - -The call to check whether a given list of keys have value stored in the store. This call immediately returns in normal cases but still suffers from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed. Calling check() with a list of keys that one wants to check whether stored in the store or not. - -keys (list[str]) – The keys to query whether stored in the store. - -Clones the store and returns a new object that points to the same underlying store. The returned store can be used concurrently with the original object. This is intended to provide a safe way to use a store from multiple threads by cloning one store per thread. - -Inserts the key-value pair into the store based on the supplied key and performs comparison between expected_value and desired_value before inserting. desired_value will only be set if expected_value for the key already exists in the store or if expected_value is an empty string. - -key (str) – The key to be checked in the store. - -expected_value (str) – The value associated with key to be checked before insertion. - -desired_value (str) – The value associated with key to be added to the store. - -Deletes the key-value pair associated with key from the store. Returns true if the key was successfully deleted, and false if it was not. - -The delete_key API is only supported by the TCPStore and HashStore. Using this API with the FileStore will result in an exception. - -key (str) – The key to be deleted from the store - -True if key was deleted, otherwise False. - -Retrieves the value associated with the given key in the store. If key is not present in the store, the function will wait for timeout, which is defined when initializing the store, before throwing an exception. - -key (str) – The function will return the value associated with this key. - -Value associated with key if key is in the store. - -Returns true if the store supports extended operations. - -Retrieve all values in keys. If any key in keys is not present in the store, the function will wait for timeout - -keys (List[str]) – The keys to be retrieved from the store. - -Inserts a list key-value pair into the store based on the supplied keys and values - -keys (List[str]) – The keys to insert. - -values (List[str]) – The values to insert. - -Returns the number of keys set in the store. Note that this number will typically be one greater than the number of keys added by set() and add() since one key is used to coordinate all the workers using the store. - -When used with the TCPStore, num_keys returns the number of keys written to the underlying file. If the store is destructed and another store is created with the same file, the original keys will be retained. - -The number of keys present in the store. - -Returns the length of the specified queue. - -If the queue doesn’t exist it returns 0. - -See queue_push for more details. - -key (str) – The key of the queue to get the length. - -Pops a value from the specified queue or waits until timeout if the queue is empty. - -See queue_push for more details. - -If block is False, a dist.QueueEmptyError will be raised if the queue is empty. - -key (str) – The key of the queue to pop from. - -block (bool) – Whether to block waiting for the key or immediately return. - -Pushes a value into the specified queue. - -Using the same key for queues and set/get operations may result in unexpected behavior. - -wait/check operations are supported for queues. - -wait with queues will only wake one waiting worker rather than all. - -key (str) – The key of the queue to push to. - -value (str) – The value to push into the queue. - -Inserts the key-value pair into the store based on the supplied key and value. If key already exists in the store, it will overwrite the old value with the new supplied value. - -key (str) – The key to be added to the store. - -value (str) – The value associated with key to be added to the store. - -Sets the store’s default timeout. This timeout is used during initialization and in wait() and get(). - -timeout (timedelta) – timeout to be set in the store. - -Gets the timeout of the store. - -wait(self: torch._C._distributed_c10d.Store, arg0: collections.abc.Sequence[str]) -> None - -Waits for each key in keys to be added to the store. If not all keys are set before the timeout (set during store initialization), then wait will throw an exception. - -keys (list) – List of keys on which to wait until they are set in the store. - -wait(self: torch._C._distributed_c10d.Store, arg0: collections.abc.Sequence[str], arg1: datetime.timedelta) -> None - -Waits for each key in keys to be added to the store, and throws an exception if the keys have not been set by the supplied timeout. - -keys (list) – List of keys on which to wait until they are set in the store. - -timeout (timedelta) – Time to wait for the keys to be added before throwing an exception. - -A TCP-based distributed key-value store implementation. The server store holds the data, while the client stores can connect to the server store over TCP and perform actions such as set() to insert a key-value pair, get() to retrieve a key-value pair, etc. There should always be one server store initialized because the client store(s) will wait for the server to establish a connection. - -host_name (str) – The hostname or IP Address the server store should run on. - -port (int) – The port on which the server store should listen for incoming requests. - -world_size (int, optional) – The total number of store users (number of clients + 1 for the server). Default is None (None indicates a non-fixed number of store users). - -is_master (bool, optional) – True when initializing the server store and False for client stores. Default is False. - -timeout (timedelta, optional) – Timeout used by the store during initialization and for methods such as get() and wait(). Default is timedelta(seconds=300) - -wait_for_workers (bool, optional) – Whether to wait for all the workers to connect with the server store. This is only applicable when world_size is a fixed value. Default is True. - -multi_tenant (bool, optional) – If True, all TCPStore instances in the current process with the same host/port will use the same underlying TCPServer. Default is False. - -master_listen_fd (int, optional) – If specified, the underlying TCPServer will listen on this file descriptor, which must be a socket already bound to port. To bind an ephemeral port we recommend setting the port to 0 and reading .port. Default is None (meaning the server creates a new socket and attempts to bind it to port). - -use_libuv (bool, optional) – If True, use libuv for TCPServer backend. Default is True. - -Creates a new TCPStore. - -Gets the hostname on which the store listens for requests. - -Returns True if it’s using the libuv backend. - -Gets the port number on which the store listens for requests. - -A thread-safe store implementation based on an underlying hashmap. This store can be used within the same process (for example, by other threads), but cannot be used across processes. - -Creates a new HashStore. - -A store implementation that uses a file to store the underlying key-value pairs. - -file_name (str) – path of the file in which to store the key-value pairs - -world_size (int, optional) – The total number of processes using the store. Default is -1 (a negative value indicates a non-fixed number of store users). - -Creates a new FileStore. - -Gets the path of the file used by FileStore to store key-value pairs. - -A wrapper around any of the 3 key-value stores (TCPStore, FileStore, and HashStore) that adds a prefix to each key inserted to the store. - -prefix (str) – The prefix string that is prepended to each key before being inserted into the store. - -store (torch.distributed.store) – A store object that forms the underlying key-value store. - -Creates a new PrefixStore. - -Gets the underlying store object that PrefixStore wraps around. - -Note that you can use torch.profiler (recommended, only available after 1.8.1) or torch.autograd.profiler to profile collective communication and point-to-point communication APIs mentioned here. All out-of-the-box backends (gloo, nccl, mpi) are supported and collective communication usage will be rendered as expected in profiling output/traces. Profiling your code is the same as any regular torch operator: - -Please refer to the profiler documentation for a full overview of profiler features. - -The multi-GPU functions (which stand for multiple GPUs per CPU thread) are deprecated. As of today, PyTorch Distributed’s preferred programming model is one device per thread, as exemplified by the APIs in this document. If you are a backend developer and want to support multiple devices per thread, please contact PyTorch Distributed’s maintainers. - -Object collectives have a number of serious limitations. Read further to determine if they are safe to use for your use case. - -Object collectives are a set of collective-like operations that work on arbitrary Python objects, as long as they can be pickled. There are various collective patterns implemented (e.g. broadcast, all_gather, …) but they each roughly follow this pattern: - -convert the input object into a pickle (raw bytes), then shove it into a byte tensor - -communicate the size of this byte tensor to peers (first collective operation) - -allocate appropriately sized tensor to perform the real collective - -communicate the object data (second collective operation) - -convert raw data back into Python (unpickle) - -Object collectives sometimes have surprising performance or memory characteristics that lead to long runtimes or OOMs, and thus they should be used with caution. Here are some common issues. - -Asymmetric pickle/unpickle time - Pickling objects can be slow, depending on the number, type and size of the objects. When the collective has a fan-in (e.g. gather_object), the receiving rank(s) must unpickle N times more objects than the sending rank(s) had to pickle, which can cause other ranks to time out on their next collective. - -Inefficient tensor communication - Tensors should be sent via regular collective APIs, not object collective APIs. It is possible to send Tensors via object collective APIs, but they will be serialized and deserialized (including a CPU-sync and device-to-host copy in the case of non-CPU tensors), and in almost every case other than debugging or troubleshooting code, it would be worth the trouble to refactor the code to use non-object collectives instead. - -Unexpected tensor devices - If you still want to send tensors via object collectives, there is another aspect specific to cuda (and possibly other accelerators) tensors. If you pickle a tensor that is currently on cuda:3, and then unpickle it, you will get another tensor on cuda:3 regardless of which process you are on, or which CUDA device is the ‘default’ device for that process. With regular tensor collective APIs, ‘output tensors’ will always be on the same, local device, which is generally what you’d expect. - -Unpickling a tensor will implicitly activate a CUDA context if it is the first time a GPU is used by the process, which can waste significant amounts of GPU memory. This issue can be avoided by moving tensors to CPU before passing them as inputs to an object collective. - -Besides the builtin GLOO/MPI/NCCL backends, PyTorch distributed supports third-party backends through a run-time register mechanism. For references on how to develop a third-party backend through C++ Extension, please refer to Tutorials - Custom C++ and CUDA Extensions and test/cpp_extensions/cpp_c10d_extension.cpp. The capability of third-party backends are decided by their own implementations. - -The new backend derives from c10d::ProcessGroup and registers the backend name and the instantiating interface through torch.distributed.Backend.register_backend() when imported. - -When manually importing this backend and invoking torch.distributed.init_process_group() with the corresponding backend name, the torch.distributed package runs on the new backend. - -The support of third-party backend is experimental and subject to change. - -The torch.distributed package also provides a launch utility in torch.distributed.launch. This helper utility can be used to launch multiple processes per node for distributed training. - -Module torch.distributed.launch. - -torch.distributed.launch is a module that spawns up multiple distributed training processes on each of the training nodes. - -This module is going to be deprecated in favor of torchrun. - -The utility can be used for single-node distributed training, in which one or more processes per node will be spawned. The utility can be used for either CPU training or GPU training. If the utility is used for GPU training, each distributed process will be operating on a single GPU. This can achieve well-improved single-node training performance. It can also be used in multi-node distributed training, by spawning up multiple processes on each node for well-improved multi-node distributed training performance as well. This will especially be beneficial for systems with multiple Infiniband interfaces that have direct-GPU support, since all of them can be utilized for aggregated communication bandwidth. - -In both cases of single-node distributed training or multi-node distributed training, this utility will launch the given number of processes per node (--nproc-per-node). If used for GPU training, this number needs to be less or equal to the number of GPUs on the current system (nproc_per_node), and each process will be operating on a single GPU from GPU 0 to GPU (nproc_per_node - 1). - -How to use this module: - -Single-Node multi-process distributed training - -Multi-Node multi-process distributed training: (e.g. two nodes) - -Node 1: (IP: 192.168.1.1, and has a free port: 1234) - -To look up what optional arguments this module offers: - -1. This utility and multi-process distributed (single-node or multi-node) GPU training currently only achieves the best performance using the NCCL distributed backend. Thus NCCL backend is the recommended backend to use for GPU training. - -2. In your training program, you must parse the command-line argument: --local-rank=LOCAL_PROCESS_RANK, which will be provided by this module. If your training program uses GPUs, you should ensure that your code only runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: - -Parsing the local_rank argument - -Set your device to local rank using either - -Changed in version 2.0.0: The launcher will passes the --local-rank= argument to your script. From PyTorch 2.0.0 onwards, the dashed --local-rank is preferred over the previously used underscored --local_rank. - -For backward compatibility, it may be necessary for users to handle both cases in their argument parsing code. This means including both "--local-rank" and "--local_rank" in the argument parser. If only "--local_rank" is provided, the launcher will trigger an error: “error: unrecognized arguments: –local-rank=”. For training code that only supports PyTorch 2.0.0+, including "--local-rank" should be sufficient. - -3. In your training program, you are supposed to call the following function at the beginning to start the distributed backend. It is strongly recommended that init_method=env://. Other init methods (e.g. tcp://) may work, but env:// is the one that is officially supported by this module. - -4. In your training program, you can either use regular distributed functions or use torch.nn.parallel.DistributedDataParallel() module. If your training program uses GPUs for training and you would like to use torch.nn.parallel.DistributedDataParallel() module, here is how to configure it. - -Please ensure that device_ids argument is set to be the only GPU device id that your code will be operating on. This is generally the local rank of the process. In other words, the device_ids needs to be [args.local_rank], and output_device needs to be args.local_rank in order to use this utility - -5. Another way to pass local_rank to the subprocesses via environment variable LOCAL_RANK. This behavior is enabled when you launch the script with --use-env=True. You must adjust the subprocess example above to replace args.local_rank with os.environ['LOCAL_RANK']; the launcher will not pass --local-rank when you specify this flag. - -local_rank is NOT globally unique: it is only unique per process on a machine. Thus, don’t use it to decide if you should, e.g., write to a networked filesystem. See pytorch/pytorch#12042 for an example of how things can go wrong if you don’t do this correctly. - -The Multiprocessing package - torch.multiprocessing package also provides a spawn function in torch.multiprocessing.spawn(). This helper function can be used to spawn multiple processes. It works by passing in the function that you want to run and spawns N processes to run it. This can be used for multiprocess distributed training as well. - -For references on how to use it, please refer to PyTorch example - ImageNet implementation - -Note that this function requires Python 3.4 or higher. - -Debugging distributed applications can be challenging due to hard to understand hangs, crashes, or inconsistent behavior across ranks. torch.distributed provides a suite of tools to help debug training applications in a self-serve fashion: - -It is extremely convenient to use python’s debugger in a distributed environment, but because it does not work out of the box many people do not use it at all. PyTorch offers a customized wrapper around pdb that streamlines the process. - -torch.distributed.breakpoint makes this process easy. Internally, it customizes pdb’s breakpoint behavior in two ways but otherwise behaves as normal pdb. - -Attaches the debugger only on one rank (specified by the user). - -Ensures all other ranks stop, by using a torch.distributed.barrier() that will release once the debugged rank issues a continue - -Reroutes stdin from the child process such that it connects to your terminal. - -To use it, simply issue torch.distributed.breakpoint(rank) on all ranks, using the same value for rank in each case. - -As of v1.10, torch.distributed.monitored_barrier() exists as an alternative to torch.distributed.barrier() which fails with helpful information about which rank may be faulty when crashing, i.e. not all ranks calling into torch.distributed.monitored_barrier() within the provided timeout. torch.distributed.monitored_barrier() implements a host-side barrier using send/recv communication primitives in a process similar to acknowledgements, allowing rank 0 to report which rank(s) failed to acknowledge the barrier in time. As an example, consider the following function where rank 1 fails to call into torch.distributed.monitored_barrier() (in practice this could be due to an application bug or hang in a previous collective): - -The following error message is produced on rank 0, allowing the user to determine which rank(s) may be faulty and investigate further: - -With TORCH_CPP_LOG_LEVEL=INFO, the environment variable TORCH_DISTRIBUTED_DEBUG can be used to trigger additional useful logging and collective synchronization checks to ensure all ranks are synchronized appropriately. TORCH_DISTRIBUTED_DEBUG can be set to either OFF (default), INFO, or DETAIL depending on the debugging level required. Please note that the most verbose option, DETAIL may impact the application performance and thus should only be used when debugging issues. - -Setting TORCH_DISTRIBUTED_DEBUG=INFO will result in additional debug logging when models trained with torch.nn.parallel.DistributedDataParallel() are initialized, and TORCH_DISTRIBUTED_DEBUG=DETAIL will additionally log runtime performance statistics a select number of iterations. These runtime statistics include data such as forward time, backward time, gradient communication time, etc. As an example, given the following application: - -The following logs are rendered at initialization time: - -The following logs are rendered during runtime (when TORCH_DISTRIBUTED_DEBUG=DETAIL is set): - -In addition, TORCH_DISTRIBUTED_DEBUG=INFO enhances crash logging in torch.nn.parallel.DistributedDataParallel() due to unused parameters in the model. Currently, find_unused_parameters=True must be passed into torch.nn.parallel.DistributedDataParallel() initialization if there are parameters that may be unused in the forward pass, and as of v1.10, all model outputs are required to be used in loss computation as torch.nn.parallel.DistributedDataParallel() does not support unused parameters in the backwards pass. These constraints are challenging especially for larger models, thus when crashing with an error, torch.nn.parallel.DistributedDataParallel() will log the fully qualified name of all parameters that went unused. For example, in the above application, if we modify loss to be instead computed as loss = output[1], then TwoLinLayerNet.a does not receive a gradient in the backwards pass, and thus results in DDP failing. On a crash, the user is passed information about parameters which went unused, which may be challenging to manually find for large models: - -Setting TORCH_DISTRIBUTED_DEBUG=DETAIL will trigger additional consistency and synchronization checks on every collective call issued by the user either directly or indirectly (such as DDP allreduce). This is done by creating a wrapper process group that wraps all process groups returned by torch.distributed.init_process_group() and torch.distributed.new_group() APIs. As a result, these APIs will return a wrapper process group that can be used exactly like a regular process group, but performs consistency checks before dispatching the collective to an underlying process group. Currently, these checks include a torch.distributed.monitored_barrier(), which ensures all ranks complete their outstanding collective calls and reports ranks which are stuck. Next, the collective itself is checked for consistency by ensuring all collective functions match and are called with consistent tensor shapes. If this is not the case, a detailed error report is included when the application crashes, rather than a hang or uninformative error message. As an example, consider the following function which has mismatched input shapes into torch.distributed.all_reduce(): - -With the NCCL backend, such an application would likely result in a hang which can be challenging to root-cause in nontrivial scenarios. If the user enables TORCH_DISTRIBUTED_DEBUG=DETAIL and reruns the application, the following error message reveals the root cause: - -For fine-grained control of the debug level during runtime the functions torch.distributed.set_debug_level(), torch.distributed.set_debug_level_from_env(), and torch.distributed.get_debug_level() can also be used. - -In addition, TORCH_DISTRIBUTED_DEBUG=DETAIL can be used in conjunction with TORCH_SHOW_CPP_STACKTRACES=1 to log the entire callstack when a collective desynchronization is detected. These collective desynchronization checks will work for all applications that use c10d collective calls backed by process groups created with the torch.distributed.init_process_group() and torch.distributed.new_group() APIs. - -In addition to explicit debugging support via torch.distributed.monitored_barrier() and TORCH_DISTRIBUTED_DEBUG, the underlying C++ library of torch.distributed also outputs log messages at various levels. These messages can be helpful to understand the execution state of a distributed training job and to troubleshoot problems such as network connection failures. The following matrix shows how the log level can be adjusted via the combination of TORCH_CPP_LOG_LEVEL and TORCH_DISTRIBUTED_DEBUG environment variables. - -TORCH_DISTRIBUTED_DEBUG - -Distributed components raise custom Exception types derived from RuntimeError: - -torch.distributed.DistError: This is the base type of all distributed exceptions. - -torch.distributed.DistBackendError: This exception is thrown when a backend-specific error occurs. For example, if the NCCL backend is used and the user attempts to use a GPU that is not available to the NCCL library. - -torch.distributed.DistNetworkError: This exception is thrown when networking libraries encounter errors (ex: Connection reset by peer) - -torch.distributed.DistStoreError: This exception is thrown when the Store encounters an error (ex: TCPStore timeout) - -Exception raised when an error occurs in the distributed library - -Exception raised when a backend error occurs in distributed - -Exception raised when a network error occurs in distributed - -Exception raised when an error occurs in the distributed store - -If you are running single node training, it may be convenient to interactively breakpoint your script. We offer a way to conveniently breakpoint a single rank: - -Set a breakpoint, but only on a single rank. All other ranks will wait for you to be done with the breakpoint before continuing. - -rank (int) – Which rank to break on. Default: 0 - -skip (int) – Skip the first skip calls to this breakpoint. Default: 0. - ---- - -## DistributedDataParallel# - -**URL:** https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html - -**Contents:** -- DistributedDataParallel# - -Implement distributed data parallelism based on torch.distributed at module level. - -This container provides data parallelism by synchronizing gradients across each model replica. The devices to synchronize across are specified by the input process_group, which is the entire world by default. Note that DistributedDataParallel does not chunk or otherwise shard the input across participating GPUs; the user is responsible for defining how to do so, for example through the use of a DistributedSampler. - -See also: Basics and Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel. The same constraints on input as in torch.nn.DataParallel apply. - -Creation of this class requires that torch.distributed to be already initialized, by calling torch.distributed.init_process_group(). - -DistributedDataParallel is proven to be significantly faster than torch.nn.DataParallel for single-node multi-GPU data parallel training. - -To use DistributedDataParallel on a host with N GPUs, you should spawn up N processes, ensuring that each process exclusively works on a single GPU from 0 to N-1. This can be done by either setting CUDA_VISIBLE_DEVICES for every process or by calling the following API for GPUs, - -or calling the unified API for accelerator, - -where i is from 0 to N-1. In each process, you should refer the following to construct this module: - -Or you can use the latest API for initialization: - -In order to spawn up multiple processes per node, you can use either torch.distributed.launch or torch.multiprocessing.spawn. - -Please refer to PyTorch Distributed Overview for a brief introduction to all features related to distributed training. - -DistributedDataParallel can be used in conjunction with torch.distributed.optim.ZeroRedundancyOptimizer to reduce per-rank optimizer states memory footprint. Please refer to ZeroRedundancyOptimizer recipe for more details. - -nccl backend is currently the fastest and highly recommended backend when using GPUs. This applies to both single-node and multi-node distributed training. - -This module also supports mixed-precision distributed training. This means that your model can have different types of parameters such as mixed types of fp16 and fp32, the gradient reduction on these mixed types of parameters will just work fine. - -If you use torch.save on one process to checkpoint the module, and torch.load on some other processes to recover it, make sure that map_location is configured properly for every process. Without map_location, torch.load would recover the module to devices where the module was saved from. - -When a model is trained on M nodes with batch=N, the gradient will be M times smaller when compared to the same model trained on a single node with batch=M*N if the loss is summed (NOT averaged as usual) across instances in a batch (because the gradients between different nodes are averaged). You should take this into consideration when you want to obtain a mathematically equivalent training process compared to the local training counterpart. But in most cases, you can just treat a DistributedDataParallel wrapped model, a DataParallel wrapped model and an ordinary model on a single GPU as the same (E.g. using the same learning rate for equivalent batch size). - -Parameters are never broadcast between processes. The module performs an all-reduce step on gradients and assumes that they will be modified by the optimizer in all processes in the same way. Buffers (e.g. BatchNorm stats) are broadcast from the module in process of rank 0, to all other replicas in the system in every iteration. - -If you are using DistributedDataParallel in conjunction with the Distributed RPC Framework, you should always use torch.distributed.autograd.backward() to compute gradients and torch.distributed.optim.DistributedOptimizer for optimizing parameters. - -DistributedDataParallel currently offers limited support for gradient checkpointing with torch.utils.checkpoint(). If the checkpoint is done with use_reentrant=False (recommended), DDP will work as expected without any limitations. If, however, the checkpoint is done with use_reentrant=True (the default), DDP will work as expected when there are no unused parameters in the model and each layer is checkpointed at most once (make sure you are not passing find_unused_parameters=True to DDP). We currently do not support the case where a layer is checkpointed multiple times, or when there unused parameters in the checkpointed model. - -To let a non-DDP model load a state dict from a DDP model, consume_prefix_in_state_dict_if_present() needs to be applied to strip the prefix “module.” in the DDP state dict before loading. - -Constructor, forward method, and differentiation of the output (or a function of the output of this module) are distributed synchronization points. Take that into account in case different processes might be executing different code. - -This module assumes all parameters are registered in the model by the time it is created. No parameters should be added nor removed later. Same applies to buffers. - -This module assumes all parameters are registered in the model of each distributed processes are in the same order. The module itself will conduct gradient allreduce following the reverse order of the registered parameters of the model. In other words, it is users’ responsibility to ensure that each distributed process has the exact same model and thus the exact same parameter registration order. - -This module allows parameters with non-rowmajor-contiguous strides. For example, your model may contain some parameters whose torch.memory_format is torch.contiguous_format and others whose format is torch.channels_last. However, corresponding parameters in different processes must have the same strides. - -This module doesn’t work with torch.autograd.grad() (i.e. it will only work if gradients are to be accumulated in .grad attributes of parameters). - -If you plan on using this module with a nccl backend or a gloo backend (that uses Infiniband), together with a DataLoader that uses multiple workers, please change the multiprocessing start method to forkserver (Python 3 only) or spawn. Unfortunately Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will likely experience deadlocks if you don’t change this setting. - -You should never try to change your model’s parameters after wrapping up your model with DistributedDataParallel. Because, when wrapping up your model with DistributedDataParallel, the constructor of DistributedDataParallel will register the additional gradient reduction functions on all the parameters of the model itself at the time of construction. If you change the model’s parameters afterwards, gradient reduction functions no longer match the correct set of parameters. - -Using DistributedDataParallel in conjunction with the Distributed RPC Framework is experimental and subject to change. - -module (Module) – module to be parallelized - -device_ids (list of int or torch.device) – CUDA devices. 1) For single-device modules, device_ids can contain exactly one device id, which represents the only CUDA device where the input module corresponding to this process resides. Alternatively, device_ids can also be None. 2) For multi-device modules and CPU modules, device_ids must be None. When device_ids is None for both cases, both the input data for the forward pass and the actual module must be placed on the correct device. (default: None) - -CUDA devices. 1) For single-device modules, device_ids can contain exactly one device id, which represents the only CUDA device where the input module corresponding to this process resides. Alternatively, device_ids can also be None. 2) For multi-device modules and CPU modules, device_ids must be None. - -When device_ids is None for both cases, both the input data for the forward pass and the actual module must be placed on the correct device. (default: None) - -output_device (int or torch.device) – Device location of output for single-device CUDA modules. For multi-device modules and CPU modules, it must be None, and the module itself dictates the output location. (default: device_ids[0] for single-device modules) - -broadcast_buffers (bool) – Flag that enables syncing (broadcasting) buffers of the module at beginning of the forward function. (default: True) - -init_sync (bool) – Whether to sync during initialization to verify param shapes and broadcast parameters and buffers. WARNING: if this is set to False the user is required to ensure themselves that the weights are the same on all ranks. (default: True) - -process_group – The process group to be used for distributed data all-reduction. If None, the default process group, which is created by torch.distributed.init_process_group(), will be used. (default: None) - -bucket_cap_mb – DistributedDataParallel will bucket parameters into multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation. bucket_cap_mb controls the bucket size in MebiBytes (MiB). If None, a default size of 25 MiB will be used. (default: None) - -find_unused_parameters (bool) – Traverse the autograd graph from all tensors contained in the return value of the wrapped module’s forward function. Parameters that don’t receive gradients as part of this graph are preemptively marked as being ready to be reduced. In addition, parameters that may have been used in the wrapped module’s forward function but were not part of loss computation and thus would also not receive gradients are preemptively marked as ready to be reduced. (default: False) - -check_reduction – This argument is deprecated. - -gradient_as_bucket_view (bool) – When set to True, gradients will be views pointing to different offsets of allreduce communication buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size. Moreover, it avoids the overhead of copying between gradients and allreduce communication buckets. When gradients are views, detach_() cannot be called on the gradients. If hitting such errors, please fix it by referring to the zero_grad() function in torch/optim/optimizer.py as a solution. Note that gradients will be views after first iteration, so the peak memory saving should be checked after first iteration. - -static_graph (bool) – When set to True, DDP knows the trained graph is static. Static graph means 1) The set of used and unused parameters will not change during the whole training loop; in this case, it does not matter whether users set find_unused_parameters = True or not. 2) How the graph is trained will not change during the whole training loop (meaning there is no control flow depending on iterations). When static_graph is set to be True, DDP will support cases that can not be supported in the past: 1) Reentrant backwards. 2) Activation checkpointing multiple times. 3) Activation checkpointing when model has unused parameters. 4) There are model parameters that are outside of forward function. 5) Potentially improve performance when there are unused parameters, as DDP will not search graph in each iteration to detect unused parameters when static_graph is set to be True. To check whether you can set static_graph to be True, one way is to check ddp logging data at the end of your previous model training, if ddp_logging_data.get("can_set_static_graph") == True, mostly you can set static_graph = True as well. Example::>>> model_DDP = torch.nn.parallel.DistributedDataParallel(model) >>> # Training loop >>> ... >>> ddp_logging_data = model_DDP._get_ddp_logging_data() >>> static_graph = ddp_logging_data.get("can_set_static_graph") - -When set to True, DDP knows the trained graph is static. Static graph means 1) The set of used and unused parameters will not change during the whole training loop; in this case, it does not matter whether users set find_unused_parameters = True or not. 2) How the graph is trained will not change during the whole training loop (meaning there is no control flow depending on iterations). When static_graph is set to be True, DDP will support cases that can not be supported in the past: 1) Reentrant backwards. 2) Activation checkpointing multiple times. 3) Activation checkpointing when model has unused parameters. 4) There are model parameters that are outside of forward function. 5) Potentially improve performance when there are unused parameters, as DDP will not search graph in each iteration to detect unused parameters when static_graph is set to be True. To check whether you can set static_graph to be True, one way is to check ddp logging data at the end of your previous model training, if ddp_logging_data.get("can_set_static_graph") == True, mostly you can set static_graph = True as well. - -delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter) – a list of named parameters whose all reduce will be delayed when the gradient of the parameter specified in param_to_hook_all_reduce is ready. Other arguments of DDP do not apply to named params specified in this argument as these named params will be ignored by DDP reducer. - -param_to_hook_all_reduce (torch.nn.Parameter) – a parameter to hook delayed all reduce of parameters specified in delay_all_reduce_named_params. - -skip_all_reduce_unused_params – When set to True, DDP will skip reducing unused parameters. This requires that unused parameters remain the same across all ranks throughout the entire training process. If this condition is not met, it may cause desynchronization and result in training hang. - -module (Module) – the module to be parallelized. - -Context manager for training with uneven inputs across processes in DDP. - -This context manager will keep track of already-joined DDP processes, and “shadow” the forward and backward passes by inserting collective communication operations to match with the ones created by non-joined DDP processes. This will ensure each collective call has a corresponding call by already-joined DDP processes, preventing hangs or errors that would otherwise happen when training with uneven inputs across processes. Alternatively, if the flag throw_on_early_termination is specified to be True, all trainers will throw an error once one rank runs out of inputs, allowing these errors to be caught and handled according to application logic. - -Once all DDP processes have joined, the context manager will broadcast the model corresponding to the last joined process to all processes to ensure the model is the same across all processes (which is guaranteed by DDP). - -To use this to enable training with uneven inputs across processes, simply wrap this context manager around your training loop. No further modifications to the model or data loading is required. - -If the model or training loop this context manager is wrapped around has additional distributed collective operations, such as SyncBatchNorm in the model’s forward pass, then the flag throw_on_early_termination must be enabled. This is because this context manager is not aware of non-DDP collective communication. This flag will cause all ranks to throw when any one rank exhausts inputs, allowing these errors to be caught and recovered from across all ranks. - -divide_by_initial_world_size (bool) – If True, will divide gradients by the initial world_size DDP training was launched with. If False, will compute the effective world size (number of ranks that have not depleted their inputs yet) and divide gradients by that during allreduce. Set divide_by_initial_world_size=True to ensure every input sample including the uneven inputs have equal weight in terms of how much they contribute to the global gradient. This is achieved by always dividing the gradient by the initial world_size even when we encounter uneven inputs. If you set this to False, we divide the gradient by the remaining number of nodes. This ensures parity with training on a smaller world_size although it also means the uneven inputs would contribute more towards the global gradient. Typically, you would want to set this to True for cases where the last few inputs of your training job are uneven. In extreme cases, where there is a large discrepancy in the number of inputs, setting this to False might provide better results. - -enable (bool) – Whether to enable uneven input detection or not. Pass in enable=False to disable in cases where you know that inputs are even across participating processes. Default is True. - -throw_on_early_termination (bool) – Whether to throw an error or continue training when at least one rank has exhausted inputs. If True, will throw upon the first rank reaching end of data. If False, will continue training with a smaller effective world size until all ranks are joined. Note that if this flag is specified, then the flag divide_by_initial_world_size would be ignored. Default is False. - -DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes. - -kwargs (dict) – a dict containing any keyword arguments to modify the behavior of the join hook at run time; all Joinable instances sharing the same join context manager are forwarded the same value for kwargs. - -If True, then gradients are divided by the initial world size that DDP was launched with. If False, then gradients are divided by the effective world size (i.e. the number of non-joined processes), meaning that the uneven inputs contribute more toward the global gradient. Typically, this should be set to True if the degree of unevenness is small but can be set to False in extreme cases for possibly better results. Default is True. - -Context manager to disable gradient synchronizations across DDP processes. - -Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass exiting the context. - -The forward pass should be included inside the context manager, or else gradients will still be synchronized. - -Register communication hook for user-defined DDP aggregation of gradients across multiple workers. - -This hook would be very useful for researchers to try out new ideas. For example, this hook can be used to implement several algorithms like GossipGrad and gradient compression which involve different communication strategies for parameter syncs while running Distributed DataParallel training. - -state (object) – Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in GossipGrad, etc. It is locally stored by each worker and shared by all the gradient tensors on the worker. - -Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in GossipGrad, etc. - -It is locally stored by each worker and shared by all the gradient tensors on the worker. - -hook (Callable) – Callable with the following signature: hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: This function is called once the bucket is ready. The hook can perform whatever processing is needed and return a Future indicating completion of any async work (ex: allreduce). If the hook doesn’t perform any communication, it still must return a completed Future. The Future should hold the new value of grad bucket’s tensors. Once a bucket is ready, c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. Note that the future’s return type must be a single tensor. We also provide an API called get_future to retrieve a Future associated with the completion of c10d.ProcessGroup.Work. get_future is currently supported for NCCL and also supported for most operations on GLOO and MPI, except for peer to peer operations (send/recv). - -Callable with the following signature: hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: - -This function is called once the bucket is ready. The hook can perform whatever processing is needed and return a Future indicating completion of any async work (ex: allreduce). If the hook doesn’t perform any communication, it still must return a completed Future. The Future should hold the new value of grad bucket’s tensors. Once a bucket is ready, c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. Note that the future’s return type must be a single tensor. - -We also provide an API called get_future to retrieve a Future associated with the completion of c10d.ProcessGroup.Work. get_future is currently supported for NCCL and also supported for most operations on GLOO and MPI, except for peer to peer operations (send/recv). - -Grad bucket’s tensors will not be predivided by world_size. User is responsible to divide by the world_size in case of operations like allreduce. - -DDP communication hook can only be registered once and should be registered before calling backward. - -The Future object that hook returns should contain a single tensor that has the same shape with the tensors inside grad bucket. - -get_future API supports NCCL, and partially GLOO and MPI backends (no support for peer-to-peer operations like send/recv) and will return a torch.futures.Future. - -Below is an example of a noop hook that returns the same tensor. - -Below is an example of a Parallel SGD algorithm where gradients are encoded before allreduce, and then decoded after allreduce. - ---- - -## DDP Communication Hooks# - -**URL:** https://pytorch.org/docs/stable/ddp_comm_hooks.html - -**Contents:** -- DDP Communication Hooks# -- How to Use a Communication Hook?# -- What Does a Communication Hook Operate On?# -- Default Communication Hooks# -- PowerSGD Communication Hook# - - PowerSGD State# - - PowerSGD Hooks# -- Debugging Communication Hooks# -- Checkpointing of Communication Hooks# -- Acknowledgements# - -Created On: Jun 06, 2025 | Last Updated On: Jun 06, 2025 - -DDP communication hook is a generic interface to control how to communicate gradients across workers by overriding the vanilla allreduce in DistributedDataParallel. A few built-in communication hooks are provided, and users can easily apply any of these hooks to optimize communication. Besides, the hook interface can also support user-defined communication strategies for more advanced use cases. - -To use a communication hook, the user just needs to let the DDP model register the hook before the training loop as below. - -torch.nn.parallel.DistributedDataParallel.register_comm_hook() - -A communication hook provides a flexible way to allreduce gradients. Therefore, it mainly operates on the gradients on each replica before allreduce, which are bucketized to increase the overlap between communication and computation. Particularly, torch.distributed.GradBucket represents a bucket of gradient tensors to be allreduced. - -This class mainly passes a flattened gradient tensor (returned by buffer()) to DDP communication hook. This tensor can be further decomposed into a list of per-parameter tensors within this bucket (returned by get_per_parameter_tensors()) to apply layer-wise operations. - -Since the buckets are rebuilt after the first iteration, should not rely on the indices at the beginning of training. - -The index of a bucket that stores gradients of a few contiguous layers. All the gradients are bucketized. - -A flattened 1D torch.Tensor buffer, which can be further decomposed into a list of per-parameter tensors within this bucket. - -A list of torch.Tensor. Each tensor in the list corresponds to a gradient. - -Whether this bucket is the last bucket to allreduce in an iteration. This also means that this bucket corresponds to the first few layers in the forward pass. - -Replaces the tensor in the bucket with the input tensor buffer. - -A list of torch.Tensor. Each tensor in the list corresponds to a model parameter. - -Default communication hooks are simple stateless hooks, so the input state in register_comm_hook is either a process group or None. The input bucket is a torch.distributed.GradBucket object. - -Call allreduce using GradBucket tensors. - -Once gradient tensors are aggregated across all workers, its then callback takes the mean and returns the result. - -If user registers this DDP communication hook, DDP results is expected to be same as the case where no hook was registered. Hence, this won’t change behavior of DDP and user can use this as a reference or modify this hook to log useful information or any other purposes while unaffecting DDP behavior. - -Compress by casting GradBucket to torch.float16 divided by process group size. - -This DDP communication hook implements a simple gradient compression approach that casts GradBucket tensor to half-precision floating-point format (torch.float16) and then divides it by the process group size. It allreduces those float16 gradient tensors. Once compressed gradient tensors are allreduced, the chained callback decompress casts it back to the input data type (such as float32). - -Warning: This API is experimental, and it requires NCCL version later than 2.9.6. - -This DDP communication hook implements a simple gradient compression approach that casts GradBucket tensor to half-precision Brain floating point format (torch.bfloat16) and then divides it by the process group size. It allreduces those bfloat16 gradient tensors. Once compressed gradient tensors are allreduced, the chained callback decompress casts it back to the input data type (such as float32). - -Additionally, a communication hook wrapper is provided to support fp16_compress_hook() or bf16_compress_hook() as a wrapper, which can be combined with other communication hooks. - -Cast input tensor to torch.float16, cast result of hook back to input dtype. - -This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision floating point format (torch.float16), and casts the resulting tensor of the given hook back to the input data type, such as float32. Therefore, fp16_compress_hook is equivalent to fp16_compress_wrapper(allreduce_hook). - -Callable[[Any, GradBucket], Future[Tensor]] - -Warning: This API is experimental, and it requires NCCL version later than 2.9.6. - -This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision Brain floating point format (torch.bfloat16), and casts the resulting tensor of the given hook back to the input data type, such as float32. - -Therefore, bf16_compress_hook is equivalent to bf16_compress_wrapper(allreduce_hook). - -Callable[[Any, GradBucket], Future[Tensor]] - -PowerSGD (Vogels et al., NeurIPS 2019) is a gradient compression algorithm, which can provide very high compression rates and accelerate bandwidth-bound distributed training. This algorithm needs to maintain both some hyperparameters and the internal state. Therefore, PowerSGD communication hook is a stateful hook, and the user needs to provide a state object defined as below. - -Store both the algorithm’s hyperparameters and internal state for all gradients during training. - -Particularly, matrix_approximation_rank and start_powerSGD_iter are the main hyperparameters that should be tuned by the user. For performance, we suggest to keep binary hyperparameters use_error_feedback and warm_start on. - -matrix_approximation_rank controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression. - -1.1. If matrix_approximation_rank is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy. - -1.2. The increase of matrix_approximation_rank can substantially increase the computation costs of the compression, and the accuracy may not be further improved beyond a certain matrix_approximation_rank threshold. - -To tune matrix_approximation_rank, we suggest to start from 1 and increase by factors of 2 (like an exponential grid search, 1, 2, 4, …), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32. - -start_powerSGD_iter defers PowerSGD compression until step start_powerSGD_iter, and vanilla allreduce runs prior to step start_powerSGD_iter. This hybrid scheme of vanilla allreduce + PowerSGD can effectively improve the accuracy, even a relatively small matrix_approximation_rank is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy. - -To tune start_powerSGD_iter, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. If there is a warm-up stage in the training, start_powerSGD_iter typically should be no less than the number of warm-up steps. - -min_compression_rate is the minimum compression rate required when a layer is compressed. Due to the computation overheads incurred by the compression, a tensor is worth compressing only if there can be sufficient saving in bandwidth, where (num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols. If the specified compression rate threshold cannot be satisfied, the tensor will be directly allreduced without compression. - -Compression statistics are logged every compression_stats_logging_frequency iterations once PowerSGD compression starts. - -orthogonalization_epsilon can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy. - -batch_tensors_with_same_shape controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e., bucket_cap_mb arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set to True if the compression / decompression computation is a bottleneck. - -If error feedback or warm-up is enabled, the minimum value of start_powerSGD_iter allowed in DDP is 2. This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP, and this can conflict with any tensor memorized before the rebuild process. - -PowerSGD typically requires extra memory of the same size as the model’s gradients to enable error feedback, which can compensate for biased compressed communication and improve accuracy. - -PowerSGD hooks may conflict with Apex automatic mixed precision package. Please use PyTorch native automatic mixed precision package instead. - -Implement PowerSGD algorithm. - -This DDP communication hook implements PowerSGD gradient compression algorithm described in the paper. Once gradient tensors are aggregated across all workers, this hook applies compression as follows: - -Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups: - -1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth. - -1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases). - -Handles uncompressed tensors: - -2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression; - -2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor. - -Handles the tensors that should be compressed by PowerSGD compression: - -3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; - -3.2. Computes each P in Ps, which is equal to MQ; - -3.3. Allreduces Ps as a batch; - -3.4. Orthogonalizes each P in Ps; - -3.5. Computes each Q in Qs, which is approximately equal to M^TP; - -3.6. Allreduces Qs as a batch; - -3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T. - -Note that this communication hook enforces vanilla allreduce for the first state.start_powerSGD_iter iterations. This not only gives the user more control over the tradeoff between speedup and accuracy, but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers. - -state (PowerSGDState) – State information to configure the compression rate and support error feedback, warm start, etc. To tune the compression configs, mainly need to tune matrix_approximation_rank, start_powerSGD_iter and min_compression_rate. - -bucket (dist.GradBucket) – Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. Note that since DDP comm hook only supports single process single device mode, only exactly one tensor is stored in this bucket. - -Future handler of the communication, which updates the gradients in place. - -Implement simplified PowerSGD algorithm. - -This DDP communication hook implements a simplified PowerSGD gradient compression algorithm described in the paper. This variant does not compress the gradients layer by layer, but instead compresses the flattened input tensor that batches all the gradients. Therefore, it is faster than powerSGD_hook(), but usually results in a much lower accuracy, unless matrix_approximation_rank is 1. - -Increasing matrix_approximation_rank here may not necessarily increase the accuracy, because batching per-parameter tensors without column/row alignment can destroy low-rank structure. Therefore, the user should always consider powerSGD_hook() first, and only consider this variant when a satisfactory accuracy can be achieved when matrix_approximation_rank is 1. - -Once gradient tensors are aggregated across all workers, this hook applies compression as follows: - -Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings; - -Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; - -Computes P, which is equal to MQ; - -Computes Q, which is approximately equal to M^TP; - -Computes M, which is approximately equal to PQ^T. - -Truncates the input tensor to the original length. - -Note that this communication hook enforces vanilla allreduce for the first state.start_powerSGD_iter iterations. This not only gives the user more control over the tradeoff between speedup and accuracy, but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers. - -state (PowerSGDState) – State information to configure the compression rate and support error feedback, warm start, etc. To tune the compression configs, mainly need to tune matrix_approximation_rank and start_powerSGD_iter. - -bucket (dist.GradBucket) – Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. Note that since DDP comm hook only supports single process single device mode, only exactly one tensor is stored in this bucket. - -Future handler of the communication, which updates the gradients in place. - -As the name implies, debugging communication hooks are only used for debugging and performance optimization purpose. - -Debugging communication hooks do not necessarily output the correct results. - -Return a future that wraps the input, so it is a no-op that does not incur any communication overheads. - -This hook should only be used for headroom analysis of allreduce optimization, instead of the normal gradient synchronization. For example, if only less than 10% speedup of training time can be observed after this hook is registered, it usually implies that allreduce is not a performance bottleneck for this case. Such instrumentation can be particularly useful if GPU traces cannot be easily retrieved or the trace analysis is complicated some factors such as the overlap between allreduce and computation or the desynchronization across ranks. - -A stateful communication hook can be saved as a part of model checkpointing to enable trainer restarts. To make a hook serializable, __setstate__ and __getstate__ should be defined. - -__getstate__ should exclude non-serializable attributes from a returned dictionary. - -__setstate__ should properly initialize non-serializable attributes, excluded from a provided state. - -PowerSGDState has __setstate__ and __getstate__ implemented and can be used as a reference. - -Return a Dict[str, Any] which will be pickled and saved. - -process_group is not serializable and excluded from a returned state. - -Take a provided state and set to this PowerSGDState instance. - -process_group is set to default. - -Here is a simple, end-to-end example of saving and reloading PowerSGD state and hook. - -Many thanks to PowerSGD paper author Thijs Vogels for the code review on PowerSGD communication hook, as well as the comparison experiments, which show that the performance of PowerSGD communication hook is on par with the implementation in the original paper. - ---- - -## Distributed Checkpoint - torch.distributed.checkpoint# - -**URL:** https://pytorch.org/docs/stable/distributed.checkpoint.html - -**Contents:** -- Distributed Checkpoint - torch.distributed.checkpoint# -- Additional resources:# - -Created On: Nov 16, 2022 | Last Updated On: Sep 04, 2025 - -Distributed Checkpoint (DCP) support loading and saving models from multiple ranks in parallel. It handles load-time resharding which enables saving in one cluster topology and loading into another. - -DCP is different than torch.save and torch.load in a few significant ways: - -It produces multiple files per checkpoint, with at least one per rank. - -It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead. - -The entrypoints to load and save a checkpoint are the following: - -Getting Started with Distributed Checkpoint (DCP) - -Asynchronous Saving with Distributed Checkpoint (DCP) - -TorchTitan Checkpointing Docs - -TorchTitan DCP Implementation - -Enum for async checkpointer type. - -This class contains futures for staging and upload completion. It is returned by async_save(). staging_completion is a future that indicates when local copy of state_dict is complete. upload_completion is a future that indicates when a checkpoint completed saving. - -Save a distributed model in SPMD style. - -This function is different from torch.save() as it handles ShardedTensor , and DTensor by having each rank only save their local shards. - -For each Stateful object (having both a state_dict and a load_state_dict), save will call state_dict before serialization. - -There is no guarantees of Backwards Compatibility across PyTorch versions for saved state_dicts. - -If using the process_group argument, make sure that only its ranks call save_state_dict and that all data in state_dict belong to it. - -When saving checkpoint for FSDP’s ShardingStrategy.HYBRID_SHARD, only one of the shard_group should be calling save_state_dict and the corresponding process group needs to be passed in. - -state_dict in the local process. - -state_dict (Dict[str, Any]) – The state_dict to save. - -checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default: None) - -storage_writer (Optional[StorageWriter]) – Instance of StorageWriter used to perform writes. If this is not specified, DCP will automatically infer the writer based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: None) - -planner (Optional[SavePlanner]) – Instance of SavePlanner. If this is not specified, the default planner will be used. (Default: None) - -process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization. (Default: None) - -no_dist (bool) – If True, this function will assume the intent is to load a checkpoint on a single rank/process. (Default: False) - -use_collectives (bool) – If False, this function will assume the intent is to save a checkpoint without using cross-rank synchronization. (Default: True) This configuration is experimental and should be used with caution. It will change the format of the saved checkpoint and may not be backward compatible. - -Metadata object for the saved checkpoint. - -save_state_dict uses collectives to coordinate writes across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). - -Asynchronous version of save. This code first de-stages the state_dict on to the staging storage (defaults to CPU memory), and then calls the save in a separate thread. - -This feature is experimental and subject to change. MUST CALL CLOSE AFTER LAST CHECKPOINT IS SAVED - -state_dict (Dict[str, Any]) – The state_dict to save. - -checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default: None) - -storage_writer (Optional[StorageWriter]) – Instance of StorageWriter used to perform ‘stage’ and ‘save’. If this is not specified, DCP will automatically infer the writer based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: None) - -planner (Optional[SavePlanner]) – Instance of SavePlanner. If this is not specified, the default planner will be used. (Default: None) - -process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization. (Default: None) - -async_checkpointer_type (AsyncCheckpointerType) – whether to do checkpoint in separate thread or process (Default: AsyncCheckpointerType.THREAD) - -async_stager (AsyncStager) – provides staging implementation. If storage_writer implements AsyncStager and async_stager is provided, async_stager will be used for staging - -no_dist (bool) – If True, this function will assume the intent is to save a checkpoint on a single rank/process. (Default: False) - -use_collectives (bool) – If False, Save the checkpoint without rank coordination. (Default: True) This configuration is experimental and should be used with caution. It will change the format of the saved checkpoint and may not be backward compatible. - -A future holding the resultant Metadata object from save. - -This method is deprecated. Please switch to ‘save’. - -Load a checkpoint into a distributed state dict in SPMD style. - -Each rank must have the same keys in their state_dict provided to this API. Mismatched keys may result in hangs or errors. If unsure, you can use the utils._assert_same_keys API to check (but may incur communication costs). - -Each rank will try to read the least amount of data necessary to fulfill the requested state_dict. When loading ShardedTensor or DTensor instances, each rank only reads data for their local shards. - -For each Stateful object (having both a state_dict and a load_state_dict), load will first call state_dict before attempting deserialization, followed by load_state_dict once the deserialization is complete. For each non-Stateful object, load will deserialize the object, and then replace it in the state_dict with the deserialized object. - -All tensors in state_dict must be allocated on their destination device prior to calling this function. - -All non-tensor data is loaded using torch.load() and modified in place on state_dict. - -Users must call load_state_dict on the root module to ensure load pos-processing and non-tensor data properly propagates. - -state_dict (Dict[str, Any]) – The state_dict to load the checkpoint into. - -checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default: None) - -storage_reader (Optional[StorageReader]) – Instance of StorageWriter used to perform reads. If this is not specified, DCP will automatically infer the reader based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: None) - -planner (Optional[LoadPlanner]) – Instance of LoadPlanner. If this is not specified, the default planner will be used. (Default: None) - -process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization. (Default: None) - -no_dist (bool) – If True, this function will assume the intent is to load a checkpoint without using cross-rank synchronization. (Default: False) - -load_state_dict uses collectives to coordinate reads across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). - -This method is deprecated. Please switch to ‘load’. - -The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (torch.distributed.checkpoint.async_save): - -This protocol is meant to provide customization and extensibility for dcp.async_save, allowing users to customize how data is staged previous to executing the usual dcp.save path in parallel. The expected order of operations (concretely defined in torch.distributed.state_dict_saver.async_save) is the following: - -This call gives the AsyncStager the opportunity to ‘stage’ the state_dict. The expectation and purpose of staging in this context is to create a “training-safe” representation of the state dict, meaning that any updates to module data after staging is complete should not be reflected in the state dict returned from this method. For example, in the default case a copy of the entire state dict is created on CPU RAM and returned here, allowing users to continue training without risking changes to data which is being serialized. - -for serializing the state_dict and writing it to storage. - -the serialization thread starts and before returning from dcp.async_save. If this is set to False, the assumption is the user has defined a custom synchronization point for the the purpose of further optimizing save latency in the training loop (for example, by overlapping staging with the forward/backward pass), and it is the respondsibility of the user to call AsyncStager.synchronize_staging at the appropriate time. - -Clean up all resources used by the stager. - -Whether to synchronize after executing the stage. - -Returns a “staged” copy of state_dict. The expectation of the staged copy is that it is inoculated from any updates incurred after the stage call is complete. - -Union[Future[dict[str, Union[~StatefulT, Any]]], dict[str, Union[~StatefulT, Any]]] - -In the case stage is async in some way, this method should be called to ensure staging is complete and it is safe to begin modifying the original state_dict - -DefaultStager provides a full-featured staging implementation that combines multiple optimization techniques for efficient checkpoint preparation. - -The staging process works as follows: 1. State dictionary is submitted for staging (sync or async) 2. Tensors are copied from GPU to optimized CPU storage 3. CUDA operations are synchronized if non-blocking copies are used 4. Staged state dictionary is returned or made available via Future - -# Synchronous staging stager = DefaultStager(StagingOptions(use_async_staging=False)) staged_dict = stager.stage(state_dict) stager.close() - -# Asynchronous staging stager = DefaultStager(StagingOptions(use_async_staging=True)) future = stager.stage(state_dict) # … do other work … staged_dict = future.result() stager.close() - -# Context manager pattern (recommended) stager = DefaultStager(config) with stager: result = stager.stage(state_dict) - -Async staging provides best performance when model computation can overlap with staging operations - -Pinned memory improves CPU-GPU transfer speeds but uses more memory - -Shared memory allows efficient IPC to checkpoint process - -Non-blocking copies reduce GPU idle time during memory transfers - -DefaultStager is not thread-safe. Each thread should use its own instance, or external synchronization should be provided. - -Clean up all resources used by the DefaultStager. Shuts down the ThreadPoolExecutor used for async staging operations and cleans up the underlying StateDictStager’s cached storages. Should be called when the stager is no longer needed to prevent resource leaks, especially in long-running applications. After calling close(), the stager should not be used for further staging operations. - -stager = DefaultStager(StagingOptions(use_async_staging=True)) future = stager.stage(state_dict) result = future.result() stager.close() # Clean up all resources - -This function is responsible for staging staging the state_dict. See class docstring for more details on staging. If use_async_staging is True, it will return a Future object that will be fulfilled when staging is complete. If use_async_staging is False, it will return the fully staged state_dict. - -state_dict (STATE_DICT_TYPE) – The state_dict to be staged. - -Union[dict[str, Union[~StatefulT, Any]], Future[dict[str, Union[~StatefulT, Any]]]] - -When use_async_staging is True, this method will wait until staging is complete. If use_async_staging is False, this method is a no-op. - -Configuration options for checkpoint staging behavior. - -use_pinned_memory (bool) – Enable pinned memory allocation for faster CPU-GPU transfers. Requires CUDA to be available. Default: True - -use_shared_memory (bool) – Enable shared memory for multi-process scenarios. Useful when multiple processes need access to the same staged data. Default: True - -use_async_staging (bool) – Enable asynchronous staging using a background thread pool. Allows overlapping computation with staging operations. Requires CUDA. Default: True - -use_non_blocking_copy (bool) – Use non-blocking device memory copies with stream synchronization. Improves performance by allowing CPU work to continue during GPU transfers. Default: True - -CUDA-dependent features will raise exception if CUDA is not available. - -An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete. This implementation also provides an option to optimize stage latency using pinned memory. - -N.B. synchronize_staging is a no-op in this case. - -Returns a copy of state_dict on the CPU. - -dict[str, Union[~StatefulT, Any]] - -No-op function, since staging is blocking. - -In addition to the above entrypoints, Stateful objects, as described below, provide additional customization during saving/loading - -Stateful protocol for objects that can be checkpointed and restored. - -Restore the object’s state from the provided state_dict. - -state_dict (dict[str, Any]) – The state dict to restore from - -Objects should return their state_dict representation as a dictionary. The output of this function will be checkpointed, and later restored in load_state_dict(). - -Because of the inplace nature of restoring a checkpoint, this function is also called during torch.distributed.checkpoint.load. - -The objects state dict - -This example shows how to use Pytorch Distributed Checkpoint to save a FSDP model. - -The following types define the IO interface used during checkpoint: - -Interface used by load_state_dict to read from storage. - -One StorageReader instance acts as both the coordinator and the follower in a distributed checkpoint. As part of initialization, each instance is told its role. - -A subclass should expected the following sequence of calls by load_state_dict: - -(all ranks) set checkpoint_id if users pass a valid checkpoint_id. - -(all ranks) read_metadata() - -(all ranks) set_up_storage_reader() - -(all ranks) prepare_local_plan() - -(coordinator) prepare_global_plan() - -(all ranks) read_data() - -Perform centralized planning of storage loading. - -This method is only called on the coordinator instance. - -While this method can produce a completely different plan, the preferred way is to store storage specific data in LoadPlan::storage_data. - -plans (list[torch.distributed.checkpoint.planner.LoadPlan]) – A list of LoadPlan instances, one for each rank. - -A list of transformed LoadPlan after storage global planning - -list[torch.distributed.checkpoint.planner.LoadPlan] - -Perform storage-specific local planning. - -While this method can produce a completely different plan, the recommended way is to store storage specific data in LoadPlan::storage_data. - -plan (LoadPlan) – The local plan from the LoadPlan in use. - -A transformed LoadPlan after storage local planning - -Read all items from plan using planner to resolve the data. - -A subclass should call LoadPlanner::load_bytes to deserialize a BytesIO object into the right place. - -A subclass should call LoadPlanner::resolve_tensor to get access to the tensors that in should load data into. - -It’s the StorageLayer responsibility to properly schedule any cross device copies required. - -plan (LoadPlan) – The local plan to execute on - -planner (LoadPlanner) – The planner object to use to resolve items. - -A future that completes once all reads are finished. - -Read the checkpoint metadata. - -The metadata object associated with the checkpoint being loaded. - -Calls to indicates a brand new checkpoint read is going to happen. A checkpoint_id may be present if users set the checkpoint_id for this checkpoint read. The meaning of the checkpiont_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage. - -checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is more like a key-value store. (Default: None) - -Initialize this instance. - -metadata (Metadata) – The metadata schema to use. - -is_coordinator (bool) – Whether this instance is responsible for coordinating the checkpoint. - -Check if the given checkpoint_id is supported by the storage. This allow us to enable automatic storage selection. - -Interface used by save_state_dict to write to storage. - -One StorageWriter instance acts as both the coordinator and the follower in a distributed checkpoint. As part of initialization, each instance is told its role. - -A subclass should expect the following sequence of calls. - -(all ranks) set checkpoint_id if users pass a valid checkpoint_id. - -(all ranks) set_up_storage_writer() - -(all ranks) prepare_local_plan() - -(coordinator) prepare_global_plan() - -(all ranks) write_data() - -(coordinator) finish() - -Write the metadata and marks the current checkpoint as successful. - -The actual format/schema used for serializing metadata is an implementation detail. The only requirement is that it’s recoverable in to the same object graph. - -metadata (Metadata) – metadata for the new checkpoint - -results (list[list[torch.distributed.checkpoint.storage.WriteResult]]) – A list of WriteResults from all ranks. - -Perform centralized planning of storage. - -This method is only called on the coordinator instance. - -While this method can produce a completely different plan, the preferred way is to store storage specific data in SavePlan::storage_data. - -plans (list[torch.distributed.checkpoint.planner.SavePlan]) – A list of SavePlan instances, one for each rank. - -A list of transformed SavePlan after storage global planning - -list[torch.distributed.checkpoint.planner.SavePlan] - -Perform storage-specific local planning. - -While this method can produce a completely different plan, the recommended way is to store storage specific data in SavePlan::storage_data. - -plan (SavePlan) – The local plan from the SavePlanner in use. - -A transformed SavePlan after storage local planning - -Calls to indicates a brand new checkpoint write is going to happen. A checkpoint_id may be present if users set the checkpoint_id for this checkpoint write. The meaning of the checkpiont_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage. - -checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default: None) - -Initialize this instance. - -is_coordinator (bool) – Whether this instance is responsible for coordinating the checkpoint. - -Return the storage-specific metadata. This is used to store additional information in a checkpoint that can be useful for providing request-level observability. StorageMeta is passed to the SavePlanner during save calls. Returns None by default. - -TODO: provide an example - -Optional[StorageMeta] - -Check if the given checkpoint_id is supported by the storage. This allow us to enable automatic storage selection. - -Write all items from plan using planner to resolve the data. - -A subclass should call SavePlanner::resolve_data on each item from the plan to get access to the underlying object to write. - -Subclasses should lazily call resolve_data as it can allocate memory. In case of tensors, make following assumptions: - -They might be on any device, including not matching the one on WriteItem::tensor_data - -They might be views or not contiguous. Only the projection needs to be saved. - -plan (SavePlan) – The save plan to execute. - -planner (SavePlanner) – Planner object to be used to resolve items to data. - -A future that completes to a list of WriteResult - -Future[list[torch.distributed.checkpoint.storage.WriteResult]] - -The following types define the planner interface used during checkpoint: - -Abstract class defining the protocol used by load_state_dict to plan the load process. - -LoadPlanner are stateful objects that can be used to customize the whole load process. - -LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it will be visible to the whole process. - -A planner subclass can expect the following sequence of calls during load_state_dict: - -Signals the start of loading a checkpoint. - -Process the state_dict and produces a LoadPlan that will be sent for global planning. - -Takes the LoadPlan from all ranks and make any global decision. - -This is called once per non-tensor value in state_dict. - -They are called in pair for each Tensor value in state_dict. - -Users are recommended to extend DefaultLoadPlanner instead of this interface directly as most changes can be expressed by changes in a single method. - -There are two usual patterns of extension: - -Rewriting state_dict. This is the simplest way to extend the load process as it doesn’t requite understanding the intrincacies of how LoadPlan works. We need to keep a reference to the original state_dict as load happens in place so we need to be able to perform it in place - -Modifying resolve_tensor and commit_tensor to handle load time transformation. - -Call once the StorageReader finished loading data into tensor. - -The provided tensor is the same one returned by the call to resolve_tensor. This method is only needed if this LoadPlanner needs to post process tensor prior to copying it back to the one in the state_dict. - -The contents of tensor will follow its device synchronization model. - -Compute the global load plan and return plans for each rank. - -. N.B. This is called on the coordinator rank only - -list[torch.distributed.checkpoint.planner.LoadPlan] - -Create a LoadPlan based on state_dict and metadata provided by set_up_planner. - -. N.B. This is called on every rank. - -Accept the plan from coordinator and return final LoadPlan. - -Load the item described by read_item``and ``value. - -This method is expected to modify in-place the underlying state_dict. - -The contents of value are defined by the SavePlanner used to produce the checkpoint being loaded. - -Return the BytesIO to be used by the StorageReader to load read_item. - -The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents. - -Return the tensor described by read_item to be used by the StorageReader to load read_item. - -The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents. If, for any reason, that’s not possible, the planner can use the commit_tensor method to copy the data back to the one in state_dict. - -Initialize this instance to load data into state_dict. - -. N.B. This is called on every rank. - -Abstract class defining the protocol used by save_state_dict to plan the save process. - -SavePlanners are stateful objects that can be used to customize the whole save process. - -SavePlanner acts as an access proxy to the state_dict, so any transformation done to it will be visible to the whole process. - -A planner subclass can expect the following sequence of calls during save_state_dict: - -Signals the start of a checkpoint save. - -Process the state_dict and produces a SavePlan that will be sent for global planning. - -Takes the SavePlan from all ranks and make any global decision. - -This gives each rank a chance to adjust to global planning decisions. - -Lookups a value on the state_dict for the storage layer to write. - -Users are recommended to extend DefaultSavePlanner instead of this interface directly as most changes can be expressed by changes in a single method. - -There are 3 usual patterns of extension: - -Rewriting state_dict. This is the simplest way to extend the save process as it doesn’t requite understanding the intrincacies of how SavePlan works: - -Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted - -Using the global planning step to make central decisions that can’t be made individually by each rank - -Finally, some planners need to save additional metadata in the checkpoint, this is accomplished by having each rank contribute their data items in the local plan and the global planner aggregate them: - -Compute the global checkpoint plan and return the local plan of each rank. - -This is called on the coordinator rank only. - -tuple[list[torch.distributed.checkpoint.planner.SavePlan], torch.distributed.checkpoint.metadata.Metadata] - -Compute the save plan for the current rank. - -This will be aggregated and passed to create_global_plan. Planner specific data can be passed through SavePlan::planner_data. - -This is called on all ranks. - -Merge the plan created by create_local_plan and the result of create_global_plan. - -This is called on all ranks. - -Transform and prepare write_item from state_dict for storage, ensuring idempotency and thread-safety. - -Lookup the object associated with write_item in state_dict and apply any transformation (such as serialization) prior to the storage layer consuming it. - -Called on each rank multiple times, at least once per WriteItem in the final SavePlan. - -This method should be idempotent and thread-save. StorageWriter implementations are free to call it as frequently as they need. - -Any transformation that allocates memory should be lazily done when his method is called in order to reduce peak memory required by checkpointing. - -When returning tensors, they can be on any device or format, they can be views too. It’s the storage layer responsibility to figure out how to save them. - -Union[Tensor, BytesIO] - -Initialize this planner to save state_dict. - -Implementations should save those values as they won’t be provided lated in the save process. - -This is called on all ranks. - -Dataclass which holds information about what needs to be written to storage. - -Calculates the storage size of the underlying tensor, or None if this is not a tensor write. - -Optional[int] storage size, in bytes of underlying tensor if any. - -We provide a filesystem based storage layer: - -return the checkpoint_id that will be used to load the checkpoint. - -Basic implementation of StorageWriter using file IO. - -This implementation makes the following assumptions and simplifications: - -The checkpoint path is an empty or non-existing directory. - -File creation is atomic - -The checkpoint consist of one file per write request plus a global .metadata file with the serialized metadata if rank coordination is enabled. a rank local __{rank}.metadata file with the serialized metadata if rank coordination is NOT enabled. - -Override of AsyncStager.stage - -dict[str, Union[~StatefulT, Any]] - -We also provide other storage layers, including ones to interact with HuggingFace safetensors: - -.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageReader :members: - -.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageWriter :members: - -.. autoclass:: torch.distributed.checkpoint.QuantizedHuggingFaceStorageReader :members: - -We provide default implementations of LoadPlanner and SavePlanner that can handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor. - -Extension from the planner interface to make it easy to extend the default planner. - -Extension from the planner interface to make it easy to extend the default planner. - -DefaultLoadPlanner that adds multiple features on top of LoadPlanner. - -In particular it adds the following: - -flatten_state_dict: Handle state_dict with nested dicts flatten_sharded_tensors: For FSDP in 2D parallel mode allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint. - -Extension from the planner interface to make it easy to extend the default planner. - -Extension from the planner interface to make it easy to extend the default planner. - -Due to legacy design decisions, the state dictionaries of FSDP and DDP may have different keys or fully qualified names (e.g., layer1.weight) even when the original unparallelized model is identical. Moreover, FSDP offers various types of model state dictionaries, such as full and sharded state dictionaries. Additionally, optimizer state dictionaries employ parameter IDs instead of fully qualified names to identify parameters, potentially causing issues when parallelisms are used (e.g., pipeline parallelism). - -To tackle these challenges, we offer a collection of APIs for users to easily manage state_dicts. get_model_state_dict() returns a model state dictionary with keys consistent with those returned by the unparallelized model state dictionary. Similarly, get_optimizer_state_dict() provides the optimizer state dictionary with keys uniform across all parallelisms applied. To achieve this consistency, get_optimizer_state_dict() converts parameter IDs to fully qualified names identical to those found in the unparallelized model state dictionary. - -Note that results returned by these APIs can be used directly with the torch.distributed.checkpoint.save() and torch.distributed.checkpoint.load() methods without requiring any additional conversions. - -set_model_state_dict() and set_optimizer_state_dict() are provided to load the model and optimizer state_dict generated by by their respective getter APIs. - -Note that set_optimizer_state_dict() can only be called before backward() or after step() is called on optimizers. - -Note that this feature is experimental, and API signatures might change in the future. - -Return the model state_dict and optimizers state_dict. - -get_state_dict can process any module that is parallelized by PyTorch FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any combination of these parallelisms. The main functions of get_state_dict are: 1.) returning a model and optimizer state_dict that can be resharded with a different number of trainers and/or different parallelisms. 2.) hiding the parallelism-specific state_dict APIs. Users don’t have to call these APIs. 3.) sanity checking the result state_dict. - -The keys of the result state dictionary are the canonical FQNs (Fully Qualified Names). A canonical FQN refers to the FQN based on a parameter’s position in an nn.Module hierarchy. More specifically, a canonical FQN to a parameter is the FQN returned by module.named_parameters() or module.named_buffers() when the module is not distributed by any parallelisms. Since the optimizer internally uses parameter IDs to represent a parameter, there will be a conversion from the parameter IDs to the canonical FQNs when calling this API. - -get_state_dict can also process a module that is not parallelized. In such a case, get_state_dict only performs one function – converting the optimizer parameter IDs to the canonical FQNs. - -model (nn.Module) – the nn.Module to the model. - -optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize model. - -submodules (deprecated) – Optional[set[nn.Module]]: only return the model parameters that belong to the submodules. - -options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be returned. See StateDictOptions for the details. - -Tuple that contain model state_dict and optimizer state_dict. - -Tuple[Dict[str, ValueType], OptimizerStateType] - -Return the model state_dict of model. - -See get_state_dict for the detail usage. - -model (nn.Module) – the nn.Module to the model. - -submodules (deprecated) – Optional[set[nn.Module]]: only return the model parameters that belong to the submodules. - -options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be returned. See StateDictOptions for the details. - -The state_dict for model. - -Return the combined state_dict for optimizers. - -See get_state_dict for the detail usage. - -model (nn.Module) – the nn.Module to the model. - -optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize model. - -submodules (deprecated) – Optional[set[nn.Module]]: only return the model parameters that belong to the submodules. - -options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be returned. See StateDictOptions for the details. - -The state_dict for optimizers. - -Load the model state_dict and optimizers state_dict. - -The counterpart of get_state_dict to set the state_dict to the model and optimizers. The given model_state_dict and optim_state_dict do not have to be returned by get_state_dict but must meet the following requirements: 1) all FQNs are canonical FQNs as defined in get_state_dict, 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, 3) optimizer state_dict cannot contain the parameter IDs; the keys should be the canonical FQNs. - -is called on the optimizers. Otherwise, the optimizer states won’t be initialized correctly. - -model (nn.Module) – the nn.Module to the model. - -optimizers (Union[Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize model. - -model_state_dict (Dict[str, ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): the model state_dict to load. If the key of the model_state_dict is nn.Module, the key is a submodule of model and the value should be the state_dict of the submodule. When loading the state_dict, the prefix of the submodule will be append to the state_dict. - -optim_state_dict (OptimizerStateType) – OptimizerStateType: the optimizer state_dict to load. - -options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be loaded. See StateDictOptions for the details. - -missing_keys is a list of str containing the missing keys of the model state_dict. unexpected_keys is a list of str containing the unexpected keys of the model state_dict. - -missing_keys is a list of str containing the missing keys of the model state_dict. - -unexpected_keys is a list of str containing the unexpected keys of the model state_dict. - -NamedTuple with missing_keys and unexpected_keys fields - -Load the model state_dict. - -The counterpart of get_model_state_dict to set the state_dict to the model. See set_state_dict for the detail usage. - -model (nn.Module) – the nn.Module to the model. - -model_state_dict (Dict[str, ValueType]) – (Dict[str, ValueType]): the model state_dict to load. If the key of the model_state_dict is nn.Module, the key is a submodule of model and the value should be the state_dict of the submodule. When loading the state_dict, the prefix of the submodule will be append to the state_dict. - -options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be loaded. See StateDictOptions for the details. - -missing_keys is a list of str containing the missing keys unexpected_keys is a list of str containing the unexpected keys - -missing_keys is a list of str containing the missing keys - -unexpected_keys is a list of str containing the unexpected keys - -NamedTuple with missing_keys and unexpected_keys fields - -Load the optimizers state_dict. - -The counterpart of get_optimizer_state_dict to set the state_dict to the optimizers. See set_state_dict for the detail usage. - -step() is called on the optimizers. Otherwise, the optimizer states won’t be initialized correctly. - -model (nn.Module) – the nn.Module to the model. - -optimizers (Union[Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize model. - -optim_state_dict (OptimizerStateType) – OptimizerStateType: the optimizer state_dict to load. - -options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be loaded. See StateDictOptions for the details. - -This dataclass specifies how get_state_dict/set_state_dict will work. - -full_state_dict: if this is set to True, all the tensors in the returned state_dict will be gathered. No ShardedTensor and DTensor will be in the returned state_dict. - -cpu_offload: offload all the tensors to cpu. To prevent CPU OOM, if full_state_dict is also true, then only the rank0 will get the state_dict and all other ranks will get empty state_dict. - -ignore_frozen_params: if the value is True, the returned state_dict won’t contain any frozen parameters – the requires_grad is False. The default value is False. - -keep_submodule_prefixes (deprecated): when submodules is not None, this option indicates whether to keep the submodule prefixes from the state_dict keys. or example, if the submodule is module.pretrain and the full FQN of the parameter is pretrain.layer1.weight of the param. When this option is True, the parameter’s key in the returned state_dict will be pretrain.layer1.weight. If the options is False, the key will be layer1.weight. Note that if keep_submodule_prefixes is False, there may be conflicted FQNs, hence there should be only one submodule in submodules. - -strict: the strict option when set_state_dict calls model.load_state_dict(). - -full state_dict and will broadcast the tensors in the state_dict/ optim_state_dict one by one to other ranks. Other ranks will receive the tensors and shard according to the local shards in the model and optimizer. full_state_dict must be set to True when using this option. This option currently only supports DTensor, not the legacy ShardedTensor. - -For users which are used to using and sharing models in the torch.save format, the following methods are provided which provide offline utilities for converting betweeing formats. - -Given a directory containing a DCP checkpoint, this function will convert it into a Torch save file. - -dcp_checkpoint_dir (Union[str, PathLike]) – Directory containing the DCP checkpoint. - -torch_save_path (Union[str, PathLike]) – Filename to store the converted Torch save file. - -To avoid OOM, it’s recommended to only run this function on a single rank. - -Given the location of a torch save file, converts it into a DCP checkpoint. - -torch_save_path (Union[str, PathLike]) – Filename of the Torch save file. - -dcp_checkpoint_dir (Union[str, PathLike]) – Directory to store the DCP checkpoint. - -To avoid OOM, it’s recommended to only run this function on a single rank. - -The following classes can also be utilized for online loading and resharding of models from the torch.save format. - -StorageReader for reading a Torch Save file. This reader will read the entire checkpoint on the coordinator rank, and then broadcast and shard each tensor to all ranks. - -. N.B. Intended to be used with DynamicMetaLoadPlanner - -Current implementation only supports loading Tensors. - -Implementation of the StorageReader method - -list[torch.distributed.checkpoint.planner.LoadPlan] - -Implementation of the StorageReader method - -Reads torch save data on the coordinator rank, and broadcast afterwards this incurrs a communication cost, but avoids having to load the entire checkpoint on each rank, hopefully preventing OOM issues - -Extends the default StorageReader to support building the metadata file - -Implementation of the StorageReader method - -Implementation of the StorageReader method - -Implementation of the StorageReader method - -Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict, avoiding the need to read metadata from disk. This is useful when reading formats which don’t have a metadata file, like Torch Save files. - -. N.B. Intended to be used with BroadcastingTorchSaveReader - -Current implementation only supports loading Tensors. - -Setups of the planner, extnding default behavior by creating the Metadata object from the state dict - -The following experimental interfaces are provided for improved observability in production environments: - ---- - -## torch.distributed.tensor# - -**URL:** https://pytorch.org/docs/stable/distributed.tensor.html - -**Contents:** -- torch.distributed.tensor# -- PyTorch DTensor (Distributed Tensor)# - - DTensor Class APIs# - - DeviceMesh as the distributed communicator# - - DTensor Placement Types# -- Different ways to create a DTensor# - - Create DTensor from a logical torch.Tensor# - - DTensor Factory Functions# - - Random Operations# -- Debugging# - -Created On: Jun 13, 2025 | Last Updated On: Aug 23, 2025 - -torch.distributed.tensor is currently in alpha state and under development, we are committing backward compatibility for the most APIs listed in the doc, but there might be API changes if necessary. - -PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed logic, including sharded storage, operator computation and collective communications across devices/hosts. DTensor could be used to build different parallelism solutions and support sharded state_dict representation when working with multi-dimensional sharding. - -Please see examples from the PyTorch native parallelism solutions that are built on top of DTensor: - -DTensor follows the SPMD (single program, multiple data) programming model to empower users to write distributed program as if it’s a single-device program with the same convergence property. It provides a uniform tensor sharding layout (DTensor Layout) through specifying the DeviceMesh and Placement: - -DeviceMesh represents the device topology and the communicators of the cluster using an n-dimensional array. - -Placement describes the sharding layout of the logical tensor on the DeviceMesh. DTensor supports three types of placements: Shard, Replicate and Partial. - -DTensor is a torch.Tensor subclass. This means once a DTensor is created, it could be used in very similar way to torch.Tensor, including running different types of PyTorch operators as if running them in a single device, allowing proper distributed computation for PyTorch operators. - -In addition to existing torch.Tensor methods, it also offers a set of additional methods to interact with torch.Tensor, redistribute the DTensor Layout to a new DTensor, get the full tensor content on all devices, etc. - -DTensor (Distributed Tensor) is a subclass of torch.Tensor that provides single-device like abstraction to program with multi-device torch.Tensor. It describes the distributed tensor sharding layout (DTensor Layout) through the DeviceMesh and following types of Placement: - -Shard: Tensor sharded on the tensor dimension dim on the devices of the DeviceMesh dimension - -Replicate: Tensor replicated on the devices of the DeviceMesh dimension - -Partial: Tensor is pending reduction on the devices of the DeviceMesh dimension - -When calling PyTorch operators, DTensor overrides the PyTorch operators to perform sharded computation and issue communications whenever necessary. Along with the operator computation, DTensor will transform or propagate the placements (DTensor Layout) properly (based on the operator semantic itself) and generate new DTensor outputs. - -To ensure numerical correctness of the DTensor sharded computation when calling PyTorch operators, DTensor requires every Tensor argument of the operator be DTensor. - -Directly using the Tensor subclass constructor here is not the recommended way to create a DTensor (i.e. it does not handle autograd correctly hence is not the public API). Please refer to the create_dtensor section to see how to create a DTensor. - -Return a list of ChunkStorageMetadata, which is a dataclass that describes the size/offset of the local shard/replica on current rank. For DTensor, each rank will have a single local shard/replica, so the returned list usually only has one element. - -This dunder method is primariy used for distributed checkpoint purpose. - -A List[ChunkStorageMetadata] object that represents the shard size/offset on the current rank. - -Create a DTensor from a local torch.Tensor on each rank according to the device_mesh and placements specified. - -local_tensor (torch.Tensor) – local torch.Tensor on each rank. - -device_mesh (DeviceMesh, optional) – DeviceMesh to place the tensor, if not specified, must be called under a DeviceMesh context manager, default: None - -placements (List[Placement], optional) – the placements that describes how to place the local torch.Tensor on DeviceMesh, must have the same number of elements as device_mesh.ndim. - -run_check (bool, optional) – at a cost of extra communications, perform sanity check across ranks to check each local tensor’s meta information to ensure correctness. If have Replicate in placements, the data on first rank of the device mesh dimension will be broadcasted to other ranks. default: False - -shape (torch.Size, optional) – A List of int which specifies the size of DTensor which build on top of local_tensor. Note this needs to be provided if the shape of local_tensor are different across the ranks. If not provided, shape will be computed assuming the given distributed tensor is evenly sharded across ranks. default: None - -stride (tuple, optional) – A List of int which specifies the stride of DTensor. If not provided, stride will be computed assuming the given distributed tensor is evenly sharded across ranks. default: None - -When run_check=False, it is the user’s responsibility to ensure the local tensor passed in is correct across ranks (i.e. the tensor is sharded for the Shard(dim) placement or replicated for the Replicate() placement). If not, the behavior of the created DTensor is undefined. - -from_local is differentiable, the requires_grad of the created DTensor object will depend on if local_tensor requires_grad or not. - -Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate them together. It’s a syntactic sugar of the following code: - -dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local() - -grad_placements (List[Placement], optional) – the placements describes the future layout of any gradient layout of the full Tensor returned from this function. full_tensor converts DTensor to a full torch.Tensor and the returned torch.tensor might not be used as the original replicated DTensor layout later in the code. This argument is the hint that user can give to autograd in case the gradient layout of the returned tensor does not match the original replicated DTensor layout. If not specified, we will assume the gradient layout of the full tensor be replicated. - -A torch.Tensor object that represents the full tensor of this DTensor. - -full_tensor is differentiable. - -redistribute performs necessary collective operations that redistribute the current DTensor from its current placements to a new placements, or from its current DeviceMesh to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by specifying a Replicate placement for each dimension of the DeviceMesh. - -When redistributing from current to the new placements on one device mesh dimension, we will perform the following operations including communication collective or local operation: - -Shard(dim) -> Replicate(): all_gather - -Shard(src_dim) -> Shard(dst_dim): all_to_all - -Replicate() -> Shard(dim): local chunking (i.e. torch.chunk) - -Partial() -> Replicate(): all_reduce - -Partial() -> Shard(dim): reduce_scatter - -redistribute would correctly figure out the necessary redistribute steps for DTensors that are created either on 1-D or N-D DeviceMesh. - -device_mesh (DeviceMesh, optional) – DeviceMesh to place the DTensor. If not specified, it would use the current DTensor’s DeviceMesh. default: None - -placements (List[Placement], optional) – the new placements that describes how to place the DTensor into the DeviceMesh, must have the same number of elements as device_mesh.ndim. default: replicate on all mesh dimensions - -async_op (bool, optional) – whether to perform the DTensor redistribute operation asynchronously or not. Default: False - -forward_dtype (torch.dtype, optional) – the local tensor datatype can be converted to forward_dtype before redistributing the local tensor in its forward. The result DTensor will be in forward_dtype Default: None. - -backward_dtype (torch.dtype, optional) – the local tensor datatype can be converted to backward_dtype before redistributing the local tensor in its backward. The result DTensor gradient would be converted back to the current DTensor dtype. Default: None - -redistribute is differentiable, which means user do not need to worry about the backward formula of the redistribute operation. - -redistribute currently only supports redistributing DTensor on the same DeviceMesh, Please file an issue if you need to redistribute DTensor to different DeviceMesh. - -Get the local tensor of this DTensor on its current rank. For sharding it returns a local shard of the logical tensor view, for replication it returns the replica on its current rank. - -grad_placements (List[Placement], optional) – the placements describes the future layout of any gradient layout of the Tensor returned from this function. to_local converts DTensor to local tensor and the returned local tensor might not be used as the original DTensor layout later in the code. This argument is the hint that user can give to autograd in case the gradient layout of the returned tensor does not match the original DTensor layout. If not specified, we will assume the gradient layout remains the same as the original DTensor and use that for gradient computation. - -A torch.Tensor or AsyncCollectiveTensor object. it represents the local tensor on its current rank. When an AsyncCollectiveTensor object is returned, it means the local tensor is not ready yet (i.e. communication is not finished). In this case, user needs to call wait to wait the local tensor to be ready. - -to_local is differentiable, the requires_grad of the local tensor returned will depend on if the DTensor requires_grad or not. - -The DeviceMesh attribute that associates with this DTensor object. - -device_mesh is a read-only property, it can not be set. - -The placements attribute of this DTensor that describes the layout of this DTensor on the its DeviceMesh. - -placements is a read-only property, it can not be set. - -DeviceMesh was built from DTensor as the abstraction to describe cluster’s device topology and represent multi-dimensional communicators (on top of ProcessGroup). To see the details of how to create/use a DeviceMesh, please refer to the DeviceMesh recipe. - -DTensor supports the following types of Placement on each DeviceMesh dimension: - -The Shard(dim) placement describes the DTensor sharding on tensor dimension dim over a corresponding DeviceMesh dimension, where each rank on the DeviceMesh dimension only holds a shard/piece of the global Tensor. The Shard(dim) placement follows the torch.chunk(dim) semantic, where the last few shards on the DeviceMesh dimension might be empty when the tensor dimension is not evenly divisible on the DeviceMesh dimension. The Shard placement can be used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.) - -dim (int) – The tensor dimension that describes the DTensor is sharded over its corresponding DeviceMesh dimension. - -sharding on a tensor dimension where the tensor dimension size is not evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. - -The Replicate() placement describes the DTensor replicating on a corresponding DeviceMesh dimension, where each rank on the DeviceMesh dimension holds a replica of the global Tensor. The Replicate placement can be used by all DTensor APIs (i.e. distribute_tensor, DTensor.from_local, etc.) - -The Partial(reduce_op) placement describes the DTensor that is pending reduction on a specified DeviceMesh dimension, where each rank on the DeviceMesh dimension holds the partial value of the global Tensor. User can redistribute the Partial DTensor to a Replicate or Shard(dim) placement on the specified DeviceMesh dimension using redistribute, which would trigger necessary communication operations under the hood (i.e. allreduce, reduce_scatter). - -reduce_op (str, optional) – The reduction op to be used for the partial DTensor to produce Replicated/Sharded DTensor. Only element-wise reduction operations are supported, including: “sum”, “avg”, “product”, “max”, “min”, default: “sum”. - -The Partial placement can be generated as a result of the DTensor operators, and can only be used by the DTensor.from_local API. - -The base class for the Placement type, where it describes how a DTensor is placed onto the DeviceMesh. Placement and DeviceMesh together could describe the DTensor Layout. It is the base class of the three main DTensor Placement types: Shard, Replicate, and Partial. - -This class is not meant to be used directly, mainly served as a typing stub. - -distribute_tensor() creates a DTensor from a logical or “global” torch.Tensor on each rank. This could be used to shard the leaf torch.Tensor s (i.e. model parameters/buffers and inputs). - -DTensor.from_local() creates a DTensor from a local torch.Tensor on each rank, which can be used to create DTensor from a non-leaf torch.Tensor s (i.e. intermediate activation tensors during forward/backward). - -DTensor provides dedicated tensor factory functions (e.g. empty(), ones(), randn(), etc.) to allow different DTensor creations by directly specifying the DeviceMesh and Placement. Compare to distribute_tensor(), this could directly materializing the sharded memory on device, instead of performing sharding after initializing the logical Tensor memory. - -The SPMD (single program, multiple data) programming model in torch.distributed launches multiple processes (i.e. via torchrun) to execute the same program, this means that the model inside the program would be initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly on GPU if enough memory). - -DTensor offers a distribute_tensor() API that could shard the model weights or Tensors to DTensor s, where it would create a DTensor from the “logical” Tensor on each process. This would empower the created DTensor s to comply with the single device semantic, which is critical for numerical correctness. - -Distribute a leaf torch.Tensor (i.e. nn.Parameter/buffers) to the device_mesh according to the placements specified. The rank of device_mesh and placements must be the same. The tensor to distribute is the logical or “global” tensor, and the API would use the tensor from first rank of the DeviceMesh dimension as the source of truth to preserve the single-device semantic. If you want to construct a DTensor in the middle of the Autograd computation, please use DTensor.from_local() instead. - -tensor (torch.Tensor) – torch.Tensor to be distributed. Note that if you want to shard a tensor on a dimension that is not evenly divisible by the number of devices in that mesh dimension, we use torch.chunk semantic to shard the tensor and scatter the shards. The uneven sharding behavior is experimental and subject to change. - -device_mesh (DeviceMesh, optional) – DeviceMesh to distribute the tensor, if not specified, must be called under a DeviceMesh context manager, default: None - -placements (List[Placement], optional) – the placements that describes how to place the tensor on DeviceMesh, must have the same number of elements as device_mesh.ndim. If not specified, we will by default replicate the tensor across the device_mesh from the first rank of each dimension of the device_mesh. - -src_data_rank (int, optional) – the rank of the source data for the logical/global tensor, it is used by distribute_tensor() to scatter/broadcast the shards/replicas to other ranks. By default, we use group_rank=0 on each DeviceMesh dimension as the source data to preserve the single-device semantic. If passing None explicitly, distribute_tensor() simply uses its local data instead of trying to preserve the single-device semantic via scatter/broadcast. Default: 0 - -A DTensor or XLAShardedTensor object. - -When initialize the DeviceMesh with the xla device_type, distribute_tensor return XLAShardedTensor instead. see this issue for more details. The XLA integration is experimental and subject to change. - -Along with distribute_tensor(), DTensor also offers a distribute_module() API to allow easier sharding on the nn.Module level - -This function expose three functions to control the parameters/inputs/outputs of the module: - -1. To perform sharding on the module before runtime execution by specifying the partition_fn (i.e. allow user to convert Module parameters to DTensor parameters according to the partition_fn specified). 2. To control the inputs or outputs of the module during runtime execution by specifying the input_fn and output_fn. (i.e. convert the input to DTensor, convert the output back to torch.Tensor) - -module (nn.Module) – user module to be partitioned. - -device_mesh (DeviceMesh) – the device mesh to place the module. - -partition_fn (Callable) – the function to partition parameters (i.e. shard certain parameters across the device_mesh). If partition_fn is not specified, by default we replicate all module parameters of module across the mesh. - -input_fn (Callable) – specify the input distribution, i.e. could control how the input of the module is sharded. input_fn will be installed as a module forward_pre_hook (pre forward hook). - -output_fn (Callable) – specify the output distribution, i.e. could control how the output is sharded, or convert it back to torch.Tensor. output_fn will be installed as a module forward_hook (post forward hook). - -A module that contains parameters/buffers that are all DTensor s. - -When initialize the DeviceMesh with the xla device_type, distribute_module return nn.Module with PyTorch/XLA SPMD annotated parameters. See this issue for more details. The XLA integration is experimental and subject to change. - -DTensor also provides dedicated tensor factory functions to allow creating DTensor directly using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally specifying the DeviceMesh and Placement for the DTensor created: - -Returns a DTensor filled with the scalar value 0. - -size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) - -requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default: False. - -dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if None, uses a global default (see torch.set_default_dtype()). - -layout (torch.layout, optional) – the desired layout of returned DTensor. Default: torch.strided. - -device_mesh – DeviceMesh type, contains the mesh info of ranks - -placements – a sequence of Placement type: Shard, Replicate - -A DTensor object on each rank - -Returns a DTensor filled with the scalar value 1, with the shape defined by the variable argument size. - -size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - -dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if None, uses a global default (see torch.set_default_dtype()). - -layout (torch.layout, optional) – the desired layout of returned DTensor. Default: torch.strided. - -requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default: False. - -device_mesh – DeviceMesh type, contains the mesh info of ranks - -placements – a sequence of Placement type: Shard, Replicate - -A DTensor object on each rank - -Returns a DTensor filled with uninitialized data. The shape of the DTensor is defined by the variable argument size. - -size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) - -dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if None, uses a global default (see torch.set_default_dtype()). layout (torch.layout, optional): the desired layout of returned DTensor. Default: torch.strided. - -requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default: False. - -device_mesh – DeviceMesh type, contains the mesh info of ranks - -placements – a sequence of Placement type: Shard, Replicate - -A DTensor object on each rank - -Returns a DTensor filled with fill_value according to device_mesh and placements, with the shape defined by the argument size. - -size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - -fill_value (Scalar) – the value to fill the output tensor with. - -dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if None, uses a global default (see torch.set_default_dtype()). - -layout (torch.layout, optional) – the desired layout of returned DTensor. Default: torch.strided. - -requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default: False. - -device_mesh – DeviceMesh type, contains the mesh info of ranks. - -placements – a sequence of Placement type: Shard, Replicate - -A DTensor object on each rank - -Returns a DTensor filled with random numbers from a uniform distribution on the interval [0, 1). The shape of the tensor is defined by the variable argument size. - -size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - -dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if None, uses a global default (see torch.set_default_dtype()). - -layout (torch.layout, optional) – the desired layout of returned DTensor. Default: torch.strided. - -requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default: False. - -device_mesh – DeviceMesh type, contains the mesh info of ranks. - -placements – a sequence of Placement type: Shard, Replicate - -A DTensor object on each rank - -Returns a DTensor filled with random numbers from a normal distribution with mean 0 and variance 1. The shape of the tensor is defined by the variable argument size. - -size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - -dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if None, uses a global default (see torch.set_default_dtype()). - -layout (torch.layout, optional) – the desired layout of returned DTensor. Default: torch.strided. - -requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default: False. - -device_mesh – DeviceMesh type, contains the mesh info of ranks. - -placements – a sequence of Placement type: Shard, Replicate - -A DTensor object on each rank - -DTensor provides distributed RNG functionality to ensure that random operations on sharded tensors get unique values, and random operations on replicated tensors get the same values. This system requires that all participating ranks (e.g. SPMD ranks) start out using the same generator state before each dtensor random operation is performed, and if this is true, it ensures they all end up at the same state after each dtensor random operation completes. There is no communication performed during random operations to synchronize RNG states. - -Operators that accept a generator kwarg will utilize the user-passed generator, if passed, or the default generator for the device otherwise. Whichever generator is used, it will be advanced after the DTensor operation. It is valid to use the same generator for both DTensor and non-DTensor operations, but care must be taken to ensure the non-DTensor operations advance the generator state equally on all ranks if so. - -When using DTensor together with Pipeline Parallelism, ranks for each pipeline stage should use a distinct seed, and ranks within a pipeline stage should use the same seed. - -DTensor’s RNG infra is based on the philox based RNG algorithm, and supports any philox based backend (cuda, and other cuda-like devices), but unfortunately does not yet support the CPU backend. - -When launching the program, you can turn on additional logging using the TORCH_LOGS environment variable from torch._logging : - -TORCH_LOGS=+dtensor will display logging.DEBUG messages and all levels above it. - -TORCH_LOGS=dtensor will display logging.INFO messages and above. - -TORCH_LOGS=-dtensor will display logging.WARNING messages and above. - -To debug the program that applied DTensor, and understand more details about what collectives happened under the hood, DTensor provides a CommDebugMode: - -CommDebugMode is a context manager that counts the number of functional collectives within its context. It does this using a TorchDispatchMode. - -Not all collectives are supported yet. - -Generates detailed table displaying operations and collective tracing information on a module level. Amount of information is dependent on noise_level - -prints module-level collective counts - -prints dTensor operations not included in trivial operations, module information - -prints operations not included in trivial operations - -prints all operations - -Creates json file used to build browser visual 0. prints module-level collective counts 1. prints dTensor operations not included in trivial operations 2. prints operations not included in trivial operations 3. prints all operations - -Returns the communication counts as a dictionary. - -The communication counts as a dictionary. - -dict[str, dict[str, Any]] - -dict[str, dict[str, Any]] - -Alternative to console CommDebugMode output, writes to file specified by the user - -To visualize the sharding of a DTensor that have less than 3 dimensions, DTensor provides visualize_sharding(): - -Visualizes sharding in the terminal for DTensor that are 1D or 2D. - -This requires the tabulate package, or rich and matplotlib. No sharding info will be printed for empty tensors - -DTensor also provides a set of experimental features. These features are either in prototyping stage, or the basic functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to these features. - -context_parallel is an experimental API to enable context parallelism (CP). This API performs two actions: 1) patch the SDPA (torch.nn.functional.scaled_dot_product_attention) with the CP-enabled one, 2) shard buffers along the sequence dimension and each rank will preserve the corresponding shard according mesh. - -mesh (DeviceMesh) – the device mesh for the context parallelism. - -buffers (Optional[List[torch.Tensor]]) – buffers that the usage depend on the sequence dimension. Examples are input batch, labels and positional embedding buffers. These buffers must be sharded along the sequence dimension to ensure the accuracy. The sharding will happen in-place, the buffer’s shape will change within the context. The buffers will be restored after the context finishes. no_restore_buffers can be used to specify which buffers don’t need to be restored. Note that buffers should not contain any nn.Parameter. - -buffer_seq_dims (Optional[List[int]]) – the sequence dimensions of buffers. - -no_restore_buffers (Optional[Set[torch.Tensor]]) – buffers in these set won’t be restored after the context exits. This set must be a subset of buffers. If the buffers won’t be used after the context exits, these buffers can be put in this list to avoid extra restore time. - -Generator[None, None, None] - -torch.distributed.tensor.experimental.context_parallel is a prototype feature in PyTorch. The API is subject to change. - -local_map() is an experimental API that allows users to pass DTensor s to a function that is written to be applied on torch.Tensor s. It is done by extracting the local components of DTensor, call the function, and wrap the outputs to DTensor according to the out_placements. - -func (Callable) – the function to be applied on each local shard of DTensor s. - -out_placements (Union[PlacementType, Tuple[PlacementType, …]]) – the desired placements of the DTensor s in func’s flattened output. If the flattened output is a single value, the out_placements should be of type PlacementType. Otherwise if the flattened output has multiple values, the out_placements should be a tuple of PlacementType values 1:1 mapping to the flattened output. Besides, for Tensor output, we use PlacementType as its placements (a Tuple[Placement] value). For non-Tensor output, the PlacementType should be None. Note that the only exception is when no DTensor argument is passed in. In this case, even if out_placements is not None, the result function should ignore the desired placements because the function is not running with DTensor s. - -in_placements (Tuple[PlacementType, …], optional) – the required placements of the DTensor s in the flattened inputs of func. If in_placements is specified, local_map() would examine whether the placements of each DTensor argument is the same as the required placements or not. If the placements are not the same and redistribute_inputs is False, an exception will be raised. Otherwise if redistribute_inputs is True, the argument will be first redistributed to the required sharding placements before passing its local tensor to func. The only exception is when required placements are not None and the argument is a torch.Tensor. In this case, the placements examination will be skipped and the argument will be directly passed to func. If in_placements is None, no placements examination will be performed. Default: None - -in_grad_placements (Tuple[PlacementType, …], optional) – the placements hint of the DTensor s gradient corresponds to the flattened input DTensor. This argument is the hint that user can give to to_local() in case the gradient layout of the local tensor input does not match its DTensor input layout. If not specified, we will assume the gradient layout of the local tensor input remains the same as the original DTensor input and use that for gradient computation. Default: None. - -device_mesh (DeviceMesh, optional) – the device mesh that the output DTensor s are placed on. If not specified, this will be inferred from the first input DTensor’s device mesh. Default: None. - -redistribute_inputs (bool, optional) – the bool value indicating whether to reshard the input DTensor s when their placements are different from the required input placements. If this value is False and some DTensor input has a different placement, an exception will be raised. Default: False. - -A Callable that applies func to each local shard of the input DTensor and returns a DTensor constructed from the return value of func. - -AssertionError – For any non-DTensor output, we require its corresponding output placement in out_placements be None. An AssertionError will be raised if this is not the case. - -ValueError – If redistribute_inputs=False but the input DTensor needs a redistribution according to in_placements. - -This API is currently experimental and subject to change - -register_sharding() is an experimental API that allows users to register sharding strategies for an operator when the tensor inputs and outputs are DTensor. It can be useful when: (1) there doesn’t exist a default sharding strategy for op, e.g. when op is a custom operator that is not supported by DTensor; (2) when users would like to overwrite default sharding strategies of existing operators. - -op (Union[OpOverload, List[OpOverload]]) – An op or a list of ops to register the customized sharding function. - -A function decorator which can be used to wrap a function that defines the sharding strategy for the operator specified in op. The defined sharding strategy will be registered to DTensor and will override the default sharding strategy if DTensor has already implemented the operator. The customized sharding function takes the same inputs as the original op (except that if an arg is a torch.Tensor, it will be replaced by a tensor-like object that DTensor uses internally). The function should return a sequence of 2-tuples, each specifying acceptable output placements and its corresponding input placements. - -This API is currently experimental and subject to change - ---- - -## FullyShardedDataParallel# - -**URL:** https://pytorch.org/docs/stable/fsdp.html - -**Contents:** -- FullyShardedDataParallel# - -Created On: Feb 02, 2022 | Last Updated On: Jun 11, 2025 - -A wrapper for sharding module parameters across data parallel workers. - -This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shortened to FSDP. - -Using FSDP involves wrapping your module and then initializing your optimizer after. This is required since FSDP changes the parameter variables. - -When setting up FSDP, you need to consider the destination CUDA device. If the device has an ID (dev_id), you have three options: - -Place the module on that device - -Set the device using torch.cuda.set_device(dev_id) - -Pass dev_id into the device_id constructor argument. - -This ensures that the FSDP instance’s compute device is the destination device. For option 1 and 3, the FSDP initialization always occurs on GPU. For option 2, the FSDP initialization happens on module’s current device, which may be a CPU. - -If you’re using the sync_module_states=True flag, you need to ensure that the module is on a GPU or use the device_id argument to specify a CUDA device that FSDP will move the module to in the FSDP constructor. This is necessary because sync_module_states=True requires GPU communication. - -FSDP also takes care of moving input tensors to the forward method to the GPU compute device, so you don’t need to manually move them from CPU. - -For use_orig_params=True, ShardingStrategy.SHARD_GRAD_OP exposes the unsharded parameters, not the sharded parameters after forward, unlike ShardingStrategy.FULL_SHARD. If you want to inspect the gradients, you can use the summon_full_params method with with_grads=True. - -With limit_all_gathers=True, you may see a gap in the FSDP pre-forward where the CPU thread is not issuing any kernels. This is intentional and shows the rate limiter in effect. Synchronizing the CPU thread in that way prevents over-allocating memory for subsequent all-gathers, and it should not actually delay GPU kernel execution. - -FSDP replaces managed modules’ parameters with torch.Tensor views during forward and backward computation for autograd-related reasons. If your module’s forward relies on saved references to the parameters instead of reacquiring the references each iteration, then it will not see FSDP’s newly created views, and autograd will not work correctly. - -Finally, when using sharding_strategy=ShardingStrategy.HYBRID_SHARD with the sharding process group being intra-node and the replication process group being inter-node, setting NCCL_CROSS_NIC=1 can help improve the all-reduce times over the replication process group for some cluster setups. - -There are several limitations to be aware of when using FSDP: - -FSDP currently does not support gradient accumulation outside no_sync() when using CPU offloading. This is because FSDP uses the newly-reduced gradient instead of accumulating with any existing gradient, which can lead to incorrect results. - -FSDP does not support running the forward pass of a submodule that is contained in an FSDP instance. This is because the submodule’s parameters will be sharded, but the submodule itself is not an FSDP instance, so its forward pass will not all-gather the full parameters appropriately. - -FSDP does not work with double backwards due to the way it registers backward hooks. - -FSDP has some constraints when freezing parameters. For use_orig_params=False, each FSDP instance must manage parameters that are all frozen or all non-frozen. For use_orig_params=True, FSDP supports mixing frozen and non-frozen parameters, but it’s recommended to avoid doing so to prevent higher than expected gradient memory usage. - -As of PyTorch 1.12, FSDP offers limited support for shared parameters. If enhanced shared parameter support is needed for your use case, please post in this issue. - -You should avoid modifying the parameters between forward and backward without using the summon_full_params context, as the modifications may not persist. - -module (nn.Module) – This is the module to be wrapped with FSDP. - -process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – This is the process group over which the model is sharded and thus the one used for FSDP’s all-gather and reduce-scatter collective communications. If None, then FSDP uses the default process group. For hybrid sharding strategies such as ShardingStrategy.HYBRID_SHARD, users can pass in a tuple of process groups, representing the groups over which to shard and replicate, respectively. If None, then FSDP constructs process groups for the user to shard intra-node and replicate inter-node. (Default: None) - -sharding_strategy (Optional[ShardingStrategy]) – This configures the sharding strategy, which may trade off memory saving and communication overhead. See ShardingStrategy for details. (Default: FULL_SHARD) - -cpu_offload (Optional[CPUOffload]) – This configures CPU offloading. If this is set to None, then no CPU offloading happens. See CPUOffload for details. (Default: None) - -auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]) – This specifies a policy to apply FSDP to submodules of module, which is needed for communication and computation overlap and thus affects performance. If None, then FSDP only applies to module, and users should manually apply FSDP to parent modules themselves (proceeding bottom-up). For convenience, this accepts ModuleWrapPolicy directly, which allows users to specify the module classes to wrap (e.g. the transformer block). Otherwise, this should be a callable that takes in three arguments module: nn.Module, recurse: bool, and nonwrapped_numel: int and should return a bool specifying whether the passed-in module should have FSDP applied if recurse=False or if the traversal should continue into the module’s subtree if recurse=True. Users may add additional arguments to the callable. The size_based_auto_wrap_policy in torch.distributed.fsdp.wrap.py gives an example callable that applies FSDP to a module if the parameters in its subtree exceed 100M numel. We recommend printing the model after applying FSDP and adjusting as needed. Example: >>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> nonwrapped_numel: int, >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return nonwrapped_numel >= min_num_params >>> # Configure a custom `min_num_params` >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) - -This specifies a policy to apply FSDP to submodules of module, which is needed for communication and computation overlap and thus affects performance. If None, then FSDP only applies to module, and users should manually apply FSDP to parent modules themselves (proceeding bottom-up). For convenience, this accepts ModuleWrapPolicy directly, which allows users to specify the module classes to wrap (e.g. the transformer block). Otherwise, this should be a callable that takes in three arguments module: nn.Module, recurse: bool, and nonwrapped_numel: int and should return a bool specifying whether the passed-in module should have FSDP applied if recurse=False or if the traversal should continue into the module’s subtree if recurse=True. Users may add additional arguments to the callable. The size_based_auto_wrap_policy in torch.distributed.fsdp.wrap.py gives an example callable that applies FSDP to a module if the parameters in its subtree exceed 100M numel. We recommend printing the model after applying FSDP and adjusting as needed. - -backward_prefetch (Optional[BackwardPrefetch]) – This configures explicit backward prefetching of all-gathers. If None, then FSDP does not backward prefetch, and there is no communication and computation overlap in the backward pass. See BackwardPrefetch for details. (Default: BACKWARD_PRE) - -mixed_precision (Optional[MixedPrecision]) – This configures native mixed precision for FSDP. If this is set to None, then no mixed precision is used. Otherwise, parameter, buffer, and gradient reduction dtypes can be set. See MixedPrecision for details. (Default: None) - -ignored_modules (Optional[Iterable[torch.nn.Module]]) – Modules whose own parameters and child modules’ parameters and buffers are ignored by this instance. None of the modules directly in ignored_modules should be FullyShardedDataParallel instances, and any child modules that are already-constructed FullyShardedDataParallel instances will not be ignored if they are nested under this instance. This argument may be used to avoid sharding specific parameters at module granularity when using an auto_wrap_policy or if parameters’ sharding is not managed by FSDP. (Default: None) - -param_init_fn (Optional[Callable[[nn.Module], None]]) – A Callable[torch.nn.Module] -> None that specifies how modules that are currently on the meta device should be initialized onto an actual device. As of v1.12, FSDP detects modules with parameters or buffers on meta device via is_meta and either applies param_init_fn if specified or calls nn.Module.reset_parameters() otherwise. For both cases, the implementation should only initialize the parameters/buffers of the module, not those of its submodules. This is to avoid re-initialization. In addition, FSDP also supports deferred initialization via torchdistX’s (pytorch/torchdistX) deferred_init() API, where the deferred modules are initialized by calling param_init_fn if specified or torchdistX’s default materialize_module() otherwise. If param_init_fn is specified, then it is applied to all meta-device modules, meaning that it should probably case on the module type. FSDP calls the initialization function before parameter flattening and sharding. Example: >>> module = MyModule(device="meta") >>> def my_init_fn(module: nn.Module): >>> # E.g. initialize depending on the module type >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy) - -A Callable[torch.nn.Module] -> None that specifies how modules that are currently on the meta device should be initialized onto an actual device. As of v1.12, FSDP detects modules with parameters or buffers on meta device via is_meta and either applies param_init_fn if specified or calls nn.Module.reset_parameters() otherwise. For both cases, the implementation should only initialize the parameters/buffers of the module, not those of its submodules. This is to avoid re-initialization. In addition, FSDP also supports deferred initialization via torchdistX’s (pytorch/torchdistX) deferred_init() API, where the deferred modules are initialized by calling param_init_fn if specified or torchdistX’s default materialize_module() otherwise. If param_init_fn is specified, then it is applied to all meta-device modules, meaning that it should probably case on the module type. FSDP calls the initialization function before parameter flattening and sharding. - -device_id (Optional[Union[int, torch.device]]) – An int or torch.device giving the CUDA device on which FSDP initialization takes place, including the module initialization if needed and the parameter sharding. This should be specified to improve initialization speed if module is on CPU. If the default CUDA device was set (e.g. via torch.cuda.set_device), then the user may pass torch.cuda.current_device to this. (Default: None) - -sync_module_states (bool) – If True, then each FSDP module will broadcast module parameters and buffers from rank 0 to ensure that they are replicated across ranks (adding communication overhead to this constructor). This can help load state_dict checkpoints via load_state_dict in a memory efficient way. See FullStateDictConfig for an example of this. (Default: False) - -forward_prefetch (bool) – If True, then FSDP explicitly prefetches the next forward-pass all-gather before the current forward computation. This is only useful for CPU-bound workloads, in which case issuing the next all-gather earlier may improve overlap. This should only be used for static-graph models since the prefetching follows the first iteration’s execution order. (Default: False) - -limit_all_gathers (bool) – If True, then FSDP explicitly synchronizes the CPU thread to ensure GPU memory usage from only two consecutive FSDP instances (the current instance running computation and the next instance whose all-gather is prefetched). If False, then FSDP allows the CPU thread to issue all-gathers without any extra synchronization. (Default: True) We often refer to this feature as the “rate limiter”. This flag should only be set to False for specific CPU-bound workloads with low memory pressure in which case the CPU thread can aggressively issue all kernels without concern for the GPU memory usage. - -use_orig_params (bool) – Setting this to True has FSDP use module ‘s original parameters. FSDP exposes those original parameters to the user via nn.Module.named_parameters() instead of FSDP’s internal FlatParameter s. This means that the optimizer step runs on the original parameters, enabling per-original-parameter hyperparameters. FSDP preserves the original parameter variables and manipulates their data between unsharded and sharded forms, where they are always views into the underlying unsharded or sharded FlatParameter, respectively. With the current algorithm, the sharded form is always 1D, losing the original tensor structure. An original parameter may have all, some, or none of its data present for a given rank. In the none case, its data will be like a size-0 empty tensor. Users should not author programs relying on what data is present for a given original parameter in its sharded form. True is required to use torch.compile(). Setting this to False exposes FSDP’s internal FlatParameter s to the user via nn.Module.named_parameters(). (Default: False) - -ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]) – Ignored parameters or modules that will not be managed by this FSDP instance, meaning that the parameters are not sharded and their gradients are not reduced across ranks. This argument unifies with the existing ignored_modules argument, and we may deprecate ignored_modules soon. For backward compatibility, we keep both ignored_states and ignored_modules`, but FSDP only allows one of them to be specified as not None. - -device_mesh (Optional[DeviceMesh]) – DeviceMesh can be used as an alternative to process_group. When device_mesh is passed, FSDP will use the underlying process groups for all-gather and reduce-scatter collective communications. Therefore, these two args need to be mutually exclusive. For hybrid sharding strategies such as ShardingStrategy.HYBRID_SHARD, users can pass in a 2D DeviceMesh instead of a tuple of process groups. For 2D FSDP + TP, users are required to pass in device_mesh instead of process_group. For more DeviceMesh info, please visit: https://pytorch.org/tutorials/recipes/distributed_device_mesh.html - -Apply fn recursively to every submodule (as returned by .children()) as well as self. - -Typical use includes initializing the parameters of a model (see also torch.nn.init). - -Compared to torch.nn.Module.apply, this version additionally gathers the full parameters before applying fn. It should not be called from within another summon_full_params context. - -fn (Module -> None) – function to be applied to each submodule - -Check if this instance is a root FSDP module. - -Clip the gradient norm of all parameters. - -The norm is computed over all parameters’ gradients as viewed as a single vector, and the gradients are modified in-place. - -max_norm (float or int) – max norm of the gradients - -norm_type (float or int) – type of the used p-norm. Can be 'inf' for infinity norm. - -Total norm of the parameters (viewed as a single vector). - -If every FSDP instance uses NO_SHARD, meaning that no gradients are sharded across ranks, then you may directly use torch.nn.utils.clip_grad_norm_(). - -If at least some FSDP instance uses a sharded strategy (i.e. one other than NO_SHARD), then you should use this method instead of torch.nn.utils.clip_grad_norm_() since this method handles the fact that gradients are sharded across ranks. - -The total norm returned will have the “largest” dtype across all parameters/gradients as defined by PyTorch’s type promotion semantics. For example, if all parameters/gradients use a low precision dtype, then the returned norm’s dtype will be that low precision dtype, but if there exists at least one parameter/ gradient using FP32, then the returned norm’s dtype will be FP32. - -This needs to be called on all ranks since it uses collective communications. - -Flatten a sharded optimizer state-dict. - -The API is similar to shard_full_optim_state_dict(). The only difference is that the input sharded_optim_state_dict should be returned from sharded_optim_state_dict(). Therefore, there will be all-gather calls on each rank to gather ShardedTensor s. - -sharded_optim_state_dict (Dict[str, Any]) – Optimizer state dict corresponding to the unflattened parameters and holding the sharded optimizer state. - -model (torch.nn.Module) – Refer to shard_full_optim_state_dict(). - -optim (torch.optim.Optimizer) – Optimizer for model ‘s parameters. - -Refer to shard_full_optim_state_dict(). - -Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic. - -Return all nested FSDP instances. - -This possibly includes module itself and only includes FSDP root modules if root_only=True. - -module (torch.nn.Module) – Root module, which may or may not be an FSDP module. - -root_only (bool) – Whether to return only FSDP root modules. (Default: False) - -FSDP modules that are nested in the input module. - -List[FullyShardedDataParallel] - -Return the full optimizer state-dict. - -Consolidates the full optimizer state on rank 0 and returns it as a dict following the convention of torch.optim.Optimizer.state_dict(), i.e. with keys "state" and "param_groups". The flattened parameters in FSDP modules contained in model are mapped back to their unflattened parameters. - -This needs to be called on all ranks since it uses collective communications. However, if rank0_only=True, then the state dict is only populated on rank 0, and all other ranks return an empty dict. - -Unlike torch.optim.Optimizer.state_dict(), this method uses full parameter names as keys instead of parameter IDs. - -Like in torch.optim.Optimizer.state_dict(), the tensors contained in the optimizer state dict are not cloned, so there may be aliasing surprises. For best practices, consider saving the returned optimizer state dict immediately, e.g. using torch.save(). - -model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters were passed into the optimizer optim. - -optim (torch.optim.Optimizer) – Optimizer for model ‘s parameters. - -optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer optim representing either a list of parameter groups or an iterable of parameters; if None, then this method assumes the input was model.parameters(). This argument is deprecated, and there is no need to pass it in anymore. (Default: None) - -rank0_only (bool) – If True, saves the populated dict only on rank 0; if False, saves it on all ranks. (Default: True) - -group (dist.ProcessGroup) – Model’s process group or None if using the default process group. (Default: None) - -A dict containing the optimizer state for model ‘s original unflattened parameters and including keys “state” and “param_groups” following the convention of torch.optim.Optimizer.state_dict(). If rank0_only=True, then nonzero ranks return an empty dict. - -Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at module. - -The target module does not have to be an FSDP module. - -A StateDictSettings containing the state_dict_type and state_dict / optim_state_dict configs that are currently set. - -AssertionError` if the StateDictSettings for differen – - -FSDP submodules differ. – - -Return the wrapped module. - -Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself. - -Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix when inside the summon_full_params() context manager. - -Iterator[tuple[str, torch.Tensor]] - -Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself. - -Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix when inside the summon_full_params() context manager. - -Iterator[tuple[str, torch.nn.parameter.Parameter]] - -Disable gradient synchronizations across FSDP instances. - -Within this context, gradients will be accumulated in module variables, which will later be synchronized in the first forward-backward pass after exiting the context. This should only be used on the root FSDP instance and will recursively apply to all children FSDP instances. - -This likely results in higher memory usage because FSDP will accumulate the full model gradients (instead of gradient shards) until the eventual sync. - -When used with CPU offloading, the gradients will not be offloaded to CPU when inside the context manager. Instead, they will only be offloaded right after the eventual sync. - -Transform the state-dict of an optimizer corresponding to a sharded model. - -The given state-dict can be transformed to one of three types: 1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict. - -For full optimizer state_dict, all states are unflattened and not sharded. Rank0 only and CPU only can be specified via state_dict_type() to avoid OOM. - -For sharded optimizer state_dict, all states are unflattened but sharded. CPU only can be specified via state_dict_type() to further save memory. - -For local state_dict, no transformation will be performed. But a state will be converted from nn.Tensor to ShardedTensor to represent its sharding nature (this is not supported yet). - -model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters were passed into the optimizer optim. - -optim (torch.optim.Optimizer) – Optimizer for model ‘s parameters. - -optim_state_dict (Dict[str, Any]) – the target optimizer state_dict to transform. If the value is None, optim.state_dict() will be used. ( Default: None) - -group (dist.ProcessGroup) – Model’s process group across which parameters are sharded or None if using the default process group. ( Default: None) - -A dict containing the optimizer state for model. The sharding of the optimizer state is based on state_dict_type. - -Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. - -Given a optim_state_dict that is transformed through optim_state_dict(), it gets converted to the flattened optimizer state_dict that can be loaded to optim which is the optimizer for model. model must be sharded by FullyShardedDataParallel. - -model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters were passed into the optimizer optim. - -optim (torch.optim.Optimizer) – Optimizer for model ‘s parameters. - -optim_state_dict (Dict[str, Any]) – The optimizer states to be loaded. - -is_named_optimizer (bool) – Is this optimizer a NamedOptimizer or KeyedOptimizer. Only set to True if optim is TorchRec’s KeyedOptimizer or torch.distributed’s NamedOptimizer. - -load_directly (bool) – If this is set to True, this API will also call optim.load_state_dict(result) before returning the result. Otherwise, users are responsible to call optim.load_state_dict() (Default: False) - -group (dist.ProcessGroup) – Model’s process group across which parameters are sharded or None if using the default process group. ( Default: None) - -Register a communication hook. - -This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates gradients across multiple workers. This hook can be used to implement several algorithms like GossipGrad and gradient compression which involve different communication strategies for parameter syncs while training with FullyShardedDataParallel. - -FSDP communication hook should be registered before running an initial forward pass and only once. - -state (object) – Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in GossipGrad, etc. It is locally stored by each worker and shared by all the gradient tensors on the worker. - -Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in GossipGrad, etc. It is locally stored by each worker and shared by all the gradient tensors on the worker. - -hook (Callable) – Callable, which has one of the following signatures: 1) hook: Callable[torch.Tensor] -> None: This function takes in a Python tensor, which represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). It then performs all necessary processing and returns None; 2) hook: Callable[torch.Tensor, torch.Tensor] -> None: This function takes in two Python tensors, the first one represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). The latter represents a pre-sized tensor to store a chunk of a sharded gradient after reduction. In both cases, callable performs all necessary processing and returns None. Callables with signature 1 are expected to handle gradient communication for a NO_SHARD case. Callables with signature 2 are expected to handle gradient communication for sharded cases. - -Re-keys the optimizer state dict optim_state_dict to use the key type optim_state_key_type. - -This can be used to achieve compatibility between optimizer state dicts from models with FSDP instances and ones without. - -To re-key an FSDP full optimizer state dict (i.e. from full_optim_state_dict()) to use parameter IDs and be loadable to a non-wrapped model: - -To re-key a normal optimizer state dict from a non-wrapped model to be loadable to a wrapped model: - -The optimizer state dict re-keyed using the parameter keys specified by optim_state_key_type. - -Scatter the full optimizer state dict from rank 0 to all other ranks. - -Returns the sharded optimizer state dict on each rank. The return value is the same as shard_full_optim_state_dict(), and on rank 0, the first argument should be the return value of full_optim_state_dict(). - -Both shard_full_optim_state_dict() and scatter_full_optim_state_dict() may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost. - -full_optim_state_dict (Optional[Dict[str, Any]]) – Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state if on rank 0; the argument is ignored on nonzero ranks. - -model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters correspond to the optimizer state in full_optim_state_dict. - -optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a list of parameter groups or an iterable of parameters; if None, then this method assumes the input was model.parameters(). This argument is deprecated, and there is no need to pass it in anymore. (Default: None) - -optim (Optional[torch.optim.Optimizer]) – Optimizer that will load the state dict returned by this method. This is the preferred argument to use over optim_input. (Default: None) - -group (dist.ProcessGroup) – Model’s process group or None if using the default process group. (Default: None) - -The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state. - -Set the state_dict_type of all the descendant FSDP modules of the target module. - -Also takes (optional) configuration for the model’s and optimizer’s state dict. The target module does not have to be a FSDP module. If the target module is a FSDP module, its state_dict_type will also be changed. - -This API should be called for only the top-level (root) module. - -This API enables users to transparently use the conventional state_dict API to take model checkpoints in cases where the root FSDP module is wrapped by another nn.Module. For example, the following will ensure state_dict is called on all non-FSDP instances, while dispatching into sharded_state_dict implementation for FSDP: - -module (torch.nn.Module) – Root module. - -state_dict_type (StateDictType) – the desired state_dict_type to set. - -state_dict_config (Optional[StateDictConfig]) – the configuration for the target state_dict_type. - -optim_state_dict_config (Optional[OptimStateDictConfig]) – the configuration for the optimizer state dict. - -A StateDictSettings that include the previous state_dict type and configuration for the module. - -Shard a full optimizer state-dict. - -Remaps the state in full_optim_state_dict to flattened parameters instead of unflattened parameters and restricts to only this rank’s part of the optimizer state. The first argument should be the return value of full_optim_state_dict(). - -Both shard_full_optim_state_dict() and scatter_full_optim_state_dict() may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost. - -full_optim_state_dict (Dict[str, Any]) – Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state. - -model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters correspond to the optimizer state in full_optim_state_dict. - -optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a list of parameter groups or an iterable of parameters; if None, then this method assumes the input was model.parameters(). This argument is deprecated, and there is no need to pass it in anymore. (Default: None) - -optim (Optional[torch.optim.Optimizer]) – Optimizer that will load the state dict returned by this method. This is the preferred argument to use over optim_input. (Default: None) - -The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state. - -Return the optimizer state-dict in its sharded form. - -The API is similar to full_optim_state_dict() but this API chunks all non-zero-dimension states to ShardedTensor to save memory. This API should only be used when the model state_dict is derived with the context manager with state_dict_type(SHARDED_STATE_DICT):. - -For the detailed usage, refer to full_optim_state_dict(). - -The returned state dict contains ShardedTensor and cannot be directly used by the regular optim.load_state_dict. - -Set the state_dict_type of all the descendant FSDP modules of the target module. - -This context manager has the same functions as set_state_dict_type(). Read the document of set_state_dict_type() for the detail. - -module (torch.nn.Module) – Root module. - -state_dict_type (StateDictType) – the desired state_dict_type to set. - -state_dict_config (Optional[StateDictConfig]) – the model state_dict configuration for the target state_dict_type. - -optim_state_dict_config (Optional[OptimStateDictConfig]) – the optimizer state_dict configuration for the target state_dict_type. - -Expose full params for FSDP instances with this context manager. - -Can be useful after forward/backward for a model to get the params for additional processing or checking. It can take a non-FSDP module and will summon full params for all contained FSDP modules as well as their children, depending on the recurse argument. - -This can be used on inner FSDPs. - -This can not be used within a forward or backward pass. Nor can forward and backward be started from within this context. - -Parameters will revert to their local shards after the context manager exits, storage behavior is the same as forward. - -The full parameters can be modified, but only the portion corresponding to the local param shard will persist after the context manager exits (unless writeback=False, in which case changes will be discarded). In the case where FSDP does not shard the parameters, currently only when world_size == 1, or NO_SHARD config, the modification is persisted regardless of writeback. - -This method works on modules which are not FSDP themselves but may contain multiple independent FSDP units. In that case, the given arguments will apply to all contained FSDP units. - -Note that rank0_only=True in conjunction with writeback=True is not currently supported and will raise an error. This is because model parameter shapes would be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited. - -Note that offload_to_cpu and rank0_only=False will result in full parameters being redundantly copied to CPU memory for GPUs that reside on the same machine, which may incur the risk of CPU OOM. It is recommended to use offload_to_cpu with rank0_only=True. - -recurse (bool, Optional) – recursively summon all params for nested FSDP instances (default: True). - -writeback (bool, Optional) – if False, modifications to params are discarded after the context manager exits; disabling this can be slightly more efficient (default: True) - -rank0_only (bool, Optional) – if True, full parameters are materialized on only global rank 0. This means that within the context, only rank 0 will have full parameters and the other ranks will have sharded parameters. Note that setting rank0_only=True with writeback=True is not supported, as model parameter shapes will be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited. - -offload_to_cpu (bool, Optional) – If True, full parameters are offloaded to CPU. Note that this offloading currently only occurs if the parameter is sharded (which is only not the case for world_size = 1 or NO_SHARD config). It is recommended to use offload_to_cpu with rank0_only=True to avoid redundant copies of model parameters being offloaded to the same CPU memory. - -with_grads (bool, Optional) – If True, gradients are also unsharded with the parameters. Currently, this is only supported when passing use_orig_params=True to the FSDP constructor and offload_to_cpu=False to this method. (Default: False) - -This configures explicit backward prefetching, which improves throughput by enabling communication and computation overlap in the backward pass at the cost of slightly increased memory usage. - -BACKWARD_PRE: This enables the most overlap but increases memory usage the most. This prefetches the next set of parameters before the current set of parameters’ gradient computation. This overlaps the next all-gather and the current gradient computation, and at the peak, it holds the current set of parameters, next set of parameters, and current set of gradients in memory. - -BACKWARD_POST: This enables less overlap but requires less memory usage. This prefetches the next set of parameters after the current set of parameters’ gradient computation. This overlaps the current reduce-scatter and the next gradient computation, and it frees the current set of parameters before allocating memory for the next set of parameters, only holding the next set of parameters and current set of gradients in memory at the peak. - -FSDP’s backward_prefetch argument accepts None, which disables the backward prefetching altogether. This has no overlap and does not increase memory usage. In general, we do not recommend this setting since it may degrade throughput significantly. - -For more technical context: For a single process group using NCCL backend, any collectives, even if issued from different streams, contend for the same per-device NCCL stream, which implies that the relative order in which the collectives are issued matters for overlapping. The two backward prefetching values correspond to different issue orders. - -This specifies the sharding strategy to be used for distributed training by FullyShardedDataParallel. - -FULL_SHARD: Parameters, gradients, and optimizer states are sharded. For the parameters, this strategy unshards (via all-gather) before the forward, reshards after the forward, unshards before the backward computation, and reshards after the backward computation. For gradients, it synchronizes and shards them (via reduce-scatter) after the backward computation. The sharded optimizer states are updated locally per rank. - -SHARD_GRAD_OP: Gradients and optimizer states are sharded during computation, and additionally, parameters are sharded outside computation. For the parameters, this strategy unshards before the forward, does not reshard them after the forward, and only reshards them after the backward computation. The sharded optimizer states are updated locally per rank. Inside no_sync(), the parameters are not resharded after the backward computation. - -NO_SHARD: Parameters, gradients, and optimizer states are not sharded but instead replicated across ranks similar to PyTorch’s DistributedDataParallel API. For gradients, this strategy synchronizes them (via all-reduce) after the backward computation. The unsharded optimizer states are updated locally per rank. - -HYBRID_SHARD: Apply FULL_SHARD within a node, and replicate parameters across nodes. This results in reduced communication volume as expensive all-gathers and reduce-scatters are only done within a node, which can be more performant for medium -sized models. - -_HYBRID_SHARD_ZERO2: Apply SHARD_GRAD_OP within a node, and replicate parameters across nodes. This is like HYBRID_SHARD, except this may provide even higher throughput since the unsharded parameters are not freed after the forward pass, saving the all-gathers in the pre-backward. - -This configures FSDP-native mixed precision training. - -param_dtype (Optional[torch.dtype]) – This specifies the dtype for model parameters during forward and backward and thus the dtype for forward and backward computation. Outside forward and backward, the sharded parameters are kept in full precision (e.g. for the optimizer step), and for model checkpointing, the parameters are always saved in full precision. (Default: None) - -reduce_dtype (Optional[torch.dtype]) – This specifies the dtype for gradient reduction (i.e. reduce-scatter or all-reduce). If this is None but param_dtype is not None, then this takes on the param_dtype value, still running gradient reduction in low precision. This is permitted to differ from param_dtype, e.g. to force gradient reduction to run in full precision. (Default: None) - -buffer_dtype (Optional[torch.dtype]) – This specifies the dtype for buffers. FSDP does not shard buffers. Rather, FSDP casts them to buffer_dtype in the first forward pass and keeps them in that dtype thereafter. For model checkpointing, the buffers are saved in full precision except for LOCAL_STATE_DICT. (Default: None) - -keep_low_precision_grads (bool) – If False, then FSDP upcasts gradients to full precision after the backward pass in preparation for the optimizer step. If True, then FSDP keeps the gradients in the dtype used for gradient reduction, which can save memory if using a custom optimizer that supports running in low precision. (Default: False) - -cast_forward_inputs (bool) – If True, then this FSDP module casts its forward args and kwargs to param_dtype. This is to ensure that parameter and input dtypes match for forward computation, as required by many ops. This may need to be set to True when only applying mixed precision to some but not all FSDP modules, in which case a mixed-precision FSDP submodule needs to recast its inputs. (Default: False) - -cast_root_forward_inputs (bool) – If True, then the root FSDP module casts its forward args and kwargs to param_dtype, overriding the value of cast_forward_inputs. For non-root FSDP modules, this does not do anything. (Default: True) - -_module_classes_to_ignore (collections.abc.Sequence[type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): This specifies module classes to ignore for mixed precision when using an auto_wrap_policy: Modules of these classes will have FSDP applied to them separately with mixed precision disabled (meaning that the final FSDP construction would deviate from the specified policy). If auto_wrap_policy is not specified, then this does not do anything. This API is experimental and subject to change. (Default: (_BatchNorm,)) - -This API is experimental and subject to change. - -Only floating point tensors are cast to their specified dtypes. - -In summon_full_params, parameters are forced to full precision, but buffers are not. - -Layer norm and batch norm accumulate in float32 even when their inputs are in a low precision like float16 or bfloat16. Disabling FSDP’s mixed precision for those norm modules only means that the affine parameters are kept in float32. However, this incurs separate all-gathers and reduce-scatters for those norm modules, which may be inefficient, so if the workload permits, the user should prefer to still apply mixed precision to those modules. - -By default, if the user passes a model with any _BatchNorm modules and specifies an auto_wrap_policy, then the batch norm modules will have FSDP applied to them separately with mixed precision disabled. See the _module_classes_to_ignore argument. - -MixedPrecision has cast_root_forward_inputs=True and cast_forward_inputs=False by default. For the root FSDP instance, its cast_root_forward_inputs takes precedence over its cast_forward_inputs. For non-root FSDP instances, their cast_root_forward_inputs values are ignored. The default setting is sufficient for the typical case where each FSDP instance has the same MixedPrecision configuration and only needs to cast inputs to the param_dtype at the beginning of the model’s forward pass. - -For nested FSDP instances with different MixedPrecision configurations, we recommend setting individual cast_forward_inputs values to configure casting inputs or not before each instance’s forward. In such a case, since the casts happen before each FSDP instance’s forward, a parent FSDP instance should have its non-FSDP submodules run before its FSDP submodules to avoid the activation dtype being changed due to a different MixedPrecision configuration. - -The above shows a working example. On the other hand, if model[1] were replaced with model[0], meaning that the submodule using different MixedPrecision ran its forward first, then model[1] would incorrectly see float16 activations instead of bfloat16 ones. - -This configures CPU offloading. - -offload_params (bool) – This specifies whether to offload parameters to CPU when not involved in computation. If True, then this offloads gradients to CPU as well, meaning that the optimizer step runs on CPU. - -StateDictConfig is the base class for all state_dict configuration classes. Users should instantiate a child class (e.g. FullStateDictConfig) in order to configure settings for the corresponding state_dict type supported by FSDP. - -offload_to_cpu (bool) – If True, then FSDP offloads the state dict values to CPU, and if False, then FSDP keeps them on GPU. (Default: False) - -FullStateDictConfig is a config class meant to be used with StateDictType.FULL_STATE_DICT. We recommend enabling both offload_to_cpu=True and rank0_only=True when saving full state dicts to save GPU memory and CPU memory, respectively. This config class is meant to be used via the state_dict_type() context manager as follows: - -rank0_only (bool) – If True, then only rank 0 saves the full state dict, and nonzero ranks save an empty dict. If False, then all ranks save the full state dict. (Default: False) - -ShardedStateDictConfig is a config class meant to be used with StateDictType.SHARDED_STATE_DICT. - -_use_dtensor (bool) – If True, then FSDP saves the state dict values as DTensor, and if False, then FSDP saves them as ShardedTensor. (Default: False) - -_use_dtensor is a private field of ShardedStateDictConfig and it is used by FSDP to determine the type of state dict values. Users should not manually modify _use_dtensor. - -OptimStateDictConfig is the base class for all optim_state_dict configuration classes. Users should instantiate a child class (e.g. FullOptimStateDictConfig) in order to configure settings for the corresponding optim_state_dict type supported by FSDP. - -offload_to_cpu (bool) – If True, then FSDP offloads the state dict’s tensor values to CPU, and if False, then FSDP keeps them on the original device (which is GPU unless parameter CPU offloading is enabled). (Default: True) - -rank0_only (bool) – If True, then only rank 0 saves the full state dict, and nonzero ranks save an empty dict. If False, then all ranks save the full state dict. (Default: False) - -ShardedOptimStateDictConfig is a config class meant to be used with StateDictType.SHARDED_STATE_DICT. - -_use_dtensor (bool) – If True, then FSDP saves the state dict values as DTensor, and if False, then FSDP saves them as ShardedTensor. (Default: False) - -_use_dtensor is a private field of ShardedOptimStateDictConfig and it is used by FSDP to determine the type of state dict values. Users should not manually modify _use_dtensor. - ---- - -## Distributed Optimizers# - -**URL:** https://pytorch.org/docs/stable/distributed.optim.html - -**Contents:** -- Distributed Optimizers# - -Created On: Mar 01, 2021 | Last Updated On: Jun 16, 2025 - -Distributed optimizer is not currently supported when using CUDA tensors - -torch.distributed.optim exposes DistributedOptimizer, which takes a list of remote parameters (RRef) and runs the optimizer locally on the workers where the parameters live. The distributed optimizer can use any of the local optimizer Base class to apply the gradients on each worker. - -DistributedOptimizer takes remote references to parameters scattered across workers and applies the given optimizer locally for each parameter. - -This class uses get_gradients() in order to retrieve the gradients for specific parameters. - -Concurrent calls to step(), either from the same or different clients, will be serialized on each worker – as each worker’s optimizer can only work on one set of gradients at a time. However, there is no guarantee that the full forward-backward-optimizer sequence will execute for one client at a time. This means that the gradients being applied may not correspond to the latest forward pass executed on a given worker. Also, there is no guaranteed ordering across workers. - -DistributedOptimizer creates the local optimizer with TorchScript enabled by default, so that optimizer updates are not blocked by the Python Global Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed Model Parallel). This feature is currently enabled for most optimizers. You can also follow the recipe in PyTorch tutorials to enable TorchScript support for your own custom optimizers. - -optimizer_class (optim.Optimizer) – the class of optimizer to instantiate on each worker. - -params_rref (list[RRef]) – list of RRefs to local or remote parameters to optimize. - -args – arguments to pass to the optimizer constructor on each worker. - -kwargs – arguments to pass to the optimizer constructor on each worker. - -Performs a single optimization step. - -This will call torch.optim.Optimizer.step() on each worker containing parameters to be optimized, and will block until all workers return. The provided context_id will be used to retrieve the corresponding context that contains the gradients that should be applied to the parameters. - -context_id – the autograd context id for which we should run the optimizer step. - -Wraps an arbitrary torch.optim.Optimizer and runs post-local SGD, This optimizer runs local optimizer at every step. After the warm-up stage, it averages parameters periodically after the local optimizer is applied. - -optim (Optimizer) – The local optimizer. - -averager (ModelAverager) – A model averager instance to run post-localSGD algorithm. - -This is the same as torch.optim.Optimizer load_state_dict(), but also restores model averager’s step value to the one saved in the provided state_dict. - -If there is no "step" entry in state_dict, it will raise a warning and initialize the model averager’s step to 0. - -This is the same as torch.optim.Optimizer state_dict(), but adds an extra entry to record model averager’s step to the checkpoint to ensure reload does not cause unnecessary warm up again. - -Performs a single optimization step (parameter update). - -Wrap an arbitrary optim.Optimizer and shards its states across ranks in the group. - -The sharing is done as described by ZeRO. - -The local optimizer instance in each rank is only responsible for updating approximately 1 / world_size parameters and hence only needs to keep 1 / world_size optimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state. ZeroRedundancyOptimizer can be used in conjunction with torch.nn.parallel.DistributedDataParallel to reduce per-rank peak memory consumption. - -ZeroRedundancyOptimizer uses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order. - -params (Iterable) – an Iterable of torch.Tensor s or dict s giving all parameters, which will be sharded across ranks. - -optimizer_class (torch.nn.Optimizer) – the class of the local optimizer. - -process_group (ProcessGroup, optional) – torch.distributed ProcessGroup (default: dist.group.WORLD initialized by torch.distributed.init_process_group()). - -parameters_as_bucket_view (bool, optional) – if True, parameters are packed into buckets to speed up communication, and param.data fields point to bucket views at different offsets; if False, each individual parameter is communicated separately, and each params.data stays intact (default: False). - -overlap_with_ddp (bool, optional) – if True, step() is overlapped with DistributedDataParallel ‘s gradient synchronization; this requires (1) either a functional optimizer for the optimizer_class argument or one with a functional equivalent and (2) registering a DDP communication hook constructed from one of the functions in ddp_zero_hook.py; parameters are packed into buckets matching those in DistributedDataParallel, meaning that the parameters_as_bucket_view argument is ignored. If False, step() runs disjointly after the backward pass (per normal). (default: False) - -**defaults – any trailing arguments, which are forwarded to the local optimizer. - -Currently, ZeroRedundancyOptimizer requires that all of the passed-in parameters are the same dense type. - -If you pass overlap_with_ddp=True, be wary of the following: Given the way that overlapping DistributedDataParallel with ZeroRedundancyOptimizer is currently implemented, the first two or three training iterations do not perform parameter updates in the optimizer step, depending on if static_graph=False or static_graph=True, respectively. This is because it needs information about the gradient bucketing strategy used by DistributedDataParallel, which is not finalized until the second forward pass if static_graph=False or until the third forward pass if static_graph=True. To adjust for this, one option is to prepend dummy inputs. - -ZeroRedundancyOptimizer is experimental and subject to change. - -Add a parameter group to the Optimizer ‘s param_groups. - -This can be useful when fine tuning a pre-trained network, as frozen layers can be made trainable and added to the Optimizer as training progresses. - -param_group (dict) – specifies the parameters to be optimized and group-specific optimization options. - -This method handles updating the shards on all partitions but needs to be called on all ranks. Calling this on a subset of the ranks will cause the training to hang because communication primitives are called depending on the managed parameters and expect all the ranks to participate on the same set of parameters. - -Consolidate a list of state_dict s (one per rank) on the target rank. - -to (int) – the rank that receives the optimizer states (default: 0). - -RuntimeError – if overlap_with_ddp=True and this method is called before this ZeroRedundancyOptimizer instance has been fully initialized, which happens once DistributedDataParallel gradient buckets have been rebuilt. - -This needs to be called on all ranks. - -Return default device. - -Return the ZeRO join hook. - -It enables training on uneven inputs by shadowing the collective communications in the optimizer step. - -Gradients must be properly set before this hook is called. - -kwargs (dict) – a dict containing any keyword arguments to modify the behavior of the join hook at run time; all Joinable instances sharing the same join context manager are forwarded the same value for kwargs. - -This hook does not support any keyword arguments; i.e. kwargs is unused. - -Return process group. - -Load the state pertaining to the given rank from the input state_dict, updating the local optimizer as needed. - -state_dict (dict) – optimizer state; should be an object returned from a call to state_dict(). - -RuntimeError – if overlap_with_ddp=True and this method is called before this ZeroRedundancyOptimizer instance has been fully initialized, which happens once DistributedDataParallel gradient buckets have been rebuilt. - -Return the last global optimizer state known to this rank. - -RuntimeError – if overlap_with_ddp=True and this method is called before this ZeroRedundancyOptimizer instance has been fully initialized, which happens once DistributedDataParallel gradient buckets have been rebuilt; or if this method is called without a preceding call to consolidate_state_dict(). - -Perform a single optimizer step and syncs parameters across all ranks. - -closure (Callable) – a closure that re-evaluates the model and returns the loss; optional for most optimizers. - -Optional loss depending on the underlying local optimizer. - -Any extra parameters are passed to the base optimizer as-is. - ---- - -## Torch Distributed Elastic# - -**URL:** https://pytorch.org/docs/stable/distributed.elastic.html - -**Contents:** -- Torch Distributed Elastic# -- Get Started# -- Documentation# - -Created On: Jun 16, 2025 | Last Updated On: Jul 25, 2025 - -Makes distributed PyTorch fault-tolerant and elastic. - ---- - -## Pipeline Parallelism# - -**URL:** https://pytorch.org/docs/stable/distributed.pipelining.html - -**Contents:** -- Pipeline Parallelism# -- Why Pipeline Parallel?# -- What is torch.distributed.pipelining?# -- Step 1: build PipelineStage# -- Step 2: use PipelineSchedule for execution# -- Options for Splitting a Model# - - Option 1: splitting a model manually# - - Option 2: splitting a model automatically# -- Hugging Face Examples# -- Technical Deep Dive# - -Created On: Jun 16, 2025 | Last Updated On: Aug 13, 2025 - -torch.distributed.pipelining is currently in alpha state and under development. API changes may be possible. It was migrated from the PiPPy project. - -Pipeline Parallelism is one of the primitive parallelism for deep learning. It allows the execution of a model to be partitioned such that multiple micro-batches can execute different parts of the model code concurrently. Pipeline parallelism can be an effective technique for: - -bandwidth-limited clusters - -large model inference - -The above scenarios share a commonality that the computation per device cannot hide the communication of conventional parallelism, for example, the weight all-gather of FSDP. - -While promising for scaling, pipelining is often difficult to implement because it needs to partition the execution of a model in addition to model weights. The partitioning of execution often requires intrusive code changes to your model. Another aspect of complexity comes from scheduling micro-batches in a distributed environment, with data flow dependency considered. - -The pipelining package provides a toolkit that does said things automatically which allows easy implementation of pipeline parallelism on general models. - -It consists of two parts: a splitting frontend and a distributed runtime. The splitting frontend takes your model code as-is, splits it up into “model partitions”, and captures the data-flow relationship. The distributed runtime executes the pipeline stages on different devices in parallel, handling things like micro-batch splitting, scheduling, communication, and gradient propagation, etc. - -Overall, the pipelining package provides the following features: - -Splitting of model code based on simple specification. - -Rich support for pipeline schedules, including GPipe, 1F1B, Interleaved 1F1B and Looped BFS, and providing the infrastructure for writing customized schedules. - -First-class support for cross-host pipeline parallelism, as this is where PP is typically used (over slower interconnects). - -Composability with other PyTorch parallel techniques such as data parallel (DDP, FSDP) or tensor parallel. The TorchTitan project demonstrates a “3D parallel” application on the Llama model. - -Before we can use a PipelineSchedule, we need to create PipelineStage objects that wrap the part of the model running in that stage. The PipelineStage is responsible for allocating communication buffers and creating send/recv ops to communicate with its peers. It manages intermediate buffers e.g. for the outputs of forward that have not been consumed yet, and it provides a utility for running the backwards for the stage model. - -A PipelineStage needs to know the input and output shapes for the stage model, so that it can correctly allocate communication buffers. The shapes must be static, e.g. at runtime the shapes can not change from step to step. A class PipeliningShapeError will be raised if runtime shapes do not match the expected shapes. When composing with other paralleisms or applying mixed precision, these techniques must be taken into account so the PipelineStage knows the correct shape (and dtype) for the output of the stage module at runtime. - -Users may construct a PipelineStage instance directly, by passing in an nn.Module representing the portion of the model that should run on the stage. This may require changes to the original model code. See the example in Option 1: splitting a model manually. - -Alternatively, the splitting frontend can use graph partitioning to split your model into a series of nn.Module automatically. This technique requires the model is traceable with torch.Export. Composability of the resulting nn.Module with other parallelism techniques is experimental, and may require some workarounds. Usage of this frontend may be more appealing if the user cannot easily change the model code. See Option 2: splitting a model automatically for more information. - -We can now attach the PipelineStage to a pipeline schedule, and run the schedule with input data. Here is a GPipe example: - -Note that the above code needs to be launched for each worker, thus we use a launcher service to launch multiple processes: - -To directly construct a PipelineStage, the user is responsible for providing a single nn.Module instance that owns the relevant nn.Parameters and nn.Buffers, and defines a forward() method that executes the operations relevant for that stage. For example, a condensed version of the Transformer class defined in Torchtitan shows a pattern of building an easily partitionable model. - -A model defined in this manner can be easily configured per stage by first initializing the whole model (using meta-device to avoid OOM errors), deleting undesired layers for that stage, and then creating a PipelineStage that wraps the model. For example: - -When composing with other Data or Model parallelism techniques, output_args may also be required, if the output shape/dtype of the model chunk will be affected. - -If you have a full model and do not want to spend time on modifying it into a sequence of “model partitions”, the pipeline API is here to help. Here is a brief example: - -If we print the model, we can see multiple hierarchies, which makes it hard to split by hand: - -Let us see how the pipeline API works: - -The pipeline API splits your model given a split_spec, where SplitPoint.BEGINNING stands for adding a split point before execution of certain submodule in the forward function, and similarly, SplitPoint.END for split point after such. - -If we print(pipe), we can see: - -The “model partitions” are represented by submodules (submod_0, submod_1), each of which is reconstructed with original model operations, weights and hierarchies. In addition, a “root-level” forward function is reconstructed to capture the data flow between those partitions. Such data flow will be replayed by the pipeline runtime later, in a distributed fashion. - -The Pipe object provides a method for retrieving the “model partitions”: - -The returned stage_mod is a nn.Module, with which you can create an optimizer, save or load checkpoints, or apply other parallelisms. - -Pipe also allows you to create a distributed stage runtime on a device given a ProcessGroup: - -Alternatively, if you would like to build the stage runtime later after some modification to the stage_mod, you can use a functional version of the build_stage API. For example: - -The pipeline frontend uses a tracer (torch.export) to capture your model into a single graph. If your model is not full-graph’able, you can use our manual frontend below. - -In the PiPPy repo where this package was original created, we kept examples based on unmodified Hugging Face models. See the examples/huggingface directory. - -First, the pipeline API turns our model into a directed acyclic graph (DAG) by tracing the model. It traces the model using torch.export – a PyTorch 2 full-graph capturing tool. - -Then, it groups together the operations and parameters needed by a stage into a reconstructed submodule: submod_0, submod_1, … - -Different from conventional submodule access methods like Module.children(), the pipeline API does not only cut the module structure of your model, but also the forward function of your model. - -This is necessary because model structure like Module.children() merely captures information during Module.__init__(), and does not capture any information about Module.forward(). Said differently, Module.children() lacks information about the following aspects key to pipelininig: - -Execution order of child modules in forward - -Activation flows between child modules - -Whether there are any functional operators between child modules (for example, relu or add operations will not be captured by Module.children()). - -The pipeline API, on the contrary, makes sure that the forward behavior is truly preserved. It also captures the activation flow between the partitions, helping the distributed runtime to make correct send/receive calls without human intervention. - -Another flexibility of the pipeline API is that split points can be at arbitrary levels within your model hierarchy. In the split partitions, the original model hierarchy related to that partition will be reconstructed at no cost to you. At a result, fully-qualified names (FQNs) pointing to a submodule or parameter would be still valid, and services that relies on FQNs (such as FSDP, TP or checkpointing) can still run with your partitioned modules with almost zero code change. - -You can implement your own pipeline schedule by extending one of the following two class: - -PipelineScheduleSingle - -PipelineScheduleMulti - -PipelineScheduleSingle is for schedules that assigns only one stage per rank. PipelineScheduleMulti is for schedules that assigns multiple stages per rank. - -For example, ScheduleGPipe and Schedule1F1B are subclasses of PipelineScheduleSingle. Whereas, ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble, and ScheduleZBVZeroBubble are subclasses of PipelineScheduleMulti. - -You can turn on additional logging using the TORCH_LOGS environment variable from torch._logging: - -TORCH_LOGS=+pp will display logging.DEBUG messages and all levels above it. - -TORCH_LOGS=pp will display logging.INFO messages and above. - -TORCH_LOGS=-pp will display logging.WARNING messages and above. - -The following set of APIs transform your model into a pipeline representation. - -Enum representing the points at which a split can occur in the execution of a submodule. :ivar BEGINNING: Represents adding a split point before the execution of a certain submodule in the forward function. :ivar END: Represents adding a split point after the execution of a certain submodule in the forward function. - -Split a module based on a specification. - -See Pipe for more details. - -module (Module) – The module to be split. - -mb_args (tuple[Any, ...]) – Example positional inputs, in micro-batch form. - -mb_kwargs (Optional[dict[str, Any]]) – Example keyword inputs, in micro-batch form. (default: None) - -split_spec (Optional[dict[str, torch.distributed.pipelining._IR.SplitPoint]]) – A dictionary using submodule names as split marker. (default: None) - -split_policy (Optional[Callable[[GraphModule], GraphModule]]) – The policy to use for splitting the module. (default: None) - -A pipeline representation of class Pipe. - -pipe_split is a special operator that is used to mark the boundary between stages in a module. It is used to split the module into stages. It is a no-op if your annotated module is run eagerly. - -The above example will be split into two stages. - -Class used to specify chunking of inputs - -Given a sequence of args and kwargs, split them into a number of chunks according to their respective chunking specs. - -args (tuple[Any, ...]) – Tuple of args - -kwargs (Optional[dict[str, Any]]) – Dict of kwargs - -chunks (int) – Number of chunks to split the args and kwargs into - -args_chunk_spec (Optional[tuple[torch.distributed.pipelining.microbatch.TensorChunkSpec, ...]]) – chunking specs for args, in same shape as args - -kwargs_chunk_spec (Optional[dict[str, torch.distributed.pipelining.microbatch.TensorChunkSpec]]) – chunking specs for kwargs, in same shape as kwargs - -List of sharded args kwargs_split: List of sharded kwargs - -Given a list of chunks, merge them into a single value according to the chunk spec. - -chunks (list[Any]) – list of chunks - -chunk_spec – Chunking spec for the chunks - -A class representing a pipeline stage in a pipeline parallelism setup. - -PipelineStage assumes sequential partitioning of the model, i.e. the model is split into chunks where outputs from one chunk feed into inputs of the next chunk, with no skip connections. - -PipelineStage performs runtime shape/dtype inference automatically by propagating the outputs from stage0 to stage1 and so forth, in linear order. To bypass shape inference, pass the input_args and output_args to each PipelineStage instance. - -submodule (nn.Module) – The PyTorch module wrapped by this stage. - -stage_index (int) – The ID of this stage. - -num_stages (int) – The total number of stages. - -device (torch.device) – The device where this stage is located. - -input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – The input arguments for the submodule. - -output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – The output arguments for the submodule. - -group (dist.ProcessGroup, optional) – The process group for distributed training. If None, default group. - -dw_builder (Optional[Callable[[], Callable[..., None]]) – If provided, dw_builder will build a new dw_runner function that will the W action (input weights) for F, I, W (Fwd, Input, Weight) zero bubble schedules. - -Create a pipeline stage given a stage_module to be wrapped by this stage and pipeline information. - -stage_module (torch.nn.Module) – the module to be wrapped by this stage - -stage_index (int) – the index of this stage in the pipeline - -pipe_info (PipeInfo) – information about the pipeline, can be retrieved by pipe.info() - -device (torch.device) – the device to be used by this stage - -group (Optional[dist.ProcessGroup]) – the process group to be used by this stage - -a pipeline stage that can run with PipelineSchedules. - -The GPipe schedule. Will go through all the microbatches in a fill-drain manner. - -The 1F1B schedule. Will perform one forward and one backward on the microbatches in steady state. - -The Interleaved 1F1B schedule. See https://arxiv.org/pdf/2104.04473 for details. Will perform one forward and one backward on the microbatches in steady state and supports multiple stages per rank. When microbatches are ready for multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch (also called “depth first”). - -This schedule is mostly similar to the original paper. It differs by being relaxing the requirement of num_microbatch % pp_size == 0. Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and it works as long as n_microbatches % num_rounds is 0. As a few examples, support - -pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. - -pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. - -Breadth-First Pipeline Parallelism. See https://arxiv.org/abs/2211.05953 for details. Similar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. What is different is that when microbatches are ready for multiple local stages, Loops BFS will prioritizes the earlier stage, running all available microbatches at once. - -The Interleaved Zero Bubble schedule. See https://arxiv.org/pdf/2401.10241 for details. Will perform one forward and one backward on inputs for the microbatches in steady state and supports multiple stages per rank. Uses the backward for weights to fill in the pipeline bubble. - -In particular this is implementing the ZB1P schedule in the paper. - -The Zero Bubble schedule (ZBV variant). See https://arxiv.org/pdf/2401.10241 Section 6 for details. - -This schedules requires exactly two stages per rank. - -This schedule will perform one forward and one backward on inputs for the microbatches in steady state and supports multiple stages per rank. Uses backward with respect to weights to fill in the pipeline bubble. - -This ZB-V schedule would have the “zero bubble” property only if time forward == time backward input == time backward weights. In practice, this is not likely true for real models so alternatively a greedy scheduler could be implemented for unequal/unbalanced time. - -The DualPipeV schedule. A more efficient schedule variant based on the DualPipe schedule introduced by DeepSeek in https://arxiv.org/pdf/2412.19437 - -Based on the open sourced code from deepseek-ai/DualPipe - -Base class for single-stage schedules. Implements the step method. Derived classes should implement _step_microbatches. - -Gradients are scaled by num_microbatches depending on the scale_grads argument, defaulting to True. This setting should match the configuration of your loss_fn, which may either average losses (scale_grads=True) or sum losses (scale_grads=False). - -Run one iteration of the pipeline schedule with whole-batch input. Will chunk the input into microbatches automatically, and go through the microbatches according to the schedule implementation. - -args: positional arguments to the model (as in non-pipeline case). kwargs: keyword arguments to the model (as in non-pipeline case). target: target for the loss function. losses: a list to store the losses for each microbatch. - -Base class for multi-stage schedules. Implements the step method. - -Gradients are scaled by num_microbatches depending on the scale_grads argument, defaulting to True. This setting should match the configuration of your loss_fn, which may either average losses (scale_grads=True) or sum losses (scale_grads=False). - -Run one iteration of the pipeline schedule with whole-batch input. Will chunk the input into microbatches automatically, and go through the microbatches according to the schedule implementation. - -args: positional arguments to the model (as in non-pipeline case). kwargs: keyword arguments to the model (as in non-pipeline case). target: target for the loss function. losses: a list to store the losses for each microbatch. - ---- - -## Tensor Parallelism - torch.distributed.tensor.parallel# - -**URL:** https://pytorch.org/docs/stable/distributed.tensor.parallel.html - -**Contents:** -- Tensor Parallelism - torch.distributed.tensor.parallel# - -Created On: Jun 13, 2025 | Last Updated On: Jun 13, 2025 - -Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor (DTensor)[https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md] and provides different parallelism styles: Colwise, Rowwise, and Sequence Parallelism. - -Tensor Parallelism APIs are experimental and subject to change. - -The entrypoint to parallelize your nn.Module using Tensor Parallelism is: - -Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan. - -We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains ParallelStyle, which indicates how user wants the module or sub_module to be parallelized. - -User can also specify different parallel style per module fully qualified name (FQN). - -Note that parallelize_module only accepts a 1-D DeviceMesh, if you have a 2-D or N-D DeviceMesh, slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. device_mesh["tp"]) - -module (nn.Module) – Module to be parallelized. - -device_mesh (DeviceMesh, optional) – Object which describes the mesh topology of devices for the DTensor. If not specified, the call must be under a DeviceMesh context. - -parallelize_plan (Union[ParallelStyle, Dict[str, ParallelStyle]], optional) – The plan used to parallelize the module. It can be either a ParallelStyle object which contains how we prepare input/output for Tensor Parallelism or it can be a dict of module FQN and its corresponding ParallelStyle object. If not specified, the call will do nothing at the moment. - -src_data_rank (int, optional) – the rank of the source data for the logical/global tensor, it is used by distribute_tensor() to scatter/broadcast the shards/replicas to other ranks. By default, we use group_rank=0 on each DeviceMesh dimension as the source data to preserve the single-device semantic. If passing None explicitly, parallelize_module() simply uses its local data instead of trying to preserve the single-device semantic via scatter/broadcast. Default: 0 - -A nn.Module object parallelized. - -For complex module architecture like Attention, MLP layers, we recommend composing different ParallelStyles together (i.e. ColwiseParallel and RowwiseParallel) and pass as a parallelize_plan, to achieves the desired sharding computation. - -Tensor Parallelism supports the following parallel styles: - -Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding. Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules. (i.e. MLP, Attention) - -input_layouts (Placement, optional) – The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor. If not specified, we assume the input tensor to be replicated. - -output_layouts (Placement, optional) – The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module with the user desired layout. If not specified, the output tensor is sharded on the last dimension. - -use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module output, default: True. - -A ParallelStyle object that represents Colwise sharding of the nn.Module. - -By default ColwiseParallel output is sharded on the last dimension if the output_layouts not specified, if there’re operators that require specific tensor shape (i.e. before the paired RowwiseParallel), keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size. - -Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding. Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules. (i.e. MLP, Attention) - -input_layouts (Placement, optional) – The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension. - -output_layouts (Placement, optional) – The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module with the user desired layout. If not specified, the output tensor is replicated. - -use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module output, default: True. - -A ParallelStyle object that represents Rowwise sharding of the nn.Module. - -SequenceParallel replicates a compatible nn.Module parameters and runs the sharded computation with input sharded on the sequence dimension. This currently supports nn.LayerNorm, nn.Dropout, and the RMSNorm python implementation - -This style implements the operation that is described in the paper Reducing Activation Recomputation in Large Transformer Models - -If the input passed in to this nn.Module is a torch.Tensor, it assumes that the input is already sharded on the sequence dimension and converts the input to a DTensor sharded on the sequence dimension. If the input passed in to this nn.Module is already a DTensor but is not sharded on the sequence dimension, it would redistribute the input to be sharded on the sequence dimension. - -The output of the nn.Module will be sharded on the sequence dimension. - -sequence_dim (int, optional) – The sequence dimension of the input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor that is sharded on the sequence dimension, default: 1. - -use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module output, default: False. - -A ParallelStyle object that represents Sequence Parallel of the nn.Module. - -SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e. nn.LayerNorm or RMSNorm, and they by default have ones initialization). If you have custom inits for the weights on those modules, you need to broadcast the weights before/after parallelizing to ensure that they are replicated. - -To simply configure the nn.Module’s inputs and outputs with DTensor layouts and perform necessary layout redistributions, without distribute the module parameters to DTensors, the following ParallelStyle s can be used in the parallelize_plan when calling parallelize_module: - -Configure the nn.Module’s inputs to convert the input tensors of the nn.Module to DTensors at runtime according to input_layouts, and perform layout redistribution according to the desired_input_layouts. - -input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, None need to be specified as a placeholder. default: None. - -desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. This argument needs to have the same length with input_layouts. default: None. - -input_kwarg_layouts (Dict[str, Placement]) – The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. default: None - -desired_input_kwarg_layouts – (Dict[str, Placement]): The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. default: None. - -use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module inputs, default: False. - -A ParallelStyle object that prepares the sharding layouts of the nn.Module’s inputs. - -Configure the nn.Module’s outputs to convert the output tensors of the nn.Module to DTensors at runtime according to output_layouts, and perform layout redistribution according to the desired_output_layouts. - -output_layouts (Union[Placement, Tuple[Placement]]) – The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to DTensors if they are torch.Tensor. If some outputs are not torch.Tensor or no need to convert to DTensors, None need to be specified as a placeholder. - -desired_output_layouts (Union[Placement, Tuple[Placement]]) – The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module have the desired DTensor layouts. - -use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module outputs, default: True. - -A ParallelStyle object that prepares the sharding layouts of the nn.Module’s outputs. - -Configure the nn.Module’s inputs (and outputs) to convert the input tensors (and output tensors, respectively) of the nn.Module to DTensors at runtime according to input_layouts (and output_layouts, respectively), and perform layout redistribution according to the desired_input_layouts (and desired_output_layouts, respectively). This is a combination of PrepareModuleInput and PrepareModuleOutput. - -input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, None need to be specified as a placeholder. default: None. - -desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. This argument needs to have the same length with input_layouts. default: None. - -input_kwarg_layouts (Dict[str, Placement]) – The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. default: None - -desired_input_kwarg_layouts – (Dict[str, Placement]): The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. default: None. - -use_local_input (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module inputs, default: False. - -output_layouts (Union[Placement, Tuple[Placement]]) – The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to DTensors if they are torch.Tensor. If some outputs are not torch.Tensor or no need to convert to DTensors, None need to be specified as a placeholder. - -desired_output_layouts (Union[Placement, Tuple[Placement]]) – The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module have the desired DTensor layouts. - -use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module outputs, default: True. - -A ParallelStyle object that prepares the sharding layouts of the nn.Module’s inputs and outputs. - -when using the Shard(dim) as the input/output layouts for the above ParallelStyle s, we assume the input/output activation tensors are evenly sharded on the tensor dimension dim on the DeviceMesh that TP operates on. For instance, since RowwiseParallel accepts input that is sharded on the last dimension, it assumes the input tensor has already been evenly sharded on the last dimension. For the case of uneven sharded activation tensors, one could pass in DTensor directly to the partitioned modules, and use use_local_output=False to return DTensor after each ParallelStyle, where DTensor could track the uneven sharding information. - -For models like Transformer, we recommend users to use ColwiseParallel and RowwiseParallel together in the parallelize_plan for achieve the desired sharding for the entire model (i.e. Attention and MLP). - -Parallelized cross-entropy loss computation (loss parallelism), is supported via the following context manager: - -A context manager that enables loss parallelism, where efficient parallelized loss computation can be performed when the input is sharded on the class dimension. Currently only the cross-entropy loss is supported. - -Within this context manager, one can use cross_entropy() or CrossEntropyLoss as usual, with the following assumptions on the input parameters. The corresponding backward() call, if any, also needs to happen under this context manager. - -input (DTensor) – Input logits. Assumed to be sharded on the class dimension. - -target (Union[torch.Tensor, DTensor]) – Must be ground truth class indices (class probabilities currently not supported). Assumed to be replicated across the DeviceMesh. - -weight (Union[torch.Tensor, DTensor], optional) – If given, assumed to be replicated across the DeviceMesh. - -label_smoothing – Currently not supported. - -A replicated DTensor. - -A sharded DTensor is manually created here to showcase the usage. In practice, it is usually the output of a TP module. - ---- diff --git a/skills/mlops/pytorch-lightning/SKILL.md b/skills/mlops/pytorch-lightning/SKILL.md deleted file mode 100644 index b55f288ac..000000000 --- a/skills/mlops/pytorch-lightning/SKILL.md +++ /dev/null @@ -1,349 +0,0 @@ ---- -name: pytorch-lightning -description: High-level PyTorch framework with Trainer class, automatic distributed training (DDP/FSDP/DeepSpeed), callbacks system, and minimal boilerplate. Scales from laptop to supercomputer with same code. Use when you want clean training loops with built-in best practices. -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [lightning, torch, transformers] -metadata: - hermes: - tags: [PyTorch Lightning, Training Framework, Distributed Training, DDP, FSDP, DeepSpeed, High-Level API, Callbacks, Best Practices, Scalable] - ---- - -# PyTorch Lightning - High-Level Training Framework - -## Quick start - -PyTorch Lightning organizes PyTorch code to eliminate boilerplate while maintaining flexibility. - -**Installation**: -```bash -pip install lightning -``` - -**Convert PyTorch to Lightning** (3 steps): - -```python -import lightning as L -import torch -from torch import nn -from torch.utils.data import DataLoader, Dataset - -# Step 1: Define LightningModule (organize your PyTorch code) -class LitModel(L.LightningModule): - def __init__(self, hidden_size=128): - super().__init__() - self.model = nn.Sequential( - nn.Linear(28 * 28, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, 10) - ) - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = nn.functional.cross_entropy(y_hat, y) - self.log('train_loss', loss) # Auto-logged to TensorBoard - return loss - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=1e-3) - -# Step 2: Create data -train_loader = DataLoader(train_dataset, batch_size=32) - -# Step 3: Train with Trainer (handles everything else!) -trainer = L.Trainer(max_epochs=10, accelerator='gpu', devices=2) -model = LitModel() -trainer.fit(model, train_loader) -``` - -**That's it!** Trainer handles: -- GPU/TPU/CPU switching -- Distributed training (DDP, FSDP, DeepSpeed) -- Mixed precision (FP16, BF16) -- Gradient accumulation -- Checkpointing -- Logging -- Progress bars - -## Common workflows - -### Workflow 1: From PyTorch to Lightning - -**Original PyTorch code**: -```python -model = MyModel() -optimizer = torch.optim.Adam(model.parameters()) -model.to('cuda') - -for epoch in range(max_epochs): - for batch in train_loader: - batch = batch.to('cuda') - optimizer.zero_grad() - loss = model(batch) - loss.backward() - optimizer.step() -``` - -**Lightning version**: -```python -class LitModel(L.LightningModule): - def __init__(self): - super().__init__() - self.model = MyModel() - - def training_step(self, batch, batch_idx): - loss = self.model(batch) # No .to('cuda') needed! - return loss - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters()) - -# Train -trainer = L.Trainer(max_epochs=10, accelerator='gpu') -trainer.fit(LitModel(), train_loader) -``` - -**Benefits**: 40+ lines → 15 lines, no device management, automatic distributed - -### Workflow 2: Validation and testing - -```python -class LitModel(L.LightningModule): - def __init__(self): - super().__init__() - self.model = MyModel() - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = nn.functional.cross_entropy(y_hat, y) - self.log('train_loss', loss) - return loss - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - val_loss = nn.functional.cross_entropy(y_hat, y) - acc = (y_hat.argmax(dim=1) == y).float().mean() - self.log('val_loss', val_loss) - self.log('val_acc', acc) - - def test_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - test_loss = nn.functional.cross_entropy(y_hat, y) - self.log('test_loss', test_loss) - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=1e-3) - -# Train with validation -trainer = L.Trainer(max_epochs=10) -trainer.fit(model, train_loader, val_loader) - -# Test -trainer.test(model, test_loader) -``` - -**Automatic features**: -- Validation runs every epoch by default -- Metrics logged to TensorBoard -- Best model checkpointing based on val_loss - -### Workflow 3: Distributed training (DDP) - -```python -# Same code as single GPU! -model = LitModel() - -# 8 GPUs with DDP (automatic!) -trainer = L.Trainer( - accelerator='gpu', - devices=8, - strategy='ddp' # Or 'fsdp', 'deepspeed' -) - -trainer.fit(model, train_loader) -``` - -**Launch**: -```bash -# Single command, Lightning handles the rest -python train.py -``` - -**No changes needed**: -- Automatic data distribution -- Gradient synchronization -- Multi-node support (just set `num_nodes=2`) - -### Workflow 4: Callbacks for monitoring - -```python -from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor - -# Create callbacks -checkpoint = ModelCheckpoint( - monitor='val_loss', - mode='min', - save_top_k=3, - filename='model-{epoch:02d}-{val_loss:.2f}' -) - -early_stop = EarlyStopping( - monitor='val_loss', - patience=5, - mode='min' -) - -lr_monitor = LearningRateMonitor(logging_interval='epoch') - -# Add to Trainer -trainer = L.Trainer( - max_epochs=100, - callbacks=[checkpoint, early_stop, lr_monitor] -) - -trainer.fit(model, train_loader, val_loader) -``` - -**Result**: -- Auto-saves best 3 models -- Stops early if no improvement for 5 epochs -- Logs learning rate to TensorBoard - -### Workflow 5: Learning rate scheduling - -```python -class LitModel(L.LightningModule): - # ... (training_step, etc.) - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) - - # Cosine annealing - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, - T_max=100, - eta_min=1e-5 - ) - - return { - 'optimizer': optimizer, - 'lr_scheduler': { - 'scheduler': scheduler, - 'interval': 'epoch', # Update per epoch - 'frequency': 1 - } - } - -# Learning rate auto-logged! -trainer = L.Trainer(max_epochs=100) -trainer.fit(model, train_loader) -``` - -## When to use vs alternatives - -**Use PyTorch Lightning when**: -- Want clean, organized code -- Need production-ready training loops -- Switching between single GPU, multi-GPU, TPU -- Want built-in callbacks and logging -- Team collaboration (standardized structure) - -**Key advantages**: -- **Organized**: Separates research code from engineering -- **Automatic**: DDP, FSDP, DeepSpeed with 1 line -- **Callbacks**: Modular training extensions -- **Reproducible**: Less boilerplate = fewer bugs -- **Tested**: 1M+ downloads/month, battle-tested - -**Use alternatives instead**: -- **Accelerate**: Minimal changes to existing code, more flexibility -- **Ray Train**: Multi-node orchestration, hyperparameter tuning -- **Raw PyTorch**: Maximum control, learning purposes -- **Keras**: TensorFlow ecosystem - -## Common issues - -**Issue: Loss not decreasing** - -Check data and model setup: -```python -# Add to training_step -def training_step(self, batch, batch_idx): - if batch_idx == 0: - print(f"Batch shape: {batch[0].shape}") - print(f"Labels: {batch[1]}") - loss = ... - return loss -``` - -**Issue: Out of memory** - -Reduce batch size or use gradient accumulation: -```python -trainer = L.Trainer( - accumulate_grad_batches=4, # Effective batch = batch_size × 4 - precision='bf16' # Or 'fp16', reduces memory 50% -) -``` - -**Issue: Validation not running** - -Ensure you pass val_loader: -```python -# WRONG -trainer.fit(model, train_loader) - -# CORRECT -trainer.fit(model, train_loader, val_loader) -``` - -**Issue: DDP spawns multiple processes unexpectedly** - -Lightning auto-detects GPUs. Explicitly set devices: -```python -# Test on CPU first -trainer = L.Trainer(accelerator='cpu', devices=1) - -# Then GPU -trainer = L.Trainer(accelerator='gpu', devices=1) -``` - -## Advanced topics - -**Callbacks**: See [references/callbacks.md](references/callbacks.md) for EarlyStopping, ModelCheckpoint, custom callbacks, and callback hooks. - -**Distributed strategies**: See [references/distributed.md](references/distributed.md) for DDP, FSDP, DeepSpeed ZeRO integration, multi-node setup. - -**Hyperparameter tuning**: See [references/hyperparameter-tuning.md](references/hyperparameter-tuning.md) for integration with Optuna, Ray Tune, and WandB sweeps. - -## Hardware requirements - -- **CPU**: Works (good for debugging) -- **Single GPU**: Works -- **Multi-GPU**: DDP (default), FSDP, or DeepSpeed -- **Multi-node**: DDP, FSDP, DeepSpeed -- **TPU**: Supported (8 cores) -- **Apple MPS**: Supported - -**Precision options**: -- FP32 (default) -- FP16 (V100, older GPUs) -- BF16 (A100/H100, recommended) -- FP8 (H100) - -## Resources - -- Docs: https://lightning.ai/docs/pytorch/stable/ -- GitHub: https://github.com/Lightning-AI/pytorch-lightning ⭐ 29,000+ -- Version: 2.5.5+ -- Examples: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples -- Discord: https://discord.gg/lightning-ai -- Used by: Kaggle winners, research labs, production teams - - diff --git a/skills/mlops/pytorch-lightning/references/callbacks.md b/skills/mlops/pytorch-lightning/references/callbacks.md deleted file mode 100644 index 3d65ffa2d..000000000 --- a/skills/mlops/pytorch-lightning/references/callbacks.md +++ /dev/null @@ -1,436 +0,0 @@ -# PyTorch Lightning Callbacks - -## Overview - -Callbacks add functionality to training without modifying the LightningModule. They capture **non-essential logic** like checkpointing, early stopping, and logging. - -## Built-In Callbacks - -### 1. ModelCheckpoint - -**Saves best models during training**: - -```python -from lightning.pytorch.callbacks import ModelCheckpoint - -# Save top 3 models based on validation loss -checkpoint = ModelCheckpoint( - dirpath='checkpoints/', - filename='model-{epoch:02d}-{val_loss:.2f}', - monitor='val_loss', - mode='min', - save_top_k=3, - save_last=True, # Also save last epoch - verbose=True -) - -trainer = L.Trainer(callbacks=[checkpoint]) -trainer.fit(model, train_loader, val_loader) -``` - -**Configuration options**: -```python -checkpoint = ModelCheckpoint( - monitor='val_acc', # Metric to monitor - mode='max', # 'max' for accuracy, 'min' for loss - save_top_k=5, # Keep best 5 models - save_last=True, # Save last epoch separately - every_n_epochs=1, # Save every N epochs - save_on_train_epoch_end=False, # Save on validation end instead - filename='best-{epoch}-{val_acc:.3f}', # Naming pattern - auto_insert_metric_name=False # Don't auto-add metric to filename -) -``` - -**Load checkpoint**: -```python -# Load best model -best_model_path = checkpoint.best_model_path -model = LitModel.load_from_checkpoint(best_model_path) - -# Resume training -trainer = L.Trainer(callbacks=[checkpoint]) -trainer.fit(model, train_loader, val_loader, ckpt_path='checkpoints/last.ckpt') -``` - -### 2. EarlyStopping - -**Stops training when metric stops improving**: - -```python -from lightning.pytorch.callbacks import EarlyStopping - -early_stop = EarlyStopping( - monitor='val_loss', - patience=5, # Wait 5 epochs - mode='min', - min_delta=0.001, # Minimum change to qualify as improvement - verbose=True, - strict=True, # Crash if monitored metric not found - check_on_train_epoch_end=False # Check on validation end -) - -trainer = L.Trainer(callbacks=[early_stop]) -trainer.fit(model, train_loader, val_loader) -# Stops automatically if no improvement for 5 epochs -``` - -**Advanced usage**: -```python -early_stop = EarlyStopping( - monitor='val_loss', - patience=10, - min_delta=0.0, - verbose=True, - mode='min', - stopping_threshold=0.1, # Stop if val_loss < 0.1 - divergence_threshold=5.0, # Stop if val_loss > 5.0 - check_finite=True # Stop on NaN/Inf -) -``` - -### 3. LearningRateMonitor - -**Logs learning rate**: - -```python -from lightning.pytorch.callbacks import LearningRateMonitor - -lr_monitor = LearningRateMonitor( - logging_interval='epoch', # Or 'step' - log_momentum=True # Also log momentum -) - -trainer = L.Trainer(callbacks=[lr_monitor]) -# Learning rate automatically logged to TensorBoard/WandB -``` - -### 4. TQDMProgressBar - -**Customizes progress bar**: - -```python -from lightning.pytorch.callbacks import TQDMProgressBar - -progress_bar = TQDMProgressBar( - refresh_rate=10, # Update every 10 batches - process_position=0 -) - -trainer = L.Trainer(callbacks=[progress_bar]) -``` - -### 5. GradientAccumulationScheduler - -**Dynamic gradient accumulation**: - -```python -from lightning.pytorch.callbacks import GradientAccumulationScheduler - -# Accumulate more gradients as training progresses -accumulator = GradientAccumulationScheduler( - scheduling={ - 0: 8, # Epochs 0-4: accumulate 8 batches - 5: 4, # Epochs 5-9: accumulate 4 batches - 10: 2 # Epochs 10+: accumulate 2 batches - } -) - -trainer = L.Trainer(callbacks=[accumulator]) -``` - -### 6. StochasticWeightAveraging (SWA) - -**Averages weights for better generalization**: - -```python -from lightning.pytorch.callbacks import StochasticWeightAveraging - -swa = StochasticWeightAveraging( - swa_lrs=1e-2, # SWA learning rate - swa_epoch_start=0.8, # Start at 80% of training - annealing_epochs=10, # Annealing period - annealing_strategy='cos' # 'cos' or 'linear' -) - -trainer = L.Trainer(callbacks=[swa]) -``` - -## Custom Callbacks - -### Basic Custom Callback - -```python -from lightning.pytorch.callbacks import Callback - -class PrintingCallback(Callback): - def on_train_start(self, trainer, pl_module): - print("Training is starting!") - - def on_train_end(self, trainer, pl_module): - print("Training is done!") - - def on_epoch_end(self, trainer, pl_module): - print(f"Epoch {trainer.current_epoch} ended") - -# Use it -trainer = L.Trainer(callbacks=[PrintingCallback()]) -``` - -### Advanced Custom Callback - -```python -class MetricsCallback(Callback): - """Logs custom metrics every N batches.""" - - def __init__(self, log_every_n_batches=100): - self.log_every_n_batches = log_every_n_batches - self.metrics = [] - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - if batch_idx % self.log_every_n_batches == 0: - # Compute custom metric - metric = self.compute_metric(outputs) - self.metrics.append(metric) - - # Log to Lightning - pl_module.log('custom_metric', metric) - - def compute_metric(self, outputs): - # Your custom logic - return outputs['loss'].item() - - def state_dict(self): - """Save callback state in checkpoint.""" - return {'metrics': self.metrics} - - def load_state_dict(self, state_dict): - """Restore callback state from checkpoint.""" - self.metrics = state_dict['metrics'] -``` - -### Gradient Monitoring Callback - -```python -class GradientMonitorCallback(Callback): - """Monitor gradient norms.""" - - def on_after_backward(self, trainer, pl_module): - # Compute gradient norm - total_norm = 0.0 - for p in pl_module.parameters(): - if p.grad is not None: - param_norm = p.grad.data.norm(2) - total_norm += param_norm.item() ** 2 - total_norm = total_norm ** 0.5 - - # Log - pl_module.log('grad_norm', total_norm) - - # Warn if exploding - if total_norm > 100: - print(f"Warning: Large gradient norm: {total_norm:.2f}") -``` - -### Model Inspection Callback - -```python -class ModelInspectionCallback(Callback): - """Inspect model activations during training.""" - - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): - if batch_idx == 0: # First batch of epoch - # Register hooks - self.activations = {} - - def get_activation(name): - def hook(model, input, output): - self.activations[name] = output.detach() - return hook - - # Attach to specific layers - pl_module.model.layer1.register_forward_hook(get_activation('layer1')) - pl_module.model.layer2.register_forward_hook(get_activation('layer2')) - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - if batch_idx == 0: - # Log activation statistics - for name, activation in self.activations.items(): - mean = activation.mean().item() - std = activation.std().item() - pl_module.log(f'{name}_mean', mean) - pl_module.log(f'{name}_std', std) -``` - -## Callback Hooks - -**All available hooks**: - -```python -class MyCallback(Callback): - # Setup/Teardown - def setup(self, trainer, pl_module, stage): - """Called at beginning of fit/test/predict.""" - pass - - def teardown(self, trainer, pl_module, stage): - """Called at end of fit/test/predict.""" - pass - - # Training - def on_train_start(self, trainer, pl_module): - pass - - def on_train_epoch_start(self, trainer, pl_module): - pass - - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): - pass - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - pass - - def on_train_epoch_end(self, trainer, pl_module): - pass - - def on_train_end(self, trainer, pl_module): - pass - - # Validation - def on_validation_start(self, trainer, pl_module): - pass - - def on_validation_epoch_start(self, trainer, pl_module): - pass - - def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - pass - - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - pass - - def on_validation_epoch_end(self, trainer, pl_module): - pass - - def on_validation_end(self, trainer, pl_module): - pass - - # Test (same structure as validation) - def on_test_start(self, trainer, pl_module): - pass - # ... (test_epoch_start, test_batch_start, etc.) - - # Predict - def on_predict_start(self, trainer, pl_module): - pass - # ... (predict_epoch_start, predict_batch_start, etc.) - - # Backward - def on_before_backward(self, trainer, pl_module, loss): - pass - - def on_after_backward(self, trainer, pl_module): - pass - - # Optimizer - def on_before_optimizer_step(self, trainer, pl_module, optimizer): - pass - - # Checkpointing - def on_save_checkpoint(self, trainer, pl_module, checkpoint): - """Add data to checkpoint.""" - pass - - def on_load_checkpoint(self, trainer, pl_module, checkpoint): - """Restore data from checkpoint.""" - pass -``` - -## Combining Multiple Callbacks - -```python -from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor - -# Create all callbacks -checkpoint = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=3) -early_stop = EarlyStopping(monitor='val_loss', patience=5) -lr_monitor = LearningRateMonitor(logging_interval='epoch') -custom_callback = MyCustomCallback() - -# Add all to Trainer -trainer = L.Trainer( - callbacks=[checkpoint, early_stop, lr_monitor, custom_callback] -) - -trainer.fit(model, train_loader, val_loader) -``` - -**Execution order**: Callbacks execute in the order they're added - -## Best Practices - -### 1. Keep Callbacks Independent - -**Bad** (dependent on other callback): -```python -class BadCallback(Callback): - def on_train_end(self, trainer, pl_module): - # Assumes ModelCheckpoint is present - best_path = trainer.checkpoint_callback.best_model_path # Fragile! -``` - -**Good** (self-contained): -```python -class GoodCallback(Callback): - def on_train_end(self, trainer, pl_module): - # Find checkpoint callback if present - for callback in trainer.callbacks: - if isinstance(callback, ModelCheckpoint): - best_path = callback.best_model_path - break -``` - -### 2. Use State Dict for Persistence - -```python -class StatefulCallback(Callback): - def __init__(self): - self.counter = 0 - self.history = [] - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - self.counter += 1 - self.history.append(outputs['loss'].item()) - - def state_dict(self): - """Save state.""" - return { - 'counter': self.counter, - 'history': self.history - } - - def load_state_dict(self, state_dict): - """Restore state.""" - self.counter = state_dict['counter'] - self.history = state_dict['history'] -``` - -### 3. Handle Distributed Training - -```python -class DistributedCallback(Callback): - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - # Only run on main process - if trainer.is_global_zero: - print("This only prints once in distributed training") - - # Run on all processes - loss = outputs['loss'] - # ... do something with loss on each GPU -``` - -## Resources - -- Callback API: https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html -- Built-in callbacks: https://lightning.ai/docs/pytorch/stable/api_references.html#callbacks -- Examples: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/callbacks diff --git a/skills/mlops/pytorch-lightning/references/distributed.md b/skills/mlops/pytorch-lightning/references/distributed.md deleted file mode 100644 index 886b3c75a..000000000 --- a/skills/mlops/pytorch-lightning/references/distributed.md +++ /dev/null @@ -1,490 +0,0 @@ -# PyTorch Lightning Distributed Training - -## Distributed Strategies - -Lightning supports multiple distributed strategies with a single parameter change. - -### 1. DDP (DistributedDataParallel) - -**Default strategy for multi-GPU**: - -```python -# Automatic DDP on all available GPUs -trainer = L.Trainer(accelerator='gpu', devices=4, strategy='ddp') - -# Or auto-detect -trainer = L.Trainer(accelerator='gpu', devices='auto') -``` - -**How DDP works**: -- Replicates model on each GPU -- Each GPU processes different batch -- Gradients all-reduced across GPUs -- Model weights synchronized - -**Launch**: -```bash -# Lightning handles spawning processes automatically -python train.py -``` - -**DDP Configuration**: -```python -from lightning.pytorch.strategies import DDPStrategy - -strategy = DDPStrategy( - find_unused_parameters=False, # Set True if model has unused params - gradient_as_bucket_view=True, # Memory optimization - static_graph=False, # Set True if graph doesn't change -) - -trainer = L.Trainer(strategy=strategy) -``` - -### 2. FSDP (Fully Sharded Data Parallel) - -**For large models (7B+ parameters)**: - -```python -from lightning.pytorch.strategies import FSDPStrategy - -strategy = FSDPStrategy( - sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent - activation_checkpointing=None, # Or specify layer types - cpu_offload=False, # CPU offload for memory -) - -trainer = L.Trainer( - accelerator='gpu', - devices=8, - strategy=strategy, - precision='bf16' # Recommended with FSDP -) - -trainer.fit(model, train_loader) -``` - -**FSDP Sharding Strategies**: -```python -# FULL_SHARD (most memory efficient, equivalent to ZeRO-3) -strategy = FSDPStrategy(sharding_strategy="FULL_SHARD") - -# SHARD_GRAD_OP (less memory efficient, equivalent to ZeRO-2) -strategy = FSDPStrategy(sharding_strategy="SHARD_GRAD_OP") - -# NO_SHARD (no sharding, like DDP) -strategy = FSDPStrategy(sharding_strategy="NO_SHARD") -``` - -**Auto-wrap policy** (wrap transformer blocks): -```python -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from transformers.models.gpt2.modeling_gpt2 import GPT2Block -import functools - -auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={GPT2Block} -) - -strategy = FSDPStrategy( - auto_wrap_policy=auto_wrap_policy, - activation_checkpointing_policy={GPT2Block} # Checkpoint these blocks -) -``` - -### 3. DeepSpeed - -**For massive models (70B+ parameters)**: - -```python -from lightning.pytorch.strategies import DeepSpeedStrategy - -# DeepSpeed ZeRO-3 with CPU offload -strategy = DeepSpeedStrategy( - stage=3, # ZeRO-3 - offload_optimizer=True, # CPU offload optimizer - offload_parameters=True, # CPU offload parameters - cpu_checkpointing=True, # Checkpoint to CPU -) - -trainer = L.Trainer( - accelerator='gpu', - devices=8, - strategy=strategy, - precision='bf16' -) - -trainer.fit(model, train_loader) -``` - -**DeepSpeed configuration file**: -```json -{ - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "gradient_accumulation_steps": "auto", - "zero_optimization": { - "stage": 3, - "offload_optimizer": { - "device": "cpu", - "pin_memory": true - }, - "offload_param": { - "device": "cpu", - "pin_memory": true - }, - "overlap_comm": true, - "contiguous_gradients": true, - "reduce_bucket_size": 5e8, - "stage3_prefetch_bucket_size": 5e8, - "stage3_param_persistence_threshold": 1e6 - }, - "bf16": { - "enabled": true - } -} -``` - -**Use config file**: -```python -strategy = DeepSpeedStrategy(config='deepspeed_config.json') -trainer = L.Trainer(strategy=strategy) -``` - -### 4. DDP Spawn - -**Windows-compatible DDP**: - -```python -# Use when DDP doesn't work (e.g., Windows, Jupyter) -trainer = L.Trainer( - accelerator='gpu', - devices=2, - strategy='ddp_spawn' # Spawns new processes -) -``` - -**Note**: Slower than DDP due to process spawning overhead - -## Multi-Node Training - -### Setup Multi-Node Cluster - -**Node 0 (master)**: -```bash -export MASTER_ADDR=192.168.1.100 -export MASTER_PORT=12355 -export WORLD_SIZE=16 # 2 nodes × 8 GPUs -export NODE_RANK=0 - -python train.py -``` - -**Node 1 (worker)**: -```bash -export MASTER_ADDR=192.168.1.100 -export MASTER_PORT=12355 -export WORLD_SIZE=16 -export NODE_RANK=1 - -python train.py -``` - -**Training script**: -```python -trainer = L.Trainer( - accelerator='gpu', - devices=8, # GPUs per node - num_nodes=2, # Total nodes - strategy='ddp' -) - -trainer.fit(model, train_loader) -``` - -### SLURM Integration - -**SLURM job script**: -```bash -#!/bin/bash -#SBATCH --nodes=4 -#SBATCH --ntasks-per-node=8 -#SBATCH --gres=gpu:8 -#SBATCH --time=24:00:00 - -# Lightning auto-detects SLURM environment -srun python train.py -``` - -**Training script** (no changes needed): -```python -# Lightning automatically reads SLURM environment variables -trainer = L.Trainer( - accelerator='gpu', - devices=8, - num_nodes=4, # From SBATCH --nodes - strategy='ddp' -) -``` - -### Kubernetes (KubeFlow) - -**Training script**: -```python -import os - -# Lightning auto-detects Kubernetes -trainer = L.Trainer( - accelerator='gpu', - devices=int(os.getenv('WORLD_SIZE', 1)), - strategy='ddp' -) -``` - -## Mixed Precision Training - -### BF16 (A100/H100) - -```python -trainer = L.Trainer( - precision='bf16', # Or 'bf16-mixed' - accelerator='gpu' -) -``` - -**Advantages**: -- No gradient scaler needed -- Same dynamic range as FP32 -- 2× speedup, 50% memory reduction - -### FP16 (V100, older GPUs) - -```python -trainer = L.Trainer( - precision='16-mixed', # Or just '16' - accelerator='gpu' -) -``` - -**Automatic gradient scaling** handled by Lightning - -### FP8 (H100) - -```python -# Requires transformer_engine -# pip install transformer-engine[pytorch] - -trainer = L.Trainer( - precision='transformer-engine', - accelerator='gpu' -) -``` - -**Benefits**: 2× faster than BF16 on H100 - -## Gradient Accumulation - -**Simulate larger batch size**: - -```python -trainer = L.Trainer( - accumulate_grad_batches=4, # Accumulate 4 batches - precision='bf16' -) - -# Effective batch = batch_size × accumulate_grad_batches × num_gpus -# Example: 32 × 4 × 8 = 1024 -``` - -**Dynamic accumulation**: -```python -# Accumulate more early in training -trainer = L.Trainer( - accumulate_grad_batches={ - 0: 8, # Epochs 0-4: accumulate 8 - 5: 4, # Epochs 5-9: accumulate 4 - 10: 2 # Epochs 10+: accumulate 2 - } -) -``` - -## Checkpointing in Distributed - -### Save Checkpoint - -```python -from lightning.pytorch.callbacks import ModelCheckpoint - -# Only rank 0 saves by default -checkpoint = ModelCheckpoint( - dirpath='checkpoints/', - filename='model-{epoch:02d}', - save_top_k=3 -) - -trainer = L.Trainer(callbacks=[checkpoint], strategy='ddp') -trainer.fit(model, train_loader) -``` - -**Manual save**: -```python -class MyModel(L.LightningModule): - def training_step(self, batch, batch_idx): - # Training... - loss = ... - - # Save every 1000 steps (only rank 0) - if batch_idx % 1000 == 0 and self.trainer.is_global_zero: - self.trainer.save_checkpoint(f'checkpoint_step_{batch_idx}.ckpt') - - return loss -``` - -### Load Checkpoint - -```python -# Resume training -trainer = L.Trainer(strategy='ddp') -trainer.fit(model, train_loader, ckpt_path='checkpoints/last.ckpt') - -# Load for inference -model = MyModel.load_from_checkpoint('checkpoints/best.ckpt') -model.eval() -``` - -## Strategy Comparison - -| Strategy | Memory Efficiency | Speed | Use Case | -|----------|------------------|-------|----------| -| DDP | Low | Fast | Small models (<7B), single node | -| FSDP | High | Medium | Large models (7-70B) | -| DeepSpeed ZeRO-2 | Medium | Fast | Medium models (1-13B) | -| DeepSpeed ZeRO-3 | Very High | Slower | Massive models (70B+) | -| DDP Spawn | Low | Slow | Windows, debugging | - -## Best Practices - -### 1. Choose Right Strategy - -```python -# Model size guide -if model_params < 1e9: # <1B - strategy = 'ddp' -elif model_params < 7e9: # 1-7B - strategy = 'ddp' or DeepSpeedStrategy(stage=2) -elif model_params < 70e9: # 7-70B - strategy = FSDPStrategy(sharding_strategy="FULL_SHARD") -else: # 70B+ - strategy = DeepSpeedStrategy(stage=3, offload_optimizer=True) - -trainer = L.Trainer(strategy=strategy) -``` - -### 2. Avoid Sync Issues - -```python -class MyModel(L.LightningModule): - def training_step(self, batch, batch_idx): - # WRONG: This runs on all GPUs independently - if batch_idx % 100 == 0: - self.log_something() # Logged 8 times on 8 GPUs! - - # CORRECT: Use is_global_zero - if batch_idx % 100 == 0 and self.trainer.is_global_zero: - self.log_something() # Logged once - - loss = ... - return loss -``` - -### 3. Efficient Data Loading - -```python -from torch.utils.data import DataLoader, DistributedSampler - -# Lightning handles DistributedSampler automatically -train_loader = DataLoader( - dataset, - batch_size=32, - num_workers=4, # 4 workers per GPU - pin_memory=True, - persistent_workers=True -) - -# Lightning automatically wraps with DistributedSampler in DDP -trainer.fit(model, train_loader) -``` - -### 4. Reduce Communication Overhead - -```python -from lightning.pytorch.strategies import DDPStrategy - -strategy = DDPStrategy( - gradient_as_bucket_view=True, # Reduce memory copies - static_graph=True, # If model graph doesn't change (faster) -) - -trainer = L.Trainer(strategy=strategy) -``` - -## Common Issues - -### Issue: NCCL Timeout - -**Symptom**: Training hangs with `NCCL timeout` error - -**Solution 1**: Increase timeout -```bash -export NCCL_TIMEOUT=3600 # 1 hour -python train.py -``` - -**Solution 2**: Check network -```bash -# Test inter-node communication -nvidia-smi nvlink -s - -# Verify all nodes can ping each other -ping -``` - -### Issue: OOM with FSDP - -**Solution**: Enable CPU offload -```python -strategy = FSDPStrategy( - sharding_strategy="FULL_SHARD", - cpu_offload=True # Offload to CPU -) -``` - -### Issue: Different Results with DDP - -**Cause**: Different random seeds per GPU - -**Solution**: Set seed in LightningModule -```python -class MyModel(L.LightningModule): - def __init__(self): - super().__init__() - L.seed_everything(42, workers=True) # Same seed everywhere -``` - -### Issue: DeepSpeed Config Errors - -**Solution**: Use Lightning's auto config -```python -strategy = DeepSpeedStrategy( - stage=3, - # Don't specify config file, Lightning generates automatically -) -``` - -## Resources - -- Distributed strategies: https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html -- FSDP guide: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html -- DeepSpeed: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/deepspeed.html -- Multi-node: https://lightning.ai/docs/pytorch/stable/clouds/cluster.html diff --git a/skills/mlops/pytorch-lightning/references/hyperparameter-tuning.md b/skills/mlops/pytorch-lightning/references/hyperparameter-tuning.md deleted file mode 100644 index ea57f7116..000000000 --- a/skills/mlops/pytorch-lightning/references/hyperparameter-tuning.md +++ /dev/null @@ -1,556 +0,0 @@ -# Hyperparameter Tuning with PyTorch Lightning - -## Integration with Tuning Frameworks - -Lightning integrates seamlessly with popular hyperparameter tuning libraries. - -### 1. Ray Tune Integration - -**Installation**: -```bash -pip install ray[tune] -pip install lightning -``` - -**Basic Ray Tune example**: - -```python -import lightning as L -from ray import tune -from ray.tune.integration.pytorch_lightning import TuneReportCallback - -class LitModel(L.LightningModule): - def __init__(self, lr, batch_size): - super().__init__() - self.lr = lr - self.batch_size = batch_size - self.model = nn.Sequential(nn.Linear(10, 128), nn.ReLU(), nn.Linear(128, 1)) - - def training_step(self, batch, batch_idx): - loss = self.model(batch).mean() - self.log('train_loss', loss) - return loss - - def validation_step(self, batch, batch_idx): - val_loss = self.model(batch).mean() - self.log('val_loss', val_loss) - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=self.lr) - -def train_fn(config): - """Training function for Ray Tune.""" - model = LitModel(lr=config["lr"], batch_size=config["batch_size"]) - - # Add callback to report metrics to Tune - trainer = L.Trainer( - max_epochs=10, - callbacks=[TuneReportCallback({"loss": "val_loss"}, on="validation_end")] - ) - - trainer.fit(model, train_loader, val_loader) - -# Define search space -config = { - "lr": tune.loguniform(1e-5, 1e-1), - "batch_size": tune.choice([16, 32, 64, 128]) -} - -# Run hyperparameter search -analysis = tune.run( - train_fn, - config=config, - num_samples=20, # 20 trials - resources_per_trial={"gpu": 1} -) - -# Best hyperparameters -best_config = analysis.get_best_config(metric="loss", mode="min") -print(f"Best config: {best_config}") -``` - -**Advanced: Population-Based Training (PBT)**: - -```python -from ray.tune.schedulers import PopulationBasedTraining - -# PBT scheduler -scheduler = PopulationBasedTraining( - time_attr='training_iteration', - metric='val_loss', - mode='min', - perturbation_interval=5, # Perturb every 5 epochs - hyperparam_mutations={ - "lr": tune.loguniform(1e-5, 1e-1), - "batch_size": [16, 32, 64, 128] - } -) - -analysis = tune.run( - train_fn, - config=config, - num_samples=8, # Population size - scheduler=scheduler, - resources_per_trial={"gpu": 1} -) -``` - -### 2. Optuna Integration - -**Installation**: -```bash -pip install optuna -pip install optuna-integration -``` - -**Optuna example**: - -```python -import optuna -from optuna.integration import PyTorchLightningPruningCallback - -def objective(trial): - # Suggest hyperparameters - lr = trial.suggest_loguniform('lr', 1e-5, 1e-1) - batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128]) - n_layers = trial.suggest_int('n_layers', 1, 3) - hidden_size = trial.suggest_int('hidden_size', 64, 512, step=64) - - # Create model - model = LitModel(lr=lr, n_layers=n_layers, hidden_size=hidden_size) - - # Pruning callback (early stopping for bad trials) - pruning_callback = PyTorchLightningPruningCallback(trial, monitor="val_loss") - - trainer = L.Trainer( - max_epochs=20, - callbacks=[pruning_callback], - enable_progress_bar=False, - logger=False - ) - - trainer.fit(model, train_loader, val_loader) - - return trainer.callback_metrics["val_loss"].item() - -# Create study -study = optuna.create_study( - direction='minimize', - pruner=optuna.pruners.MedianPruner() # Prune bad trials early -) - -# Optimize -study.optimize(objective, n_trials=50, timeout=3600) - -# Best params -print(f"Best trial: {study.best_trial.params}") -print(f"Best value: {study.best_value}") - -# Visualization -optuna.visualization.plot_optimization_history(study).show() -optuna.visualization.plot_param_importances(study).show() -``` - -**Optuna with distributed training**: - -```python -import optuna - -# Shared database for distributed optimization -storage = optuna.storages.RDBStorage( - url='postgresql://user:pass@localhost/optuna' -) - -study = optuna.create_study( - study_name='distributed_study', - storage=storage, - load_if_exists=True, - direction='minimize' -) - -# Run on multiple machines -study.optimize(objective, n_trials=50) -``` - -### 3. Weights & Biases (WandB) Sweeps - -**Installation**: -```bash -pip install wandb -``` - -**WandB sweep config** (`sweep.yaml`): -```yaml -program: train.py -method: bayes -metric: - name: val_loss - goal: minimize -parameters: - lr: - distribution: log_uniform_values - min: 0.00001 - max: 0.1 - batch_size: - values: [16, 32, 64, 128] - optimizer: - values: ['adam', 'sgd', 'adamw'] - dropout: - distribution: uniform - min: 0.0 - max: 0.5 -``` - -**Training script** (`train.py`): -```python -import wandb -import lightning as L -from lightning.pytorch.loggers import WandbLogger - -def train(): - # Initialize wandb - wandb.init() - config = wandb.config - - # Create model with sweep params - model = LitModel( - lr=config.lr, - batch_size=config.batch_size, - optimizer=config.optimizer, - dropout=config.dropout - ) - - # WandB logger - wandb_logger = WandbLogger(project='hyperparameter-sweep') - - trainer = L.Trainer( - max_epochs=20, - logger=wandb_logger - ) - - trainer.fit(model, train_loader, val_loader) - -if __name__ == '__main__': - train() -``` - -**Launch sweep**: -```bash -# Initialize sweep -wandb sweep sweep.yaml -# Output: wandb: Created sweep with ID: abc123 - -# Run agent (can run on multiple machines) -wandb agent your-entity/your-project/abc123 -``` - -### 4. Hyperopt Integration - -**Installation**: -```bash -pip install hyperopt -``` - -**Hyperopt example**: - -```python -from hyperopt import hp, fmin, tpe, Trials - -def objective(params): - model = LitModel( - lr=params['lr'], - batch_size=int(params['batch_size']), - hidden_size=int(params['hidden_size']) - ) - - trainer = L.Trainer( - max_epochs=10, - enable_progress_bar=False, - logger=False - ) - - trainer.fit(model, train_loader, val_loader) - - # Return loss (minimize) - return trainer.callback_metrics["val_loss"].item() - -# Define search space -space = { - 'lr': hp.loguniform('lr', np.log(1e-5), np.log(1e-1)), - 'batch_size': hp.quniform('batch_size', 16, 128, 16), - 'hidden_size': hp.quniform('hidden_size', 64, 512, 64) -} - -# Optimize -trials = Trials() -best = fmin( - fn=objective, - space=space, - algo=tpe.suggest, # Tree-structured Parzen Estimator - max_evals=50, - trials=trials -) - -print(f"Best hyperparameters: {best}") -``` - -## Built-In Lightning Tuning - -### Auto Learning Rate Finder - -```python -class LitModel(L.LightningModule): - def __init__(self, lr=1e-3): - super().__init__() - self.lr = lr - self.model = nn.Linear(10, 1) - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=self.lr) - - def training_step(self, batch, batch_idx): - loss = self.model(batch).mean() - return loss - -# Find optimal learning rate -model = LitModel() -trainer = L.Trainer(auto_lr_find=True) - -# This runs LR finder before training -trainer.tune(model, train_loader) - -# Or manually -from lightning.pytorch.tuner import Tuner -tuner = Tuner(trainer) -lr_finder = tuner.lr_find(model, train_loader) - -# Plot results -fig = lr_finder.plot(suggest=True) -fig.show() - -# Get suggested LR -suggested_lr = lr_finder.suggestion() -print(f"Suggested LR: {suggested_lr}") - -# Update model -model.lr = suggested_lr - -# Train with optimal LR -trainer.fit(model, train_loader) -``` - -### Auto Batch Size Finder - -```python -class LitModel(L.LightningModule): - def __init__(self, batch_size=32): - super().__init__() - self.batch_size = batch_size - self.model = nn.Linear(10, 1) - - def train_dataloader(self): - return DataLoader(dataset, batch_size=self.batch_size) - -model = LitModel() -trainer = L.Trainer(auto_scale_batch_size='binsearch') - -# Find optimal batch size -trainer.tune(model) - -print(f"Optimal batch size: {model.batch_size}") - -# Train with optimal batch size -trainer.fit(model, train_loader) -``` - -## Advanced Tuning Strategies - -### 1. Multi-Fidelity Optimization (Successive Halving) - -```python -from ray.tune.schedulers import ASHAScheduler - -# ASHA: Asynchronous Successive Halving Algorithm -scheduler = ASHAScheduler( - max_t=100, # Max epochs - grace_period=10, # Min epochs before stopping - reduction_factor=2 # Halve resources each round -) - -analysis = tune.run( - train_fn, - config=config, - num_samples=64, - scheduler=scheduler, - resources_per_trial={"gpu": 1} -) -``` - -**How it works**: -- Start 64 trials -- After 10 epochs, stop bottom 50% (32 trials remain) -- After 20 epochs, stop bottom 50% (16 trials remain) -- After 40 epochs, stop bottom 50% (8 trials remain) -- After 80 epochs, stop bottom 50% (4 trials remain) -- Run remaining 4 trials to completion (100 epochs) - -### 2. Bayesian Optimization - -```python -from ray.tune.search.bayesopt import BayesOptSearch - -search = BayesOptSearch( - metric="val_loss", - mode="min" -) - -analysis = tune.run( - train_fn, - config=config, - num_samples=50, - search_alg=search, - resources_per_trial={"gpu": 1} -) -``` - -### 3. Grid Search - -```python -from ray import tune - -# Exhaustive grid search -config = { - "lr": tune.grid_search([1e-5, 1e-4, 1e-3, 1e-2]), - "batch_size": tune.grid_search([16, 32, 64, 128]), - "optimizer": tune.grid_search(['adam', 'sgd', 'adamw']) -} - -# Total trials: 4 × 4 × 3 = 48 -analysis = tune.run(train_fn, config=config) -``` - -### 4. Random Search - -```python -config = { - "lr": tune.loguniform(1e-5, 1e-1), - "batch_size": tune.choice([16, 32, 64, 128]), - "dropout": tune.uniform(0.0, 0.5), - "hidden_size": tune.randint(64, 512) -} - -# Random sampling -analysis = tune.run( - train_fn, - config=config, - num_samples=100 # 100 random samples -) -``` - -## Best Practices - -### 1. Start Simple - -```python -# Phase 1: Coarse search (fast) -coarse_config = { - "lr": tune.loguniform(1e-5, 1e-1), - "batch_size": tune.choice([32, 64]) -} -coarse_analysis = tune.run(train_fn, config=coarse_config, num_samples=10, max_epochs=5) - -# Phase 2: Fine-tune around best (slow) -best_lr = coarse_analysis.best_config["lr"] -fine_config = { - "lr": tune.uniform(best_lr * 0.5, best_lr * 2), - "batch_size": tune.choice([16, 32, 64, 128]) -} -fine_analysis = tune.run(train_fn, config=fine_config, num_samples=20, max_epochs=20) -``` - -### 2. Use Checkpointing - -```python -def train_fn(config, checkpoint_dir=None): - model = LitModel(lr=config["lr"]) - - trainer = L.Trainer( - max_epochs=100, - callbacks=[ - TuneReportCheckpointCallback( - metrics={"loss": "val_loss"}, - filename="checkpoint", - on="validation_end" - ) - ] - ) - - # Resume from checkpoint if exists - ckpt_path = None - if checkpoint_dir: - ckpt_path = os.path.join(checkpoint_dir, "checkpoint") - - trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path) -``` - -### 3. Monitor Resource Usage - -```python -import GPUtil - -def train_fn(config): - # Before training - GPUs = GPUtil.getGPUs() - print(f"GPU memory before: {GPUs[0].memoryUsed} MB") - - # Train - model = LitModel(lr=config["lr"], batch_size=config["batch_size"]) - trainer.fit(model, train_loader) - - # After training - GPUs = GPUtil.getGPUs() - print(f"GPU memory after: {GPUs[0].memoryUsed} MB") -``` - -## Common Issues - -### Issue: Trials Running Out of Memory - -**Solution**: Reduce concurrent trials or batch size -```python -analysis = tune.run( - train_fn, - config=config, - resources_per_trial={"gpu": 0.5}, # 2 trials per GPU - max_concurrent_trials=2 # Limit concurrent trials -) -``` - -### Issue: Slow Hyperparameter Search - -**Solution**: Use early stopping scheduler -```python -from ray.tune.schedulers import ASHAScheduler - -scheduler = ASHAScheduler( - max_t=100, - grace_period=5, # Stop bad trials after 5 epochs - reduction_factor=3 -) -``` - -### Issue: Can't Reproduce Best Trial - -**Solution**: Set seeds in training function -```python -def train_fn(config): - L.seed_everything(42, workers=True) - # Rest of training... -``` - -## Resources - -- Ray Tune + Lightning: https://docs.ray.io/en/latest/tune/examples/tune-pytorch-lightning.html -- Optuna: https://optuna.readthedocs.io/ -- WandB Sweeps: https://docs.wandb.ai/guides/sweeps -- Lightning Tuner: https://lightning.ai/docs/pytorch/stable/tuning.html diff --git a/skills/mlops/simpo/SKILL.md b/skills/mlops/simpo/SKILL.md deleted file mode 100644 index 0af7b122c..000000000 --- a/skills/mlops/simpo/SKILL.md +++ /dev/null @@ -1,222 +0,0 @@ ---- -name: simpo-training -description: Simple Preference Optimization for LLM alignment. Reference-free alternative to DPO with better performance (+6.4 points on AlpacaEval 2.0). No reference model needed, more efficient than DPO. Use for preference alignment when want simpler, faster training than DPO/PPO. -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [torch, transformers, datasets, trl, accelerate] -metadata: - hermes: - tags: [Post-Training, SimPO, Preference Optimization, Alignment, DPO Alternative, Reference-Free, LLM Alignment, Efficient Training] - ---- - -# SimPO - Simple Preference Optimization - -## Quick start - -SimPO is a reference-free preference optimization method that outperforms DPO without needing a reference model. - -**Installation**: -```bash -# Create environment -conda create -n simpo python=3.10 && conda activate simpo - -# Install PyTorch 2.2.2 -# Visit: https://pytorch.org/get-started/locally/ - -# Install alignment-handbook -git clone https://github.com/huggingface/alignment-handbook.git -cd alignment-handbook -python -m pip install . - -# Install Flash Attention 2 -python -m pip install flash-attn --no-build-isolation -``` - -**Training** (Mistral 7B): -```bash -ACCELERATE_LOG_LEVEL=info accelerate launch \ - --config_file accelerate_configs/deepspeed_zero3.yaml \ - scripts/run_simpo.py \ - training_configs/mistral-7b-base-simpo.yaml -``` - -## Common workflows - -### Workflow 1: Train from base model (Mistral 7B) - -**Config** (`mistral-7b-base-simpo.yaml`): -```yaml -# Model -model_name_or_path: mistralai/Mistral-7B-v0.1 -torch_dtype: bfloat16 - -# Dataset -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 1.0 -dataset_splits: - - train_prefs - - test_prefs - -# SimPO hyperparameters -beta: 2.0 # Reward scaling (2.0-10.0) -gamma_beta_ratio: 0.5 # Target margin (0-1) -loss_type: sigmoid # sigmoid or hinge -sft_weight: 0.0 # Optional SFT regularization - -# Training -learning_rate: 5e-7 # Critical: 3e-7 to 1e-6 -num_train_epochs: 1 -per_device_train_batch_size: 1 -gradient_accumulation_steps: 8 - -# Output -output_dir: ./outputs/mistral-7b-simpo -``` - -**Launch training**: -```bash -accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \ - scripts/run_simpo.py training_configs/mistral-7b-base-simpo.yaml -``` - -### Workflow 2: Fine-tune instruct model (Llama 3 8B) - -**Config** (`llama3-8b-instruct-simpo.yaml`): -```yaml -model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct - -dataset_mixer: - argilla/ultrafeedback-binarized-preferences-cleaned: 1.0 - -beta: 2.5 -gamma_beta_ratio: 0.5 -learning_rate: 5e-7 -sft_weight: 0.1 # Add SFT loss to preserve capabilities - -num_train_epochs: 1 -per_device_train_batch_size: 2 -gradient_accumulation_steps: 4 -output_dir: ./outputs/llama3-8b-simpo -``` - -**Launch**: -```bash -accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \ - scripts/run_simpo.py training_configs/llama3-8b-instruct-simpo.yaml -``` - -### Workflow 3: Reasoning-intensive tasks (lower LR) - -**For math/code tasks**: -```yaml -model_name_or_path: deepseek-ai/deepseek-math-7b-base - -dataset_mixer: - argilla/distilabel-math-preference-dpo: 1.0 - -beta: 5.0 # Higher for stronger signal -gamma_beta_ratio: 0.7 # Larger margin -learning_rate: 3e-7 # Lower LR for reasoning -sft_weight: 0.0 - -num_train_epochs: 1 -per_device_train_batch_size: 1 -gradient_accumulation_steps: 16 -``` - -## When to use vs alternatives - -**Use SimPO when**: -- Want simpler training than DPO (no reference model) -- Have preference data (chosen/rejected pairs) -- Need better performance than DPO -- Limited compute resources -- Single-node training sufficient - -**Algorithm selection**: -- **SimPO**: Simplest, best performance, no reference model -- **DPO**: Need reference model baseline, more conservative -- **PPO**: Maximum control, need reward model, complex setup -- **GRPO**: Memory-efficient RL, no critic - -**Use alternatives instead**: -- **OpenRLHF**: Multi-node distributed training, PPO/GRPO -- **TRL**: Need multiple methods in one framework -- **DPO**: Established baseline comparison - -## Common issues - -**Issue: Loss divergence** - -Reduce learning rate: -```yaml -learning_rate: 3e-7 # Reduce from 5e-7 -``` - -Reduce beta: -```yaml -beta: 1.0 # Reduce from 2.0 -``` - -**Issue: Model forgets capabilities** - -Add SFT regularization: -```yaml -sft_weight: 0.1 # Add SFT loss component -``` - -**Issue: Poor preference separation** - -Increase beta and margin: -```yaml -beta: 5.0 # Increase from 2.0 -gamma_beta_ratio: 0.8 # Increase from 0.5 -``` - -**Issue: OOM during training** - -Reduce batch size: -```yaml -per_device_train_batch_size: 1 -gradient_accumulation_steps: 16 # Maintain effective batch -``` - -Enable gradient checkpointing: -```yaml -gradient_checkpointing: true -``` - -## Advanced topics - -**Loss functions**: See [references/loss-functions.md](references/loss-functions.md) for sigmoid vs hinge loss, mathematical formulations, and when to use each. - -**Hyperparameter tuning**: See [references/hyperparameters.md](references/hyperparameters.md) for beta, gamma, learning rate selection guide, and model-size-specific recommendations. - -**Dataset preparation**: See [references/datasets.md](references/datasets.md) for preference data formats, quality filtering, and custom dataset creation. - -## Hardware requirements - -- **GPU**: NVIDIA A100/H100 recommended -- **VRAM**: - - 7B model: 1× A100 40GB (DeepSpeed ZeRO-3) - - 8B model: 2× A100 40GB - - 70B model: 8× A100 80GB -- **Single-node**: DeepSpeed ZeRO-3 sufficient -- **Mixed precision**: BF16 recommended - -**Memory optimization**: -- DeepSpeed ZeRO-3 (default config) -- Gradient checkpointing -- Flash Attention 2 - -## Resources - -- Paper: https://arxiv.org/abs/2405.14734 (NeurIPS 2024) -- GitHub: https://github.com/princeton-nlp/SimPO -- Models: https://huggingface.co/princeton-nlp -- Alignment Handbook: https://github.com/huggingface/alignment-handbook - - - diff --git a/skills/mlops/simpo/references/datasets.md b/skills/mlops/simpo/references/datasets.md deleted file mode 100644 index 449e6cf86..000000000 --- a/skills/mlops/simpo/references/datasets.md +++ /dev/null @@ -1,478 +0,0 @@ -# Datasets - -Complete guide to preference datasets for SimPO training. - -## Dataset Format - -### Required Fields - -Preference datasets must contain: -```json -{ - "prompt": "User question or instruction", - "chosen": "Better/preferred response", - "rejected": "Worse/rejected response" -} -``` - -**Alternative field names** (auto-detected): -- `prompt` → `question`, `instruction`, `input` -- `chosen` → `response_chosen`, `winner`, `preferred` -- `rejected` → `response_rejected`, `loser` - -### Example Entry - -```json -{ - "prompt": "Explain quantum computing in simple terms.", - "chosen": "Quantum computing uses quantum bits (qubits) that can exist in multiple states simultaneously through superposition. This allows quantum computers to process many possibilities at once, making them potentially much faster than classical computers for specific tasks like cryptography and optimization.", - "rejected": "It's like regular computing but quantum." -} -``` - -## Popular Datasets - -### 1. UltraFeedback (Recommended) - -**HuggingFaceH4/ultrafeedback_binarized**: -- **Size**: 60K preference pairs -- **Quality**: High (GPT-4 annotations) -- **Domain**: General instruction following -- **Format**: Clean, ready-to-use - -**Config**: -```yaml -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 1.0 -dataset_splits: - - train_prefs - - test_prefs -``` - -### 2. Argilla UltraFeedback (Cleaned) - -**argilla/ultrafeedback-binarized-preferences-cleaned**: -- **Size**: 50K pairs (filtered) -- **Quality**: Very high (deduped, cleaned) -- **Domain**: General -- **Format**: Clean - -**Config**: -```yaml -dataset_mixer: - argilla/ultrafeedback-binarized-preferences-cleaned: 1.0 -``` - -### 3. Distilabel Math - -**argilla/distilabel-math-preference-dpo**: -- **Size**: 30K pairs -- **Quality**: High (GSM8K, MATH) -- **Domain**: Math reasoning -- **Format**: Math-specific - -**Config**: -```yaml -dataset_mixer: - argilla/distilabel-math-preference-dpo: 1.0 -``` - -### 4. HelpSteer - -**nvidia/HelpSteer**: -- **Size**: 38K samples -- **Quality**: High (human ratings) -- **Domain**: Helpfulness alignment -- **Format**: Multi-attribute ratings - -**Config**: -```yaml -dataset_mixer: - nvidia/HelpSteer: 1.0 -``` - -### 5. Anthropic HH-RLHF - -**Anthropic/hh-rlhf**: -- **Size**: 161K samples -- **Quality**: High (human preferences) -- **Domain**: Harmless + helpful -- **Format**: Conversational - -**Config**: -```yaml -dataset_mixer: - Anthropic/hh-rlhf: 1.0 -``` - -## Dataset Mixing - -### Multiple Datasets - -**Equal mix**: -```yaml -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 0.5 - Anthropic/hh-rlhf: 0.5 -``` - -**Weighted mix**: -```yaml -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 0.7 - argilla/distilabel-math-preference-dpo: 0.2 - nvidia/HelpSteer: 0.1 -``` - -**Domain-specific emphasis**: -```yaml -# 80% general + 20% math -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 0.8 - argilla/distilabel-math-preference-dpo: 0.2 -``` - -## Data Quality - -### Quality Indicators - -**Good preference data**: -- ✅ Clear quality difference between chosen/rejected -- ✅ Diverse prompts -- ✅ Minimal noise/annotation errors -- ✅ Appropriate difficulty level - -**Poor preference data**: -- ❌ Ambiguous preferences -- ❌ Repetitive prompts -- ❌ Annotation noise -- ❌ Too easy/hard prompts - -### Quality Filtering - -**Filter by length difference**: -```python -def filter_by_length(example): - chosen_len = len(example['chosen'].split()) - rejected_len = len(example['rejected'].split()) - # Reject if chosen is much shorter (potential low-effort) - return chosen_len >= rejected_len * 0.5 - -dataset = dataset.filter(filter_by_length) -``` - -**Filter by diversity**: -```python -seen_prompts = set() - -def filter_duplicates(example): - prompt = example['prompt'] - if prompt in seen_prompts: - return False - seen_prompts.add(prompt) - return True - -dataset = dataset.filter(filter_duplicates) -``` - -## Custom Dataset Creation - -### Format 1: JSON Lines - -**File** (`preferences.jsonl`): -```jsonl -{"prompt": "What is Python?", "chosen": "Python is a high-level programming language...", "rejected": "It's a snake."} -{"prompt": "Explain AI.", "chosen": "AI refers to systems that can...", "rejected": "It's computers that think."} -``` - -**Load**: -```yaml -dataset_mixer: - json: - data_files: preferences.jsonl -``` - -### Format 2: HuggingFace Dataset - -**Create from dict**: -```python -from datasets import Dataset - -data = { - "prompt": ["What is Python?", "Explain AI."], - "chosen": ["Python is...", "AI refers to..."], - "rejected": ["It's a snake.", "It's computers..."] -} - -dataset = Dataset.from_dict(data) -dataset.push_to_hub("username/my-preferences") -``` - -**Use in config**: -```yaml -dataset_mixer: - username/my-preferences: 1.0 -``` - -### Format 3: ChatML - -**For conversational data**: -```json -{ - "prompt": [ - {"role": "user", "content": "What is quantum computing?"} - ], - "chosen": [ - {"role": "assistant", "content": "Quantum computing uses qubits..."} - ], - "rejected": [ - {"role": "assistant", "content": "It's like regular computing but quantum."} - ] -} -``` - -**Apply chat template**: -```yaml -dataset_text_field: null # Will apply chat template -``` - -## Synthetic Data Generation - -### Using GPT-4 - -**Prompt template**: -``` -Given the following question: -{prompt} - -Generate two responses: -1. A high-quality, detailed response (chosen) -2. A low-quality, brief response (rejected) - -Format as JSON with "chosen" and "rejected" fields. -``` - -**Example code**: -```python -import openai - -def generate_pair(prompt): - response = openai.ChatCompletion.create( - model="gpt-4", - messages=[{ - "role": "user", - "content": f"Given: {prompt}\n\nGenerate chosen/rejected pair in JSON." - }] - ) - return json.loads(response.choices[0].message.content) - -# Generate dataset -prompts = load_prompts() -dataset = [generate_pair(p) for p in prompts] -``` - -### Using Local Model - -**With vLLM**: -```python -from vllm import LLM - -llm = LLM(model="meta-llama/Meta-Llama-3-70B-Instruct") - -def generate_variations(prompt): - # Generate multiple completions - outputs = llm.generate( - [prompt] * 4, - sampling_params={ - "temperature": 0.8, - "top_p": 0.9, - "max_tokens": 512 - } - ) - - # Select best/worst - chosen = max(outputs, key=lambda x: len(x.outputs[0].text)) - rejected = min(outputs, key=lambda x: len(x.outputs[0].text)) - - return { - "prompt": prompt, - "chosen": chosen.outputs[0].text, - "rejected": rejected.outputs[0].text - } -``` - -## Data Preprocessing - -### Truncation - -**Limit sequence length**: -```yaml -max_prompt_length: 512 -max_completion_length: 512 -max_length: 1024 # Total -``` - -**Implementation**: -```python -def truncate_example(example): - tokenizer.truncation_side = "left" # For prompts - prompt_tokens = tokenizer( - example['prompt'], - max_length=512, - truncation=True - ) - - tokenizer.truncation_side = "right" # For completions - chosen_tokens = tokenizer( - example['chosen'], - max_length=512, - truncation=True - ) - - return { - "prompt": tokenizer.decode(prompt_tokens['input_ids']), - "chosen": tokenizer.decode(chosen_tokens['input_ids']) - } - -dataset = dataset.map(truncate_example) -``` - -### Deduplication - -**Remove exact duplicates**: -```python -dataset = dataset.unique('prompt') -``` - -**Remove near-duplicates** (MinHash): -```python -from datasketch import MinHash, MinHashLSH - -def deduplicate_lsh(dataset, threshold=0.8): - lsh = MinHashLSH(threshold=threshold, num_perm=128) - seen = [] - - for i, example in enumerate(dataset): - m = MinHash(num_perm=128) - for word in example['prompt'].split(): - m.update(word.encode('utf8')) - - if not lsh.query(m): - lsh.insert(i, m) - seen.append(example) - - return Dataset.from_list(seen) - -dataset = deduplicate_lsh(dataset) -``` - -## Data Augmentation - -### Paraphrasing Prompts - -```python -def paraphrase_prompt(example): - # Use paraphrasing model - paraphrased = paraphrase_model(example['prompt']) - - return [ - example, # Original - { - "prompt": paraphrased, - "chosen": example['chosen'], - "rejected": example['rejected'] - } - ] - -dataset = dataset.map(paraphrase_prompt, batched=False, remove_columns=[]) -``` - -### Difficulty Balancing - -**Mix easy/medium/hard**: -```python -def categorize_difficulty(example): - prompt_len = len(example['prompt'].split()) - if prompt_len < 20: - return "easy" - elif prompt_len < 50: - return "medium" - else: - return "hard" - -dataset = dataset.map(lambda x: {"difficulty": categorize_difficulty(x)}) - -# Sample balanced dataset -easy = dataset.filter(lambda x: x['difficulty'] == 'easy').shuffle().select(range(1000)) -medium = dataset.filter(lambda x: x['difficulty'] == 'medium').shuffle().select(range(1000)) -hard = dataset.filter(lambda x: x['difficulty'] == 'hard').shuffle().select(range(1000)) - -balanced = concatenate_datasets([easy, medium, hard]).shuffle() -``` - -## Dataset Statistics - -### Compute Stats - -```python -def compute_stats(dataset): - prompt_lens = [len(x['prompt'].split()) for x in dataset] - chosen_lens = [len(x['chosen'].split()) for x in dataset] - rejected_lens = [len(x['rejected'].split()) for x in dataset] - - print(f"Dataset size: {len(dataset)}") - print(f"Avg prompt length: {np.mean(prompt_lens):.1f} words") - print(f"Avg chosen length: {np.mean(chosen_lens):.1f} words") - print(f"Avg rejected length: {np.mean(rejected_lens):.1f} words") - print(f"Chosen > Rejected: {sum(c > r for c, r in zip(chosen_lens, rejected_lens)) / len(dataset):.1%}") - -compute_stats(dataset) -``` - -**Expected output**: -``` -Dataset size: 50000 -Avg prompt length: 45.2 words -Avg chosen length: 180.5 words -Avg rejected length: 120.3 words -Chosen > Rejected: 85.2% -``` - -## Best Practices - -### 1. Data Quality Over Quantity - -- **Prefer**: 10K high-quality pairs -- **Over**: 100K noisy pairs - -### 2. Clear Preference Signals - -- Chosen should be noticeably better -- Avoid marginal differences -- Remove ambiguous pairs - -### 3. Domain Matching - -- Match dataset domain to target use case -- Mix datasets for broader coverage -- Include safety-filtered data - -### 4. Validate Before Training - -```python -# Sample 10 random examples -samples = dataset.shuffle().select(range(10)) - -for ex in samples: - print(f"Prompt: {ex['prompt']}") - print(f"Chosen: {ex['chosen'][:100]}...") - print(f"Rejected: {ex['rejected'][:100]}...") - print(f"Preference clear: {'✓' if len(ex['chosen']) > len(ex['rejected']) else '?'}") - print() -``` - -## References - -- HuggingFace Datasets: https://huggingface.co/datasets -- Alignment Handbook: https://github.com/huggingface/alignment-handbook -- UltraFeedback: https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized diff --git a/skills/mlops/simpo/references/hyperparameters.md b/skills/mlops/simpo/references/hyperparameters.md deleted file mode 100644 index f55c31f86..000000000 --- a/skills/mlops/simpo/references/hyperparameters.md +++ /dev/null @@ -1,452 +0,0 @@ -# Hyperparameters - -Complete guide to SimPO hyperparameter selection and tuning. - -## Overview - -Key hyperparameters in SimPO: -1. **Learning Rate** - Most critical -2. **Beta (β)** - Reward scaling -3. **Gamma-Beta Ratio (γ/β)** - Target margin -4. **SFT Weight** - Regularization strength - -## Learning Rate - -### Recommended Ranges - -**By model size**: -| Model Size | Learning Rate | Notes | -|------------|---------------|-------| -| 1B-3B | 5e-7 to 1e-6 | Higher end safe | -| 7B-8B | 3e-7 to 5e-7 | **Standard** | -| 13B-30B | 1e-7 to 3e-7 | Lower for stability | -| 70B+ | 5e-8 to 1e-7 | Very conservative | - -**By task type**: -| Task | Learning Rate | Reason | -|------|---------------|--------| -| General chat | 5e-7 | Standard | -| Code generation | 3e-7 | **Precise reasoning** | -| Math reasoning | 3e-7 | **Careful optimization** | -| Creative writing | 1e-6 | More aggressive OK | - -### Why Learning Rate Matters - -**Too high** (> 1e-6 for 7B): -- Loss divergence -- Catastrophic forgetting -- Unstable training - -**Too low** (< 1e-7 for 7B): -- Very slow convergence -- May not finish in time -- Undertraining - -**Optimal** (3e-7 to 5e-7 for 7B): -- Stable convergence -- Good final performance -- Efficient training - -### Config Examples - -**Mistral 7B (general)**: -```yaml -learning_rate: 5e-7 -num_train_epochs: 1 -warmup_ratio: 0.1 -lr_scheduler_type: cosine -``` - -**Llama 3 8B (reasoning)**: -```yaml -learning_rate: 3e-7 -num_train_epochs: 1 -warmup_ratio: 0.1 -lr_scheduler_type: cosine -``` - -**Gemma 2 9B (creative)**: -```yaml -learning_rate: 1e-6 -num_train_epochs: 1 -warmup_ratio: 0.1 -lr_scheduler_type: linear -``` - -## Beta (β) - -### Recommended Values - -**Range**: 2.0 to 10.0 (much higher than DPO's 0.01-0.1) - -**By preference strength**: -| Beta | Preference Strength | Use Case | -|------|-------------------|----------| -| 1.0-2.0 | Weak | Subtle preferences | -| 2.0-5.0 | **Standard** | General alignment | -| 5.0-10.0 | Strong | Clear preferences | - -**Default**: 2.0 to 2.5 - -### Why Beta Matters - -**Low beta** (< 2.0): -- Weak reward signal -- Slow preference learning -- May underfit - -**High beta** (> 10.0): -- Very strong reward signal -- Risk of overfitting -- May ignore weak preferences - -**Optimal** (2.0-5.0): -- Balanced reward scaling -- Stable training -- Good generalization - -### Interaction with Gamma - -**Beta and gamma together**: -``` -Target margin in reward space = gamma -Target margin in logit space = gamma / beta -``` - -**Example**: -```yaml -beta: 2.0 -gamma_beta_ratio: 0.5 -# Effective gamma = 2.0 * 0.5 = 1.0 -``` - -### Config Examples - -**Weak preferences**: -```yaml -beta: 2.0 -gamma_beta_ratio: 0.3 # Small margin -``` - -**Standard**: -```yaml -beta: 2.5 -gamma_beta_ratio: 0.5 # Default -``` - -**Strong preferences**: -```yaml -beta: 5.0 -gamma_beta_ratio: 0.7 # Larger margin -``` - -## Gamma-Beta Ratio (γ/β) - -### Recommended Values - -**Range**: 0.0 to 1.0 - -**By scenario**: -| Ratio | Margin | Use Case | -|-------|--------|----------| -| 0.0-0.3 | Small | Weak preference data | -| 0.4-0.6 | **Standard** | General use | -| 0.7-1.0 | Large | Very clear preferences | - -**Default**: 0.5 - -### Why Gamma Matters - -**Low gamma** (< 0.3): -- Small target margin -- Less aggressive alignment -- More conservative - -**High gamma** (> 0.7): -- Large target margin -- Stronger alignment -- More aggressive - -**Optimal** (0.4-0.6): -- Balanced margin -- Stable training -- Good alignment - -### Mathematical Meaning - -**In loss function**: -```python -logits = pi_logratios - gamma_beta_ratio -loss = -log(sigmoid(beta * logits)) -``` - -**Interpretation**: -- gamma_beta_ratio shifts the decision boundary -- Higher ratio = requires larger log prob difference -- Controls how "clear" preferences must be - -### Config Examples - -**Noisy preferences**: -```yaml -gamma_beta_ratio: 0.3 # Smaller margin, more tolerant -``` - -**Standard**: -```yaml -gamma_beta_ratio: 0.5 # Default -``` - -**High-quality preferences**: -```yaml -gamma_beta_ratio: 0.8 # Larger margin, stricter -``` - -## SFT Weight - -### Recommended Values - -**Range**: 0.0 to 1.0 - -**By model type**: -| Model Type | SFT Weight | Reason | -|------------|-----------|--------| -| Base model | 0.0 | No prior capabilities | -| **Instruct model** | 0.05-0.1 | Preserve instruction following | -| Chat model | 0.1-0.2 | Preserve conversational skills | - -**Default**: 0.0 (no SFT regularization) - -### Why SFT Weight Matters - -**Zero SFT** (0.0): -- Pure preference optimization -- May forget capabilities -- Standard for base models - -**Low SFT** (0.05-0.1): -- Balanced approach -- **Recommended for instruct models** -- Slight capability preservation - -**High SFT** (> 0.2): -- Strong capability preservation -- Weaker preference alignment -- May reduce alignment gains - -### Trade-off - -``` -Total Loss = SimPO Loss + (sft_weight * SFT Loss) -``` - -**Example**: -```yaml -sft_weight: 0.1 -# 90% preference optimization + 10% capability preservation -``` - -### Config Examples - -**Base model (no SFT)**: -```yaml -model_name_or_path: mistralai/Mistral-7B-v0.1 -sft_weight: 0.0 -``` - -**Instruct model (light SFT)**: -```yaml -model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct -sft_weight: 0.1 -``` - -**Chat model (moderate SFT)**: -```yaml -model_name_or_path: HuggingFaceH4/zephyr-7b-beta -sft_weight: 0.2 -``` - -## Model-Size-Specific Recommendations - -### 7B Models (Mistral, Llama 3) - -**Standard config**: -```yaml -learning_rate: 5e-7 -beta: 2.0 -gamma_beta_ratio: 0.5 -sft_weight: 0.0 # 0.1 if instruct model -num_train_epochs: 1 -per_device_train_batch_size: 2 -gradient_accumulation_steps: 4 -``` - -### 8B-13B Models - -**Standard config**: -```yaml -learning_rate: 3e-7 -beta: 2.5 -gamma_beta_ratio: 0.5 -sft_weight: 0.1 # If instruct -num_train_epochs: 1 -per_device_train_batch_size: 1 -gradient_accumulation_steps: 8 -``` - -### 70B Models - -**Standard config**: -```yaml -learning_rate: 1e-7 -beta: 2.0 -gamma_beta_ratio: 0.5 -sft_weight: 0.05 -num_train_epochs: 1 -per_device_train_batch_size: 1 -gradient_accumulation_steps: 16 -``` - -## Batch Size & Gradient Accumulation - -### Effective Batch Size - -``` -Effective Batch Size = per_device_batch_size * num_gpus * grad_accum_steps -``` - -**Recommended effective batch sizes**: -- 7B: 128-256 -- 13B: 64-128 -- 70B: 32-64 - -### Config Examples - -**Single GPU (A100 40GB)**: -```yaml -per_device_train_batch_size: 1 -gradient_accumulation_steps: 128 # Effective batch = 128 -``` - -**4 GPUs (A100 40GB)**: -```yaml -per_device_train_batch_size: 2 -gradient_accumulation_steps: 16 # Effective batch = 2*4*16 = 128 -``` - -**8 GPUs (A100 80GB)**: -```yaml -per_device_train_batch_size: 2 -gradient_accumulation_steps: 8 # Effective batch = 2*8*8 = 128 -``` - -## Loss Type - -### Sigmoid vs Hinge - -**Sigmoid** (default, recommended): -```yaml -loss_type: sigmoid -label_smoothing: 0.0 -``` - -**Hinge** (experimental): -```yaml -loss_type: hinge -# No label smoothing for hinge -``` - -**When to use hinge**: -- Margin-based tasks -- SVM-style optimization -- Experimental purposes - -**Generally**: Stick with sigmoid - -## Tuning Guide - -### Step 1: Start with Defaults - -```yaml -learning_rate: 5e-7 # For 7B -beta: 2.0 -gamma_beta_ratio: 0.5 -sft_weight: 0.0 # 0.1 if instruct -loss_type: sigmoid -``` - -### Step 2: Monitor Training - -**Check every 100 steps**: -- Loss curve (should decrease smoothly) -- Reward margin (should increase) -- Chosen/rejected logps (should separate) - -### Step 3: Adjust if Needed - -**If loss diverges**: -```yaml -learning_rate: 3e-7 # Reduce from 5e-7 -beta: 1.0 # Reduce from 2.0 -``` - -**If loss plateaus early**: -```yaml -learning_rate: 1e-6 # Increase from 5e-7 -beta: 5.0 # Increase from 2.0 -``` - -**If model forgets**: -```yaml -sft_weight: 0.2 # Increase from 0.0 -``` - -## Complete Example Configs - -### Mistral 7B Base (Standard) - -```yaml -model_name_or_path: mistralai/Mistral-7B-v0.1 -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 1.0 - -learning_rate: 5e-7 -beta: 2.0 -gamma_beta_ratio: 0.5 -loss_type: sigmoid -sft_weight: 0.0 - -num_train_epochs: 1 -per_device_train_batch_size: 2 -gradient_accumulation_steps: 4 -warmup_ratio: 0.1 -lr_scheduler_type: cosine - -bf16: true -gradient_checkpointing: true -``` - -### Llama 3 8B Instruct (Reasoning) - -```yaml -model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct -dataset_mixer: - argilla/distilabel-math-preference-dpo: 1.0 - -learning_rate: 3e-7 -beta: 5.0 -gamma_beta_ratio: 0.7 -loss_type: sigmoid -sft_weight: 0.1 - -num_train_epochs: 1 -per_device_train_batch_size: 1 -gradient_accumulation_steps: 16 -warmup_ratio: 0.1 -lr_scheduler_type: cosine -``` - -## References - -- SimPO paper: https://arxiv.org/abs/2405.14734 -- Alignment Handbook: https://github.com/huggingface/alignment-handbook diff --git a/skills/mlops/simpo/references/loss-functions.md b/skills/mlops/simpo/references/loss-functions.md deleted file mode 100644 index 3aba0dc5d..000000000 --- a/skills/mlops/simpo/references/loss-functions.md +++ /dev/null @@ -1,350 +0,0 @@ -# Loss Functions - -Complete guide to SimPO loss functions and mathematical formulations. - -## Overview - -SimPO supports two loss types: -- **Sigmoid** (default) - Smooth, differentiable loss -- **Hinge** - Margin-based, sparse loss - -Both are reference-free (no reference model needed). - -## SimPO Loss Formula - -### Core Calculation - -**Step 1: Log probability ratio**: -``` -pi_logratios = log P_θ(y_chosen|x) - log P_θ(y_rejected|x) -``` - -**Step 2: Apply target margin**: -``` -logits = pi_logratios - γ/β -``` -Where: -- γ/β = `gamma_beta_ratio` (target margin) - -**Step 3: Compute loss** (depends on loss type) - -### Sigmoid Loss (Default) - -**Formula**: -``` -L = -log σ(β * logits) * (1 - ε) - log σ(-β * logits) * ε -``` - -Where: -- β = `beta` (reward scaling) -- σ = sigmoid function -- ε = `label_smoothing` (default 0.0) - -**Implementation**: -```python -losses = ( - -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - - F.logsigmoid(-self.beta * logits) * self.label_smoothing -) -``` - -**Characteristics**: -- Smooth, continuous gradients -- Probabilistic interpretation -- Standard choice for most tasks -- Works well with higher beta values - -### Hinge Loss - -**Formula**: -``` -L = max(0, 1 - β * logits) -``` - -**Implementation**: -```python -losses = torch.relu(1 - self.beta * logits) -``` - -**Characteristics**: -- Non-smooth (has kink at logits = 1/β) -- Margin-based (SVM-style) -- Can lead to sparser solutions -- Less commonly used - -## Comparison to DPO - -### DPO Loss (Reference Model Required) - -**Formula**: -``` -L_DPO = -E[log σ(β * log(π_θ(y_w|x)/π_ref(y_w|x)) - β * log(π_θ(y_l|x)/π_ref(y_l|x)))] -``` - -**Key features**: -- Requires reference model π_ref -- Normalizes by reference log probabilities -- More conservative (stays close to reference) - -### SimPO Loss (Reference-Free) - -**Formula**: -``` -L_SimPO = -log σ(β * (log π_θ(y_w|x) - log π_θ(y_l|x) - γ/β)) -``` - -**Key features**: -- No reference model needed -- Direct preference optimization -- Target margin γ/β controls preference strength -- More efficient (fewer model forward passes) - -**Visual comparison**: -``` -DPO: [Policy] - [Reference] → Loss -SimPO: [Policy] → Loss -``` - -## Average Log Probability Reward - -### Calculation - -**Per-token log probabilities**: -```python -# Get log probs for each token -per_token_logps = log_softmax(logits).gather(dim=-1, index=labels) - -# Create mask to ignore padding -loss_mask = (labels != label_pad_token_id) -``` - -**Average log probability** (if `average_log_prob=True`): -```python -avg_logp = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) -``` - -**Sum log probability** (if `average_log_prob=False`): -```python -sum_logp = (per_token_logps * loss_mask).sum(-1) -``` - -**Why average?** -- Normalizes for sequence length -- Prevents bias toward shorter/longer responses -- Standard practice in SimPO - -### Reward Metrics - -**Chosen reward**: -```python -chosen_rewards = beta * policy_chosen_logps.detach() -``` - -**Rejected reward**: -```python -rejected_rewards = beta * policy_rejected_logps.detach() -``` - -**Reward margin**: -```python -reward_margin = chosen_rewards.mean() - rejected_rewards.mean() -``` - -## Label Smoothing - -### Formula with Smoothing - -**Sigmoid loss**: -``` -L = -log σ(β * logits) * (1 - ε) - log σ(-β * logits) * ε -``` - -**Effect**: -- ε = 0.0: No smoothing (default) -- ε = 0.1: 10% smoothing (soft labels) -- ε = 0.5: Maximum smoothing - -**When to use**: -- Noisy preference labels -- Uncertain preferences -- Prevent overconfidence - -**Config**: -```yaml -label_smoothing: 0.1 # 10% smoothing -``` - -## SFT Regularization - -### Combined Loss - -**With SFT component**: -``` -L_total = L_SimPO + λ * L_SFT -``` - -Where: -- L_SFT = cross-entropy loss on chosen responses -- λ = `sft_weight` (0.0 to 1.0) - -**Implementation**: -```python -if self.sft_weight > 0: - sft_loss = -policy_chosen_logps - total_loss = simpo_loss + self.sft_weight * sft_loss -``` - -**When to use**: -- Preserve model capabilities -- Prevent catastrophic forgetting -- Fine-tuning instruct models - -**Trade-off**: -- Higher sft_weight: Preserve capabilities, less alignment -- Lower sft_weight: Stronger alignment, may forget capabilities - -**Config**: -```yaml -sft_weight: 0.1 # 10% SFT regularization -``` - -## Loss Type Selection - -### Sigmoid vs Hinge - -| Aspect | Sigmoid | Hinge | -|--------|---------|-------| -| Smoothness | Smooth | Non-smooth | -| Gradients | Continuous | Discontinuous at margin | -| Sparsity | Dense solutions | Sparse solutions | -| Interpretability | Probabilistic | Geometric margin | -| Use case | **General purpose** | Margin-based tasks | -| Recommendation | **Default choice** | Experimental | - -**Config**: -```yaml -# Sigmoid (default) -loss_type: sigmoid - -# Hinge (alternative) -loss_type: hinge -``` - -## Mathematical Properties - -### Gradient Analysis - -**Sigmoid loss gradient**: -``` -∂L/∂logits = -β * σ(-β * logits) * (1 - ε) + β * σ(β * logits) * ε -``` - -**Hinge loss gradient**: -``` -∂L/∂logits = -β if logits < 1/β - 0 otherwise -``` - -**Implications**: -- Sigmoid: Always provides gradient signal -- Hinge: No gradient when margin satisfied - -### Convergence Behavior - -**Sigmoid**: -- Asymptotically approaches zero loss -- Continues optimizing even with large margins -- Smoother training curves - -**Hinge**: -- Reaches zero loss at margin -- Stops optimizing once margin satisfied -- May have training plateaus - -## Complete Loss Examples - -### Example 1: Basic SimPO (Sigmoid) - -**Config**: -```yaml -beta: 2.0 -gamma_beta_ratio: 0.5 -loss_type: sigmoid -label_smoothing: 0.0 -sft_weight: 0.0 -``` - -**Loss calculation**: -```python -# Step 1: Compute log probs -chosen_logps = avg_log_prob(policy(chosen)) # e.g., -1.2 -rejected_logps = avg_log_prob(policy(rejected)) # e.g., -2.5 - -# Step 2: Log ratio and margin -pi_logratios = -1.2 - (-2.5) = 1.3 -logits = 1.3 - 0.5 = 0.8 - -# Step 3: Sigmoid loss -loss = -log(sigmoid(2.0 * 0.8)) - = -log(sigmoid(1.6)) - = -log(0.832) - = 0.184 -``` - -### Example 2: SimPO with SFT - -**Config**: -```yaml -beta: 2.5 -gamma_beta_ratio: 0.5 -loss_type: sigmoid -sft_weight: 0.1 -``` - -**Loss calculation**: -```python -# SimPO loss (as above) -simpo_loss = 0.184 - -# SFT loss -sft_loss = -chosen_logps = -(-1.2) = 1.2 - -# Total loss -total_loss = simpo_loss + 0.1 * sft_loss - = 0.184 + 0.12 - = 0.304 -``` - -## Debugging - -### Check Reward Margins - -**Low margin (< 0.5)**: -- Preferences not being learned -- Increase beta or gamma_beta_ratio - -**High margin (> 5.0)**: -- May be overfitting -- Reduce beta or learning rate - -**Monitor**: -```python -reward_margin = chosen_rewards.mean() - rejected_rewards.mean() -print(f"Reward margin: {reward_margin:.2f}") -``` - -### Check Log Probabilities - -**Typical values**: -- Chosen: -1.0 to -2.0 (higher is better) -- Rejected: -2.0 to -4.0 (lower is worse) - -**Warning signs**: -- Both very negative (< -10): Model not learning -- Both very positive (> 0): Numerical instability - -## References - -- SimPO paper: https://arxiv.org/abs/2405.14734 -- DPO paper: https://arxiv.org/abs/2305.18290 -- Implementation: https://github.com/princeton-nlp/SimPO diff --git a/skills/mlops/stable-diffusion/SKILL.md b/skills/mlops/stable-diffusion/SKILL.md deleted file mode 100644 index d3932061b..000000000 --- a/skills/mlops/stable-diffusion/SKILL.md +++ /dev/null @@ -1,522 +0,0 @@ ---- -name: stable-diffusion-image-generation -description: State-of-the-art text-to-image generation with Stable Diffusion models via HuggingFace Diffusers. Use when generating images from text prompts, performing image-to-image translation, inpainting, or building custom diffusion pipelines. -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [diffusers>=0.30.0, transformers>=4.41.0, accelerate>=0.31.0, torch>=2.0.0] -metadata: - hermes: - tags: [Image Generation, Stable Diffusion, Diffusers, Text-to-Image, Multimodal, Computer Vision] - ---- - -# Stable Diffusion Image Generation - -Comprehensive guide to generating images with Stable Diffusion using the HuggingFace Diffusers library. - -## When to use Stable Diffusion - -**Use Stable Diffusion when:** -- Generating images from text descriptions -- Performing image-to-image translation (style transfer, enhancement) -- Inpainting (filling in masked regions) -- Outpainting (extending images beyond boundaries) -- Creating variations of existing images -- Building custom image generation workflows - -**Key features:** -- **Text-to-Image**: Generate images from natural language prompts -- **Image-to-Image**: Transform existing images with text guidance -- **Inpainting**: Fill masked regions with context-aware content -- **ControlNet**: Add spatial conditioning (edges, poses, depth) -- **LoRA Support**: Efficient fine-tuning and style adaptation -- **Multiple Models**: SD 1.5, SDXL, SD 3.0, Flux support - -**Use alternatives instead:** -- **DALL-E 3**: For API-based generation without GPU -- **Midjourney**: For artistic, stylized outputs -- **Imagen**: For Google Cloud integration -- **Leonardo.ai**: For web-based creative workflows - -## Quick start - -### Installation - -```bash -pip install diffusers transformers accelerate torch -pip install xformers # Optional: memory-efficient attention -``` - -### Basic text-to-image - -```python -from diffusers import DiffusionPipeline -import torch - -# Load pipeline (auto-detects model type) -pipe = DiffusionPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - torch_dtype=torch.float16 -) -pipe.to("cuda") - -# Generate image -image = pipe( - "A serene mountain landscape at sunset, highly detailed", - num_inference_steps=50, - guidance_scale=7.5 -).images[0] - -image.save("output.png") -``` - -### Using SDXL (higher quality) - -```python -from diffusers import AutoPipelineForText2Image -import torch - -pipe = AutoPipelineForText2Image.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - torch_dtype=torch.float16, - variant="fp16" -) -pipe.to("cuda") - -# Enable memory optimization -pipe.enable_model_cpu_offload() - -image = pipe( - prompt="A futuristic city with flying cars, cinematic lighting", - height=1024, - width=1024, - num_inference_steps=30 -).images[0] -``` - -## Architecture overview - -### Three-pillar design - -Diffusers is built around three core components: - -``` -Pipeline (orchestration) -├── Model (neural networks) -│ ├── UNet / Transformer (noise prediction) -│ ├── VAE (latent encoding/decoding) -│ └── Text Encoder (CLIP/T5) -└── Scheduler (denoising algorithm) -``` - -### Pipeline inference flow - -``` -Text Prompt → Text Encoder → Text Embeddings - ↓ -Random Noise → [Denoising Loop] ← Scheduler - ↓ - Predicted Noise - ↓ - VAE Decoder → Final Image -``` - -## Core concepts - -### Pipelines - -Pipelines orchestrate complete workflows: - -| Pipeline | Purpose | -|----------|---------| -| `StableDiffusionPipeline` | Text-to-image (SD 1.x/2.x) | -| `StableDiffusionXLPipeline` | Text-to-image (SDXL) | -| `StableDiffusion3Pipeline` | Text-to-image (SD 3.0) | -| `FluxPipeline` | Text-to-image (Flux models) | -| `StableDiffusionImg2ImgPipeline` | Image-to-image | -| `StableDiffusionInpaintPipeline` | Inpainting | - -### Schedulers - -Schedulers control the denoising process: - -| Scheduler | Steps | Quality | Use Case | -|-----------|-------|---------|----------| -| `EulerDiscreteScheduler` | 20-50 | Good | Default choice | -| `EulerAncestralDiscreteScheduler` | 20-50 | Good | More variation | -| `DPMSolverMultistepScheduler` | 15-25 | Excellent | Fast, high quality | -| `DDIMScheduler` | 50-100 | Good | Deterministic | -| `LCMScheduler` | 4-8 | Good | Very fast | -| `UniPCMultistepScheduler` | 15-25 | Excellent | Fast convergence | - -### Swapping schedulers - -```python -from diffusers import DPMSolverMultistepScheduler - -# Swap for faster generation -pipe.scheduler = DPMSolverMultistepScheduler.from_config( - pipe.scheduler.config -) - -# Now generate with fewer steps -image = pipe(prompt, num_inference_steps=20).images[0] -``` - -## Generation parameters - -### Key parameters - -| Parameter | Default | Description | -|-----------|---------|-------------| -| `prompt` | Required | Text description of desired image | -| `negative_prompt` | None | What to avoid in the image | -| `num_inference_steps` | 50 | Denoising steps (more = better quality) | -| `guidance_scale` | 7.5 | Prompt adherence (7-12 typical) | -| `height`, `width` | 512/1024 | Output dimensions (multiples of 8) | -| `generator` | None | Torch generator for reproducibility | -| `num_images_per_prompt` | 1 | Batch size | - -### Reproducible generation - -```python -import torch - -generator = torch.Generator(device="cuda").manual_seed(42) - -image = pipe( - prompt="A cat wearing a top hat", - generator=generator, - num_inference_steps=50 -).images[0] -``` - -### Negative prompts - -```python -image = pipe( - prompt="Professional photo of a dog in a garden", - negative_prompt="blurry, low quality, distorted, ugly, bad anatomy", - guidance_scale=7.5 -).images[0] -``` - -## Image-to-image - -Transform existing images with text guidance: - -```python -from diffusers import AutoPipelineForImage2Image -from PIL import Image - -pipe = AutoPipelineForImage2Image.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - torch_dtype=torch.float16 -).to("cuda") - -init_image = Image.open("input.jpg").resize((512, 512)) - -image = pipe( - prompt="A watercolor painting of the scene", - image=init_image, - strength=0.75, # How much to transform (0-1) - num_inference_steps=50 -).images[0] -``` - -## Inpainting - -Fill masked regions: - -```python -from diffusers import AutoPipelineForInpainting -from PIL import Image - -pipe = AutoPipelineForInpainting.from_pretrained( - "runwayml/stable-diffusion-inpainting", - torch_dtype=torch.float16 -).to("cuda") - -image = Image.open("photo.jpg") -mask = Image.open("mask.png") # White = inpaint region - -result = pipe( - prompt="A red car parked on the street", - image=image, - mask_image=mask, - num_inference_steps=50 -).images[0] -``` - -## ControlNet - -Add spatial conditioning for precise control: - -```python -from diffusers import StableDiffusionControlNetPipeline, ControlNetModel -import torch - -# Load ControlNet for edge conditioning -controlnet = ControlNetModel.from_pretrained( - "lllyasviel/control_v11p_sd15_canny", - torch_dtype=torch.float16 -) - -pipe = StableDiffusionControlNetPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - controlnet=controlnet, - torch_dtype=torch.float16 -).to("cuda") - -# Use Canny edge image as control -control_image = get_canny_image(input_image) - -image = pipe( - prompt="A beautiful house in the style of Van Gogh", - image=control_image, - num_inference_steps=30 -).images[0] -``` - -### Available ControlNets - -| ControlNet | Input Type | Use Case | -|------------|------------|----------| -| `canny` | Edge maps | Preserve structure | -| `openpose` | Pose skeletons | Human poses | -| `depth` | Depth maps | 3D-aware generation | -| `normal` | Normal maps | Surface details | -| `mlsd` | Line segments | Architectural lines | -| `scribble` | Rough sketches | Sketch-to-image | - -## LoRA adapters - -Load fine-tuned style adapters: - -```python -from diffusers import DiffusionPipeline - -pipe = DiffusionPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - torch_dtype=torch.float16 -).to("cuda") - -# Load LoRA weights -pipe.load_lora_weights("path/to/lora", weight_name="style.safetensors") - -# Generate with LoRA style -image = pipe("A portrait in the trained style").images[0] - -# Adjust LoRA strength -pipe.fuse_lora(lora_scale=0.8) - -# Unload LoRA -pipe.unload_lora_weights() -``` - -### Multiple LoRAs - -```python -# Load multiple LoRAs -pipe.load_lora_weights("lora1", adapter_name="style") -pipe.load_lora_weights("lora2", adapter_name="character") - -# Set weights for each -pipe.set_adapters(["style", "character"], adapter_weights=[0.7, 0.5]) - -image = pipe("A portrait").images[0] -``` - -## Memory optimization - -### Enable CPU offloading - -```python -# Model CPU offload - moves models to CPU when not in use -pipe.enable_model_cpu_offload() - -# Sequential CPU offload - more aggressive, slower -pipe.enable_sequential_cpu_offload() -``` - -### Attention slicing - -```python -# Reduce memory by computing attention in chunks -pipe.enable_attention_slicing() - -# Or specific chunk size -pipe.enable_attention_slicing("max") -``` - -### xFormers memory-efficient attention - -```python -# Requires xformers package -pipe.enable_xformers_memory_efficient_attention() -``` - -### VAE slicing for large images - -```python -# Decode latents in tiles for large images -pipe.enable_vae_slicing() -pipe.enable_vae_tiling() -``` - -## Model variants - -### Loading different precisions - -```python -# FP16 (recommended for GPU) -pipe = DiffusionPipeline.from_pretrained( - "model-id", - torch_dtype=torch.float16, - variant="fp16" -) - -# BF16 (better precision, requires Ampere+ GPU) -pipe = DiffusionPipeline.from_pretrained( - "model-id", - torch_dtype=torch.bfloat16 -) -``` - -### Loading specific components - -```python -from diffusers import UNet2DConditionModel, AutoencoderKL - -# Load custom VAE -vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") - -# Use with pipeline -pipe = DiffusionPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - vae=vae, - torch_dtype=torch.float16 -) -``` - -## Batch generation - -Generate multiple images efficiently: - -```python -# Multiple prompts -prompts = [ - "A cat playing piano", - "A dog reading a book", - "A bird painting a picture" -] - -images = pipe(prompts, num_inference_steps=30).images - -# Multiple images per prompt -images = pipe( - "A beautiful sunset", - num_images_per_prompt=4, - num_inference_steps=30 -).images -``` - -## Common workflows - -### Workflow 1: High-quality generation - -```python -from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler -import torch - -# 1. Load SDXL with optimizations -pipe = StableDiffusionXLPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - torch_dtype=torch.float16, - variant="fp16" -) -pipe.to("cuda") -pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) -pipe.enable_model_cpu_offload() - -# 2. Generate with quality settings -image = pipe( - prompt="A majestic lion in the savanna, golden hour lighting, 8k, detailed fur", - negative_prompt="blurry, low quality, cartoon, anime, sketch", - num_inference_steps=30, - guidance_scale=7.5, - height=1024, - width=1024 -).images[0] -``` - -### Workflow 2: Fast prototyping - -```python -from diffusers import AutoPipelineForText2Image, LCMScheduler -import torch - -# Use LCM for 4-8 step generation -pipe = AutoPipelineForText2Image.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - torch_dtype=torch.float16 -).to("cuda") - -# Load LCM LoRA for fast generation -pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") -pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) -pipe.fuse_lora() - -# Generate in ~1 second -image = pipe( - "A beautiful landscape", - num_inference_steps=4, - guidance_scale=1.0 -).images[0] -``` - -## Common issues - -**CUDA out of memory:** -```python -# Enable memory optimizations -pipe.enable_model_cpu_offload() -pipe.enable_attention_slicing() -pipe.enable_vae_slicing() - -# Or use lower precision -pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) -``` - -**Black/noise images:** -```python -# Check VAE configuration -# Use safety checker bypass if needed -pipe.safety_checker = None - -# Ensure proper dtype consistency -pipe = pipe.to(dtype=torch.float16) -``` - -**Slow generation:** -```python -# Use faster scheduler -from diffusers import DPMSolverMultistepScheduler -pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) - -# Reduce steps -image = pipe(prompt, num_inference_steps=20).images[0] -``` - -## References - -- **[Advanced Usage](references/advanced-usage.md)** - Custom pipelines, fine-tuning, deployment -- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions - -## Resources - -- **Documentation**: https://huggingface.co/docs/diffusers -- **Repository**: https://github.com/huggingface/diffusers -- **Model Hub**: https://huggingface.co/models?library=diffusers -- **Discord**: https://discord.gg/diffusers diff --git a/skills/mlops/stable-diffusion/references/advanced-usage.md b/skills/mlops/stable-diffusion/references/advanced-usage.md deleted file mode 100644 index 2384715f9..000000000 --- a/skills/mlops/stable-diffusion/references/advanced-usage.md +++ /dev/null @@ -1,716 +0,0 @@ -# Stable Diffusion Advanced Usage Guide - -## Custom Pipelines - -### Building from components - -```python -from diffusers import ( - UNet2DConditionModel, - AutoencoderKL, - DDPMScheduler, - StableDiffusionPipeline -) -from transformers import CLIPTextModel, CLIPTokenizer -import torch - -# Load components individually -unet = UNet2DConditionModel.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - subfolder="unet" -) -vae = AutoencoderKL.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - subfolder="vae" -) -text_encoder = CLIPTextModel.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - subfolder="text_encoder" -) -tokenizer = CLIPTokenizer.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - subfolder="tokenizer" -) -scheduler = DDPMScheduler.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - subfolder="scheduler" -) - -# Assemble pipeline -pipe = StableDiffusionPipeline( - unet=unet, - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False -) -``` - -### Custom denoising loop - -```python -from diffusers import DDIMScheduler, AutoencoderKL, UNet2DConditionModel -from transformers import CLIPTextModel, CLIPTokenizer -import torch - -def custom_generate( - prompt: str, - num_steps: int = 50, - guidance_scale: float = 7.5, - height: int = 512, - width: int = 512 -): - # Load components - tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - unet = UNet2DConditionModel.from_pretrained("sd-model", subfolder="unet") - vae = AutoencoderKL.from_pretrained("sd-model", subfolder="vae") - scheduler = DDIMScheduler.from_pretrained("sd-model", subfolder="scheduler") - - device = "cuda" - text_encoder.to(device) - unet.to(device) - vae.to(device) - - # Encode prompt - text_input = tokenizer( - prompt, - padding="max_length", - max_length=77, - truncation=True, - return_tensors="pt" - ) - text_embeddings = text_encoder(text_input.input_ids.to(device))[0] - - # Unconditional embeddings for classifier-free guidance - uncond_input = tokenizer( - "", - padding="max_length", - max_length=77, - return_tensors="pt" - ) - uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] - - # Concatenate for batch processing - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - # Initialize latents - latents = torch.randn( - (1, 4, height // 8, width // 8), - device=device - ) - latents = latents * scheduler.init_noise_sigma - - # Denoising loop - scheduler.set_timesteps(num_steps) - for t in scheduler.timesteps: - latent_model_input = torch.cat([latents] * 2) - latent_model_input = scheduler.scale_model_input(latent_model_input, t) - - # Predict noise - with torch.no_grad(): - noise_pred = unet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings - ).sample - - # Classifier-free guidance - noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_cond - noise_pred_uncond - ) - - # Update latents - latents = scheduler.step(noise_pred, t, latents).prev_sample - - # Decode latents - latents = latents / vae.config.scaling_factor - with torch.no_grad(): - image = vae.decode(latents).sample - - # Convert to PIL - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - image = (image * 255).round().astype("uint8")[0] - - return Image.fromarray(image) -``` - -## IP-Adapter - -Use image prompts alongside text: - -```python -from diffusers import StableDiffusionPipeline -from diffusers.utils import load_image -import torch - -pipe = StableDiffusionPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - torch_dtype=torch.float16 -).to("cuda") - -# Load IP-Adapter -pipe.load_ip_adapter( - "h94/IP-Adapter", - subfolder="models", - weight_name="ip-adapter_sd15.bin" -) - -# Set IP-Adapter scale -pipe.set_ip_adapter_scale(0.6) - -# Load reference image -ip_image = load_image("reference_style.jpg") - -# Generate with image + text prompt -image = pipe( - prompt="A portrait in a garden", - ip_adapter_image=ip_image, - num_inference_steps=50 -).images[0] -``` - -### Multiple IP-Adapter images - -```python -# Use multiple reference images -pipe.set_ip_adapter_scale([0.5, 0.7]) - -images = [ - load_image("style_reference.jpg"), - load_image("composition_reference.jpg") -] - -result = pipe( - prompt="A landscape painting", - ip_adapter_image=images, - num_inference_steps=50 -).images[0] -``` - -## SDXL Refiner - -Two-stage generation for higher quality: - -```python -from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline -import torch - -# Load base model -base = StableDiffusionXLPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - torch_dtype=torch.float16, - variant="fp16" -).to("cuda") - -# Load refiner -refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-refiner-1.0", - torch_dtype=torch.float16, - variant="fp16" -).to("cuda") - -# Generate with base (partial denoising) -image = base( - prompt="A majestic eagle soaring over mountains", - num_inference_steps=40, - denoising_end=0.8, - output_type="latent" -).images - -# Refine with refiner -refined = refiner( - prompt="A majestic eagle soaring over mountains", - image=image, - num_inference_steps=40, - denoising_start=0.8 -).images[0] -``` - -## T2I-Adapter - -Lightweight conditioning without full ControlNet: - -```python -from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter -import torch - -# Load adapter -adapter = T2IAdapter.from_pretrained( - "TencentARC/t2i-adapter-canny-sdxl-1.0", - torch_dtype=torch.float16 -) - -pipe = StableDiffusionXLAdapterPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - adapter=adapter, - torch_dtype=torch.float16 -).to("cuda") - -# Get canny edges -canny_image = get_canny_image(input_image) - -image = pipe( - prompt="A colorful anime character", - image=canny_image, - num_inference_steps=30, - adapter_conditioning_scale=0.8 -).images[0] -``` - -## Fine-tuning with DreamBooth - -Train on custom subjects: - -```python -from diffusers import StableDiffusionPipeline, DDPMScheduler -from diffusers.optimization import get_scheduler -import torch -from torch.utils.data import Dataset, DataLoader -from PIL import Image -import os - -class DreamBoothDataset(Dataset): - def __init__(self, instance_images_path, instance_prompt, tokenizer, size=512): - self.instance_images_path = instance_images_path - self.instance_prompt = instance_prompt - self.tokenizer = tokenizer - self.size = size - - self.instance_images = [ - os.path.join(instance_images_path, f) - for f in os.listdir(instance_images_path) - if f.endswith(('.png', '.jpg', '.jpeg')) - ] - - def __len__(self): - return len(self.instance_images) - - def __getitem__(self, idx): - image = Image.open(self.instance_images[idx]).convert("RGB") - image = image.resize((self.size, self.size)) - image = torch.tensor(np.array(image)).permute(2, 0, 1) / 127.5 - 1.0 - - tokens = self.tokenizer( - self.instance_prompt, - padding="max_length", - max_length=77, - truncation=True, - return_tensors="pt" - ) - - return {"image": image, "input_ids": tokens.input_ids.squeeze()} - -def train_dreambooth( - pretrained_model: str, - instance_data_dir: str, - instance_prompt: str, - output_dir: str, - learning_rate: float = 5e-6, - max_train_steps: int = 800, - train_batch_size: int = 1 -): - # Load pipeline - pipe = StableDiffusionPipeline.from_pretrained(pretrained_model) - - unet = pipe.unet - vae = pipe.vae - text_encoder = pipe.text_encoder - tokenizer = pipe.tokenizer - noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model, subfolder="scheduler") - - # Freeze VAE and text encoder - vae.requires_grad_(False) - text_encoder.requires_grad_(False) - - # Create dataset - dataset = DreamBoothDataset( - instance_data_dir, instance_prompt, tokenizer - ) - dataloader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True) - - # Setup optimizer - optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate) - lr_scheduler = get_scheduler( - "constant", - optimizer=optimizer, - num_warmup_steps=0, - num_training_steps=max_train_steps - ) - - # Training loop - unet.train() - device = "cuda" - unet.to(device) - vae.to(device) - text_encoder.to(device) - - global_step = 0 - for epoch in range(max_train_steps // len(dataloader) + 1): - for batch in dataloader: - if global_step >= max_train_steps: - break - - # Encode images to latents - latents = vae.encode(batch["image"].to(device)).latent_dist.sample() - latents = latents * vae.config.scaling_factor - - # Sample noise - noise = torch.randn_like(latents) - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],)) - timesteps = timesteps.to(device) - - # Add noise - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get text embeddings - encoder_hidden_states = text_encoder(batch["input_ids"].to(device))[0] - - # Predict noise - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Compute loss - loss = torch.nn.functional.mse_loss(noise_pred, noise) - - # Backprop - loss.backward() - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - global_step += 1 - - if global_step % 100 == 0: - print(f"Step {global_step}, Loss: {loss.item():.4f}") - - # Save model - pipe.unet = unet - pipe.save_pretrained(output_dir) -``` - -## LoRA Training - -Efficient fine-tuning with Low-Rank Adaptation: - -```python -from peft import LoraConfig, get_peft_model -from diffusers import StableDiffusionPipeline -import torch - -def train_lora( - base_model: str, - train_dataset, - output_dir: str, - lora_rank: int = 4, - learning_rate: float = 1e-4, - max_train_steps: int = 1000 -): - pipe = StableDiffusionPipeline.from_pretrained(base_model) - unet = pipe.unet - - # Configure LoRA - lora_config = LoraConfig( - r=lora_rank, - lora_alpha=lora_rank, - target_modules=["to_q", "to_v", "to_k", "to_out.0"], - lora_dropout=0.1 - ) - - # Apply LoRA to UNet - unet = get_peft_model(unet, lora_config) - unet.print_trainable_parameters() # Shows ~0.1% trainable - - # Train (similar to DreamBooth but only LoRA params) - optimizer = torch.optim.AdamW( - unet.parameters(), - lr=learning_rate - ) - - # ... training loop ... - - # Save LoRA weights only - unet.save_pretrained(output_dir) -``` - -## Textual Inversion - -Learn new concepts through embeddings: - -```python -from diffusers import StableDiffusionPipeline -import torch - -# Load with textual inversion -pipe = StableDiffusionPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - torch_dtype=torch.float16 -).to("cuda") - -# Load learned embedding -pipe.load_textual_inversion( - "sd-concepts-library/cat-toy", - token="" -) - -# Use in prompts -image = pipe("A photo of on a beach").images[0] -``` - -## Quantization - -Reduce memory with quantization: - -```python -from diffusers import BitsAndBytesConfig, StableDiffusionXLPipeline -import torch - -# 8-bit quantization -quantization_config = BitsAndBytesConfig(load_in_8bit=True) - -pipe = StableDiffusionXLPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - quantization_config=quantization_config, - torch_dtype=torch.float16 -) -``` - -### NF4 quantization (4-bit) - -```python -quantization_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.float16 -) - -pipe = StableDiffusionXLPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - quantization_config=quantization_config -) -``` - -## Production Deployment - -### FastAPI server - -```python -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -from diffusers import DiffusionPipeline -import torch -import base64 -from io import BytesIO - -app = FastAPI() - -# Load model at startup -pipe = DiffusionPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - torch_dtype=torch.float16 -).to("cuda") -pipe.enable_model_cpu_offload() - -class GenerationRequest(BaseModel): - prompt: str - negative_prompt: str = "" - num_inference_steps: int = 30 - guidance_scale: float = 7.5 - width: int = 512 - height: int = 512 - seed: int = None - -class GenerationResponse(BaseModel): - image_base64: str - seed: int - -@app.post("/generate", response_model=GenerationResponse) -async def generate(request: GenerationRequest): - try: - generator = None - seed = request.seed or torch.randint(0, 2**32, (1,)).item() - generator = torch.Generator("cuda").manual_seed(seed) - - image = pipe( - prompt=request.prompt, - negative_prompt=request.negative_prompt, - num_inference_steps=request.num_inference_steps, - guidance_scale=request.guidance_scale, - width=request.width, - height=request.height, - generator=generator - ).images[0] - - # Convert to base64 - buffer = BytesIO() - image.save(buffer, format="PNG") - image_base64 = base64.b64encode(buffer.getvalue()).decode() - - return GenerationResponse(image_base64=image_base64, seed=seed) - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/health") -async def health(): - return {"status": "healthy"} -``` - -### Docker deployment - -```dockerfile -FROM nvidia/cuda:12.1-runtime-ubuntu22.04 - -RUN apt-get update && apt-get install -y python3 python3-pip - -WORKDIR /app - -COPY requirements.txt . -RUN pip3 install -r requirements.txt - -COPY . . - -# Pre-download model -RUN python3 -c "from diffusers import DiffusionPipeline; DiffusionPipeline.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5')" - -EXPOSE 8000 -CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"] -``` - -### Kubernetes deployment - -```yaml -apiVersion: apps/v1 -kind: Deployment -metadata: - name: stable-diffusion -spec: - replicas: 2 - selector: - matchLabels: - app: stable-diffusion - template: - metadata: - labels: - app: stable-diffusion - spec: - containers: - - name: sd - image: your-registry/stable-diffusion:latest - ports: - - containerPort: 8000 - resources: - limits: - nvidia.com/gpu: 1 - memory: "16Gi" - requests: - nvidia.com/gpu: 1 - memory: "8Gi" - env: - - name: TRANSFORMERS_CACHE - value: "/cache/huggingface" - volumeMounts: - - name: model-cache - mountPath: /cache - volumes: - - name: model-cache - persistentVolumeClaim: - claimName: model-cache-pvc ---- -apiVersion: v1 -kind: Service -metadata: - name: stable-diffusion -spec: - selector: - app: stable-diffusion - ports: - - port: 80 - targetPort: 8000 - type: LoadBalancer -``` - -## Callback System - -Monitor and modify generation: - -```python -from diffusers import StableDiffusionPipeline -from diffusers.callbacks import PipelineCallback -import torch - -class ProgressCallback(PipelineCallback): - def __init__(self): - self.progress = [] - - def callback_fn(self, pipe, step_index, timestep, callback_kwargs): - self.progress.append({ - "step": step_index, - "timestep": timestep.item() - }) - - # Optionally modify latents - latents = callback_kwargs["latents"] - - return callback_kwargs - -# Use callback -callback = ProgressCallback() - -image = pipe( - prompt="A sunset", - callback_on_step_end=callback.callback_fn, - callback_on_step_end_tensor_inputs=["latents"] -).images[0] - -print(f"Generation completed in {len(callback.progress)} steps") -``` - -### Early stopping - -```python -def early_stop_callback(pipe, step_index, timestep, callback_kwargs): - # Stop after 20 steps - if step_index >= 20: - pipe._interrupt = True - return callback_kwargs - -image = pipe( - prompt="A landscape", - num_inference_steps=50, - callback_on_step_end=early_stop_callback -).images[0] -``` - -## Multi-GPU Inference - -### Device map auto - -```python -from diffusers import StableDiffusionXLPipeline - -pipe = StableDiffusionXLPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - device_map="auto", # Automatically distribute across GPUs - torch_dtype=torch.float16 -) -``` - -### Manual distribution - -```python -from accelerate import infer_auto_device_map, dispatch_model - -# Create device map -device_map = infer_auto_device_map( - pipe.unet, - max_memory={0: "10GiB", 1: "10GiB"} -) - -# Dispatch model -pipe.unet = dispatch_model(pipe.unet, device_map=device_map) -``` diff --git a/skills/mlops/stable-diffusion/references/troubleshooting.md b/skills/mlops/stable-diffusion/references/troubleshooting.md deleted file mode 100644 index f358643b6..000000000 --- a/skills/mlops/stable-diffusion/references/troubleshooting.md +++ /dev/null @@ -1,555 +0,0 @@ -# Stable Diffusion Troubleshooting Guide - -## Installation Issues - -### Package conflicts - -**Error**: `ImportError: cannot import name 'cached_download' from 'huggingface_hub'` - -**Fix**: -```bash -# Update huggingface_hub -pip install --upgrade huggingface_hub - -# Reinstall diffusers -pip install --upgrade diffusers -``` - -### xFormers installation fails - -**Error**: `RuntimeError: CUDA error: no kernel image is available for execution` - -**Fix**: -```bash -# Check CUDA version -nvcc --version - -# Install matching xformers -pip install xformers --index-url https://download.pytorch.org/whl/cu121 # For CUDA 12.1 - -# Or build from source -pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers -``` - -### Torch/CUDA mismatch - -**Error**: `RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED` - -**Fix**: -```bash -# Check versions -python -c "import torch; print(torch.__version__, torch.cuda.is_available())" - -# Reinstall PyTorch with correct CUDA -pip uninstall torch torchvision -pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 -``` - -## Memory Issues - -### CUDA out of memory - -**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory` - -**Solutions**: - -```python -# Solution 1: Enable CPU offloading -pipe.enable_model_cpu_offload() - -# Solution 2: Sequential CPU offload (more aggressive) -pipe.enable_sequential_cpu_offload() - -# Solution 3: Attention slicing -pipe.enable_attention_slicing() - -# Solution 4: VAE slicing for large images -pipe.enable_vae_slicing() - -# Solution 5: Use lower precision -pipe = DiffusionPipeline.from_pretrained( - "model-id", - torch_dtype=torch.float16 # or torch.bfloat16 -) - -# Solution 6: Reduce batch size -image = pipe(prompt, num_images_per_prompt=1).images[0] - -# Solution 7: Generate smaller images -image = pipe(prompt, height=512, width=512).images[0] - -# Solution 8: Clear cache between generations -import gc -torch.cuda.empty_cache() -gc.collect() -``` - -### Memory grows over time - -**Problem**: Memory usage increases with each generation - -**Fix**: -```python -import gc -import torch - -def generate_with_cleanup(pipe, prompt, **kwargs): - try: - image = pipe(prompt, **kwargs).images[0] - return image - finally: - # Clear cache after generation - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() -``` - -### Large model loading fails - -**Error**: `RuntimeError: Unable to load model weights` - -**Fix**: -```python -# Use low CPU memory mode -pipe = DiffusionPipeline.from_pretrained( - "large-model-id", - low_cpu_mem_usage=True, - torch_dtype=torch.float16 -) -``` - -## Generation Issues - -### Black images - -**Problem**: Output images are completely black - -**Solutions**: -```python -# Solution 1: Disable safety checker -pipe.safety_checker = None - -# Solution 2: Check VAE scaling -# The issue might be with VAE encoding/decoding -latents = latents / pipe.vae.config.scaling_factor # Before decode - -# Solution 3: Ensure proper dtype -pipe = pipe.to(dtype=torch.float16) -pipe.vae = pipe.vae.to(dtype=torch.float32) # VAE often needs fp32 - -# Solution 4: Check guidance scale -# Too high can cause issues -image = pipe(prompt, guidance_scale=7.5).images[0] # Not 20+ -``` - -### Noise/static images - -**Problem**: Output looks like random noise - -**Solutions**: -```python -# Solution 1: Increase inference steps -image = pipe(prompt, num_inference_steps=50).images[0] - -# Solution 2: Check scheduler configuration -pipe.scheduler = pipe.scheduler.from_config(pipe.scheduler.config) - -# Solution 3: Verify model was loaded correctly -print(pipe.unet) # Should show model architecture -``` - -### Blurry images - -**Problem**: Output images are low quality or blurry - -**Solutions**: -```python -# Solution 1: Use more steps -image = pipe(prompt, num_inference_steps=50).images[0] - -# Solution 2: Use better VAE -from diffusers import AutoencoderKL -vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") -pipe.vae = vae - -# Solution 3: Use SDXL or refiner -pipe = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0" -) - -# Solution 4: Upscale with img2img -upscale_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(...) -upscaled = upscale_pipe( - prompt=prompt, - image=image.resize((1024, 1024)), - strength=0.3 -).images[0] -``` - -### Prompt not being followed - -**Problem**: Generated image doesn't match the prompt - -**Solutions**: -```python -# Solution 1: Increase guidance scale -image = pipe(prompt, guidance_scale=10.0).images[0] - -# Solution 2: Use negative prompts -image = pipe( - prompt="A red car", - negative_prompt="blue, green, yellow, wrong color", - guidance_scale=7.5 -).images[0] - -# Solution 3: Use prompt weighting -# Emphasize important words -prompt = "A (red:1.5) car on a street" - -# Solution 4: Use longer, more detailed prompts -prompt = """ -A bright red sports car, ferrari style, parked on a city street, -photorealistic, high detail, 8k, professional photography -""" -``` - -### Distorted faces/hands - -**Problem**: Faces and hands look deformed - -**Solutions**: -```python -# Solution 1: Use negative prompts -negative_prompt = """ -bad hands, bad anatomy, deformed, ugly, blurry, -extra fingers, mutated hands, poorly drawn hands, -poorly drawn face, mutation, deformed face -""" - -# Solution 2: Use face-specific models -# ADetailer or similar post-processing - -# Solution 3: Use ControlNet for poses -# Load pose estimation and condition generation - -# Solution 4: Inpaint problematic areas -mask = create_face_mask(image) -fixed = inpaint_pipe( - prompt="beautiful detailed face", - image=image, - mask_image=mask -).images[0] -``` - -## Scheduler Issues - -### Scheduler not compatible - -**Error**: `ValueError: Scheduler ... is not compatible with pipeline` - -**Fix**: -```python -from diffusers import EulerDiscreteScheduler - -# Create scheduler from config -pipe.scheduler = EulerDiscreteScheduler.from_config( - pipe.scheduler.config -) - -# Check compatible schedulers -print(pipe.scheduler.compatibles) -``` - -### Wrong number of steps - -**Problem**: Model generates different quality with same steps - -**Fix**: -```python -# Reset timesteps explicitly -pipe.scheduler.set_timesteps(num_inference_steps) - -# Check scheduler's step count -print(len(pipe.scheduler.timesteps)) -``` - -## LoRA Issues - -### LoRA weights not loading - -**Error**: `RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel` - -**Fix**: -```python -# Check weight file format -# Should be .safetensors or .bin - -# Load with correct key prefix -pipe.load_lora_weights( - "path/to/lora", - weight_name="lora.safetensors" -) - -# Try loading into specific component -pipe.unet.load_attn_procs("path/to/lora") -``` - -### LoRA not affecting output - -**Problem**: Generated images look the same with/without LoRA - -**Fix**: -```python -# Fuse LoRA weights -pipe.fuse_lora(lora_scale=1.0) - -# Or set scale explicitly -pipe.set_adapters(["lora_name"], adapter_weights=[1.0]) - -# Verify LoRA is loaded -print(list(pipe.unet.attn_processors.keys())) -``` - -### Multiple LoRAs conflict - -**Problem**: Multiple LoRAs produce artifacts - -**Fix**: -```python -# Load with different adapter names -pipe.load_lora_weights("lora1", adapter_name="style") -pipe.load_lora_weights("lora2", adapter_name="subject") - -# Balance weights -pipe.set_adapters( - ["style", "subject"], - adapter_weights=[0.5, 0.5] # Lower weights -) - -# Or use LoRA merge before loading -# Merge LoRAs offline with appropriate ratios -``` - -## ControlNet Issues - -### ControlNet not conditioning - -**Problem**: ControlNet has no effect on output - -**Fix**: -```python -# Check control image format -# Should be RGB, matching generation size -control_image = control_image.resize((512, 512)) - -# Increase conditioning scale -image = pipe( - prompt=prompt, - image=control_image, - controlnet_conditioning_scale=1.0, # Try 0.5-1.5 - num_inference_steps=30 -).images[0] - -# Verify ControlNet is loaded -print(pipe.controlnet) -``` - -### Control image preprocessing - -**Fix**: -```python -from controlnet_aux import CannyDetector - -# Proper preprocessing -canny = CannyDetector() -control_image = canny(input_image) - -# Ensure correct format -control_image = control_image.convert("RGB") -control_image = control_image.resize((512, 512)) -``` - -## Hub/Download Issues - -### Model download fails - -**Error**: `requests.exceptions.ConnectionError` - -**Fix**: -```bash -# Set longer timeout -export HF_HUB_DOWNLOAD_TIMEOUT=600 - -# Use mirror if available -export HF_ENDPOINT=https://hf-mirror.com - -# Or download manually -huggingface-cli download stable-diffusion-v1-5/stable-diffusion-v1-5 -``` - -### Cache issues - -**Error**: `OSError: Can't load model from cache` - -**Fix**: -```bash -# Clear cache -rm -rf ~/.cache/huggingface/hub - -# Or set different cache location -export HF_HOME=/path/to/cache - -# Force re-download -pipe = DiffusionPipeline.from_pretrained( - "model-id", - force_download=True -) -``` - -### Access denied for gated models - -**Error**: `401 Client Error: Unauthorized` - -**Fix**: -```bash -# Login to Hugging Face -huggingface-cli login - -# Or use token -pipe = DiffusionPipeline.from_pretrained( - "model-id", - token="hf_xxxxx" -) - -# Accept model license on Hub website first -``` - -## Performance Issues - -### Slow generation - -**Problem**: Generation takes too long - -**Solutions**: -```python -# Solution 1: Use faster scheduler -from diffusers import DPMSolverMultistepScheduler -pipe.scheduler = DPMSolverMultistepScheduler.from_config( - pipe.scheduler.config -) - -# Solution 2: Reduce steps -image = pipe(prompt, num_inference_steps=20).images[0] - -# Solution 3: Use LCM -from diffusers import LCMScheduler -pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") -pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) -image = pipe(prompt, num_inference_steps=4, guidance_scale=1.0).images[0] - -# Solution 4: Enable xFormers -pipe.enable_xformers_memory_efficient_attention() - -# Solution 5: Compile model -pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) -``` - -### First generation is slow - -**Problem**: First image takes much longer - -**Fix**: -```python -# Warm up the model -_ = pipe("warmup", num_inference_steps=1) - -# Then run actual generation -image = pipe(prompt, num_inference_steps=50).images[0] - -# Compile for faster subsequent runs -pipe.unet = torch.compile(pipe.unet) -``` - -## Debugging Tips - -### Enable debug logging - -```python -import logging -logging.basicConfig(level=logging.DEBUG) - -# Or for specific modules -logging.getLogger("diffusers").setLevel(logging.DEBUG) -logging.getLogger("transformers").setLevel(logging.DEBUG) -``` - -### Check model components - -```python -# Print pipeline components -print(pipe.components) - -# Check model config -print(pipe.unet.config) -print(pipe.vae.config) -print(pipe.scheduler.config) - -# Verify device placement -print(pipe.device) -for name, module in pipe.components.items(): - if hasattr(module, 'device'): - print(f"{name}: {module.device}") -``` - -### Validate inputs - -```python -# Check image dimensions -print(f"Height: {height}, Width: {width}") -assert height % 8 == 0, "Height must be divisible by 8" -assert width % 8 == 0, "Width must be divisible by 8" - -# Check prompt tokenization -tokens = pipe.tokenizer(prompt, return_tensors="pt") -print(f"Token count: {tokens.input_ids.shape[1]}") # Max 77 for SD -``` - -### Save intermediate results - -```python -def save_latents_callback(pipe, step_index, timestep, callback_kwargs): - latents = callback_kwargs["latents"] - - # Decode and save intermediate - with torch.no_grad(): - image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor).sample - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy()[0] - Image.fromarray((image * 255).astype("uint8")).save(f"step_{step_index}.png") - - return callback_kwargs - -image = pipe( - prompt, - callback_on_step_end=save_latents_callback, - callback_on_step_end_tensor_inputs=["latents"] -).images[0] -``` - -## Getting Help - -1. **Documentation**: https://huggingface.co/docs/diffusers -2. **GitHub Issues**: https://github.com/huggingface/diffusers/issues -3. **Discord**: https://discord.gg/diffusers -4. **Forum**: https://discuss.huggingface.co - -### Reporting Issues - -Include: -- Diffusers version: `pip show diffusers` -- PyTorch version: `python -c "import torch; print(torch.__version__)"` -- CUDA version: `nvcc --version` -- GPU model: `nvidia-smi` -- Full error traceback -- Minimal reproducible code -- Model name/ID used diff --git a/skills/mlops/tensorrt-llm/SKILL.md b/skills/mlops/tensorrt-llm/SKILL.md deleted file mode 100644 index 056511699..000000000 --- a/skills/mlops/tensorrt-llm/SKILL.md +++ /dev/null @@ -1,190 +0,0 @@ ---- -name: tensorrt-llm -description: Optimizes LLM inference with NVIDIA TensorRT for maximum throughput and lowest latency. Use for production deployment on NVIDIA GPUs (A100/H100), when you need 10-100x faster inference than PyTorch, or for serving models with quantization (FP8/INT4), in-flight batching, and multi-GPU scaling. -version: 1.0.0 -author: Orchestra Research -license: MIT -dependencies: [tensorrt-llm, torch] -metadata: - hermes: - tags: [Inference Serving, TensorRT-LLM, NVIDIA, Inference Optimization, High Throughput, Low Latency, Production, FP8, INT4, In-Flight Batching, Multi-GPU] - ---- - -# TensorRT-LLM - -NVIDIA's open-source library for optimizing LLM inference with state-of-the-art performance on NVIDIA GPUs. - -## When to use TensorRT-LLM - -**Use TensorRT-LLM when:** -- Deploying on NVIDIA GPUs (A100, H100, GB200) -- Need maximum throughput (24,000+ tokens/sec on Llama 3) -- Require low latency for real-time applications -- Working with quantized models (FP8, INT4, FP4) -- Scaling across multiple GPUs or nodes - -**Use vLLM instead when:** -- Need simpler setup and Python-first API -- Want PagedAttention without TensorRT compilation -- Working with AMD GPUs or non-NVIDIA hardware - -**Use llama.cpp instead when:** -- Deploying on CPU or Apple Silicon -- Need edge deployment without NVIDIA GPUs -- Want simpler GGUF quantization format - -## Quick start - -### Installation - -```bash -# Docker (recommended) -docker pull nvidia/tensorrt_llm:latest - -# pip install -pip install tensorrt_llm==1.2.0rc3 - -# Requires CUDA 13.0.0, TensorRT 10.13.2, Python 3.10-3.12 -``` - -### Basic inference - -```python -from tensorrt_llm import LLM, SamplingParams - -# Initialize model -llm = LLM(model="meta-llama/Meta-Llama-3-8B") - -# Configure sampling -sampling_params = SamplingParams( - max_tokens=100, - temperature=0.7, - top_p=0.9 -) - -# Generate -prompts = ["Explain quantum computing"] -outputs = llm.generate(prompts, sampling_params) - -for output in outputs: - print(output.text) -``` - -### Serving with trtllm-serve - -```bash -# Start server (automatic model download and compilation) -trtllm-serve meta-llama/Meta-Llama-3-8B \ - --tp_size 4 \ # Tensor parallelism (4 GPUs) - --max_batch_size 256 \ - --max_num_tokens 4096 - -# Client request -curl -X POST http://localhost:8000/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "meta-llama/Meta-Llama-3-8B", - "messages": [{"role": "user", "content": "Hello!"}], - "temperature": 0.7, - "max_tokens": 100 - }' -``` - -## Key features - -### Performance optimizations -- **In-flight batching**: Dynamic batching during generation -- **Paged KV cache**: Efficient memory management -- **Flash Attention**: Optimized attention kernels -- **Quantization**: FP8, INT4, FP4 for 2-4× faster inference -- **CUDA graphs**: Reduced kernel launch overhead - -### Parallelism -- **Tensor parallelism (TP)**: Split model across GPUs -- **Pipeline parallelism (PP)**: Layer-wise distribution -- **Expert parallelism**: For Mixture-of-Experts models -- **Multi-node**: Scale beyond single machine - -### Advanced features -- **Speculative decoding**: Faster generation with draft models -- **LoRA serving**: Efficient multi-adapter deployment -- **Disaggregated serving**: Separate prefill and generation - -## Common patterns - -### Quantized model (FP8) - -```python -from tensorrt_llm import LLM - -# Load FP8 quantized model (2× faster, 50% memory) -llm = LLM( - model="meta-llama/Meta-Llama-3-70B", - dtype="fp8", - max_num_tokens=8192 -) - -# Inference same as before -outputs = llm.generate(["Summarize this article..."]) -``` - -### Multi-GPU deployment - -```python -# Tensor parallelism across 8 GPUs -llm = LLM( - model="meta-llama/Meta-Llama-3-405B", - tensor_parallel_size=8, - dtype="fp8" -) -``` - -### Batch inference - -```python -# Process 100 prompts efficiently -prompts = [f"Question {i}: ..." for i in range(100)] - -outputs = llm.generate( - prompts, - sampling_params=SamplingParams(max_tokens=200) -) - -# Automatic in-flight batching for maximum throughput -``` - -## Performance benchmarks - -**Meta Llama 3-8B** (H100 GPU): -- Throughput: 24,000 tokens/sec -- Latency: ~10ms per token -- vs PyTorch: **100× faster** - -**Llama 3-70B** (8× A100 80GB): -- FP8 quantization: 2× faster than FP16 -- Memory: 50% reduction with FP8 - -## Supported models - -- **LLaMA family**: Llama 2, Llama 3, CodeLlama -- **GPT family**: GPT-2, GPT-J, GPT-NeoX -- **Qwen**: Qwen, Qwen2, QwQ -- **DeepSeek**: DeepSeek-V2, DeepSeek-V3 -- **Mixtral**: Mixtral-8x7B, Mixtral-8x22B -- **Vision**: LLaVA, Phi-3-vision -- **100+ models** on HuggingFace - -## References - -- **[Optimization Guide](references/optimization.md)** - Quantization, batching, KV cache tuning -- **[Multi-GPU Setup](references/multi-gpu.md)** - Tensor/pipeline parallelism, multi-node -- **[Serving Guide](references/serving.md)** - Production deployment, monitoring, autoscaling - -## Resources - -- **Docs**: https://nvidia.github.io/TensorRT-LLM/ -- **GitHub**: https://github.com/NVIDIA/TensorRT-LLM -- **Models**: https://huggingface.co/models?library=tensorrt_llm - - diff --git a/skills/mlops/tensorrt-llm/references/multi-gpu.md b/skills/mlops/tensorrt-llm/references/multi-gpu.md deleted file mode 100644 index 1c0a5e7e9..000000000 --- a/skills/mlops/tensorrt-llm/references/multi-gpu.md +++ /dev/null @@ -1,298 +0,0 @@ -# Multi-GPU Deployment Guide - -Comprehensive guide to scaling TensorRT-LLM across multiple GPUs and nodes. - -## Parallelism Strategies - -### Tensor Parallelism (TP) - -**What it does**: Splits model layers across GPUs horizontally. - -**Use case**: -- Model fits in total GPU memory but not single GPU -- Need low latency (single forward pass) -- GPUs on same node (NVLink required for best performance) - -**Example** (Llama 3-70B on 4× A100): -```python -from tensorrt_llm import LLM - -llm = LLM( - model="meta-llama/Meta-Llama-3-70B", - tensor_parallel_size=4, # Split across 4 GPUs - dtype="fp16" -) - -# Model automatically sharded across GPUs -# Single forward pass, low latency -``` - -**Performance**: -- Latency: ~Same as single GPU -- Throughput: 4× higher (4 GPUs) -- Communication: High (activations synced every layer) - -### Pipeline Parallelism (PP) - -**What it does**: Splits model layers across GPUs vertically (layer-wise). - -**Use case**: -- Very large models (175B+) -- Can tolerate higher latency -- GPUs across multiple nodes - -**Example** (Llama 3-405B on 8× H100): -```python -llm = LLM( - model="meta-llama/Meta-Llama-3-405B", - tensor_parallel_size=4, # TP=4 within nodes - pipeline_parallel_size=2, # PP=2 across nodes - dtype="fp8" -) - -# Total: 8 GPUs (4×2) -# Layers 0-40: Node 1 (4 GPUs with TP) -# Layers 41-80: Node 2 (4 GPUs with TP) -``` - -**Performance**: -- Latency: Higher (sequential through pipeline) -- Throughput: High with micro-batching -- Communication: Lower than TP - -### Expert Parallelism (EP) - -**What it does**: Distributes MoE experts across GPUs. - -**Use case**: Mixture-of-Experts models (Mixtral, DeepSeek-V2) - -**Example** (Mixtral-8x22B on 8× A100): -```python -llm = LLM( - model="mistralai/Mixtral-8x22B", - tensor_parallel_size=4, - expert_parallel_size=2, # Distribute 8 experts across 2 groups - dtype="fp8" -) -``` - -## Configuration Examples - -### Small model (7-13B) - Single GPU - -```python -# Llama 3-8B on 1× A100 80GB -llm = LLM( - model="meta-llama/Meta-Llama-3-8B", - dtype="fp16" # or fp8 for H100 -) -``` - -**Resources**: -- GPU: 1× A100 80GB -- Memory: ~16GB model + 30GB KV cache -- Throughput: 3,000-5,000 tokens/sec - -### Medium model (70B) - Multi-GPU same node - -```python -# Llama 3-70B on 4× A100 80GB (NVLink) -llm = LLM( - model="meta-llama/Meta-Llama-3-70B", - tensor_parallel_size=4, - dtype="fp8" # 70GB → 35GB per GPU -) -``` - -**Resources**: -- GPU: 4× A100 80GB with NVLink -- Memory: ~35GB per GPU (FP8) -- Throughput: 10,000-15,000 tokens/sec -- Latency: 15-20ms per token - -### Large model (405B) - Multi-node - -```python -# Llama 3-405B on 2 nodes × 8 H100 = 16 GPUs -llm = LLM( - model="meta-llama/Meta-Llama-3-405B", - tensor_parallel_size=8, # TP within each node - pipeline_parallel_size=2, # PP across 2 nodes - dtype="fp8" -) -``` - -**Resources**: -- GPU: 2 nodes × 8 H100 80GB -- Memory: ~25GB per GPU (FP8) -- Throughput: 20,000-30,000 tokens/sec -- Network: InfiniBand recommended - -## Server Deployment - -### Single-node multi-GPU - -```bash -# Llama 3-70B on 4 GPUs (automatic TP) -trtllm-serve meta-llama/Meta-Llama-3-70B \ - --tp_size 4 \ - --max_batch_size 256 \ - --dtype fp8 - -# Listens on http://localhost:8000 -``` - -### Multi-node with Ray - -```bash -# Node 1 (head node) -ray start --head --port=6379 - -# Node 2 (worker) -ray start --address='node1:6379' - -# Deploy across cluster -trtllm-serve meta-llama/Meta-Llama-3-405B \ - --tp_size 8 \ - --pp_size 2 \ - --num_workers 2 \ # 2 nodes - --dtype fp8 -``` - -### Kubernetes deployment - -```yaml -apiVersion: apps/v1 -kind: Deployment -metadata: - name: tensorrt-llm-llama3-70b -spec: - replicas: 1 - template: - spec: - containers: - - name: trtllm - image: nvidia/tensorrt_llm:latest - command: - - trtllm-serve - - meta-llama/Meta-Llama-3-70B - - --tp_size=4 - - --max_batch_size=256 - resources: - limits: - nvidia.com/gpu: 4 # Request 4 GPUs -``` - -## Parallelism Decision Tree - -``` -Model size < 20GB? -├─ YES: Single GPU (no parallelism) -└─ NO: Model size < 80GB? - ├─ YES: TP=2 or TP=4 (same node) - └─ NO: Model size < 320GB? - ├─ YES: TP=4 or TP=8 (same node, NVLink required) - └─ NO: TP=8 + PP=2 (multi-node) -``` - -## Communication Optimization - -### NVLink vs PCIe - -**NVLink** (DGX A100, HGX H100): -- Bandwidth: 600 GB/s (A100), 900 GB/s (H100) -- Ideal for TP (high communication) -- **Recommended for all multi-GPU setups** - -**PCIe**: -- Bandwidth: 64 GB/s (PCIe 4.0 x16) -- 10× slower than NVLink -- Avoid TP, use PP instead - -### InfiniBand for multi-node - -**HDR InfiniBand** (200 Gb/s): -- Required for multi-node TP or PP -- Latency: <1μs -- **Essential for 405B+ models** - -## Monitoring Multi-GPU - -```python -# Monitor GPU utilization -nvidia-smi dmon -s u - -# Monitor memory -nvidia-smi dmon -s m - -# Monitor NVLink utilization -nvidia-smi nvlink --status - -# TensorRT-LLM built-in metrics -curl http://localhost:8000/metrics -``` - -**Key metrics**: -- GPU utilization: Target 80-95% -- Memory usage: Should be balanced across GPUs -- NVLink traffic: High for TP, low for PP -- Throughput: Tokens/sec across all GPUs - -## Common Issues - -### Imbalanced GPU memory - -**Symptom**: GPU 0 has 90% memory, GPU 3 has 40% - -**Solutions**: -- Verify TP/PP configuration -- Check model sharding (should be equal) -- Restart server to reset state - -### Low NVLink utilization - -**Symptom**: NVLink bandwidth <100 GB/s with TP=4 - -**Solutions**: -- Verify NVLink topology: `nvidia-smi topo -m` -- Check for PCIe fallback -- Ensure GPUs are on same NVSwitch - -### OOM with multi-GPU - -**Solutions**: -- Increase TP size (more GPUs) -- Reduce batch size -- Enable FP8 quantization -- Use pipeline parallelism - -## Performance Scaling - -### TP Scaling (Llama 3-70B, FP8) - -| GPUs | TP Size | Throughput | Latency | Efficiency | -|------|---------|------------|---------|------------| -| 1 | 1 | OOM | - | - | -| 2 | 2 | 6,000 tok/s | 18ms | 85% | -| 4 | 4 | 11,000 tok/s | 16ms | 78% | -| 8 | 8 | 18,000 tok/s | 15ms | 64% | - -**Note**: Efficiency drops with more GPUs due to communication overhead. - -### PP Scaling (Llama 3-405B, FP8) - -| Nodes | TP | PP | Total GPUs | Throughput | -|-------|----|----|------------|------------| -| 1 | 8 | 1 | 8 | OOM | -| 2 | 8 | 2 | 16 | 25,000 tok/s | -| 4 | 8 | 4 | 32 | 45,000 tok/s | - -## Best Practices - -1. **Prefer TP over PP** when possible (lower latency) -2. **Use NVLink** for all TP deployments -3. **Use InfiniBand** for multi-node deployments -4. **Start with smallest TP** that fits model in memory -5. **Monitor GPU balance** - all GPUs should have similar utilization -6. **Test with benchmark** before production -7. **Use FP8** on H100 for 2× speedup diff --git a/skills/mlops/tensorrt-llm/references/optimization.md b/skills/mlops/tensorrt-llm/references/optimization.md deleted file mode 100644 index 2eb255ddf..000000000 --- a/skills/mlops/tensorrt-llm/references/optimization.md +++ /dev/null @@ -1,242 +0,0 @@ -# TensorRT-LLM Optimization Guide - -Comprehensive guide to optimizing LLM inference with TensorRT-LLM. - -## Quantization - -### FP8 Quantization (Recommended for H100) - -**Benefits**: -- 2× faster inference -- 50% memory reduction -- Minimal accuracy loss (<1% perplexity degradation) - -**Usage**: -```python -from tensorrt_llm import LLM - -# Automatic FP8 quantization -llm = LLM( - model="meta-llama/Meta-Llama-3-70B", - dtype="fp8", - quantization="fp8" -) -``` - -**Performance** (Llama 3-70B on 8× H100): -- FP16: 5,000 tokens/sec -- FP8: **10,000 tokens/sec** (2× speedup) -- Memory: 140GB → 70GB - -### INT4 Quantization (Maximum compression) - -**Benefits**: -- 4× memory reduction -- 3-4× faster inference -- Fits larger models on same hardware - -**Usage**: -```python -# INT4 with AWQ calibration -llm = LLM( - model="meta-llama/Meta-Llama-3-405B", - dtype="int4_awq", - quantization="awq" -) - -# INT4 with GPTQ calibration -llm = LLM( - model="meta-llama/Meta-Llama-3-405B", - dtype="int4_gptq", - quantization="gptq" -) -``` - -**Trade-offs**: -- Accuracy: 1-3% perplexity increase -- Speed: 3-4× faster than FP16 -- Use case: When memory is critical - -## In-Flight Batching - -**What it does**: Dynamically batches requests during generation instead of waiting for all sequences to finish. - -**Configuration**: -```python -# Server configuration -trtllm-serve meta-llama/Meta-Llama-3-8B \ - --max_batch_size 256 \ # Maximum concurrent sequences - --max_num_tokens 4096 \ # Total tokens in batch - --enable_chunked_context \ # Split long prompts - --scheduler_policy max_utilization -``` - -**Performance**: -- Throughput: **4-8× higher** vs static batching -- Latency: Lower P50/P99 for mixed workloads -- GPU utilization: 80-95% vs 40-60% - -## Paged KV Cache - -**What it does**: Manages KV cache memory like OS manages virtual memory (paging). - -**Benefits**: -- 40-60% higher throughput -- No memory fragmentation -- Supports longer sequences - -**Configuration**: -```python -# Automatic paged KV cache (default) -llm = LLM( - model="meta-llama/Meta-Llama-3-8B", - kv_cache_free_gpu_mem_fraction=0.9, # Use 90% GPU mem for cache - enable_prefix_caching=True # Cache common prefixes -) -``` - -## Speculative Decoding - -**What it does**: Uses small draft model to predict multiple tokens, verified by target model in parallel. - -**Speedup**: 2-3× faster for long generations - -**Usage**: -```python -from tensorrt_llm import LLM - -# Target model (Llama 3-70B) -llm = LLM( - model="meta-llama/Meta-Llama-3-70B", - speculative_model="meta-llama/Meta-Llama-3-8B", # Draft model - num_speculative_tokens=5 # Tokens to predict ahead -) - -# Same API, 2-3× faster -outputs = llm.generate(prompts) -``` - -**Best models for drafting**: -- Target: Llama 3-70B → Draft: Llama 3-8B -- Target: Qwen2-72B → Draft: Qwen2-7B -- Same family, 8-10× smaller - -## CUDA Graphs - -**What it does**: Reduces kernel launch overhead by recording GPU operations. - -**Benefits**: -- 10-20% lower latency -- More stable P99 latency -- Better for small batch sizes - -**Configuration** (automatic by default): -```python -llm = LLM( - model="meta-llama/Meta-Llama-3-8B", - enable_cuda_graph=True, # Default: True - cuda_graph_cache_size=2 # Cache 2 graph variants -) -``` - -## Chunked Context - -**What it does**: Splits long prompts into chunks to reduce memory spikes. - -**Use case**: Prompts >8K tokens with limited GPU memory - -**Configuration**: -```bash -trtllm-serve meta-llama/Meta-Llama-3-8B \ - --max_num_tokens 4096 \ - --enable_chunked_context \ - --max_chunked_prefill_length 2048 # Process 2K tokens at a time -``` - -## Overlap Scheduling - -**What it does**: Overlaps compute and memory operations. - -**Benefits**: -- 15-25% higher throughput -- Better GPU utilization -- Default in v1.2.0+ - -**No configuration needed** - enabled automatically. - -## Quantization Comparison Table - -| Method | Memory | Speed | Accuracy | Use Case | -|--------|--------|-------|----------|----------| -| FP16 | 1× (baseline) | 1× | Best | High accuracy needed | -| FP8 | 0.5× | 2× | -0.5% ppl | **H100 default** | -| INT4 AWQ | 0.25× | 3-4× | -1.5% ppl | Memory critical | -| INT4 GPTQ | 0.25× | 3-4× | -2% ppl | Maximum speed | - -## Tuning Workflow - -1. **Start with defaults**: - ```python - llm = LLM(model="meta-llama/Meta-Llama-3-70B") - ``` - -2. **Enable FP8** (if H100): - ```python - llm = LLM(model="...", dtype="fp8") - ``` - -3. **Tune batch size**: - ```python - # Increase until OOM, then reduce 20% - trtllm-serve ... --max_batch_size 256 - ``` - -4. **Enable chunked context** (if long prompts): - ```bash - --enable_chunked_context --max_chunked_prefill_length 2048 - ``` - -5. **Try speculative decoding** (if latency critical): - ```python - llm = LLM(model="...", speculative_model="...") - ``` - -## Benchmarking - -```bash -# Install benchmark tool -pip install tensorrt_llm[benchmark] - -# Run benchmark -python benchmarks/python/benchmark.py \ - --model meta-llama/Meta-Llama-3-8B \ - --batch_size 64 \ - --input_len 128 \ - --output_len 256 \ - --dtype fp8 -``` - -**Metrics to track**: -- Throughput (tokens/sec) -- Latency P50/P90/P99 (ms) -- GPU memory usage (GB) -- GPU utilization (%) - -## Common Issues - -**OOM errors**: -- Reduce `max_batch_size` -- Reduce `max_num_tokens` -- Enable INT4 quantization -- Increase `tensor_parallel_size` - -**Low throughput**: -- Increase `max_batch_size` -- Enable in-flight batching -- Verify CUDA graphs enabled -- Check GPU utilization - -**High latency**: -- Try speculative decoding -- Reduce `max_batch_size` (less queueing) -- Use FP8 instead of FP16 diff --git a/skills/mlops/tensorrt-llm/references/serving.md b/skills/mlops/tensorrt-llm/references/serving.md deleted file mode 100644 index 6ff1f18a4..000000000 --- a/skills/mlops/tensorrt-llm/references/serving.md +++ /dev/null @@ -1,470 +0,0 @@ -# Production Serving Guide - -Comprehensive guide to deploying TensorRT-LLM in production environments. - -## Server Modes - -### trtllm-serve (Recommended) - -**Features**: -- OpenAI-compatible API -- Automatic model download and compilation -- Built-in load balancing -- Prometheus metrics -- Health checks - -**Basic usage**: -```bash -trtllm-serve meta-llama/Meta-Llama-3-8B \ - --tp_size 1 \ - --max_batch_size 256 \ - --port 8000 -``` - -**Advanced configuration**: -```bash -trtllm-serve meta-llama/Meta-Llama-3-70B \ - --tp_size 4 \ - --dtype fp8 \ - --max_batch_size 256 \ - --max_num_tokens 4096 \ - --enable_chunked_context \ - --scheduler_policy max_utilization \ - --port 8000 \ - --api_key $API_KEY # Optional authentication -``` - -### Python LLM API (For embedding) - -```python -from tensorrt_llm import LLM - -class LLMService: - def __init__(self): - self.llm = LLM( - model="meta-llama/Meta-Llama-3-8B", - dtype="fp8" - ) - - def generate(self, prompt, max_tokens=100): - from tensorrt_llm import SamplingParams - - params = SamplingParams( - max_tokens=max_tokens, - temperature=0.7 - ) - outputs = self.llm.generate([prompt], params) - return outputs[0].text - -# Use in FastAPI, Flask, etc -from fastapi import FastAPI -app = FastAPI() -service = LLMService() - -@app.post("/generate") -def generate(prompt: str): - return {"response": service.generate(prompt)} -``` - -## OpenAI-Compatible API - -### Chat Completions - -```bash -curl -X POST http://localhost:8000/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "meta-llama/Meta-Llama-3-8B", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Explain quantum computing"} - ], - "temperature": 0.7, - "max_tokens": 500, - "stream": false - }' -``` - -**Response**: -```json -{ - "id": "chat-abc123", - "object": "chat.completion", - "created": 1234567890, - "model": "meta-llama/Meta-Llama-3-8B", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "Quantum computing is..." - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 25, - "completion_tokens": 150, - "total_tokens": 175 - } -} -``` - -### Streaming - -```bash -curl -X POST http://localhost:8000/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "meta-llama/Meta-Llama-3-8B", - "messages": [{"role": "user", "content": "Count to 10"}], - "stream": true - }' -``` - -**Response** (SSE stream): -``` -data: {"choices":[{"delta":{"content":"1"}}]} - -data: {"choices":[{"delta":{"content":", 2"}}]} - -data: {"choices":[{"delta":{"content":", 3"}}]} - -data: [DONE] -``` - -### Completions - -```bash -curl -X POST http://localhost:8000/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "meta-llama/Meta-Llama-3-8B", - "prompt": "The capital of France is", - "max_tokens": 10, - "temperature": 0.0 - }' -``` - -## Monitoring - -### Prometheus Metrics - -**Enable metrics**: -```bash -trtllm-serve meta-llama/Meta-Llama-3-8B \ - --enable_metrics \ - --metrics_port 9090 -``` - -**Key metrics**: -```bash -# Scrape metrics -curl http://localhost:9090/metrics - -# Important metrics: -# - trtllm_request_success_total - Total successful requests -# - trtllm_request_latency_seconds - Request latency histogram -# - trtllm_tokens_generated_total - Total tokens generated -# - trtllm_active_requests - Current active requests -# - trtllm_queue_size - Requests waiting in queue -# - trtllm_gpu_memory_usage_bytes - GPU memory usage -# - trtllm_kv_cache_usage_ratio - KV cache utilization -``` - -### Health Checks - -```bash -# Readiness probe -curl http://localhost:8000/health/ready - -# Liveness probe -curl http://localhost:8000/health/live - -# Model info -curl http://localhost:8000/v1/models -``` - -**Kubernetes probes**: -```yaml -livenessProbe: - httpGet: - path: /health/live - port: 8000 - initialDelaySeconds: 60 - periodSeconds: 10 - -readinessProbe: - httpGet: - path: /health/ready - port: 8000 - initialDelaySeconds: 30 - periodSeconds: 5 -``` - -## Production Deployment - -### Docker Deployment - -**Dockerfile**: -```dockerfile -FROM nvidia/tensorrt_llm:latest - -# Copy any custom configs -COPY config.yaml /app/config.yaml - -# Expose ports -EXPOSE 8000 9090 - -# Start server -CMD ["trtllm-serve", "meta-llama/Meta-Llama-3-8B", \ - "--tp_size", "4", \ - "--dtype", "fp8", \ - "--max_batch_size", "256", \ - "--enable_metrics", \ - "--metrics_port", "9090"] -``` - -**Run container**: -```bash -docker run --gpus all -p 8000:8000 -p 9090:9090 \ - tensorrt-llm:latest -``` - -### Kubernetes Deployment - -**Complete deployment**: -```yaml -apiVersion: apps/v1 -kind: Deployment -metadata: - name: tensorrt-llm -spec: - replicas: 2 # Multiple replicas for HA - selector: - matchLabels: - app: tensorrt-llm - template: - metadata: - labels: - app: tensorrt-llm - spec: - containers: - - name: trtllm - image: nvidia/tensorrt_llm:latest - command: - - trtllm-serve - - meta-llama/Meta-Llama-3-70B - - --tp_size=4 - - --dtype=fp8 - - --max_batch_size=256 - - --enable_metrics - ports: - - containerPort: 8000 - name: http - - containerPort: 9090 - name: metrics - resources: - limits: - nvidia.com/gpu: 4 - livenessProbe: - httpGet: - path: /health/live - port: 8000 - readinessProbe: - httpGet: - path: /health/ready - port: 8000 ---- -apiVersion: v1 -kind: Service -metadata: - name: tensorrt-llm -spec: - selector: - app: tensorrt-llm - ports: - - name: http - port: 80 - targetPort: 8000 - - name: metrics - port: 9090 - targetPort: 9090 - type: LoadBalancer -``` - -### Load Balancing - -**NGINX configuration**: -```nginx -upstream tensorrt_llm { - least_conn; # Route to least busy server - server trtllm-1:8000 max_fails=3 fail_timeout=30s; - server trtllm-2:8000 max_fails=3 fail_timeout=30s; - server trtllm-3:8000 max_fails=3 fail_timeout=30s; -} - -server { - listen 80; - location / { - proxy_pass http://tensorrt_llm; - proxy_read_timeout 300s; # Long timeout for slow generations - proxy_connect_timeout 10s; - } -} -``` - -## Autoscaling - -### Horizontal Pod Autoscaler (HPA) - -```yaml -apiVersion: autoscaling/v2 -kind: HorizontalPodAutoscaler -metadata: - name: tensorrt-llm-hpa -spec: - scaleTargetRef: - apiVersion: apps/v1 - kind: Deployment - name: tensorrt-llm - minReplicas: 2 - maxReplicas: 10 - metrics: - - type: Pods - pods: - metric: - name: trtllm_active_requests - target: - type: AverageValue - averageValue: "50" # Scale when avg >50 active requests -``` - -### Custom Metrics - -```yaml -# Scale based on queue size -- type: Pods - pods: - metric: - name: trtllm_queue_size - target: - type: AverageValue - averageValue: "10" -``` - -## Cost Optimization - -### GPU Selection - -**A100 80GB** ($3-4/hour): -- Use for: 70B models with FP8 -- Throughput: 10,000-15,000 tok/s (TP=4) -- Cost per 1M tokens: $0.20-0.30 - -**H100 80GB** ($6-8/hour): -- Use for: 70B models with FP8, 405B models -- Throughput: 20,000-30,000 tok/s (TP=4) -- Cost per 1M tokens: $0.15-0.25 (2× faster = lower cost) - -**L4** ($0.50-1/hour): -- Use for: 7-8B models -- Throughput: 1,000-2,000 tok/s -- Cost per 1M tokens: $0.25-0.50 - -### Batch Size Tuning - -**Impact on cost**: -- Batch size 1: 1,000 tok/s → $3/hour per 1M = $3/M tokens -- Batch size 64: 5,000 tok/s → $3/hour per 5M = $0.60/M tokens -- **5× cost reduction** with batching - -**Recommendation**: Target batch size 32-128 for cost efficiency. - -## Security - -### API Authentication - -```bash -# Generate API key -export API_KEY=$(openssl rand -hex 32) - -# Start server with authentication -trtllm-serve meta-llama/Meta-Llama-3-8B \ - --api_key $API_KEY - -# Client request -curl -X POST http://localhost:8000/v1/chat/completions \ - -H "Authorization: Bearer $API_KEY" \ - -H "Content-Type: application/json" \ - -d '{"model": "...", "messages": [...]}' -``` - -### Network Policies - -```yaml -apiVersion: networking.k8s.io/v1 -kind: NetworkPolicy -metadata: - name: tensorrt-llm-policy -spec: - podSelector: - matchLabels: - app: tensorrt-llm - policyTypes: - - Ingress - ingress: - - from: - - podSelector: - matchLabels: - app: api-gateway # Only allow from gateway - ports: - - protocol: TCP - port: 8000 -``` - -## Troubleshooting - -### High latency - -**Diagnosis**: -```bash -# Check queue size -curl http://localhost:9090/metrics | grep queue_size - -# Check active requests -curl http://localhost:9090/metrics | grep active_requests -``` - -**Solutions**: -- Scale horizontally (more replicas) -- Increase batch size (if GPU underutilized) -- Enable chunked context (if long prompts) -- Use FP8 quantization - -### OOM crashes - -**Solutions**: -- Reduce `max_batch_size` -- Reduce `max_num_tokens` -- Enable FP8 or INT4 quantization -- Increase `tensor_parallel_size` - -### Timeout errors - -**NGINX config**: -```nginx -proxy_read_timeout 600s; # 10 minutes for very long generations -proxy_send_timeout 600s; -``` - -## Best Practices - -1. **Use FP8 on H100** for 2× speedup and 50% cost reduction -2. **Monitor metrics** - Set up Prometheus + Grafana -3. **Set readiness probes** - Prevent routing to unhealthy pods -4. **Use load balancing** - Distribute load across replicas -5. **Tune batch size** - Balance latency and throughput -6. **Enable streaming** - Better UX for chat applications -7. **Set up autoscaling** - Handle traffic spikes -8. **Use persistent volumes** - Cache compiled models -9. **Implement retries** - Handle transient failures -10. **Monitor costs** - Track cost per token diff --git a/skills/music-creation/heartmula/SKILL.md b/skills/music-creation/heartmula/SKILL.md new file mode 100644 index 000000000..d8905dd5d --- /dev/null +++ b/skills/music-creation/heartmula/SKILL.md @@ -0,0 +1,170 @@ +--- +name: heartmula +description: Set up and run HeartMuLa, the open-source music generation model family (Suno-like). Generates full songs from lyrics + tags with multilingual support. +version: 1.0.0 +metadata: + hermes: + tags: [music, audio, generation, ai, heartmula, heartcodec, lyrics, songs] + related_skills: [audiocraft] +--- + +# HeartMuLa - Open-Source Music Generation + +## Overview +HeartMuLa is a family of open-source music foundation models (Apache-2.0) that generates music conditioned on lyrics and tags. Comparable to Suno for open-source. Includes: +- **HeartMuLa** - Music language model (3B/7B) for generation from lyrics + tags +- **HeartCodec** - 12.5Hz music codec for high-fidelity audio reconstruction +- **HeartTranscriptor** - Whisper-based lyrics transcription +- **HeartCLAP** - Audio-text alignment model + +## When to Use +- User wants to generate music/songs from text descriptions +- User wants an open-source Suno alternative +- User wants local/offline music generation +- User asks about HeartMuLa, heartlib, or AI music generation + +## Hardware Requirements +- **Minimum**: 8GB VRAM with `--lazy_load true` (loads/unloads models sequentially) +- **Recommended**: 16GB+ VRAM for comfortable single-GPU usage +- **Multi-GPU**: Use `--mula_device cuda:0 --codec_device cuda:1` to split across GPUs +- 3B model with lazy_load peaks at ~6.2GB VRAM + +## Installation Steps + +### 1. Clone Repository +```bash +cd ~/ # or desired directory +git clone https://github.com/HeartMuLa/heartlib.git +cd heartlib +``` + +### 2. Create Virtual Environment (Python 3.10 required) +```bash +uv venv --python 3.10 .venv +. .venv/bin/activate +uv pip install -e . +``` + +### 3. Fix Dependency Compatibility Issues + +**IMPORTANT**: As of Feb 2026, the pinned dependencies have conflicts with newer packages. Apply these fixes: + +```bash +# Upgrade datasets (old version incompatible with current pyarrow) +uv pip install --upgrade datasets + +# Upgrade transformers (needed for huggingface-hub 1.x compatibility) +uv pip install --upgrade transformers +``` + +### 4. Patch Source Code (Required for transformers 5.x) + +**Patch 1 - RoPE cache fix** in `src/heartlib/heartmula/modeling_heartmula.py`: + +In the `setup_caches` method of the `HeartMuLa` class, add RoPE reinitialization after the `reset_caches` try/except block and before the `with device:` block: + +```python +# Re-initialize RoPE caches that were skipped during meta-device loading +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE +for module in self.modules(): + if isinstance(module, Llama3ScaledRoPE) and not module.is_cache_built: + module.rope_init() + module.to(device) +``` + +**Why**: `from_pretrained` creates model on meta device first; `Llama3ScaledRoPE.rope_init()` skips cache building on meta tensors, then never rebuilds after weights are loaded to real device. + +**Patch 2 - HeartCodec loading fix** in `src/heartlib/pipelines/music_generation.py`: + +Add `ignore_mismatched_sizes=True` to ALL `HeartCodec.from_pretrained()` calls (there are 2: the eager load in `__init__` and the lazy load in the `codec` property). + +**Why**: VQ codebook `initted` buffers have shape `[1]` in checkpoint vs `[]` in model. Same data, just scalar vs 0-d tensor. Safe to ignore. + +### 5. Download Model Checkpoints +```bash +cd heartlib # project root +hf download --local-dir './ckpt' 'HeartMuLa/HeartMuLaGen' +hf download --local-dir './ckpt/HeartMuLa-oss-3B' 'HeartMuLa/HeartMuLa-oss-3B-happy-new-year' +hf download --local-dir './ckpt/HeartCodec-oss' 'HeartMuLa/HeartCodec-oss-20260123' +``` + +All 3 can be downloaded in parallel. Total size is several GB. + +## GPU / CUDA + +HeartMuLa uses CUDA by default (`--mula_device cuda --codec_device cuda`). No extra setup needed if the user has an NVIDIA GPU with PyTorch CUDA support installed. + +- The installed `torch==2.4.1` includes CUDA 12.1 support out of the box +- `torchtune` may report version `0.4.0+cpu` — this is just package metadata, it still uses CUDA via PyTorch +- To verify GPU is being used, look for "CUDA memory" lines in the output (e.g. "CUDA memory before unloading: 6.20 GB") +- **No GPU?** You can run on CPU with `--mula_device cpu --codec_device cpu`, but expect generation to be **extremely slow** (potentially 30-60+ minutes for a single song vs ~4 minutes on GPU). CPU mode also requires significant RAM (~12GB+ free). If the user has no NVIDIA GPU, recommend using a cloud GPU service (Google Colab free tier with T4, Lambda Labs, etc.) or the online demo at https://heartmula.github.io/ instead. + +## Usage + +### Basic Generation +```bash +cd heartlib +. .venv/bin/activate +python ./examples/run_music_generation.py \ + --model_path=./ckpt \ + --version="3B" \ + --lyrics="./assets/lyrics.txt" \ + --tags="./assets/tags.txt" \ + --save_path="./assets/output.mp3" \ + --lazy_load true +``` + +### Input Formatting + +**Tags** (comma-separated, no spaces): +``` +piano,happy,wedding,synthesizer,romantic +``` +or +``` +rock,energetic,guitar,drums,male-vocal +``` + +**Lyrics** (use bracketed structural tags): +``` +[Intro] + +[Verse] +Your lyrics here... + +[Chorus] +Chorus lyrics... + +[Bridge] +Bridge lyrics... + +[Outro] +``` + +### Key Parameters +| Parameter | Default | Description | +|-----------|---------|-------------| +| `--max_audio_length_ms` | 240000 | Max length in ms (240s = 4 min) | +| `--topk` | 50 | Top-k sampling | +| `--temperature` | 1.0 | Sampling temperature | +| `--cfg_scale` | 1.5 | Classifier-free guidance scale | +| `--lazy_load` | false | Load/unload models on demand (saves VRAM) | +| `--mula_dtype` | bfloat16 | Dtype for HeartMuLa (bf16 recommended) | +| `--codec_dtype` | float32 | Dtype for HeartCodec (fp32 recommended for quality) | + +### Performance +- RTF (Real-Time Factor) ≈ 1.0 — a 4-minute song takes ~4 minutes to generate +- Output: MP3, 48kHz stereo, 128kbps + +## Pitfalls +1. **Do NOT use bf16 for HeartCodec** — degrades audio quality. Use fp32 (default). +2. **Tags may be ignored** — known issue (#90). Lyrics tend to dominate; experiment with tag ordering. +3. **Triton not available on macOS** — Linux/CUDA only for GPU acceleration. +4. **RTX 5080 incompatibility** reported in upstream issues. +5. The dependency pin conflicts require the manual upgrades and patches described above. + +## Links +- Repo: https://github.com/HeartMuLa/heartlib +- Models: https://huggingface.co/HeartMuLa +- Paper: https://arxiv.org/abs/2601.10547 +- License: Apache-2.0