Files
timmy-config/pipelines/tests/test_orchestrator.py

334 lines
12 KiB
Python

#!/usr/bin/env python3
"""Tests for pipeline orchestrator — queue, parallelism, resume, token tracking."""
import json
import os
import tempfile
import time
from pathlib import Path
import pytest
# Add project root to path
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from pipelines.orchestrator import (
PipelineOrchestrator,
Job,
JobStatus,
JobPriority,
JobCheckpoint,
TokenUsage,
OrchestratorDB,
RateLimiter,
)
@pytest.fixture
def tmp_db(tmp_path):
"""Fresh orchestrator DB for each test."""
return tmp_path / "test_orchestrator.db"
@pytest.fixture
def orch(tmp_db):
"""Orchestrator instance with temp DB."""
return PipelineOrchestrator(max_workers=2, token_budget=10000, db_path=tmp_db)
class TestJobDataModels:
"""Test Job, TokenUsage, JobCheckpoint dataclasses."""
def test_token_usage_total(self):
usage = TokenUsage(input_tokens=100, output_tokens=50)
assert usage.total_tokens == 150
def test_token_usage_zero(self):
usage = TokenUsage()
assert usage.total_tokens == 0
def test_token_usage_serialization(self):
usage = TokenUsage(input_tokens=10, output_tokens=20, cache_read_tokens=5, cost_usd=0.001)
d = usage.to_dict()
restored = TokenUsage.from_dict(d)
assert restored.input_tokens == 10
assert restored.total_tokens == 30
def test_checkpoint_serialization(self):
cp = JobCheckpoint(job_id="abc", step=3, data={"key": "val"})
d = cp.to_dict()
restored = JobCheckpoint.from_dict(d)
assert restored.step == 3
assert restored.data == {"key": "val"}
def test_job_serialization(self):
job = Job(id="test-1", pipeline="demo", task={"action": "run"})
d = job.to_dict()
restored = Job.from_dict(d)
assert restored.id == "test-1"
assert restored.status == JobStatus.PENDING
assert restored.priority == JobPriority.NORMAL
class TestOrchestratorDB:
"""Test SQLite-backed job queue."""
def test_save_and_get(self, tmp_db):
db = OrchestratorDB(tmp_db)
job = Job(id="j1", pipeline="test", task={"x": 1})
db.save_job(job)
fetched = db.get_job("j1")
assert fetched is not None
assert fetched.id == "j1"
assert fetched.task == {"x": 1}
def test_get_next_job_priority(self, tmp_db):
db = OrchestratorDB(tmp_db)
db.save_job(Job(id="low", pipeline="test", task={}, priority=JobPriority.LOW))
db.save_job(Job(id="high", pipeline="test", task={}, priority=JobPriority.HIGH))
db.save_job(Job(id="normal", pipeline="test", task={}, priority=JobPriority.NORMAL))
next_job = db.get_next_job()
assert next_job.id == "high"
def test_get_next_job_pipeline_filter(self, tmp_db):
db = OrchestratorDB(tmp_db)
db.save_job(Job(id="a", pipeline="alpha", task={}))
db.save_job(Job(id="b", pipeline="beta", task={}))
next_job = db.get_next_job(pipeline="beta")
assert next_job.id == "b"
def test_get_jobs_by_status(self, tmp_db):
db = OrchestratorDB(tmp_db)
db.save_job(Job(id="a", pipeline="test", task={}, status=JobStatus.PENDING))
db.save_job(Job(id="b", pipeline="test", task={}, status=JobStatus.COMPLETED))
pending = db.get_jobs_by_status(JobStatus.PENDING)
assert len(pending) == 1
assert pending[0].id == "a"
def test_checkpoint_save_load(self, tmp_db):
db = OrchestratorDB(tmp_db)
cp = JobCheckpoint(job_id="j1", step=5, data={"progress": "50%"})
db.save_checkpoint("j1", cp)
loaded = db.get_checkpoint("j1")
assert loaded is not None
assert loaded.step == 5
assert loaded.data == {"progress": "50%"}
def test_stats(self, tmp_db):
db = OrchestratorDB(tmp_db)
job = Job(id="j1", pipeline="test", task={}, status=JobStatus.COMPLETED)
job.token_usage = TokenUsage(input_tokens=100, output_tokens=50)
db.save_job(job)
stats = db.get_stats()
assert stats["completed"] == 1
assert stats["total_tokens"] == 150
class TestRateLimiter:
"""Test rate limiter."""
def test_no_limit(self):
rl = RateLimiter()
can_proceed, wait = rl.can_proceed("unknown")
assert can_proceed is True
assert wait == 0.0
def test_rpm_limit(self):
rl = RateLimiter()
rl.configure("test", requests_per_minute=2, tokens_per_minute=1000)
assert rl.can_proceed("test")[0] is True
rl.record_request("test")
assert rl.can_proceed("test")[0] is True
rl.record_request("test")
can_proceed, wait = rl.can_proceed("test")
assert can_proceed is False
assert wait > 0
class TestPipelineOrchestrator:
"""Test the main orchestrator."""
def test_submit_and_retrieve(self, orch):
job_id = orch.submit_job("test_pipeline", {"action": "process"})
assert job_id is not None
progress = orch.get_progress(job_id)
assert progress["status"] == "pending"
assert progress["pipeline"] == "test_pipeline"
def test_submit_batch(self, orch):
ids = orch.submit_batch("test", [{"i": i} for i in range(5)])
assert len(ids) == 5
def test_handler_execution(self, orch):
results = []
def handler(job):
results.append(job.id)
return {"status": "ok"}
orch.register_handler("demo", handler)
job_id = orch.submit_job("demo", {"action": "test"})
orch.run(max_jobs=1)
progress = orch.get_progress(job_id)
assert progress["status"] == "completed"
assert len(results) == 1
def test_handler_failure_and_retry(self, orch):
attempts = []
def handler(job):
attempts.append(1)
if len(attempts) < 3:
raise ValueError("transient error")
return {"status": "ok"}
orch.register_handler("retry_test", handler)
job_id = orch.submit_job("retry_test", {"action": "test"}, max_retries=3)
orch.run(max_jobs=1)
progress = orch.get_progress(job_id)
assert progress["status"] == "completed"
assert len(attempts) == 3
def test_handler_exhausts_retries(self, orch):
def handler(job):
raise ValueError("permanent error")
orch.register_handler("fail_test", handler)
job_id = orch.submit_job("fail_test", {"action": "test"}, max_retries=2)
orch.run(max_jobs=1)
progress = orch.get_progress(job_id)
assert progress["status"] == "failed"
assert "permanent error" in progress["error"]
def test_no_handler(self, orch):
job_id = orch.submit_job("nonexistent", {"action": "test"})
orch.run(max_jobs=1)
progress = orch.get_progress(job_id)
assert progress["status"] == "failed"
assert "No handler" in progress["error"]
def test_token_budget_tracking(self, orch):
def handler(job):
return {"status": "ok", "token_usage": {"input_tokens": 500, "output_tokens": 200}}
orch.register_handler("token_test", handler)
job_id = orch.submit_job("token_test", {"action": "test"}, token_budget=1000)
orch.run(max_jobs=1)
progress = orch.get_progress(job_id)
assert progress["token_usage"]["input_tokens"] == 500
assert progress["token_usage"]["output_tokens"] == 200
def test_token_budget_exceeded(self, orch):
def handler(job):
return {"status": "ok"}
orch.register_handler("budget_test", handler)
# Set job with already-exhausted budget by manipulating DB
job_id = orch.submit_job("budget_test", {"action": "test"}, token_budget=100)
job = orch.db.get_job(job_id)
job.token_usage = TokenUsage(input_tokens=100, output_tokens=10)
orch.db.save_job(job)
orch.run(max_jobs=1)
progress = orch.get_progress(job_id)
assert progress["status"] == "failed"
assert "budget" in progress["error"].lower()
def test_parallel_execution(self, orch):
"""Verify workers execute in parallel."""
import threading
active = set()
max_concurrent = [0]
def handler(job):
active.add(threading.current_thread().name)
max_concurrent[0] = max(max_concurrent[0], len(active))
time.sleep(0.1)
active.discard(threading.current_thread().name)
return {"status": "ok"}
orch.register_handler("parallel", handler)
orch.submit_batch("parallel", [{"i": i} for i in range(4)])
orch.run(max_jobs=4)
# With 2 workers, at least 2 should have been active simultaneously
assert max_concurrent[0] >= 2
def test_resume_paused_job(self, orch):
"""Test resume from checkpoint."""
call_count = [0]
def handler(job):
call_count[0] += 1
if call_count[0] == 1:
# Simulate saving checkpoint before failure
job.checkpoint = JobCheckpoint(job_id=job.id, step=1, data={"partial": True})
orch.db.save_checkpoint(job.id, job.checkpoint)
raise ValueError("first attempt fails")
# Second attempt succeeds
return {"status": "ok", "resumed_from": job.checkpoint.step if job.checkpoint else None}
orch.register_handler("resume_test", handler)
job_id = orch.submit_job("resume_test", {"action": "test"}, max_retries=3)
# First run — fails, saves checkpoint
orch.run(max_jobs=1)
# Manually resume (set to pending)
job = orch.db.get_job(job_id)
if job.status == JobStatus.FAILED:
job.status = JobStatus.PENDING
job.retry_count = 0
job.error = None
orch.db.save_job(job)
orch.run(max_jobs=1)
progress = orch.get_progress(job_id)
assert progress["status"] == "completed"
def test_resume_on_restart(self, orch):
"""Test that run() resumes paused/running jobs with checkpoints on startup."""
# Create a paused job with a checkpoint
job = Job(id="resume-on-start", pipeline="restart_test", task={"action": "resume"})
job.status = JobStatus.PAUSED
orch.db.save_job(job)
orch.db.save_checkpoint("resume-on-start", JobCheckpoint(
job_id="resume-on-start", step=3, data={"progress": "50%"}
))
calls = []
def handler(job):
calls.append(job.checkpoint.step if job.checkpoint else None)
return {"status": "ok"}
orch.register_handler("restart_test", handler)
orch.run(max_jobs=1)
# Job should have been auto-resumed and executed
progress = orch.get_progress("resume-on-start")
assert progress["status"] == "completed"
assert calls == [3] # Handler saw checkpoint step 3
def test_cancel_job(self, orch):
job_id = orch.submit_job("cancel_test", {"action": "test"})
orch.cancel_job(job_id)
progress = orch.get_progress(job_id)
assert progress["status"] == "cancelled"
def test_generate_report(self, orch):
def handler(job):
return {"status": "ok", "token_usage": {"input_tokens": 100, "output_tokens": 50}}
orch.register_handler("report_test", handler)
orch.submit_batch("report_test", [{"i": i} for i in range(3)])
orch.run(max_jobs=3)
report = orch.generate_report("report_test")
assert report["completed"] == 3
assert report["failed"] == 0
assert report["success_rate"] == 100.0
assert report["total_tokens"] == 450 # 3 * 150