#!/usr/bin/env python3 """Tests for training_pair_provenance.py""" import json, tempfile, unittest from pathlib import Path import sys, os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from training.training_pair_provenance import ProvenanceTracker, load_jsonl, write_jsonl class TestAnnotate(unittest.TestCase): def test_annotate_adds_required_fields(self): t = ProvenanceTracker() p = {"conversations": [{"from": "human", "value": "hi"}]} r = t.annotate(p, source="trajectory", model="hermes4:14b", session_id="s1") m = r["provenance"] self.assertEqual(m["source"], "trajectory") self.assertEqual(m["model"], "hermes4:14b") self.assertTrue(m["approved"]) self.assertNotEqual(m["timestamp"], "") def test_exclude_sets_flag(self): t = ProvenanceTracker() p = {"conversations": []} t.annotate(p, source="trajectory", model="hermes4:14b") r = t.exclude(p, "quality_filter") self.assertTrue(r["provenance"]["excluded"]) self.assertFalse(r["provenance"]["approved"]) def test_backfill_adds_provenance(self): t = ProvenanceTracker() p = {"conversations": []} r = t.backfill(p, source="backfill", model="unknown") self.assertEqual(r["provenance"]["source"], "backfill") class TestValidate(unittest.TestCase): def test_valid_pair(self): t = ProvenanceTracker() p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "source_session_id": "c1", "approved": True}} self.assertEqual(t.validate(p), []) def test_missing_provenance(self): t = ProvenanceTracker() self.assertTrue(any("missing" in e for e in t.validate({"conversations": []}))) def test_missing_field(self): t = ProvenanceTracker() p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "approved": True}} self.assertTrue(any("source_session_id" in e for e in t.validate(p))) def test_excluded_no_reason(self): t = ProvenanceTracker() p = {"provenance": {"source": "curated", "model": "timmy-curated", "timestamp": "2026-01-01", "source_session_id": "c1", "approved": True, "excluded": True}} self.assertTrue(any("exclusion_reason" in e for e in t.validate(p))) class TestReport(unittest.TestCase): def test_report_counts(self): t = ProvenanceTracker() for i in range(5): t.annotate({"conversations": []}, source="trajectory", model="hermes4:14b", session_id=f"s{i}") for i in range(3): t.annotate({"conversations": []}, source="curated", model="timmy-curated", session_id=f"c{i}") r = t.report() self.assertEqual(r["total"], 8) self.assertEqual(r["approved"], 8) self.assertEqual(r["by_source"]["trajectory"], 5) self.assertEqual(r["by_source"]["curated"], 3) class TestBackfillFile(unittest.TestCase): def test_round_trip(self): t = ProvenanceTracker() with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: for i in range(3): f.write(json.dumps({"conversations": [{"from": "human", "value": f"p{i}"}]}) + " ") p = Path(f.name) try: cnt = t.backfill_file(p, source="backfill", model="unknown") self.assertEqual(cnt, 3) loaded = load_jsonl(p) for pair in loaded: self.assertEqual(pair["provenance"]["source"], "backfill") finally: p.unlink() if __name__ == "__main__": unittest.main(verbosity=2)