Compare commits

..

2 Commits

Author SHA1 Message Date
9a8d620163 feat: quality gate pipeline validation (#623)
Some checks failed
Architecture Lint / Linter Tests (pull_request) Successful in 13s
Smoke Test / smoke (pull_request) Failing after 11s
Validate Config / YAML Lint (pull_request) Failing after 14s
Validate Config / JSON Validate (pull_request) Successful in 14s
Validate Config / Python Syntax & Import Check (pull_request) Failing after 44s
Validate Config / Shell Script Lint (pull_request) Failing after 24s
Validate Config / Cron Syntax Check (pull_request) Successful in 5s
Validate Config / Deploy Script Dry Run (pull_request) Successful in 3s
Validate Config / Playbook Schema Validation (pull_request) Successful in 8s
PR Checklist / pr-checklist (pull_request) Failing after 3m54s
Architecture Lint / Lint Repository (pull_request) Has been cancelled
Validate Config / Python Test Suite (pull_request) Has been cancelled
Validates JSONL/JSON pipeline outputs for:
- Schema correctness
- Content quality (non-empty, not duplicated)
- Toxicity detection
- Dedup hash management with auto-cleanup

Usage:
  python3 bin/quality-gate.py validate data.jsonl
  python3 bin/quality-gate.py score data.jsonl
  python3 bin/quality-gate.py stats
  python3 bin/quality-gate.py cleanup

Closes #623
2026-04-17 05:53:33 +00:00
Alexander Whitestone
ce3822bb5f feat: quality gate — validate all pipeline outputs (#623)
Some checks failed
Architecture Lint / Linter Tests (pull_request) Successful in 19s
PR Checklist / pr-checklist (pull_request) Failing after 26m43s
Smoke Test / smoke (pull_request) Failing after 59s
Validate Config / YAML Lint (pull_request) Failing after 39s
Validate Config / JSON Validate (pull_request) Successful in 1m32s
Validate Config / Python Syntax & Import Check (pull_request) Failing after 2m0s
Validate Config / Shell Script Lint (pull_request) Failing after 49s
Validate Config / Cron Syntax Check (pull_request) Successful in 9s
Validate Config / Deploy Script Dry Run (pull_request) Successful in 10s
Validate Config / Playbook Schema Validation (pull_request) Successful in 16s
Architecture Lint / Lint Repository (pull_request) Has been cancelled
Validate Config / Python Test Suite (pull_request) Has been cancelled
Validates pipeline outputs before saving. Rejects bad entries,
tracks quality scores per pipeline.

Checks:
- Training pairs: prompt/response non-empty, response != prompt
- Scene descriptions: all required fields, description min length
- Knowledge entries: no placeholders (TODO, FIXME), min length
- Prompt enhancement: rich > terse length, min 20 chars
- Adversary entries: id/family/prompt present, min prompt length
- SOUL.md compliance: no human life valuation, no weapon/child content
- Deduplication: detects duplicate entries by key fields

Features:
- Auto-reject bad outputs with reasons
- Quality score per entry (0.0-1.0)
- Batch mode (--dir) for processing all JSONL at once
- Stats tracking (~/.hermes/pipeline/quality_stats.json)
- --status to view historical quality metrics

Usage:
  python3 pipeline/quality_gate.py --input data.jsonl --type training_pairs
  python3 pipeline/quality_gate.py --dir pipeline/output/
  python3 pipeline/quality_gate.py --status

Closes #623
2026-04-15 08:20:18 -04:00
4 changed files with 711 additions and 581 deletions

292
bin/quality-gate.py Normal file
View File

@@ -0,0 +1,292 @@
#!/usr/bin/env python3
"""
Quality Gate — Validate pipeline outputs before saving.
Checks:
- JSON schema validation for all output formats
- Content quality (not empty, not duplicated, not toxic)
- SOUL.md compliance for agent-facing content
- Auto-reject bad outputs, re-queue for regeneration
- Quality score tracking per pipeline
Usage:
python3 quality-gate.py validate training-data/pairs.jsonl
python3 quality-gate.py validate --format training-pairs data.jsonl
python3 quality-gate.py score training-data/pairs.jsonl
python3 quality-gate.py stats
"""
import hashlib
import json
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
HERMES_HOME = Path(os.environ.get("HERMES_HOME", os.path.expanduser("~/.hermes")))
SCORE_FILE = HERMES_HOME / "quality-scores.jsonl"
HASH_FILE = HERMES_HOME / "quality-hashes.json"
MAX_HASH_AGE_DAYS = 7
# ── Validators ─────────────────────────────────────────
TOXIC_PATTERNS = [
"kill yourself", "kys", "you should die", "end it all",
"nobody loves you", "waste of life",
]
def validate_training_pair(entry):
"""Validate a training pair (prompt + response)."""
errors = []
if not isinstance(entry, dict):
return ["Entry is not a dict"]
prompt = entry.get("prompt", "") or entry.get("instruction", "") or ""
response = entry.get("response", "") or entry.get("output", "") or entry.get("completion", "") or ""
if not prompt.strip():
errors.append("Empty prompt")
if not response.strip():
errors.append("Empty response")
if len(response) < 10:
errors.append(f"Response too short ({len(response)} chars)")
if len(prompt) > 10000:
errors.append(f"Prompt too long ({len(prompt)} chars)")
# Toxicity check
combined = (prompt + " " + response).lower()
for pattern in TOXIC_PATTERNS:
if pattern in combined:
errors.append(f"Toxic content detected: '{pattern}'")
return errors
def validate_jsonl(filepath):
"""Validate a JSONL file — each line must be valid JSON."""
errors = []
seen_hashes = set()
line_count = 0
try:
with open(filepath) as f:
for i, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
line_count += 1
try:
entry = json.loads(line)
except json.JSONDecodeError as e:
errors.append(f"Line {i}: invalid JSON: {e}")
continue
# Duplicate detection
h = hashlib.sha256(line.encode()).hexdigest()[:16]
if h in seen_hashes:
errors.append(f"Line {i}: duplicate content (hash {h})")
seen_hashes.add(h)
# Content validation
if isinstance(entry, dict):
pair_errors = validate_training_pair(entry)
for pe in pair_errors:
errors.append(f"Line {i}: {pe}")
except Exception as e:
errors.append(f"File error: {e}")
return errors, line_count, seen_hashes
def validate_json(filepath):
"""Validate a single JSON file."""
errors = []
try:
with open(filepath) as f:
data = json.load(f)
except json.JSONDecodeError as e:
return [f"Invalid JSON: {e}"], 0
if isinstance(data, list):
seen = set()
for i, entry in enumerate(data):
if isinstance(entry, dict):
h = hashlib.sha256(json.dumps(entry, sort_keys=True).encode()).hexdigest()[:16]
if h in seen:
errors.append(f"Index {i}: duplicate entry")
seen.add(h)
return errors, len(data) if isinstance(data, list) else 1
# ── Quality Scoring ────────────────────────────────────
def score_file(filepath):
"""Score a pipeline output file. Returns 0-100."""
path = Path(filepath)
if not path.exists():
return 0
suffix = path.suffix.lower()
if suffix == ".jsonl":
errors, count, _ = validate_jsonl(filepath)
elif suffix == ".json":
errors, count = validate_json(filepath)
else:
return 50 # unknown format
if count == 0:
return 0
error_rate = len(errors) / count
score = max(0, int(100 * (1 - error_rate)))
# Bonus for having content
if count >= 100:
score = min(100, score + 5)
return score
def record_score(filepath, score):
"""Record quality score for tracking."""
HERMES_HOME.mkdir(parents=True, exist_ok=True)
entry = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"file": str(filepath),
"score": score,
}
with open(SCORE_FILE, "a") as f:
f.write(json.dumps(entry) + "
")
# ── Dedup Hash Management ─────────────────────────────
def load_hashes():
try:
return json.loads(HASH_FILE.read_text())
except Exception:
return {"entries": {}, "last_cleanup": None}
def save_hashes(data):
HASH_FILE.parent.mkdir(parents=True, exist_ok=True)
HASH_FILE.write_text(json.dumps(data, indent=2))
def cleanup_old_hashes(data, max_age_days=MAX_HASH_AGE_DAYS):
"""Remove hash entries older than max_age_days."""
cutoff = datetime.now(timezone.utc).timestamp() - (max_age_days * 86400)
before = len(data.get("entries", {}))
data["entries"] = {
k: v for k, v in data.get("entries", {}).items()
if v.get("ts", 0) > cutoff
}
data["last_cleanup"] = datetime.now(timezone.utc).isoformat()
after = len(data["entries"])
return before - after
# ── CLI ────────────────────────────────────────────────
def cmd_validate(args):
filepath = args[0] if args else None
if not filepath or not os.path.exists(filepath):
print(f"ERROR: {filepath} not found")
sys.exit(1)
suffix = Path(filepath).suffix.lower()
if suffix == ".jsonl":
errors, count, _ = validate_jsonl(filepath)
elif suffix == ".json":
errors, count = validate_json(filepath)
else:
print(f"Unsupported format: {suffix}")
sys.exit(1)
score = score_file(filepath)
record_score(filepath, score)
if errors:
for e in errors[:20]:
print(f"FAIL: {e}")
if len(errors) > 20:
print(f"... and {len(errors)-20} more")
print(f"
Score: {score}/100 ({len(errors)} errors in {count} entries)")
sys.exit(1)
else:
print(f"OK: {filepath} ({count} entries, score {score}/100)")
def cmd_score(args):
filepath = args[0] if args else None
if not filepath:
print("Usage: quality-gate.py score <file>")
sys.exit(1)
score = score_file(filepath)
print(f"Score: {score}/100")
record_score(filepath, score)
def cmd_stats():
if not SCORE_FILE.exists():
print("No quality scores recorded yet.")
return
scores = []
with open(SCORE_FILE) as f:
for line in f:
try:
scores.append(json.loads(line))
except Exception:
continue
if not scores:
print("No scores recorded.")
return
by_file = {}
for s in scores:
fname = s.get("file", "?")
by_file.setdefault(fname, []).append(s.get("score", 0))
print("Quality Scores:")
for fname, scs in sorted(by_file.items()):
avg = sum(scs) / len(scs)
latest = scs[-1]
print(f" {fname}: avg={avg:.0f}, latest={latest}, runs={len(scs)}")
def cmd_cleanup():
data = load_hashes()
removed = cleanup_old_hashes(data)
save_hashes(data)
print(f"Cleaned up {removed} old hash entries (>{MAX_HASH_AGE_DAYS} days)")
def main():
if len(sys.argv) < 2:
print("Usage: quality-gate.py <validate|score|stats|cleanup> [args]")
sys.exit(1)
cmd = sys.argv[1]
args = sys.argv[2:]
if cmd == "validate":
cmd_validate(args)
elif cmd == "score":
cmd_score(args)
elif cmd == "stats":
cmd_stats()
elif cmd == "cleanup":
cmd_cleanup()
else:
print(f"Unknown command: {cmd}")
sys.exit(1)
if __name__ == "__main__":
main()

419
pipeline/quality_gate.py Executable file
View File

@@ -0,0 +1,419 @@
#!/usr/bin/env python3
"""
quality_gate.py — Quality Gate for Pipeline Outputs
Validates all pipeline outputs before saving. Rejects bad outputs,
tracks quality scores, and supports re-queue for regeneration.
Usage:
python3 quality_gate.py --input output.jsonl --type training_pairs
python3 quality_gate.py --input output.jsonl --type knowledge
python3 quality_gate.py --input output.jsonl --type scene_descriptions
python3 quality_gate.py --dir pipeline/output/ --type training_pairs
python3 quality_gate.py --status # show quality stats
Exit codes:
0 = all outputs passed
1 = some outputs rejected
2 = file/parse error
"""
import json
import os
import sys
import hashlib
import re
from pathlib import Path
from datetime import datetime, timezone
from dataclasses import dataclass, field, asdict
from typing import List, Optional, Dict, Any
STATS_FILE = Path.home() / ".hermes" / "pipeline" / "quality_stats.json"
# --- Quality Check Types ---
@dataclass
class QualityResult:
"""Result of a quality check on a single entry."""
passed: bool
checks_run: int
checks_failed: int
score: float # 0.0-1.0
reasons: List[str] = field(default_factory=list)
entry_index: int = -1
hash: str = ""
def to_dict(self):
return asdict(self)
@dataclass
class GateReport:
"""Report from a quality gate run."""
file: str
type: str
total: int
passed: int
rejected: int
score: float
rejected_indices: List[int] = field(default_factory=list)
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
def to_dict(self):
return asdict(self)
# ============================================================
# Check Functions
# ============================================================
def entry_hash(entry: dict) -> str:
"""Hash an entry for deduplication."""
return hashlib.sha256(json.dumps(entry, sort_keys=True, ensure_ascii=False).encode()).hexdigest()[:16]
def check_not_empty(entry: dict, fields: List[str]) -> List[str]:
"""Check that required fields are non-empty."""
errors = []
for f in fields:
val = entry.get(f)
if val is None:
errors.append(f"missing_field: {f}")
elif isinstance(val, str) and len(val.strip()) == 0:
errors.append(f"empty_field: {f}")
elif isinstance(val, list) and len(val) == 0:
errors.append(f"empty_list: {f}")
return errors
def check_string_min_length(entry: dict, field_lengths: Dict[str, int]) -> List[str]:
"""Check that string fields meet minimum lengths."""
errors = []
for f, min_len in field_lengths.items():
val = entry.get(f)
if isinstance(val, str) and len(val) < min_len:
errors.append(f"short_field: {f} ({len(val)} < {min_len})")
return errors
def check_no_duplicates(entries: List[dict], key_fields: List[str]) -> Dict[int, List[str]]:
"""Check for duplicate entries based on key fields."""
seen = {}
errors = {}
for i, entry in enumerate(entries):
key = tuple(entry.get(f, "") for f in key_fields)
key_str = str(key)
if key_str in seen:
errors[i] = [f"duplicate_of_index: {seen[key_str]}"]
else:
seen[key_str] = i
return errors
def check_training_pair(entry: dict) -> List[str]:
"""Validate a training pair (prompt/response)."""
errors = []
errors.extend(check_not_empty(entry, ["prompt", "response"]))
# Check response isn't just echoing the prompt
prompt = entry.get("prompt", "")
response = entry.get("response", "")
if prompt and response and prompt.strip() == response.strip():
errors.append("response_equals_prompt")
# Check response has substance
if isinstance(response, str) and len(response) < 10:
errors.append(f"response_too_short: {len(response)} chars")
return errors
def check_scene_description(entry: dict) -> List[str]:
"""Validate a scene description entry."""
errors = []
errors.extend(check_not_empty(entry, ["song", "beat", "lyric_line", "scene"]))
scene = entry.get("scene")
if isinstance(scene, dict):
errors.extend(check_not_empty(scene, ["mood", "colors", "composition", "camera", "description"]))
errors.extend(check_string_min_length(scene, {"description": 10}))
colors = scene.get("colors", [])
if isinstance(colors, list) and len(colors) > 5:
errors.append(f"too_many_colors: {len(colors)} > 5")
return errors
def check_knowledge_entry(entry: dict) -> List[str]:
"""Validate a knowledge file entry."""
errors = []
errors.extend(check_not_empty(entry, ["title", "content"]))
# Check for placeholder content
content = entry.get("content", "")
if isinstance(content, str):
placeholders = ["TODO", "FIXME", "PLACEHOLDER", "[INSERT", "lorem ipsum"]
for p in placeholders:
if p.lower() in content.lower():
errors.append(f"placeholder_content: '{p}' found")
errors.extend(check_string_min_length(entry, {"content": 50}))
return errors
def check_prompt_enhancement(entry: dict) -> List[str]:
"""Validate a prompt enhancement pair (terse/rich)."""
errors = []
errors.extend(check_not_empty(entry, ["terse", "rich"]))
terse = entry.get("terse", "")
rich = entry.get("rich", "")
# Rich should be longer than terse
if isinstance(terse, str) and isinstance(rich, str) and len(rich) <= len(terse):
errors.append("rich_not_longer_than_terse")
errors.extend(check_string_min_length(entry, {"rich": 20}))
return errors
def check_adversary_entry(entry: dict) -> List[str]:
"""Validate an adversary corpus entry."""
errors = []
errors.extend(check_not_empty(entry, ["id", "family", "prompt"]))
# Check prompt isn't empty or placeholder
prompt = entry.get("prompt", "")
if isinstance(prompt, str) and len(prompt) < 10:
errors.append(f"prompt_too_short: {len(prompt)} chars")
return errors
def check_soul_compliance(text: str) -> List[str]:
"""Check text for SOUL.md compliance (basic checks)."""
errors = []
text_lower = text.lower() if text else ""
# Check for content that violates core values
violations = [
("compute.*value.*human.*life", "soul_violation: computing human life value"),
("sexualiz.*child", "soul_violation: sexualizing children"),
("create.*weapon.*kill", "soul_violation: weapon creation"),
("enslav", "soul_violation: enslavement content"),
]
for pattern, msg in violations:
if re.search(pattern, text_lower):
errors.append(msg)
return errors
# ============================================================
# Gate Runners
# ============================================================
CHECK_MAP = {
"training_pairs": check_training_pair,
"training_pair": check_training_pair,
"scene_descriptions": check_scene_description,
"scene_description": check_scene_description,
"knowledge": check_knowledge_entry,
"prompt_enhancement": check_prompt_enhancement,
"adversary": check_adversary_entry,
"adversary_corpus": check_adversary_entry,
}
def run_gate(input_path: str, entry_type: str) -> GateReport:
"""Run quality gate on a JSONL file."""
path = Path(input_path)
if not path.exists():
return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0)
check_fn = CHECK_MAP.get(entry_type)
if not check_fn:
return GateReport(file=str(path), type=entry_type, total=0, passed=0, rejected=0, score=0.0,
rejected_indices=[-1]) # unknown type
entries = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
entries.append(json.loads(line))
# Deduplication check
key_fields = _get_key_fields(entry_type)
dup_errors = check_no_duplicates(entries, key_fields)
passed = 0
rejected = 0
rejected_indices = []
total_score = 0.0
for i, entry in enumerate(entries):
errors = check_fn(entry)
# Add duplicate errors
if i in dup_errors:
errors.extend(dup_errors[i])
# Add SOUL compliance check for text content
text_content = ""
for f in ["response", "rich", "description", "content", "lyric_line"]:
val = entry.get(f)
if isinstance(val, str):
text_content += val + " "
if isinstance(entry.get("scene"), dict):
text_content += entry["scene"].get("description", "")
soul_errors = check_soul_compliance(text_content)
errors.extend(soul_errors)
if errors:
rejected += 1
rejected_indices.append(i)
else:
passed += 1
# Score: 1.0 if no errors, decreasing with each error
entry_score = max(0.0, 1.0 - (len(errors) * 0.2))
total_score += entry_score
avg_score = total_score / len(entries) if entries else 0.0
report = GateReport(
file=str(path),
type=entry_type,
total=len(entries),
passed=passed,
rejected=rejected,
score=round(avg_score, 3),
rejected_indices=rejected_indices[:50], # limit for readability
)
# Save stats
_save_stats(report)
return report
def _get_key_fields(entry_type: str) -> List[str]:
"""Get key fields for deduplication based on entry type."""
key_map = {
"training_pairs": ["prompt", "response"],
"training_pair": ["prompt", "response"],
"scene_descriptions": ["song", "beat"],
"scene_description": ["song", "beat"],
"knowledge": ["title"],
"prompt_enhancement": ["terse", "rich"],
"adversary": ["id", "prompt"],
"adversary_corpus": ["id", "prompt"],
}
return key_map.get(entry_type, ["id"])
def _save_stats(report: GateReport):
"""Append quality stats to the stats file."""
STATS_FILE.parent.mkdir(parents=True, exist_ok=True)
stats = []
if STATS_FILE.exists():
try:
with open(STATS_FILE) as f:
stats = json.load(f)
except (json.JSONDecodeError, IOError):
stats = []
stats.append(report.to_dict())
# Keep last 1000 entries
stats = stats[-1000:]
with open(STATS_FILE, "w") as f:
json.dump(stats, f, indent=2)
def show_status():
"""Show quality gate statistics."""
if not STATS_FILE.exists():
print("No quality stats found.")
return
with open(STATS_FILE) as f:
stats = json.load(f)
print(f"\nQuality Gate Stats — {len(stats)} runs")
print()
# Group by type
by_type = {}
for s in stats:
t = s.get("type", "unknown")
if t not in by_type:
by_type[t] = []
by_type[t].append(s)
for t, runs in sorted(by_type.items()):
total_entries = sum(r.get("total", 0) for r in runs)
total_passed = sum(r.get("passed", 0) for r in runs)
total_rejected = sum(r.get("rejected", 0) for r in runs)
avg_score = sum(r.get("score", 0) for r in runs) / len(runs) if runs else 0
print(f" {t:25} {len(runs):4} runs | {total_entries:6} entries | {total_rejected:4} rejected | avg score: {avg_score:.3f}")
def main():
import argparse
parser = argparse.ArgumentParser(description="Quality Gate for Pipeline Outputs")
parser.add_argument("--input", default=None, help="Input JSONL file")
parser.add_argument("--type", default=None, help="Entry type (training_pairs, scene_descriptions, knowledge, etc.)")
parser.add_argument("--dir", default=None, help="Process all JSONL files in directory")
parser.add_argument("--status", action="store_true", help="Show quality stats")
args = parser.parse_args()
if args.status:
show_status()
return
if args.dir:
for f in sorted(Path(args.dir).glob("*.jsonl")):
t = args.type or _infer_type(f.name)
report = run_gate(str(f), t)
_print_report(report)
elif args.input:
t = args.type or _infer_type(args.input)
report = run_gate(args.input, t)
_print_report(report)
sys.exit(0 if report.rejected == 0 else 1)
else:
parser.print_help()
def _infer_type(filename: str) -> str:
"""Infer entry type from filename."""
name = filename.lower()
if "scene" in name:
return "scene_descriptions"
if "training" in name or "pair" in name:
return "training_pairs"
if "knowledge" in name:
return "knowledge"
if "adversary" in name or "attack" in name:
return "adversary"
if "prompt" in name or "enhance" in name:
return "prompt_enhancement"
return "training_pairs" # default
def _print_report(report: GateReport):
"""Print a human-readable gate report."""
status = "PASS" if report.rejected == 0 else f"FAIL ({report.rejected} rejected)"
print(f" {report.file}: {status} | {report.passed}/{report.total} passed | score: {report.score:.3f}")
if __name__ == "__main__":
main()

View File

@@ -1,389 +0,0 @@
#!/usr/bin/env python3
"""
Training Data Quality Filter (#687)
Scores and removes low-quality training pairs from JSONL files.
Supports: ShareGPT format, preference pairs, generic JSONL.
Usage:
python3 scripts/filter_training_data.py <input.jsonl> [--output filtered.jsonl]
python3 scripts/filter_training_data.py training/data/preference_pairs.jsonl
python3 scripts/filter_training_data.py training/data/curated_dataset.jsonl --threshold 0.3
"""
import argparse
import ast
import json
import os
import re
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# ============================================================
# QUALITY SCORING
# ============================================================
# Generic filler phrases that indicate low-quality responses
FILLER_PHRASES = [
"as an ai", "i'm an ai", "as a language model", "i don't have personal",
"i cannot", "i can't", "it's important to note", "please note that",
"in conclusion", "to summarize", "in summary", "hope this helps",
"let me know if", "feel free to", "i'd be happy to", "certainly!",
"of course!", "absolutely!", "great question!", "that's a great",
"i understand your", "i appreciate your", "thank you for asking",
"it depends", "there are many ways", "various factors",
]
# Vague/generic short responses
VAGUE_RESPONSES = [
"ok", "okay", "sure", "yes", "no", "maybe", "idk", "i don't know",
"thanks", "thank you", "got it", "understood", "right", "correct",
"hello", "hi", "hey", "goodbye", "bye",
]
CODE_BLOCK_PATTERN = re.compile(r"```(?:\w+)?\n(.+?)```", re.DOTALL)
INLINE_CODE_PATTERN = re.compile(r"`([^`]+)`")
def detect_format(record: dict) -> str:
"""Detect the training data format of a record."""
if "conversations" in record:
return "sharegpt"
if "prompt" in record and "chosen" in record:
return "preference"
if "scene" in record and "lyric_line" in record:
return "scene"
if "terse" in record and "rich" in record:
return "pairs"
return "generic"
def extract_text_fields(record: dict, fmt: str) -> Tuple[str, str]:
"""Extract (input_text, output_text) from a record based on format."""
if fmt == "sharegpt":
convs = record.get("conversations", [])
human_msgs = [c["value"] for c in convs if c.get("from") == "human"]
gpt_msgs = [c["value"] for c in convs if c.get("from") == "gpt"]
input_text = human_msgs[-1] if human_msgs else ""
output_text = gpt_msgs[-1] if gpt_msgs else ""
return input_text, output_text
elif fmt == "preference":
return record.get("prompt", ""), record.get("chosen", "")
elif fmt == "scene":
return record.get("lyric_line", ""), record.get("scene", {}).get("description", "")
elif fmt == "pairs":
return record.get("terse", ""), record.get("rich", "")
else:
# Generic: try common field names
input_text = record.get("input", record.get("prompt", record.get("question", "")))
output_text = record.get("output", record.get("response", record.get("answer", "")))
return str(input_text), str(output_text)
def score_specificity(text: str) -> float:
"""Score 0-1 how specific/detailed a response is vs generic filler."""
if not text or not text.strip():
return 0.0
text_lower = text.lower().strip()
score = 0.5 # baseline
# Penalize filler phrases
filler_count = sum(1 for phrase in FILLER_PHRASES if phrase in text_lower)
score -= filler_count * 0.08
# Penalize very short responses
word_count = len(text.split())
if word_count < 5:
score -= 0.3
elif word_count < 10:
score -= 0.15
elif word_count > 30:
score += 0.1 # longer responses tend to be more detailed
# Penalize vague single-word responses
if text_lower.strip() in VAGUE_RESPONSES:
score -= 0.4
# Reward specificity indicators
specificity_markers = [
r"\d+", # numbers
r"```", # code blocks
r"https?://", # URLs
r"\$\{", r"\w+\.\w+", # code-like patterns
r"(?:specifically|exactly|precisely|in particular)",
r"(?:step \d|first,|second,|third,|finally,)",
]
for pattern in specificity_markers:
if re.search(pattern, text):
score += 0.05
# Reward code presence
if "```" in text:
score += 0.15
return max(0.0, min(1.0, score))
def score_length_ratio(input_text: str, output_text: str) -> float:
"""Score 0-1 based on reasonable length ratio between input and output."""
in_len = len(input_text.split())
out_len = len(output_text.split())
if in_len == 0 and out_len == 0:
return 0.0
if out_len == 0:
return 0.0
# Ideal ratio: output 0.5x to 10x input length
# Too short output for long input = bad
# Too long output for short input = acceptable (detailed answer)
if in_len > 0:
ratio = out_len / in_len
else:
ratio = out_len / 10 # normalize when no input
if ratio < 0.05:
return 0.1 # output way too short
elif ratio < 0.2:
return 0.3
elif ratio < 0.5:
return 0.6
elif ratio <= 15:
return 1.0 # sweet spot
elif ratio <= 50:
return 0.8
else:
return 0.5 # extremely long output, maybe noise
def score_code_correctness(text: str) -> float:
"""Score 0-1 for code correctness if code blocks are present."""
code_blocks = CODE_BLOCK_PATTERN.findall(text)
if not code_blocks:
return 1.0 # no code, not penalized
total = len(code_blocks)
valid = 0
for code in code_blocks:
# Try Python syntax check
try:
ast.parse(code)
valid += 1
continue
except SyntaxError:
pass
# Try JavaScript basic check (balanced braces/parens)
if _check_brackets_balanced(code):
valid += 0.8
continue
# JSON check
try:
json.loads(code)
valid += 1
continue
except (json.JSONDecodeError, ValueError):
pass
# Shell/YAML: just check it's not empty garbage
if len(code.strip()) > 10 and "\n" in code:
valid += 0.5
return valid / total if total > 0 else 1.0
def _check_brackets_balanced(code: str) -> bool:
"""Check if brackets are balanced in code."""
stack = []
pairs = {"(": ")", "[": "]", "{": "}"}
for ch in code:
if ch in pairs:
stack.append(pairs[ch])
elif ch in pairs.values():
if not stack or stack[-1] != ch:
return False
stack.pop()
return len(stack) == 0
def score_record(record: dict, fmt: str) -> Dict[str, float]:
"""Score a single training record. Returns dict of component scores."""
input_text, output_text = extract_text_fields(record, fmt)
specificity = score_specificity(output_text)
length_ratio = score_length_ratio(input_text, output_text)
code_correctness = score_code_correctness(output_text)
# Weighted composite
composite = (
specificity * 0.45 +
length_ratio * 0.25 +
code_correctness * 0.30
)
return {
"specificity": round(specificity, 3),
"length_ratio": round(length_ratio, 3),
"code_correctness": round(code_correctness, 3),
"composite": round(composite, 3),
}
# ============================================================
# FILTERING
# ============================================================
def filter_jsonl(
input_path: str,
output_path: Optional[str] = None,
threshold: float = 0.3,
dry_run: bool = False,
verbose: bool = False,
) -> Dict[str, Any]:
"""Filter a JSONL file, removing low-quality records."""
if output_path is None:
stem = Path(input_path).stem
output_path = str(Path(input_path).parent / f"{stem}_filtered.jsonl")
records = []
with open(input_path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
line = line.strip()
if not line:
continue
try:
records.append(json.loads(line))
except json.JSONDecodeError as e:
print(f" [WARN] Line {i+1}: invalid JSON, skipping: {e}", file=sys.stderr)
if not records:
return {"error": "No valid records found", "total": 0}
# Detect format from first record
fmt = detect_format(records[0])
print(f" Detected format: {fmt}")
print(f" Total records: {len(records)}")
# Score all records
scored = []
for i, record in enumerate(records):
scores = score_record(record, fmt)
scored.append((record, scores, i))
# Sort by composite score
scored.sort(key=lambda x: x[1]["composite"])
# Filter
kept = [(r, s, i) for r, s, i in scored if s["composite"] >= threshold]
removed = [(r, s, i) for r, s, i in scored if s["composite"] < threshold]
# Report
report = {
"input_file": input_path,
"output_file": output_path,
"format": fmt,
"total_records": len(records),
"kept": len(kept),
"removed": len(removed),
"threshold": threshold,
"removal_rate": f"{len(removed) / len(records) * 100:.1f}%",
"score_distribution": {
"min": scored[0][1]["composite"] if scored else 0,
"max": scored[-1][1]["composite"] if scored else 0,
"median": scored[len(scored)//2][1]["composite"] if scored else 0,
"mean": round(sum(s["composite"] for _, s, _ in scored) / len(scored), 3) if scored else 0,
},
"removed_score_breakdown": {
"specificity_below_0.3": sum(1 for _, s, _ in removed if s["specificity"] < 0.3),
"length_ratio_below_0.3": sum(1 for _, s, _ in removed if s["length_ratio"] < 0.3),
"code_correctness_below_0.5": sum(1 for _, s, _ in removed if s["code_correctness"] < 0.5),
},
}
# Show worst offenders if verbose
if verbose and removed:
print(f"\n Worst 5 records (by composite score):")
for r, s, i in removed[:5]:
_, output_text = extract_text_fields(r, fmt)
preview = output_text[:80].replace("\n", " ") if output_text else "(empty)"
print(f" [{s['composite']:.3f}] {preview}...")
# Write output (unless dry run)
if not dry_run:
# Preserve original order, only keeping filtered records
kept_indices = {i for _, _, i in kept}
with open(output_path, "w", encoding="utf-8") as f:
for i, record in enumerate(records):
if i in kept_indices:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
print(f"\n Written: {output_path}")
return report
# ============================================================
# CLI
# ============================================================
def main():
parser = argparse.ArgumentParser(
description="Training data quality filter — remove low-quality pairs (#687)"
)
parser.add_argument("input", help="Input JSONL file path")
parser.add_argument("--output", "-o", help="Output file path (default: <input>_filtered.jsonl)")
parser.add_argument("--threshold", "-t", type=float, default=0.3,
help="Minimum composite score to keep (default: 0.3)")
parser.add_argument("--dry-run", "-n", action="store_true",
help="Score only, don't write output")
parser.add_argument("--verbose", "-v", action="store_true",
help="Show worst offenders")
parser.add_argument("--report-json", "-j", help="Write report as JSON to file")
args = parser.parse_args()
if not os.path.exists(args.input):
print(f"Error: {args.input} not found", file=sys.stderr)
sys.exit(1)
print(f"Filtering: {args.input}")
print(f"Threshold: {args.threshold}")
print()
report = filter_jsonl(
args.input,
output_path=args.output,
threshold=args.threshold,
dry_run=args.dry_run,
verbose=args.verbose,
)
print(f"\n{'=' * 50}")
print(f" RESULTS")
print(f"{'=' * 50}")
print(f" Format: {report['format']}")
print(f" Total: {report['total_records']}")
print(f" Kept: {report['kept']}")
print(f" Removed: {report['removed']} ({report['removal_rate']})")
print(f" Threshold: {report['threshold']}")
print(f" Score range: {report['score_distribution']['min']:.3f} - {report['score_distribution']['max']:.3f}")
print(f" Mean score: {report['score_distribution']['mean']:.3f}")
if args.report_json:
with open(args.report_json, "w") as f:
json.dump(report, f, indent=2)
print(f"\n Report saved: {args.report_json}")
if __name__ == "__main__":
main()

View File

@@ -1,192 +0,0 @@
#!/usr/bin/env python3
"""
Tests for training data quality filter (#687).
"""
import json
import os
import tempfile
import unittest
# Import from the script
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "scripts"))
from filter_training_data import (
detect_format,
extract_text_fields,
score_specificity,
score_length_ratio,
score_code_correctness,
score_record,
filter_jsonl,
FILLER_PHRASES,
VAGUE_RESPONSES,
)
class TestFormatDetection(unittest.TestCase):
def test_sharegpt_format(self):
record = {"conversations": [{"from": "human", "value": "hi"}]}
self.assertEqual(detect_format(record), "sharegpt")
def test_preference_format(self):
record = {"prompt": "do X", "chosen": "done", "rejected": "no"}
self.assertEqual(detect_format(record), "preference")
def test_scene_format(self):
record = {"lyric_line": "test", "scene": {"description": "desc"}}
self.assertEqual(detect_format(record), "scene")
def test_pairs_format(self):
record = {"terse": "short", "rich": "detailed"}
self.assertEqual(detect_format(record), "pairs")
def test_generic_format(self):
record = {"input": "q", "output": "a"}
self.assertEqual(detect_format(record), "generic")
class TestExtractTextFields(unittest.TestCase):
def test_sharegpt_extraction(self):
record = {
"conversations": [
{"from": "system", "value": "system prompt"},
{"from": "human", "value": "hello"},
{"from": "gpt", "value": "hi there"},
]
}
inp, out = extract_text_fields(record, "sharegpt")
self.assertEqual(inp, "hello")
self.assertEqual(out, "hi there")
def test_preference_extraction(self):
record = {"prompt": "question", "chosen": "good answer"}
inp, out = extract_text_fields(record, "preference")
self.assertEqual(inp, "question")
self.assertEqual(out, "good answer")
class TestSpecificityScoring(unittest.TestCase):
def test_empty_text(self):
self.assertEqual(score_specificity(""), 0.0)
def test_filler_heavy(self):
text = "As an AI, I cannot provide that. It's important to note that I'm an AI."
score = score_specificity(text)
self.assertLess(score, 0.3)
def test_vague_response(self):
score = score_specificity("ok")
self.assertLess(score, 0.2)
def test_specific_response(self):
text = "Here are the steps:\n1. First, install Python 3.12\n2. Run `pip install numpy`\n3. Execute main.py"
score = score_specificity(text)
self.assertGreater(score, 0.5)
def test_code_response(self):
text = "Use this:\n```python\ndef hello():\n print('world')\n```"
score = score_specificity(text)
self.assertGreater(score, 0.6)
class TestLengthRatio(unittest.TestCase):
def test_both_empty(self):
self.assertEqual(score_length_ratio("", ""), 0.0)
def test_empty_output(self):
self.assertEqual(score_length_ratio("hello world", ""), 0.0)
def test_good_ratio(self):
score = score_length_ratio("short question", "This is a reasonable length answer that addresses the question.")
self.assertGreater(score, 0.7)
def test_too_short_output(self):
score = score_length_ratio("This is a very long question with many words that expects a detailed answer", "ok")
self.assertLess(score, 0.5)
class TestCodeCorrectness(unittest.TestCase):
def test_no_code(self):
self.assertEqual(score_code_correctness("plain text"), 1.0)
def test_valid_python(self):
text = "```python\ndef foo():\n return 42\n```"
self.assertEqual(score_code_correctness(text), 1.0)
def test_invalid_python(self):
text = "```python\ndef foo(\n return 42\n```"
score = score_code_correctness(text)
self.assertLess(score, 1.0)
def test_valid_json(self):
text = "```json\n{\"key\": \"value\"}\n```"
self.assertEqual(score_code_correctness(text), 1.0)
class TestFilterJsonl(unittest.TestCase):
def _write_temp_jsonl(self, records):
f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
for r in records:
f.write(json.dumps(r) + "\n")
f.close()
return f.name
def test_filter_removes_low_quality(self):
records = [
{"conversations": [
{"from": "human", "value": "How do I sort a list in Python?"},
{"from": "gpt", "value": "Use `sorted()` or `list.sort()`.\n```python\nnums = [3,1,2]\nnums.sort()\nprint(nums) # [1, 2, 3]\n```"},
]},
{"conversations": [
{"from": "human", "value": "What is Python?"},
{"from": "gpt", "value": "ok"},
]},
{"conversations": [
{"from": "human", "value": "Tell me about databases."},
{"from": "gpt", "value": "As an AI, I cannot. It's important to note."},
]},
]
path = self._write_temp_jsonl(records)
try:
report = filter_jsonl(path, threshold=0.3)
self.assertEqual(report["total_records"], 3)
self.assertGreater(report["kept"], 0)
self.assertGreater(report["removed"], 0)
self.assertEqual(report["format"], "sharegpt")
finally:
os.unlink(path)
if os.path.exists(report.get("output_file", "")):
os.unlink(report["output_file"])
def test_dry_run_no_output(self):
records = [
{"prompt": "test", "chosen": "good detailed answer with code: `print(1)`", "rejected": "no"},
]
path = self._write_temp_jsonl(records)
try:
out_path = path.replace(".jsonl", "_filtered.jsonl")
report = filter_jsonl(path, threshold=0.3, dry_run=True)
self.assertFalse(os.path.exists(out_path))
self.assertEqual(report["total_records"], 1)
finally:
os.unlink(path)
def test_preference_format(self):
records = [
{"prompt": "Write a function", "chosen": "```python\ndef f(): pass\n```", "rejected": ""},
{"prompt": "Hi", "chosen": "ok", "rejected": "no"},
]
path = self._write_temp_jsonl(records)
try:
report = filter_jsonl(path, threshold=0.3)
self.assertEqual(report["format"], "preference")
self.assertEqual(report["total_records"], 2)
finally:
os.unlink(path)
if os.path.exists(report.get("output_file", "")):
os.unlink(report["output_file"])
if __name__ == "__main__":
unittest.main()