#!/usr/bin/env python3 """Tests for training_pair_provenance.py""" import json import tempfile import unittest from pathlib import Path from datetime import datetime, timezone # Adjust import path import sys sys.path.insert(0, str(Path(__file__).resolve().parent)) from training_pair_provenance import ( make_provenance, attach_provenance, extract_trajectory_provenance, pair_fingerprint, validate_provenance, provenance_dashboard, backfill_provenance, load_jsonl, save_jsonl, VALID_SOURCES, REQUIRED_FIELDS, ) class TestMakeProvenance(unittest.TestCase): def test_creates_valid_provenance(self): prov = make_provenance( source="curated", source_session_id="test-001", model="timmy-curated", ) for field in REQUIRED_FIELDS: self.assertIn(field, prov) self.assertEqual(prov["source"], "curated") self.assertEqual(prov["source_session_id"], "test-001") self.assertEqual(prov["model"], "timmy-curated") def test_rejects_invalid_source(self): with self.assertRaises(ValueError): make_provenance(source="bogus", source_session_id="x", model="y") def test_accepts_all_valid_sources(self): for source in VALID_SOURCES: prov = make_provenance(source=source, source_session_id="x", model="y") self.assertEqual(prov["source"], source) def test_includes_extras(self): prov = make_provenance( source="curated", source_session_id="x", model="y", extras={"custom_field": "value"}, ) self.assertEqual(prov["custom_field"], "value") def test_uses_provided_timestamp(self): ts = "2026-01-01T00:00:00Z" prov = make_provenance(source="curated", source_session_id="x", model="y", timestamp=ts) self.assertEqual(prov["timestamp"], ts) def test_generates_timestamp_if_not_provided(self): prov = make_provenance(source="curated", source_session_id="x", model="y") self.assertIn("T", prov["timestamp"]) self.assertIn("Z", prov["timestamp"]) class TestAttachProvenance(unittest.TestCase): def test_attaches_to_pair(self): pair = {"id": "test", "conversations": []} result = attach_provenance(pair, source="curated", source_session_id="s1", model="m1") self.assertIn("provenance", result) self.assertEqual(result["provenance"]["source"], "curated") def test_does_not_overwrite_existing(self): pair = { "id": "test", "provenance": {"source": "original", "source_session_id": "old", "model": "x"}, } result = attach_provenance(pair, source="new", source_session_id="new", model="y") self.assertEqual(result["provenance"]["source"], "original") def test_overwrites_with_force(self): pair = { "id": "test", "provenance": {"source": "curated", "source_session_id": "old", "model": "x", "timestamp": "t"}, } result = attach_provenance( pair, source="backfill", source_session_id="new", model="y", extras={"force": True} ) self.assertEqual(result["provenance"]["source"], "backfill") class TestExtractTrajectoryProvenance(unittest.TestCase): def test_extracts_from_trajectory(self): entry = { "id": "session-123", "model": "hermes3:latest", "started_at": "2026-04-14T10:00:00", } prov = extract_trajectory_provenance(entry) self.assertEqual(prov["source_session_id"], "session-123") self.assertEqual(prov["model"], "hermes3:latest") self.assertEqual(prov["timestamp"], "2026-04-14T10:00:00") def test_uses_defaults_for_missing(self): entry = {} prov = extract_trajectory_provenance(entry) self.assertEqual(prov["source_session_id"], "unknown") self.assertEqual(prov["model"], "unknown") self.assertIn("T", prov["timestamp"]) class TestPairFingerprint(unittest.TestCase): def test_same_content_same_hash(self): pair1 = {"conversations": [{"from": "human", "value": "hi"}, {"from": "gpt", "value": "hello"}]} pair2 = {"conversations": [{"from": "human", "value": "hi"}, {"from": "gpt", "value": "hello"}]} self.assertEqual(pair_fingerprint(pair1), pair_fingerprint(pair2)) def test_different_content_different_hash(self): pair1 = {"conversations": [{"from": "human", "value": "hi"}]} pair2 = {"conversations": [{"from": "human", "value": "bye"}]} self.assertNotEqual(pair_fingerprint(pair1), pair_fingerprint(pair2)) def test_ignores_provenance_in_hash(self): pair1 = { "conversations": [{"from": "human", "value": "hi"}], "provenance": {"source": "a"}, } pair2 = { "conversations": [{"from": "human", "value": "hi"}], "provenance": {"source": "b"}, } self.assertEqual(pair_fingerprint(pair1), pair_fingerprint(pair2)) def test_ignores_system_prompt(self): pair1 = { "conversations": [ {"from": "system", "value": "prompt A"}, {"from": "human", "value": "hi"}, ] } pair2 = { "conversations": [ {"from": "system", "value": "prompt B"}, {"from": "human", "value": "hi"}, ] } self.assertEqual(pair_fingerprint(pair1), pair_fingerprint(pair2)) class TestValidateProvenance(unittest.TestCase): def _write_jsonl(self, entries): f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) for entry in entries: f.write(json.dumps(entry) + "\n") f.close() return f.name def test_all_valid(self): path = self._write_jsonl([ {"id": "1", "provenance": {"source": "curated", "source_session_id": "s1", "model": "m1", "timestamp": "2026-01-01T00:00:00Z"}}, {"id": "2", "provenance": {"source": "trajectory", "source_session_id": "s2", "model": "m2", "timestamp": "2026-01-01T00:00:00Z"}}, ]) report = validate_provenance(path) self.assertEqual(report["missing_provenance"], 0) self.assertEqual(report["missing_fields"], 0) self.assertEqual(report["coverage"], 100.0) def test_missing_provenance(self): path = self._write_jsonl([ {"id": "1", "conversations": []}, ]) report = validate_provenance(path) self.assertEqual(report["missing_provenance"], 1) self.assertEqual(report["coverage"], 0.0) def test_missing_fields(self): path = self._write_jsonl([ {"id": "1", "provenance": {"source": "curated"}}, # missing session_id, model, timestamp ]) report = validate_provenance(path) self.assertEqual(report["missing_fields"], 1) def test_invalid_source(self): path = self._write_jsonl([ {"id": "1", "provenance": {"source": "bogus", "source_session_id": "s1", "model": "m1", "timestamp": "2026-01-01T00:00:00Z"}}, ]) report = validate_provenance(path) self.assertEqual(report["invalid_source"], 1) class TestBackfillProvenance(unittest.TestCase): def _write_jsonl(self, entries): f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) for entry in entries: f.write(json.dumps(entry) + "\n") f.close() return f.name def test_backfills_missing(self): path = self._write_jsonl([ {"id": "1", "conversations": []}, {"id": "2", "conversations": [], "provenance": {"source": "existing", "source_session_id": "s", "model": "m", "timestamp": "t"}}, ]) out = tempfile.NamedTemporaryFile(suffix=".jsonl", delete=False).name stats = backfill_provenance(path, source="backfill", model="test-model", output_path=out) self.assertEqual(stats["backfilled"], 1) self.assertEqual(stats["already_had"], 1) entries = load_jsonl(out) self.assertEqual(entries[0]["provenance"]["source"], "backfill") self.assertEqual(entries[1]["provenance"]["source"], "existing") class TestDashboard(unittest.TestCase): def test_generates_output(self): f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) for i in range(5): f.write(json.dumps({ "id": str(i), "provenance": { "source": "curated" if i < 3 else "trajectory", "source_session_id": f"s{i}", "model": "timmy-curated" if i < 3 else "hermes3", "timestamp": f"2026-04-{14+i:02d}T00:00:00Z", }, }) + "\n") f.close() output = provenance_dashboard(f.name) self.assertIn("PROVENANCE DASHBOARD", output) self.assertIn("timmy-curated", output) self.assertIn("hermes3", output) self.assertIn("curated", output) if __name__ == "__main__": unittest.main()