diff --git a/training/tests/test_provenance.py b/training/tests/test_provenance.py new file mode 100644 index 00000000..053d6022 --- /dev/null +++ b/training/tests/test_provenance.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +"""Tests for training_pair_provenance.py""" +import json, tempfile, unittest +from pathlib import Path +import sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +from training.training_pair_provenance import ProvenanceTracker, load_jsonl, write_jsonl + + +class TestAnnotate(unittest.TestCase): + def test_annotate_adds_required_fields(self): + t = ProvenanceTracker() + p = {"conversations": [{"from": "human", "value": "hi"}]} + r = t.annotate(p, source="trajectory", model="hermes4:14b", session_id="s1") + m = r["provenance"] + self.assertEqual(m["source"], "trajectory") + self.assertEqual(m["model"], "hermes4:14b") + self.assertTrue(m["approved"]) + self.assertNotEqual(m["timestamp"], "") + + def test_exclude_sets_flag(self): + t = ProvenanceTracker() + p = {"conversations": []} + t.annotate(p, source="trajectory", model="hermes4:14b") + r = t.exclude(p, "quality_filter") + self.assertTrue(r["provenance"]["excluded"]) + self.assertFalse(r["provenance"]["approved"]) + + def test_backfill_adds_provenance(self): + t = ProvenanceTracker() + p = {"conversations": []} + r = t.backfill(p, source="backfill", model="unknown") + self.assertEqual(r["provenance"]["source"], "backfill") + + +class TestValidate(unittest.TestCase): + def test_valid_pair(self): + t = ProvenanceTracker() + p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "source_session_id": "c1", "approved": True}} + self.assertEqual(t.validate(p), []) + + def test_missing_provenance(self): + t = ProvenanceTracker() + self.assertTrue(any("missing" in e for e in t.validate({"conversations": []}))) + + def test_missing_field(self): + t = ProvenanceTracker() + p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "approved": True}} + self.assertTrue(any("source_session_id" in e for e in t.validate(p))) + + def test_excluded_no_reason(self): + t = ProvenanceTracker() + p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "source_session_id": "c1", "approved": True, "excluded": True}} + self.assertTrue(any("exclusion_reason" in e for e in t.validate(p))) + + +class TestReport(unittest.TestCase): + def test_report_counts(self): + t = ProvenanceTracker() + for i in range(5): + t.annotate({"conversations": []}, source="trajectory", model="hermes4:14b", session_id=f"s{i}") + for i in range(3): + t.annotate({"conversations": []}, source="curated", model="timmy-curated", session_id=f"c{i}") + r = t.report() + self.assertEqual(r["total"], 8) + self.assertEqual(r["approved"], 8) + self.assertEqual(r["by_source"]["trajectory"], 5) + self.assertEqual(r["by_source"]["curated"], 3) + + +class TestBackfillFile(unittest.TestCase): + def test_round_trip(self): + t = ProvenanceTracker() + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for i in range(3): + f.write(json.dumps({"conversations": [{"from": "human", "value": f"p{i}"}]}) + " +") + p = Path(f.name) + try: + cnt = t.backfill_file(p, source="backfill", model="unknown") + self.assertEqual(cnt, 3) + loaded = load_jsonl(p) + for pair in loaded: + self.assertEqual(pair["provenance"]["source"], "backfill") + finally: + p.unlink() + + +if __name__ == "__main__": + unittest.main(verbosity=2)