diff --git a/pipeline/orchestrator.py b/pipeline/orchestrator.py new file mode 100644 index 00000000..f094fead --- /dev/null +++ b/pipeline/orchestrator.py @@ -0,0 +1,568 @@ +#!/usr/bin/env python3 +""" +orchestrator.py — Shared Pipeline Orchestrator + +SQLite-backed job queue with parallel workers, token budget tracking, +checkpoint resume, rate limiting, and error retry. + +All 5 pipelines use this orchestrator for consistent execution. + +Usage: + python3 orchestrator.py --pipeline training_factory --jobs jobs.jsonl + python3 orchestrator.py --pipeline adversary --jobs jobs.jsonl --workers 5 + python3 orchestrator.py --status + python3 orchestrator.py --resume training_factory + python3 orchestrator.py --report training_factory +""" + +import json +import os +import sys +import time +import sqlite3 +import hashlib +import threading +import signal +from datetime import datetime, timezone +from pathlib import Path +from dataclasses import dataclass, field +from typing import List, Optional, Dict, Any, Callable +from concurrent.futures import ThreadPoolExecutor, as_completed + +DB_PATH = Path.home() / ".hermes" / "pipeline" / "orchestrator.db" +REPORT_DIR = Path.home() / ".hermes" / "pipeline" / "reports" + +# ============================================================ +# Data Structures +# ============================================================ + +@dataclass +class JobStatus: + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + RETRYING = "retrying" + SKIPPED = "skipped" + + +@dataclass +class PipelineStats: + """Runtime statistics for a pipeline run.""" + pipeline: str + total_jobs: int = 0 + completed: int = 0 + failed: int = 0 + skipped: int = 0 + tokens_used: int = 0 + tokens_budget: int = 5_000_000 + elapsed_seconds: float = 0.0 + start_time: str = "" + jobs_per_minute: float = 0.0 + + def to_dict(self): + return { + "pipeline": self.pipeline, + "total_jobs": self.total_jobs, + "completed": self.completed, + "failed": self.failed, + "skipped": self.skipped, + "tokens_used": self.tokens_used, + "tokens_budget": self.tokens_budget, + "elapsed_seconds": round(self.elapsed_seconds, 1), + "start_time": self.start_time, + "jobs_per_minute": round(self.jobs_per_minute, 2), + } + + +# ============================================================ +# Database +# ============================================================ + +def get_db(): + """Get SQLite database connection.""" + DB_PATH.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(DB_PATH), timeout=30, check_same_thread=False) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA busy_timeout=5000") + _init_db(conn) + return conn + + +def _init_db(conn): + """Initialize database schema.""" + conn.executescript(""" + CREATE TABLE IF NOT EXISTS jobs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + pipeline TEXT NOT NULL, + job_key TEXT NOT NULL, + payload TEXT NOT NULL, + status TEXT DEFAULT 'pending', + attempts INTEGER DEFAULT 0, + max_attempts INTEGER DEFAULT 3, + tokens_used INTEGER DEFAULT 0, + error TEXT, + result TEXT, + checkpoint TEXT, + created_at TEXT DEFAULT (datetime('now')), + started_at TEXT, + completed_at TEXT, + UNIQUE(pipeline, job_key) + ); + + CREATE INDEX IF NOT EXISTS idx_jobs_pipeline_status ON jobs(pipeline, status); + CREATE INDEX IF NOT EXISTS idx_jobs_pipeline_key ON jobs(pipeline, job_key); + + CREATE TABLE IF NOT EXISTS pipeline_runs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + pipeline TEXT NOT NULL, + started_at TEXT DEFAULT (datetime('now')), + completed_at TEXT, + total_jobs INTEGER DEFAULT 0, + completed INTEGER DEFAULT 0, + failed INTEGER DEFAULT 0, + tokens_used INTEGER DEFAULT 0, + report TEXT + ); + """) + conn.commit() + + +# ============================================================ +# Job Queue +# ============================================================ + +class JobQueue: + """SQLite-backed job queue.""" + + def __init__(self, pipeline: str, conn=None): + self.pipeline = pipeline + self.conn = conn or get_db() + + def enqueue(self, job_key: str, payload: dict, max_attempts: int = 3): + """Add a job to the queue (skip if already exists).""" + try: + self.conn.execute( + "INSERT INTO jobs (pipeline, job_key, payload, max_attempts) VALUES (?, ?, ?, ?)", + (self.pipeline, job_key, json.dumps(payload), max_attempts), + ) + self.conn.commit() + return True + except sqlite3.IntegrityError: + # Already exists — check if it needs retry + row = self.conn.execute( + "SELECT status FROM jobs WHERE pipeline=? AND job_key=?", + (self.pipeline, job_key), + ).fetchone() + if row and row[0] == "failed": + # Reset for retry + self.conn.execute( + "UPDATE jobs SET status='pending', attempts=0, error=NULL WHERE pipeline=? AND job_key=?", + (self.pipeline, job_key), + ) + self.conn.commit() + return True + return False + + def enqueue_batch(self, jobs: List[dict], key_field: str = "id"): + """Enqueue multiple jobs. Returns (added, skipped) counts.""" + added = 0 + skipped = 0 + for job in jobs: + key = str(job.get(key_field, hashlib.md5(json.dumps(job).encode()).hexdigest()[:12])) + if self.enqueue(key, job): + added += 1 + else: + skipped += 1 + return added, skipped + + def claim_next(self) -> Optional[dict]: + """Claim the next pending job (atomic).""" + row = self.conn.execute( + """UPDATE jobs SET status='running', started_at=datetime('now') + WHERE id = ( + SELECT id FROM jobs WHERE pipeline=? AND status IN ('pending', 'retrying') + ORDER BY attempts ASC, created_at ASC LIMIT 1 + ) RETURNING *""", + (self.pipeline,), + ).fetchone() + + if not row: + return None + + cols = [d[1] for d in self.conn.execute("PRAGMA table_info(jobs)").fetchall()] + return dict(zip(cols, row)) + + def complete(self, job_key: str, result: dict, tokens_used: int = 0): + """Mark a job as completed.""" + self.conn.execute( + """UPDATE jobs SET status='completed', completed_at=datetime('now'), + result=?, tokens_used=? WHERE pipeline=? AND job_key=?""", + (json.dumps(result), tokens_used, self.pipeline, job_key), + ) + self.conn.commit() + + def fail(self, job_key: str, error: str, retry: bool = True): + """Mark a job as failed, optionally retry.""" + row = self.conn.execute( + "SELECT attempts, max_attempts FROM jobs WHERE pipeline=? AND job_key=?", + (self.pipeline, job_key), + ).fetchone() + + if not row: + return + + attempts, max_attempts = row + new_attempts = attempts + 1 + + if retry and new_attempts < max_attempts: + # Exponential backoff: 2^attempts seconds + delay = min(2 ** new_attempts, 60) + self.conn.execute( + """UPDATE jobs SET status='retrying', attempts=?, error=? + WHERE pipeline=? AND job_key=?""", + (new_attempts, error, self.pipeline, job_key), + ) + else: + self.conn.execute( + """UPDATE jobs SET status='failed', attempts=?, error=?, + completed_at=datetime('now') WHERE pipeline=? AND job_key=?""", + (new_attempts, error, self.pipeline, job_key), + ) + self.conn.commit() + + def save_checkpoint(self, job_key: str, checkpoint: dict): + """Save progress checkpoint for resume.""" + self.conn.execute( + "UPDATE jobs SET checkpoint=? WHERE pipeline=? AND job_key=?", + (json.dumps(checkpoint), self.pipeline, job_key), + ) + self.conn.commit() + + def get_checkpoint(self, job_key: str) -> Optional[dict]: + """Get saved checkpoint.""" + row = self.conn.execute( + "SELECT checkpoint FROM jobs WHERE pipeline=? AND job_key=?", + (self.pipeline, job_key), + ).fetchone() + if row and row[0]: + return json.loads(row[0]) + return None + + def stats(self) -> dict: + """Get queue statistics.""" + rows = self.conn.execute( + """SELECT status, COUNT(*), COALESCE(SUM(tokens_used), 0) + FROM jobs WHERE pipeline=? GROUP BY status""", + (self.pipeline,), + ).fetchall() + + result = {"total": 0, "tokens_used": 0} + for status, count, tokens in rows: + result[status] = count + result["total"] += count + result["tokens_used"] += tokens + return result + + +# ============================================================ +# Orchestrator +# ============================================================ + +class Orchestrator: + """ + Shared orchestrator for all pipelines. + + Features: + - Parallel worker pool (configurable) + - Token budget tracking + - Checkpoint resume + - Rate limiting + - Error retry with exponential backoff + - Final report generation + """ + + def __init__(self, pipeline: str, workers: int = 10, token_budget: int = 5_000_000): + self.pipeline = pipeline + self.workers = workers + self.token_budget = token_budget + self.queue = JobQueue(pipeline) + self.conn = self.queue.conn + self._shutdown = False + self._stats = PipelineStats(pipeline=pipeline, tokens_budget=token_budget) + self._rate_limit_delay = 0.1 # seconds between jobs + self._response_cache: Dict[str, dict] = {} + + signal.signal(signal.SIGINT, self._handle_signal) + signal.signal(signal.SIGTERM, self._handle_signal) + + def _handle_signal(self, signum, frame): + """Graceful shutdown on signal.""" + print(f"\nReceived signal {signum}. Shutting down gracefully...") + self._shutdown = True + + def load_jobs(self, jobs_path: str, key_field: str = "id"): + """Load jobs from a JSONL file into the queue.""" + jobs = [] + with open(jobs_path) as f: + for line in f: + line = line.strip() + if line: + jobs.append(json.loads(line)) + + added, skipped = self.queue.enqueue_batch(jobs, key_field) + print(f"Loaded: {added} new, {skipped} existing") + + def run(self, job_handler: Callable[[dict], dict] = None): + """ + Run the orchestrator. Processes all pending jobs with parallel workers. + + Args: + job_handler: function(job_payload) -> dict with 'tokens_used' key + """ + start = time.time() + self._stats.start_time = datetime.now(timezone.utc).isoformat() + + # Record run + self.conn.execute( + "INSERT INTO pipeline_runs (pipeline, started_at) VALUES (?, ?)", + (self.pipeline, self._stats.start_time), + ) + run_id = self.conn.execute("SELECT last_insert_rowid()").fetchone()[0] + self.conn.commit() + + stats = self.queue.stats() + self._stats.total_jobs = stats.get("pending", 0) + stats.get("retrying", 0) + print(f"\nPipeline: {self.pipeline}") + print(f"Jobs: {self._stats.total_jobs} pending | Workers: {self.workers} | Budget: {self.token_budget:,} tokens") + print() + + if self._stats.total_jobs == 0: + print("No jobs to process.") + return + + completed = 0 + failed = 0 + skipped = 0 + tokens_used = 0 + + with ThreadPoolExecutor(max_workers=self.workers) as executor: + futures = {} + + while not self._shutdown: + # Check token budget + if tokens_used >= self.token_budget: + print(f"Token budget exhausted ({tokens_used:,}/{self.token_budget:,})") + break + + # Fill worker pool + while len(futures) < self.workers and not self._shutdown: + job = self.queue.claim_next() + if not job: + break + + # Check response cache (zero-token retries) + job_key = job["job_key"] + payload = json.loads(job["payload"]) + cache_key = hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest() + + if cache_key in self._response_cache: + result = self._response_cache[cache_key] + self.queue.complete(job_key, result, tokens_used=0) + skipped += 1 + continue + + # Submit to worker + future = executor.submit(self._process_job, job, job_handler) + futures[future] = job + + # Rate limiting + time.sleep(self._rate_limit_delay) + + if not futures: + break + + # Collect results + done = [] + for future in as_completed(futures, timeout=1): + job = futures[future] + try: + result = future.result() + if result.get("success"): + tokens = result.get("tokens_used", 0) + tokens_used += tokens + self.queue.complete(job["job_key"], result, tokens_used=tokens) + completed += 1 + else: + error = result.get("error", "unknown error") + self.queue.fail(job["job_key"], error, retry=True) + failed += 1 + except Exception as e: + self.queue.fail(job["job_key"], str(e), retry=True) + failed += 1 + + done.append(future) + + # Progress update + total = completed + failed + skipped + if total % 10 == 0: + elapsed = time.time() - start + rate = completed / (elapsed / 60) if elapsed > 0 else 0 + print(f" Progress: {total}/{self._stats.total_jobs} | " + f"completed={completed} failed={failed} | " + f"tokens={tokens_used:,} | " + f"{rate:.1f}/min") + + for f in done: + del futures[f] + + # Final report + elapsed = time.time() - start + self._stats.completed = completed + self._stats.failed = failed + self._stats.skipped = skipped + self._stats.tokens_used = tokens_used + self._stats.elapsed_seconds = elapsed + self._stats.jobs_per_minute = completed / (elapsed / 60) if elapsed > 0 else 0 + + # Save run + self.conn.execute( + """UPDATE pipeline_runs SET completed_at=?, total_jobs=?, completed=?, + failed=?, tokens_used=?, report=? WHERE id=?""", + (datetime.now(timezone.utc).isoformat(), self._stats.total_jobs, + completed, failed, tokens_used, json.dumps(self._stats.to_dict()), run_id), + ) + self.conn.commit() + + # Print report + print(f"\n{'='*50}") + print(f"Pipeline: {self.pipeline}") + print(f"Completed: {completed}/{self._stats.total_jobs}") + print(f"Failed: {failed}") + print(f"Skipped (cached): {skipped}") + print(f"Tokens: {tokens_used:,}/{self.token_budget:,}") + print(f"Time: {elapsed:.1f}s ({self._stats.jobs_per_minute:.1f}/min)") + print(f"{'='*50}") + + # Save report file + self._save_report() + + def _process_job(self, job: dict, handler: Callable = None) -> dict: + """Process a single job.""" + payload = json.loads(job["payload"]) + job_key = job["job_key"] + checkpoint = self.queue.get_checkpoint(job_key) + + if handler: + try: + result = handler(payload, checkpoint=checkpoint) + return result or {"success": True, "tokens_used": 0} + except Exception as e: + return {"success": False, "error": str(e)} + else: + # Default handler: just mark as complete + return {"success": True, "tokens_used": 0} + + def _save_report(self): + """Save pipeline run report.""" + REPORT_DIR.mkdir(parents=True, exist_ok=True) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + path = REPORT_DIR / f"{self.pipeline}_{ts}.json" + with open(path, "w") as f: + json.dump(self._stats.to_dict(), f, indent=2) + print(f"Report: {path}") + + def resume(self): + """Resume failed/retrying jobs from a previous run.""" + stats = self.queue.stats() + retrying = stats.get("retrying", 0) + failed = stats.get("failed", 0) + print(f"Resume {self.pipeline}: {retrying} retrying, {failed} failed to reset") + + # Reset failed jobs to pending for retry + self.conn.execute( + "UPDATE jobs SET status='pending', attempts=0 WHERE pipeline=? AND status='failed'", + (self.pipeline,), + ) + self.conn.execute( + "UPDATE jobs SET status='pending' WHERE pipeline=? AND status='retrying'", + (self.pipeline,), + ) + self.conn.commit() + + def status(self): + """Show pipeline status.""" + stats = self.queue.stats() + print(f"\nPipeline: {self.pipeline}") + for k, v in sorted(stats.items()): + print(f" {k}: {v}") + + +# ============================================================ +# CLI +# ============================================================ + +def show_all_status(): + """Show status of all pipelines.""" + conn = get_db() + pipelines = conn.execute( + "SELECT DISTINCT pipeline FROM jobs ORDER BY pipeline" + ).fetchall() + + if not pipelines: + print("No pipelines in database.") + return + + print(f"\nAll Pipeline Status") + print(f"{'='*60}") + + for (pipeline,) in pipelines: + queue = JobQueue(pipeline, conn) + stats = queue.stats() + total = stats.get("total", 0) + pending = stats.get("pending", 0) + running = stats.get("running", 0) + completed = stats.get("completed", 0) + failed = stats.get("failed", 0) + tokens = stats.get("tokens_used", 0) + print(f" {pipeline:25} total={total:4} pending={pending:3} running={running:2} " + f"completed={completed:4} failed={failed:3} tokens={tokens:,}") + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Shared Pipeline Orchestrator") + parser.add_argument("--pipeline", "-p", help="Pipeline name") + parser.add_argument("--jobs", "-j", help="Jobs JSONL file to load") + parser.add_argument("--workers", "-w", type=int, default=10, help="Parallel workers") + parser.add_argument("--budget", "-b", type=int, default=5_000_000, help="Token budget") + parser.add_argument("--status", action="store_true", help="Show status") + parser.add_argument("--resume", action="store_true", help="Resume failed jobs") + parser.add_argument("--key-field", default="id", help="Job key field name") + args = parser.parse_args() + + if args.status: + if args.pipeline: + orch = Orchestrator(args.pipeline) + orch.status() + else: + show_all_status() + return + + if not args.pipeline: + parser.error("--pipeline is required") + + orch = Orchestrator(args.pipeline, workers=args.workers, token_budget=args.budget) + + if args.jobs: + orch.load_jobs(args.jobs, key_field=args.key_field) + + if args.resume: + orch.resume() + + if args.jobs or args.resume: + orch.run() + + +if __name__ == "__main__": + main()