Compare commits
2 Commits
fix/602
...
feat/621-p
| Author | SHA1 | Date | |
|---|---|---|---|
| c39b80aec5 | |||
| ee6919c73c |
94
pipelines/README.md
Normal file
94
pipelines/README.md
Normal file
@@ -0,0 +1,94 @@
|
||||
# Pipeline Infrastructure
|
||||
|
||||
Shared orchestrator for all batch pipelines.
|
||||
|
||||
## Components
|
||||
|
||||
### orchestrator.py
|
||||
Shared orchestrator providing:
|
||||
- **Job Queue**: SQLite-backed with priority support
|
||||
- **Worker Pool**: Configurable parallelism (default 10)
|
||||
- **Token Budget**: Per-job tracking and limits
|
||||
- **Checkpointing**: Resume from any point after restart
|
||||
- **Rate Limiting**: Provider-aware request throttling
|
||||
- **Retry Logic**: Exponential backoff with configurable retries
|
||||
- **Reporting**: Generate summary reports
|
||||
|
||||
## Usage
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from pipelines.orchestrator import PipelineOrchestrator, JobPriority
|
||||
|
||||
# Create orchestrator
|
||||
orchestrator = PipelineOrchestrator(max_workers=10)
|
||||
|
||||
# Register pipeline handler
|
||||
def my_handler(job):
|
||||
# Process job.task
|
||||
return {"result": "done"}
|
||||
|
||||
orchestrator.register_handler("my_pipeline", my_handler)
|
||||
|
||||
# Submit jobs
|
||||
job_id = orchestrator.submit_job(
|
||||
pipeline="my_pipeline",
|
||||
task={"action": "process", "data": "..."},
|
||||
priority=JobPriority.HIGH,
|
||||
token_budget=100000
|
||||
)
|
||||
|
||||
# Run orchestrator
|
||||
orchestrator.run()
|
||||
```
|
||||
|
||||
### CLI
|
||||
```bash
|
||||
# Submit a job
|
||||
python -m pipelines.orchestrator submit my_pipeline --task '{"action": "process"}'
|
||||
|
||||
# Run orchestrator
|
||||
python -m pipelines.orchestrator run --workers 10 --max-jobs 100
|
||||
|
||||
# Check job status
|
||||
python -m pipelines.orchestrator status <job_id>
|
||||
|
||||
# Resume paused job
|
||||
python -m pipelines.orchestrator resume <job_id>
|
||||
|
||||
# Show stats
|
||||
python -m pipelines.orchestrator stats
|
||||
|
||||
# Generate report
|
||||
python -m pipelines.orchestrator report
|
||||
```
|
||||
|
||||
## Database
|
||||
|
||||
Jobs are stored in `~/.hermes/pipelines/orchestrator.db`:
|
||||
- `jobs` - Job queue and state
|
||||
- `checkpoints` - Resume points
|
||||
- `reports` - Generated reports
|
||||
|
||||
## Configuration
|
||||
|
||||
### Rate Limits
|
||||
```python
|
||||
orchestrator.configure_rate_limit("Nous", rpm=60, tpm=1000000)
|
||||
orchestrator.configure_rate_limit("Anthropic", rpm=50, tpm=800000)
|
||||
```
|
||||
|
||||
### Token Budgets
|
||||
Default: 1M tokens per job. Override per-job:
|
||||
```python
|
||||
orchestrator.submit_job("pipeline", task, token_budget=500000)
|
||||
```
|
||||
|
||||
## Pipelines
|
||||
|
||||
All pipelines share this orchestrator:
|
||||
1. **batch-runner** - Run prompts across datasets
|
||||
2. **data-gen** - Generate training data
|
||||
3. **eval-runner** - Run evaluations
|
||||
4. **trajectory-compress** - Compress trajectories
|
||||
5. **web-research** - Research tasks
|
||||
807
pipelines/orchestrator.py
Normal file
807
pipelines/orchestrator.py
Normal file
@@ -0,0 +1,807 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pipeline Orchestrator - Shared infrastructure for all pipelines.
|
||||
|
||||
Provides:
|
||||
- Job queue (SQLite-backed)
|
||||
- Parallel worker pool (configurable, default 10)
|
||||
- Token budget tracking per job
|
||||
- Progress persistence (resume from checkpoint)
|
||||
- Rate limiting (respect provider limits)
|
||||
- Error retry with exponential backoff
|
||||
- Final report generation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
import hashlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
HERMES_HOME = Path.home() / ".hermes"
|
||||
PIPELINES_DIR = HERMES_HOME / "pipelines"
|
||||
ORCHESTRATOR_DB = PIPELINES_DIR / "orchestrator.db"
|
||||
DEFAULT_WORKERS = 10
|
||||
DEFAULT_TOKEN_BUDGET = 1_000_000 # 1M tokens default
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class JobPriority(Enum):
|
||||
LOW = 0
|
||||
NORMAL = 5
|
||||
HIGH = 10
|
||||
CRITICAL = 20
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobCheckpoint:
|
||||
"""Checkpoint for resumable job execution."""
|
||||
job_id: str
|
||||
step: int
|
||||
data: Dict[str, Any]
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'JobCheckpoint':
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenUsage:
|
||||
"""Token usage tracking."""
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
cache_write_tokens: int = 0
|
||||
cost_usd: float = 0.0
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self.input_tokens + self.output_tokens
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'TokenUsage':
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
"""A pipeline job."""
|
||||
id: str
|
||||
pipeline: str
|
||||
task: Dict[str, Any]
|
||||
status: JobStatus = JobStatus.PENDING
|
||||
priority: JobPriority = JobPriority.NORMAL
|
||||
token_budget: int = DEFAULT_TOKEN_BUDGET
|
||||
token_usage: TokenUsage = field(default_factory=TokenUsage)
|
||||
retry_count: int = 0
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=time.time)
|
||||
started_at: Optional[float] = None
|
||||
completed_at: Optional[float] = None
|
||||
error: Optional[str] = None
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
checkpoint: Optional[JobCheckpoint] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = asdict(self)
|
||||
d['status'] = self.status.value
|
||||
d['priority'] = self.priority.value
|
||||
d['token_usage'] = self.token_usage.to_dict()
|
||||
if self.checkpoint:
|
||||
d['checkpoint'] = self.checkpoint.to_dict()
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'Job':
|
||||
data['status'] = JobStatus(data['status'])
|
||||
data['priority'] = JobPriority(data['priority'])
|
||||
data['token_usage'] = TokenUsage.from_dict(data.get('token_usage', {}))
|
||||
if data.get('checkpoint'):
|
||||
data['checkpoint'] = JobCheckpoint.from_dict(data['checkpoint'])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter for API providers."""
|
||||
|
||||
def __init__(self):
|
||||
self.limits: Dict[str, Dict[str, Any]] = {}
|
||||
self.requests: Dict[str, List[float]] = {}
|
||||
|
||||
def configure(self, provider: str, requests_per_minute: int, tokens_per_minute: int):
|
||||
"""Configure rate limits for a provider."""
|
||||
self.limits[provider] = {
|
||||
'rpm': requests_per_minute,
|
||||
'tpm': tokens_per_minute,
|
||||
}
|
||||
if provider not in self.requests:
|
||||
self.requests[provider] = []
|
||||
|
||||
def can_proceed(self, provider: str, tokens: int = 0) -> Tuple[bool, float]:
|
||||
"""Check if request can proceed. Returns (can_proceed, wait_seconds)."""
|
||||
if provider not in self.limits:
|
||||
return True, 0.0
|
||||
|
||||
now = time.time()
|
||||
minute_ago = now - 60
|
||||
|
||||
# Clean old requests
|
||||
self.requests[provider] = [t for t in self.requests[provider] if t > minute_ago]
|
||||
|
||||
limit = self.limits[provider]
|
||||
|
||||
# Check RPM
|
||||
if len(self.requests[provider]) >= limit['rpm']:
|
||||
oldest = min(self.requests[provider])
|
||||
wait = 60 - (now - oldest)
|
||||
return False, max(0, wait)
|
||||
|
||||
return True, 0.0
|
||||
|
||||
def record_request(self, provider: str, tokens: int = 0):
|
||||
"""Record a request."""
|
||||
if provider not in self.requests:
|
||||
self.requests[provider] = []
|
||||
self.requests[provider].append(time.time())
|
||||
|
||||
|
||||
class OrchestratorDB:
|
||||
"""SQLite-backed job queue and state management."""
|
||||
|
||||
SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
pipeline TEXT NOT NULL,
|
||||
task TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
priority INTEGER NOT NULL,
|
||||
token_budget INTEGER NOT NULL,
|
||||
token_usage TEXT NOT NULL,
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
max_retries INTEGER DEFAULT 3,
|
||||
created_at REAL NOT NULL,
|
||||
started_at REAL,
|
||||
completed_at REAL,
|
||||
error TEXT,
|
||||
result TEXT,
|
||||
checkpoint TEXT,
|
||||
metadata TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_pipeline ON jobs(pipeline);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_priority ON jobs(priority DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS checkpoints (
|
||||
job_id TEXT PRIMARY KEY,
|
||||
step INTEGER NOT NULL,
|
||||
data TEXT NOT NULL,
|
||||
timestamp REAL NOT NULL,
|
||||
FOREIGN KEY (job_id) REFERENCES jobs(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS reports (
|
||||
id TEXT PRIMARY KEY,
|
||||
pipeline TEXT NOT NULL,
|
||||
job_ids TEXT NOT NULL,
|
||||
summary TEXT NOT NULL,
|
||||
token_usage TEXT NOT NULL,
|
||||
created_at REAL NOT NULL
|
||||
);
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Optional[Path] = None):
|
||||
self.db_path = db_path or ORCHESTRATOR_DB
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database schema."""
|
||||
conn = sqlite3.connect(str(self.db_path))
|
||||
conn.executescript(self.SCHEMA)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _get_conn(self) -> sqlite3.Connection:
|
||||
"""Get database connection."""
|
||||
conn = sqlite3.connect(str(self.db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def save_job(self, job: Job):
|
||||
"""Save or update a job."""
|
||||
conn = self._get_conn()
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO jobs
|
||||
(id, pipeline, task, status, priority, token_budget, token_usage,
|
||||
retry_count, max_retries, created_at, started_at, completed_at,
|
||||
error, result, checkpoint, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
job.id, job.pipeline, json.dumps(job.task), job.status.value,
|
||||
job.priority.value, job.token_budget, json.dumps(job.token_usage.to_dict()),
|
||||
job.retry_count, job.max_retries, job.created_at, job.started_at,
|
||||
job.completed_at, job.error, json.dumps(job.result) if job.result else None,
|
||||
json.dumps(job.checkpoint.to_dict()) if job.checkpoint else None,
|
||||
json.dumps(job.metadata)
|
||||
))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def get_job(self, job_id: str) -> Optional[Job]:
|
||||
"""Get a job by ID."""
|
||||
conn = self._get_conn()
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
conn.close()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return Job(
|
||||
id=row['id'],
|
||||
pipeline=row['pipeline'],
|
||||
task=json.loads(row['task']),
|
||||
status=JobStatus(row['status']),
|
||||
priority=JobPriority(row['priority']),
|
||||
token_budget=row['token_budget'],
|
||||
token_usage=TokenUsage.from_dict(json.loads(row['token_usage'])),
|
||||
retry_count=row['retry_count'],
|
||||
max_retries=row['max_retries'],
|
||||
created_at=row['created_at'],
|
||||
started_at=row['started_at'],
|
||||
completed_at=row['completed_at'],
|
||||
error=row['error'],
|
||||
result=json.loads(row['result']) if row['result'] else None,
|
||||
checkpoint=JobCheckpoint.from_dict(json.loads(row['checkpoint'])) if row['checkpoint'] else None,
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else {}
|
||||
)
|
||||
|
||||
def get_next_job(self, pipeline: Optional[str] = None) -> Optional[Job]:
|
||||
"""Get next pending job (highest priority first)."""
|
||||
conn = self._get_conn()
|
||||
|
||||
query = "SELECT * FROM jobs WHERE status = 'pending'"
|
||||
params = []
|
||||
|
||||
if pipeline:
|
||||
query += " AND pipeline = ?"
|
||||
params.append(pipeline)
|
||||
|
||||
query += " ORDER BY priority DESC, created_at ASC LIMIT 1"
|
||||
|
||||
row = conn.execute(query, params).fetchone()
|
||||
conn.close()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return Job(
|
||||
id=row['id'],
|
||||
pipeline=row['pipeline'],
|
||||
task=json.loads(row['task']),
|
||||
status=JobStatus(row['status']),
|
||||
priority=JobPriority(row['priority']),
|
||||
token_budget=row['token_budget'],
|
||||
token_usage=TokenUsage.from_dict(json.loads(row['token_usage'])),
|
||||
retry_count=row['retry_count'],
|
||||
max_retries=row['max_retries'],
|
||||
created_at=row['created_at'],
|
||||
started_at=row['started_at'],
|
||||
completed_at=row['completed_at'],
|
||||
error=row['error'],
|
||||
result=json.loads(row['result']) if row['result'] else None,
|
||||
checkpoint=JobCheckpoint.from_dict(json.loads(row['checkpoint'])) if row['checkpoint'] else None,
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else {}
|
||||
)
|
||||
|
||||
def get_jobs_by_status(self, status: JobStatus, pipeline: Optional[str] = None) -> List[Job]:
|
||||
"""Get all jobs with given status."""
|
||||
conn = self._get_conn()
|
||||
|
||||
query = "SELECT * FROM jobs WHERE status = ?"
|
||||
params = [status.value]
|
||||
|
||||
if pipeline:
|
||||
query += " AND pipeline = ?"
|
||||
params.append(pipeline)
|
||||
|
||||
query += " ORDER BY priority DESC, created_at ASC"
|
||||
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
conn.close()
|
||||
|
||||
return [
|
||||
Job(
|
||||
id=row['id'],
|
||||
pipeline=row['pipeline'],
|
||||
task=json.loads(row['task']),
|
||||
status=JobStatus(row['status']),
|
||||
priority=JobPriority(row['priority']),
|
||||
token_budget=row['token_budget'],
|
||||
token_usage=TokenUsage.from_dict(json.loads(row['token_usage'])),
|
||||
retry_count=row['retry_count'],
|
||||
max_retries=row['max_retries'],
|
||||
created_at=row['created_at'],
|
||||
started_at=row['started_at'],
|
||||
completed_at=row['completed_at'],
|
||||
error=row['error'],
|
||||
result=json.loads(row['result']) if row['result'] else None,
|
||||
checkpoint=JobCheckpoint.from_dict(json.loads(row['checkpoint'])) if row['checkpoint'] else None,
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else {}
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def save_checkpoint(self, job_id: str, checkpoint: JobCheckpoint):
|
||||
"""Save a checkpoint for a job."""
|
||||
conn = self._get_conn()
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO checkpoints (job_id, step, data, timestamp)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (job_id, checkpoint.step, json.dumps(checkpoint.data), checkpoint.timestamp))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def get_checkpoint(self, job_id: str) -> Optional[JobCheckpoint]:
|
||||
"""Get the latest checkpoint for a job."""
|
||||
conn = self._get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT * FROM checkpoints WHERE job_id = ? ORDER BY step DESC LIMIT 1",
|
||||
(job_id,)
|
||||
).fetchone()
|
||||
conn.close()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return JobCheckpoint(
|
||||
job_id=row['job_id'],
|
||||
step=row['step'],
|
||||
data=json.loads(row['data']),
|
||||
timestamp=row['timestamp']
|
||||
)
|
||||
|
||||
def get_stats(self, pipeline: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get queue statistics."""
|
||||
conn = self._get_conn()
|
||||
|
||||
query = "SELECT status, COUNT(*) as count FROM jobs"
|
||||
params = []
|
||||
|
||||
if pipeline:
|
||||
query += " WHERE pipeline = ?"
|
||||
params.append(pipeline)
|
||||
|
||||
query += " GROUP BY status"
|
||||
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
conn.close()
|
||||
|
||||
stats = {row['status']: row['count'] for row in rows}
|
||||
|
||||
# Get token usage
|
||||
conn = self._get_conn()
|
||||
query = "SELECT SUM(CAST(json_extract(token_usage, '$.input_tokens') AS INTEGER)) as input, SUM(CAST(json_extract(token_usage, '$.output_tokens') AS INTEGER)) as output FROM jobs"
|
||||
if pipeline:
|
||||
query += " WHERE pipeline = ?"
|
||||
|
||||
row = conn.execute(query, params if pipeline else []).fetchone()
|
||||
conn.close()
|
||||
|
||||
stats['total_input_tokens'] = row['input'] or 0
|
||||
stats['total_output_tokens'] = row['output'] or 0
|
||||
stats['total_tokens'] = stats['total_input_tokens'] + stats['total_output_tokens']
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
class PipelineOrchestrator:
|
||||
"""Main orchestrator for pipeline execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_workers: int = DEFAULT_WORKERS,
|
||||
token_budget: int = DEFAULT_TOKEN_BUDGET,
|
||||
db_path: Optional[Path] = None
|
||||
):
|
||||
self.max_workers = max_workers
|
||||
self.token_budget = token_budget
|
||||
self.db = OrchestratorDB(db_path)
|
||||
self.rate_limiter = RateLimiter()
|
||||
self.executor: Optional[ThreadPoolExecutor] = None
|
||||
self.running = False
|
||||
self._handlers: Dict[str, Callable] = {}
|
||||
|
||||
def register_handler(self, pipeline: str, handler: Callable):
|
||||
"""Register a handler for a pipeline type."""
|
||||
self._handlers[pipeline] = handler
|
||||
|
||||
def configure_rate_limit(self, provider: str, rpm: int, tpm: int):
|
||||
"""Configure rate limits for a provider."""
|
||||
self.rate_limiter.configure(provider, rpm, tpm)
|
||||
|
||||
def submit_job(
|
||||
self,
|
||||
pipeline: str,
|
||||
task: Dict[str, Any],
|
||||
priority: JobPriority = JobPriority.NORMAL,
|
||||
token_budget: Optional[int] = None,
|
||||
max_retries: int = 3,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""Submit a new job to the queue."""
|
||||
job_id = str(uuid.uuid4())
|
||||
|
||||
job = Job(
|
||||
id=job_id,
|
||||
pipeline=pipeline,
|
||||
task=task,
|
||||
priority=priority,
|
||||
token_budget=token_budget or self.token_budget,
|
||||
max_retries=max_retries,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
self.db.save_job(job)
|
||||
logger.info(f"Job {job_id} submitted to pipeline {pipeline}")
|
||||
|
||||
return job_id
|
||||
|
||||
def submit_batch(
|
||||
self,
|
||||
pipeline: str,
|
||||
tasks: List[Dict[str, Any]],
|
||||
priority: JobPriority = JobPriority.NORMAL,
|
||||
token_budget: Optional[int] = None
|
||||
) -> List[str]:
|
||||
"""Submit multiple jobs at once."""
|
||||
job_ids = []
|
||||
for task in tasks:
|
||||
job_id = self.submit_job(pipeline, task, priority, token_budget)
|
||||
job_ids.append(job_id)
|
||||
|
||||
logger.info(f"Submitted {len(job_ids)} jobs to pipeline {pipeline}")
|
||||
return job_ids
|
||||
|
||||
def _execute_job(self, job: Job) -> Job:
|
||||
"""Execute a single job with retry logic."""
|
||||
handler = self._handlers.get(job.pipeline)
|
||||
if not handler:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = f"No handler registered for pipeline: {job.pipeline}"
|
||||
job.completed_at = time.time()
|
||||
self.db.save_job(job)
|
||||
return job
|
||||
|
||||
# Check token budget
|
||||
if job.token_usage.total_tokens >= job.token_budget:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = "Token budget exceeded"
|
||||
job.completed_at = time.time()
|
||||
self.db.save_job(job)
|
||||
return job
|
||||
|
||||
# Update status
|
||||
job.status = JobStatus.RUNNING
|
||||
job.started_at = time.time()
|
||||
self.db.save_job(job)
|
||||
|
||||
try:
|
||||
# Execute with exponential backoff retry
|
||||
for attempt in range(job.max_retries + 1):
|
||||
try:
|
||||
# Check rate limits
|
||||
provider = job.metadata.get('provider', 'default')
|
||||
can_proceed, wait_time = self.rate_limiter.can_proceed(provider)
|
||||
|
||||
if not can_proceed:
|
||||
logger.info(f"Rate limited, waiting {wait_time:.1f}s")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Execute the handler
|
||||
result = handler(job)
|
||||
|
||||
# Record request
|
||||
self.rate_limiter.record_request(provider)
|
||||
|
||||
# Update job with result
|
||||
job.result = result
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.completed_at = time.time()
|
||||
|
||||
# Update token usage from result if provided
|
||||
if 'token_usage' in result:
|
||||
usage = result['token_usage']
|
||||
job.token_usage.input_tokens += usage.get('input_tokens', 0)
|
||||
job.token_usage.output_tokens += usage.get('output_tokens', 0)
|
||||
job.token_usage.cache_read_tokens += usage.get('cache_read_tokens', 0)
|
||||
job.token_usage.cache_write_tokens += usage.get('cache_write_tokens', 0)
|
||||
job.token_usage.cost_usd += usage.get('cost_usd', 0)
|
||||
|
||||
self.db.save_job(job)
|
||||
return job
|
||||
|
||||
except Exception as e:
|
||||
job.retry_count += 1
|
||||
|
||||
if job.retry_count >= job.max_retries:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = str(e)
|
||||
job.completed_at = time.time()
|
||||
self.db.save_job(job)
|
||||
return job
|
||||
|
||||
# Exponential backoff
|
||||
wait_time = (2 ** job.retry_count) + (time.time() % 1)
|
||||
logger.warning(f"Job {job.id} failed (attempt {job.retry_count}), retrying in {wait_time:.1f}s: {e}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
except Exception as e:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = f"Unexpected error: {str(e)}"
|
||||
job.completed_at = time.time()
|
||||
self.db.save_job(job)
|
||||
return job
|
||||
|
||||
return job
|
||||
|
||||
def run(self, pipeline: Optional[str] = None, max_jobs: Optional[int] = None):
|
||||
"""Run the orchestrator, processing jobs from the queue."""
|
||||
self.running = True
|
||||
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
||||
|
||||
logger.info(f"Orchestrator starting (workers={self.max_workers})")
|
||||
|
||||
try:
|
||||
jobs_processed = 0
|
||||
|
||||
while self.running:
|
||||
# Get next job
|
||||
job = self.db.get_next_job(pipeline)
|
||||
|
||||
if not job:
|
||||
# No pending jobs, wait a bit
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
# Submit to thread pool
|
||||
future = self.executor.submit(self._execute_job, job)
|
||||
|
||||
# Don't wait for completion, get next job
|
||||
jobs_processed += 1
|
||||
|
||||
if max_jobs and jobs_processed >= max_jobs:
|
||||
logger.info(f"Reached max_jobs limit ({max_jobs})")
|
||||
break
|
||||
|
||||
finally:
|
||||
self.executor.shutdown(wait=True)
|
||||
self.running = False
|
||||
logger.info(f"Orchestrator stopped (processed {jobs_processed} jobs)")
|
||||
|
||||
def run_single(self, job_id: str) -> Job:
|
||||
"""Run a single job by ID (useful for resume)."""
|
||||
job = self.db.get_job(job_id)
|
||||
|
||||
if not job:
|
||||
raise ValueError(f"Job not found: {job_id}")
|
||||
|
||||
if job.status not in (JobStatus.PENDING, JobStatus.FAILED):
|
||||
raise ValueError(f"Job {job_id} is not pending or failed (status: {job.status})")
|
||||
|
||||
# Reset for retry
|
||||
if job.status == JobStatus.FAILED:
|
||||
job.status = JobStatus.PENDING
|
||||
job.retry_count = 0
|
||||
job.error = None
|
||||
self.db.save_job(job)
|
||||
|
||||
return self._execute_job(job)
|
||||
|
||||
def pause_job(self, job_id: str):
|
||||
"""Pause a job."""
|
||||
job = self.db.get_job(job_id)
|
||||
if job and job.status == JobStatus.RUNNING:
|
||||
job.status = JobStatus.PAUSED
|
||||
self.db.save_job(job)
|
||||
logger.info(f"Job {job_id} paused")
|
||||
|
||||
def resume_job(self, job_id: str) -> Job:
|
||||
"""Resume a paused job from checkpoint."""
|
||||
job = self.db.get_job(job_id)
|
||||
|
||||
if not job:
|
||||
raise ValueError(f"Job not found: {job_id}")
|
||||
|
||||
if job.status != JobStatus.PAUSED:
|
||||
raise ValueError(f"Job {job_id} is not paused (status: {job.status})")
|
||||
|
||||
# Load checkpoint if exists
|
||||
checkpoint = self.db.get_checkpoint(job_id)
|
||||
if checkpoint:
|
||||
job.checkpoint = checkpoint
|
||||
logger.info(f"Resuming job {job_id} from checkpoint step {checkpoint.step}")
|
||||
|
||||
job.status = JobStatus.PENDING
|
||||
self.db.save_job(job)
|
||||
|
||||
return self._execute_job(job)
|
||||
|
||||
def cancel_job(self, job_id: str):
|
||||
"""Cancel a job."""
|
||||
job = self.db.get_job(job_id)
|
||||
if job and job.status in (JobStatus.PENDING, JobStatus.RUNNING, JobStatus.PAUSED):
|
||||
job.status = JobStatus.CANCELLED
|
||||
job.completed_at = time.time()
|
||||
self.db.save_job(job)
|
||||
logger.info(f"Job {job_id} cancelled")
|
||||
|
||||
def get_progress(self, job_id: str) -> Dict[str, Any]:
|
||||
"""Get job progress."""
|
||||
job = self.db.get_job(job_id)
|
||||
|
||||
if not job:
|
||||
return {"error": "Job not found"}
|
||||
|
||||
progress = {
|
||||
"job_id": job.id,
|
||||
"pipeline": job.pipeline,
|
||||
"status": job.status.value,
|
||||
"retry_count": job.retry_count,
|
||||
"token_usage": job.token_usage.to_dict(),
|
||||
"token_budget": job.token_budget,
|
||||
"token_percent": (job.token_usage.total_tokens / job.token_budget * 100) if job.token_budget > 0 else 0,
|
||||
"created_at": job.created_at,
|
||||
"started_at": job.started_at,
|
||||
"completed_at": job.completed_at,
|
||||
}
|
||||
|
||||
if job.checkpoint:
|
||||
progress["checkpoint"] = {
|
||||
"step": job.checkpoint.step,
|
||||
"timestamp": job.checkpoint.timestamp,
|
||||
}
|
||||
|
||||
if job.error:
|
||||
progress["error"] = job.error
|
||||
|
||||
return progress
|
||||
|
||||
def generate_report(self, pipeline: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Generate a summary report."""
|
||||
stats = self.db.get_stats(pipeline)
|
||||
|
||||
completed_jobs = self.db.get_jobs_by_status(JobStatus.COMPLETED, pipeline)
|
||||
failed_jobs = self.db.get_jobs_by_status(JobStatus.FAILED, pipeline)
|
||||
|
||||
# Calculate timing stats
|
||||
durations = []
|
||||
for job in completed_jobs:
|
||||
if job.started_at and job.completed_at:
|
||||
durations.append(job.completed_at - job.started_at)
|
||||
|
||||
report = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"pipeline": pipeline or "all",
|
||||
"stats": stats,
|
||||
"completed": len(completed_jobs),
|
||||
"failed": len(failed_jobs),
|
||||
"success_rate": len(completed_jobs) / (len(completed_jobs) + len(failed_jobs)) * 100 if (completed_jobs or failed_jobs) else 0,
|
||||
"avg_duration": sum(durations) / len(durations) if durations else 0,
|
||||
"total_tokens": stats.get('total_tokens', 0),
|
||||
}
|
||||
|
||||
return report
|
||||
|
||||
def stop(self):
|
||||
"""Stop the orchestrator."""
|
||||
self.running = False
|
||||
logger.info("Orchestrator stop requested")
|
||||
|
||||
|
||||
# CLI interface
|
||||
def main():
|
||||
"""CLI for orchestrator management."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Pipeline Orchestrator")
|
||||
subparsers = parser.add_subparsers(dest="command")
|
||||
|
||||
# Submit job
|
||||
submit_parser = subparsers.add_parser("submit", help="Submit a job")
|
||||
submit_parser.add_argument("pipeline", help="Pipeline name")
|
||||
submit_parser.add_argument("--task", required=True, help="Task JSON")
|
||||
submit_parser.add_argument("--priority", type=int, default=5, help="Job priority")
|
||||
submit_parser.add_argument("--budget", type=int, help="Token budget")
|
||||
|
||||
# Run orchestrator
|
||||
run_parser = subparsers.add_parser("run", help="Run orchestrator")
|
||||
run_parser.add_argument("--pipeline", help="Filter by pipeline")
|
||||
run_parser.add_argument("--workers", type=int, default=10, help="Max workers")
|
||||
run_parser.add_argument("--max-jobs", type=int, help="Max jobs to process")
|
||||
|
||||
# Job management
|
||||
status_parser = subparsers.add_parser("status", help="Get job status")
|
||||
status_parser.add_argument("job_id", help="Job ID")
|
||||
|
||||
resume_parser = subparsers.add_parser("resume", help="Resume paused job")
|
||||
resume_parser.add_argument("job_id", help="Job ID")
|
||||
|
||||
cancel_parser = subparsers.add_parser("cancel", help="Cancel job")
|
||||
cancel_parser.add_argument("job_id", help="Job ID")
|
||||
|
||||
# Stats
|
||||
stats_parser = subparsers.add_parser("stats", help="Show queue stats")
|
||||
stats_parser.add_argument("--pipeline", help="Filter by pipeline")
|
||||
|
||||
# Report
|
||||
report_parser = subparsers.add_parser("report", help="Generate report")
|
||||
report_parser.add_argument("--pipeline", help="Filter by pipeline")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
orchestrator = PipelineOrchestrator(max_workers=args.workers if hasattr(args, 'workers') else 10)
|
||||
|
||||
if args.command == "submit":
|
||||
task = json.loads(args.task)
|
||||
priority = JobPriority(args.priority)
|
||||
job_id = orchestrator.submit_job(args.pipeline, task, priority, args.budget)
|
||||
print(f"Job submitted: {job_id}")
|
||||
|
||||
elif args.command == "run":
|
||||
orchestrator.run(args.pipeline, args.max_jobs)
|
||||
|
||||
elif args.command == "status":
|
||||
progress = orchestrator.get_progress(args.job_id)
|
||||
print(json.dumps(progress, indent=2))
|
||||
|
||||
elif args.command == "resume":
|
||||
job = orchestrator.resume_job(args.job_id)
|
||||
print(f"Job {args.job_id} completed with status: {job.status.value}")
|
||||
|
||||
elif args.command == "cancel":
|
||||
orchestrator.cancel_job(args.job_id)
|
||||
print(f"Job {args.job_id} cancelled")
|
||||
|
||||
elif args.command == "stats":
|
||||
stats = orchestrator.db.get_stats(args.pipeline)
|
||||
print(json.dumps(stats, indent=2))
|
||||
|
||||
elif args.command == "report":
|
||||
report = orchestrator.generate_report(args.pipeline)
|
||||
print(json.dumps(report, indent=2))
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user