#!/usr/bin/env python3 """Tests for pipeline orchestrator — queue, parallelism, resume, token tracking.""" import json import os import tempfile import time from pathlib import Path import pytest # Add project root to path import sys sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from pipelines.orchestrator import ( PipelineOrchestrator, Job, JobStatus, JobPriority, JobCheckpoint, TokenUsage, OrchestratorDB, RateLimiter, ) @pytest.fixture def tmp_db(tmp_path): """Fresh orchestrator DB for each test.""" return tmp_path / "test_orchestrator.db" @pytest.fixture def orch(tmp_db): """Orchestrator instance with temp DB.""" return PipelineOrchestrator(max_workers=2, token_budget=10000, db_path=tmp_db) class TestJobDataModels: """Test Job, TokenUsage, JobCheckpoint dataclasses.""" def test_token_usage_total(self): usage = TokenUsage(input_tokens=100, output_tokens=50) assert usage.total_tokens == 150 def test_token_usage_zero(self): usage = TokenUsage() assert usage.total_tokens == 0 def test_token_usage_serialization(self): usage = TokenUsage(input_tokens=10, output_tokens=20, cache_read_tokens=5, cost_usd=0.001) d = usage.to_dict() restored = TokenUsage.from_dict(d) assert restored.input_tokens == 10 assert restored.total_tokens == 30 def test_checkpoint_serialization(self): cp = JobCheckpoint(job_id="abc", step=3, data={"key": "val"}) d = cp.to_dict() restored = JobCheckpoint.from_dict(d) assert restored.step == 3 assert restored.data == {"key": "val"} def test_job_serialization(self): job = Job(id="test-1", pipeline="demo", task={"action": "run"}) d = job.to_dict() restored = Job.from_dict(d) assert restored.id == "test-1" assert restored.status == JobStatus.PENDING assert restored.priority == JobPriority.NORMAL class TestOrchestratorDB: """Test SQLite-backed job queue.""" def test_save_and_get(self, tmp_db): db = OrchestratorDB(tmp_db) job = Job(id="j1", pipeline="test", task={"x": 1}) db.save_job(job) fetched = db.get_job("j1") assert fetched is not None assert fetched.id == "j1" assert fetched.task == {"x": 1} def test_get_next_job_priority(self, tmp_db): db = OrchestratorDB(tmp_db) db.save_job(Job(id="low", pipeline="test", task={}, priority=JobPriority.LOW)) db.save_job(Job(id="high", pipeline="test", task={}, priority=JobPriority.HIGH)) db.save_job(Job(id="normal", pipeline="test", task={}, priority=JobPriority.NORMAL)) next_job = db.get_next_job() assert next_job.id == "high" def test_get_next_job_pipeline_filter(self, tmp_db): db = OrchestratorDB(tmp_db) db.save_job(Job(id="a", pipeline="alpha", task={})) db.save_job(Job(id="b", pipeline="beta", task={})) next_job = db.get_next_job(pipeline="beta") assert next_job.id == "b" def test_get_jobs_by_status(self, tmp_db): db = OrchestratorDB(tmp_db) db.save_job(Job(id="a", pipeline="test", task={}, status=JobStatus.PENDING)) db.save_job(Job(id="b", pipeline="test", task={}, status=JobStatus.COMPLETED)) pending = db.get_jobs_by_status(JobStatus.PENDING) assert len(pending) == 1 assert pending[0].id == "a" def test_checkpoint_save_load(self, tmp_db): db = OrchestratorDB(tmp_db) cp = JobCheckpoint(job_id="j1", step=5, data={"progress": "50%"}) db.save_checkpoint("j1", cp) loaded = db.get_checkpoint("j1") assert loaded is not None assert loaded.step == 5 assert loaded.data == {"progress": "50%"} def test_stats(self, tmp_db): db = OrchestratorDB(tmp_db) job = Job(id="j1", pipeline="test", task={}, status=JobStatus.COMPLETED) job.token_usage = TokenUsage(input_tokens=100, output_tokens=50) db.save_job(job) stats = db.get_stats() assert stats["completed"] == 1 assert stats["total_tokens"] == 150 class TestRateLimiter: """Test rate limiter.""" def test_no_limit(self): rl = RateLimiter() can_proceed, wait = rl.can_proceed("unknown") assert can_proceed is True assert wait == 0.0 def test_rpm_limit(self): rl = RateLimiter() rl.configure("test", requests_per_minute=2, tokens_per_minute=1000) assert rl.can_proceed("test")[0] is True rl.record_request("test") assert rl.can_proceed("test")[0] is True rl.record_request("test") can_proceed, wait = rl.can_proceed("test") assert can_proceed is False assert wait > 0 class TestPipelineOrchestrator: """Test the main orchestrator.""" def test_submit_and_retrieve(self, orch): job_id = orch.submit_job("test_pipeline", {"action": "process"}) assert job_id is not None progress = orch.get_progress(job_id) assert progress["status"] == "pending" assert progress["pipeline"] == "test_pipeline" def test_submit_batch(self, orch): ids = orch.submit_batch("test", [{"i": i} for i in range(5)]) assert len(ids) == 5 def test_handler_execution(self, orch): results = [] def handler(job): results.append(job.id) return {"status": "ok"} orch.register_handler("demo", handler) job_id = orch.submit_job("demo", {"action": "test"}) orch.run(max_jobs=1) progress = orch.get_progress(job_id) assert progress["status"] == "completed" assert len(results) == 1 def test_handler_failure_and_retry(self, orch): attempts = [] def handler(job): attempts.append(1) if len(attempts) < 3: raise ValueError("transient error") return {"status": "ok"} orch.register_handler("retry_test", handler) job_id = orch.submit_job("retry_test", {"action": "test"}, max_retries=3) orch.run(max_jobs=1) progress = orch.get_progress(job_id) assert progress["status"] == "completed" assert len(attempts) == 3 def test_handler_exhausts_retries(self, orch): def handler(job): raise ValueError("permanent error") orch.register_handler("fail_test", handler) job_id = orch.submit_job("fail_test", {"action": "test"}, max_retries=2) orch.run(max_jobs=1) progress = orch.get_progress(job_id) assert progress["status"] == "failed" assert "permanent error" in progress["error"] def test_no_handler(self, orch): job_id = orch.submit_job("nonexistent", {"action": "test"}) orch.run(max_jobs=1) progress = orch.get_progress(job_id) assert progress["status"] == "failed" assert "No handler" in progress["error"] def test_token_budget_tracking(self, orch): def handler(job): return {"status": "ok", "token_usage": {"input_tokens": 500, "output_tokens": 200}} orch.register_handler("token_test", handler) job_id = orch.submit_job("token_test", {"action": "test"}, token_budget=1000) orch.run(max_jobs=1) progress = orch.get_progress(job_id) assert progress["token_usage"]["input_tokens"] == 500 assert progress["token_usage"]["output_tokens"] == 200 def test_token_budget_exceeded(self, orch): def handler(job): return {"status": "ok"} orch.register_handler("budget_test", handler) # Set job with already-exhausted budget by manipulating DB job_id = orch.submit_job("budget_test", {"action": "test"}, token_budget=100) job = orch.db.get_job(job_id) job.token_usage = TokenUsage(input_tokens=100, output_tokens=10) orch.db.save_job(job) orch.run(max_jobs=1) progress = orch.get_progress(job_id) assert progress["status"] == "failed" assert "budget" in progress["error"].lower() def test_parallel_execution(self, orch): """Verify workers execute in parallel.""" import threading active = set() max_concurrent = [0] def handler(job): active.add(threading.current_thread().name) max_concurrent[0] = max(max_concurrent[0], len(active)) time.sleep(0.1) active.discard(threading.current_thread().name) return {"status": "ok"} orch.register_handler("parallel", handler) orch.submit_batch("parallel", [{"i": i} for i in range(4)]) orch.run(max_jobs=4) # With 2 workers, at least 2 should have been active simultaneously assert max_concurrent[0] >= 2 def test_resume_paused_job(self, orch): """Test resume from checkpoint.""" call_count = [0] def handler(job): call_count[0] += 1 if call_count[0] == 1: # Simulate saving checkpoint before failure job.checkpoint = JobCheckpoint(job_id=job.id, step=1, data={"partial": True}) orch.db.save_checkpoint(job.id, job.checkpoint) raise ValueError("first attempt fails") # Second attempt succeeds return {"status": "ok", "resumed_from": job.checkpoint.step if job.checkpoint else None} orch.register_handler("resume_test", handler) job_id = orch.submit_job("resume_test", {"action": "test"}, max_retries=3) # First run — fails, saves checkpoint orch.run(max_jobs=1) # Manually resume (set to pending) job = orch.db.get_job(job_id) if job.status == JobStatus.FAILED: job.status = JobStatus.PENDING job.retry_count = 0 job.error = None orch.db.save_job(job) orch.run(max_jobs=1) progress = orch.get_progress(job_id) assert progress["status"] == "completed" def test_resume_on_restart(self, orch): """Test that run() resumes paused/running jobs with checkpoints on startup.""" # Create a paused job with a checkpoint job = Job(id="resume-on-start", pipeline="restart_test", task={"action": "resume"}) job.status = JobStatus.PAUSED orch.db.save_job(job) orch.db.save_checkpoint("resume-on-start", JobCheckpoint( job_id="resume-on-start", step=3, data={"progress": "50%"} )) calls = [] def handler(job): calls.append(job.checkpoint.step if job.checkpoint else None) return {"status": "ok"} orch.register_handler("restart_test", handler) orch.run(max_jobs=1) # Job should have been auto-resumed and executed progress = orch.get_progress("resume-on-start") assert progress["status"] == "completed" assert calls == [3] # Handler saw checkpoint step 3 def test_cancel_job(self, orch): job_id = orch.submit_job("cancel_test", {"action": "test"}) orch.cancel_job(job_id) progress = orch.get_progress(job_id) assert progress["status"] == "cancelled" def test_generate_report(self, orch): def handler(job): return {"status": "ok", "token_usage": {"input_tokens": 100, "output_tokens": 50}} orch.register_handler("report_test", handler) orch.submit_batch("report_test", [{"i": i} for i in range(3)]) orch.run(max_jobs=3) report = orch.generate_report("report_test") assert report["completed"] == 3 assert report["failed"] == 0 assert report["success_rate"] == 100.0 assert report["total_tokens"] == 450 # 3 * 150