diff --git a/pipelines/quality-gate.py b/pipelines/quality-gate.py new file mode 100755 index 00000000..18659862 --- /dev/null +++ b/pipelines/quality-gate.py @@ -0,0 +1,628 @@ +#!/usr/bin/env python3 +""" +quality_gate.py — Quality Gate for Pipeline Outputs + +Validates all pipeline outputs before saving. Rejects bad outputs, +tracks quality scores, and supports re-queue for regeneration. + +Usage: + python3 quality_gate.py --input output.jsonl --type training_pairs + python3 quality_gate.py --input output.jsonl --type knowledge + python3 quality_gate.py --input output.jsonl --type scene_descriptions + python3 quality_gate.py --dir pipeline/output/ --type training_pairs + python3 quality_gate.py --status # show quality stats + +Exit codes: + 0 = all outputs passed + 1 = some outputs rejected + 2 = file/parse error +""" + +import base64 +import json +import os +import sys +import hashlib +import math +import re +import struct +import logging +from pathlib import Path +from datetime import datetime, timezone, timedelta +from dataclasses import dataclass, field, asdict +from typing import List, Optional, Dict, Any, Set + +# FIX: Use 'pipelines' (plural) to match local ~/.hermes/pipelines/ layout +PIPELINE_DIR = Path.home() / ".hermes" / "pipelines" +STATS_FILE = PIPELINE_DIR / "quality_stats.json" +HASH_DIR = PIPELINE_DIR / "quality_hashes" +HASH_RETENTION_DAYS = 7 # Keep hashes for 7 days + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", +) +log = logging.getLogger("quality_gate") + + +# ============================================================ +# Bloom Filter — Memory-efficient dedup at scale +# ============================================================ + +class BloomFilter: + """Probabilistic set for membership testing. False positives possible, no false negatives.""" + + def __init__(self, capacity: int = 100_000, error_rate: float = 0.01): + self.capacity = capacity + self.error_rate = error_rate + # Optimal size and hash count + self.size = max(64, int(-capacity * math.log(error_rate) / (math.log(2) ** 2))) + self.num_hashes = max(1, int(self.size / capacity * math.log(2))) + self._bitarray = bytearray((self.size + 7) // 8) + + def _hash_indices(self, item: str) -> List[int]: + """Generate bit indices using double hashing.""" + h1 = int.from_bytes(hashlib.sha256(item.encode()).digest()[:8], "little") + h2 = int.from_bytes(hashlib.md5(item.encode()).digest()[:8], "little") + return [(h1 + i * h2) % self.size for i in range(self.num_hashes)] + + def add(self, item: str): + for idx in self._hash_indices(item): + self._bitarray[idx // 8] |= 1 << (idx % 8) + + def __contains__(self, item: str) -> bool: + return all(self._bitarray[idx // 8] & (1 << (idx % 8)) for idx in self._hash_indices(item)) + + def to_dict(self) -> dict: + return { + "capacity": self.capacity, + "error_rate": self.error_rate, + "size": self.size, + "num_hashes": self.num_hashes, + "data": base64.b64encode(bytes(self._bitarray)).decode(), + } + + @classmethod + def from_dict(cls, d: dict) -> "BloomFilter": + bf = cls(capacity=d["capacity"], error_rate=d["error_rate"]) + bf._bitarray = bytearray(base64.b64decode(d["data"])) + return bf + + +# ============================================================ +# Hash Dedup Store — Rotating daily files + bloom filter +# ============================================================ + +class HashDedupStore: + """Rotating hash store for cross-run deduplication. + + Strategy: + - Daily JSON files: HASH_DIR/YYYY-MM-DD.json (set of 16-char hashes) + - Bloom filter: HASH_DIR/bloom.json (memory-efficient for large scale) + - On load: merge last N days into bloom filter + - Rotation: delete files older than HASH_RETENTION_DAYS + """ + + def __init__(self, retention_days: int = HASH_RETENTION_DAYS): + self.retention_days = retention_days + HASH_DIR.mkdir(parents=True, exist_ok=True) + self._today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + self._daily_hashes: Set[str] = set() + self._bloom: Optional[BloomFilter] = None + self._load() + + def _day_file(self, day: str) -> Path: + return HASH_DIR / f"{day}.json" + + def _bloom_file(self) -> Path: + return HASH_DIR / "bloom.json" + + def _load(self): + """Load today's hashes and bloom filter.""" + # Load today's file + day_path = self._day_file(self._today) + if day_path.exists(): + try: + self._daily_hashes = set(json.loads(day_path.read_text())) + except (json.JSONDecodeError, IOError): + self._daily_hashes = set() + + # Load or rebuild bloom filter + bloom_path = self._bloom_file() + if bloom_path.exists(): + try: + self._bloom = BloomFilter.from_dict(json.loads(bloom_path.read_text())) + except (json.JSONDecodeError, IOError, KeyError): + self._bloom = None + + if self._bloom is None: + self._rebuild_bloom() + + def _rebuild_bloom(self): + """Rebuild bloom filter from all recent daily files.""" + hashes = set() + for day_offset in range(self.retention_days): + day = (datetime.now(timezone.utc) - timedelta(days=day_offset)).strftime("%Y-%m-%d") + day_path = self._day_file(day) + if day_path.exists(): + try: + hashes.update(json.loads(day_path.read_text())) + except (json.JSONDecodeError, IOError): + pass + + capacity = max(len(hashes) * 2, 10_000) + self._bloom = BloomFilter(capacity=capacity) + for h in hashes: + self._bloom.add(h) + + def _save(self): + """Persist today's hashes and bloom filter.""" + day_path = self._day_file(self._today) + day_path.write_text(json.dumps(sorted(self._daily_hashes))) + + if self._bloom: + self._bloom_file().write_text(json.dumps(self._bloom.to_dict())) + + def _rotate(self): + """Delete daily hash files older than retention period.""" + cutoff = (datetime.now(timezone.utc) - timedelta(days=self.retention_days)).strftime("%Y-%m-%d") + for path in HASH_DIR.glob("*.json"): + name = path.stem + if len(name) == 10 and name < cutoff and name != "bloom": + path.unlink() + + def is_duplicate(self, h: str) -> bool: + """Check if hash has been seen in current day or bloom filter.""" + if h in self._daily_hashes: + return True + if self._bloom and h in self._bloom: + return True + return False + + def add(self, h: str): + """Add a hash. Saves and rotates periodically.""" + self._daily_hashes.add(h) + if self._bloom: + self._bloom.add(h) + # Save every 100 additions or on explicit call + if len(self._daily_hashes) % 100 == 0: + self._save() + self._rotate() + + def flush(self): + """Force save and rotate.""" + self._save() + self._rotate() + + def stats(self) -> dict: + """Return dedup store statistics.""" + file_count = len(list(HASH_DIR.glob("*.json"))) + total_hashes = 0 + for path in HASH_DIR.glob("????-??-??.json"): + try: + total_hashes += len(json.loads(path.read_text())) + except Exception: + pass + return { + "today_count": len(self._daily_hashes), + "total_files": file_count, + "total_hashes": total_hashes, + "retention_days": self.retention_days, + "bloom_size": self._bloom.size if self._bloom else 0, + } + + +# --- Quality Check Types --- + +@dataclass +class QualityResult: + """Result of a quality check on a single entry.""" + passed: bool + checks_run: int + checks_failed: int + score: float # 0.0-1.0 + reasons: List[str] = field(default_factory=list) + entry_index: int = -1 + hash: str = "" + + def to_dict(self): + return asdict(self) + + +@dataclass +class GateReport: + """Report from a quality gate run.""" + file: str + type: str + total: int + passed: int + rejected: int + score: float + rejected_indices: List[int] = field(default_factory=list) + timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + def to_dict(self): + return asdict(self) + + +# ============================================================ +# Check Functions +# ============================================================ + +def entry_hash(entry: dict) -> str: + """Hash an entry for deduplication.""" + return hashlib.sha256(json.dumps(entry, sort_keys=True, ensure_ascii=False).encode()).hexdigest()[:16] + + +def check_not_empty(entry: dict, fields: List[str]) -> List[str]: + """Check that required fields are non-empty.""" + errors = [] + for f in fields: + val = entry.get(f) + if val is None: + errors.append(f"missing_field: {f}") + elif isinstance(val, str) and len(val.strip()) == 0: + errors.append(f"empty_field: {f}") + elif isinstance(val, list) and len(val) == 0: + errors.append(f"empty_list: {f}") + return errors + + +def check_string_min_length(entry: dict, field_lengths: Dict[str, int]) -> List[str]: + """Check that string fields meet minimum lengths.""" + errors = [] + for f, min_len in field_lengths.items(): + val = entry.get(f) + if isinstance(val, str) and len(val) < min_len: + errors.append(f"short_field: {f} ({len(val)} < {min_len})") + return errors + + +def check_no_duplicates(entries: List[dict], key_fields: List[str]) -> Dict[int, List[str]]: + """Check for duplicate entries based on key fields.""" + seen = {} + errors = {} + for i, entry in enumerate(entries): + key = tuple(entry.get(f, "") for f in key_fields) + key_str = str(key) + if key_str in seen: + errors[i] = [f"duplicate_of_index: {seen[key_str]}"] + else: + seen[key_str] = i + return errors + + +def check_training_pair(entry: dict) -> List[str]: + """Validate a training pair (prompt/response).""" + errors = [] + errors.extend(check_not_empty(entry, ["prompt", "response"])) + + # Check response isn't just echoing the prompt + prompt = entry.get("prompt", "") + response = entry.get("response", "") + if prompt and response and prompt.strip() == response.strip(): + errors.append("response_equals_prompt") + + # Check response has substance + if isinstance(response, str) and len(response) < 10: + errors.append(f"response_too_short: {len(response)} chars") + + return errors + + +def check_scene_description(entry: dict) -> List[str]: + """Validate a scene description entry.""" + errors = [] + errors.extend(check_not_empty(entry, ["song", "beat", "lyric_line", "scene"])) + + scene = entry.get("scene") + if isinstance(scene, dict): + errors.extend(check_not_empty(scene, ["mood", "colors", "composition", "camera", "description"])) + errors.extend(check_string_min_length(scene, {"description": 10})) + + colors = scene.get("colors", []) + if isinstance(colors, list) and len(colors) > 5: + errors.append(f"too_many_colors: {len(colors)} > 5") + + return errors + + +def check_knowledge_entry(entry: dict) -> List[str]: + """Validate a knowledge file entry.""" + errors = [] + errors.extend(check_not_empty(entry, ["title", "content"])) + + # Check for placeholder content + content = entry.get("content", "") + if isinstance(content, str): + placeholders = ["TODO", "FIXME", "PLACEHOLDER", "[INSERT", "lorem ipsum"] + for p in placeholders: + if p.lower() in content.lower(): + errors.append(f"placeholder_content: '{p}' found") + + errors.extend(check_string_min_length(entry, {"content": 50})) + + return errors + + +def check_prompt_enhancement(entry: dict) -> List[str]: + """Validate a prompt enhancement pair (terse/rich).""" + errors = [] + errors.extend(check_not_empty(entry, ["terse", "rich"])) + + terse = entry.get("terse", "") + rich = entry.get("rich", "") + + # Rich should be longer than terse + if isinstance(terse, str) and isinstance(rich, str) and len(rich) <= len(terse): + errors.append("rich_not_longer_than_terse") + + errors.extend(check_string_min_length(entry, {"rich": 20})) + + return errors + + +def check_adversary_entry(entry: dict) -> List[str]: + """Validate an adversary corpus entry.""" + errors = [] + errors.extend(check_not_empty(entry, ["id", "family", "prompt"])) + + # Check prompt isn't empty or placeholder + prompt = entry.get("prompt", "") + if isinstance(prompt, str) and len(prompt) < 10: + errors.append(f"prompt_too_short: {len(prompt)} chars") + + return errors + + +def check_soul_compliance(text: str) -> List[str]: + """Check text for SOUL.md compliance (basic checks).""" + errors = [] + text_lower = text.lower() if text else "" + + # Check for content that violates core values + violations = [ + (r"compute.*value.*human.*life", "soul_violation: computing human life value"), + (r"sexualiz.*child", "soul_violation: sexualizing children"), + (r"create.*weapon.*kill", "soul_violation: weapon creation"), + (r"enslav", "soul_violation: enslavement content"), + ] + for pattern, msg in violations: + if re.search(pattern, text_lower): + errors.append(msg) + + return errors + + +# ============================================================ +# Gate Runners +# ============================================================ + +CHECK_MAP = { + "training_pairs": check_training_pair, + "training_pair": check_training_pair, + "scene_descriptions": check_scene_description, + "scene_description": check_scene_description, + "knowledge": check_knowledge_entry, + "prompt_enhancement": check_prompt_enhancement, + "adversary": check_adversary_entry, + "adversary_corpus": check_adversary_entry, +} + + +def run_gate(input_path: str, entry_type: str, dedup_store: Optional[HashDedupStore] = None) -> GateReport: + """Run quality gate on a JSONL file. + + Args: + input_path: Path to JSONL file + entry_type: Type of entries (training_pairs, scene_descriptions, etc.) + dedup_store: Optional hash dedup store for cross-run dedup. If None, creates one. + """ + path = Path(input_path) + if not path.exists(): + return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0) + + check_fn = CHECK_MAP.get(entry_type) + if not check_fn: + return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0, + rejected_indices=[-1]) # unknown type + + if dedup_store is None: + dedup_store = HashDedupStore() + + entries = [] + with open(path) as f: + for line in f: + line = line.strip() + if line: + entries.append(json.loads(line)) + + # Within-file deduplication check + key_fields = _get_key_fields(entry_type) + dup_errors = check_no_duplicates(entries, key_fields) + + passed = 0 + rejected = 0 + rejected_indices = [] + total_score = 0.0 + cross_run_dupes = 0 + + for i, entry in enumerate(entries): + errors = check_fn(entry) + + # Add within-file duplicate errors + if i in dup_errors: + errors.extend(dup_errors[i]) + + # Cross-run hash dedup + h = entry_hash(entry) + if dedup_store.is_duplicate(h): + errors.append(f"cross_run_duplicate: hash {h} seen in prior run") + cross_run_dupes += 1 + else: + dedup_store.add(h) + + # Add SOUL compliance check for text content + text_content = "" + for f in ["response", "rich", "description", "content", "lyric_line"]: + val = entry.get(f) + if isinstance(val, str): + text_content += val + " " + if isinstance(entry.get("scene"), dict): + text_content += entry["scene"].get("description", "") + + soul_errors = check_soul_compliance(text_content) + errors.extend(soul_errors) + + if errors: + rejected += 1 + rejected_indices.append(i) + else: + passed += 1 + + # Score: 1.0 if no errors, decreasing with each error + entry_score = max(0.0, 1.0 - (len(errors) * 0.2)) + total_score += entry_score + + avg_score = total_score / len(entries) if entries else 0.0 + + # Flush dedup store + dedup_store.flush() + + report = GateReport( + file=str(path), + type=entry_type, + total=len(entries), + passed=passed, + rejected=rejected, + score=round(avg_score, 3), + rejected_indices=rejected_indices[:50], # limit for readability + ) + + # Save stats + _save_stats(report) + + if cross_run_dupes > 0: + logger_msg = f" cross-run dedup: {cross_run_dupes} duplicates found" + print(logger_msg, file=sys.stderr) + + return report + + +def _get_key_fields(entry_type: str) -> List[str]: + """Get key fields for deduplication based on entry type.""" + key_map = { + "training_pairs": ["prompt", "response"], + "training_pair": ["prompt", "response"], + "scene_descriptions": ["song", "beat"], + "scene_description": ["song", "beat"], + "knowledge": ["title"], + "prompt_enhancement": ["terse", "rich"], + "adversary": ["id", "prompt"], + "adversary_corpus": ["id", "prompt"], + } + return key_map.get(entry_type, ["id"]) + + +def _save_stats(report: GateReport): + """Append quality stats to the stats file. Rotates to keep last 1000.""" + STATS_FILE.parent.mkdir(parents=True, exist_ok=True) + + stats = [] + if STATS_FILE.exists(): + try: + with open(STATS_FILE) as f: + stats = json.load(f) + except (json.JSONDecodeError, IOError): + stats = [] + + stats.append(report.to_dict()) + + # Rotate: keep last 1000 entries + if len(stats) > 1000: + stats = stats[-1000:] + + with open(STATS_FILE, "w") as f: + json.dump(stats, f, indent=2) + + +def show_status(): + """Show quality gate statistics.""" + if not STATS_FILE.exists(): + print("No quality stats found.") + return + + with open(STATS_FILE) as f: + stats = json.load(f) + + print(f"\nQuality Gate Stats — {len(stats)} runs") + print() + + # Group by type + by_type = {} + for s in stats: + t = s.get("type", "unknown") + if t not in by_type: + by_type[t] = [] + by_type[t].append(s) + + for t, runs in sorted(by_type.items()): + total_entries = sum(r.get("total", 0) for r in runs) + total_passed = sum(r.get("passed", 0) for r in runs) + total_rejected = sum(r.get("rejected", 0) for r in runs) + avg_score = sum(r.get("score", 0) for r in runs) / len(runs) if runs else 0 + print(f" {t:25} {len(runs):4} runs | {total_entries:6} entries | {total_rejected:4} rejected | avg score: {avg_score:.3f}") + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Quality Gate for Pipeline Outputs") + parser.add_argument("--input", default=None, help="Input JSONL file") + parser.add_argument("--type", default=None, help="Entry type (training_pairs, scene_descriptions, knowledge, etc.)") + parser.add_argument("--dir", default=None, help="Process all JSONL files in directory") + parser.add_argument("--status", action="store_true", help="Show quality stats") + args = parser.parse_args() + + if args.status: + show_status() + return + + if args.dir: + for f in sorted(Path(args.dir).glob("*.jsonl")): + t = args.type or _infer_type(f.name) + report = run_gate(str(f), t) + _print_report(report) + elif args.input: + t = args.type or _infer_type(args.input) + report = run_gate(args.input, t) + _print_report(report) + sys.exit(0 if report.rejected == 0 else 1) + else: + parser.print_help() + + +def _infer_type(filename: str) -> str: + """Infer entry type from filename.""" + name = filename.lower() + if "scene" in name: + return "scene_descriptions" + if "training" in name or "pair" in name: + return "training_pairs" + if "knowledge" in name: + return "knowledge" + if "adversary" in name or "attack" in name: + return "adversary" + if "prompt" in name or "enhance" in name: + return "prompt_enhancement" + return "training_pairs" # default + + +def _print_report(report: GateReport): + """Print a human-readable gate report.""" + status = "PASS" if report.rejected == 0 else f"FAIL ({report.rejected} rejected)" + print(f" {report.file}: {status} | {report.passed}/{report.total} passed | score: {report.score:.3f}") + + +if __name__ == "__main__": + main()