334 lines
12 KiB
Python
334 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""Tests for pipeline orchestrator — queue, parallelism, resume, token tracking."""
|
|
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
# Add project root to path
|
|
import sys
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
from pipelines.orchestrator import (
|
|
PipelineOrchestrator,
|
|
Job,
|
|
JobStatus,
|
|
JobPriority,
|
|
JobCheckpoint,
|
|
TokenUsage,
|
|
OrchestratorDB,
|
|
RateLimiter,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_db(tmp_path):
|
|
"""Fresh orchestrator DB for each test."""
|
|
return tmp_path / "test_orchestrator.db"
|
|
|
|
|
|
@pytest.fixture
|
|
def orch(tmp_db):
|
|
"""Orchestrator instance with temp DB."""
|
|
return PipelineOrchestrator(max_workers=2, token_budget=10000, db_path=tmp_db)
|
|
|
|
|
|
class TestJobDataModels:
|
|
"""Test Job, TokenUsage, JobCheckpoint dataclasses."""
|
|
|
|
def test_token_usage_total(self):
|
|
usage = TokenUsage(input_tokens=100, output_tokens=50)
|
|
assert usage.total_tokens == 150
|
|
|
|
def test_token_usage_zero(self):
|
|
usage = TokenUsage()
|
|
assert usage.total_tokens == 0
|
|
|
|
def test_token_usage_serialization(self):
|
|
usage = TokenUsage(input_tokens=10, output_tokens=20, cache_read_tokens=5, cost_usd=0.001)
|
|
d = usage.to_dict()
|
|
restored = TokenUsage.from_dict(d)
|
|
assert restored.input_tokens == 10
|
|
assert restored.total_tokens == 30
|
|
|
|
def test_checkpoint_serialization(self):
|
|
cp = JobCheckpoint(job_id="abc", step=3, data={"key": "val"})
|
|
d = cp.to_dict()
|
|
restored = JobCheckpoint.from_dict(d)
|
|
assert restored.step == 3
|
|
assert restored.data == {"key": "val"}
|
|
|
|
def test_job_serialization(self):
|
|
job = Job(id="test-1", pipeline="demo", task={"action": "run"})
|
|
d = job.to_dict()
|
|
restored = Job.from_dict(d)
|
|
assert restored.id == "test-1"
|
|
assert restored.status == JobStatus.PENDING
|
|
assert restored.priority == JobPriority.NORMAL
|
|
|
|
|
|
class TestOrchestratorDB:
|
|
"""Test SQLite-backed job queue."""
|
|
|
|
def test_save_and_get(self, tmp_db):
|
|
db = OrchestratorDB(tmp_db)
|
|
job = Job(id="j1", pipeline="test", task={"x": 1})
|
|
db.save_job(job)
|
|
fetched = db.get_job("j1")
|
|
assert fetched is not None
|
|
assert fetched.id == "j1"
|
|
assert fetched.task == {"x": 1}
|
|
|
|
def test_get_next_job_priority(self, tmp_db):
|
|
db = OrchestratorDB(tmp_db)
|
|
db.save_job(Job(id="low", pipeline="test", task={}, priority=JobPriority.LOW))
|
|
db.save_job(Job(id="high", pipeline="test", task={}, priority=JobPriority.HIGH))
|
|
db.save_job(Job(id="normal", pipeline="test", task={}, priority=JobPriority.NORMAL))
|
|
next_job = db.get_next_job()
|
|
assert next_job.id == "high"
|
|
|
|
def test_get_next_job_pipeline_filter(self, tmp_db):
|
|
db = OrchestratorDB(tmp_db)
|
|
db.save_job(Job(id="a", pipeline="alpha", task={}))
|
|
db.save_job(Job(id="b", pipeline="beta", task={}))
|
|
next_job = db.get_next_job(pipeline="beta")
|
|
assert next_job.id == "b"
|
|
|
|
def test_get_jobs_by_status(self, tmp_db):
|
|
db = OrchestratorDB(tmp_db)
|
|
db.save_job(Job(id="a", pipeline="test", task={}, status=JobStatus.PENDING))
|
|
db.save_job(Job(id="b", pipeline="test", task={}, status=JobStatus.COMPLETED))
|
|
pending = db.get_jobs_by_status(JobStatus.PENDING)
|
|
assert len(pending) == 1
|
|
assert pending[0].id == "a"
|
|
|
|
def test_checkpoint_save_load(self, tmp_db):
|
|
db = OrchestratorDB(tmp_db)
|
|
cp = JobCheckpoint(job_id="j1", step=5, data={"progress": "50%"})
|
|
db.save_checkpoint("j1", cp)
|
|
loaded = db.get_checkpoint("j1")
|
|
assert loaded is not None
|
|
assert loaded.step == 5
|
|
assert loaded.data == {"progress": "50%"}
|
|
|
|
def test_stats(self, tmp_db):
|
|
db = OrchestratorDB(tmp_db)
|
|
job = Job(id="j1", pipeline="test", task={}, status=JobStatus.COMPLETED)
|
|
job.token_usage = TokenUsage(input_tokens=100, output_tokens=50)
|
|
db.save_job(job)
|
|
stats = db.get_stats()
|
|
assert stats["completed"] == 1
|
|
assert stats["total_tokens"] == 150
|
|
|
|
|
|
class TestRateLimiter:
|
|
"""Test rate limiter."""
|
|
|
|
def test_no_limit(self):
|
|
rl = RateLimiter()
|
|
can_proceed, wait = rl.can_proceed("unknown")
|
|
assert can_proceed is True
|
|
assert wait == 0.0
|
|
|
|
def test_rpm_limit(self):
|
|
rl = RateLimiter()
|
|
rl.configure("test", requests_per_minute=2, tokens_per_minute=1000)
|
|
assert rl.can_proceed("test")[0] is True
|
|
rl.record_request("test")
|
|
assert rl.can_proceed("test")[0] is True
|
|
rl.record_request("test")
|
|
can_proceed, wait = rl.can_proceed("test")
|
|
assert can_proceed is False
|
|
assert wait > 0
|
|
|
|
|
|
class TestPipelineOrchestrator:
|
|
"""Test the main orchestrator."""
|
|
|
|
def test_submit_and_retrieve(self, orch):
|
|
job_id = orch.submit_job("test_pipeline", {"action": "process"})
|
|
assert job_id is not None
|
|
progress = orch.get_progress(job_id)
|
|
assert progress["status"] == "pending"
|
|
assert progress["pipeline"] == "test_pipeline"
|
|
|
|
def test_submit_batch(self, orch):
|
|
ids = orch.submit_batch("test", [{"i": i} for i in range(5)])
|
|
assert len(ids) == 5
|
|
|
|
def test_handler_execution(self, orch):
|
|
results = []
|
|
def handler(job):
|
|
results.append(job.id)
|
|
return {"status": "ok"}
|
|
|
|
orch.register_handler("demo", handler)
|
|
job_id = orch.submit_job("demo", {"action": "test"})
|
|
orch.run(max_jobs=1)
|
|
|
|
progress = orch.get_progress(job_id)
|
|
assert progress["status"] == "completed"
|
|
assert len(results) == 1
|
|
|
|
def test_handler_failure_and_retry(self, orch):
|
|
attempts = []
|
|
def handler(job):
|
|
attempts.append(1)
|
|
if len(attempts) < 3:
|
|
raise ValueError("transient error")
|
|
return {"status": "ok"}
|
|
|
|
orch.register_handler("retry_test", handler)
|
|
job_id = orch.submit_job("retry_test", {"action": "test"}, max_retries=3)
|
|
orch.run(max_jobs=1)
|
|
|
|
progress = orch.get_progress(job_id)
|
|
assert progress["status"] == "completed"
|
|
assert len(attempts) == 3
|
|
|
|
def test_handler_exhausts_retries(self, orch):
|
|
def handler(job):
|
|
raise ValueError("permanent error")
|
|
|
|
orch.register_handler("fail_test", handler)
|
|
job_id = orch.submit_job("fail_test", {"action": "test"}, max_retries=2)
|
|
orch.run(max_jobs=1)
|
|
|
|
progress = orch.get_progress(job_id)
|
|
assert progress["status"] == "failed"
|
|
assert "permanent error" in progress["error"]
|
|
|
|
def test_no_handler(self, orch):
|
|
job_id = orch.submit_job("nonexistent", {"action": "test"})
|
|
orch.run(max_jobs=1)
|
|
progress = orch.get_progress(job_id)
|
|
assert progress["status"] == "failed"
|
|
assert "No handler" in progress["error"]
|
|
|
|
def test_token_budget_tracking(self, orch):
|
|
def handler(job):
|
|
return {"status": "ok", "token_usage": {"input_tokens": 500, "output_tokens": 200}}
|
|
|
|
orch.register_handler("token_test", handler)
|
|
job_id = orch.submit_job("token_test", {"action": "test"}, token_budget=1000)
|
|
orch.run(max_jobs=1)
|
|
|
|
progress = orch.get_progress(job_id)
|
|
assert progress["token_usage"]["input_tokens"] == 500
|
|
assert progress["token_usage"]["output_tokens"] == 200
|
|
|
|
def test_token_budget_exceeded(self, orch):
|
|
def handler(job):
|
|
return {"status": "ok"}
|
|
|
|
orch.register_handler("budget_test", handler)
|
|
# Set job with already-exhausted budget by manipulating DB
|
|
job_id = orch.submit_job("budget_test", {"action": "test"}, token_budget=100)
|
|
job = orch.db.get_job(job_id)
|
|
job.token_usage = TokenUsage(input_tokens=100, output_tokens=10)
|
|
orch.db.save_job(job)
|
|
|
|
orch.run(max_jobs=1)
|
|
progress = orch.get_progress(job_id)
|
|
assert progress["status"] == "failed"
|
|
assert "budget" in progress["error"].lower()
|
|
|
|
def test_parallel_execution(self, orch):
|
|
"""Verify workers execute in parallel."""
|
|
import threading
|
|
active = set()
|
|
max_concurrent = [0]
|
|
|
|
def handler(job):
|
|
active.add(threading.current_thread().name)
|
|
max_concurrent[0] = max(max_concurrent[0], len(active))
|
|
time.sleep(0.1)
|
|
active.discard(threading.current_thread().name)
|
|
return {"status": "ok"}
|
|
|
|
orch.register_handler("parallel", handler)
|
|
orch.submit_batch("parallel", [{"i": i} for i in range(4)])
|
|
orch.run(max_jobs=4)
|
|
|
|
# With 2 workers, at least 2 should have been active simultaneously
|
|
assert max_concurrent[0] >= 2
|
|
|
|
def test_resume_paused_job(self, orch):
|
|
"""Test resume from checkpoint."""
|
|
call_count = [0]
|
|
|
|
def handler(job):
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
# Simulate saving checkpoint before failure
|
|
job.checkpoint = JobCheckpoint(job_id=job.id, step=1, data={"partial": True})
|
|
orch.db.save_checkpoint(job.id, job.checkpoint)
|
|
raise ValueError("first attempt fails")
|
|
# Second attempt succeeds
|
|
return {"status": "ok", "resumed_from": job.checkpoint.step if job.checkpoint else None}
|
|
|
|
orch.register_handler("resume_test", handler)
|
|
job_id = orch.submit_job("resume_test", {"action": "test"}, max_retries=3)
|
|
|
|
# First run — fails, saves checkpoint
|
|
orch.run(max_jobs=1)
|
|
|
|
# Manually resume (set to pending)
|
|
job = orch.db.get_job(job_id)
|
|
if job.status == JobStatus.FAILED:
|
|
job.status = JobStatus.PENDING
|
|
job.retry_count = 0
|
|
job.error = None
|
|
orch.db.save_job(job)
|
|
orch.run(max_jobs=1)
|
|
|
|
progress = orch.get_progress(job_id)
|
|
assert progress["status"] == "completed"
|
|
|
|
def test_resume_on_restart(self, orch):
|
|
"""Test that run() resumes paused/running jobs with checkpoints on startup."""
|
|
# Create a paused job with a checkpoint
|
|
job = Job(id="resume-on-start", pipeline="restart_test", task={"action": "resume"})
|
|
job.status = JobStatus.PAUSED
|
|
orch.db.save_job(job)
|
|
orch.db.save_checkpoint("resume-on-start", JobCheckpoint(
|
|
job_id="resume-on-start", step=3, data={"progress": "50%"}
|
|
))
|
|
|
|
calls = []
|
|
def handler(job):
|
|
calls.append(job.checkpoint.step if job.checkpoint else None)
|
|
return {"status": "ok"}
|
|
|
|
orch.register_handler("restart_test", handler)
|
|
orch.run(max_jobs=1)
|
|
|
|
# Job should have been auto-resumed and executed
|
|
progress = orch.get_progress("resume-on-start")
|
|
assert progress["status"] == "completed"
|
|
assert calls == [3] # Handler saw checkpoint step 3
|
|
|
|
def test_cancel_job(self, orch):
|
|
job_id = orch.submit_job("cancel_test", {"action": "test"})
|
|
orch.cancel_job(job_id)
|
|
progress = orch.get_progress(job_id)
|
|
assert progress["status"] == "cancelled"
|
|
|
|
def test_generate_report(self, orch):
|
|
def handler(job):
|
|
return {"status": "ok", "token_usage": {"input_tokens": 100, "output_tokens": 50}}
|
|
|
|
orch.register_handler("report_test", handler)
|
|
orch.submit_batch("report_test", [{"i": i} for i in range(3)])
|
|
orch.run(max_jobs=3)
|
|
|
|
report = orch.generate_report("report_test")
|
|
assert report["completed"] == 3
|
|
assert report["failed"] == 0
|
|
assert report["success_rate"] == 100.0
|
|
assert report["total_tokens"] == 450 # 3 * 150
|