test: provenance tracking tests — 21 cases (#752)

This commit is contained in:
2026-04-15 22:43:29 +00:00
parent 3b6ff9038e
commit b5f480da47

View File

@@ -0,0 +1,240 @@
#!/usr/bin/env python3
"""Tests for training_pair_provenance.py"""
import json
import tempfile
import unittest
from pathlib import Path
from datetime import datetime, timezone
# Adjust import path
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent))
from training_pair_provenance import (
make_provenance,
attach_provenance,
extract_trajectory_provenance,
pair_fingerprint,
validate_provenance,
provenance_dashboard,
backfill_provenance,
load_jsonl,
save_jsonl,
VALID_SOURCES,
REQUIRED_FIELDS,
)
class TestMakeProvenance(unittest.TestCase):
def test_creates_valid_provenance(self):
prov = make_provenance(
source="curated",
source_session_id="test-001",
model="timmy-curated",
)
for field in REQUIRED_FIELDS:
self.assertIn(field, prov)
self.assertEqual(prov["source"], "curated")
self.assertEqual(prov["source_session_id"], "test-001")
self.assertEqual(prov["model"], "timmy-curated")
def test_rejects_invalid_source(self):
with self.assertRaises(ValueError):
make_provenance(source="bogus", source_session_id="x", model="y")
def test_accepts_all_valid_sources(self):
for source in VALID_SOURCES:
prov = make_provenance(source=source, source_session_id="x", model="y")
self.assertEqual(prov["source"], source)
def test_includes_extras(self):
prov = make_provenance(
source="curated",
source_session_id="x",
model="y",
extras={"custom_field": "value"},
)
self.assertEqual(prov["custom_field"], "value")
def test_uses_provided_timestamp(self):
ts = "2026-01-01T00:00:00Z"
prov = make_provenance(source="curated", source_session_id="x", model="y", timestamp=ts)
self.assertEqual(prov["timestamp"], ts)
def test_generates_timestamp_if_not_provided(self):
prov = make_provenance(source="curated", source_session_id="x", model="y")
self.assertIn("T", prov["timestamp"])
self.assertIn("Z", prov["timestamp"])
class TestAttachProvenance(unittest.TestCase):
def test_attaches_to_pair(self):
pair = {"id": "test", "conversations": []}
result = attach_provenance(pair, source="curated", source_session_id="s1", model="m1")
self.assertIn("provenance", result)
self.assertEqual(result["provenance"]["source"], "curated")
def test_does_not_overwrite_existing(self):
pair = {
"id": "test",
"provenance": {"source": "original", "source_session_id": "old", "model": "x"},
}
result = attach_provenance(pair, source="new", source_session_id="new", model="y")
self.assertEqual(result["provenance"]["source"], "original")
def test_overwrites_with_force(self):
pair = {
"id": "test",
"provenance": {"source": "curated", "source_session_id": "old", "model": "x", "timestamp": "t"},
}
result = attach_provenance(
pair, source="backfill", source_session_id="new", model="y", extras={"force": True}
)
self.assertEqual(result["provenance"]["source"], "backfill")
class TestExtractTrajectoryProvenance(unittest.TestCase):
def test_extracts_from_trajectory(self):
entry = {
"id": "session-123",
"model": "hermes3:latest",
"started_at": "2026-04-14T10:00:00",
}
prov = extract_trajectory_provenance(entry)
self.assertEqual(prov["source_session_id"], "session-123")
self.assertEqual(prov["model"], "hermes3:latest")
self.assertEqual(prov["timestamp"], "2026-04-14T10:00:00")
def test_uses_defaults_for_missing(self):
entry = {}
prov = extract_trajectory_provenance(entry)
self.assertEqual(prov["source_session_id"], "unknown")
self.assertEqual(prov["model"], "unknown")
self.assertIn("T", prov["timestamp"])
class TestPairFingerprint(unittest.TestCase):
def test_same_content_same_hash(self):
pair1 = {"conversations": [{"from": "human", "value": "hi"}, {"from": "gpt", "value": "hello"}]}
pair2 = {"conversations": [{"from": "human", "value": "hi"}, {"from": "gpt", "value": "hello"}]}
self.assertEqual(pair_fingerprint(pair1), pair_fingerprint(pair2))
def test_different_content_different_hash(self):
pair1 = {"conversations": [{"from": "human", "value": "hi"}]}
pair2 = {"conversations": [{"from": "human", "value": "bye"}]}
self.assertNotEqual(pair_fingerprint(pair1), pair_fingerprint(pair2))
def test_ignores_provenance_in_hash(self):
pair1 = {
"conversations": [{"from": "human", "value": "hi"}],
"provenance": {"source": "a"},
}
pair2 = {
"conversations": [{"from": "human", "value": "hi"}],
"provenance": {"source": "b"},
}
self.assertEqual(pair_fingerprint(pair1), pair_fingerprint(pair2))
def test_ignores_system_prompt(self):
pair1 = {
"conversations": [
{"from": "system", "value": "prompt A"},
{"from": "human", "value": "hi"},
]
}
pair2 = {
"conversations": [
{"from": "system", "value": "prompt B"},
{"from": "human", "value": "hi"},
]
}
self.assertEqual(pair_fingerprint(pair1), pair_fingerprint(pair2))
class TestValidateProvenance(unittest.TestCase):
def _write_jsonl(self, entries):
f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
for entry in entries:
f.write(json.dumps(entry) + "\n")
f.close()
return f.name
def test_all_valid(self):
path = self._write_jsonl([
{"id": "1", "provenance": {"source": "curated", "source_session_id": "s1", "model": "m1", "timestamp": "2026-01-01T00:00:00Z"}},
{"id": "2", "provenance": {"source": "trajectory", "source_session_id": "s2", "model": "m2", "timestamp": "2026-01-01T00:00:00Z"}},
])
report = validate_provenance(path)
self.assertEqual(report["missing_provenance"], 0)
self.assertEqual(report["missing_fields"], 0)
self.assertEqual(report["coverage"], 100.0)
def test_missing_provenance(self):
path = self._write_jsonl([
{"id": "1", "conversations": []},
])
report = validate_provenance(path)
self.assertEqual(report["missing_provenance"], 1)
self.assertEqual(report["coverage"], 0.0)
def test_missing_fields(self):
path = self._write_jsonl([
{"id": "1", "provenance": {"source": "curated"}}, # missing session_id, model, timestamp
])
report = validate_provenance(path)
self.assertEqual(report["missing_fields"], 1)
def test_invalid_source(self):
path = self._write_jsonl([
{"id": "1", "provenance": {"source": "bogus", "source_session_id": "s1", "model": "m1", "timestamp": "2026-01-01T00:00:00Z"}},
])
report = validate_provenance(path)
self.assertEqual(report["invalid_source"], 1)
class TestBackfillProvenance(unittest.TestCase):
def _write_jsonl(self, entries):
f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
for entry in entries:
f.write(json.dumps(entry) + "\n")
f.close()
return f.name
def test_backfills_missing(self):
path = self._write_jsonl([
{"id": "1", "conversations": []},
{"id": "2", "conversations": [], "provenance": {"source": "existing", "source_session_id": "s", "model": "m", "timestamp": "t"}},
])
out = tempfile.NamedTemporaryFile(suffix=".jsonl", delete=False).name
stats = backfill_provenance(path, source="backfill", model="test-model", output_path=out)
self.assertEqual(stats["backfilled"], 1)
self.assertEqual(stats["already_had"], 1)
entries = load_jsonl(out)
self.assertEqual(entries[0]["provenance"]["source"], "backfill")
self.assertEqual(entries[1]["provenance"]["source"], "existing")
class TestDashboard(unittest.TestCase):
def test_generates_output(self):
f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
for i in range(5):
f.write(json.dumps({
"id": str(i),
"provenance": {
"source": "curated" if i < 3 else "trajectory",
"source_session_id": f"s{i}",
"model": "timmy-curated" if i < 3 else "hermes3",
"timestamp": f"2026-04-{14+i:02d}T00:00:00Z",
},
}) + "\n")
f.close()
output = provenance_dashboard(f.name)
self.assertIn("PROVENANCE DASHBOARD", output)
self.assertIn("timmy-curated", output)
self.assertIn("hermes3", output)
self.assertIn("curated", output)
if __name__ == "__main__":
unittest.main()