feat: add provenance validation tests (#752)
This commit is contained in:
90
training/tests/test_provenance.py
Normal file
90
training/tests/test_provenance.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for training_pair_provenance.py"""
|
||||
import json, tempfile, unittest
|
||||
from pathlib import Path
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
from training.training_pair_provenance import ProvenanceTracker, load_jsonl, write_jsonl
|
||||
|
||||
|
||||
class TestAnnotate(unittest.TestCase):
|
||||
def test_annotate_adds_required_fields(self):
|
||||
t = ProvenanceTracker()
|
||||
p = {"conversations": [{"from": "human", "value": "hi"}]}
|
||||
r = t.annotate(p, source="trajectory", model="hermes4:14b", session_id="s1")
|
||||
m = r["provenance"]
|
||||
self.assertEqual(m["source"], "trajectory")
|
||||
self.assertEqual(m["model"], "hermes4:14b")
|
||||
self.assertTrue(m["approved"])
|
||||
self.assertNotEqual(m["timestamp"], "")
|
||||
|
||||
def test_exclude_sets_flag(self):
|
||||
t = ProvenanceTracker()
|
||||
p = {"conversations": []}
|
||||
t.annotate(p, source="trajectory", model="hermes4:14b")
|
||||
r = t.exclude(p, "quality_filter")
|
||||
self.assertTrue(r["provenance"]["excluded"])
|
||||
self.assertFalse(r["provenance"]["approved"])
|
||||
|
||||
def test_backfill_adds_provenance(self):
|
||||
t = ProvenanceTracker()
|
||||
p = {"conversations": []}
|
||||
r = t.backfill(p, source="backfill", model="unknown")
|
||||
self.assertEqual(r["provenance"]["source"], "backfill")
|
||||
|
||||
|
||||
class TestValidate(unittest.TestCase):
|
||||
def test_valid_pair(self):
|
||||
t = ProvenanceTracker()
|
||||
p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "source_session_id": "c1", "approved": True}}
|
||||
self.assertEqual(t.validate(p), [])
|
||||
|
||||
def test_missing_provenance(self):
|
||||
t = ProvenanceTracker()
|
||||
self.assertTrue(any("missing" in e for e in t.validate({"conversations": []})))
|
||||
|
||||
def test_missing_field(self):
|
||||
t = ProvenanceTracker()
|
||||
p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "approved": True}}
|
||||
self.assertTrue(any("source_session_id" in e for e in t.validate(p)))
|
||||
|
||||
def test_excluded_no_reason(self):
|
||||
t = ProvenanceTracker()
|
||||
p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "source_session_id": "c1", "approved": True, "excluded": True}}
|
||||
self.assertTrue(any("exclusion_reason" in e for e in t.validate(p)))
|
||||
|
||||
|
||||
class TestReport(unittest.TestCase):
|
||||
def test_report_counts(self):
|
||||
t = ProvenanceTracker()
|
||||
for i in range(5):
|
||||
t.annotate({"conversations": []}, source="trajectory", model="hermes4:14b", session_id=f"s{i}")
|
||||
for i in range(3):
|
||||
t.annotate({"conversations": []}, source="curated", model="timmy-curated", session_id=f"c{i}")
|
||||
r = t.report()
|
||||
self.assertEqual(r["total"], 8)
|
||||
self.assertEqual(r["approved"], 8)
|
||||
self.assertEqual(r["by_source"]["trajectory"], 5)
|
||||
self.assertEqual(r["by_source"]["curated"], 3)
|
||||
|
||||
|
||||
class TestBackfillFile(unittest.TestCase):
|
||||
def test_round_trip(self):
|
||||
t = ProvenanceTracker()
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
for i in range(3):
|
||||
f.write(json.dumps({"conversations": [{"from": "human", "value": f"p{i}"}]}) + "
|
||||
")
|
||||
p = Path(f.name)
|
||||
try:
|
||||
cnt = t.backfill_file(p, source="backfill", model="unknown")
|
||||
self.assertEqual(cnt, 3)
|
||||
loaded = load_jsonl(p)
|
||||
for pair in loaded:
|
||||
self.assertEqual(pair["provenance"]["source"], "backfill")
|
||||
finally:
|
||||
p.unlink()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user