91 lines
3.6 KiB
Python
91 lines
3.6 KiB
Python
#!/usr/bin/env python3
|
|
"""Tests for training_pair_provenance.py"""
|
|
import json, tempfile, unittest
|
|
from pathlib import Path
|
|
import sys, os
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
from training.training_pair_provenance import ProvenanceTracker, load_jsonl, write_jsonl
|
|
|
|
|
|
class TestAnnotate(unittest.TestCase):
|
|
def test_annotate_adds_required_fields(self):
|
|
t = ProvenanceTracker()
|
|
p = {"conversations": [{"from": "human", "value": "hi"}]}
|
|
r = t.annotate(p, source="trajectory", model="hermes4:14b", session_id="s1")
|
|
m = r["provenance"]
|
|
self.assertEqual(m["source"], "trajectory")
|
|
self.assertEqual(m["model"], "hermes4:14b")
|
|
self.assertTrue(m["approved"])
|
|
self.assertNotEqual(m["timestamp"], "")
|
|
|
|
def test_exclude_sets_flag(self):
|
|
t = ProvenanceTracker()
|
|
p = {"conversations": []}
|
|
t.annotate(p, source="trajectory", model="hermes4:14b")
|
|
r = t.exclude(p, "quality_filter")
|
|
self.assertTrue(r["provenance"]["excluded"])
|
|
self.assertFalse(r["provenance"]["approved"])
|
|
|
|
def test_backfill_adds_provenance(self):
|
|
t = ProvenanceTracker()
|
|
p = {"conversations": []}
|
|
r = t.backfill(p, source="backfill", model="unknown")
|
|
self.assertEqual(r["provenance"]["source"], "backfill")
|
|
|
|
|
|
class TestValidate(unittest.TestCase):
|
|
def test_valid_pair(self):
|
|
t = ProvenanceTracker()
|
|
p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "source_session_id": "c1", "approved": True}}
|
|
self.assertEqual(t.validate(p), [])
|
|
|
|
def test_missing_provenance(self):
|
|
t = ProvenanceTracker()
|
|
self.assertTrue(any("missing" in e for e in t.validate({"conversations": []})))
|
|
|
|
def test_missing_field(self):
|
|
t = ProvenanceTracker()
|
|
p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "approved": True}}
|
|
self.assertTrue(any("source_session_id" in e for e in t.validate(p)))
|
|
|
|
def test_excluded_no_reason(self):
|
|
t = ProvenanceTracker()
|
|
p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "source_session_id": "c1", "approved": True, "excluded": True}}
|
|
self.assertTrue(any("exclusion_reason" in e for e in t.validate(p)))
|
|
|
|
|
|
class TestReport(unittest.TestCase):
|
|
def test_report_counts(self):
|
|
t = ProvenanceTracker()
|
|
for i in range(5):
|
|
t.annotate({"conversations": []}, source="trajectory", model="hermes4:14b", session_id=f"s{i}")
|
|
for i in range(3):
|
|
t.annotate({"conversations": []}, source="curated", model="timmy-curated", session_id=f"c{i}")
|
|
r = t.report()
|
|
self.assertEqual(r["total"], 8)
|
|
self.assertEqual(r["approved"], 8)
|
|
self.assertEqual(r["by_source"]["trajectory"], 5)
|
|
self.assertEqual(r["by_source"]["curated"], 3)
|
|
|
|
|
|
class TestBackfillFile(unittest.TestCase):
|
|
def test_round_trip(self):
|
|
t = ProvenanceTracker()
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
|
for i in range(3):
|
|
f.write(json.dumps({"conversations": [{"from": "human", "value": f"p{i}"}]}) + "
|
|
")
|
|
p = Path(f.name)
|
|
try:
|
|
cnt = t.backfill_file(p, source="backfill", model="unknown")
|
|
self.assertEqual(cnt, 3)
|
|
loaded = load_jsonl(p)
|
|
for pair in loaded:
|
|
self.assertEqual(pair["provenance"]["source"], "backfill")
|
|
finally:
|
|
p.unlink()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|