From 3250eba0cca5432ac337e5e9a3805ad52244a87f Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Fri, 17 Apr 2026 05:20:02 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20orchestrator=20test=20suite=20=E2=80=94?= =?UTF-8?q?=20queue,=20resume,=20parallel,=20tokens?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pipelines/tests/test_orchestrator.py | 333 +++++++++++++++++++++++++++ 1 file changed, 333 insertions(+) create mode 100644 pipelines/tests/test_orchestrator.py diff --git a/pipelines/tests/test_orchestrator.py b/pipelines/tests/test_orchestrator.py new file mode 100644 index 00000000..69ea799d --- /dev/null +++ b/pipelines/tests/test_orchestrator.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +"""Tests for pipeline orchestrator — queue, parallelism, resume, token tracking.""" + +import json +import os +import tempfile +import time +from pathlib import Path + +import pytest + +# Add project root to path +import sys +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from pipelines.orchestrator import ( + PipelineOrchestrator, + Job, + JobStatus, + JobPriority, + JobCheckpoint, + TokenUsage, + OrchestratorDB, + RateLimiter, +) + + +@pytest.fixture +def tmp_db(tmp_path): + """Fresh orchestrator DB for each test.""" + return tmp_path / "test_orchestrator.db" + + +@pytest.fixture +def orch(tmp_db): + """Orchestrator instance with temp DB.""" + return PipelineOrchestrator(max_workers=2, token_budget=10000, db_path=tmp_db) + + +class TestJobDataModels: + """Test Job, TokenUsage, JobCheckpoint dataclasses.""" + + def test_token_usage_total(self): + usage = TokenUsage(input_tokens=100, output_tokens=50) + assert usage.total_tokens == 150 + + def test_token_usage_zero(self): + usage = TokenUsage() + assert usage.total_tokens == 0 + + def test_token_usage_serialization(self): + usage = TokenUsage(input_tokens=10, output_tokens=20, cache_read_tokens=5, cost_usd=0.001) + d = usage.to_dict() + restored = TokenUsage.from_dict(d) + assert restored.input_tokens == 10 + assert restored.total_tokens == 30 + + def test_checkpoint_serialization(self): + cp = JobCheckpoint(job_id="abc", step=3, data={"key": "val"}) + d = cp.to_dict() + restored = JobCheckpoint.from_dict(d) + assert restored.step == 3 + assert restored.data == {"key": "val"} + + def test_job_serialization(self): + job = Job(id="test-1", pipeline="demo", task={"action": "run"}) + d = job.to_dict() + restored = Job.from_dict(d) + assert restored.id == "test-1" + assert restored.status == JobStatus.PENDING + assert restored.priority == JobPriority.NORMAL + + +class TestOrchestratorDB: + """Test SQLite-backed job queue.""" + + def test_save_and_get(self, tmp_db): + db = OrchestratorDB(tmp_db) + job = Job(id="j1", pipeline="test", task={"x": 1}) + db.save_job(job) + fetched = db.get_job("j1") + assert fetched is not None + assert fetched.id == "j1" + assert fetched.task == {"x": 1} + + def test_get_next_job_priority(self, tmp_db): + db = OrchestratorDB(tmp_db) + db.save_job(Job(id="low", pipeline="test", task={}, priority=JobPriority.LOW)) + db.save_job(Job(id="high", pipeline="test", task={}, priority=JobPriority.HIGH)) + db.save_job(Job(id="normal", pipeline="test", task={}, priority=JobPriority.NORMAL)) + next_job = db.get_next_job() + assert next_job.id == "high" + + def test_get_next_job_pipeline_filter(self, tmp_db): + db = OrchestratorDB(tmp_db) + db.save_job(Job(id="a", pipeline="alpha", task={})) + db.save_job(Job(id="b", pipeline="beta", task={})) + next_job = db.get_next_job(pipeline="beta") + assert next_job.id == "b" + + def test_get_jobs_by_status(self, tmp_db): + db = OrchestratorDB(tmp_db) + db.save_job(Job(id="a", pipeline="test", task={}, status=JobStatus.PENDING)) + db.save_job(Job(id="b", pipeline="test", task={}, status=JobStatus.COMPLETED)) + pending = db.get_jobs_by_status(JobStatus.PENDING) + assert len(pending) == 1 + assert pending[0].id == "a" + + def test_checkpoint_save_load(self, tmp_db): + db = OrchestratorDB(tmp_db) + cp = JobCheckpoint(job_id="j1", step=5, data={"progress": "50%"}) + db.save_checkpoint("j1", cp) + loaded = db.get_checkpoint("j1") + assert loaded is not None + assert loaded.step == 5 + assert loaded.data == {"progress": "50%"} + + def test_stats(self, tmp_db): + db = OrchestratorDB(tmp_db) + job = Job(id="j1", pipeline="test", task={}, status=JobStatus.COMPLETED) + job.token_usage = TokenUsage(input_tokens=100, output_tokens=50) + db.save_job(job) + stats = db.get_stats() + assert stats["completed"] == 1 + assert stats["total_tokens"] == 150 + + +class TestRateLimiter: + """Test rate limiter.""" + + def test_no_limit(self): + rl = RateLimiter() + can_proceed, wait = rl.can_proceed("unknown") + assert can_proceed is True + assert wait == 0.0 + + def test_rpm_limit(self): + rl = RateLimiter() + rl.configure("test", requests_per_minute=2, tokens_per_minute=1000) + assert rl.can_proceed("test")[0] is True + rl.record_request("test") + assert rl.can_proceed("test")[0] is True + rl.record_request("test") + can_proceed, wait = rl.can_proceed("test") + assert can_proceed is False + assert wait > 0 + + +class TestPipelineOrchestrator: + """Test the main orchestrator.""" + + def test_submit_and_retrieve(self, orch): + job_id = orch.submit_job("test_pipeline", {"action": "process"}) + assert job_id is not None + progress = orch.get_progress(job_id) + assert progress["status"] == "pending" + assert progress["pipeline"] == "test_pipeline" + + def test_submit_batch(self, orch): + ids = orch.submit_batch("test", [{"i": i} for i in range(5)]) + assert len(ids) == 5 + + def test_handler_execution(self, orch): + results = [] + def handler(job): + results.append(job.id) + return {"status": "ok"} + + orch.register_handler("demo", handler) + job_id = orch.submit_job("demo", {"action": "test"}) + orch.run(max_jobs=1) + + progress = orch.get_progress(job_id) + assert progress["status"] == "completed" + assert len(results) == 1 + + def test_handler_failure_and_retry(self, orch): + attempts = [] + def handler(job): + attempts.append(1) + if len(attempts) < 3: + raise ValueError("transient error") + return {"status": "ok"} + + orch.register_handler("retry_test", handler) + job_id = orch.submit_job("retry_test", {"action": "test"}, max_retries=3) + orch.run(max_jobs=1) + + progress = orch.get_progress(job_id) + assert progress["status"] == "completed" + assert len(attempts) == 3 + + def test_handler_exhausts_retries(self, orch): + def handler(job): + raise ValueError("permanent error") + + orch.register_handler("fail_test", handler) + job_id = orch.submit_job("fail_test", {"action": "test"}, max_retries=2) + orch.run(max_jobs=1) + + progress = orch.get_progress(job_id) + assert progress["status"] == "failed" + assert "permanent error" in progress["error"] + + def test_no_handler(self, orch): + job_id = orch.submit_job("nonexistent", {"action": "test"}) + orch.run(max_jobs=1) + progress = orch.get_progress(job_id) + assert progress["status"] == "failed" + assert "No handler" in progress["error"] + + def test_token_budget_tracking(self, orch): + def handler(job): + return {"status": "ok", "token_usage": {"input_tokens": 500, "output_tokens": 200}} + + orch.register_handler("token_test", handler) + job_id = orch.submit_job("token_test", {"action": "test"}, token_budget=1000) + orch.run(max_jobs=1) + + progress = orch.get_progress(job_id) + assert progress["token_usage"]["input_tokens"] == 500 + assert progress["token_usage"]["output_tokens"] == 200 + + def test_token_budget_exceeded(self, orch): + def handler(job): + return {"status": "ok"} + + orch.register_handler("budget_test", handler) + # Set job with already-exhausted budget by manipulating DB + job_id = orch.submit_job("budget_test", {"action": "test"}, token_budget=100) + job = orch.db.get_job(job_id) + job.token_usage = TokenUsage(input_tokens=100, output_tokens=10) + orch.db.save_job(job) + + orch.run(max_jobs=1) + progress = orch.get_progress(job_id) + assert progress["status"] == "failed" + assert "budget" in progress["error"].lower() + + def test_parallel_execution(self, orch): + """Verify workers execute in parallel.""" + import threading + active = set() + max_concurrent = [0] + + def handler(job): + active.add(threading.current_thread().name) + max_concurrent[0] = max(max_concurrent[0], len(active)) + time.sleep(0.1) + active.discard(threading.current_thread().name) + return {"status": "ok"} + + orch.register_handler("parallel", handler) + orch.submit_batch("parallel", [{"i": i} for i in range(4)]) + orch.run(max_jobs=4) + + # With 2 workers, at least 2 should have been active simultaneously + assert max_concurrent[0] >= 2 + + def test_resume_paused_job(self, orch): + """Test resume from checkpoint.""" + call_count = [0] + + def handler(job): + call_count[0] += 1 + if call_count[0] == 1: + # Simulate saving checkpoint before failure + job.checkpoint = JobCheckpoint(job_id=job.id, step=1, data={"partial": True}) + orch.db.save_checkpoint(job.id, job.checkpoint) + raise ValueError("first attempt fails") + # Second attempt succeeds + return {"status": "ok", "resumed_from": job.checkpoint.step if job.checkpoint else None} + + orch.register_handler("resume_test", handler) + job_id = orch.submit_job("resume_test", {"action": "test"}, max_retries=3) + + # First run — fails, saves checkpoint + orch.run(max_jobs=1) + + # Manually resume (set to pending) + job = orch.db.get_job(job_id) + if job.status == JobStatus.FAILED: + job.status = JobStatus.PENDING + job.retry_count = 0 + job.error = None + orch.db.save_job(job) + orch.run(max_jobs=1) + + progress = orch.get_progress(job_id) + assert progress["status"] == "completed" + + def test_resume_on_restart(self, orch): + """Test that run() resumes paused/running jobs with checkpoints on startup.""" + # Create a paused job with a checkpoint + job = Job(id="resume-on-start", pipeline="restart_test", task={"action": "resume"}) + job.status = JobStatus.PAUSED + orch.db.save_job(job) + orch.db.save_checkpoint("resume-on-start", JobCheckpoint( + job_id="resume-on-start", step=3, data={"progress": "50%"} + )) + + calls = [] + def handler(job): + calls.append(job.checkpoint.step if job.checkpoint else None) + return {"status": "ok"} + + orch.register_handler("restart_test", handler) + orch.run(max_jobs=1) + + # Job should have been auto-resumed and executed + progress = orch.get_progress("resume-on-start") + assert progress["status"] == "completed" + assert calls == [3] # Handler saw checkpoint step 3 + + def test_cancel_job(self, orch): + job_id = orch.submit_job("cancel_test", {"action": "test"}) + orch.cancel_job(job_id) + progress = orch.get_progress(job_id) + assert progress["status"] == "cancelled" + + def test_generate_report(self, orch): + def handler(job): + return {"status": "ok", "token_usage": {"input_tokens": 100, "output_tokens": 50}} + + orch.register_handler("report_test", handler) + orch.submit_batch("report_test", [{"i": i} for i in range(3)]) + orch.run(max_jobs=3) + + report = orch.generate_report("report_test") + assert report["completed"] == 3 + assert report["failed"] == 0 + assert report["success_rate"] == 100.0 + assert report["total_tokens"] == 450 # 3 * 150