Fixes #621 - Fix DEFAULT_TOKEN_BUDGET syntax error - Resume paused/running jobs with checkpoints on restart - Proper future collection and drain in run() - Add 'list' CLI command for job inspection - Throttle when at worker capacity
864 lines
30 KiB
Python
864 lines
30 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Pipeline Orchestrator - Shared infrastructure for all pipelines.
|
|
|
|
Provides:
|
|
- Job queue (SQLite-backed)
|
|
- Parallel worker pool (configurable, default 10)
|
|
- Token budget tracking per job
|
|
- Progress persistence (resume from checkpoint)
|
|
- Rate limiting (respect provider limits)
|
|
- Error retry with exponential backoff
|
|
- Final report generation
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import sqlite3
|
|
import time
|
|
import uuid
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from dataclasses import dataclass, field, asdict
|
|
from datetime import datetime, timedelta
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
import hashlib
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configuration
|
|
HERMES_HOME = Path.home() / ".hermes"
|
|
PIPELINES_DIR = HERMES_HOME / "pipelines"
|
|
ORCHESTRATOR_DB = PIPELINES_DIR / "orchestrator.db"
|
|
DEFAULT_WORKERS = 10
|
|
DEFAULT_TOKEN_BUDGET = 1_000_000 # 1M tokens default
|
|
|
|
|
|
class JobStatus(Enum):
|
|
PENDING = "pending"
|
|
RUNNING = "running"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
PAUSED = "paused"
|
|
CANCELLED = "cancelled"
|
|
|
|
|
|
class JobPriority(Enum):
|
|
LOW = 0
|
|
NORMAL = 5
|
|
HIGH = 10
|
|
CRITICAL = 20
|
|
|
|
|
|
@dataclass
|
|
class JobCheckpoint:
|
|
"""Checkpoint for resumable job execution."""
|
|
job_id: str
|
|
step: int
|
|
data: Dict[str, Any]
|
|
timestamp: float = field(default_factory=time.time)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'JobCheckpoint':
|
|
return cls(**data)
|
|
|
|
|
|
@dataclass
|
|
class TokenUsage:
|
|
"""Token usage tracking."""
|
|
input_tokens: int = 0
|
|
output_tokens: int = 0
|
|
cache_read_tokens: int = 0
|
|
cache_write_tokens: int = 0
|
|
cost_usd: float = 0.0
|
|
|
|
@property
|
|
def total_tokens(self) -> int:
|
|
return self.input_tokens + self.output_tokens
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'TokenUsage':
|
|
return cls(**data)
|
|
|
|
|
|
@dataclass
|
|
class Job:
|
|
"""A pipeline job."""
|
|
id: str
|
|
pipeline: str
|
|
task: Dict[str, Any]
|
|
status: JobStatus = JobStatus.PENDING
|
|
priority: JobPriority = JobPriority.NORMAL
|
|
token_budget: int = DEFAULT_TOKEN_BUDGET
|
|
token_usage: TokenUsage = field(default_factory=TokenUsage)
|
|
retry_count: int = 0
|
|
max_retries: int = 3
|
|
created_at: float = field(default_factory=time.time)
|
|
started_at: Optional[float] = None
|
|
completed_at: Optional[float] = None
|
|
error: Optional[str] = None
|
|
result: Optional[Dict[str, Any]] = None
|
|
checkpoint: Optional[JobCheckpoint] = None
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
d = asdict(self)
|
|
d['status'] = self.status.value
|
|
d['priority'] = self.priority.value
|
|
d['token_usage'] = self.token_usage.to_dict()
|
|
if self.checkpoint:
|
|
d['checkpoint'] = self.checkpoint.to_dict()
|
|
return d
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'Job':
|
|
data['status'] = JobStatus(data['status'])
|
|
data['priority'] = JobPriority(data['priority'])
|
|
data['token_usage'] = TokenUsage.from_dict(data.get('token_usage', {}))
|
|
if data.get('checkpoint'):
|
|
data['checkpoint'] = JobCheckpoint.from_dict(data['checkpoint'])
|
|
return cls(**data)
|
|
|
|
|
|
class RateLimiter:
|
|
"""Rate limiter for API providers."""
|
|
|
|
def __init__(self):
|
|
self.limits: Dict[str, Dict[str, Any]] = {}
|
|
self.requests: Dict[str, List[float]] = {}
|
|
|
|
def configure(self, provider: str, requests_per_minute: int, tokens_per_minute: int):
|
|
"""Configure rate limits for a provider."""
|
|
self.limits[provider] = {
|
|
'rpm': requests_per_minute,
|
|
'tpm': tokens_per_minute,
|
|
}
|
|
if provider not in self.requests:
|
|
self.requests[provider] = []
|
|
|
|
def can_proceed(self, provider: str, tokens: int = 0) -> Tuple[bool, float]:
|
|
"""Check if request can proceed. Returns (can_proceed, wait_seconds)."""
|
|
if provider not in self.limits:
|
|
return True, 0.0
|
|
|
|
now = time.time()
|
|
minute_ago = now - 60
|
|
|
|
# Clean old requests
|
|
self.requests[provider] = [t for t in self.requests[provider] if t > minute_ago]
|
|
|
|
limit = self.limits[provider]
|
|
|
|
# Check RPM
|
|
if len(self.requests[provider]) >= limit['rpm']:
|
|
oldest = min(self.requests[provider])
|
|
wait = 60 - (now - oldest)
|
|
return False, max(0, wait)
|
|
|
|
return True, 0.0
|
|
|
|
def record_request(self, provider: str, tokens: int = 0):
|
|
"""Record a request."""
|
|
if provider not in self.requests:
|
|
self.requests[provider] = []
|
|
self.requests[provider].append(time.time())
|
|
|
|
|
|
class OrchestratorDB:
|
|
"""SQLite-backed job queue and state management."""
|
|
|
|
SCHEMA = """
|
|
CREATE TABLE IF NOT EXISTS jobs (
|
|
id TEXT PRIMARY KEY,
|
|
pipeline TEXT NOT NULL,
|
|
task TEXT NOT NULL,
|
|
status TEXT NOT NULL,
|
|
priority INTEGER NOT NULL,
|
|
token_budget INTEGER NOT NULL,
|
|
token_usage TEXT NOT NULL,
|
|
retry_count INTEGER DEFAULT 0,
|
|
max_retries INTEGER DEFAULT 3,
|
|
created_at REAL NOT NULL,
|
|
started_at REAL,
|
|
completed_at REAL,
|
|
error TEXT,
|
|
result TEXT,
|
|
checkpoint TEXT,
|
|
metadata TEXT
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
|
|
CREATE INDEX IF NOT EXISTS idx_jobs_pipeline ON jobs(pipeline);
|
|
CREATE INDEX IF NOT EXISTS idx_jobs_priority ON jobs(priority DESC);
|
|
|
|
CREATE TABLE IF NOT EXISTS checkpoints (
|
|
job_id TEXT PRIMARY KEY,
|
|
step INTEGER NOT NULL,
|
|
data TEXT NOT NULL,
|
|
timestamp REAL NOT NULL,
|
|
FOREIGN KEY (job_id) REFERENCES jobs(id)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS reports (
|
|
id TEXT PRIMARY KEY,
|
|
pipeline TEXT NOT NULL,
|
|
job_ids TEXT NOT NULL,
|
|
summary TEXT NOT NULL,
|
|
token_usage TEXT NOT NULL,
|
|
created_at REAL NOT NULL
|
|
);
|
|
"""
|
|
|
|
def __init__(self, db_path: Optional[Path] = None):
|
|
self.db_path = db_path or ORCHESTRATOR_DB
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._init_db()
|
|
|
|
def _init_db(self):
|
|
"""Initialize database schema."""
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
conn.executescript(self.SCHEMA)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def _get_conn(self) -> sqlite3.Connection:
|
|
"""Get database connection."""
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
def save_job(self, job: Job):
|
|
"""Save or update a job."""
|
|
conn = self._get_conn()
|
|
conn.execute("""
|
|
INSERT OR REPLACE INTO jobs
|
|
(id, pipeline, task, status, priority, token_budget, token_usage,
|
|
retry_count, max_retries, created_at, started_at, completed_at,
|
|
error, result, checkpoint, metadata)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
job.id, job.pipeline, json.dumps(job.task), job.status.value,
|
|
job.priority.value, job.token_budget, json.dumps(job.token_usage.to_dict()),
|
|
job.retry_count, job.max_retries, job.created_at, job.started_at,
|
|
job.completed_at, job.error, json.dumps(job.result) if job.result else None,
|
|
json.dumps(job.checkpoint.to_dict()) if job.checkpoint else None,
|
|
json.dumps(job.metadata)
|
|
))
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def get_job(self, job_id: str) -> Optional[Job]:
|
|
"""Get a job by ID."""
|
|
conn = self._get_conn()
|
|
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
|
conn.close()
|
|
|
|
if not row:
|
|
return None
|
|
|
|
return Job(
|
|
id=row['id'],
|
|
pipeline=row['pipeline'],
|
|
task=json.loads(row['task']),
|
|
status=JobStatus(row['status']),
|
|
priority=JobPriority(row['priority']),
|
|
token_budget=row['token_budget'],
|
|
token_usage=TokenUsage.from_dict(json.loads(row['token_usage'])),
|
|
retry_count=row['retry_count'],
|
|
max_retries=row['max_retries'],
|
|
created_at=row['created_at'],
|
|
started_at=row['started_at'],
|
|
completed_at=row['completed_at'],
|
|
error=row['error'],
|
|
result=json.loads(row['result']) if row['result'] else None,
|
|
checkpoint=JobCheckpoint.from_dict(json.loads(row['checkpoint'])) if row['checkpoint'] else None,
|
|
metadata=json.loads(row['metadata']) if row['metadata'] else {}
|
|
)
|
|
|
|
def get_next_job(self, pipeline: Optional[str] = None) -> Optional[Job]:
|
|
"""Get next pending job (highest priority first)."""
|
|
conn = self._get_conn()
|
|
|
|
query = "SELECT * FROM jobs WHERE status = 'pending'"
|
|
params = []
|
|
|
|
if pipeline:
|
|
query += " AND pipeline = ?"
|
|
params.append(pipeline)
|
|
|
|
query += " ORDER BY priority DESC, created_at ASC LIMIT 1"
|
|
|
|
row = conn.execute(query, params).fetchone()
|
|
conn.close()
|
|
|
|
if not row:
|
|
return None
|
|
|
|
return Job(
|
|
id=row['id'],
|
|
pipeline=row['pipeline'],
|
|
task=json.loads(row['task']),
|
|
status=JobStatus(row['status']),
|
|
priority=JobPriority(row['priority']),
|
|
token_budget=row['token_budget'],
|
|
token_usage=TokenUsage.from_dict(json.loads(row['token_usage'])),
|
|
retry_count=row['retry_count'],
|
|
max_retries=row['max_retries'],
|
|
created_at=row['created_at'],
|
|
started_at=row['started_at'],
|
|
completed_at=row['completed_at'],
|
|
error=row['error'],
|
|
result=json.loads(row['result']) if row['result'] else None,
|
|
checkpoint=JobCheckpoint.from_dict(json.loads(row['checkpoint'])) if row['checkpoint'] else None,
|
|
metadata=json.loads(row['metadata']) if row['metadata'] else {}
|
|
)
|
|
|
|
def get_jobs_by_status(self, status: JobStatus, pipeline: Optional[str] = None) -> List[Job]:
|
|
"""Get all jobs with given status."""
|
|
conn = self._get_conn()
|
|
|
|
query = "SELECT * FROM jobs WHERE status = ?"
|
|
params = [status.value]
|
|
|
|
if pipeline:
|
|
query += " AND pipeline = ?"
|
|
params.append(pipeline)
|
|
|
|
query += " ORDER BY priority DESC, created_at ASC"
|
|
|
|
rows = conn.execute(query, params).fetchall()
|
|
conn.close()
|
|
|
|
return [
|
|
Job(
|
|
id=row['id'],
|
|
pipeline=row['pipeline'],
|
|
task=json.loads(row['task']),
|
|
status=JobStatus(row['status']),
|
|
priority=JobPriority(row['priority']),
|
|
token_budget=row['token_budget'],
|
|
token_usage=TokenUsage.from_dict(json.loads(row['token_usage'])),
|
|
retry_count=row['retry_count'],
|
|
max_retries=row['max_retries'],
|
|
created_at=row['created_at'],
|
|
started_at=row['started_at'],
|
|
completed_at=row['completed_at'],
|
|
error=row['error'],
|
|
result=json.loads(row['result']) if row['result'] else None,
|
|
checkpoint=JobCheckpoint.from_dict(json.loads(row['checkpoint'])) if row['checkpoint'] else None,
|
|
metadata=json.loads(row['metadata']) if row['metadata'] else {}
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
def save_checkpoint(self, job_id: str, checkpoint: JobCheckpoint):
|
|
"""Save a checkpoint for a job."""
|
|
conn = self._get_conn()
|
|
conn.execute("""
|
|
INSERT OR REPLACE INTO checkpoints (job_id, step, data, timestamp)
|
|
VALUES (?, ?, ?, ?)
|
|
""", (job_id, checkpoint.step, json.dumps(checkpoint.data), checkpoint.timestamp))
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def get_checkpoint(self, job_id: str) -> Optional[JobCheckpoint]:
|
|
"""Get the latest checkpoint for a job."""
|
|
conn = self._get_conn()
|
|
row = conn.execute(
|
|
"SELECT * FROM checkpoints WHERE job_id = ? ORDER BY step DESC LIMIT 1",
|
|
(job_id,)
|
|
).fetchone()
|
|
conn.close()
|
|
|
|
if not row:
|
|
return None
|
|
|
|
return JobCheckpoint(
|
|
job_id=row['job_id'],
|
|
step=row['step'],
|
|
data=json.loads(row['data']),
|
|
timestamp=row['timestamp']
|
|
)
|
|
|
|
def get_stats(self, pipeline: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Get queue statistics."""
|
|
conn = self._get_conn()
|
|
|
|
query = "SELECT status, COUNT(*) as count FROM jobs"
|
|
params = []
|
|
|
|
if pipeline:
|
|
query += " WHERE pipeline = ?"
|
|
params.append(pipeline)
|
|
|
|
query += " GROUP BY status"
|
|
|
|
rows = conn.execute(query, params).fetchall()
|
|
conn.close()
|
|
|
|
stats = {row['status']: row['count'] for row in rows}
|
|
|
|
# Get token usage
|
|
conn = self._get_conn()
|
|
query = "SELECT SUM(CAST(json_extract(token_usage, '$.input_tokens') AS INTEGER)) as input, SUM(CAST(json_extract(token_usage, '$.output_tokens') AS INTEGER)) as output FROM jobs"
|
|
if pipeline:
|
|
query += " WHERE pipeline = ?"
|
|
|
|
row = conn.execute(query, params if pipeline else []).fetchone()
|
|
conn.close()
|
|
|
|
stats['total_input_tokens'] = row['input'] or 0
|
|
stats['total_output_tokens'] = row['output'] or 0
|
|
stats['total_tokens'] = stats['total_input_tokens'] + stats['total_output_tokens']
|
|
|
|
return stats
|
|
|
|
|
|
class PipelineOrchestrator:
|
|
"""Main orchestrator for pipeline execution."""
|
|
|
|
def __init__(
|
|
self,
|
|
max_workers: int = DEFAULT_WORKERS,
|
|
token_budget: int = DEFAULT_TOKEN_BUDGET,
|
|
db_path: Optional[Path] = None
|
|
):
|
|
self.max_workers = max_workers
|
|
self.token_budget = token_budget
|
|
self.db = OrchestratorDB(db_path)
|
|
self.rate_limiter = RateLimiter()
|
|
self.executor: Optional[ThreadPoolExecutor] = None
|
|
self.running = False
|
|
self._handlers: Dict[str, Callable] = {}
|
|
|
|
def register_handler(self, pipeline: str, handler: Callable):
|
|
"""Register a handler for a pipeline type."""
|
|
self._handlers[pipeline] = handler
|
|
|
|
def configure_rate_limit(self, provider: str, rpm: int, tpm: int):
|
|
"""Configure rate limits for a provider."""
|
|
self.rate_limiter.configure(provider, rpm, tpm)
|
|
|
|
def submit_job(
|
|
self,
|
|
pipeline: str,
|
|
task: Dict[str, Any],
|
|
priority: JobPriority = JobPriority.NORMAL,
|
|
token_budget: Optional[int] = None,
|
|
max_retries: int = 3,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> str:
|
|
"""Submit a new job to the queue."""
|
|
job_id = str(uuid.uuid4())
|
|
|
|
job = Job(
|
|
id=job_id,
|
|
pipeline=pipeline,
|
|
task=task,
|
|
priority=priority,
|
|
token_budget=token_budget or self.token_budget,
|
|
max_retries=max_retries,
|
|
metadata=metadata or {}
|
|
)
|
|
|
|
self.db.save_job(job)
|
|
logger.info(f"Job {job_id} submitted to pipeline {pipeline}")
|
|
|
|
return job_id
|
|
|
|
def submit_batch(
|
|
self,
|
|
pipeline: str,
|
|
tasks: List[Dict[str, Any]],
|
|
priority: JobPriority = JobPriority.NORMAL,
|
|
token_budget: Optional[int] = None
|
|
) -> List[str]:
|
|
"""Submit multiple jobs at once."""
|
|
job_ids = []
|
|
for task in tasks:
|
|
job_id = self.submit_job(pipeline, task, priority, token_budget)
|
|
job_ids.append(job_id)
|
|
|
|
logger.info(f"Submitted {len(job_ids)} jobs to pipeline {pipeline}")
|
|
return job_ids
|
|
|
|
def _execute_job(self, job: Job) -> Job:
|
|
"""Execute a single job with retry logic."""
|
|
handler = self._handlers.get(job.pipeline)
|
|
if not handler:
|
|
job.status = JobStatus.FAILED
|
|
job.error = f"No handler registered for pipeline: {job.pipeline}"
|
|
job.completed_at = time.time()
|
|
self.db.save_job(job)
|
|
return job
|
|
|
|
# Check token budget
|
|
if job.token_usage.total_tokens >= job.token_budget:
|
|
job.status = JobStatus.FAILED
|
|
job.error = "Token budget exceeded"
|
|
job.completed_at = time.time()
|
|
self.db.save_job(job)
|
|
return job
|
|
|
|
# Update status
|
|
job.status = JobStatus.RUNNING
|
|
job.started_at = time.time()
|
|
self.db.save_job(job)
|
|
|
|
try:
|
|
# Execute with exponential backoff retry
|
|
for attempt in range(job.max_retries + 1):
|
|
try:
|
|
# Check rate limits
|
|
provider = job.metadata.get('provider', 'default')
|
|
can_proceed, wait_time = self.rate_limiter.can_proceed(provider)
|
|
|
|
if not can_proceed:
|
|
logger.info(f"Rate limited, waiting {wait_time:.1f}s")
|
|
time.sleep(wait_time)
|
|
|
|
# Execute the handler
|
|
result = handler(job)
|
|
|
|
# Record request
|
|
self.rate_limiter.record_request(provider)
|
|
|
|
# Update job with result
|
|
job.result = result
|
|
job.status = JobStatus.COMPLETED
|
|
job.completed_at = time.time()
|
|
|
|
# Update token usage from result if provided
|
|
if 'token_usage' in result:
|
|
usage = result['token_usage']
|
|
job.token_usage.input_tokens += usage.get('input_tokens', 0)
|
|
job.token_usage.output_tokens += usage.get('output_tokens', 0)
|
|
job.token_usage.cache_read_tokens += usage.get('cache_read_tokens', 0)
|
|
job.token_usage.cache_write_tokens += usage.get('cache_write_tokens', 0)
|
|
job.token_usage.cost_usd += usage.get('cost_usd', 0)
|
|
|
|
self.db.save_job(job)
|
|
return job
|
|
|
|
except Exception as e:
|
|
job.retry_count += 1
|
|
|
|
if job.retry_count >= job.max_retries:
|
|
job.status = JobStatus.FAILED
|
|
job.error = str(e)
|
|
job.completed_at = time.time()
|
|
self.db.save_job(job)
|
|
return job
|
|
|
|
# Exponential backoff
|
|
wait_time = (2 ** job.retry_count) + (time.time() % 1)
|
|
logger.warning(f"Job {job.id} failed (attempt {job.retry_count}), retrying in {wait_time:.1f}s: {e}")
|
|
time.sleep(wait_time)
|
|
|
|
except Exception as e:
|
|
job.status = JobStatus.FAILED
|
|
job.error = f"Unexpected error: {str(e)}"
|
|
job.completed_at = time.time()
|
|
self.db.save_job(job)
|
|
return job
|
|
|
|
return job
|
|
|
|
def run(self, pipeline: Optional[str] = None, max_jobs: Optional[int] = None):
|
|
"""Run the orchestrator, processing jobs from the queue.
|
|
|
|
On startup, checks for paused/running jobs with checkpoints and
|
|
resumes them first before picking up new pending jobs.
|
|
"""
|
|
self.running = True
|
|
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
|
futures = {}
|
|
|
|
logger.info(f"Orchestrator starting (workers={self.max_workers})")
|
|
|
|
# Resume paused jobs with checkpoints on restart
|
|
for status in (JobStatus.PAUSED, JobStatus.RUNNING):
|
|
for job in self.db.get_jobs_by_status(status, pipeline):
|
|
if job.checkpoint:
|
|
logger.info(f"Resuming {status.value} job {job.id} from checkpoint step {job.checkpoint.step}")
|
|
job.status = JobStatus.PENDING
|
|
self.db.save_job(job)
|
|
|
|
try:
|
|
jobs_processed = 0
|
|
|
|
while self.running:
|
|
# Check completed futures
|
|
done = [f for f in futures if f.done()]
|
|
for f in done:
|
|
try:
|
|
f.result() # propagate exceptions
|
|
except Exception as e:
|
|
logger.error(f"Worker error: {e}")
|
|
del futures[f]
|
|
|
|
# Throttle if at capacity
|
|
if len(futures) >= self.max_workers:
|
|
time.sleep(0.1)
|
|
continue
|
|
|
|
# Get next job
|
|
job = self.db.get_next_job(pipeline)
|
|
|
|
if not job:
|
|
if not futures:
|
|
# No jobs and no workers — done
|
|
break
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
future = self.executor.submit(self._execute_job, job)
|
|
futures[future] = job.id
|
|
jobs_processed += 1
|
|
|
|
if max_jobs and jobs_processed >= max_jobs:
|
|
logger.info(f"Reached max_jobs limit ({max_jobs})")
|
|
break
|
|
|
|
# Wait for remaining futures
|
|
for f in futures:
|
|
try:
|
|
f.result(timeout=300)
|
|
except Exception as e:
|
|
logger.error(f"Worker error on drain: {e}")
|
|
|
|
finally:
|
|
self.executor.shutdown(wait=True)
|
|
self.running = False
|
|
logger.info(f"Orchestrator stopped (processed {jobs_processed} jobs)")
|
|
|
|
def run_single(self, job_id: str) -> Job:
|
|
"""Run a single job by ID (useful for resume)."""
|
|
job = self.db.get_job(job_id)
|
|
|
|
if not job:
|
|
raise ValueError(f"Job not found: {job_id}")
|
|
|
|
if job.status not in (JobStatus.PENDING, JobStatus.FAILED):
|
|
raise ValueError(f"Job {job_id} is not pending or failed (status: {job.status})")
|
|
|
|
# Reset for retry
|
|
if job.status == JobStatus.FAILED:
|
|
job.status = JobStatus.PENDING
|
|
job.retry_count = 0
|
|
job.error = None
|
|
self.db.save_job(job)
|
|
|
|
return self._execute_job(job)
|
|
|
|
def pause_job(self, job_id: str):
|
|
"""Pause a job."""
|
|
job = self.db.get_job(job_id)
|
|
if job and job.status == JobStatus.RUNNING:
|
|
job.status = JobStatus.PAUSED
|
|
self.db.save_job(job)
|
|
logger.info(f"Job {job_id} paused")
|
|
|
|
def resume_job(self, job_id: str) -> Job:
|
|
"""Resume a paused job from checkpoint."""
|
|
job = self.db.get_job(job_id)
|
|
|
|
if not job:
|
|
raise ValueError(f"Job not found: {job_id}")
|
|
|
|
if job.status != JobStatus.PAUSED:
|
|
raise ValueError(f"Job {job_id} is not paused (status: {job.status})")
|
|
|
|
# Load checkpoint if exists
|
|
checkpoint = self.db.get_checkpoint(job_id)
|
|
if checkpoint:
|
|
job.checkpoint = checkpoint
|
|
logger.info(f"Resuming job {job_id} from checkpoint step {checkpoint.step}")
|
|
|
|
job.status = JobStatus.PENDING
|
|
self.db.save_job(job)
|
|
|
|
return self._execute_job(job)
|
|
|
|
def cancel_job(self, job_id: str):
|
|
"""Cancel a job."""
|
|
job = self.db.get_job(job_id)
|
|
if job and job.status in (JobStatus.PENDING, JobStatus.RUNNING, JobStatus.PAUSED):
|
|
job.status = JobStatus.CANCELLED
|
|
job.completed_at = time.time()
|
|
self.db.save_job(job)
|
|
logger.info(f"Job {job_id} cancelled")
|
|
|
|
def get_progress(self, job_id: str) -> Dict[str, Any]:
|
|
"""Get job progress."""
|
|
job = self.db.get_job(job_id)
|
|
|
|
if not job:
|
|
return {"error": "Job not found"}
|
|
|
|
progress = {
|
|
"job_id": job.id,
|
|
"pipeline": job.pipeline,
|
|
"status": job.status.value,
|
|
"retry_count": job.retry_count,
|
|
"token_usage": job.token_usage.to_dict(),
|
|
"token_budget": job.token_budget,
|
|
"token_percent": (job.token_usage.total_tokens / job.token_budget * 100) if job.token_budget > 0 else 0,
|
|
"created_at": job.created_at,
|
|
"started_at": job.started_at,
|
|
"completed_at": job.completed_at,
|
|
}
|
|
|
|
if job.checkpoint:
|
|
progress["checkpoint"] = {
|
|
"step": job.checkpoint.step,
|
|
"timestamp": job.checkpoint.timestamp,
|
|
}
|
|
|
|
if job.error:
|
|
progress["error"] = job.error
|
|
|
|
return progress
|
|
|
|
def generate_report(self, pipeline: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Generate a summary report."""
|
|
stats = self.db.get_stats(pipeline)
|
|
|
|
completed_jobs = self.db.get_jobs_by_status(JobStatus.COMPLETED, pipeline)
|
|
failed_jobs = self.db.get_jobs_by_status(JobStatus.FAILED, pipeline)
|
|
|
|
# Calculate timing stats
|
|
durations = []
|
|
for job in completed_jobs:
|
|
if job.started_at and job.completed_at:
|
|
durations.append(job.completed_at - job.started_at)
|
|
|
|
report = {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"pipeline": pipeline or "all",
|
|
"stats": stats,
|
|
"completed": len(completed_jobs),
|
|
"failed": len(failed_jobs),
|
|
"success_rate": len(completed_jobs) / (len(completed_jobs) + len(failed_jobs)) * 100 if (completed_jobs or failed_jobs) else 0,
|
|
"avg_duration": sum(durations) / len(durations) if durations else 0,
|
|
"total_tokens": stats.get('total_tokens', 0),
|
|
}
|
|
|
|
return report
|
|
|
|
def stop(self):
|
|
"""Stop the orchestrator."""
|
|
self.running = False
|
|
logger.info("Orchestrator stop requested")
|
|
|
|
|
|
# CLI interface
|
|
def main():
|
|
"""CLI for orchestrator management."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Pipeline Orchestrator")
|
|
subparsers = parser.add_subparsers(dest="command")
|
|
|
|
# List jobs
|
|
list_parser = subparsers.add_parser("list", help="List jobs")
|
|
list_parser.add_argument("--status", help="Filter by status")
|
|
list_parser.add_argument("--pipeline", help="Filter by pipeline")
|
|
|
|
# Submit job
|
|
submit_parser = subparsers.add_parser("submit", help="Submit a job")
|
|
submit_parser.add_argument("pipeline", help="Pipeline name")
|
|
submit_parser.add_argument("--task", required=True, help="Task JSON")
|
|
submit_parser.add_argument("--priority", type=int, default=5, help="Job priority")
|
|
submit_parser.add_argument("--budget", type=int, help="Token budget")
|
|
|
|
# Run orchestrator
|
|
run_parser = subparsers.add_parser("run", help="Run orchestrator")
|
|
run_parser.add_argument("--pipeline", help="Filter by pipeline")
|
|
run_parser.add_argument("--workers", type=int, default=10, help="Max workers")
|
|
run_parser.add_argument("--max-jobs", type=int, help="Max jobs to process")
|
|
|
|
# Job management
|
|
status_parser = subparsers.add_parser("status", help="Get job status")
|
|
status_parser.add_argument("job_id", help="Job ID")
|
|
|
|
resume_parser = subparsers.add_parser("resume", help="Resume paused job")
|
|
resume_parser.add_argument("job_id", help="Job ID")
|
|
|
|
cancel_parser = subparsers.add_parser("cancel", help="Cancel job")
|
|
cancel_parser.add_argument("job_id", help="Job ID")
|
|
|
|
# Stats
|
|
stats_parser = subparsers.add_parser("stats", help="Show queue stats")
|
|
stats_parser.add_argument("--pipeline", help="Filter by pipeline")
|
|
|
|
# Report
|
|
report_parser = subparsers.add_parser("report", help="Generate report")
|
|
report_parser.add_argument("--pipeline", help="Filter by pipeline")
|
|
|
|
args = parser.parse_args()
|
|
|
|
orchestrator = PipelineOrchestrator(max_workers=args.workers if hasattr(args, 'workers') else 10)
|
|
|
|
if args.command == "submit":
|
|
task = json.loads(args.task)
|
|
priority = JobPriority(args.priority)
|
|
job_id = orchestrator.submit_job(args.pipeline, task, priority, args.budget)
|
|
print(f"Job submitted: {job_id}")
|
|
|
|
elif args.command == "run":
|
|
orchestrator.run(args.pipeline, args.max_jobs)
|
|
|
|
elif args.command == "list":
|
|
status_filter = JobStatus(args.status) if args.status else None
|
|
if status_filter:
|
|
jobs = orchestrator.db.get_jobs_by_status(status_filter, args.pipeline)
|
|
else:
|
|
# Show all jobs
|
|
conn = orchestrator.db._get_conn()
|
|
rows = conn.execute("SELECT * FROM jobs ORDER BY priority DESC, created_at ASC").fetchall()
|
|
conn.close()
|
|
jobs = [orchestrator.db.get_job(row['id']) for row in rows]
|
|
for job in jobs:
|
|
dur = ""
|
|
if job.started_at and job.completed_at:
|
|
dur = f" ({job.completed_at - job.started_at:.1f}s)"
|
|
print(f" {job.id[:8]} {job.status.value:10s} p{job.priority.value} {job.pipeline} tokens={job.token_usage.total_tokens}{dur}")
|
|
print(f"\n{len(jobs)} jobs")
|
|
|
|
elif args.command == "status":
|
|
progress = orchestrator.get_progress(args.job_id)
|
|
print(json.dumps(progress, indent=2))
|
|
|
|
elif args.command == "resume":
|
|
job = orchestrator.resume_job(args.job_id)
|
|
print(f"Job {args.job_id} completed with status: {job.status.value}")
|
|
|
|
elif args.command == "cancel":
|
|
orchestrator.cancel_job(args.job_id)
|
|
print(f"Job {args.job_id} cancelled")
|
|
|
|
elif args.command == "stats":
|
|
stats = orchestrator.db.get_stats(args.pipeline)
|
|
print(json.dumps(stats, indent=2))
|
|
|
|
elif args.command == "report":
|
|
report = orchestrator.generate_report(args.pipeline)
|
|
print(json.dumps(report, indent=2))
|
|
|
|
else:
|
|
parser.print_help()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|