diff --git a/training/tests/test_provenance.py b/training/tests/test_provenance.py index b7c62840..8e0f3fb8 100644 --- a/training/tests/test_provenance.py +++ b/training/tests/test_provenance.py @@ -1,158 +1,240 @@ #!/usr/bin/env python3 -""" -Tests for training_pair_provenance module. -""" +"""Tests for training_pair_provenance.py""" import json -import sys +import tempfile +import unittest from pathlib import Path +from datetime import datetime, timezone -sys.path.insert(0, str(Path(__file__).parent.parent / "training")) +# Adjust import path +import sys +sys.path.insert(0, str(Path(__file__).resolve().parent)) from training_pair_provenance import ( - add_provenance, + make_provenance, + attach_provenance, + extract_trajectory_provenance, + pair_fingerprint, validate_provenance, - get_provenance_stats, - extract_provenance_from_trajectory, - ProvenanceMetadata + provenance_dashboard, + backfill_provenance, + load_jsonl, + save_jsonl, + VALID_SOURCES, + REQUIRED_FIELDS, ) -def test_add_provenance(): - """Test adding provenance to a pair.""" - pair = { - "id": "test_001", - "conversations": [ - {"from": "human", "value": "Hello"}, - {"from": "gpt", "value": "Hi"} - ] - } +class TestMakeProvenance(unittest.TestCase): + def test_creates_valid_provenance(self): + prov = make_provenance( + source="curated", + source_session_id="test-001", + model="timmy-curated", + ) + for field in REQUIRED_FIELDS: + self.assertIn(field, prov) + self.assertEqual(prov["source"], "curated") + self.assertEqual(prov["source_session_id"], "test-001") + self.assertEqual(prov["model"], "timmy-curated") - result = add_provenance( - pair, - source_session_id="session_123", - source_type="trajectory", - model="hermes-4.3-70b" - ) + def test_rejects_invalid_source(self): + with self.assertRaises(ValueError): + make_provenance(source="bogus", source_session_id="x", model="y") - assert "provenance" in result - assert result["provenance"]["source_session_id"] == "session_123" - assert result["provenance"]["source_type"] == "trajectory" - assert result["provenance"]["model"] == "hermes-4.3-70b" - assert "timestamp" in result["provenance"] - print("✓ test_add_provenance") + def test_accepts_all_valid_sources(self): + for source in VALID_SOURCES: + prov = make_provenance(source=source, source_session_id="x", model="y") + self.assertEqual(prov["source"], source) + + def test_includes_extras(self): + prov = make_provenance( + source="curated", + source_session_id="x", + model="y", + extras={"custom_field": "value"}, + ) + self.assertEqual(prov["custom_field"], "value") + + def test_uses_provided_timestamp(self): + ts = "2026-01-01T00:00:00Z" + prov = make_provenance(source="curated", source_session_id="x", model="y", timestamp=ts) + self.assertEqual(prov["timestamp"], ts) + + def test_generates_timestamp_if_not_provided(self): + prov = make_provenance(source="curated", source_session_id="x", model="y") + self.assertIn("T", prov["timestamp"]) + self.assertIn("Z", prov["timestamp"]) -def test_validate_provenance_valid(): - """Test validation of valid provenance.""" - pair = { - "id": "test", - "provenance": { - "source_session_id": "session_1", - "source_type": "trajectory", - "model": "hermes-4.3", - "timestamp": "2026-04-15T12:00:00" +class TestAttachProvenance(unittest.TestCase): + def test_attaches_to_pair(self): + pair = {"id": "test", "conversations": []} + result = attach_provenance(pair, source="curated", source_session_id="s1", model="m1") + self.assertIn("provenance", result) + self.assertEqual(result["provenance"]["source"], "curated") + + def test_does_not_overwrite_existing(self): + pair = { + "id": "test", + "provenance": {"source": "original", "source_session_id": "old", "model": "x"}, } - } + result = attach_provenance(pair, source="new", source_session_id="new", model="y") + self.assertEqual(result["provenance"]["source"], "original") - valid, errors = validate_provenance(pair) - assert valid, f"Expected valid, got errors: {errors}" - assert len(errors) == 0 - print("✓ test_validate_provenance_valid") - - -def test_validate_provenance_missing(): - """Test validation fails on missing provenance.""" - pair = {"id": "test"} - - valid, errors = validate_provenance(pair) - assert not valid - assert "Missing provenance metadata" in errors - print("✓ test_validate_provenance_missing") - - -def test_validate_provenance_invalid_type(): - """Test validation fails on invalid source_type.""" - pair = { - "id": "test", - "provenance": { - "source_session_id": "session_1", - "source_type": "invalid", - "model": "hermes-4.3", - "timestamp": "2026-04-15T12:00:00" + def test_overwrites_with_force(self): + pair = { + "id": "test", + "provenance": {"source": "curated", "source_session_id": "old", "model": "x", "timestamp": "t"}, } - } - - valid, errors = validate_provenance(pair) - assert not valid - assert any("Invalid source_type" in e for e in errors) - print("✓ test_validate_provenance_invalid_type") + result = attach_provenance( + pair, source="backfill", source_session_id="new", model="y", extras={"force": True} + ) + self.assertEqual(result["provenance"]["source"], "backfill") -def test_get_provenance_stats(): - """Test statistics computation.""" - pairs = [ - {"provenance": {"source_type": "trajectory", "model": "hermes-4.3"}}, - {"provenance": {"source_type": "curated", "model": "timmy-curated"}}, - {"provenance": {"source_type": "trajectory", "model": "hermes-4.3", "excluded": True}}, - {}, # No provenance - ] +class TestExtractTrajectoryProvenance(unittest.TestCase): + def test_extracts_from_trajectory(self): + entry = { + "id": "session-123", + "model": "hermes3:latest", + "started_at": "2026-04-14T10:00:00", + } + prov = extract_trajectory_provenance(entry) + self.assertEqual(prov["source_session_id"], "session-123") + self.assertEqual(prov["model"], "hermes3:latest") + self.assertEqual(prov["timestamp"], "2026-04-14T10:00:00") - stats = get_provenance_stats(pairs) - - assert stats["total_pairs"] == 4 - assert stats["with_provenance"] == 3 - assert stats["coverage_pct"] == 75.0 - assert stats["by_source_type"]["trajectory"] == 2 - assert stats["by_source_type"]["curated"] == 1 - assert stats["by_model"]["hermes-4.3"] == 2 - assert stats["excluded"] == 1 - print("✓ test_get_provenance_stats") + def test_uses_defaults_for_missing(self): + entry = {} + prov = extract_trajectory_provenance(entry) + self.assertEqual(prov["source_session_id"], "unknown") + self.assertEqual(prov["model"], "unknown") + self.assertIn("T", prov["timestamp"]) -def test_extract_provenance_from_trajectory(): - """Test extracting provenance from trajectory data.""" - trajectory = { - "id": "traj_001", - "model": "nexus-consciousness", - "started_at": "2026-04-15T10:00:00" - } +class TestPairFingerprint(unittest.TestCase): + def test_same_content_same_hash(self): + pair1 = {"conversations": [{"from": "human", "value": "hi"}, {"from": "gpt", "value": "hello"}]} + pair2 = {"conversations": [{"from": "human", "value": "hi"}, {"from": "gpt", "value": "hello"}]} + self.assertEqual(pair_fingerprint(pair1), pair_fingerprint(pair2)) - result = extract_provenance_from_trajectory(trajectory) + def test_different_content_different_hash(self): + pair1 = {"conversations": [{"from": "human", "value": "hi"}]} + pair2 = {"conversations": [{"from": "human", "value": "bye"}]} + self.assertNotEqual(pair_fingerprint(pair1), pair_fingerprint(pair2)) - assert result["source_session_id"] == "traj_001" - assert result["source_type"] == "trajectory" - assert result["model"] == "nexus-consciousness" - assert result["timestamp"] == "2026-04-15T10:00:00" - print("✓ test_extract_provenance_from_trajectory") + def test_ignores_provenance_in_hash(self): + pair1 = { + "conversations": [{"from": "human", "value": "hi"}], + "provenance": {"source": "a"}, + } + pair2 = { + "conversations": [{"from": "human", "value": "hi"}], + "provenance": {"source": "b"}, + } + self.assertEqual(pair_fingerprint(pair1), pair_fingerprint(pair2)) + + def test_ignores_system_prompt(self): + pair1 = { + "conversations": [ + {"from": "system", "value": "prompt A"}, + {"from": "human", "value": "hi"}, + ] + } + pair2 = { + "conversations": [ + {"from": "system", "value": "prompt B"}, + {"from": "human", "value": "hi"}, + ] + } + self.assertEqual(pair_fingerprint(pair1), pair_fingerprint(pair2)) -def test_provenance_metadata_dataclass(): - """Test ProvenanceMetadata dataclass.""" - meta = ProvenanceMetadata( - source_session_id="session_1", - source_type="curated", - model="timmy-curated", - timestamp="2026-04-15T12:00:00" - ) +class TestValidateProvenance(unittest.TestCase): + def _write_jsonl(self, entries): + f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) + for entry in entries: + f.write(json.dumps(entry) + "\n") + f.close() + return f.name - d = meta.to_dict() - assert d["source_session_id"] == "session_1" - assert d["source_type"] == "curated" - assert "excluded" not in d # False is excluded - print("✓ test_provenance_metadata_dataclass") + def test_all_valid(self): + path = self._write_jsonl([ + {"id": "1", "provenance": {"source": "curated", "source_session_id": "s1", "model": "m1", "timestamp": "2026-01-01T00:00:00Z"}}, + {"id": "2", "provenance": {"source": "trajectory", "source_session_id": "s2", "model": "m2", "timestamp": "2026-01-01T00:00:00Z"}}, + ]) + report = validate_provenance(path) + self.assertEqual(report["missing_provenance"], 0) + self.assertEqual(report["missing_fields"], 0) + self.assertEqual(report["coverage"], 100.0) + + def test_missing_provenance(self): + path = self._write_jsonl([ + {"id": "1", "conversations": []}, + ]) + report = validate_provenance(path) + self.assertEqual(report["missing_provenance"], 1) + self.assertEqual(report["coverage"], 0.0) + + def test_missing_fields(self): + path = self._write_jsonl([ + {"id": "1", "provenance": {"source": "curated"}}, # missing session_id, model, timestamp + ]) + report = validate_provenance(path) + self.assertEqual(report["missing_fields"], 1) + + def test_invalid_source(self): + path = self._write_jsonl([ + {"id": "1", "provenance": {"source": "bogus", "source_session_id": "s1", "model": "m1", "timestamp": "2026-01-01T00:00:00Z"}}, + ]) + report = validate_provenance(path) + self.assertEqual(report["invalid_source"], 1) + + +class TestBackfillProvenance(unittest.TestCase): + def _write_jsonl(self, entries): + f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) + for entry in entries: + f.write(json.dumps(entry) + "\n") + f.close() + return f.name + + def test_backfills_missing(self): + path = self._write_jsonl([ + {"id": "1", "conversations": []}, + {"id": "2", "conversations": [], "provenance": {"source": "existing", "source_session_id": "s", "model": "m", "timestamp": "t"}}, + ]) + out = tempfile.NamedTemporaryFile(suffix=".jsonl", delete=False).name + stats = backfill_provenance(path, source="backfill", model="test-model", output_path=out) + self.assertEqual(stats["backfilled"], 1) + self.assertEqual(stats["already_had"], 1) + + entries = load_jsonl(out) + self.assertEqual(entries[0]["provenance"]["source"], "backfill") + self.assertEqual(entries[1]["provenance"]["source"], "existing") + + +class TestDashboard(unittest.TestCase): + def test_generates_output(self): + f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) + for i in range(5): + f.write(json.dumps({ + "id": str(i), + "provenance": { + "source": "curated" if i < 3 else "trajectory", + "source_session_id": f"s{i}", + "model": "timmy-curated" if i < 3 else "hermes3", + "timestamp": f"2026-04-{14+i:02d}T00:00:00Z", + }, + }) + "\n") + f.close() + output = provenance_dashboard(f.name) + self.assertIn("PROVENANCE DASHBOARD", output) + self.assertIn("timmy-curated", output) + self.assertIn("hermes3", output) + self.assertIn("curated", output) if __name__ == "__main__": - print("Running provenance tests...") - print() - - test_add_provenance() - test_validate_provenance_valid() - test_validate_provenance_missing() - test_validate_provenance_invalid_type() - test_get_provenance_stats() - test_extract_provenance_from_trajectory() - test_provenance_metadata_dataclass() - - print() - print("All tests passed!") + unittest.main()