feat: add provenance validation tests (#752)

This commit is contained in:
2026-04-15 22:48:23 +00:00
parent 08c2c5b945
commit 2b607f4eaf

View File

@@ -0,0 +1,90 @@
#!/usr/bin/env python3
"""Tests for training_pair_provenance.py"""
import json, tempfile, unittest
from pathlib import Path
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from training.training_pair_provenance import ProvenanceTracker, load_jsonl, write_jsonl
class TestAnnotate(unittest.TestCase):
def test_annotate_adds_required_fields(self):
t = ProvenanceTracker()
p = {"conversations": [{"from": "human", "value": "hi"}]}
r = t.annotate(p, source="trajectory", model="hermes4:14b", session_id="s1")
m = r["provenance"]
self.assertEqual(m["source"], "trajectory")
self.assertEqual(m["model"], "hermes4:14b")
self.assertTrue(m["approved"])
self.assertNotEqual(m["timestamp"], "")
def test_exclude_sets_flag(self):
t = ProvenanceTracker()
p = {"conversations": []}
t.annotate(p, source="trajectory", model="hermes4:14b")
r = t.exclude(p, "quality_filter")
self.assertTrue(r["provenance"]["excluded"])
self.assertFalse(r["provenance"]["approved"])
def test_backfill_adds_provenance(self):
t = ProvenanceTracker()
p = {"conversations": []}
r = t.backfill(p, source="backfill", model="unknown")
self.assertEqual(r["provenance"]["source"], "backfill")
class TestValidate(unittest.TestCase):
def test_valid_pair(self):
t = ProvenanceTracker()
p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "source_session_id": "c1", "approved": True}}
self.assertEqual(t.validate(p), [])
def test_missing_provenance(self):
t = ProvenanceTracker()
self.assertTrue(any("missing" in e for e in t.validate({"conversations": []})))
def test_missing_field(self):
t = ProvenanceTracker()
p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "approved": True}}
self.assertTrue(any("source_session_id" in e for e in t.validate(p)))
def test_excluded_no_reason(self):
t = ProvenanceTracker()
p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "source_session_id": "c1", "approved": True, "excluded": True}}
self.assertTrue(any("exclusion_reason" in e for e in t.validate(p)))
class TestReport(unittest.TestCase):
def test_report_counts(self):
t = ProvenanceTracker()
for i in range(5):
t.annotate({"conversations": []}, source="trajectory", model="hermes4:14b", session_id=f"s{i}")
for i in range(3):
t.annotate({"conversations": []}, source="curated", model="timmy-curated", session_id=f"c{i}")
r = t.report()
self.assertEqual(r["total"], 8)
self.assertEqual(r["approved"], 8)
self.assertEqual(r["by_source"]["trajectory"], 5)
self.assertEqual(r["by_source"]["curated"], 3)
class TestBackfillFile(unittest.TestCase):
def test_round_trip(self):
t = ProvenanceTracker()
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
for i in range(3):
f.write(json.dumps({"conversations": [{"from": "human", "value": f"p{i}"}]}) + "
")
p = Path(f.name)
try:
cnt = t.backfill_file(p, source="backfill", model="unknown")
self.assertEqual(cnt, 3)
loaded = load_jsonl(p)
for pair in loaded:
self.assertEqual(pair["provenance"]["source"], "backfill")
finally:
p.unlink()
if __name__ == "__main__":
unittest.main(verbosity=2)