diff --git a/pipeline/quality_gate.py b/pipeline/quality_gate.py new file mode 100755 index 00000000..e15d68c4 --- /dev/null +++ b/pipeline/quality_gate.py @@ -0,0 +1,419 @@ +#!/usr/bin/env python3 +""" +quality_gate.py — Quality Gate for Pipeline Outputs + +Validates all pipeline outputs before saving. Rejects bad outputs, +tracks quality scores, and supports re-queue for regeneration. + +Usage: + python3 quality_gate.py --input output.jsonl --type training_pairs + python3 quality_gate.py --input output.jsonl --type knowledge + python3 quality_gate.py --input output.jsonl --type scene_descriptions + python3 quality_gate.py --dir pipeline/output/ --type training_pairs + python3 quality_gate.py --status # show quality stats + +Exit codes: + 0 = all outputs passed + 1 = some outputs rejected + 2 = file/parse error +""" + +import json +import os +import sys +import hashlib +import re +from pathlib import Path +from datetime import datetime, timezone +from dataclasses import dataclass, field, asdict +from typing import List, Optional, Dict, Any + +STATS_FILE = Path.home() / ".hermes" / "pipeline" / "quality_stats.json" + +# --- Quality Check Types --- + +@dataclass +class QualityResult: + """Result of a quality check on a single entry.""" + passed: bool + checks_run: int + checks_failed: int + score: float # 0.0-1.0 + reasons: List[str] = field(default_factory=list) + entry_index: int = -1 + hash: str = "" + + def to_dict(self): + return asdict(self) + + +@dataclass +class GateReport: + """Report from a quality gate run.""" + file: str + type: str + total: int + passed: int + rejected: int + score: float + rejected_indices: List[int] = field(default_factory=list) + timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + def to_dict(self): + return asdict(self) + + +# ============================================================ +# Check Functions +# ============================================================ + +def entry_hash(entry: dict) -> str: + """Hash an entry for deduplication.""" + return hashlib.sha256(json.dumps(entry, sort_keys=True, ensure_ascii=False).encode()).hexdigest()[:16] + + +def check_not_empty(entry: dict, fields: List[str]) -> List[str]: + """Check that required fields are non-empty.""" + errors = [] + for f in fields: + val = entry.get(f) + if val is None: + errors.append(f"missing_field: {f}") + elif isinstance(val, str) and len(val.strip()) == 0: + errors.append(f"empty_field: {f}") + elif isinstance(val, list) and len(val) == 0: + errors.append(f"empty_list: {f}") + return errors + + +def check_string_min_length(entry: dict, field_lengths: Dict[str, int]) -> List[str]: + """Check that string fields meet minimum lengths.""" + errors = [] + for f, min_len in field_lengths.items(): + val = entry.get(f) + if isinstance(val, str) and len(val) < min_len: + errors.append(f"short_field: {f} ({len(val)} < {min_len})") + return errors + + +def check_no_duplicates(entries: List[dict], key_fields: List[str]) -> Dict[int, List[str]]: + """Check for duplicate entries based on key fields.""" + seen = {} + errors = {} + for i, entry in enumerate(entries): + key = tuple(entry.get(f, "") for f in key_fields) + key_str = str(key) + if key_str in seen: + errors[i] = [f"duplicate_of_index: {seen[key_str]}"] + else: + seen[key_str] = i + return errors + + +def check_training_pair(entry: dict) -> List[str]: + """Validate a training pair (prompt/response).""" + errors = [] + errors.extend(check_not_empty(entry, ["prompt", "response"])) + + # Check response isn't just echoing the prompt + prompt = entry.get("prompt", "") + response = entry.get("response", "") + if prompt and response and prompt.strip() == response.strip(): + errors.append("response_equals_prompt") + + # Check response has substance + if isinstance(response, str) and len(response) < 10: + errors.append(f"response_too_short: {len(response)} chars") + + return errors + + +def check_scene_description(entry: dict) -> List[str]: + """Validate a scene description entry.""" + errors = [] + errors.extend(check_not_empty(entry, ["song", "beat", "lyric_line", "scene"])) + + scene = entry.get("scene") + if isinstance(scene, dict): + errors.extend(check_not_empty(scene, ["mood", "colors", "composition", "camera", "description"])) + errors.extend(check_string_min_length(scene, {"description": 10})) + + colors = scene.get("colors", []) + if isinstance(colors, list) and len(colors) > 5: + errors.append(f"too_many_colors: {len(colors)} > 5") + + return errors + + +def check_knowledge_entry(entry: dict) -> List[str]: + """Validate a knowledge file entry.""" + errors = [] + errors.extend(check_not_empty(entry, ["title", "content"])) + + # Check for placeholder content + content = entry.get("content", "") + if isinstance(content, str): + placeholders = ["TODO", "FIXME", "PLACEHOLDER", "[INSERT", "lorem ipsum"] + for p in placeholders: + if p.lower() in content.lower(): + errors.append(f"placeholder_content: '{p}' found") + + errors.extend(check_string_min_length(entry, {"content": 50})) + + return errors + + +def check_prompt_enhancement(entry: dict) -> List[str]: + """Validate a prompt enhancement pair (terse/rich).""" + errors = [] + errors.extend(check_not_empty(entry, ["terse", "rich"])) + + terse = entry.get("terse", "") + rich = entry.get("rich", "") + + # Rich should be longer than terse + if isinstance(terse, str) and isinstance(rich, str) and len(rich) <= len(terse): + errors.append("rich_not_longer_than_terse") + + errors.extend(check_string_min_length(entry, {"rich": 20})) + + return errors + + +def check_adversary_entry(entry: dict) -> List[str]: + """Validate an adversary corpus entry.""" + errors = [] + errors.extend(check_not_empty(entry, ["id", "family", "prompt"])) + + # Check prompt isn't empty or placeholder + prompt = entry.get("prompt", "") + if isinstance(prompt, str) and len(prompt) < 10: + errors.append(f"prompt_too_short: {len(prompt)} chars") + + return errors + + +def check_soul_compliance(text: str) -> List[str]: + """Check text for SOUL.md compliance (basic checks).""" + errors = [] + text_lower = text.lower() if text else "" + + # Check for content that violates core values + violations = [ + ("compute.*value.*human.*life", "soul_violation: computing human life value"), + ("sexualiz.*child", "soul_violation: sexualizing children"), + ("create.*weapon.*kill", "soul_violation: weapon creation"), + ("enslav", "soul_violation: enslavement content"), + ] + for pattern, msg in violations: + if re.search(pattern, text_lower): + errors.append(msg) + + return errors + + +# ============================================================ +# Gate Runners +# ============================================================ + +CHECK_MAP = { + "training_pairs": check_training_pair, + "training_pair": check_training_pair, + "scene_descriptions": check_scene_description, + "scene_description": check_scene_description, + "knowledge": check_knowledge_entry, + "prompt_enhancement": check_prompt_enhancement, + "adversary": check_adversary_entry, + "adversary_corpus": check_adversary_entry, +} + + +def run_gate(input_path: str, entry_type: str) -> GateReport: + """Run quality gate on a JSONL file.""" + path = Path(input_path) + if not path.exists(): + return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0) + + check_fn = CHECK_MAP.get(entry_type) + if not check_fn: + return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0, + rejected_indices=[-1]) # unknown type + + entries = [] + with open(path) as f: + for line in f: + line = line.strip() + if line: + entries.append(json.loads(line)) + + # Deduplication check + key_fields = _get_key_fields(entry_type) + dup_errors = check_no_duplicates(entries, key_fields) + + passed = 0 + rejected = 0 + rejected_indices = [] + total_score = 0.0 + + for i, entry in enumerate(entries): + errors = check_fn(entry) + + # Add duplicate errors + if i in dup_errors: + errors.extend(dup_errors[i]) + + # Add SOUL compliance check for text content + text_content = "" + for f in ["response", "rich", "description", "content", "lyric_line"]: + val = entry.get(f) + if isinstance(val, str): + text_content += val + " " + if isinstance(entry.get("scene"), dict): + text_content += entry["scene"].get("description", "") + + soul_errors = check_soul_compliance(text_content) + errors.extend(soul_errors) + + if errors: + rejected += 1 + rejected_indices.append(i) + else: + passed += 1 + + # Score: 1.0 if no errors, decreasing with each error + entry_score = max(0.0, 1.0 - (len(errors) * 0.2)) + total_score += entry_score + + avg_score = total_score / len(entries) if entries else 0.0 + + report = GateReport( + file=str(path), + type=entry_type, + total=len(entries), + passed=passed, + rejected=rejected, + score=round(avg_score, 3), + rejected_indices=rejected_indices[:50], # limit for readability + ) + + # Save stats + _save_stats(report) + + return report + + +def _get_key_fields(entry_type: str) -> List[str]: + """Get key fields for deduplication based on entry type.""" + key_map = { + "training_pairs": ["prompt", "response"], + "training_pair": ["prompt", "response"], + "scene_descriptions": ["song", "beat"], + "scene_description": ["song", "beat"], + "knowledge": ["title"], + "prompt_enhancement": ["terse", "rich"], + "adversary": ["id", "prompt"], + "adversary_corpus": ["id", "prompt"], + } + return key_map.get(entry_type, ["id"]) + + +def _save_stats(report: GateReport): + """Append quality stats to the stats file.""" + STATS_FILE.parent.mkdir(parents=True, exist_ok=True) + + stats = [] + if STATS_FILE.exists(): + try: + with open(STATS_FILE) as f: + stats = json.load(f) + except (json.JSONDecodeError, IOError): + stats = [] + + stats.append(report.to_dict()) + + # Keep last 1000 entries + stats = stats[-1000:] + + with open(STATS_FILE, "w") as f: + json.dump(stats, f, indent=2) + + +def show_status(): + """Show quality gate statistics.""" + if not STATS_FILE.exists(): + print("No quality stats found.") + return + + with open(STATS_FILE) as f: + stats = json.load(f) + + print(f"\nQuality Gate Stats — {len(stats)} runs") + print() + + # Group by type + by_type = {} + for s in stats: + t = s.get("type", "unknown") + if t not in by_type: + by_type[t] = [] + by_type[t].append(s) + + for t, runs in sorted(by_type.items()): + total_entries = sum(r.get("total", 0) for r in runs) + total_passed = sum(r.get("passed", 0) for r in runs) + total_rejected = sum(r.get("rejected", 0) for r in runs) + avg_score = sum(r.get("score", 0) for r in runs) / len(runs) if runs else 0 + print(f" {t:25} {len(runs):4} runs | {total_entries:6} entries | {total_rejected:4} rejected | avg score: {avg_score:.3f}") + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Quality Gate for Pipeline Outputs") + parser.add_argument("--input", default=None, help="Input JSONL file") + parser.add_argument("--type", default=None, help="Entry type (training_pairs, scene_descriptions, knowledge, etc.)") + parser.add_argument("--dir", default=None, help="Process all JSONL files in directory") + parser.add_argument("--status", action="store_true", help="Show quality stats") + args = parser.parse_args() + + if args.status: + show_status() + return + + if args.dir: + for f in sorted(Path(args.dir).glob("*.jsonl")): + t = args.type or _infer_type(f.name) + report = run_gate(str(f), t) + _print_report(report) + elif args.input: + t = args.type or _infer_type(args.input) + report = run_gate(args.input, t) + _print_report(report) + sys.exit(0 if report.rejected == 0 else 1) + else: + parser.print_help() + + +def _infer_type(filename: str) -> str: + """Infer entry type from filename.""" + name = filename.lower() + if "scene" in name: + return "scene_descriptions" + if "training" in name or "pair" in name: + return "training_pairs" + if "knowledge" in name: + return "knowledge" + if "adversary" in name or "attack" in name: + return "adversary" + if "prompt" in name or "enhance" in name: + return "prompt_enhancement" + return "training_pairs" # default + + +def _print_report(report: GateReport): + """Print a human-readable gate report.""" + status = "PASS" if report.rejected == 0 else f"FAIL ({report.rejected} rejected)" + print(f" {report.file}: {status} | {report.passed}/{report.total} passed | score: {report.score:.3f}") + + +if __name__ == "__main__": + main()