491 lines
11 KiB
Markdown
491 lines
11 KiB
Markdown
|
|
# PyTorch Lightning Distributed Training
|
|||
|
|
|
|||
|
|
## Distributed Strategies
|
|||
|
|
|
|||
|
|
Lightning supports multiple distributed strategies with a single parameter change.
|
|||
|
|
|
|||
|
|
### 1. DDP (DistributedDataParallel)
|
|||
|
|
|
|||
|
|
**Default strategy for multi-GPU**:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
# Automatic DDP on all available GPUs
|
|||
|
|
trainer = L.Trainer(accelerator='gpu', devices=4, strategy='ddp')
|
|||
|
|
|
|||
|
|
# Or auto-detect
|
|||
|
|
trainer = L.Trainer(accelerator='gpu', devices='auto')
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**How DDP works**:
|
|||
|
|
- Replicates model on each GPU
|
|||
|
|
- Each GPU processes different batch
|
|||
|
|
- Gradients all-reduced across GPUs
|
|||
|
|
- Model weights synchronized
|
|||
|
|
|
|||
|
|
**Launch**:
|
|||
|
|
```bash
|
|||
|
|
# Lightning handles spawning processes automatically
|
|||
|
|
python train.py
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**DDP Configuration**:
|
|||
|
|
```python
|
|||
|
|
from lightning.pytorch.strategies import DDPStrategy
|
|||
|
|
|
|||
|
|
strategy = DDPStrategy(
|
|||
|
|
find_unused_parameters=False, # Set True if model has unused params
|
|||
|
|
gradient_as_bucket_view=True, # Memory optimization
|
|||
|
|
static_graph=False, # Set True if graph doesn't change
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
trainer = L.Trainer(strategy=strategy)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 2. FSDP (Fully Sharded Data Parallel)
|
|||
|
|
|
|||
|
|
**For large models (7B+ parameters)**:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from lightning.pytorch.strategies import FSDPStrategy
|
|||
|
|
|
|||
|
|
strategy = FSDPStrategy(
|
|||
|
|
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
|
|||
|
|
activation_checkpointing=None, # Or specify layer types
|
|||
|
|
cpu_offload=False, # CPU offload for memory
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
trainer = L.Trainer(
|
|||
|
|
accelerator='gpu',
|
|||
|
|
devices=8,
|
|||
|
|
strategy=strategy,
|
|||
|
|
precision='bf16' # Recommended with FSDP
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
trainer.fit(model, train_loader)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**FSDP Sharding Strategies**:
|
|||
|
|
```python
|
|||
|
|
# FULL_SHARD (most memory efficient, equivalent to ZeRO-3)
|
|||
|
|
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
|
|||
|
|
|
|||
|
|
# SHARD_GRAD_OP (less memory efficient, equivalent to ZeRO-2)
|
|||
|
|
strategy = FSDPStrategy(sharding_strategy="SHARD_GRAD_OP")
|
|||
|
|
|
|||
|
|
# NO_SHARD (no sharding, like DDP)
|
|||
|
|
strategy = FSDPStrategy(sharding_strategy="NO_SHARD")
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Auto-wrap policy** (wrap transformer blocks):
|
|||
|
|
```python
|
|||
|
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
|||
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
|||
|
|
import functools
|
|||
|
|
|
|||
|
|
auto_wrap_policy = functools.partial(
|
|||
|
|
transformer_auto_wrap_policy,
|
|||
|
|
transformer_layer_cls={GPT2Block}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
strategy = FSDPStrategy(
|
|||
|
|
auto_wrap_policy=auto_wrap_policy,
|
|||
|
|
activation_checkpointing_policy={GPT2Block} # Checkpoint these blocks
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 3. DeepSpeed
|
|||
|
|
|
|||
|
|
**For massive models (70B+ parameters)**:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from lightning.pytorch.strategies import DeepSpeedStrategy
|
|||
|
|
|
|||
|
|
# DeepSpeed ZeRO-3 with CPU offload
|
|||
|
|
strategy = DeepSpeedStrategy(
|
|||
|
|
stage=3, # ZeRO-3
|
|||
|
|
offload_optimizer=True, # CPU offload optimizer
|
|||
|
|
offload_parameters=True, # CPU offload parameters
|
|||
|
|
cpu_checkpointing=True, # Checkpoint to CPU
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
trainer = L.Trainer(
|
|||
|
|
accelerator='gpu',
|
|||
|
|
devices=8,
|
|||
|
|
strategy=strategy,
|
|||
|
|
precision='bf16'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
trainer.fit(model, train_loader)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**DeepSpeed configuration file**:
|
|||
|
|
```json
|
|||
|
|
{
|
|||
|
|
"train_batch_size": "auto",
|
|||
|
|
"train_micro_batch_size_per_gpu": "auto",
|
|||
|
|
"gradient_accumulation_steps": "auto",
|
|||
|
|
"zero_optimization": {
|
|||
|
|
"stage": 3,
|
|||
|
|
"offload_optimizer": {
|
|||
|
|
"device": "cpu",
|
|||
|
|
"pin_memory": true
|
|||
|
|
},
|
|||
|
|
"offload_param": {
|
|||
|
|
"device": "cpu",
|
|||
|
|
"pin_memory": true
|
|||
|
|
},
|
|||
|
|
"overlap_comm": true,
|
|||
|
|
"contiguous_gradients": true,
|
|||
|
|
"reduce_bucket_size": 5e8,
|
|||
|
|
"stage3_prefetch_bucket_size": 5e8,
|
|||
|
|
"stage3_param_persistence_threshold": 1e6
|
|||
|
|
},
|
|||
|
|
"bf16": {
|
|||
|
|
"enabled": true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Use config file**:
|
|||
|
|
```python
|
|||
|
|
strategy = DeepSpeedStrategy(config='deepspeed_config.json')
|
|||
|
|
trainer = L.Trainer(strategy=strategy)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 4. DDP Spawn
|
|||
|
|
|
|||
|
|
**Windows-compatible DDP**:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
# Use when DDP doesn't work (e.g., Windows, Jupyter)
|
|||
|
|
trainer = L.Trainer(
|
|||
|
|
accelerator='gpu',
|
|||
|
|
devices=2,
|
|||
|
|
strategy='ddp_spawn' # Spawns new processes
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Note**: Slower than DDP due to process spawning overhead
|
|||
|
|
|
|||
|
|
## Multi-Node Training
|
|||
|
|
|
|||
|
|
### Setup Multi-Node Cluster
|
|||
|
|
|
|||
|
|
**Node 0 (master)**:
|
|||
|
|
```bash
|
|||
|
|
export MASTER_ADDR=192.168.1.100
|
|||
|
|
export MASTER_PORT=12355
|
|||
|
|
export WORLD_SIZE=16 # 2 nodes × 8 GPUs
|
|||
|
|
export NODE_RANK=0
|
|||
|
|
|
|||
|
|
python train.py
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Node 1 (worker)**:
|
|||
|
|
```bash
|
|||
|
|
export MASTER_ADDR=192.168.1.100
|
|||
|
|
export MASTER_PORT=12355
|
|||
|
|
export WORLD_SIZE=16
|
|||
|
|
export NODE_RANK=1
|
|||
|
|
|
|||
|
|
python train.py
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Training script**:
|
|||
|
|
```python
|
|||
|
|
trainer = L.Trainer(
|
|||
|
|
accelerator='gpu',
|
|||
|
|
devices=8, # GPUs per node
|
|||
|
|
num_nodes=2, # Total nodes
|
|||
|
|
strategy='ddp'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
trainer.fit(model, train_loader)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### SLURM Integration
|
|||
|
|
|
|||
|
|
**SLURM job script**:
|
|||
|
|
```bash
|
|||
|
|
#!/bin/bash
|
|||
|
|
#SBATCH --nodes=4
|
|||
|
|
#SBATCH --ntasks-per-node=8
|
|||
|
|
#SBATCH --gres=gpu:8
|
|||
|
|
#SBATCH --time=24:00:00
|
|||
|
|
|
|||
|
|
# Lightning auto-detects SLURM environment
|
|||
|
|
srun python train.py
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Training script** (no changes needed):
|
|||
|
|
```python
|
|||
|
|
# Lightning automatically reads SLURM environment variables
|
|||
|
|
trainer = L.Trainer(
|
|||
|
|
accelerator='gpu',
|
|||
|
|
devices=8,
|
|||
|
|
num_nodes=4, # From SBATCH --nodes
|
|||
|
|
strategy='ddp'
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### Kubernetes (KubeFlow)
|
|||
|
|
|
|||
|
|
**Training script**:
|
|||
|
|
```python
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
# Lightning auto-detects Kubernetes
|
|||
|
|
trainer = L.Trainer(
|
|||
|
|
accelerator='gpu',
|
|||
|
|
devices=int(os.getenv('WORLD_SIZE', 1)),
|
|||
|
|
strategy='ddp'
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Mixed Precision Training
|
|||
|
|
|
|||
|
|
### BF16 (A100/H100)
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
trainer = L.Trainer(
|
|||
|
|
precision='bf16', # Or 'bf16-mixed'
|
|||
|
|
accelerator='gpu'
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Advantages**:
|
|||
|
|
- No gradient scaler needed
|
|||
|
|
- Same dynamic range as FP32
|
|||
|
|
- 2× speedup, 50% memory reduction
|
|||
|
|
|
|||
|
|
### FP16 (V100, older GPUs)
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
trainer = L.Trainer(
|
|||
|
|
precision='16-mixed', # Or just '16'
|
|||
|
|
accelerator='gpu'
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Automatic gradient scaling** handled by Lightning
|
|||
|
|
|
|||
|
|
### FP8 (H100)
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
# Requires transformer_engine
|
|||
|
|
# pip install transformer-engine[pytorch]
|
|||
|
|
|
|||
|
|
trainer = L.Trainer(
|
|||
|
|
precision='transformer-engine',
|
|||
|
|
accelerator='gpu'
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Benefits**: 2× faster than BF16 on H100
|
|||
|
|
|
|||
|
|
## Gradient Accumulation
|
|||
|
|
|
|||
|
|
**Simulate larger batch size**:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
trainer = L.Trainer(
|
|||
|
|
accumulate_grad_batches=4, # Accumulate 4 batches
|
|||
|
|
precision='bf16'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Effective batch = batch_size × accumulate_grad_batches × num_gpus
|
|||
|
|
# Example: 32 × 4 × 8 = 1024
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Dynamic accumulation**:
|
|||
|
|
```python
|
|||
|
|
# Accumulate more early in training
|
|||
|
|
trainer = L.Trainer(
|
|||
|
|
accumulate_grad_batches={
|
|||
|
|
0: 8, # Epochs 0-4: accumulate 8
|
|||
|
|
5: 4, # Epochs 5-9: accumulate 4
|
|||
|
|
10: 2 # Epochs 10+: accumulate 2
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Checkpointing in Distributed
|
|||
|
|
|
|||
|
|
### Save Checkpoint
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|||
|
|
|
|||
|
|
# Only rank 0 saves by default
|
|||
|
|
checkpoint = ModelCheckpoint(
|
|||
|
|
dirpath='checkpoints/',
|
|||
|
|
filename='model-{epoch:02d}',
|
|||
|
|
save_top_k=3
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
trainer = L.Trainer(callbacks=[checkpoint], strategy='ddp')
|
|||
|
|
trainer.fit(model, train_loader)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Manual save**:
|
|||
|
|
```python
|
|||
|
|
class MyModel(L.LightningModule):
|
|||
|
|
def training_step(self, batch, batch_idx):
|
|||
|
|
# Training...
|
|||
|
|
loss = ...
|
|||
|
|
|
|||
|
|
# Save every 1000 steps (only rank 0)
|
|||
|
|
if batch_idx % 1000 == 0 and self.trainer.is_global_zero:
|
|||
|
|
self.trainer.save_checkpoint(f'checkpoint_step_{batch_idx}.ckpt')
|
|||
|
|
|
|||
|
|
return loss
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### Load Checkpoint
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
# Resume training
|
|||
|
|
trainer = L.Trainer(strategy='ddp')
|
|||
|
|
trainer.fit(model, train_loader, ckpt_path='checkpoints/last.ckpt')
|
|||
|
|
|
|||
|
|
# Load for inference
|
|||
|
|
model = MyModel.load_from_checkpoint('checkpoints/best.ckpt')
|
|||
|
|
model.eval()
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Strategy Comparison
|
|||
|
|
|
|||
|
|
| Strategy | Memory Efficiency | Speed | Use Case |
|
|||
|
|
|----------|------------------|-------|----------|
|
|||
|
|
| DDP | Low | Fast | Small models (<7B), single node |
|
|||
|
|
| FSDP | High | Medium | Large models (7-70B) |
|
|||
|
|
| DeepSpeed ZeRO-2 | Medium | Fast | Medium models (1-13B) |
|
|||
|
|
| DeepSpeed ZeRO-3 | Very High | Slower | Massive models (70B+) |
|
|||
|
|
| DDP Spawn | Low | Slow | Windows, debugging |
|
|||
|
|
|
|||
|
|
## Best Practices
|
|||
|
|
|
|||
|
|
### 1. Choose Right Strategy
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
# Model size guide
|
|||
|
|
if model_params < 1e9: # <1B
|
|||
|
|
strategy = 'ddp'
|
|||
|
|
elif model_params < 7e9: # 1-7B
|
|||
|
|
strategy = 'ddp' or DeepSpeedStrategy(stage=2)
|
|||
|
|
elif model_params < 70e9: # 7-70B
|
|||
|
|
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
|
|||
|
|
else: # 70B+
|
|||
|
|
strategy = DeepSpeedStrategy(stage=3, offload_optimizer=True)
|
|||
|
|
|
|||
|
|
trainer = L.Trainer(strategy=strategy)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 2. Avoid Sync Issues
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
class MyModel(L.LightningModule):
|
|||
|
|
def training_step(self, batch, batch_idx):
|
|||
|
|
# WRONG: This runs on all GPUs independently
|
|||
|
|
if batch_idx % 100 == 0:
|
|||
|
|
self.log_something() # Logged 8 times on 8 GPUs!
|
|||
|
|
|
|||
|
|
# CORRECT: Use is_global_zero
|
|||
|
|
if batch_idx % 100 == 0 and self.trainer.is_global_zero:
|
|||
|
|
self.log_something() # Logged once
|
|||
|
|
|
|||
|
|
loss = ...
|
|||
|
|
return loss
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 3. Efficient Data Loading
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|||
|
|
|
|||
|
|
# Lightning handles DistributedSampler automatically
|
|||
|
|
train_loader = DataLoader(
|
|||
|
|
dataset,
|
|||
|
|
batch_size=32,
|
|||
|
|
num_workers=4, # 4 workers per GPU
|
|||
|
|
pin_memory=True,
|
|||
|
|
persistent_workers=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Lightning automatically wraps with DistributedSampler in DDP
|
|||
|
|
trainer.fit(model, train_loader)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 4. Reduce Communication Overhead
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from lightning.pytorch.strategies import DDPStrategy
|
|||
|
|
|
|||
|
|
strategy = DDPStrategy(
|
|||
|
|
gradient_as_bucket_view=True, # Reduce memory copies
|
|||
|
|
static_graph=True, # If model graph doesn't change (faster)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
trainer = L.Trainer(strategy=strategy)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Common Issues
|
|||
|
|
|
|||
|
|
### Issue: NCCL Timeout
|
|||
|
|
|
|||
|
|
**Symptom**: Training hangs with `NCCL timeout` error
|
|||
|
|
|
|||
|
|
**Solution 1**: Increase timeout
|
|||
|
|
```bash
|
|||
|
|
export NCCL_TIMEOUT=3600 # 1 hour
|
|||
|
|
python train.py
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Solution 2**: Check network
|
|||
|
|
```bash
|
|||
|
|
# Test inter-node communication
|
|||
|
|
nvidia-smi nvlink -s
|
|||
|
|
|
|||
|
|
# Verify all nodes can ping each other
|
|||
|
|
ping <node-2-ip>
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### Issue: OOM with FSDP
|
|||
|
|
|
|||
|
|
**Solution**: Enable CPU offload
|
|||
|
|
```python
|
|||
|
|
strategy = FSDPStrategy(
|
|||
|
|
sharding_strategy="FULL_SHARD",
|
|||
|
|
cpu_offload=True # Offload to CPU
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### Issue: Different Results with DDP
|
|||
|
|
|
|||
|
|
**Cause**: Different random seeds per GPU
|
|||
|
|
|
|||
|
|
**Solution**: Set seed in LightningModule
|
|||
|
|
```python
|
|||
|
|
class MyModel(L.LightningModule):
|
|||
|
|
def __init__(self):
|
|||
|
|
super().__init__()
|
|||
|
|
L.seed_everything(42, workers=True) # Same seed everywhere
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### Issue: DeepSpeed Config Errors
|
|||
|
|
|
|||
|
|
**Solution**: Use Lightning's auto config
|
|||
|
|
```python
|
|||
|
|
strategy = DeepSpeedStrategy(
|
|||
|
|
stage=3,
|
|||
|
|
# Don't specify config file, Lightning generates automatically
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Resources
|
|||
|
|
|
|||
|
|
- Distributed strategies: https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html
|
|||
|
|
- FSDP guide: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html
|
|||
|
|
- DeepSpeed: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/deepspeed.html
|
|||
|
|
- Multi-node: https://lightning.ai/docs/pytorch/stable/clouds/cluster.html
|