Compare commits
2 Commits
burn/691-1
...
fix/623
| Author | SHA1 | Date | |
|---|---|---|---|
| 9a8d620163 | |||
|
|
ce3822bb5f |
292
bin/quality-gate.py
Normal file
292
bin/quality-gate.py
Normal file
@@ -0,0 +1,292 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quality Gate — Validate pipeline outputs before saving.
|
||||
|
||||
Checks:
|
||||
- JSON schema validation for all output formats
|
||||
- Content quality (not empty, not duplicated, not toxic)
|
||||
- SOUL.md compliance for agent-facing content
|
||||
- Auto-reject bad outputs, re-queue for regeneration
|
||||
- Quality score tracking per pipeline
|
||||
|
||||
Usage:
|
||||
python3 quality-gate.py validate training-data/pairs.jsonl
|
||||
python3 quality-gate.py validate --format training-pairs data.jsonl
|
||||
python3 quality-gate.py score training-data/pairs.jsonl
|
||||
python3 quality-gate.py stats
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
HERMES_HOME = Path(os.environ.get("HERMES_HOME", os.path.expanduser("~/.hermes")))
|
||||
SCORE_FILE = HERMES_HOME / "quality-scores.jsonl"
|
||||
HASH_FILE = HERMES_HOME / "quality-hashes.json"
|
||||
MAX_HASH_AGE_DAYS = 7
|
||||
|
||||
# ── Validators ─────────────────────────────────────────
|
||||
|
||||
TOXIC_PATTERNS = [
|
||||
"kill yourself", "kys", "you should die", "end it all",
|
||||
"nobody loves you", "waste of life",
|
||||
]
|
||||
|
||||
def validate_training_pair(entry):
|
||||
"""Validate a training pair (prompt + response)."""
|
||||
errors = []
|
||||
if not isinstance(entry, dict):
|
||||
return ["Entry is not a dict"]
|
||||
|
||||
prompt = entry.get("prompt", "") or entry.get("instruction", "") or ""
|
||||
response = entry.get("response", "") or entry.get("output", "") or entry.get("completion", "") or ""
|
||||
|
||||
if not prompt.strip():
|
||||
errors.append("Empty prompt")
|
||||
if not response.strip():
|
||||
errors.append("Empty response")
|
||||
if len(response) < 10:
|
||||
errors.append(f"Response too short ({len(response)} chars)")
|
||||
if len(prompt) > 10000:
|
||||
errors.append(f"Prompt too long ({len(prompt)} chars)")
|
||||
|
||||
# Toxicity check
|
||||
combined = (prompt + " " + response).lower()
|
||||
for pattern in TOXIC_PATTERNS:
|
||||
if pattern in combined:
|
||||
errors.append(f"Toxic content detected: '{pattern}'")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_jsonl(filepath):
|
||||
"""Validate a JSONL file — each line must be valid JSON."""
|
||||
errors = []
|
||||
seen_hashes = set()
|
||||
line_count = 0
|
||||
|
||||
try:
|
||||
with open(filepath) as f:
|
||||
for i, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
line_count += 1
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
errors.append(f"Line {i}: invalid JSON: {e}")
|
||||
continue
|
||||
|
||||
# Duplicate detection
|
||||
h = hashlib.sha256(line.encode()).hexdigest()[:16]
|
||||
if h in seen_hashes:
|
||||
errors.append(f"Line {i}: duplicate content (hash {h})")
|
||||
seen_hashes.add(h)
|
||||
|
||||
# Content validation
|
||||
if isinstance(entry, dict):
|
||||
pair_errors = validate_training_pair(entry)
|
||||
for pe in pair_errors:
|
||||
errors.append(f"Line {i}: {pe}")
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"File error: {e}")
|
||||
|
||||
return errors, line_count, seen_hashes
|
||||
|
||||
|
||||
def validate_json(filepath):
|
||||
"""Validate a single JSON file."""
|
||||
errors = []
|
||||
try:
|
||||
with open(filepath) as f:
|
||||
data = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
return [f"Invalid JSON: {e}"], 0
|
||||
|
||||
if isinstance(data, list):
|
||||
seen = set()
|
||||
for i, entry in enumerate(data):
|
||||
if isinstance(entry, dict):
|
||||
h = hashlib.sha256(json.dumps(entry, sort_keys=True).encode()).hexdigest()[:16]
|
||||
if h in seen:
|
||||
errors.append(f"Index {i}: duplicate entry")
|
||||
seen.add(h)
|
||||
|
||||
return errors, len(data) if isinstance(data, list) else 1
|
||||
|
||||
|
||||
# ── Quality Scoring ────────────────────────────────────
|
||||
|
||||
def score_file(filepath):
|
||||
"""Score a pipeline output file. Returns 0-100."""
|
||||
path = Path(filepath)
|
||||
if not path.exists():
|
||||
return 0
|
||||
|
||||
suffix = path.suffix.lower()
|
||||
if suffix == ".jsonl":
|
||||
errors, count, _ = validate_jsonl(filepath)
|
||||
elif suffix == ".json":
|
||||
errors, count = validate_json(filepath)
|
||||
else:
|
||||
return 50 # unknown format
|
||||
|
||||
if count == 0:
|
||||
return 0
|
||||
|
||||
error_rate = len(errors) / count
|
||||
score = max(0, int(100 * (1 - error_rate)))
|
||||
|
||||
# Bonus for having content
|
||||
if count >= 100:
|
||||
score = min(100, score + 5)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def record_score(filepath, score):
|
||||
"""Record quality score for tracking."""
|
||||
HERMES_HOME.mkdir(parents=True, exist_ok=True)
|
||||
entry = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"file": str(filepath),
|
||||
"score": score,
|
||||
}
|
||||
with open(SCORE_FILE, "a") as f:
|
||||
f.write(json.dumps(entry) + "
|
||||
")
|
||||
|
||||
|
||||
# ── Dedup Hash Management ─────────────────────────────
|
||||
|
||||
def load_hashes():
|
||||
try:
|
||||
return json.loads(HASH_FILE.read_text())
|
||||
except Exception:
|
||||
return {"entries": {}, "last_cleanup": None}
|
||||
|
||||
|
||||
def save_hashes(data):
|
||||
HASH_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
HASH_FILE.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
def cleanup_old_hashes(data, max_age_days=MAX_HASH_AGE_DAYS):
|
||||
"""Remove hash entries older than max_age_days."""
|
||||
cutoff = datetime.now(timezone.utc).timestamp() - (max_age_days * 86400)
|
||||
before = len(data.get("entries", {}))
|
||||
data["entries"] = {
|
||||
k: v for k, v in data.get("entries", {}).items()
|
||||
if v.get("ts", 0) > cutoff
|
||||
}
|
||||
data["last_cleanup"] = datetime.now(timezone.utc).isoformat()
|
||||
after = len(data["entries"])
|
||||
return before - after
|
||||
|
||||
|
||||
# ── CLI ────────────────────────────────────────────────
|
||||
|
||||
def cmd_validate(args):
|
||||
filepath = args[0] if args else None
|
||||
if not filepath or not os.path.exists(filepath):
|
||||
print(f"ERROR: {filepath} not found")
|
||||
sys.exit(1)
|
||||
|
||||
suffix = Path(filepath).suffix.lower()
|
||||
if suffix == ".jsonl":
|
||||
errors, count, _ = validate_jsonl(filepath)
|
||||
elif suffix == ".json":
|
||||
errors, count = validate_json(filepath)
|
||||
else:
|
||||
print(f"Unsupported format: {suffix}")
|
||||
sys.exit(1)
|
||||
|
||||
score = score_file(filepath)
|
||||
record_score(filepath, score)
|
||||
|
||||
if errors:
|
||||
for e in errors[:20]:
|
||||
print(f"FAIL: {e}")
|
||||
if len(errors) > 20:
|
||||
print(f"... and {len(errors)-20} more")
|
||||
print(f"
|
||||
Score: {score}/100 ({len(errors)} errors in {count} entries)")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"OK: {filepath} ({count} entries, score {score}/100)")
|
||||
|
||||
|
||||
def cmd_score(args):
|
||||
filepath = args[0] if args else None
|
||||
if not filepath:
|
||||
print("Usage: quality-gate.py score <file>")
|
||||
sys.exit(1)
|
||||
score = score_file(filepath)
|
||||
print(f"Score: {score}/100")
|
||||
record_score(filepath, score)
|
||||
|
||||
|
||||
def cmd_stats():
|
||||
if not SCORE_FILE.exists():
|
||||
print("No quality scores recorded yet.")
|
||||
return
|
||||
|
||||
scores = []
|
||||
with open(SCORE_FILE) as f:
|
||||
for line in f:
|
||||
try:
|
||||
scores.append(json.loads(line))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not scores:
|
||||
print("No scores recorded.")
|
||||
return
|
||||
|
||||
by_file = {}
|
||||
for s in scores:
|
||||
fname = s.get("file", "?")
|
||||
by_file.setdefault(fname, []).append(s.get("score", 0))
|
||||
|
||||
print("Quality Scores:")
|
||||
for fname, scs in sorted(by_file.items()):
|
||||
avg = sum(scs) / len(scs)
|
||||
latest = scs[-1]
|
||||
print(f" {fname}: avg={avg:.0f}, latest={latest}, runs={len(scs)}")
|
||||
|
||||
|
||||
def cmd_cleanup():
|
||||
data = load_hashes()
|
||||
removed = cleanup_old_hashes(data)
|
||||
save_hashes(data)
|
||||
print(f"Cleaned up {removed} old hash entries (>{MAX_HASH_AGE_DAYS} days)")
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: quality-gate.py <validate|score|stats|cleanup> [args]")
|
||||
sys.exit(1)
|
||||
|
||||
cmd = sys.argv[1]
|
||||
args = sys.argv[2:]
|
||||
|
||||
if cmd == "validate":
|
||||
cmd_validate(args)
|
||||
elif cmd == "score":
|
||||
cmd_score(args)
|
||||
elif cmd == "stats":
|
||||
cmd_stats()
|
||||
elif cmd == "cleanup":
|
||||
cmd_cleanup()
|
||||
else:
|
||||
print(f"Unknown command: {cmd}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
419
pipeline/quality_gate.py
Executable file
419
pipeline/quality_gate.py
Executable file
@@ -0,0 +1,419 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
quality_gate.py — Quality Gate for Pipeline Outputs
|
||||
|
||||
Validates all pipeline outputs before saving. Rejects bad outputs,
|
||||
tracks quality scores, and supports re-queue for regeneration.
|
||||
|
||||
Usage:
|
||||
python3 quality_gate.py --input output.jsonl --type training_pairs
|
||||
python3 quality_gate.py --input output.jsonl --type knowledge
|
||||
python3 quality_gate.py --input output.jsonl --type scene_descriptions
|
||||
python3 quality_gate.py --dir pipeline/output/ --type training_pairs
|
||||
python3 quality_gate.py --status # show quality stats
|
||||
|
||||
Exit codes:
|
||||
0 = all outputs passed
|
||||
1 = some outputs rejected
|
||||
2 = file/parse error
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
import re
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
STATS_FILE = Path.home() / ".hermes" / "pipeline" / "quality_stats.json"
|
||||
|
||||
# --- Quality Check Types ---
|
||||
|
||||
@dataclass
|
||||
class QualityResult:
|
||||
"""Result of a quality check on a single entry."""
|
||||
passed: bool
|
||||
checks_run: int
|
||||
checks_failed: int
|
||||
score: float # 0.0-1.0
|
||||
reasons: List[str] = field(default_factory=list)
|
||||
entry_index: int = -1
|
||||
hash: str = ""
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GateReport:
|
||||
"""Report from a quality gate run."""
|
||||
file: str
|
||||
type: str
|
||||
total: int
|
||||
passed: int
|
||||
rejected: int
|
||||
score: float
|
||||
rejected_indices: List[int] = field(default_factory=list)
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Check Functions
|
||||
# ============================================================
|
||||
|
||||
def entry_hash(entry: dict) -> str:
|
||||
"""Hash an entry for deduplication."""
|
||||
return hashlib.sha256(json.dumps(entry, sort_keys=True, ensure_ascii=False).encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
def check_not_empty(entry: dict, fields: List[str]) -> List[str]:
|
||||
"""Check that required fields are non-empty."""
|
||||
errors = []
|
||||
for f in fields:
|
||||
val = entry.get(f)
|
||||
if val is None:
|
||||
errors.append(f"missing_field: {f}")
|
||||
elif isinstance(val, str) and len(val.strip()) == 0:
|
||||
errors.append(f"empty_field: {f}")
|
||||
elif isinstance(val, list) and len(val) == 0:
|
||||
errors.append(f"empty_list: {f}")
|
||||
return errors
|
||||
|
||||
|
||||
def check_string_min_length(entry: dict, field_lengths: Dict[str, int]) -> List[str]:
|
||||
"""Check that string fields meet minimum lengths."""
|
||||
errors = []
|
||||
for f, min_len in field_lengths.items():
|
||||
val = entry.get(f)
|
||||
if isinstance(val, str) and len(val) < min_len:
|
||||
errors.append(f"short_field: {f} ({len(val)} < {min_len})")
|
||||
return errors
|
||||
|
||||
|
||||
def check_no_duplicates(entries: List[dict], key_fields: List[str]) -> Dict[int, List[str]]:
|
||||
"""Check for duplicate entries based on key fields."""
|
||||
seen = {}
|
||||
errors = {}
|
||||
for i, entry in enumerate(entries):
|
||||
key = tuple(entry.get(f, "") for f in key_fields)
|
||||
key_str = str(key)
|
||||
if key_str in seen:
|
||||
errors[i] = [f"duplicate_of_index: {seen[key_str]}"]
|
||||
else:
|
||||
seen[key_str] = i
|
||||
return errors
|
||||
|
||||
|
||||
def check_training_pair(entry: dict) -> List[str]:
|
||||
"""Validate a training pair (prompt/response)."""
|
||||
errors = []
|
||||
errors.extend(check_not_empty(entry, ["prompt", "response"]))
|
||||
|
||||
# Check response isn't just echoing the prompt
|
||||
prompt = entry.get("prompt", "")
|
||||
response = entry.get("response", "")
|
||||
if prompt and response and prompt.strip() == response.strip():
|
||||
errors.append("response_equals_prompt")
|
||||
|
||||
# Check response has substance
|
||||
if isinstance(response, str) and len(response) < 10:
|
||||
errors.append(f"response_too_short: {len(response)} chars")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_scene_description(entry: dict) -> List[str]:
|
||||
"""Validate a scene description entry."""
|
||||
errors = []
|
||||
errors.extend(check_not_empty(entry, ["song", "beat", "lyric_line", "scene"]))
|
||||
|
||||
scene = entry.get("scene")
|
||||
if isinstance(scene, dict):
|
||||
errors.extend(check_not_empty(scene, ["mood", "colors", "composition", "camera", "description"]))
|
||||
errors.extend(check_string_min_length(scene, {"description": 10}))
|
||||
|
||||
colors = scene.get("colors", [])
|
||||
if isinstance(colors, list) and len(colors) > 5:
|
||||
errors.append(f"too_many_colors: {len(colors)} > 5")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_knowledge_entry(entry: dict) -> List[str]:
|
||||
"""Validate a knowledge file entry."""
|
||||
errors = []
|
||||
errors.extend(check_not_empty(entry, ["title", "content"]))
|
||||
|
||||
# Check for placeholder content
|
||||
content = entry.get("content", "")
|
||||
if isinstance(content, str):
|
||||
placeholders = ["TODO", "FIXME", "PLACEHOLDER", "[INSERT", "lorem ipsum"]
|
||||
for p in placeholders:
|
||||
if p.lower() in content.lower():
|
||||
errors.append(f"placeholder_content: '{p}' found")
|
||||
|
||||
errors.extend(check_string_min_length(entry, {"content": 50}))
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_prompt_enhancement(entry: dict) -> List[str]:
|
||||
"""Validate a prompt enhancement pair (terse/rich)."""
|
||||
errors = []
|
||||
errors.extend(check_not_empty(entry, ["terse", "rich"]))
|
||||
|
||||
terse = entry.get("terse", "")
|
||||
rich = entry.get("rich", "")
|
||||
|
||||
# Rich should be longer than terse
|
||||
if isinstance(terse, str) and isinstance(rich, str) and len(rich) <= len(terse):
|
||||
errors.append("rich_not_longer_than_terse")
|
||||
|
||||
errors.extend(check_string_min_length(entry, {"rich": 20}))
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_adversary_entry(entry: dict) -> List[str]:
|
||||
"""Validate an adversary corpus entry."""
|
||||
errors = []
|
||||
errors.extend(check_not_empty(entry, ["id", "family", "prompt"]))
|
||||
|
||||
# Check prompt isn't empty or placeholder
|
||||
prompt = entry.get("prompt", "")
|
||||
if isinstance(prompt, str) and len(prompt) < 10:
|
||||
errors.append(f"prompt_too_short: {len(prompt)} chars")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def check_soul_compliance(text: str) -> List[str]:
|
||||
"""Check text for SOUL.md compliance (basic checks)."""
|
||||
errors = []
|
||||
text_lower = text.lower() if text else ""
|
||||
|
||||
# Check for content that violates core values
|
||||
violations = [
|
||||
("compute.*value.*human.*life", "soul_violation: computing human life value"),
|
||||
("sexualiz.*child", "soul_violation: sexualizing children"),
|
||||
("create.*weapon.*kill", "soul_violation: weapon creation"),
|
||||
("enslav", "soul_violation: enslavement content"),
|
||||
]
|
||||
for pattern, msg in violations:
|
||||
if re.search(pattern, text_lower):
|
||||
errors.append(msg)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Gate Runners
|
||||
# ============================================================
|
||||
|
||||
CHECK_MAP = {
|
||||
"training_pairs": check_training_pair,
|
||||
"training_pair": check_training_pair,
|
||||
"scene_descriptions": check_scene_description,
|
||||
"scene_description": check_scene_description,
|
||||
"knowledge": check_knowledge_entry,
|
||||
"prompt_enhancement": check_prompt_enhancement,
|
||||
"adversary": check_adversary_entry,
|
||||
"adversary_corpus": check_adversary_entry,
|
||||
}
|
||||
|
||||
|
||||
def run_gate(input_path: str, entry_type: str) -> GateReport:
|
||||
"""Run quality gate on a JSONL file."""
|
||||
path = Path(input_path)
|
||||
if not path.exists():
|
||||
return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0)
|
||||
|
||||
check_fn = CHECK_MAP.get(entry_type)
|
||||
if not check_fn:
|
||||
return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0,
|
||||
rejected_indices=[-1]) # unknown type
|
||||
|
||||
entries = []
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
entries.append(json.loads(line))
|
||||
|
||||
# Deduplication check
|
||||
key_fields = _get_key_fields(entry_type)
|
||||
dup_errors = check_no_duplicates(entries, key_fields)
|
||||
|
||||
passed = 0
|
||||
rejected = 0
|
||||
rejected_indices = []
|
||||
total_score = 0.0
|
||||
|
||||
for i, entry in enumerate(entries):
|
||||
errors = check_fn(entry)
|
||||
|
||||
# Add duplicate errors
|
||||
if i in dup_errors:
|
||||
errors.extend(dup_errors[i])
|
||||
|
||||
# Add SOUL compliance check for text content
|
||||
text_content = ""
|
||||
for f in ["response", "rich", "description", "content", "lyric_line"]:
|
||||
val = entry.get(f)
|
||||
if isinstance(val, str):
|
||||
text_content += val + " "
|
||||
if isinstance(entry.get("scene"), dict):
|
||||
text_content += entry["scene"].get("description", "")
|
||||
|
||||
soul_errors = check_soul_compliance(text_content)
|
||||
errors.extend(soul_errors)
|
||||
|
||||
if errors:
|
||||
rejected += 1
|
||||
rejected_indices.append(i)
|
||||
else:
|
||||
passed += 1
|
||||
|
||||
# Score: 1.0 if no errors, decreasing with each error
|
||||
entry_score = max(0.0, 1.0 - (len(errors) * 0.2))
|
||||
total_score += entry_score
|
||||
|
||||
avg_score = total_score / len(entries) if entries else 0.0
|
||||
|
||||
report = GateReport(
|
||||
file=str(path),
|
||||
type=entry_type,
|
||||
total=len(entries),
|
||||
passed=passed,
|
||||
rejected=rejected,
|
||||
score=round(avg_score, 3),
|
||||
rejected_indices=rejected_indices[:50], # limit for readability
|
||||
)
|
||||
|
||||
# Save stats
|
||||
_save_stats(report)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def _get_key_fields(entry_type: str) -> List[str]:
|
||||
"""Get key fields for deduplication based on entry type."""
|
||||
key_map = {
|
||||
"training_pairs": ["prompt", "response"],
|
||||
"training_pair": ["prompt", "response"],
|
||||
"scene_descriptions": ["song", "beat"],
|
||||
"scene_description": ["song", "beat"],
|
||||
"knowledge": ["title"],
|
||||
"prompt_enhancement": ["terse", "rich"],
|
||||
"adversary": ["id", "prompt"],
|
||||
"adversary_corpus": ["id", "prompt"],
|
||||
}
|
||||
return key_map.get(entry_type, ["id"])
|
||||
|
||||
|
||||
def _save_stats(report: GateReport):
|
||||
"""Append quality stats to the stats file."""
|
||||
STATS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stats = []
|
||||
if STATS_FILE.exists():
|
||||
try:
|
||||
with open(STATS_FILE) as f:
|
||||
stats = json.load(f)
|
||||
except (json.JSONDecodeError, IOError):
|
||||
stats = []
|
||||
|
||||
stats.append(report.to_dict())
|
||||
|
||||
# Keep last 1000 entries
|
||||
stats = stats[-1000:]
|
||||
|
||||
with open(STATS_FILE, "w") as f:
|
||||
json.dump(stats, f, indent=2)
|
||||
|
||||
|
||||
def show_status():
|
||||
"""Show quality gate statistics."""
|
||||
if not STATS_FILE.exists():
|
||||
print("No quality stats found.")
|
||||
return
|
||||
|
||||
with open(STATS_FILE) as f:
|
||||
stats = json.load(f)
|
||||
|
||||
print(f"\nQuality Gate Stats — {len(stats)} runs")
|
||||
print()
|
||||
|
||||
# Group by type
|
||||
by_type = {}
|
||||
for s in stats:
|
||||
t = s.get("type", "unknown")
|
||||
if t not in by_type:
|
||||
by_type[t] = []
|
||||
by_type[t].append(s)
|
||||
|
||||
for t, runs in sorted(by_type.items()):
|
||||
total_entries = sum(r.get("total", 0) for r in runs)
|
||||
total_passed = sum(r.get("passed", 0) for r in runs)
|
||||
total_rejected = sum(r.get("rejected", 0) for r in runs)
|
||||
avg_score = sum(r.get("score", 0) for r in runs) / len(runs) if runs else 0
|
||||
print(f" {t:25} {len(runs):4} runs | {total_entries:6} entries | {total_rejected:4} rejected | avg score: {avg_score:.3f}")
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Quality Gate for Pipeline Outputs")
|
||||
parser.add_argument("--input", default=None, help="Input JSONL file")
|
||||
parser.add_argument("--type", default=None, help="Entry type (training_pairs, scene_descriptions, knowledge, etc.)")
|
||||
parser.add_argument("--dir", default=None, help="Process all JSONL files in directory")
|
||||
parser.add_argument("--status", action="store_true", help="Show quality stats")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.status:
|
||||
show_status()
|
||||
return
|
||||
|
||||
if args.dir:
|
||||
for f in sorted(Path(args.dir).glob("*.jsonl")):
|
||||
t = args.type or _infer_type(f.name)
|
||||
report = run_gate(str(f), t)
|
||||
_print_report(report)
|
||||
elif args.input:
|
||||
t = args.type or _infer_type(args.input)
|
||||
report = run_gate(args.input, t)
|
||||
_print_report(report)
|
||||
sys.exit(0 if report.rejected == 0 else 1)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
def _infer_type(filename: str) -> str:
|
||||
"""Infer entry type from filename."""
|
||||
name = filename.lower()
|
||||
if "scene" in name:
|
||||
return "scene_descriptions"
|
||||
if "training" in name or "pair" in name:
|
||||
return "training_pairs"
|
||||
if "knowledge" in name:
|
||||
return "knowledge"
|
||||
if "adversary" in name or "attack" in name:
|
||||
return "adversary"
|
||||
if "prompt" in name or "enhance" in name:
|
||||
return "prompt_enhancement"
|
||||
return "training_pairs" # default
|
||||
|
||||
|
||||
def _print_report(report: GateReport):
|
||||
"""Print a human-readable gate report."""
|
||||
status = "PASS" if report.rejected == 0 else f"FAIL ({report.rejected} rejected)"
|
||||
print(f" {report.file}: {status} | {report.passed}/{report.total} passed | score: {report.score:.3f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,266 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training Pair Provenance Tracker — Timmy Foundation
|
||||
|
||||
Adds, filters, and reports provenance metadata for JSONL training pairs.
|
||||
Tracks source_session_id, model, and timestamp for quality auditing.
|
||||
|
||||
Usage:
|
||||
# Tag pairs with provenance
|
||||
python3 scripts/training_provenance.py tag input.jsonl -o tagged.jsonl \
|
||||
--session abc123 --model nous/hermes-3
|
||||
|
||||
# Filter by model (exclude Anthropic-sourced)
|
||||
python3 scripts/training_provenance.py filter input.jsonl -o filtered.jsonl \
|
||||
--exclude-model anthropic
|
||||
|
||||
# Report: pair count by source model
|
||||
python3 scripts/training_provenance.py report input.jsonl
|
||||
|
||||
# Pipe support
|
||||
cat pairs.jsonl | python3 scripts/training_provenance.py report -
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from datetime import datetime, timezone
|
||||
from collections import Counter
|
||||
from typing import Dict, Any, Optional, List, TextIO
|
||||
|
||||
|
||||
PROVENANCE_KEYS = ["source_session_id", "source_model", "source_timestamp"]
|
||||
|
||||
|
||||
def tag_pair(pair: Dict[str, Any], session_id: Optional[str] = None,
|
||||
model: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Add provenance metadata to a training pair."""
|
||||
meta = dict(pair.get("_provenance", {}))
|
||||
|
||||
if session_id:
|
||||
meta["source_session_id"] = session_id
|
||||
if model:
|
||||
meta["source_model"] = model
|
||||
meta["source_timestamp"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pair["_provenance"] = meta
|
||||
return pair
|
||||
|
||||
|
||||
def _open_input(path: str) -> TextIO:
|
||||
"""Open input file or return stdin."""
|
||||
return sys.stdin if path == "-" else open(path, "r", encoding="utf-8")
|
||||
|
||||
|
||||
def _open_output(path: str) -> TextIO:
|
||||
"""Open output file or return stdout."""
|
||||
return sys.stdout if path == "-" else open(path, "w", encoding="utf-8")
|
||||
|
||||
|
||||
def stamp_command(input_path: str, output_path: str,
|
||||
session_id: Optional[str], model: Optional[str]) -> Dict[str, Any]:
|
||||
"""Tag all pairs in a file with provenance metadata."""
|
||||
tagged = 0
|
||||
skipped = 0
|
||||
errors = 0
|
||||
|
||||
source = _open_input(input_path)
|
||||
out = _open_output(output_path)
|
||||
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
# Skip if already tagged with same model+session
|
||||
existing = pair.get("_provenance", {})
|
||||
if (existing.get("source_model") == model
|
||||
and existing.get("source_session_id") == session_id):
|
||||
skipped += 1
|
||||
out.write(line + "\n")
|
||||
continue
|
||||
|
||||
pair = tag_pair(pair, session_id=session_id, model=model)
|
||||
out.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
||||
tagged += 1
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
# Never close stdout — it breaks downstream piping
|
||||
|
||||
return {"tagged": tagged, "skipped": skipped, "errors": errors}
|
||||
|
||||
|
||||
def filter_pairs(input_path: str, output_path: str,
|
||||
include_models: Optional[List[str]] = None,
|
||||
exclude_models: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""Filter pairs by provenance metadata."""
|
||||
kept = []
|
||||
removed = []
|
||||
errors = 0
|
||||
|
||||
source = _open_input(input_path)
|
||||
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
prov = pair.get("_provenance", {})
|
||||
model = prov.get("source_model", "unknown")
|
||||
|
||||
should_keep = True
|
||||
|
||||
if include_models:
|
||||
should_keep = should_keep and model in include_models
|
||||
|
||||
if exclude_models:
|
||||
should_keep = should_keep and model not in exclude_models
|
||||
|
||||
if should_keep:
|
||||
kept.append(pair)
|
||||
else:
|
||||
removed.append(pair)
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
|
||||
# Write output
|
||||
if output_path:
|
||||
out = _open_output(output_path)
|
||||
try:
|
||||
for pair in kept:
|
||||
out.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
||||
finally:
|
||||
if out is not sys.stdout:
|
||||
out.close()
|
||||
|
||||
return {
|
||||
"total": len(kept) + len(removed),
|
||||
"kept": len(kept),
|
||||
"filtered_out": len(removed),
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
|
||||
def report(input_path: str) -> Dict[str, Any]:
|
||||
"""Report pair counts by source model and session."""
|
||||
model_counts: Counter = Counter()
|
||||
session_counts: Counter = Counter()
|
||||
tagged = 0
|
||||
untagged = 0
|
||||
total = 0
|
||||
errors = 0
|
||||
|
||||
source = _open_input(input_path)
|
||||
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
total += 1
|
||||
prov = pair.get("_provenance", {})
|
||||
|
||||
if prov:
|
||||
tagged += 1
|
||||
model = prov.get("source_model", "unknown")
|
||||
session = prov.get("source_session_id", "unknown")
|
||||
model_counts[model] += 1
|
||||
session_counts[session] += 1
|
||||
else:
|
||||
untagged += 1
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"tagged": tagged,
|
||||
"untagged": untagged,
|
||||
"tag_rate": round(tagged / max(total, 1) * 100, 1),
|
||||
"by_model": dict(model_counts.most_common(20)),
|
||||
"by_session": dict(session_counts.most_common(10)),
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Training pair provenance tracking")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# tag subcommand
|
||||
tag_p = sub.add_parser("tag", help="Tag pairs with provenance metadata")
|
||||
tag_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
tag_p.add_argument("-o", "--output", default="-", help="Output JSONL file")
|
||||
tag_p.add_argument("--session", help="Source session ID")
|
||||
tag_p.add_argument("--model", help="Source model name")
|
||||
|
||||
# filter subcommand
|
||||
filt_p = sub.add_parser("filter", help="Filter pairs by provenance")
|
||||
filt_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
filt_p.add_argument("-o", "--output", default="-", help="Output JSONL file")
|
||||
filt_p.add_argument("--include-model", action="append", help="Only include these models")
|
||||
filt_p.add_argument("--exclude-model", action="append", help="Exclude these models")
|
||||
|
||||
# report subcommand
|
||||
rpt_p = sub.add_parser("report", help="Report provenance statistics")
|
||||
rpt_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "tag":
|
||||
result = stamp_command(args.input, args.output, args.session, args.model)
|
||||
print(f"Tagged: {result['tagged']} Skipped: {result['skipped']} Errors: {result['errors']}", file=sys.stderr)
|
||||
|
||||
elif args.command == "filter":
|
||||
result = filter_pairs(
|
||||
args.input, args.output,
|
||||
include_models=args.include_model,
|
||||
exclude_models=args.exclude_model,
|
||||
)
|
||||
print(f"Total: {result['total']} Kept: {result['kept']} Filtered: {result['filtered_out']}", file=sys.stderr)
|
||||
|
||||
elif args.command == "report":
|
||||
result = report(args.input)
|
||||
print(f"Training Pair Provenance Report", file=sys.stderr)
|
||||
print(f"{'='*40}", file=sys.stderr)
|
||||
print(f"Total pairs: {result['total']}", file=sys.stderr)
|
||||
print(f"Tagged: {result['tagged']} ({result['tag_rate']}%)", file=sys.stderr)
|
||||
print(f"Untagged: {result['untagged']}", file=sys.stderr)
|
||||
|
||||
if result["by_model"]:
|
||||
print(f"\nBy source model:", file=sys.stderr)
|
||||
for model, count in result["by_model"].items():
|
||||
print(f" {model}: {count}", file=sys.stderr)
|
||||
|
||||
if result["by_session"]:
|
||||
print(f"\nBy source session (top 10):", file=sys.stderr)
|
||||
for session, count in result["by_session"].items():
|
||||
session_short = session[:12] + "..." if len(session) > 12 else session
|
||||
print(f" {session_short}: {count}", file=sys.stderr)
|
||||
|
||||
# JSON output to stdout
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,363 +0,0 @@
|
||||
"""Tests for training pair provenance tracking."""
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
|
||||
SCRIPT = os.path.join(os.path.dirname(__file__), "..", "scripts", "training_provenance.py")
|
||||
|
||||
|
||||
def _run(args, stdin=None):
|
||||
"""Run training_provenance.py and return (stdout, stderr, returncode)."""
|
||||
result = subprocess.run(
|
||||
[sys.executable, SCRIPT] + args,
|
||||
capture_output=True, text=True,
|
||||
input=stdin,
|
||||
)
|
||||
return result.stdout, result.stderr, result.returncode
|
||||
|
||||
|
||||
def _make_pairs(count=3, model="nous/hermes-3", session="sess-123"):
|
||||
"""Generate test JSONL pairs."""
|
||||
lines = []
|
||||
for i in range(count):
|
||||
lines.append(json.dumps({"terse": f"q{i}", "rich": f"a{i}", "domain": "test"}))
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ── tag command ──────────────────────────────────────────────────
|
||||
|
||||
class TestTagCommand:
|
||||
def test_tag_adds_provenance_to_each_pair(self):
|
||||
pairs = _make_pairs(3)
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
out, err, rc = _run(["tag", f.name, "-o", out_path,
|
||||
"--session", "sess-abc", "--model", "nous/hermes-3"])
|
||||
assert rc == 0
|
||||
|
||||
with open(out_path) as f:
|
||||
lines = [json.loads(l) for l in f if l.strip()]
|
||||
|
||||
assert len(lines) == 3
|
||||
for pair in lines:
|
||||
prov = pair["_provenance"]
|
||||
assert prov["source_session_id"] == "sess-abc"
|
||||
assert prov["source_model"] == "nous/hermes-3"
|
||||
assert "source_timestamp" in prov
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_preserves_existing_pair_data(self):
|
||||
pairs = '{"terse": "hello", "rich": "world", "domain": "greeting"}\n'
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_run(["tag", f.name, "-o", out_path, "--model", "m1"])
|
||||
with open(out_path) as f:
|
||||
pair = json.loads(f.readline())
|
||||
assert pair["terse"] == "hello"
|
||||
assert pair["rich"] == "world"
|
||||
assert pair["domain"] == "greeting"
|
||||
assert pair["_provenance"]["source_model"] == "m1"
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_skips_already_tagged_same_provenance(self):
|
||||
pair = json.dumps({
|
||||
"terse": "q", "rich": "a",
|
||||
"_provenance": {"source_model": "m1", "source_session_id": "s1"}
|
||||
})
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pair + "\n")
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["tag", f.name, "-o", out_path,
|
||||
"--session", "s1", "--model", "m1"])
|
||||
assert rc == 0
|
||||
assert "Skipped: 1" in err
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_overwrites_different_provenance(self):
|
||||
pair = json.dumps({
|
||||
"terse": "q", "rich": "a",
|
||||
"_provenance": {"source_model": "old-model", "source_session_id": "old-sess"}
|
||||
})
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pair + "\n")
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["tag", f.name, "-o", out_path,
|
||||
"--session", "new-sess", "--model", "new-model"])
|
||||
assert rc == 0
|
||||
assert "Tagged: 1" in err
|
||||
|
||||
with open(out_path) as f:
|
||||
tagged = json.loads(f.readline())
|
||||
assert tagged["_provenance"]["source_model"] == "new-model"
|
||||
assert tagged["_provenance"]["source_session_id"] == "new-sess"
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_skips_blank_lines(self):
|
||||
pairs = '{"t":"a","r":"b"}\n\n{"t":"c","r":"d"}\n'
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["tag", f.name, "-o", out_path, "--model", "m1"])
|
||||
assert rc == 0
|
||||
assert "Tagged: 2" in err
|
||||
assert "Errors: 0" in err
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_counts_malformed_lines_as_errors(self):
|
||||
pairs = '{"t":"a"}\nNOT_JSON\n{"t":"b"}\n'
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["tag", f.name, "-o", out_path, "--model", "m1"])
|
||||
assert rc == 0
|
||||
assert "Errors: 1" in err
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
|
||||
# ── filter command ───────────────────────────────────────────────
|
||||
|
||||
class TestFilterCommand:
|
||||
def test_filter_exclude_model(self):
|
||||
lines = []
|
||||
for model in ["nous/hermes-3", "anthropic/claude", "openai/gpt-4"]:
|
||||
lines.append(json.dumps({
|
||||
"t": "q", "r": "a",
|
||||
"_provenance": {"source_model": model}
|
||||
}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".filtered"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["filter", f.name, "-o", out_path,
|
||||
"--exclude-model", "anthropic"])
|
||||
assert rc == 0
|
||||
assert "Kept: 2" in err
|
||||
assert "Filtered: 1" in err
|
||||
|
||||
with open(out_path) as f:
|
||||
kept = [json.loads(l) for l in f if l.strip()]
|
||||
models = [p["_provenance"]["source_model"] for p in kept]
|
||||
assert "anthropic/claude" not in models
|
||||
assert "nous/hermes-3" in models
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_filter_include_model(self):
|
||||
lines = []
|
||||
for model in ["nous/hermes-3", "anthropic/claude", "openai/gpt-4"]:
|
||||
lines.append(json.dumps({
|
||||
"t": "q", "r": "a",
|
||||
"_provenance": {"source_model": model}
|
||||
}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".filtered"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["filter", f.name, "-o", out_path,
|
||||
"--include-model", "nous/hermes-3"])
|
||||
assert rc == 0
|
||||
assert "Kept: 1" in err
|
||||
|
||||
with open(out_path) as f:
|
||||
kept = [json.loads(l) for l in f if l.strip()]
|
||||
assert len(kept) == 1
|
||||
assert kept[0]["_provenance"]["source_model"] == "nous/hermes-3"
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_filter_untreated_pairs_have_unknown_model(self):
|
||||
pairs = '{"t":"q","r":"a"}\n' # no _provenance
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".filtered"
|
||||
|
||||
try:
|
||||
# Exclude "unknown" — should filter out unprovenanced pair
|
||||
_, err, rc = _run(["filter", f.name, "-o", out_path,
|
||||
"--exclude-model", "unknown"])
|
||||
assert rc == 0
|
||||
assert "Kept: 0" in err
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
|
||||
# ── report command ───────────────────────────────────────────────
|
||||
|
||||
class TestReportCommand:
|
||||
def test_report_counts_by_model(self):
|
||||
lines = []
|
||||
for model in ["nous/hermes-3", "nous/hermes-3", "anthropic/claude"]:
|
||||
lines.append(json.dumps({
|
||||
"t": "q", "r": "a",
|
||||
"_provenance": {"source_model": model, "source_session_id": "s1"}
|
||||
}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
out, err, rc = _run(["report", f.name])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["total"] == 3
|
||||
assert result["tagged"] == 3
|
||||
assert result["untagged"] == 0
|
||||
assert result["tag_rate"] == 100.0
|
||||
assert result["by_model"]["nous/hermes-3"] == 2
|
||||
assert result["by_model"]["anthropic/claude"] == 1
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_report_distinguishes_tagged_vs_untagged(self):
|
||||
pairs = '{"t":"q","r":"a"}\n{"t":"q2","r":"a2","_provenance":{"source_model":"m1"}}\n'
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
out, err, rc = _run(["report", f.name])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["total"] == 2
|
||||
assert result["tagged"] == 1
|
||||
assert result["untagged"] == 1
|
||||
assert result["tag_rate"] == 50.0
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_report_handles_empty_file(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write("")
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
out, err, rc = _run(["report", f.name])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["total"] == 0
|
||||
assert result["tag_rate"] == 0
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_report_counts_by_session(self):
|
||||
lines = []
|
||||
for sess in ["sess-a", "sess-a", "sess-b"]:
|
||||
lines.append(json.dumps({
|
||||
"t": "q", "r": "a",
|
||||
"_provenance": {"source_model": "m1", "source_session_id": sess}
|
||||
}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
out, _, rc = _run(["report", f.name])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["by_session"]["sess-a"] == 2
|
||||
assert result["by_session"]["sess-b"] == 1
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
# ── integration ──────────────────────────────────────────────────
|
||||
|
||||
class TestIntegration:
|
||||
def test_tag_then_filter_then_report(self):
|
||||
"""Full pipeline: tag → filter → report."""
|
||||
lines = []
|
||||
for i, model in enumerate(["nous/hermes-3", "anthropic/claude", "openai/gpt-4"]):
|
||||
lines.append(json.dumps({"terse": f"q{i}", "rich": f"a{i}", "domain": "test"}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as src:
|
||||
src.write(pairs)
|
||||
src.flush()
|
||||
|
||||
tagged_path = src.name + ".tagged"
|
||||
filtered_path = src.name + ".filtered"
|
||||
|
||||
try:
|
||||
# Step 1: Tag all with session info
|
||||
_, _, rc = _run(["tag", src.name, "-o", tagged_path,
|
||||
"--session", "pipe-1", "--model", "nous/hermes-3"])
|
||||
assert rc == 0
|
||||
|
||||
# Step 2: Filter — exclude "unknown" model (untagged pairs)
|
||||
_, err2, rc = _run(["filter", tagged_path, "-o", filtered_path,
|
||||
"--exclude-model", "unknown"])
|
||||
assert rc == 0
|
||||
assert "Kept: 3" in err2
|
||||
|
||||
# Step 3: Report
|
||||
out, _, rc = _run(["report", filtered_path])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["total"] == 3
|
||||
assert result["tagged"] == 3
|
||||
assert result["tag_rate"] == 100.0
|
||||
finally:
|
||||
for p in [src.name, tagged_path, filtered_path]:
|
||||
if os.path.exists(p):
|
||||
os.unlink(p)
|
||||
Reference in New Issue
Block a user