diff --git a/pipeline/quality_gate.py b/pipeline/quality_gate.py index e15d68c4..42a25f6e 100644 --- a/pipeline/quality_gate.py +++ b/pipeline/quality_gate.py @@ -22,13 +22,187 @@ import json import os import sys import hashlib +import math import re +import struct from pathlib import Path -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta from dataclasses import dataclass, field, asdict -from typing import List, Optional, Dict, Any +from typing import List, Optional, Dict, Any, Set + +PIPELINE_DIR = Path.home() / ".hermes" / "pipeline" +STATS_FILE = PIPELINE_DIR / "quality_stats.json" +HASH_DIR = PIPELINE_DIR / "quality_hashes" +HASH_RETENTION_DAYS = 7 # Keep hashes for 7 days + + +# ============================================================ +# Bloom Filter — Memory-efficient dedup at scale +# ============================================================ + +class BloomFilter: + """Probabilistic set for membership testing. False positives possible, no false negatives.""" + + def __init__(self, capacity: int = 100_000, error_rate: float = 0.01): + self.capacity = capacity + self.error_rate = error_rate + # Optimal size and hash count + self.size = max(64, int(-capacity * math.log(error_rate) / (math.log(2) ** 2))) + self.num_hashes = max(1, int(self.size / capacity * math.log(2))) + self._bitarray = bytearray((self.size + 7) // 8) + + def _hash_indices(self, item: str) -> List[int]: + """Generate bit indices using double hashing.""" + h1 = int.from_bytes(hashlib.sha256(item.encode()).digest()[:8], "little") + h2 = int.from_bytes(hashlib.md5(item.encode()).digest()[:8], "little") + return [(h1 + i * h2) % self.size for i in range(self.num_hashes)] + + def add(self, item: str): + for idx in self._hash_indices(item): + self._bitarray[idx // 8] |= 1 << (idx % 8) + + def __contains__(self, item: str) -> bool: + return all(self._bitarray[idx // 8] & (1 << (idx % 8)) for idx in self._hash_indices(item)) + + def to_dict(self) -> dict: + return { + "capacity": self.capacity, + "error_rate": self.error_rate, + "size": self.size, + "num_hashes": self.num_hashes, + "data": base64.b64encode(bytes(self._bitarray)).decode(), + } + + @classmethod + def from_dict(cls, d: dict) -> "BloomFilter": + bf = cls(capacity=d["capacity"], error_rate=d["error_rate"]) + bf._bitarray = bytearray(base64.b64decode(d["data"])) + return bf + + +# ============================================================ +# Hash Dedup Store — Rotating daily files + bloom filter +# ============================================================ + +class HashDedupStore: + """Rotating hash store for cross-run deduplication. + + Strategy: + - Daily JSON files: HASH_DIR/YYYY-MM-DD.json (set of 16-char hashes) + - Bloom filter: HASH_DIR/bloom.json (memory-efficient for large scale) + - On load: merge last N days into bloom filter + - Rotation: delete files older than HASH_RETENTION_DAYS + """ + + def __init__(self, retention_days: int = HASH_RETENTION_DAYS): + self.retention_days = retention_days + HASH_DIR.mkdir(parents=True, exist_ok=True) + self._today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + self._daily_hashes: Set[str] = set() + self._bloom: Optional[BloomFilter] = None + self._load() + + def _day_file(self, day: str) -> Path: + return HASH_DIR / f"{day}.json" + + def _bloom_file(self) -> Path: + return HASH_DIR / "bloom.json" + + def _load(self): + """Load today's hashes and bloom filter.""" + # Load today's file + day_path = self._day_file(self._today) + if day_path.exists(): + try: + self._daily_hashes = set(json.loads(day_path.read_text())) + except (json.JSONDecodeError, IOError): + self._daily_hashes = set() + + # Load or rebuild bloom filter + bloom_path = self._bloom_file() + if bloom_path.exists(): + try: + self._bloom = BloomFilter.from_dict(json.loads(bloom_path.read_text())) + except (json.JSONDecodeError, IOError, KeyError): + self._bloom = None + + if self._bloom is None: + self._rebuild_bloom() + + def _rebuild_bloom(self): + """Rebuild bloom filter from all recent daily files.""" + hashes = set() + for day_offset in range(self.retention_days): + day = (datetime.now(timezone.utc) - timedelta(days=day_offset)).strftime("%Y-%m-%d") + day_path = self._day_file(day) + if day_path.exists(): + try: + hashes.update(json.loads(day_path.read_text())) + except (json.JSONDecodeError, IOError): + pass + + capacity = max(len(hashes) * 2, 10_000) + self._bloom = BloomFilter(capacity=capacity) + for h in hashes: + self._bloom.add(h) + + def _save(self): + """Persist today's hashes and bloom filter.""" + day_path = self._day_file(self._today) + day_path.write_text(json.dumps(sorted(self._daily_hashes))) + + if self._bloom: + self._bloom_file().write_text(json.dumps(self._bloom.to_dict())) + + def _rotate(self): + """Delete daily hash files older than retention period.""" + cutoff = (datetime.now(timezone.utc) - timedelta(days=self.retention_days)).strftime("%Y-%m-%d") + for path in HASH_DIR.glob("*.json"): + name = path.stem + if len(name) == 10 and name < cutoff and name != "bloom": + path.unlink() + + def is_duplicate(self, h: str) -> bool: + """Check if hash has been seen in current day or bloom filter.""" + if h in self._daily_hashes: + return True + if self._bloom and h in self._bloom: + return True + return False + + def add(self, h: str): + """Add a hash. Saves and rotates periodically.""" + self._daily_hashes.add(h) + if self._bloom: + self._bloom.add(h) + # Save every 100 additions or on explicit call + if len(self._daily_hashes) % 100 == 0: + self._save() + self._rotate() + + def flush(self): + """Force save and rotate.""" + self._save() + self._rotate() + + def stats(self) -> dict: + """Return dedup store statistics.""" + file_count = len(list(HASH_DIR.glob("*.json"))) + total_hashes = 0 + for path in HASH_DIR.glob("????-??-??.json"): + try: + total_hashes += len(json.loads(path.read_text())) + except Exception: + pass + return { + "today_count": len(self._daily_hashes), + "total_files": file_count, + "total_hashes": total_hashes, + "retention_days": self.retention_days, + "bloom_size": self._bloom.size if self._bloom else 0, + } + -STATS_FILE = Path.home() / ".hermes" / "pipeline" / "quality_stats.json" # --- Quality Check Types --- @@ -228,8 +402,14 @@ CHECK_MAP = { } -def run_gate(input_path: str, entry_type: str) -> GateReport: - """Run quality gate on a JSONL file.""" +def run_gate(input_path: str, entry_type: str, dedup_store: Optional[HashDedupStore] = None) -> GateReport: + """Run quality gate on a JSONL file. + + Args: + input_path: Path to JSONL file + entry_type: Type of entries (training_pairs, scene_descriptions, etc.) + dedup_store: Optional hash dedup store for cross-run dedup. If None, creates one. + """ path = Path(input_path) if not path.exists(): return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0) @@ -239,6 +419,9 @@ def run_gate(input_path: str, entry_type: str) -> GateReport: return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0, rejected_indices=[-1]) # unknown type + if dedup_store is None: + dedup_store = HashDedupStore() + entries = [] with open(path) as f: for line in f: @@ -246,7 +429,7 @@ def run_gate(input_path: str, entry_type: str) -> GateReport: if line: entries.append(json.loads(line)) - # Deduplication check + # Within-file deduplication check key_fields = _get_key_fields(entry_type) dup_errors = check_no_duplicates(entries, key_fields) @@ -254,13 +437,22 @@ def run_gate(input_path: str, entry_type: str) -> GateReport: rejected = 0 rejected_indices = [] total_score = 0.0 + cross_run_dupes = 0 for i, entry in enumerate(entries): errors = check_fn(entry) - # Add duplicate errors + # Add within-file duplicate errors if i in dup_errors: errors.extend(dup_errors[i]) + + # Cross-run hash dedup + h = entry_hash(entry) + if dedup_store.is_duplicate(h): + errors.append(f"cross_run_duplicate: hash {h} seen in prior run") + cross_run_dupes += 1 + else: + dedup_store.add(h) # Add SOUL compliance check for text content text_content = "" @@ -286,6 +478,9 @@ def run_gate(input_path: str, entry_type: str) -> GateReport: avg_score = total_score / len(entries) if entries else 0.0 + # Flush dedup store + dedup_store.flush() + report = GateReport( file=str(path), type=entry_type, @@ -299,6 +494,10 @@ def run_gate(input_path: str, entry_type: str) -> GateReport: # Save stats _save_stats(report) + if cross_run_dupes > 0: + logger_msg = f" cross-run dedup: {cross_run_dupes} duplicates found" + print(logger_msg, file=sys.stderr) + return report @@ -318,7 +517,7 @@ def _get_key_fields(entry_type: str) -> List[str]: def _save_stats(report: GateReport): - """Append quality stats to the stats file.""" + """Append quality stats to the stats file. Rotates to keep last 1000.""" STATS_FILE.parent.mkdir(parents=True, exist_ok=True) stats = [] @@ -331,8 +530,9 @@ def _save_stats(report: GateReport): stats.append(report.to_dict()) - # Keep last 1000 entries - stats = stats[-1000:] + # Rotate: keep last 1000 entries + if len(stats) > 1000: + stats = stats[-1000:] with open(STATS_FILE, "w") as f: json.dump(stats, f, indent=2) diff --git a/pipelines/__init__.py b/pipelines/__init__.py new file mode 100644 index 00000000..3d0c91d7 --- /dev/null +++ b/pipelines/__init__.py @@ -0,0 +1,22 @@ +"""Pipeline infrastructure — shared orchestrator.""" +from .orchestrator import ( + PipelineOrchestrator, + OrchestratorDB, + Job, + JobStatus, + JobPriority, + JobCheckpoint, + TokenUsage, + RateLimiter, +) + +__all__ = [ + "PipelineOrchestrator", + "OrchestratorDB", + "Job", + "JobStatus", + "JobPriority", + "JobCheckpoint", + "TokenUsage", + "RateLimiter", +] diff --git a/pipelines/orchestrator.py b/pipelines/orchestrator.py index d78e8a25..70550814 100644 --- a/pipelines/orchestrator.py +++ b/pipelines/orchestrator.py @@ -574,33 +574,67 @@ class PipelineOrchestrator: return job def run(self, pipeline: Optional[str] = None, max_jobs: Optional[int] = None): - """Run the orchestrator, processing jobs from the queue.""" + """Run the orchestrator, processing jobs from the queue. + + On startup, checks for paused/running jobs with checkpoints and + resumes them first before picking up new pending jobs. + """ self.running = True self.executor = ThreadPoolExecutor(max_workers=self.max_workers) + futures = {} logger.info(f"Orchestrator starting (workers={self.max_workers})") + # Resume paused jobs with checkpoints on restart + for status in (JobStatus.PAUSED, JobStatus.RUNNING): + for job in self.db.get_jobs_by_status(status, pipeline): + if job.checkpoint: + logger.info(f"Resuming {status.value} job {job.id} from checkpoint step {job.checkpoint.step}") + job.status = JobStatus.PENDING + self.db.save_job(job) + try: jobs_processed = 0 while self.running: + # Check completed futures + done = [f for f in futures if f.done()] + for f in done: + try: + f.result() # propagate exceptions + except Exception as e: + logger.error(f"Worker error: {e}") + del futures[f] + + # Throttle if at capacity + if len(futures) >= self.max_workers: + time.sleep(0.1) + continue + # Get next job job = self.db.get_next_job(pipeline) if not job: - # No pending jobs, wait a bit - time.sleep(1) + if not futures: + # No jobs and no workers — done + break + time.sleep(0.5) continue - # Submit to thread pool future = self.executor.submit(self._execute_job, job) - - # Don't wait for completion, get next job + futures[future] = job.id jobs_processed += 1 if max_jobs and jobs_processed >= max_jobs: logger.info(f"Reached max_jobs limit ({max_jobs})") break + + # Wait for remaining futures + for f in futures: + try: + f.result(timeout=300) + except Exception as e: + logger.error(f"Worker error on drain: {e}") finally: self.executor.shutdown(wait=True) @@ -735,6 +769,11 @@ def main(): parser = argparse.ArgumentParser(description="Pipeline Orchestrator") subparsers = parser.add_subparsers(dest="command") + # List jobs + list_parser = subparsers.add_parser("list", help="List jobs") + list_parser.add_argument("--status", help="Filter by status") + list_parser.add_argument("--pipeline", help="Filter by pipeline") + # Submit job submit_parser = subparsers.add_parser("submit", help="Submit a job") submit_parser.add_argument("pipeline", help="Pipeline name") @@ -779,6 +818,23 @@ def main(): elif args.command == "run": orchestrator.run(args.pipeline, args.max_jobs) + elif args.command == "list": + status_filter = JobStatus(args.status) if args.status else None + if status_filter: + jobs = orchestrator.db.get_jobs_by_status(status_filter, args.pipeline) + else: + # Show all jobs + conn = orchestrator.db._get_conn() + rows = conn.execute("SELECT * FROM jobs ORDER BY priority DESC, created_at ASC").fetchall() + conn.close() + jobs = [orchestrator.db.get_job(row['id']) for row in rows] + for job in jobs: + dur = "" + if job.started_at and job.completed_at: + dur = f" ({job.completed_at - job.started_at:.1f}s)" + print(f" {job.id[:8]} {job.status.value:10s} p{job.priority.value} {job.pipeline} tokens={job.token_usage.total_tokens}{dur}") + print(f"\n{len(jobs)} jobs") + elif args.command == "status": progress = orchestrator.get_progress(args.job_id) print(json.dumps(progress, indent=2)) diff --git a/pipelines/tests/test_orchestrator.py b/pipelines/tests/test_orchestrator.py new file mode 100644 index 00000000..69ea799d --- /dev/null +++ b/pipelines/tests/test_orchestrator.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +"""Tests for pipeline orchestrator — queue, parallelism, resume, token tracking.""" + +import json +import os +import tempfile +import time +from pathlib import Path + +import pytest + +# Add project root to path +import sys +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from pipelines.orchestrator import ( + PipelineOrchestrator, + Job, + JobStatus, + JobPriority, + JobCheckpoint, + TokenUsage, + OrchestratorDB, + RateLimiter, +) + + +@pytest.fixture +def tmp_db(tmp_path): + """Fresh orchestrator DB for each test.""" + return tmp_path / "test_orchestrator.db" + + +@pytest.fixture +def orch(tmp_db): + """Orchestrator instance with temp DB.""" + return PipelineOrchestrator(max_workers=2, token_budget=10000, db_path=tmp_db) + + +class TestJobDataModels: + """Test Job, TokenUsage, JobCheckpoint dataclasses.""" + + def test_token_usage_total(self): + usage = TokenUsage(input_tokens=100, output_tokens=50) + assert usage.total_tokens == 150 + + def test_token_usage_zero(self): + usage = TokenUsage() + assert usage.total_tokens == 0 + + def test_token_usage_serialization(self): + usage = TokenUsage(input_tokens=10, output_tokens=20, cache_read_tokens=5, cost_usd=0.001) + d = usage.to_dict() + restored = TokenUsage.from_dict(d) + assert restored.input_tokens == 10 + assert restored.total_tokens == 30 + + def test_checkpoint_serialization(self): + cp = JobCheckpoint(job_id="abc", step=3, data={"key": "val"}) + d = cp.to_dict() + restored = JobCheckpoint.from_dict(d) + assert restored.step == 3 + assert restored.data == {"key": "val"} + + def test_job_serialization(self): + job = Job(id="test-1", pipeline="demo", task={"action": "run"}) + d = job.to_dict() + restored = Job.from_dict(d) + assert restored.id == "test-1" + assert restored.status == JobStatus.PENDING + assert restored.priority == JobPriority.NORMAL + + +class TestOrchestratorDB: + """Test SQLite-backed job queue.""" + + def test_save_and_get(self, tmp_db): + db = OrchestratorDB(tmp_db) + job = Job(id="j1", pipeline="test", task={"x": 1}) + db.save_job(job) + fetched = db.get_job("j1") + assert fetched is not None + assert fetched.id == "j1" + assert fetched.task == {"x": 1} + + def test_get_next_job_priority(self, tmp_db): + db = OrchestratorDB(tmp_db) + db.save_job(Job(id="low", pipeline="test", task={}, priority=JobPriority.LOW)) + db.save_job(Job(id="high", pipeline="test", task={}, priority=JobPriority.HIGH)) + db.save_job(Job(id="normal", pipeline="test", task={}, priority=JobPriority.NORMAL)) + next_job = db.get_next_job() + assert next_job.id == "high" + + def test_get_next_job_pipeline_filter(self, tmp_db): + db = OrchestratorDB(tmp_db) + db.save_job(Job(id="a", pipeline="alpha", task={})) + db.save_job(Job(id="b", pipeline="beta", task={})) + next_job = db.get_next_job(pipeline="beta") + assert next_job.id == "b" + + def test_get_jobs_by_status(self, tmp_db): + db = OrchestratorDB(tmp_db) + db.save_job(Job(id="a", pipeline="test", task={}, status=JobStatus.PENDING)) + db.save_job(Job(id="b", pipeline="test", task={}, status=JobStatus.COMPLETED)) + pending = db.get_jobs_by_status(JobStatus.PENDING) + assert len(pending) == 1 + assert pending[0].id == "a" + + def test_checkpoint_save_load(self, tmp_db): + db = OrchestratorDB(tmp_db) + cp = JobCheckpoint(job_id="j1", step=5, data={"progress": "50%"}) + db.save_checkpoint("j1", cp) + loaded = db.get_checkpoint("j1") + assert loaded is not None + assert loaded.step == 5 + assert loaded.data == {"progress": "50%"} + + def test_stats(self, tmp_db): + db = OrchestratorDB(tmp_db) + job = Job(id="j1", pipeline="test", task={}, status=JobStatus.COMPLETED) + job.token_usage = TokenUsage(input_tokens=100, output_tokens=50) + db.save_job(job) + stats = db.get_stats() + assert stats["completed"] == 1 + assert stats["total_tokens"] == 150 + + +class TestRateLimiter: + """Test rate limiter.""" + + def test_no_limit(self): + rl = RateLimiter() + can_proceed, wait = rl.can_proceed("unknown") + assert can_proceed is True + assert wait == 0.0 + + def test_rpm_limit(self): + rl = RateLimiter() + rl.configure("test", requests_per_minute=2, tokens_per_minute=1000) + assert rl.can_proceed("test")[0] is True + rl.record_request("test") + assert rl.can_proceed("test")[0] is True + rl.record_request("test") + can_proceed, wait = rl.can_proceed("test") + assert can_proceed is False + assert wait > 0 + + +class TestPipelineOrchestrator: + """Test the main orchestrator.""" + + def test_submit_and_retrieve(self, orch): + job_id = orch.submit_job("test_pipeline", {"action": "process"}) + assert job_id is not None + progress = orch.get_progress(job_id) + assert progress["status"] == "pending" + assert progress["pipeline"] == "test_pipeline" + + def test_submit_batch(self, orch): + ids = orch.submit_batch("test", [{"i": i} for i in range(5)]) + assert len(ids) == 5 + + def test_handler_execution(self, orch): + results = [] + def handler(job): + results.append(job.id) + return {"status": "ok"} + + orch.register_handler("demo", handler) + job_id = orch.submit_job("demo", {"action": "test"}) + orch.run(max_jobs=1) + + progress = orch.get_progress(job_id) + assert progress["status"] == "completed" + assert len(results) == 1 + + def test_handler_failure_and_retry(self, orch): + attempts = [] + def handler(job): + attempts.append(1) + if len(attempts) < 3: + raise ValueError("transient error") + return {"status": "ok"} + + orch.register_handler("retry_test", handler) + job_id = orch.submit_job("retry_test", {"action": "test"}, max_retries=3) + orch.run(max_jobs=1) + + progress = orch.get_progress(job_id) + assert progress["status"] == "completed" + assert len(attempts) == 3 + + def test_handler_exhausts_retries(self, orch): + def handler(job): + raise ValueError("permanent error") + + orch.register_handler("fail_test", handler) + job_id = orch.submit_job("fail_test", {"action": "test"}, max_retries=2) + orch.run(max_jobs=1) + + progress = orch.get_progress(job_id) + assert progress["status"] == "failed" + assert "permanent error" in progress["error"] + + def test_no_handler(self, orch): + job_id = orch.submit_job("nonexistent", {"action": "test"}) + orch.run(max_jobs=1) + progress = orch.get_progress(job_id) + assert progress["status"] == "failed" + assert "No handler" in progress["error"] + + def test_token_budget_tracking(self, orch): + def handler(job): + return {"status": "ok", "token_usage": {"input_tokens": 500, "output_tokens": 200}} + + orch.register_handler("token_test", handler) + job_id = orch.submit_job("token_test", {"action": "test"}, token_budget=1000) + orch.run(max_jobs=1) + + progress = orch.get_progress(job_id) + assert progress["token_usage"]["input_tokens"] == 500 + assert progress["token_usage"]["output_tokens"] == 200 + + def test_token_budget_exceeded(self, orch): + def handler(job): + return {"status": "ok"} + + orch.register_handler("budget_test", handler) + # Set job with already-exhausted budget by manipulating DB + job_id = orch.submit_job("budget_test", {"action": "test"}, token_budget=100) + job = orch.db.get_job(job_id) + job.token_usage = TokenUsage(input_tokens=100, output_tokens=10) + orch.db.save_job(job) + + orch.run(max_jobs=1) + progress = orch.get_progress(job_id) + assert progress["status"] == "failed" + assert "budget" in progress["error"].lower() + + def test_parallel_execution(self, orch): + """Verify workers execute in parallel.""" + import threading + active = set() + max_concurrent = [0] + + def handler(job): + active.add(threading.current_thread().name) + max_concurrent[0] = max(max_concurrent[0], len(active)) + time.sleep(0.1) + active.discard(threading.current_thread().name) + return {"status": "ok"} + + orch.register_handler("parallel", handler) + orch.submit_batch("parallel", [{"i": i} for i in range(4)]) + orch.run(max_jobs=4) + + # With 2 workers, at least 2 should have been active simultaneously + assert max_concurrent[0] >= 2 + + def test_resume_paused_job(self, orch): + """Test resume from checkpoint.""" + call_count = [0] + + def handler(job): + call_count[0] += 1 + if call_count[0] == 1: + # Simulate saving checkpoint before failure + job.checkpoint = JobCheckpoint(job_id=job.id, step=1, data={"partial": True}) + orch.db.save_checkpoint(job.id, job.checkpoint) + raise ValueError("first attempt fails") + # Second attempt succeeds + return {"status": "ok", "resumed_from": job.checkpoint.step if job.checkpoint else None} + + orch.register_handler("resume_test", handler) + job_id = orch.submit_job("resume_test", {"action": "test"}, max_retries=3) + + # First run — fails, saves checkpoint + orch.run(max_jobs=1) + + # Manually resume (set to pending) + job = orch.db.get_job(job_id) + if job.status == JobStatus.FAILED: + job.status = JobStatus.PENDING + job.retry_count = 0 + job.error = None + orch.db.save_job(job) + orch.run(max_jobs=1) + + progress = orch.get_progress(job_id) + assert progress["status"] == "completed" + + def test_resume_on_restart(self, orch): + """Test that run() resumes paused/running jobs with checkpoints on startup.""" + # Create a paused job with a checkpoint + job = Job(id="resume-on-start", pipeline="restart_test", task={"action": "resume"}) + job.status = JobStatus.PAUSED + orch.db.save_job(job) + orch.db.save_checkpoint("resume-on-start", JobCheckpoint( + job_id="resume-on-start", step=3, data={"progress": "50%"} + )) + + calls = [] + def handler(job): + calls.append(job.checkpoint.step if job.checkpoint else None) + return {"status": "ok"} + + orch.register_handler("restart_test", handler) + orch.run(max_jobs=1) + + # Job should have been auto-resumed and executed + progress = orch.get_progress("resume-on-start") + assert progress["status"] == "completed" + assert calls == [3] # Handler saw checkpoint step 3 + + def test_cancel_job(self, orch): + job_id = orch.submit_job("cancel_test", {"action": "test"}) + orch.cancel_job(job_id) + progress = orch.get_progress(job_id) + assert progress["status"] == "cancelled" + + def test_generate_report(self, orch): + def handler(job): + return {"status": "ok", "token_usage": {"input_tokens": 100, "output_tokens": 50}} + + orch.register_handler("report_test", handler) + orch.submit_batch("report_test", [{"i": i} for i in range(3)]) + orch.run(max_jobs=3) + + report = orch.generate_report("report_test") + assert report["completed"] == 3 + assert report["failed"] == 0 + assert report["success_rate"] == 100.0 + assert report["total_tokens"] == 450 # 3 * 150 diff --git a/tests/test_quality_gate.py b/tests/test_quality_gate.py index 0c050d3d..ba222747 100644 --- a/tests/test_quality_gate.py +++ b/tests/test_quality_gate.py @@ -465,3 +465,194 @@ def test_ci_gate_on_actual_ci_automation_gate(): gate = QualityGate() gate.check_file(gate_path) assert gate.failures == 0, f"ci_automation_gate.py should pass quality gate, got {gate.failures} failures" + + +# =========================================================================== +# BLOOM FILTER + HASH DEDUP TESTS (Issue #628) +# =========================================================================== + +import sys, os +sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "pipeline")) +from quality_gate import BloomFilter, HashDedupStore, HASH_DIR, entry_hash + + +class TestBloomFilter: + + def test_empty_bloom_no_contains(self): + bf = BloomFilter(capacity=100) + assert "hello" not in bf + + def test_add_then_contains(self): + bf = BloomFilter(capacity=100) + bf.add("hello") + assert "hello" in bf + + def test_false_negatives_impossible(self): + """No false negatives — every added item is found.""" + bf = BloomFilter(capacity=1000) + items = [f"item-{i}" for i in range(500)] + for item in items: + bf.add(item) + for item in items: + assert item in bf, f"False negative for {item}" + + def test_false_positive_rate(self): + """False positive rate should be under the configured error rate.""" + bf = BloomFilter(capacity=1000, error_rate=0.01) + added = {f"added-{i}" for i in range(1000)} + for item in added: + bf.add(item) + false_positives = 0 + check_count = 10000 + for i in range(check_count): + candidate = f"not-added-{i}" + if candidate not in added and candidate in bf: + false_positives += 1 + fp_rate = false_positives / check_count + assert fp_rate < 0.05, f"FP rate {fp_rate:.3%} too high (expected <5%)" + + def test_serialization_roundtrip(self): + bf = BloomFilter(capacity=100) + bf.add("alpha") + bf.add("beta") + d = bf.to_dict() + restored = BloomFilter.from_dict(d) + assert "alpha" in restored + assert "beta" in restored + assert "gamma" not in restored + + +class TestHashDedupStore: + + def test_first_seen_not_duplicate(self, tmp_path): + import quality_gate as qg + old_hash_dir = qg.HASH_DIR + qg.HASH_DIR = tmp_path / "hashes" + try: + store = HashDedupStore() + assert not store.is_duplicate("abc123") + finally: + qg.HASH_DIR = old_hash_dir + + def test_after_add_is_duplicate(self, tmp_path): + import quality_gate as qg + old_hash_dir = qg.HASH_DIR + qg.HASH_DIR = tmp_path / "hashes" + try: + store = HashDedupStore() + store.add("abc123") + store.flush() + assert store.is_duplicate("abc123") + finally: + qg.HASH_DIR = old_hash_dir + + def test_different_hash_not_duplicate(self, tmp_path): + import quality_gate as qg + old_hash_dir = qg.HASH_DIR + qg.HASH_DIR = tmp_path / "hashes" + try: + store = HashDedupStore() + store.add("abc123") + store.flush() + assert not store.is_duplicate("xyz789") + finally: + qg.HASH_DIR = old_hash_dir + + def test_rotation_deletes_old_files(self, tmp_path): + """Files older than retention_days should be deleted.""" + import quality_gate as qg + old_hash_dir = qg.HASH_DIR + qg.HASH_DIR = tmp_path / "hashes" + qg.HASH_DIR.mkdir(parents=True, exist_ok=True) + try: + # Create old file + old_date = "2020-01-01" + (qg.HASH_DIR / f"{old_date}.json").write_text('["old_hash"]') + # Create today's file + from datetime import datetime, timezone + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + (qg.HASH_DIR / f"{today}.json").write_text('["new_hash"]') + + store = HashDedupStore(retention_days=7) + store._rotate() + + assert not (qg.HASH_DIR / f"{old_date}.json").exists(), "Old file should be deleted" + assert (qg.HASH_DIR / f"{today}.json").exists(), "Today's file should remain" + finally: + qg.HASH_DIR = old_hash_dir + + def test_stats_reports_counts(self, tmp_path): + import quality_gate as qg + old_hash_dir = qg.HASH_DIR + qg.HASH_DIR = tmp_path / "hashes" + try: + store = HashDedupStore() + for i in range(5): + store.add(f"hash-{i}") + store.flush() + stats = store.stats() + assert stats["today_count"] == 5 + assert stats["total_hashes"] >= 5 + finally: + qg.HASH_DIR = old_hash_dir + + def test_large_scale_dedup(self, tmp_path): + """10K hashes should work without blowing up memory.""" + import quality_gate as qg + old_hash_dir = qg.HASH_DIR + qg.HASH_DIR = tmp_path / "hashes" + try: + store = HashDedupStore() + hashes = [f"hash-{i:06d}" for i in range(10000)] + for h in hashes: + store.add(h) + store.flush() + # All should be duplicates now + dupes = sum(1 for h in hashes if store.is_duplicate(h)) + assert dupes == 10000, f"Expected 10000 dupes, got {dupes}" + finally: + qg.HASH_DIR = old_hash_dir + + +class TestCrossRunDedup: + + def test_run_gate_rejects_cross_run_duplicate(self, tmp_path): + """Second run with same content should reject duplicates.""" + import quality_gate as qg + old_hash_dir = qg.HASH_DIR + old_stats = qg.STATS_FILE + qg.HASH_DIR = tmp_path / "hashes" + qg.STATS_FILE = tmp_path / "stats.json" + try: + # Write test JSONL + entries = [{"prompt": "hello", "response": "world " * 20}] + jsonl_path = tmp_path / "test.jsonl" + jsonl_path.write_text(json.dumps(entries[0]) + "\n") + + # First run — passes + store1 = HashDedupStore() + report1 = qg.run_gate(str(jsonl_path), "training_pairs", store1) + assert report1.passed == 1 + assert report1.rejected == 0 + + # Second run with new store (simulates restart) — should detect dupe + store2 = HashDedupStore() + report2 = qg.run_gate(str(jsonl_path), "training_pairs", store2) + # The hash was persisted to disk, so store2 should detect it + assert report2.rejected == 1, f"Expected 1 rejected, got {report2.rejected}" + finally: + qg.HASH_DIR = old_hash_dir + qg.STATS_FILE = old_stats + + def test_entry_hash_deterministic(self): + """Same entry always produces same hash.""" + entry = {"prompt": "test", "response": "data"} + h1 = entry_hash(entry) + h2 = entry_hash(entry) + assert h1 == h2 + assert len(h1) == 16 + + def test_entry_hash_differs_for_different_entries(self): + h1 = entry_hash({"a": 1}) + h2 = entry_hash({"a": 2}) + assert h1 != h2