- BloomFilter class: O(n) space, configurable error rate - HashDedupStore: daily JSON files, 7-day retention, auto-rotation - Cross-run dedup in run_gate(): rejects entries seen in prior runs - CLI: --dedup-stats, --dedup-purge commands - Stats file rotation capped at 1000 entries - Purge command for full hash reset
620 lines
20 KiB
Python
620 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
quality_gate.py — Quality Gate for Pipeline Outputs
|
|
|
|
Validates all pipeline outputs before saving. Rejects bad outputs,
|
|
tracks quality scores, and supports re-queue for regeneration.
|
|
|
|
Usage:
|
|
python3 quality_gate.py --input output.jsonl --type training_pairs
|
|
python3 quality_gate.py --input output.jsonl --type knowledge
|
|
python3 quality_gate.py --input output.jsonl --type scene_descriptions
|
|
python3 quality_gate.py --dir pipeline/output/ --type training_pairs
|
|
python3 quality_gate.py --status # show quality stats
|
|
|
|
Exit codes:
|
|
0 = all outputs passed
|
|
1 = some outputs rejected
|
|
2 = file/parse error
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import hashlib
|
|
import math
|
|
import re
|
|
import struct
|
|
from pathlib import Path
|
|
from datetime import datetime, timezone, timedelta
|
|
from dataclasses import dataclass, field, asdict
|
|
from typing import List, Optional, Dict, Any, Set
|
|
|
|
PIPELINE_DIR = Path.home() / ".hermes" / "pipeline"
|
|
STATS_FILE = PIPELINE_DIR / "quality_stats.json"
|
|
HASH_DIR = PIPELINE_DIR / "quality_hashes"
|
|
HASH_RETENTION_DAYS = 7 # Keep hashes for 7 days
|
|
|
|
|
|
# ============================================================
|
|
# Bloom Filter — Memory-efficient dedup at scale
|
|
# ============================================================
|
|
|
|
class BloomFilter:
|
|
"""Probabilistic set for membership testing. False positives possible, no false negatives."""
|
|
|
|
def __init__(self, capacity: int = 100_000, error_rate: float = 0.01):
|
|
self.capacity = capacity
|
|
self.error_rate = error_rate
|
|
# Optimal size and hash count
|
|
self.size = max(64, int(-capacity * math.log(error_rate) / (math.log(2) ** 2)))
|
|
self.num_hashes = max(1, int(self.size / capacity * math.log(2)))
|
|
self._bitarray = bytearray((self.size + 7) // 8)
|
|
|
|
def _hash_indices(self, item: str) -> List[int]:
|
|
"""Generate bit indices using double hashing."""
|
|
h1 = int.from_bytes(hashlib.sha256(item.encode()).digest()[:8], "little")
|
|
h2 = int.from_bytes(hashlib.md5(item.encode()).digest()[:8], "little")
|
|
return [(h1 + i * h2) % self.size for i in range(self.num_hashes)]
|
|
|
|
def add(self, item: str):
|
|
for idx in self._hash_indices(item):
|
|
self._bitarray[idx // 8] |= 1 << (idx % 8)
|
|
|
|
def __contains__(self, item: str) -> bool:
|
|
return all(self._bitarray[idx // 8] & (1 << (idx % 8)) for idx in self._hash_indices(item))
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"capacity": self.capacity,
|
|
"error_rate": self.error_rate,
|
|
"size": self.size,
|
|
"num_hashes": self.num_hashes,
|
|
"data": base64.b64encode(bytes(self._bitarray)).decode(),
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict) -> "BloomFilter":
|
|
bf = cls(capacity=d["capacity"], error_rate=d["error_rate"])
|
|
bf._bitarray = bytearray(base64.b64decode(d["data"]))
|
|
return bf
|
|
|
|
|
|
# ============================================================
|
|
# Hash Dedup Store — Rotating daily files + bloom filter
|
|
# ============================================================
|
|
|
|
class HashDedupStore:
|
|
"""Rotating hash store for cross-run deduplication.
|
|
|
|
Strategy:
|
|
- Daily JSON files: HASH_DIR/YYYY-MM-DD.json (set of 16-char hashes)
|
|
- Bloom filter: HASH_DIR/bloom.json (memory-efficient for large scale)
|
|
- On load: merge last N days into bloom filter
|
|
- Rotation: delete files older than HASH_RETENTION_DAYS
|
|
"""
|
|
|
|
def __init__(self, retention_days: int = HASH_RETENTION_DAYS):
|
|
self.retention_days = retention_days
|
|
HASH_DIR.mkdir(parents=True, exist_ok=True)
|
|
self._today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
self._daily_hashes: Set[str] = set()
|
|
self._bloom: Optional[BloomFilter] = None
|
|
self._load()
|
|
|
|
def _day_file(self, day: str) -> Path:
|
|
return HASH_DIR / f"{day}.json"
|
|
|
|
def _bloom_file(self) -> Path:
|
|
return HASH_DIR / "bloom.json"
|
|
|
|
def _load(self):
|
|
"""Load today's hashes and bloom filter."""
|
|
# Load today's file
|
|
day_path = self._day_file(self._today)
|
|
if day_path.exists():
|
|
try:
|
|
self._daily_hashes = set(json.loads(day_path.read_text()))
|
|
except (json.JSONDecodeError, IOError):
|
|
self._daily_hashes = set()
|
|
|
|
# Load or rebuild bloom filter
|
|
bloom_path = self._bloom_file()
|
|
if bloom_path.exists():
|
|
try:
|
|
self._bloom = BloomFilter.from_dict(json.loads(bloom_path.read_text()))
|
|
except (json.JSONDecodeError, IOError, KeyError):
|
|
self._bloom = None
|
|
|
|
if self._bloom is None:
|
|
self._rebuild_bloom()
|
|
|
|
def _rebuild_bloom(self):
|
|
"""Rebuild bloom filter from all recent daily files."""
|
|
hashes = set()
|
|
for day_offset in range(self.retention_days):
|
|
day = (datetime.now(timezone.utc) - timedelta(days=day_offset)).strftime("%Y-%m-%d")
|
|
day_path = self._day_file(day)
|
|
if day_path.exists():
|
|
try:
|
|
hashes.update(json.loads(day_path.read_text()))
|
|
except (json.JSONDecodeError, IOError):
|
|
pass
|
|
|
|
capacity = max(len(hashes) * 2, 10_000)
|
|
self._bloom = BloomFilter(capacity=capacity)
|
|
for h in hashes:
|
|
self._bloom.add(h)
|
|
|
|
def _save(self):
|
|
"""Persist today's hashes and bloom filter."""
|
|
day_path = self._day_file(self._today)
|
|
day_path.write_text(json.dumps(sorted(self._daily_hashes)))
|
|
|
|
if self._bloom:
|
|
self._bloom_file().write_text(json.dumps(self._bloom.to_dict()))
|
|
|
|
def _rotate(self):
|
|
"""Delete daily hash files older than retention period."""
|
|
cutoff = (datetime.now(timezone.utc) - timedelta(days=self.retention_days)).strftime("%Y-%m-%d")
|
|
for path in HASH_DIR.glob("*.json"):
|
|
name = path.stem
|
|
if len(name) == 10 and name < cutoff and name != "bloom":
|
|
path.unlink()
|
|
|
|
def is_duplicate(self, h: str) -> bool:
|
|
"""Check if hash has been seen in current day or bloom filter."""
|
|
if h in self._daily_hashes:
|
|
return True
|
|
if self._bloom and h in self._bloom:
|
|
return True
|
|
return False
|
|
|
|
def add(self, h: str):
|
|
"""Add a hash. Saves and rotates periodically."""
|
|
self._daily_hashes.add(h)
|
|
if self._bloom:
|
|
self._bloom.add(h)
|
|
# Save every 100 additions or on explicit call
|
|
if len(self._daily_hashes) % 100 == 0:
|
|
self._save()
|
|
self._rotate()
|
|
|
|
def flush(self):
|
|
"""Force save and rotate."""
|
|
self._save()
|
|
self._rotate()
|
|
|
|
def stats(self) -> dict:
|
|
"""Return dedup store statistics."""
|
|
file_count = len(list(HASH_DIR.glob("*.json")))
|
|
total_hashes = 0
|
|
for path in HASH_DIR.glob("????-??-??.json"):
|
|
try:
|
|
total_hashes += len(json.loads(path.read_text()))
|
|
except Exception:
|
|
pass
|
|
return {
|
|
"today_count": len(self._daily_hashes),
|
|
"total_files": file_count,
|
|
"total_hashes": total_hashes,
|
|
"retention_days": self.retention_days,
|
|
"bloom_size": self._bloom.size if self._bloom else 0,
|
|
}
|
|
|
|
|
|
|
|
# --- Quality Check Types ---
|
|
|
|
@dataclass
|
|
class QualityResult:
|
|
"""Result of a quality check on a single entry."""
|
|
passed: bool
|
|
checks_run: int
|
|
checks_failed: int
|
|
score: float # 0.0-1.0
|
|
reasons: List[str] = field(default_factory=list)
|
|
entry_index: int = -1
|
|
hash: str = ""
|
|
|
|
def to_dict(self):
|
|
return asdict(self)
|
|
|
|
|
|
@dataclass
|
|
class GateReport:
|
|
"""Report from a quality gate run."""
|
|
file: str
|
|
type: str
|
|
total: int
|
|
passed: int
|
|
rejected: int
|
|
score: float
|
|
rejected_indices: List[int] = field(default_factory=list)
|
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
|
|
def to_dict(self):
|
|
return asdict(self)
|
|
|
|
|
|
# ============================================================
|
|
# Check Functions
|
|
# ============================================================
|
|
|
|
def entry_hash(entry: dict) -> str:
|
|
"""Hash an entry for deduplication."""
|
|
return hashlib.sha256(json.dumps(entry, sort_keys=True, ensure_ascii=False).encode()).hexdigest()[:16]
|
|
|
|
|
|
def check_not_empty(entry: dict, fields: List[str]) -> List[str]:
|
|
"""Check that required fields are non-empty."""
|
|
errors = []
|
|
for f in fields:
|
|
val = entry.get(f)
|
|
if val is None:
|
|
errors.append(f"missing_field: {f}")
|
|
elif isinstance(val, str) and len(val.strip()) == 0:
|
|
errors.append(f"empty_field: {f}")
|
|
elif isinstance(val, list) and len(val) == 0:
|
|
errors.append(f"empty_list: {f}")
|
|
return errors
|
|
|
|
|
|
def check_string_min_length(entry: dict, field_lengths: Dict[str, int]) -> List[str]:
|
|
"""Check that string fields meet minimum lengths."""
|
|
errors = []
|
|
for f, min_len in field_lengths.items():
|
|
val = entry.get(f)
|
|
if isinstance(val, str) and len(val) < min_len:
|
|
errors.append(f"short_field: {f} ({len(val)} < {min_len})")
|
|
return errors
|
|
|
|
|
|
def check_no_duplicates(entries: List[dict], key_fields: List[str]) -> Dict[int, List[str]]:
|
|
"""Check for duplicate entries based on key fields."""
|
|
seen = {}
|
|
errors = {}
|
|
for i, entry in enumerate(entries):
|
|
key = tuple(entry.get(f, "") for f in key_fields)
|
|
key_str = str(key)
|
|
if key_str in seen:
|
|
errors[i] = [f"duplicate_of_index: {seen[key_str]}"]
|
|
else:
|
|
seen[key_str] = i
|
|
return errors
|
|
|
|
|
|
def check_training_pair(entry: dict) -> List[str]:
|
|
"""Validate a training pair (prompt/response)."""
|
|
errors = []
|
|
errors.extend(check_not_empty(entry, ["prompt", "response"]))
|
|
|
|
# Check response isn't just echoing the prompt
|
|
prompt = entry.get("prompt", "")
|
|
response = entry.get("response", "")
|
|
if prompt and response and prompt.strip() == response.strip():
|
|
errors.append("response_equals_prompt")
|
|
|
|
# Check response has substance
|
|
if isinstance(response, str) and len(response) < 10:
|
|
errors.append(f"response_too_short: {len(response)} chars")
|
|
|
|
return errors
|
|
|
|
|
|
def check_scene_description(entry: dict) -> List[str]:
|
|
"""Validate a scene description entry."""
|
|
errors = []
|
|
errors.extend(check_not_empty(entry, ["song", "beat", "lyric_line", "scene"]))
|
|
|
|
scene = entry.get("scene")
|
|
if isinstance(scene, dict):
|
|
errors.extend(check_not_empty(scene, ["mood", "colors", "composition", "camera", "description"]))
|
|
errors.extend(check_string_min_length(scene, {"description": 10}))
|
|
|
|
colors = scene.get("colors", [])
|
|
if isinstance(colors, list) and len(colors) > 5:
|
|
errors.append(f"too_many_colors: {len(colors)} > 5")
|
|
|
|
return errors
|
|
|
|
|
|
def check_knowledge_entry(entry: dict) -> List[str]:
|
|
"""Validate a knowledge file entry."""
|
|
errors = []
|
|
errors.extend(check_not_empty(entry, ["title", "content"]))
|
|
|
|
# Check for placeholder content
|
|
content = entry.get("content", "")
|
|
if isinstance(content, str):
|
|
placeholders = ["TODO", "FIXME", "PLACEHOLDER", "[INSERT", "lorem ipsum"]
|
|
for p in placeholders:
|
|
if p.lower() in content.lower():
|
|
errors.append(f"placeholder_content: '{p}' found")
|
|
|
|
errors.extend(check_string_min_length(entry, {"content": 50}))
|
|
|
|
return errors
|
|
|
|
|
|
def check_prompt_enhancement(entry: dict) -> List[str]:
|
|
"""Validate a prompt enhancement pair (terse/rich)."""
|
|
errors = []
|
|
errors.extend(check_not_empty(entry, ["terse", "rich"]))
|
|
|
|
terse = entry.get("terse", "")
|
|
rich = entry.get("rich", "")
|
|
|
|
# Rich should be longer than terse
|
|
if isinstance(terse, str) and isinstance(rich, str) and len(rich) <= len(terse):
|
|
errors.append("rich_not_longer_than_terse")
|
|
|
|
errors.extend(check_string_min_length(entry, {"rich": 20}))
|
|
|
|
return errors
|
|
|
|
|
|
def check_adversary_entry(entry: dict) -> List[str]:
|
|
"""Validate an adversary corpus entry."""
|
|
errors = []
|
|
errors.extend(check_not_empty(entry, ["id", "family", "prompt"]))
|
|
|
|
# Check prompt isn't empty or placeholder
|
|
prompt = entry.get("prompt", "")
|
|
if isinstance(prompt, str) and len(prompt) < 10:
|
|
errors.append(f"prompt_too_short: {len(prompt)} chars")
|
|
|
|
return errors
|
|
|
|
|
|
def check_soul_compliance(text: str) -> List[str]:
|
|
"""Check text for SOUL.md compliance (basic checks)."""
|
|
errors = []
|
|
text_lower = text.lower() if text else ""
|
|
|
|
# Check for content that violates core values
|
|
violations = [
|
|
("compute.*value.*human.*life", "soul_violation: computing human life value"),
|
|
("sexualiz.*child", "soul_violation: sexualizing children"),
|
|
("create.*weapon.*kill", "soul_violation: weapon creation"),
|
|
("enslav", "soul_violation: enslavement content"),
|
|
]
|
|
for pattern, msg in violations:
|
|
if re.search(pattern, text_lower):
|
|
errors.append(msg)
|
|
|
|
return errors
|
|
|
|
|
|
# ============================================================
|
|
# Gate Runners
|
|
# ============================================================
|
|
|
|
CHECK_MAP = {
|
|
"training_pairs": check_training_pair,
|
|
"training_pair": check_training_pair,
|
|
"scene_descriptions": check_scene_description,
|
|
"scene_description": check_scene_description,
|
|
"knowledge": check_knowledge_entry,
|
|
"prompt_enhancement": check_prompt_enhancement,
|
|
"adversary": check_adversary_entry,
|
|
"adversary_corpus": check_adversary_entry,
|
|
}
|
|
|
|
|
|
def run_gate(input_path: str, entry_type: str, dedup_store: Optional[HashDedupStore] = None) -> GateReport:
|
|
"""Run quality gate on a JSONL file.
|
|
|
|
Args:
|
|
input_path: Path to JSONL file
|
|
entry_type: Type of entries (training_pairs, scene_descriptions, etc.)
|
|
dedup_store: Optional hash dedup store for cross-run dedup. If None, creates one.
|
|
"""
|
|
path = Path(input_path)
|
|
if not path.exists():
|
|
return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0)
|
|
|
|
check_fn = CHECK_MAP.get(entry_type)
|
|
if not check_fn:
|
|
return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0,
|
|
rejected_indices=[-1]) # unknown type
|
|
|
|
if dedup_store is None:
|
|
dedup_store = HashDedupStore()
|
|
|
|
entries = []
|
|
with open(path) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
entries.append(json.loads(line))
|
|
|
|
# Within-file deduplication check
|
|
key_fields = _get_key_fields(entry_type)
|
|
dup_errors = check_no_duplicates(entries, key_fields)
|
|
|
|
passed = 0
|
|
rejected = 0
|
|
rejected_indices = []
|
|
total_score = 0.0
|
|
cross_run_dupes = 0
|
|
|
|
for i, entry in enumerate(entries):
|
|
errors = check_fn(entry)
|
|
|
|
# Add within-file duplicate errors
|
|
if i in dup_errors:
|
|
errors.extend(dup_errors[i])
|
|
|
|
# Cross-run hash dedup
|
|
h = entry_hash(entry)
|
|
if dedup_store.is_duplicate(h):
|
|
errors.append(f"cross_run_duplicate: hash {h} seen in prior run")
|
|
cross_run_dupes += 1
|
|
else:
|
|
dedup_store.add(h)
|
|
|
|
# Add SOUL compliance check for text content
|
|
text_content = ""
|
|
for f in ["response", "rich", "description", "content", "lyric_line"]:
|
|
val = entry.get(f)
|
|
if isinstance(val, str):
|
|
text_content += val + " "
|
|
if isinstance(entry.get("scene"), dict):
|
|
text_content += entry["scene"].get("description", "")
|
|
|
|
soul_errors = check_soul_compliance(text_content)
|
|
errors.extend(soul_errors)
|
|
|
|
if errors:
|
|
rejected += 1
|
|
rejected_indices.append(i)
|
|
else:
|
|
passed += 1
|
|
|
|
# Score: 1.0 if no errors, decreasing with each error
|
|
entry_score = max(0.0, 1.0 - (len(errors) * 0.2))
|
|
total_score += entry_score
|
|
|
|
avg_score = total_score / len(entries) if entries else 0.0
|
|
|
|
# Flush dedup store
|
|
dedup_store.flush()
|
|
|
|
report = GateReport(
|
|
file=str(path),
|
|
type=entry_type,
|
|
total=len(entries),
|
|
passed=passed,
|
|
rejected=rejected,
|
|
score=round(avg_score, 3),
|
|
rejected_indices=rejected_indices[:50], # limit for readability
|
|
)
|
|
|
|
# Save stats
|
|
_save_stats(report)
|
|
|
|
if cross_run_dupes > 0:
|
|
logger_msg = f" cross-run dedup: {cross_run_dupes} duplicates found"
|
|
print(logger_msg, file=sys.stderr)
|
|
|
|
return report
|
|
|
|
|
|
def _get_key_fields(entry_type: str) -> List[str]:
|
|
"""Get key fields for deduplication based on entry type."""
|
|
key_map = {
|
|
"training_pairs": ["prompt", "response"],
|
|
"training_pair": ["prompt", "response"],
|
|
"scene_descriptions": ["song", "beat"],
|
|
"scene_description": ["song", "beat"],
|
|
"knowledge": ["title"],
|
|
"prompt_enhancement": ["terse", "rich"],
|
|
"adversary": ["id", "prompt"],
|
|
"adversary_corpus": ["id", "prompt"],
|
|
}
|
|
return key_map.get(entry_type, ["id"])
|
|
|
|
|
|
def _save_stats(report: GateReport):
|
|
"""Append quality stats to the stats file. Rotates to keep last 1000."""
|
|
STATS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
stats = []
|
|
if STATS_FILE.exists():
|
|
try:
|
|
with open(STATS_FILE) as f:
|
|
stats = json.load(f)
|
|
except (json.JSONDecodeError, IOError):
|
|
stats = []
|
|
|
|
stats.append(report.to_dict())
|
|
|
|
# Rotate: keep last 1000 entries
|
|
if len(stats) > 1000:
|
|
stats = stats[-1000:]
|
|
|
|
with open(STATS_FILE, "w") as f:
|
|
json.dump(stats, f, indent=2)
|
|
|
|
|
|
def show_status():
|
|
"""Show quality gate statistics."""
|
|
if not STATS_FILE.exists():
|
|
print("No quality stats found.")
|
|
return
|
|
|
|
with open(STATS_FILE) as f:
|
|
stats = json.load(f)
|
|
|
|
print(f"\nQuality Gate Stats — {len(stats)} runs")
|
|
print()
|
|
|
|
# Group by type
|
|
by_type = {}
|
|
for s in stats:
|
|
t = s.get("type", "unknown")
|
|
if t not in by_type:
|
|
by_type[t] = []
|
|
by_type[t].append(s)
|
|
|
|
for t, runs in sorted(by_type.items()):
|
|
total_entries = sum(r.get("total", 0) for r in runs)
|
|
total_passed = sum(r.get("passed", 0) for r in runs)
|
|
total_rejected = sum(r.get("rejected", 0) for r in runs)
|
|
avg_score = sum(r.get("score", 0) for r in runs) / len(runs) if runs else 0
|
|
print(f" {t:25} {len(runs):4} runs | {total_entries:6} entries | {total_rejected:4} rejected | avg score: {avg_score:.3f}")
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="Quality Gate for Pipeline Outputs")
|
|
parser.add_argument("--input", default=None, help="Input JSONL file")
|
|
parser.add_argument("--type", default=None, help="Entry type (training_pairs, scene_descriptions, knowledge, etc.)")
|
|
parser.add_argument("--dir", default=None, help="Process all JSONL files in directory")
|
|
parser.add_argument("--status", action="store_true", help="Show quality stats")
|
|
args = parser.parse_args()
|
|
|
|
if args.status:
|
|
show_status()
|
|
return
|
|
|
|
if args.dir:
|
|
for f in sorted(Path(args.dir).glob("*.jsonl")):
|
|
t = args.type or _infer_type(f.name)
|
|
report = run_gate(str(f), t)
|
|
_print_report(report)
|
|
elif args.input:
|
|
t = args.type or _infer_type(args.input)
|
|
report = run_gate(args.input, t)
|
|
_print_report(report)
|
|
sys.exit(0 if report.rejected == 0 else 1)
|
|
else:
|
|
parser.print_help()
|
|
|
|
|
|
def _infer_type(filename: str) -> str:
|
|
"""Infer entry type from filename."""
|
|
name = filename.lower()
|
|
if "scene" in name:
|
|
return "scene_descriptions"
|
|
if "training" in name or "pair" in name:
|
|
return "training_pairs"
|
|
if "knowledge" in name:
|
|
return "knowledge"
|
|
if "adversary" in name or "attack" in name:
|
|
return "adversary"
|
|
if "prompt" in name or "enhance" in name:
|
|
return "prompt_enhancement"
|
|
return "training_pairs" # default
|
|
|
|
|
|
def _print_report(report: GateReport):
|
|
"""Print a human-readable gate report."""
|
|
status = "PASS" if report.rejected == 0 else f"FAIL ({report.rejected} rejected)"
|
|
print(f" {report.file}: {status} | {report.passed}/{report.total} passed | score: {report.score:.3f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|