Compare commits
5 Commits
feat/647-s
...
burn/621-s
| Author | SHA1 | Date | |
|---|---|---|---|
| 15713958e6 | |||
| 776597712f | |||
| 3250eba0cc | |||
| 99d4facdad | |||
| c808c4efb3 |
@@ -22,13 +22,187 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
import math
|
||||
import re
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional, Dict, Any, Set
|
||||
|
||||
PIPELINE_DIR = Path.home() / ".hermes" / "pipeline"
|
||||
STATS_FILE = PIPELINE_DIR / "quality_stats.json"
|
||||
HASH_DIR = PIPELINE_DIR / "quality_hashes"
|
||||
HASH_RETENTION_DAYS = 7 # Keep hashes for 7 days
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Bloom Filter — Memory-efficient dedup at scale
|
||||
# ============================================================
|
||||
|
||||
class BloomFilter:
|
||||
"""Probabilistic set for membership testing. False positives possible, no false negatives."""
|
||||
|
||||
def __init__(self, capacity: int = 100_000, error_rate: float = 0.01):
|
||||
self.capacity = capacity
|
||||
self.error_rate = error_rate
|
||||
# Optimal size and hash count
|
||||
self.size = max(64, int(-capacity * math.log(error_rate) / (math.log(2) ** 2)))
|
||||
self.num_hashes = max(1, int(self.size / capacity * math.log(2)))
|
||||
self._bitarray = bytearray((self.size + 7) // 8)
|
||||
|
||||
def _hash_indices(self, item: str) -> List[int]:
|
||||
"""Generate bit indices using double hashing."""
|
||||
h1 = int.from_bytes(hashlib.sha256(item.encode()).digest()[:8], "little")
|
||||
h2 = int.from_bytes(hashlib.md5(item.encode()).digest()[:8], "little")
|
||||
return [(h1 + i * h2) % self.size for i in range(self.num_hashes)]
|
||||
|
||||
def add(self, item: str):
|
||||
for idx in self._hash_indices(item):
|
||||
self._bitarray[idx // 8] |= 1 << (idx % 8)
|
||||
|
||||
def __contains__(self, item: str) -> bool:
|
||||
return all(self._bitarray[idx // 8] & (1 << (idx % 8)) for idx in self._hash_indices(item))
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"capacity": self.capacity,
|
||||
"error_rate": self.error_rate,
|
||||
"size": self.size,
|
||||
"num_hashes": self.num_hashes,
|
||||
"data": base64.b64encode(bytes(self._bitarray)).decode(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "BloomFilter":
|
||||
bf = cls(capacity=d["capacity"], error_rate=d["error_rate"])
|
||||
bf._bitarray = bytearray(base64.b64decode(d["data"]))
|
||||
return bf
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Hash Dedup Store — Rotating daily files + bloom filter
|
||||
# ============================================================
|
||||
|
||||
class HashDedupStore:
|
||||
"""Rotating hash store for cross-run deduplication.
|
||||
|
||||
Strategy:
|
||||
- Daily JSON files: HASH_DIR/YYYY-MM-DD.json (set of 16-char hashes)
|
||||
- Bloom filter: HASH_DIR/bloom.json (memory-efficient for large scale)
|
||||
- On load: merge last N days into bloom filter
|
||||
- Rotation: delete files older than HASH_RETENTION_DAYS
|
||||
"""
|
||||
|
||||
def __init__(self, retention_days: int = HASH_RETENTION_DAYS):
|
||||
self.retention_days = retention_days
|
||||
HASH_DIR.mkdir(parents=True, exist_ok=True)
|
||||
self._today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
self._daily_hashes: Set[str] = set()
|
||||
self._bloom: Optional[BloomFilter] = None
|
||||
self._load()
|
||||
|
||||
def _day_file(self, day: str) -> Path:
|
||||
return HASH_DIR / f"{day}.json"
|
||||
|
||||
def _bloom_file(self) -> Path:
|
||||
return HASH_DIR / "bloom.json"
|
||||
|
||||
def _load(self):
|
||||
"""Load today's hashes and bloom filter."""
|
||||
# Load today's file
|
||||
day_path = self._day_file(self._today)
|
||||
if day_path.exists():
|
||||
try:
|
||||
self._daily_hashes = set(json.loads(day_path.read_text()))
|
||||
except (json.JSONDecodeError, IOError):
|
||||
self._daily_hashes = set()
|
||||
|
||||
# Load or rebuild bloom filter
|
||||
bloom_path = self._bloom_file()
|
||||
if bloom_path.exists():
|
||||
try:
|
||||
self._bloom = BloomFilter.from_dict(json.loads(bloom_path.read_text()))
|
||||
except (json.JSONDecodeError, IOError, KeyError):
|
||||
self._bloom = None
|
||||
|
||||
if self._bloom is None:
|
||||
self._rebuild_bloom()
|
||||
|
||||
def _rebuild_bloom(self):
|
||||
"""Rebuild bloom filter from all recent daily files."""
|
||||
hashes = set()
|
||||
for day_offset in range(self.retention_days):
|
||||
day = (datetime.now(timezone.utc) - timedelta(days=day_offset)).strftime("%Y-%m-%d")
|
||||
day_path = self._day_file(day)
|
||||
if day_path.exists():
|
||||
try:
|
||||
hashes.update(json.loads(day_path.read_text()))
|
||||
except (json.JSONDecodeError, IOError):
|
||||
pass
|
||||
|
||||
capacity = max(len(hashes) * 2, 10_000)
|
||||
self._bloom = BloomFilter(capacity=capacity)
|
||||
for h in hashes:
|
||||
self._bloom.add(h)
|
||||
|
||||
def _save(self):
|
||||
"""Persist today's hashes and bloom filter."""
|
||||
day_path = self._day_file(self._today)
|
||||
day_path.write_text(json.dumps(sorted(self._daily_hashes)))
|
||||
|
||||
if self._bloom:
|
||||
self._bloom_file().write_text(json.dumps(self._bloom.to_dict()))
|
||||
|
||||
def _rotate(self):
|
||||
"""Delete daily hash files older than retention period."""
|
||||
cutoff = (datetime.now(timezone.utc) - timedelta(days=self.retention_days)).strftime("%Y-%m-%d")
|
||||
for path in HASH_DIR.glob("*.json"):
|
||||
name = path.stem
|
||||
if len(name) == 10 and name < cutoff and name != "bloom":
|
||||
path.unlink()
|
||||
|
||||
def is_duplicate(self, h: str) -> bool:
|
||||
"""Check if hash has been seen in current day or bloom filter."""
|
||||
if h in self._daily_hashes:
|
||||
return True
|
||||
if self._bloom and h in self._bloom:
|
||||
return True
|
||||
return False
|
||||
|
||||
def add(self, h: str):
|
||||
"""Add a hash. Saves and rotates periodically."""
|
||||
self._daily_hashes.add(h)
|
||||
if self._bloom:
|
||||
self._bloom.add(h)
|
||||
# Save every 100 additions or on explicit call
|
||||
if len(self._daily_hashes) % 100 == 0:
|
||||
self._save()
|
||||
self._rotate()
|
||||
|
||||
def flush(self):
|
||||
"""Force save and rotate."""
|
||||
self._save()
|
||||
self._rotate()
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Return dedup store statistics."""
|
||||
file_count = len(list(HASH_DIR.glob("*.json")))
|
||||
total_hashes = 0
|
||||
for path in HASH_DIR.glob("????-??-??.json"):
|
||||
try:
|
||||
total_hashes += len(json.loads(path.read_text()))
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"today_count": len(self._daily_hashes),
|
||||
"total_files": file_count,
|
||||
"total_hashes": total_hashes,
|
||||
"retention_days": self.retention_days,
|
||||
"bloom_size": self._bloom.size if self._bloom else 0,
|
||||
}
|
||||
|
||||
|
||||
STATS_FILE = Path.home() / ".hermes" / "pipeline" / "quality_stats.json"
|
||||
|
||||
# --- Quality Check Types ---
|
||||
|
||||
@@ -228,8 +402,14 @@ CHECK_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def run_gate(input_path: str, entry_type: str) -> GateReport:
|
||||
"""Run quality gate on a JSONL file."""
|
||||
def run_gate(input_path: str, entry_type: str, dedup_store: Optional[HashDedupStore] = None) -> GateReport:
|
||||
"""Run quality gate on a JSONL file.
|
||||
|
||||
Args:
|
||||
input_path: Path to JSONL file
|
||||
entry_type: Type of entries (training_pairs, scene_descriptions, etc.)
|
||||
dedup_store: Optional hash dedup store for cross-run dedup. If None, creates one.
|
||||
"""
|
||||
path = Path(input_path)
|
||||
if not path.exists():
|
||||
return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0)
|
||||
@@ -239,6 +419,9 @@ def run_gate(input_path: str, entry_type: str) -> GateReport:
|
||||
return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0,
|
||||
rejected_indices=[-1]) # unknown type
|
||||
|
||||
if dedup_store is None:
|
||||
dedup_store = HashDedupStore()
|
||||
|
||||
entries = []
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
@@ -246,7 +429,7 @@ def run_gate(input_path: str, entry_type: str) -> GateReport:
|
||||
if line:
|
||||
entries.append(json.loads(line))
|
||||
|
||||
# Deduplication check
|
||||
# Within-file deduplication check
|
||||
key_fields = _get_key_fields(entry_type)
|
||||
dup_errors = check_no_duplicates(entries, key_fields)
|
||||
|
||||
@@ -254,13 +437,22 @@ def run_gate(input_path: str, entry_type: str) -> GateReport:
|
||||
rejected = 0
|
||||
rejected_indices = []
|
||||
total_score = 0.0
|
||||
cross_run_dupes = 0
|
||||
|
||||
for i, entry in enumerate(entries):
|
||||
errors = check_fn(entry)
|
||||
|
||||
# Add duplicate errors
|
||||
# Add within-file duplicate errors
|
||||
if i in dup_errors:
|
||||
errors.extend(dup_errors[i])
|
||||
|
||||
# Cross-run hash dedup
|
||||
h = entry_hash(entry)
|
||||
if dedup_store.is_duplicate(h):
|
||||
errors.append(f"cross_run_duplicate: hash {h} seen in prior run")
|
||||
cross_run_dupes += 1
|
||||
else:
|
||||
dedup_store.add(h)
|
||||
|
||||
# Add SOUL compliance check for text content
|
||||
text_content = ""
|
||||
@@ -286,6 +478,9 @@ def run_gate(input_path: str, entry_type: str) -> GateReport:
|
||||
|
||||
avg_score = total_score / len(entries) if entries else 0.0
|
||||
|
||||
# Flush dedup store
|
||||
dedup_store.flush()
|
||||
|
||||
report = GateReport(
|
||||
file=str(path),
|
||||
type=entry_type,
|
||||
@@ -299,6 +494,10 @@ def run_gate(input_path: str, entry_type: str) -> GateReport:
|
||||
# Save stats
|
||||
_save_stats(report)
|
||||
|
||||
if cross_run_dupes > 0:
|
||||
logger_msg = f" cross-run dedup: {cross_run_dupes} duplicates found"
|
||||
print(logger_msg, file=sys.stderr)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
@@ -318,7 +517,7 @@ def _get_key_fields(entry_type: str) -> List[str]:
|
||||
|
||||
|
||||
def _save_stats(report: GateReport):
|
||||
"""Append quality stats to the stats file."""
|
||||
"""Append quality stats to the stats file. Rotates to keep last 1000."""
|
||||
STATS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stats = []
|
||||
@@ -331,8 +530,9 @@ def _save_stats(report: GateReport):
|
||||
|
||||
stats.append(report.to_dict())
|
||||
|
||||
# Keep last 1000 entries
|
||||
stats = stats[-1000:]
|
||||
# Rotate: keep last 1000 entries
|
||||
if len(stats) > 1000:
|
||||
stats = stats[-1000:]
|
||||
|
||||
with open(STATS_FILE, "w") as f:
|
||||
json.dump(stats, f, indent=2)
|
||||
|
||||
22
pipelines/__init__.py
Normal file
22
pipelines/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Pipeline infrastructure — shared orchestrator."""
|
||||
from .orchestrator import (
|
||||
PipelineOrchestrator,
|
||||
OrchestratorDB,
|
||||
Job,
|
||||
JobStatus,
|
||||
JobPriority,
|
||||
JobCheckpoint,
|
||||
TokenUsage,
|
||||
RateLimiter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PipelineOrchestrator",
|
||||
"OrchestratorDB",
|
||||
"Job",
|
||||
"JobStatus",
|
||||
"JobPriority",
|
||||
"JobCheckpoint",
|
||||
"TokenUsage",
|
||||
"RateLimiter",
|
||||
]
|
||||
@@ -574,33 +574,67 @@ class PipelineOrchestrator:
|
||||
return job
|
||||
|
||||
def run(self, pipeline: Optional[str] = None, max_jobs: Optional[int] = None):
|
||||
"""Run the orchestrator, processing jobs from the queue."""
|
||||
"""Run the orchestrator, processing jobs from the queue.
|
||||
|
||||
On startup, checks for paused/running jobs with checkpoints and
|
||||
resumes them first before picking up new pending jobs.
|
||||
"""
|
||||
self.running = True
|
||||
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
||||
futures = {}
|
||||
|
||||
logger.info(f"Orchestrator starting (workers={self.max_workers})")
|
||||
|
||||
# Resume paused jobs with checkpoints on restart
|
||||
for status in (JobStatus.PAUSED, JobStatus.RUNNING):
|
||||
for job in self.db.get_jobs_by_status(status, pipeline):
|
||||
if job.checkpoint:
|
||||
logger.info(f"Resuming {status.value} job {job.id} from checkpoint step {job.checkpoint.step}")
|
||||
job.status = JobStatus.PENDING
|
||||
self.db.save_job(job)
|
||||
|
||||
try:
|
||||
jobs_processed = 0
|
||||
|
||||
while self.running:
|
||||
# Check completed futures
|
||||
done = [f for f in futures if f.done()]
|
||||
for f in done:
|
||||
try:
|
||||
f.result() # propagate exceptions
|
||||
except Exception as e:
|
||||
logger.error(f"Worker error: {e}")
|
||||
del futures[f]
|
||||
|
||||
# Throttle if at capacity
|
||||
if len(futures) >= self.max_workers:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Get next job
|
||||
job = self.db.get_next_job(pipeline)
|
||||
|
||||
if not job:
|
||||
# No pending jobs, wait a bit
|
||||
time.sleep(1)
|
||||
if not futures:
|
||||
# No jobs and no workers — done
|
||||
break
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
# Submit to thread pool
|
||||
future = self.executor.submit(self._execute_job, job)
|
||||
|
||||
# Don't wait for completion, get next job
|
||||
futures[future] = job.id
|
||||
jobs_processed += 1
|
||||
|
||||
if max_jobs and jobs_processed >= max_jobs:
|
||||
logger.info(f"Reached max_jobs limit ({max_jobs})")
|
||||
break
|
||||
|
||||
# Wait for remaining futures
|
||||
for f in futures:
|
||||
try:
|
||||
f.result(timeout=300)
|
||||
except Exception as e:
|
||||
logger.error(f"Worker error on drain: {e}")
|
||||
|
||||
finally:
|
||||
self.executor.shutdown(wait=True)
|
||||
@@ -735,6 +769,11 @@ def main():
|
||||
parser = argparse.ArgumentParser(description="Pipeline Orchestrator")
|
||||
subparsers = parser.add_subparsers(dest="command")
|
||||
|
||||
# List jobs
|
||||
list_parser = subparsers.add_parser("list", help="List jobs")
|
||||
list_parser.add_argument("--status", help="Filter by status")
|
||||
list_parser.add_argument("--pipeline", help="Filter by pipeline")
|
||||
|
||||
# Submit job
|
||||
submit_parser = subparsers.add_parser("submit", help="Submit a job")
|
||||
submit_parser.add_argument("pipeline", help="Pipeline name")
|
||||
@@ -779,6 +818,23 @@ def main():
|
||||
elif args.command == "run":
|
||||
orchestrator.run(args.pipeline, args.max_jobs)
|
||||
|
||||
elif args.command == "list":
|
||||
status_filter = JobStatus(args.status) if args.status else None
|
||||
if status_filter:
|
||||
jobs = orchestrator.db.get_jobs_by_status(status_filter, args.pipeline)
|
||||
else:
|
||||
# Show all jobs
|
||||
conn = orchestrator.db._get_conn()
|
||||
rows = conn.execute("SELECT * FROM jobs ORDER BY priority DESC, created_at ASC").fetchall()
|
||||
conn.close()
|
||||
jobs = [orchestrator.db.get_job(row['id']) for row in rows]
|
||||
for job in jobs:
|
||||
dur = ""
|
||||
if job.started_at and job.completed_at:
|
||||
dur = f" ({job.completed_at - job.started_at:.1f}s)"
|
||||
print(f" {job.id[:8]} {job.status.value:10s} p{job.priority.value} {job.pipeline} tokens={job.token_usage.total_tokens}{dur}")
|
||||
print(f"\n{len(jobs)} jobs")
|
||||
|
||||
elif args.command == "status":
|
||||
progress = orchestrator.get_progress(args.job_id)
|
||||
print(json.dumps(progress, indent=2))
|
||||
|
||||
333
pipelines/tests/test_orchestrator.py
Normal file
333
pipelines/tests/test_orchestrator.py
Normal file
@@ -0,0 +1,333 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for pipeline orchestrator — queue, parallelism, resume, token tracking."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Add project root to path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from pipelines.orchestrator import (
|
||||
PipelineOrchestrator,
|
||||
Job,
|
||||
JobStatus,
|
||||
JobPriority,
|
||||
JobCheckpoint,
|
||||
TokenUsage,
|
||||
OrchestratorDB,
|
||||
RateLimiter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_db(tmp_path):
|
||||
"""Fresh orchestrator DB for each test."""
|
||||
return tmp_path / "test_orchestrator.db"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orch(tmp_db):
|
||||
"""Orchestrator instance with temp DB."""
|
||||
return PipelineOrchestrator(max_workers=2, token_budget=10000, db_path=tmp_db)
|
||||
|
||||
|
||||
class TestJobDataModels:
|
||||
"""Test Job, TokenUsage, JobCheckpoint dataclasses."""
|
||||
|
||||
def test_token_usage_total(self):
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50)
|
||||
assert usage.total_tokens == 150
|
||||
|
||||
def test_token_usage_zero(self):
|
||||
usage = TokenUsage()
|
||||
assert usage.total_tokens == 0
|
||||
|
||||
def test_token_usage_serialization(self):
|
||||
usage = TokenUsage(input_tokens=10, output_tokens=20, cache_read_tokens=5, cost_usd=0.001)
|
||||
d = usage.to_dict()
|
||||
restored = TokenUsage.from_dict(d)
|
||||
assert restored.input_tokens == 10
|
||||
assert restored.total_tokens == 30
|
||||
|
||||
def test_checkpoint_serialization(self):
|
||||
cp = JobCheckpoint(job_id="abc", step=3, data={"key": "val"})
|
||||
d = cp.to_dict()
|
||||
restored = JobCheckpoint.from_dict(d)
|
||||
assert restored.step == 3
|
||||
assert restored.data == {"key": "val"}
|
||||
|
||||
def test_job_serialization(self):
|
||||
job = Job(id="test-1", pipeline="demo", task={"action": "run"})
|
||||
d = job.to_dict()
|
||||
restored = Job.from_dict(d)
|
||||
assert restored.id == "test-1"
|
||||
assert restored.status == JobStatus.PENDING
|
||||
assert restored.priority == JobPriority.NORMAL
|
||||
|
||||
|
||||
class TestOrchestratorDB:
|
||||
"""Test SQLite-backed job queue."""
|
||||
|
||||
def test_save_and_get(self, tmp_db):
|
||||
db = OrchestratorDB(tmp_db)
|
||||
job = Job(id="j1", pipeline="test", task={"x": 1})
|
||||
db.save_job(job)
|
||||
fetched = db.get_job("j1")
|
||||
assert fetched is not None
|
||||
assert fetched.id == "j1"
|
||||
assert fetched.task == {"x": 1}
|
||||
|
||||
def test_get_next_job_priority(self, tmp_db):
|
||||
db = OrchestratorDB(tmp_db)
|
||||
db.save_job(Job(id="low", pipeline="test", task={}, priority=JobPriority.LOW))
|
||||
db.save_job(Job(id="high", pipeline="test", task={}, priority=JobPriority.HIGH))
|
||||
db.save_job(Job(id="normal", pipeline="test", task={}, priority=JobPriority.NORMAL))
|
||||
next_job = db.get_next_job()
|
||||
assert next_job.id == "high"
|
||||
|
||||
def test_get_next_job_pipeline_filter(self, tmp_db):
|
||||
db = OrchestratorDB(tmp_db)
|
||||
db.save_job(Job(id="a", pipeline="alpha", task={}))
|
||||
db.save_job(Job(id="b", pipeline="beta", task={}))
|
||||
next_job = db.get_next_job(pipeline="beta")
|
||||
assert next_job.id == "b"
|
||||
|
||||
def test_get_jobs_by_status(self, tmp_db):
|
||||
db = OrchestratorDB(tmp_db)
|
||||
db.save_job(Job(id="a", pipeline="test", task={}, status=JobStatus.PENDING))
|
||||
db.save_job(Job(id="b", pipeline="test", task={}, status=JobStatus.COMPLETED))
|
||||
pending = db.get_jobs_by_status(JobStatus.PENDING)
|
||||
assert len(pending) == 1
|
||||
assert pending[0].id == "a"
|
||||
|
||||
def test_checkpoint_save_load(self, tmp_db):
|
||||
db = OrchestratorDB(tmp_db)
|
||||
cp = JobCheckpoint(job_id="j1", step=5, data={"progress": "50%"})
|
||||
db.save_checkpoint("j1", cp)
|
||||
loaded = db.get_checkpoint("j1")
|
||||
assert loaded is not None
|
||||
assert loaded.step == 5
|
||||
assert loaded.data == {"progress": "50%"}
|
||||
|
||||
def test_stats(self, tmp_db):
|
||||
db = OrchestratorDB(tmp_db)
|
||||
job = Job(id="j1", pipeline="test", task={}, status=JobStatus.COMPLETED)
|
||||
job.token_usage = TokenUsage(input_tokens=100, output_tokens=50)
|
||||
db.save_job(job)
|
||||
stats = db.get_stats()
|
||||
assert stats["completed"] == 1
|
||||
assert stats["total_tokens"] == 150
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
"""Test rate limiter."""
|
||||
|
||||
def test_no_limit(self):
|
||||
rl = RateLimiter()
|
||||
can_proceed, wait = rl.can_proceed("unknown")
|
||||
assert can_proceed is True
|
||||
assert wait == 0.0
|
||||
|
||||
def test_rpm_limit(self):
|
||||
rl = RateLimiter()
|
||||
rl.configure("test", requests_per_minute=2, tokens_per_minute=1000)
|
||||
assert rl.can_proceed("test")[0] is True
|
||||
rl.record_request("test")
|
||||
assert rl.can_proceed("test")[0] is True
|
||||
rl.record_request("test")
|
||||
can_proceed, wait = rl.can_proceed("test")
|
||||
assert can_proceed is False
|
||||
assert wait > 0
|
||||
|
||||
|
||||
class TestPipelineOrchestrator:
|
||||
"""Test the main orchestrator."""
|
||||
|
||||
def test_submit_and_retrieve(self, orch):
|
||||
job_id = orch.submit_job("test_pipeline", {"action": "process"})
|
||||
assert job_id is not None
|
||||
progress = orch.get_progress(job_id)
|
||||
assert progress["status"] == "pending"
|
||||
assert progress["pipeline"] == "test_pipeline"
|
||||
|
||||
def test_submit_batch(self, orch):
|
||||
ids = orch.submit_batch("test", [{"i": i} for i in range(5)])
|
||||
assert len(ids) == 5
|
||||
|
||||
def test_handler_execution(self, orch):
|
||||
results = []
|
||||
def handler(job):
|
||||
results.append(job.id)
|
||||
return {"status": "ok"}
|
||||
|
||||
orch.register_handler("demo", handler)
|
||||
job_id = orch.submit_job("demo", {"action": "test"})
|
||||
orch.run(max_jobs=1)
|
||||
|
||||
progress = orch.get_progress(job_id)
|
||||
assert progress["status"] == "completed"
|
||||
assert len(results) == 1
|
||||
|
||||
def test_handler_failure_and_retry(self, orch):
|
||||
attempts = []
|
||||
def handler(job):
|
||||
attempts.append(1)
|
||||
if len(attempts) < 3:
|
||||
raise ValueError("transient error")
|
||||
return {"status": "ok"}
|
||||
|
||||
orch.register_handler("retry_test", handler)
|
||||
job_id = orch.submit_job("retry_test", {"action": "test"}, max_retries=3)
|
||||
orch.run(max_jobs=1)
|
||||
|
||||
progress = orch.get_progress(job_id)
|
||||
assert progress["status"] == "completed"
|
||||
assert len(attempts) == 3
|
||||
|
||||
def test_handler_exhausts_retries(self, orch):
|
||||
def handler(job):
|
||||
raise ValueError("permanent error")
|
||||
|
||||
orch.register_handler("fail_test", handler)
|
||||
job_id = orch.submit_job("fail_test", {"action": "test"}, max_retries=2)
|
||||
orch.run(max_jobs=1)
|
||||
|
||||
progress = orch.get_progress(job_id)
|
||||
assert progress["status"] == "failed"
|
||||
assert "permanent error" in progress["error"]
|
||||
|
||||
def test_no_handler(self, orch):
|
||||
job_id = orch.submit_job("nonexistent", {"action": "test"})
|
||||
orch.run(max_jobs=1)
|
||||
progress = orch.get_progress(job_id)
|
||||
assert progress["status"] == "failed"
|
||||
assert "No handler" in progress["error"]
|
||||
|
||||
def test_token_budget_tracking(self, orch):
|
||||
def handler(job):
|
||||
return {"status": "ok", "token_usage": {"input_tokens": 500, "output_tokens": 200}}
|
||||
|
||||
orch.register_handler("token_test", handler)
|
||||
job_id = orch.submit_job("token_test", {"action": "test"}, token_budget=1000)
|
||||
orch.run(max_jobs=1)
|
||||
|
||||
progress = orch.get_progress(job_id)
|
||||
assert progress["token_usage"]["input_tokens"] == 500
|
||||
assert progress["token_usage"]["output_tokens"] == 200
|
||||
|
||||
def test_token_budget_exceeded(self, orch):
|
||||
def handler(job):
|
||||
return {"status": "ok"}
|
||||
|
||||
orch.register_handler("budget_test", handler)
|
||||
# Set job with already-exhausted budget by manipulating DB
|
||||
job_id = orch.submit_job("budget_test", {"action": "test"}, token_budget=100)
|
||||
job = orch.db.get_job(job_id)
|
||||
job.token_usage = TokenUsage(input_tokens=100, output_tokens=10)
|
||||
orch.db.save_job(job)
|
||||
|
||||
orch.run(max_jobs=1)
|
||||
progress = orch.get_progress(job_id)
|
||||
assert progress["status"] == "failed"
|
||||
assert "budget" in progress["error"].lower()
|
||||
|
||||
def test_parallel_execution(self, orch):
|
||||
"""Verify workers execute in parallel."""
|
||||
import threading
|
||||
active = set()
|
||||
max_concurrent = [0]
|
||||
|
||||
def handler(job):
|
||||
active.add(threading.current_thread().name)
|
||||
max_concurrent[0] = max(max_concurrent[0], len(active))
|
||||
time.sleep(0.1)
|
||||
active.discard(threading.current_thread().name)
|
||||
return {"status": "ok"}
|
||||
|
||||
orch.register_handler("parallel", handler)
|
||||
orch.submit_batch("parallel", [{"i": i} for i in range(4)])
|
||||
orch.run(max_jobs=4)
|
||||
|
||||
# With 2 workers, at least 2 should have been active simultaneously
|
||||
assert max_concurrent[0] >= 2
|
||||
|
||||
def test_resume_paused_job(self, orch):
|
||||
"""Test resume from checkpoint."""
|
||||
call_count = [0]
|
||||
|
||||
def handler(job):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
# Simulate saving checkpoint before failure
|
||||
job.checkpoint = JobCheckpoint(job_id=job.id, step=1, data={"partial": True})
|
||||
orch.db.save_checkpoint(job.id, job.checkpoint)
|
||||
raise ValueError("first attempt fails")
|
||||
# Second attempt succeeds
|
||||
return {"status": "ok", "resumed_from": job.checkpoint.step if job.checkpoint else None}
|
||||
|
||||
orch.register_handler("resume_test", handler)
|
||||
job_id = orch.submit_job("resume_test", {"action": "test"}, max_retries=3)
|
||||
|
||||
# First run — fails, saves checkpoint
|
||||
orch.run(max_jobs=1)
|
||||
|
||||
# Manually resume (set to pending)
|
||||
job = orch.db.get_job(job_id)
|
||||
if job.status == JobStatus.FAILED:
|
||||
job.status = JobStatus.PENDING
|
||||
job.retry_count = 0
|
||||
job.error = None
|
||||
orch.db.save_job(job)
|
||||
orch.run(max_jobs=1)
|
||||
|
||||
progress = orch.get_progress(job_id)
|
||||
assert progress["status"] == "completed"
|
||||
|
||||
def test_resume_on_restart(self, orch):
|
||||
"""Test that run() resumes paused/running jobs with checkpoints on startup."""
|
||||
# Create a paused job with a checkpoint
|
||||
job = Job(id="resume-on-start", pipeline="restart_test", task={"action": "resume"})
|
||||
job.status = JobStatus.PAUSED
|
||||
orch.db.save_job(job)
|
||||
orch.db.save_checkpoint("resume-on-start", JobCheckpoint(
|
||||
job_id="resume-on-start", step=3, data={"progress": "50%"}
|
||||
))
|
||||
|
||||
calls = []
|
||||
def handler(job):
|
||||
calls.append(job.checkpoint.step if job.checkpoint else None)
|
||||
return {"status": "ok"}
|
||||
|
||||
orch.register_handler("restart_test", handler)
|
||||
orch.run(max_jobs=1)
|
||||
|
||||
# Job should have been auto-resumed and executed
|
||||
progress = orch.get_progress("resume-on-start")
|
||||
assert progress["status"] == "completed"
|
||||
assert calls == [3] # Handler saw checkpoint step 3
|
||||
|
||||
def test_cancel_job(self, orch):
|
||||
job_id = orch.submit_job("cancel_test", {"action": "test"})
|
||||
orch.cancel_job(job_id)
|
||||
progress = orch.get_progress(job_id)
|
||||
assert progress["status"] == "cancelled"
|
||||
|
||||
def test_generate_report(self, orch):
|
||||
def handler(job):
|
||||
return {"status": "ok", "token_usage": {"input_tokens": 100, "output_tokens": 50}}
|
||||
|
||||
orch.register_handler("report_test", handler)
|
||||
orch.submit_batch("report_test", [{"i": i} for i in range(3)])
|
||||
orch.run(max_jobs=3)
|
||||
|
||||
report = orch.generate_report("report_test")
|
||||
assert report["completed"] == 3
|
||||
assert report["failed"] == 0
|
||||
assert report["success_rate"] == 100.0
|
||||
assert report["total_tokens"] == 450 # 3 * 150
|
||||
@@ -465,3 +465,194 @@ def test_ci_gate_on_actual_ci_automation_gate():
|
||||
gate = QualityGate()
|
||||
gate.check_file(gate_path)
|
||||
assert gate.failures == 0, f"ci_automation_gate.py should pass quality gate, got {gate.failures} failures"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# BLOOM FILTER + HASH DEDUP TESTS (Issue #628)
|
||||
# ===========================================================================
|
||||
|
||||
import sys, os
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "pipeline"))
|
||||
from quality_gate import BloomFilter, HashDedupStore, HASH_DIR, entry_hash
|
||||
|
||||
|
||||
class TestBloomFilter:
|
||||
|
||||
def test_empty_bloom_no_contains(self):
|
||||
bf = BloomFilter(capacity=100)
|
||||
assert "hello" not in bf
|
||||
|
||||
def test_add_then_contains(self):
|
||||
bf = BloomFilter(capacity=100)
|
||||
bf.add("hello")
|
||||
assert "hello" in bf
|
||||
|
||||
def test_false_negatives_impossible(self):
|
||||
"""No false negatives — every added item is found."""
|
||||
bf = BloomFilter(capacity=1000)
|
||||
items = [f"item-{i}" for i in range(500)]
|
||||
for item in items:
|
||||
bf.add(item)
|
||||
for item in items:
|
||||
assert item in bf, f"False negative for {item}"
|
||||
|
||||
def test_false_positive_rate(self):
|
||||
"""False positive rate should be under the configured error rate."""
|
||||
bf = BloomFilter(capacity=1000, error_rate=0.01)
|
||||
added = {f"added-{i}" for i in range(1000)}
|
||||
for item in added:
|
||||
bf.add(item)
|
||||
false_positives = 0
|
||||
check_count = 10000
|
||||
for i in range(check_count):
|
||||
candidate = f"not-added-{i}"
|
||||
if candidate not in added and candidate in bf:
|
||||
false_positives += 1
|
||||
fp_rate = false_positives / check_count
|
||||
assert fp_rate < 0.05, f"FP rate {fp_rate:.3%} too high (expected <5%)"
|
||||
|
||||
def test_serialization_roundtrip(self):
|
||||
bf = BloomFilter(capacity=100)
|
||||
bf.add("alpha")
|
||||
bf.add("beta")
|
||||
d = bf.to_dict()
|
||||
restored = BloomFilter.from_dict(d)
|
||||
assert "alpha" in restored
|
||||
assert "beta" in restored
|
||||
assert "gamma" not in restored
|
||||
|
||||
|
||||
class TestHashDedupStore:
|
||||
|
||||
def test_first_seen_not_duplicate(self, tmp_path):
|
||||
import quality_gate as qg
|
||||
old_hash_dir = qg.HASH_DIR
|
||||
qg.HASH_DIR = tmp_path / "hashes"
|
||||
try:
|
||||
store = HashDedupStore()
|
||||
assert not store.is_duplicate("abc123")
|
||||
finally:
|
||||
qg.HASH_DIR = old_hash_dir
|
||||
|
||||
def test_after_add_is_duplicate(self, tmp_path):
|
||||
import quality_gate as qg
|
||||
old_hash_dir = qg.HASH_DIR
|
||||
qg.HASH_DIR = tmp_path / "hashes"
|
||||
try:
|
||||
store = HashDedupStore()
|
||||
store.add("abc123")
|
||||
store.flush()
|
||||
assert store.is_duplicate("abc123")
|
||||
finally:
|
||||
qg.HASH_DIR = old_hash_dir
|
||||
|
||||
def test_different_hash_not_duplicate(self, tmp_path):
|
||||
import quality_gate as qg
|
||||
old_hash_dir = qg.HASH_DIR
|
||||
qg.HASH_DIR = tmp_path / "hashes"
|
||||
try:
|
||||
store = HashDedupStore()
|
||||
store.add("abc123")
|
||||
store.flush()
|
||||
assert not store.is_duplicate("xyz789")
|
||||
finally:
|
||||
qg.HASH_DIR = old_hash_dir
|
||||
|
||||
def test_rotation_deletes_old_files(self, tmp_path):
|
||||
"""Files older than retention_days should be deleted."""
|
||||
import quality_gate as qg
|
||||
old_hash_dir = qg.HASH_DIR
|
||||
qg.HASH_DIR = tmp_path / "hashes"
|
||||
qg.HASH_DIR.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
# Create old file
|
||||
old_date = "2020-01-01"
|
||||
(qg.HASH_DIR / f"{old_date}.json").write_text('["old_hash"]')
|
||||
# Create today's file
|
||||
from datetime import datetime, timezone
|
||||
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
(qg.HASH_DIR / f"{today}.json").write_text('["new_hash"]')
|
||||
|
||||
store = HashDedupStore(retention_days=7)
|
||||
store._rotate()
|
||||
|
||||
assert not (qg.HASH_DIR / f"{old_date}.json").exists(), "Old file should be deleted"
|
||||
assert (qg.HASH_DIR / f"{today}.json").exists(), "Today's file should remain"
|
||||
finally:
|
||||
qg.HASH_DIR = old_hash_dir
|
||||
|
||||
def test_stats_reports_counts(self, tmp_path):
|
||||
import quality_gate as qg
|
||||
old_hash_dir = qg.HASH_DIR
|
||||
qg.HASH_DIR = tmp_path / "hashes"
|
||||
try:
|
||||
store = HashDedupStore()
|
||||
for i in range(5):
|
||||
store.add(f"hash-{i}")
|
||||
store.flush()
|
||||
stats = store.stats()
|
||||
assert stats["today_count"] == 5
|
||||
assert stats["total_hashes"] >= 5
|
||||
finally:
|
||||
qg.HASH_DIR = old_hash_dir
|
||||
|
||||
def test_large_scale_dedup(self, tmp_path):
|
||||
"""10K hashes should work without blowing up memory."""
|
||||
import quality_gate as qg
|
||||
old_hash_dir = qg.HASH_DIR
|
||||
qg.HASH_DIR = tmp_path / "hashes"
|
||||
try:
|
||||
store = HashDedupStore()
|
||||
hashes = [f"hash-{i:06d}" for i in range(10000)]
|
||||
for h in hashes:
|
||||
store.add(h)
|
||||
store.flush()
|
||||
# All should be duplicates now
|
||||
dupes = sum(1 for h in hashes if store.is_duplicate(h))
|
||||
assert dupes == 10000, f"Expected 10000 dupes, got {dupes}"
|
||||
finally:
|
||||
qg.HASH_DIR = old_hash_dir
|
||||
|
||||
|
||||
class TestCrossRunDedup:
|
||||
|
||||
def test_run_gate_rejects_cross_run_duplicate(self, tmp_path):
|
||||
"""Second run with same content should reject duplicates."""
|
||||
import quality_gate as qg
|
||||
old_hash_dir = qg.HASH_DIR
|
||||
old_stats = qg.STATS_FILE
|
||||
qg.HASH_DIR = tmp_path / "hashes"
|
||||
qg.STATS_FILE = tmp_path / "stats.json"
|
||||
try:
|
||||
# Write test JSONL
|
||||
entries = [{"prompt": "hello", "response": "world " * 20}]
|
||||
jsonl_path = tmp_path / "test.jsonl"
|
||||
jsonl_path.write_text(json.dumps(entries[0]) + "\n")
|
||||
|
||||
# First run — passes
|
||||
store1 = HashDedupStore()
|
||||
report1 = qg.run_gate(str(jsonl_path), "training_pairs", store1)
|
||||
assert report1.passed == 1
|
||||
assert report1.rejected == 0
|
||||
|
||||
# Second run with new store (simulates restart) — should detect dupe
|
||||
store2 = HashDedupStore()
|
||||
report2 = qg.run_gate(str(jsonl_path), "training_pairs", store2)
|
||||
# The hash was persisted to disk, so store2 should detect it
|
||||
assert report2.rejected == 1, f"Expected 1 rejected, got {report2.rejected}"
|
||||
finally:
|
||||
qg.HASH_DIR = old_hash_dir
|
||||
qg.STATS_FILE = old_stats
|
||||
|
||||
def test_entry_hash_deterministic(self):
|
||||
"""Same entry always produces same hash."""
|
||||
entry = {"prompt": "test", "response": "data"}
|
||||
h1 = entry_hash(entry)
|
||||
h2 = entry_hash(entry)
|
||||
assert h1 == h2
|
||||
assert len(h1) == 16
|
||||
|
||||
def test_entry_hash_differs_for_different_entries(self):
|
||||
h1 = entry_hash({"a": 1})
|
||||
h2 = entry_hash({"a": 2})
|
||||
assert h1 != h2
|
||||
|
||||
Reference in New Issue
Block a user