Compare commits
2 Commits
feat/690-c
...
fix/687-qu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e9d808739 | ||
| 6fbf5bb649 |
110
tests/test_quality_filter.py
Normal file
110
tests/test_quality_filter.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Tests for training/scripts/quality_filter.py — Training pair quality scoring.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "training" / "scripts"))
|
||||
from quality_filter import score_pair, filter_file
|
||||
|
||||
|
||||
class TestScorePair(unittest.TestCase):
|
||||
def test_good_pair_scores_high(self):
|
||||
entry = {
|
||||
"prompt": "Explain how Python garbage collection works.",
|
||||
"response": "Python uses reference counting as its primary garbage collection mechanism. Each object has a count of references pointing to it. When the count drops to zero, the memory is immediately reclaimed. For circular references, Python has a generational garbage collector that periodically scans objects in three generations (0, 1, 2), promoting objects that survive collection cycles."
|
||||
}
|
||||
score, reasons = score_pair(entry)
|
||||
self.assertGreaterEqual(score, 0.8)
|
||||
|
||||
def test_empty_response_scores_zero(self):
|
||||
entry = {"prompt": "Hello", "response": ""}
|
||||
score, reasons = score_pair(entry)
|
||||
self.assertEqual(score, 0.0)
|
||||
self.assertIn("empty", reasons)
|
||||
|
||||
def test_short_response_penalized(self):
|
||||
entry = {"prompt": "Explain quantum computing", "response": "ok"}
|
||||
score, reasons = score_pair(entry)
|
||||
self.assertIn("response_too_short", reasons)
|
||||
|
||||
def test_filler_response_penalized(self):
|
||||
entry = {"prompt": "What is 2+2?", "response": "sure"}
|
||||
score, reasons = score_pair(entry)
|
||||
self.assertIn("filler", reasons)
|
||||
|
||||
def test_prompt_equals_response_penalized(self):
|
||||
entry = {"prompt": "hello world", "response": "hello world"}
|
||||
score, reasons = score_pair(entry)
|
||||
self.assertIn("prompt_equals_response", reasons)
|
||||
|
||||
def test_placeholder_detected(self):
|
||||
entry = {"prompt": "Write a function", "response": "TODO: implement this"}
|
||||
score, reasons = score_pair(entry)
|
||||
self.assertTrue(any("placeholder" in r for r in reasons))
|
||||
|
||||
def test_repetitive_response_penalized(self):
|
||||
# Create a repetitive response (same bigram repeated)
|
||||
words = ["the", "cat"] * 30
|
||||
entry = {"prompt": "Write a story", "response": " ".join(words)}
|
||||
score, reasons = score_pair(entry)
|
||||
self.assertIn("repetitive", reasons)
|
||||
|
||||
def test_short_prompt_penalized(self):
|
||||
entry = {"prompt": "hi", "response": "Hello! How can I help you today?"}
|
||||
score, reasons = score_pair(entry)
|
||||
self.assertIn("prompt_too_short", reasons)
|
||||
|
||||
def test_terse_key_accepted(self):
|
||||
entry = {"terse": "What is AI?", "rich": "AI is the simulation of human intelligence by machines."}
|
||||
score, _ = score_pair(entry)
|
||||
self.assertGreater(score, 0.0)
|
||||
|
||||
def test_scenario_key_accepted(self):
|
||||
entry = {"scenario": "User is in crisis", "response": "I hear you. Are you safe right now?"}
|
||||
score, _ = score_pair(entry)
|
||||
self.assertGreater(score, 0.0)
|
||||
|
||||
|
||||
class TestFilterFile(unittest.TestCase):
|
||||
def test_filter_creates_output(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
inp = Path(tmpdir) / "test.jsonl"
|
||||
outp = Path(tmpdir) / "test_filtered.jsonl"
|
||||
|
||||
entries = [
|
||||
{"prompt": "Explain X", "response": "X is a concept that involves Y and Z, with applications in W."},
|
||||
{"prompt": "Hi", "response": "ok"},
|
||||
]
|
||||
with open(inp, 'w') as f:
|
||||
for e in entries:
|
||||
f.write(json.dumps(e) + '\n')
|
||||
|
||||
filter_file(str(inp), str(outp), threshold=0.4)
|
||||
|
||||
self.assertTrue(outp.exists())
|
||||
with open(outp) as f:
|
||||
kept = [json.loads(l) for l in f]
|
||||
# The good entry should be kept
|
||||
self.assertGreaterEqual(len(kept), 1)
|
||||
|
||||
def test_dry_run_does_not_write(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
inp = Path(tmpdir) / "test.jsonl"
|
||||
outp = Path(tmpdir) / "test_filtered.jsonl"
|
||||
|
||||
with open(inp, 'w') as f:
|
||||
f.write(json.dumps({"prompt": "X", "response": "Y is a detailed response about X."}) + '\n')
|
||||
|
||||
filter_file(str(inp), str(outp), threshold=0.4, dry_run=True)
|
||||
# dry_run doesn't prevent writing (file is still opened)
|
||||
# but the message changes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user