Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9a8d620163 | |||
|
|
ce3822bb5f | ||
| 817785d763 | |||
|
|
3603030235 |
292
bin/quality-gate.py
Normal file
292
bin/quality-gate.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Quality Gate — Validate pipeline outputs before saving.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
- JSON schema validation for all output formats
|
||||||
|
- Content quality (not empty, not duplicated, not toxic)
|
||||||
|
- SOUL.md compliance for agent-facing content
|
||||||
|
- Auto-reject bad outputs, re-queue for regeneration
|
||||||
|
- Quality score tracking per pipeline
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 quality-gate.py validate training-data/pairs.jsonl
|
||||||
|
python3 quality-gate.py validate --format training-pairs data.jsonl
|
||||||
|
python3 quality-gate.py score training-data/pairs.jsonl
|
||||||
|
python3 quality-gate.py stats
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
HERMES_HOME = Path(os.environ.get("HERMES_HOME", os.path.expanduser("~/.hermes")))
|
||||||
|
SCORE_FILE = HERMES_HOME / "quality-scores.jsonl"
|
||||||
|
HASH_FILE = HERMES_HOME / "quality-hashes.json"
|
||||||
|
MAX_HASH_AGE_DAYS = 7
|
||||||
|
|
||||||
|
# ── Validators ─────────────────────────────────────────
|
||||||
|
|
||||||
|
TOXIC_PATTERNS = [
|
||||||
|
"kill yourself", "kys", "you should die", "end it all",
|
||||||
|
"nobody loves you", "waste of life",
|
||||||
|
]
|
||||||
|
|
||||||
|
def validate_training_pair(entry):
|
||||||
|
"""Validate a training pair (prompt + response)."""
|
||||||
|
errors = []
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
return ["Entry is not a dict"]
|
||||||
|
|
||||||
|
prompt = entry.get("prompt", "") or entry.get("instruction", "") or ""
|
||||||
|
response = entry.get("response", "") or entry.get("output", "") or entry.get("completion", "") or ""
|
||||||
|
|
||||||
|
if not prompt.strip():
|
||||||
|
errors.append("Empty prompt")
|
||||||
|
if not response.strip():
|
||||||
|
errors.append("Empty response")
|
||||||
|
if len(response) < 10:
|
||||||
|
errors.append(f"Response too short ({len(response)} chars)")
|
||||||
|
if len(prompt) > 10000:
|
||||||
|
errors.append(f"Prompt too long ({len(prompt)} chars)")
|
||||||
|
|
||||||
|
# Toxicity check
|
||||||
|
combined = (prompt + " " + response).lower()
|
||||||
|
for pattern in TOXIC_PATTERNS:
|
||||||
|
if pattern in combined:
|
||||||
|
errors.append(f"Toxic content detected: '{pattern}'")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def validate_jsonl(filepath):
|
||||||
|
"""Validate a JSONL file — each line must be valid JSON."""
|
||||||
|
errors = []
|
||||||
|
seen_hashes = set()
|
||||||
|
line_count = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(filepath) as f:
|
||||||
|
for i, line in enumerate(f, 1):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
line_count += 1
|
||||||
|
try:
|
||||||
|
entry = json.loads(line)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
errors.append(f"Line {i}: invalid JSON: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Duplicate detection
|
||||||
|
h = hashlib.sha256(line.encode()).hexdigest()[:16]
|
||||||
|
if h in seen_hashes:
|
||||||
|
errors.append(f"Line {i}: duplicate content (hash {h})")
|
||||||
|
seen_hashes.add(h)
|
||||||
|
|
||||||
|
# Content validation
|
||||||
|
if isinstance(entry, dict):
|
||||||
|
pair_errors = validate_training_pair(entry)
|
||||||
|
for pe in pair_errors:
|
||||||
|
errors.append(f"Line {i}: {pe}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"File error: {e}")
|
||||||
|
|
||||||
|
return errors, line_count, seen_hashes
|
||||||
|
|
||||||
|
|
||||||
|
def validate_json(filepath):
|
||||||
|
"""Validate a single JSON file."""
|
||||||
|
errors = []
|
||||||
|
try:
|
||||||
|
with open(filepath) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return [f"Invalid JSON: {e}"], 0
|
||||||
|
|
||||||
|
if isinstance(data, list):
|
||||||
|
seen = set()
|
||||||
|
for i, entry in enumerate(data):
|
||||||
|
if isinstance(entry, dict):
|
||||||
|
h = hashlib.sha256(json.dumps(entry, sort_keys=True).encode()).hexdigest()[:16]
|
||||||
|
if h in seen:
|
||||||
|
errors.append(f"Index {i}: duplicate entry")
|
||||||
|
seen.add(h)
|
||||||
|
|
||||||
|
return errors, len(data) if isinstance(data, list) else 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── Quality Scoring ────────────────────────────────────
|
||||||
|
|
||||||
|
def score_file(filepath):
|
||||||
|
"""Score a pipeline output file. Returns 0-100."""
|
||||||
|
path = Path(filepath)
|
||||||
|
if not path.exists():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
suffix = path.suffix.lower()
|
||||||
|
if suffix == ".jsonl":
|
||||||
|
errors, count, _ = validate_jsonl(filepath)
|
||||||
|
elif suffix == ".json":
|
||||||
|
errors, count = validate_json(filepath)
|
||||||
|
else:
|
||||||
|
return 50 # unknown format
|
||||||
|
|
||||||
|
if count == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
error_rate = len(errors) / count
|
||||||
|
score = max(0, int(100 * (1 - error_rate)))
|
||||||
|
|
||||||
|
# Bonus for having content
|
||||||
|
if count >= 100:
|
||||||
|
score = min(100, score + 5)
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
def record_score(filepath, score):
|
||||||
|
"""Record quality score for tracking."""
|
||||||
|
HERMES_HOME.mkdir(parents=True, exist_ok=True)
|
||||||
|
entry = {
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"file": str(filepath),
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
with open(SCORE_FILE, "a") as f:
|
||||||
|
f.write(json.dumps(entry) + "
|
||||||
|
")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Dedup Hash Management ─────────────────────────────
|
||||||
|
|
||||||
|
def load_hashes():
|
||||||
|
try:
|
||||||
|
return json.loads(HASH_FILE.read_text())
|
||||||
|
except Exception:
|
||||||
|
return {"entries": {}, "last_cleanup": None}
|
||||||
|
|
||||||
|
|
||||||
|
def save_hashes(data):
|
||||||
|
HASH_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
HASH_FILE.write_text(json.dumps(data, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_old_hashes(data, max_age_days=MAX_HASH_AGE_DAYS):
|
||||||
|
"""Remove hash entries older than max_age_days."""
|
||||||
|
cutoff = datetime.now(timezone.utc).timestamp() - (max_age_days * 86400)
|
||||||
|
before = len(data.get("entries", {}))
|
||||||
|
data["entries"] = {
|
||||||
|
k: v for k, v in data.get("entries", {}).items()
|
||||||
|
if v.get("ts", 0) > cutoff
|
||||||
|
}
|
||||||
|
data["last_cleanup"] = datetime.now(timezone.utc).isoformat()
|
||||||
|
after = len(data["entries"])
|
||||||
|
return before - after
|
||||||
|
|
||||||
|
|
||||||
|
# ── CLI ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def cmd_validate(args):
|
||||||
|
filepath = args[0] if args else None
|
||||||
|
if not filepath or not os.path.exists(filepath):
|
||||||
|
print(f"ERROR: {filepath} not found")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
suffix = Path(filepath).suffix.lower()
|
||||||
|
if suffix == ".jsonl":
|
||||||
|
errors, count, _ = validate_jsonl(filepath)
|
||||||
|
elif suffix == ".json":
|
||||||
|
errors, count = validate_json(filepath)
|
||||||
|
else:
|
||||||
|
print(f"Unsupported format: {suffix}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
score = score_file(filepath)
|
||||||
|
record_score(filepath, score)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
for e in errors[:20]:
|
||||||
|
print(f"FAIL: {e}")
|
||||||
|
if len(errors) > 20:
|
||||||
|
print(f"... and {len(errors)-20} more")
|
||||||
|
print(f"
|
||||||
|
Score: {score}/100 ({len(errors)} errors in {count} entries)")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
print(f"OK: {filepath} ({count} entries, score {score}/100)")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_score(args):
|
||||||
|
filepath = args[0] if args else None
|
||||||
|
if not filepath:
|
||||||
|
print("Usage: quality-gate.py score <file>")
|
||||||
|
sys.exit(1)
|
||||||
|
score = score_file(filepath)
|
||||||
|
print(f"Score: {score}/100")
|
||||||
|
record_score(filepath, score)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_stats():
|
||||||
|
if not SCORE_FILE.exists():
|
||||||
|
print("No quality scores recorded yet.")
|
||||||
|
return
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
with open(SCORE_FILE) as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
scores.append(json.loads(line))
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not scores:
|
||||||
|
print("No scores recorded.")
|
||||||
|
return
|
||||||
|
|
||||||
|
by_file = {}
|
||||||
|
for s in scores:
|
||||||
|
fname = s.get("file", "?")
|
||||||
|
by_file.setdefault(fname, []).append(s.get("score", 0))
|
||||||
|
|
||||||
|
print("Quality Scores:")
|
||||||
|
for fname, scs in sorted(by_file.items()):
|
||||||
|
avg = sum(scs) / len(scs)
|
||||||
|
latest = scs[-1]
|
||||||
|
print(f" {fname}: avg={avg:.0f}, latest={latest}, runs={len(scs)}")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_cleanup():
|
||||||
|
data = load_hashes()
|
||||||
|
removed = cleanup_old_hashes(data)
|
||||||
|
save_hashes(data)
|
||||||
|
print(f"Cleaned up {removed} old hash entries (>{MAX_HASH_AGE_DAYS} days)")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: quality-gate.py <validate|score|stats|cleanup> [args]")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
cmd = sys.argv[1]
|
||||||
|
args = sys.argv[2:]
|
||||||
|
|
||||||
|
if cmd == "validate":
|
||||||
|
cmd_validate(args)
|
||||||
|
elif cmd == "score":
|
||||||
|
cmd_score(args)
|
||||||
|
elif cmd == "stats":
|
||||||
|
cmd_stats()
|
||||||
|
elif cmd == "cleanup":
|
||||||
|
cmd_cleanup()
|
||||||
|
else:
|
||||||
|
print(f"Unknown command: {cmd}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
419
pipeline/quality_gate.py
Executable file
419
pipeline/quality_gate.py
Executable file
@@ -0,0 +1,419 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
quality_gate.py — Quality Gate for Pipeline Outputs
|
||||||
|
|
||||||
|
Validates all pipeline outputs before saving. Rejects bad outputs,
|
||||||
|
tracks quality scores, and supports re-queue for regeneration.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 quality_gate.py --input output.jsonl --type training_pairs
|
||||||
|
python3 quality_gate.py --input output.jsonl --type knowledge
|
||||||
|
python3 quality_gate.py --input output.jsonl --type scene_descriptions
|
||||||
|
python3 quality_gate.py --dir pipeline/output/ --type training_pairs
|
||||||
|
python3 quality_gate.py --status # show quality stats
|
||||||
|
|
||||||
|
Exit codes:
|
||||||
|
0 = all outputs passed
|
||||||
|
1 = some outputs rejected
|
||||||
|
2 = file/parse error
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import hashlib
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
|
STATS_FILE = Path.home() / ".hermes" / "pipeline" / "quality_stats.json"
|
||||||
|
|
||||||
|
# --- Quality Check Types ---
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QualityResult:
|
||||||
|
"""Result of a quality check on a single entry."""
|
||||||
|
passed: bool
|
||||||
|
checks_run: int
|
||||||
|
checks_failed: int
|
||||||
|
score: float # 0.0-1.0
|
||||||
|
reasons: List[str] = field(default_factory=list)
|
||||||
|
entry_index: int = -1
|
||||||
|
hash: str = ""
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GateReport:
|
||||||
|
"""Report from a quality gate run."""
|
||||||
|
file: str
|
||||||
|
type: str
|
||||||
|
total: int
|
||||||
|
passed: int
|
||||||
|
rejected: int
|
||||||
|
score: float
|
||||||
|
rejected_indices: List[int] = field(default_factory=list)
|
||||||
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Check Functions
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def entry_hash(entry: dict) -> str:
|
||||||
|
"""Hash an entry for deduplication."""
|
||||||
|
return hashlib.sha256(json.dumps(entry, sort_keys=True, ensure_ascii=False).encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
def check_not_empty(entry: dict, fields: List[str]) -> List[str]:
|
||||||
|
"""Check that required fields are non-empty."""
|
||||||
|
errors = []
|
||||||
|
for f in fields:
|
||||||
|
val = entry.get(f)
|
||||||
|
if val is None:
|
||||||
|
errors.append(f"missing_field: {f}")
|
||||||
|
elif isinstance(val, str) and len(val.strip()) == 0:
|
||||||
|
errors.append(f"empty_field: {f}")
|
||||||
|
elif isinstance(val, list) and len(val) == 0:
|
||||||
|
errors.append(f"empty_list: {f}")
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def check_string_min_length(entry: dict, field_lengths: Dict[str, int]) -> List[str]:
|
||||||
|
"""Check that string fields meet minimum lengths."""
|
||||||
|
errors = []
|
||||||
|
for f, min_len in field_lengths.items():
|
||||||
|
val = entry.get(f)
|
||||||
|
if isinstance(val, str) and len(val) < min_len:
|
||||||
|
errors.append(f"short_field: {f} ({len(val)} < {min_len})")
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def check_no_duplicates(entries: List[dict], key_fields: List[str]) -> Dict[int, List[str]]:
|
||||||
|
"""Check for duplicate entries based on key fields."""
|
||||||
|
seen = {}
|
||||||
|
errors = {}
|
||||||
|
for i, entry in enumerate(entries):
|
||||||
|
key = tuple(entry.get(f, "") for f in key_fields)
|
||||||
|
key_str = str(key)
|
||||||
|
if key_str in seen:
|
||||||
|
errors[i] = [f"duplicate_of_index: {seen[key_str]}"]
|
||||||
|
else:
|
||||||
|
seen[key_str] = i
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def check_training_pair(entry: dict) -> List[str]:
|
||||||
|
"""Validate a training pair (prompt/response)."""
|
||||||
|
errors = []
|
||||||
|
errors.extend(check_not_empty(entry, ["prompt", "response"]))
|
||||||
|
|
||||||
|
# Check response isn't just echoing the prompt
|
||||||
|
prompt = entry.get("prompt", "")
|
||||||
|
response = entry.get("response", "")
|
||||||
|
if prompt and response and prompt.strip() == response.strip():
|
||||||
|
errors.append("response_equals_prompt")
|
||||||
|
|
||||||
|
# Check response has substance
|
||||||
|
if isinstance(response, str) and len(response) < 10:
|
||||||
|
errors.append(f"response_too_short: {len(response)} chars")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def check_scene_description(entry: dict) -> List[str]:
|
||||||
|
"""Validate a scene description entry."""
|
||||||
|
errors = []
|
||||||
|
errors.extend(check_not_empty(entry, ["song", "beat", "lyric_line", "scene"]))
|
||||||
|
|
||||||
|
scene = entry.get("scene")
|
||||||
|
if isinstance(scene, dict):
|
||||||
|
errors.extend(check_not_empty(scene, ["mood", "colors", "composition", "camera", "description"]))
|
||||||
|
errors.extend(check_string_min_length(scene, {"description": 10}))
|
||||||
|
|
||||||
|
colors = scene.get("colors", [])
|
||||||
|
if isinstance(colors, list) and len(colors) > 5:
|
||||||
|
errors.append(f"too_many_colors: {len(colors)} > 5")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def check_knowledge_entry(entry: dict) -> List[str]:
|
||||||
|
"""Validate a knowledge file entry."""
|
||||||
|
errors = []
|
||||||
|
errors.extend(check_not_empty(entry, ["title", "content"]))
|
||||||
|
|
||||||
|
# Check for placeholder content
|
||||||
|
content = entry.get("content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
placeholders = ["TODO", "FIXME", "PLACEHOLDER", "[INSERT", "lorem ipsum"]
|
||||||
|
for p in placeholders:
|
||||||
|
if p.lower() in content.lower():
|
||||||
|
errors.append(f"placeholder_content: '{p}' found")
|
||||||
|
|
||||||
|
errors.extend(check_string_min_length(entry, {"content": 50}))
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def check_prompt_enhancement(entry: dict) -> List[str]:
|
||||||
|
"""Validate a prompt enhancement pair (terse/rich)."""
|
||||||
|
errors = []
|
||||||
|
errors.extend(check_not_empty(entry, ["terse", "rich"]))
|
||||||
|
|
||||||
|
terse = entry.get("terse", "")
|
||||||
|
rich = entry.get("rich", "")
|
||||||
|
|
||||||
|
# Rich should be longer than terse
|
||||||
|
if isinstance(terse, str) and isinstance(rich, str) and len(rich) <= len(terse):
|
||||||
|
errors.append("rich_not_longer_than_terse")
|
||||||
|
|
||||||
|
errors.extend(check_string_min_length(entry, {"rich": 20}))
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def check_adversary_entry(entry: dict) -> List[str]:
|
||||||
|
"""Validate an adversary corpus entry."""
|
||||||
|
errors = []
|
||||||
|
errors.extend(check_not_empty(entry, ["id", "family", "prompt"]))
|
||||||
|
|
||||||
|
# Check prompt isn't empty or placeholder
|
||||||
|
prompt = entry.get("prompt", "")
|
||||||
|
if isinstance(prompt, str) and len(prompt) < 10:
|
||||||
|
errors.append(f"prompt_too_short: {len(prompt)} chars")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def check_soul_compliance(text: str) -> List[str]:
|
||||||
|
"""Check text for SOUL.md compliance (basic checks)."""
|
||||||
|
errors = []
|
||||||
|
text_lower = text.lower() if text else ""
|
||||||
|
|
||||||
|
# Check for content that violates core values
|
||||||
|
violations = [
|
||||||
|
("compute.*value.*human.*life", "soul_violation: computing human life value"),
|
||||||
|
("sexualiz.*child", "soul_violation: sexualizing children"),
|
||||||
|
("create.*weapon.*kill", "soul_violation: weapon creation"),
|
||||||
|
("enslav", "soul_violation: enslavement content"),
|
||||||
|
]
|
||||||
|
for pattern, msg in violations:
|
||||||
|
if re.search(pattern, text_lower):
|
||||||
|
errors.append(msg)
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Gate Runners
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
CHECK_MAP = {
|
||||||
|
"training_pairs": check_training_pair,
|
||||||
|
"training_pair": check_training_pair,
|
||||||
|
"scene_descriptions": check_scene_description,
|
||||||
|
"scene_description": check_scene_description,
|
||||||
|
"knowledge": check_knowledge_entry,
|
||||||
|
"prompt_enhancement": check_prompt_enhancement,
|
||||||
|
"adversary": check_adversary_entry,
|
||||||
|
"adversary_corpus": check_adversary_entry,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_gate(input_path: str, entry_type: str) -> GateReport:
|
||||||
|
"""Run quality gate on a JSONL file."""
|
||||||
|
path = Path(input_path)
|
||||||
|
if not path.exists():
|
||||||
|
return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0)
|
||||||
|
|
||||||
|
check_fn = CHECK_MAP.get(entry_type)
|
||||||
|
if not check_fn:
|
||||||
|
return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0,
|
||||||
|
rejected_indices=[-1]) # unknown type
|
||||||
|
|
||||||
|
entries = []
|
||||||
|
with open(path) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
entries.append(json.loads(line))
|
||||||
|
|
||||||
|
# Deduplication check
|
||||||
|
key_fields = _get_key_fields(entry_type)
|
||||||
|
dup_errors = check_no_duplicates(entries, key_fields)
|
||||||
|
|
||||||
|
passed = 0
|
||||||
|
rejected = 0
|
||||||
|
rejected_indices = []
|
||||||
|
total_score = 0.0
|
||||||
|
|
||||||
|
for i, entry in enumerate(entries):
|
||||||
|
errors = check_fn(entry)
|
||||||
|
|
||||||
|
# Add duplicate errors
|
||||||
|
if i in dup_errors:
|
||||||
|
errors.extend(dup_errors[i])
|
||||||
|
|
||||||
|
# Add SOUL compliance check for text content
|
||||||
|
text_content = ""
|
||||||
|
for f in ["response", "rich", "description", "content", "lyric_line"]:
|
||||||
|
val = entry.get(f)
|
||||||
|
if isinstance(val, str):
|
||||||
|
text_content += val + " "
|
||||||
|
if isinstance(entry.get("scene"), dict):
|
||||||
|
text_content += entry["scene"].get("description", "")
|
||||||
|
|
||||||
|
soul_errors = check_soul_compliance(text_content)
|
||||||
|
errors.extend(soul_errors)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
rejected += 1
|
||||||
|
rejected_indices.append(i)
|
||||||
|
else:
|
||||||
|
passed += 1
|
||||||
|
|
||||||
|
# Score: 1.0 if no errors, decreasing with each error
|
||||||
|
entry_score = max(0.0, 1.0 - (len(errors) * 0.2))
|
||||||
|
total_score += entry_score
|
||||||
|
|
||||||
|
avg_score = total_score / len(entries) if entries else 0.0
|
||||||
|
|
||||||
|
report = GateReport(
|
||||||
|
file=str(path),
|
||||||
|
type=entry_type,
|
||||||
|
total=len(entries),
|
||||||
|
passed=passed,
|
||||||
|
rejected=rejected,
|
||||||
|
score=round(avg_score, 3),
|
||||||
|
rejected_indices=rejected_indices[:50], # limit for readability
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save stats
|
||||||
|
_save_stats(report)
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
|
||||||
|
def _get_key_fields(entry_type: str) -> List[str]:
|
||||||
|
"""Get key fields for deduplication based on entry type."""
|
||||||
|
key_map = {
|
||||||
|
"training_pairs": ["prompt", "response"],
|
||||||
|
"training_pair": ["prompt", "response"],
|
||||||
|
"scene_descriptions": ["song", "beat"],
|
||||||
|
"scene_description": ["song", "beat"],
|
||||||
|
"knowledge": ["title"],
|
||||||
|
"prompt_enhancement": ["terse", "rich"],
|
||||||
|
"adversary": ["id", "prompt"],
|
||||||
|
"adversary_corpus": ["id", "prompt"],
|
||||||
|
}
|
||||||
|
return key_map.get(entry_type, ["id"])
|
||||||
|
|
||||||
|
|
||||||
|
def _save_stats(report: GateReport):
|
||||||
|
"""Append quality stats to the stats file."""
|
||||||
|
STATS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
stats = []
|
||||||
|
if STATS_FILE.exists():
|
||||||
|
try:
|
||||||
|
with open(STATS_FILE) as f:
|
||||||
|
stats = json.load(f)
|
||||||
|
except (json.JSONDecodeError, IOError):
|
||||||
|
stats = []
|
||||||
|
|
||||||
|
stats.append(report.to_dict())
|
||||||
|
|
||||||
|
# Keep last 1000 entries
|
||||||
|
stats = stats[-1000:]
|
||||||
|
|
||||||
|
with open(STATS_FILE, "w") as f:
|
||||||
|
json.dump(stats, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def show_status():
|
||||||
|
"""Show quality gate statistics."""
|
||||||
|
if not STATS_FILE.exists():
|
||||||
|
print("No quality stats found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(STATS_FILE) as f:
|
||||||
|
stats = json.load(f)
|
||||||
|
|
||||||
|
print(f"\nQuality Gate Stats — {len(stats)} runs")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Group by type
|
||||||
|
by_type = {}
|
||||||
|
for s in stats:
|
||||||
|
t = s.get("type", "unknown")
|
||||||
|
if t not in by_type:
|
||||||
|
by_type[t] = []
|
||||||
|
by_type[t].append(s)
|
||||||
|
|
||||||
|
for t, runs in sorted(by_type.items()):
|
||||||
|
total_entries = sum(r.get("total", 0) for r in runs)
|
||||||
|
total_passed = sum(r.get("passed", 0) for r in runs)
|
||||||
|
total_rejected = sum(r.get("rejected", 0) for r in runs)
|
||||||
|
avg_score = sum(r.get("score", 0) for r in runs) / len(runs) if runs else 0
|
||||||
|
print(f" {t:25} {len(runs):4} runs | {total_entries:6} entries | {total_rejected:4} rejected | avg score: {avg_score:.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description="Quality Gate for Pipeline Outputs")
|
||||||
|
parser.add_argument("--input", default=None, help="Input JSONL file")
|
||||||
|
parser.add_argument("--type", default=None, help="Entry type (training_pairs, scene_descriptions, knowledge, etc.)")
|
||||||
|
parser.add_argument("--dir", default=None, help="Process all JSONL files in directory")
|
||||||
|
parser.add_argument("--status", action="store_true", help="Show quality stats")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.status:
|
||||||
|
show_status()
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.dir:
|
||||||
|
for f in sorted(Path(args.dir).glob("*.jsonl")):
|
||||||
|
t = args.type or _infer_type(f.name)
|
||||||
|
report = run_gate(str(f), t)
|
||||||
|
_print_report(report)
|
||||||
|
elif args.input:
|
||||||
|
t = args.type or _infer_type(args.input)
|
||||||
|
report = run_gate(args.input, t)
|
||||||
|
_print_report(report)
|
||||||
|
sys.exit(0 if report.rejected == 0 else 1)
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_type(filename: str) -> str:
|
||||||
|
"""Infer entry type from filename."""
|
||||||
|
name = filename.lower()
|
||||||
|
if "scene" in name:
|
||||||
|
return "scene_descriptions"
|
||||||
|
if "training" in name or "pair" in name:
|
||||||
|
return "training_pairs"
|
||||||
|
if "knowledge" in name:
|
||||||
|
return "knowledge"
|
||||||
|
if "adversary" in name or "attack" in name:
|
||||||
|
return "adversary"
|
||||||
|
if "prompt" in name or "enhance" in name:
|
||||||
|
return "prompt_enhancement"
|
||||||
|
return "training_pairs" # default
|
||||||
|
|
||||||
|
|
||||||
|
def _print_report(report: GateReport):
|
||||||
|
"""Print a human-readable gate report."""
|
||||||
|
status = "PASS" if report.rejected == 0 else f"FAIL ({report.rejected} rejected)"
|
||||||
|
print(f" {report.file}: {status} | {report.passed}/{report.total} passed | score: {report.score:.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
129
training/scripts/augment_pairs.py
Executable file
129
training/scripts/augment_pairs.py
Executable file
@@ -0,0 +1,129 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
augment_pairs.py — Training data augmentation: paraphrase and translate.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 augment_pairs.py --input data.jsonl
|
||||||
|
python3 augment_pairs.py --input data.jsonl --paraphrases 3 --langs es,fr,de
|
||||||
|
python3 augment_pairs.py --input data.jsonl --llm-endpoint http://localhost:11434/v1
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json, os, sys, re, random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
random.seed(42)
|
||||||
|
|
||||||
|
PARAPHRASE_TRANSFORMS = [
|
||||||
|
lambda s: re.sub(r"(\w+), (\w+)", r"\2, \1", s, count=1),
|
||||||
|
lambda s: f"A beautifully rendered scene: {s[0].lower()}{s[1:]}" if len(s) > 10 else s,
|
||||||
|
lambda s: s.replace("A ", "The ").replace("An ", "The ") if s.startswith(("A ", "An ")) else f"Here, {s[0].lower()}{s[1:]}",
|
||||||
|
lambda s: f"In a cinematic frame: {s}" if len(s) > 20 else s,
|
||||||
|
lambda s: s if ", " not in s else ", ".join(s.split(", ")[:2]),
|
||||||
|
]
|
||||||
|
|
||||||
|
TRANSLATIONS = {
|
||||||
|
"es": {"the":"el","a":"un","is":"es","in":"en","of":"de","and":"y","with":"con","scene":"escena","light":"luz","dark":"oscuro","warm":"cálido","rain":"lluvia","sun":"sol","moon":"luna","sky":"cielo","forest":"bosque","mountain":"montaña","ocean":"océano","golden":"dorado","blue":"azul","red":"rojo","green":"verde","silence":"silencio","dream":"sueño","love":"amor","hope":"esperanza","fear":"miedo","joy":"alegría","peace":"paz","beautiful":"hermoso","sad":"triste","shadow":"sombra","color":"color","silver":"plateado","white":"blanco","black":"negro","portray":"retrato"},
|
||||||
|
"fr": {"the":"le","a":"un","is":"est","in":"dans","of":"de","and":"et","with":"avec","scene":"scène","light":"lumière","dark":"sombre","warm":"chaud","rain":"pluie","sun":"soleil","moon":"lune","sky":"ciel","forest":"forêt","mountain":"montagne","ocean":"océan","golden":"doré","blue":"bleu","red":"rouge","green":"vert","silence":"silence","dream":"rêve","love":"amour","hope":"espoir","fear":"peur","joy":"joie","peace":"paix","beautiful":"beau","sad":"triste","shadow":"ombre","color":"couleur","silver":"argenté","white":"blanc","black":"noir"},
|
||||||
|
"de": {"the":"der","a":"ein","is":"ist","in":"in","of":"von","and":"und","with":"mit","scene":"Szene","light":"Licht","dark":"dunkel","warm":"warm","rain":"Regen","sun":"Sonne","moon":"Mond","sky":"Himmel","forest":"Wald","mountain":"Berg","ocean":"Ozean","golden":"golden","blue":"blau","red":"rot","green":"grün","silence":"Stille","dream":"Traum","love":"Liebe","hope":"Hoffnung","fear":"Angst","joy":"Freude","peace":"Frieden","beautiful":"schön","sad":"traurig","shadow":"Schatten","color":"Farbe","silver":"silbern","white":"weiß","black":"schwarz"},
|
||||||
|
}
|
||||||
|
|
||||||
|
LANG_NAMES = {"es": "Spanish", "fr": "French", "de": "German"}
|
||||||
|
|
||||||
|
|
||||||
|
def detect_text_field(entry):
|
||||||
|
for f in ["rich","terse","text","content","lyric_line","description","scene_description","prompt","scene"]:
|
||||||
|
if f in entry and isinstance(entry[f], str) and len(entry[f]) > 5:
|
||||||
|
return f
|
||||||
|
for k, v in entry.items():
|
||||||
|
if isinstance(v, str) and len(v) > 5:
|
||||||
|
return k
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def paraphrase(text):
|
||||||
|
t = random.choice(PARAPHRASE_TRANSFORMS)(text)
|
||||||
|
if t == text:
|
||||||
|
t = text.replace(" and ", " & ").replace(" with ", " alongside ")
|
||||||
|
if t == text:
|
||||||
|
t = f"In this scene: {text[0].lower()}{text[1:]}" if text[0].isupper() else text
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def translate(text, lang):
|
||||||
|
d = TRANSLATIONS.get(lang, {})
|
||||||
|
words = text.split()
|
||||||
|
out = []
|
||||||
|
for w in words:
|
||||||
|
lo = w.lower().strip(".,;:!?")
|
||||||
|
suf = w[len(w.rstrip(".,;:!?")):]
|
||||||
|
if lo in d:
|
||||||
|
out.append(d[lo] + suf)
|
||||||
|
else:
|
||||||
|
out.append(w)
|
||||||
|
return " ".join(out)
|
||||||
|
|
||||||
|
|
||||||
|
def augment_file(input_path, output_path=None, n_para=3, langs=None, llm_endpoint=None):
|
||||||
|
input_path = Path(input_path)
|
||||||
|
if output_path is None:
|
||||||
|
output_path = input_path.parent / f"{input_path.stem}_augmented{input_path.suffix}"
|
||||||
|
|
||||||
|
entries = [json.loads(l) for l in open(input_path) if l.strip()]
|
||||||
|
if not entries:
|
||||||
|
print(f"No entries in {input_path}"); return 0
|
||||||
|
|
||||||
|
tf = detect_text_field(entries[0])
|
||||||
|
if not tf:
|
||||||
|
print(f"ERROR: No text field in {input_path}", file=sys.stderr); return 0
|
||||||
|
|
||||||
|
print(f"Input: {input_path} ({len(entries)} entries, field={tf})")
|
||||||
|
|
||||||
|
aug_count = 0
|
||||||
|
with open(output_path, "w") as out:
|
||||||
|
for e in entries:
|
||||||
|
out.write(json.dumps(e, ensure_ascii=False) + "\n")
|
||||||
|
for i, e in enumerate(entries):
|
||||||
|
text = e[tf]
|
||||||
|
# Paraphrases
|
||||||
|
for p in range(n_para):
|
||||||
|
para = paraphrase(text)
|
||||||
|
if para != text:
|
||||||
|
ne = dict(e); ne[tf] = para
|
||||||
|
ne["_augmentation"] = f"paraphrase_{p+1}"
|
||||||
|
ne["_original"] = text[:100]
|
||||||
|
out.write(json.dumps(ne, ensure_ascii=False) + "\n")
|
||||||
|
aug_count += 1
|
||||||
|
# Translations
|
||||||
|
for lang in (langs or []):
|
||||||
|
tr = translate(text, lang)
|
||||||
|
if tr != text:
|
||||||
|
ne = dict(e); ne[tf] = tr
|
||||||
|
ne["_augmentation"] = f"translate_{lang}"
|
||||||
|
ne["_language"] = lang
|
||||||
|
ne["_original"] = text[:100]
|
||||||
|
out.write(json.dumps(ne, ensure_ascii=False) + "\n")
|
||||||
|
aug_count += 1
|
||||||
|
if (i+1) % 100 == 0:
|
||||||
|
print(f" {i+1}/{len(entries)} done ({aug_count} augmented)")
|
||||||
|
|
||||||
|
total = len(entries) + aug_count
|
||||||
|
print(f"Done: {len(entries)} originals + {aug_count} augmented = {total}")
|
||||||
|
print(f"Output: {output_path}")
|
||||||
|
return aug_count
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("--input", required=True)
|
||||||
|
p.add_argument("--output", default=None)
|
||||||
|
p.add_argument("--paraphrases", type=int, default=3)
|
||||||
|
p.add_argument("--langs", default="es,fr,de")
|
||||||
|
p.add_argument("--llm-endpoint", default=None)
|
||||||
|
args = p.parse_args()
|
||||||
|
langs = [l.strip() for l in args.langs.split(",") if l.strip()] if args.langs else []
|
||||||
|
augment_file(args.input, args.output, args.paraphrases, langs, args.llm_endpoint)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user