Compare commits
3 Commits
fix/660-py
...
feat/687-q
| Author | SHA1 | Date | |
|---|---|---|---|
| a0266c83a4 | |||
| b28071bb71 | |||
| 04ecad3b43 |
276
scripts/quality_filter.py
Normal file
276
scripts/quality_filter.py
Normal file
@@ -0,0 +1,276 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training Data Quality Filter — Score and remove low-quality training pairs.
|
||||
|
||||
Scores each pair on:
|
||||
1. Specificity: How concrete vs generic is the content?
|
||||
2. Length ratio: Balanced input/output lengths?
|
||||
3. Code correctness: If code is present, does it parse?
|
||||
|
||||
Usage:
|
||||
python3 quality_filter.py input.jsonl -o output.jsonl
|
||||
python3 quality_filter.py input.jsonl --report
|
||||
python3 quality_filter.py input.jsonl --threshold 0.4
|
||||
|
||||
Accepts JSONL where each line has:
|
||||
{"prompt": "...", "response": "..."} or {"input": "...", "output": "..."}
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCORING
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
GENERIC_PHRASES = [
|
||||
"i don't know", "it depends", "there are many ways",
|
||||
"that's a good question", "let me think about", "in general",
|
||||
"as an ai", "i cannot", "i'm sorry but", "unfortunately",
|
||||
"that being said", "it's worth noting", "in conclusion",
|
||||
"to summarize", "overall", "basically", "essentially",
|
||||
]
|
||||
|
||||
SPECIFIC_MARKERS = [
|
||||
r"(?:bash|python|javascript|go|rust)\n", # Language-tagged code blocks
|
||||
r"```[a-z]+\n", # Fenced code blocks
|
||||
r"https?://\S+", # URLs
|
||||
r"(?:file|path|dir|repo|branch|commit)\b", # Concrete references
|
||||
r"\d+\.\d+\.\d+", # Version numbers
|
||||
r"(?:error|exception|traceback|stderr)", # Error messages
|
||||
r"(?:curl|git|apt|brew|pip|npm)\s", # CLI commands
|
||||
r"(?:GET|POST|PUT|DELETE|PATCH)\s", # HTTP methods
|
||||
r"(?:Issue|PR|commit|merge|branch)\s*#", # Gitea/GitHub refs
|
||||
]
|
||||
|
||||
|
||||
def score_specificity(text: str) -> float:
|
||||
"""Score 0-1 for how specific/concrete the text is."""
|
||||
text_lower = text.lower()
|
||||
score = 0.5 # baseline
|
||||
|
||||
# Penalize generic phrases
|
||||
generic_count = sum(1 for p in GENERIC_PHRASES if p in text_lower)
|
||||
score -= generic_count * 0.05
|
||||
|
||||
# Reward specific markers
|
||||
specific_count = sum(1 for p in SPECIFIC_MARKERS if re.search(p, text, re.IGNORECASE))
|
||||
score += specific_count * 0.08
|
||||
|
||||
# Reward longer, detailed responses
|
||||
word_count = len(text.split())
|
||||
if word_count > 100:
|
||||
score += 0.1
|
||||
elif word_count > 50:
|
||||
score += 0.05
|
||||
elif word_count < 10:
|
||||
score -= 0.15
|
||||
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
|
||||
def score_length_ratio(prompt: str, response: str) -> float:
|
||||
"""Score 0-1 for balanced input/output lengths."""
|
||||
p_len = len(prompt.split())
|
||||
r_len = len(response.split())
|
||||
|
||||
if p_len == 0 or r_len == 0:
|
||||
return 0.0
|
||||
|
||||
ratio = r_len / p_len
|
||||
|
||||
# Ideal: response is 1-10x the prompt length
|
||||
if 1.0 <= ratio <= 10.0:
|
||||
return 1.0
|
||||
elif 0.5 <= ratio <= 20.0:
|
||||
return 0.7
|
||||
elif 0.2 <= ratio <= 50.0:
|
||||
return 0.4
|
||||
else:
|
||||
return 0.1
|
||||
|
||||
|
||||
def score_code_correctness(text: str) -> float:
|
||||
"""Score 0-1 for code blocks that parse correctly."""
|
||||
code_blocks = re.findall(r"```(?:\w*\n)?(.*?)```", text, re.DOTALL)
|
||||
|
||||
if not code_blocks:
|
||||
return 1.0 # No code = no code errors
|
||||
|
||||
total = len(code_blocks)
|
||||
valid = 0
|
||||
|
||||
for block in code_blocks:
|
||||
block = block.strip()
|
||||
if not block:
|
||||
continue
|
||||
|
||||
# Try Python parse
|
||||
try:
|
||||
ast.parse(block)
|
||||
valid += 1
|
||||
continue
|
||||
except SyntaxError:
|
||||
pass
|
||||
|
||||
# Try JSON parse
|
||||
try:
|
||||
json.loads(block)
|
||||
valid += 1
|
||||
continue
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Shell scripts: check for balanced braces/parens
|
||||
open_count = block.count("{") + block.count("(") + block.count("[")
|
||||
close_count = block.count("}") + block.count(")") + block.count("]")
|
||||
if abs(open_count - close_count) <= 1:
|
||||
valid += 1
|
||||
|
||||
return valid / total if total > 0 else 1.0
|
||||
|
||||
|
||||
def score_pair(pair: dict) -> dict:
|
||||
"""Score a single training pair. Returns scores dict and composite."""
|
||||
prompt = str(pair.get("prompt") or pair.get("input") or pair.get("question") or "")
|
||||
response = str(pair.get("response") or pair.get("output") or pair.get("answer") or pair.get("completion") or "")
|
||||
|
||||
if not prompt or not response:
|
||||
return {"specificity": 0.0, "length_ratio": 0.0, "code_correctness": 0.0, "composite": 0.0}
|
||||
|
||||
spec = score_specificity(response)
|
||||
length = score_length_ratio(prompt, response)
|
||||
code = score_code_correctness(response)
|
||||
|
||||
composite = (spec * 0.5) + (length * 0.2) + (code * 0.3)
|
||||
|
||||
return {
|
||||
"specificity": round(spec, 3),
|
||||
"length_ratio": round(length, 3),
|
||||
"code_correctness": round(code, 3),
|
||||
"composite": round(composite, 3),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FILTER
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def filter_pairs(input_path: str, output_path: str = None, threshold: float = 0.3,
|
||||
report: bool = False) -> dict:
|
||||
"""Filter JSONL training pairs by quality score."""
|
||||
|
||||
kept = []
|
||||
removed = []
|
||||
total = 0
|
||||
|
||||
with open(input_path, "r") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
removed.append({"line": line_num, "reason": "invalid JSON", "scores": {}})
|
||||
continue
|
||||
|
||||
total += 1
|
||||
scores = score_pair(pair)
|
||||
pair["_quality_scores"] = scores
|
||||
|
||||
if scores["composite"] >= threshold:
|
||||
kept.append(pair)
|
||||
else:
|
||||
pair["_filter_reason"] = f"composite {scores['composite']} < {threshold}"
|
||||
removed.append(pair)
|
||||
|
||||
# Write filtered output
|
||||
if output_path and kept:
|
||||
with open(output_path, "w") as f:
|
||||
for pair in kept:
|
||||
# Remove internal scoring metadata before writing
|
||||
clean = {k: v for k, v in pair.items() if not k.startswith("_")}
|
||||
f.write(json.dumps(clean, ensure_ascii=False) + "\n")
|
||||
|
||||
result = {
|
||||
"total": total,
|
||||
"kept": len(kept),
|
||||
"removed": len(removed),
|
||||
"threshold": threshold,
|
||||
"removal_rate": round(len(removed) / total * 100, 1) if total > 0 else 0,
|
||||
}
|
||||
|
||||
if report:
|
||||
print(f"\n=== QUALITY FILTER REPORT ===")
|
||||
print(f"Input: {input_path}")
|
||||
if output_path:
|
||||
print(f"Output: {output_path}")
|
||||
print(f"")
|
||||
print(f"Total pairs: {result['total']}")
|
||||
print(f"Kept: {result['kept']}")
|
||||
print(f"Removed: {result['removed']} ({result['removal_rate']}%)")
|
||||
print(f"Threshold: {result['threshold']}")
|
||||
print(f"")
|
||||
|
||||
# Score distribution
|
||||
if kept:
|
||||
composites = [p["_quality_scores"]["composite"] for p in kept]
|
||||
print(f"Kept scores: min={min(composites):.3f} max={max(composites):.3f} avg={sum(composites)/len(composites):.3f}")
|
||||
|
||||
if removed:
|
||||
reasons = {}
|
||||
for r in removed:
|
||||
reason = r.get("_filter_reason", r.get("reason", "unknown"))
|
||||
reasons[reason] = reasons.get(reason, 0) + 1
|
||||
print(f"\nRemoval reasons:")
|
||||
for reason, count in sorted(reasons.items(), key=lambda x: -x[1]):
|
||||
print(f" {reason}: {count}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Training data quality filter — score and remove low-quality pairs"
|
||||
)
|
||||
parser.add_argument("input", help="Input JSONL file")
|
||||
parser.add_argument("-o", "--output", help="Output JSONL file (filtered)")
|
||||
parser.add_argument("-t", "--threshold", type=float, default=0.3,
|
||||
help="Quality threshold (0.0-1.0, default: 0.3)")
|
||||
parser.add_argument("--report", action="store_true",
|
||||
help="Print detailed report")
|
||||
parser.add_argument("--dry-run", action="store_true",
|
||||
help="Score only, don't filter")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not Path(args.input).exists():
|
||||
print(f"ERROR: Input file not found: {args.input}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.dry_run and not args.output:
|
||||
args.report = True
|
||||
|
||||
output = args.output
|
||||
if args.dry_run:
|
||||
output = None
|
||||
|
||||
result = filter_pairs(args.input, output, args.threshold, args.report)
|
||||
|
||||
if not args.report:
|
||||
print(f"{result['kept']}/{result['total']} pairs kept (removed {result['removed']}, {result['removal_rate']}%)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
136
scripts/test_quality_filter.py
Normal file
136
scripts/test_quality_filter.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for training data quality filter.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from quality_filter import score_specificity, score_length_ratio, score_code_correctness, score_pair, filter_pairs
|
||||
|
||||
|
||||
class TestSpecificity(unittest.TestCase):
|
||||
|
||||
def test_generic_response_scores_low(self):
|
||||
text = "I don't know. It depends on many factors. There are many ways to approach this."
|
||||
score = score_specificity(text)
|
||||
self.assertLess(score, 0.4)
|
||||
|
||||
def test_specific_response_scores_high(self):
|
||||
text = 'Run: curl -s https://api.example.com/v1/repos | python3 -c "import sys,json; print(json.load(sys.stdin))"'
|
||||
score = score_specificity(text)
|
||||
self.assertGreater(score, 0.6)
|
||||
|
||||
def test_code_block_boosts_score(self):
|
||||
text = """Here's the fix:
|
||||
```python
|
||||
def hello():
|
||||
return "world"
|
||||
```"""
|
||||
score = score_specificity(text)
|
||||
self.assertGreater(score, 0.5)
|
||||
|
||||
def test_long_detailed_response(self):
|
||||
text = " ".join(["word"] * 150) + " GET /api/v1/repos"
|
||||
score = score_specificity(text)
|
||||
self.assertGreater(score, 0.5)
|
||||
|
||||
def test_short_response_penalized(self):
|
||||
score = score_specificity("yes")
|
||||
self.assertLess(score, 0.4)
|
||||
|
||||
|
||||
class TestLengthRatio(unittest.TestCase):
|
||||
|
||||
def test_balanced_ratio(self):
|
||||
score = score_length_ratio("short prompt", "This is a medium length response with some detail.")
|
||||
self.assertEqual(score, 1.0)
|
||||
|
||||
def test_too_short_response(self):
|
||||
score = score_length_ratio("A long prompt with many words here", "ok")
|
||||
self.assertLess(score, 1.0)
|
||||
|
||||
def test_empty_returns_zero(self):
|
||||
self.assertEqual(score_length_ratio("", "something"), 0.0)
|
||||
self.assertEqual(score_length_ratio("something", ""), 0.0)
|
||||
|
||||
|
||||
class TestCodeCorrectness(unittest.TestCase):
|
||||
|
||||
def test_no_code_returns_one(self):
|
||||
self.assertEqual(score_code_correctness("Just text, no code."), 1.0)
|
||||
|
||||
def test_valid_python(self):
|
||||
text = '```python\ndef foo():\n return 42\n```'
|
||||
self.assertEqual(score_code_correctness(text), 1.0)
|
||||
|
||||
def test_valid_json(self):
|
||||
text = '```json\n{"key": "value"}\n```'
|
||||
self.assertEqual(score_code_correctness(text), 1.0)
|
||||
|
||||
def test_invalid_python(self):
|
||||
text = '```python\ndef foo(\n return broken\n```'
|
||||
score = score_code_correctness(text)
|
||||
self.assertLess(score, 1.0)
|
||||
|
||||
|
||||
class TestScorePair(unittest.TestCase):
|
||||
|
||||
def test_good_pair(self):
|
||||
pair = {
|
||||
"prompt": "How do I list files in Python?",
|
||||
"response": 'Use `os.listdir()` or `pathlib.Path.iterdir()`. Example:\n```python\nfrom pathlib import Path\nfor f in Path(".").iterdir():\n print(f)\n```'
|
||||
}
|
||||
scores = score_pair(pair)
|
||||
self.assertGreater(scores["composite"], 0.4)
|
||||
|
||||
def test_bad_pair(self):
|
||||
pair = {
|
||||
"prompt": "How do I deploy?",
|
||||
"response": "It depends. There are many ways. I don't know your setup."
|
||||
}
|
||||
scores = score_pair(pair)
|
||||
self.assertLess(scores["composite"], 0.4)
|
||||
|
||||
def test_empty_pair_returns_zero(self):
|
||||
scores = score_pair({})
|
||||
self.assertEqual(scores["composite"], 0.0)
|
||||
|
||||
|
||||
class TestFilterPairs(unittest.TestCase):
|
||||
|
||||
def test_filter_removes_low_quality(self):
|
||||
pairs = [
|
||||
json.dumps({"prompt": "How?", "response": "Yes."}),
|
||||
json.dumps({"prompt": "List files?", "response": 'Use os.listdir():\n```python\nimport os\nos.listdir(".")\n```'}),
|
||||
json.dumps({"prompt": "Deploy?", "response": "It depends. I don't know."}),
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write("\n".join(pairs) + "\n")
|
||||
input_path = f.name
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
output_path = f.name
|
||||
|
||||
try:
|
||||
result = filter_pairs(input_path, output_path, threshold=0.3)
|
||||
self.assertEqual(result["total"], 3)
|
||||
self.assertGreater(result["kept"], 0)
|
||||
self.assertGreater(result["removed"], 0)
|
||||
|
||||
# Verify output is valid JSONL
|
||||
with open(output_path) as f:
|
||||
for line in f:
|
||||
json.loads(line.strip())
|
||||
finally:
|
||||
os.unlink(input_path)
|
||||
os.unlink(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user