Merge PR #760: training/tests/test_provenance.py (added)
This commit is contained in:
@@ -1,158 +1,240 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for training_pair_provenance module.
|
||||
"""
|
||||
"""Tests for training_pair_provenance.py"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "training"))
|
||||
# Adjust import path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
||||
from training_pair_provenance import (
|
||||
add_provenance,
|
||||
make_provenance,
|
||||
attach_provenance,
|
||||
extract_trajectory_provenance,
|
||||
pair_fingerprint,
|
||||
validate_provenance,
|
||||
get_provenance_stats,
|
||||
extract_provenance_from_trajectory,
|
||||
ProvenanceMetadata
|
||||
provenance_dashboard,
|
||||
backfill_provenance,
|
||||
load_jsonl,
|
||||
save_jsonl,
|
||||
VALID_SOURCES,
|
||||
REQUIRED_FIELDS,
|
||||
)
|
||||
|
||||
|
||||
def test_add_provenance():
|
||||
"""Test adding provenance to a pair."""
|
||||
pair = {
|
||||
"id": "test_001",
|
||||
"conversations": [
|
||||
{"from": "human", "value": "Hello"},
|
||||
{"from": "gpt", "value": "Hi"}
|
||||
]
|
||||
}
|
||||
class TestMakeProvenance(unittest.TestCase):
|
||||
def test_creates_valid_provenance(self):
|
||||
prov = make_provenance(
|
||||
source="curated",
|
||||
source_session_id="test-001",
|
||||
model="timmy-curated",
|
||||
)
|
||||
for field in REQUIRED_FIELDS:
|
||||
self.assertIn(field, prov)
|
||||
self.assertEqual(prov["source"], "curated")
|
||||
self.assertEqual(prov["source_session_id"], "test-001")
|
||||
self.assertEqual(prov["model"], "timmy-curated")
|
||||
|
||||
result = add_provenance(
|
||||
pair,
|
||||
source_session_id="session_123",
|
||||
source_type="trajectory",
|
||||
model="hermes-4.3-70b"
|
||||
)
|
||||
def test_rejects_invalid_source(self):
|
||||
with self.assertRaises(ValueError):
|
||||
make_provenance(source="bogus", source_session_id="x", model="y")
|
||||
|
||||
assert "provenance" in result
|
||||
assert result["provenance"]["source_session_id"] == "session_123"
|
||||
assert result["provenance"]["source_type"] == "trajectory"
|
||||
assert result["provenance"]["model"] == "hermes-4.3-70b"
|
||||
assert "timestamp" in result["provenance"]
|
||||
print("✓ test_add_provenance")
|
||||
def test_accepts_all_valid_sources(self):
|
||||
for source in VALID_SOURCES:
|
||||
prov = make_provenance(source=source, source_session_id="x", model="y")
|
||||
self.assertEqual(prov["source"], source)
|
||||
|
||||
def test_includes_extras(self):
|
||||
prov = make_provenance(
|
||||
source="curated",
|
||||
source_session_id="x",
|
||||
model="y",
|
||||
extras={"custom_field": "value"},
|
||||
)
|
||||
self.assertEqual(prov["custom_field"], "value")
|
||||
|
||||
def test_uses_provided_timestamp(self):
|
||||
ts = "2026-01-01T00:00:00Z"
|
||||
prov = make_provenance(source="curated", source_session_id="x", model="y", timestamp=ts)
|
||||
self.assertEqual(prov["timestamp"], ts)
|
||||
|
||||
def test_generates_timestamp_if_not_provided(self):
|
||||
prov = make_provenance(source="curated", source_session_id="x", model="y")
|
||||
self.assertIn("T", prov["timestamp"])
|
||||
self.assertIn("Z", prov["timestamp"])
|
||||
|
||||
|
||||
def test_validate_provenance_valid():
|
||||
"""Test validation of valid provenance."""
|
||||
pair = {
|
||||
"id": "test",
|
||||
"provenance": {
|
||||
"source_session_id": "session_1",
|
||||
"source_type": "trajectory",
|
||||
"model": "hermes-4.3",
|
||||
"timestamp": "2026-04-15T12:00:00"
|
||||
class TestAttachProvenance(unittest.TestCase):
|
||||
def test_attaches_to_pair(self):
|
||||
pair = {"id": "test", "conversations": []}
|
||||
result = attach_provenance(pair, source="curated", source_session_id="s1", model="m1")
|
||||
self.assertIn("provenance", result)
|
||||
self.assertEqual(result["provenance"]["source"], "curated")
|
||||
|
||||
def test_does_not_overwrite_existing(self):
|
||||
pair = {
|
||||
"id": "test",
|
||||
"provenance": {"source": "original", "source_session_id": "old", "model": "x"},
|
||||
}
|
||||
}
|
||||
result = attach_provenance(pair, source="new", source_session_id="new", model="y")
|
||||
self.assertEqual(result["provenance"]["source"], "original")
|
||||
|
||||
valid, errors = validate_provenance(pair)
|
||||
assert valid, f"Expected valid, got errors: {errors}"
|
||||
assert len(errors) == 0
|
||||
print("✓ test_validate_provenance_valid")
|
||||
|
||||
|
||||
def test_validate_provenance_missing():
|
||||
"""Test validation fails on missing provenance."""
|
||||
pair = {"id": "test"}
|
||||
|
||||
valid, errors = validate_provenance(pair)
|
||||
assert not valid
|
||||
assert "Missing provenance metadata" in errors
|
||||
print("✓ test_validate_provenance_missing")
|
||||
|
||||
|
||||
def test_validate_provenance_invalid_type():
|
||||
"""Test validation fails on invalid source_type."""
|
||||
pair = {
|
||||
"id": "test",
|
||||
"provenance": {
|
||||
"source_session_id": "session_1",
|
||||
"source_type": "invalid",
|
||||
"model": "hermes-4.3",
|
||||
"timestamp": "2026-04-15T12:00:00"
|
||||
def test_overwrites_with_force(self):
|
||||
pair = {
|
||||
"id": "test",
|
||||
"provenance": {"source": "curated", "source_session_id": "old", "model": "x", "timestamp": "t"},
|
||||
}
|
||||
}
|
||||
|
||||
valid, errors = validate_provenance(pair)
|
||||
assert not valid
|
||||
assert any("Invalid source_type" in e for e in errors)
|
||||
print("✓ test_validate_provenance_invalid_type")
|
||||
result = attach_provenance(
|
||||
pair, source="backfill", source_session_id="new", model="y", extras={"force": True}
|
||||
)
|
||||
self.assertEqual(result["provenance"]["source"], "backfill")
|
||||
|
||||
|
||||
def test_get_provenance_stats():
|
||||
"""Test statistics computation."""
|
||||
pairs = [
|
||||
{"provenance": {"source_type": "trajectory", "model": "hermes-4.3"}},
|
||||
{"provenance": {"source_type": "curated", "model": "timmy-curated"}},
|
||||
{"provenance": {"source_type": "trajectory", "model": "hermes-4.3", "excluded": True}},
|
||||
{}, # No provenance
|
||||
]
|
||||
class TestExtractTrajectoryProvenance(unittest.TestCase):
|
||||
def test_extracts_from_trajectory(self):
|
||||
entry = {
|
||||
"id": "session-123",
|
||||
"model": "hermes3:latest",
|
||||
"started_at": "2026-04-14T10:00:00",
|
||||
}
|
||||
prov = extract_trajectory_provenance(entry)
|
||||
self.assertEqual(prov["source_session_id"], "session-123")
|
||||
self.assertEqual(prov["model"], "hermes3:latest")
|
||||
self.assertEqual(prov["timestamp"], "2026-04-14T10:00:00")
|
||||
|
||||
stats = get_provenance_stats(pairs)
|
||||
|
||||
assert stats["total_pairs"] == 4
|
||||
assert stats["with_provenance"] == 3
|
||||
assert stats["coverage_pct"] == 75.0
|
||||
assert stats["by_source_type"]["trajectory"] == 2
|
||||
assert stats["by_source_type"]["curated"] == 1
|
||||
assert stats["by_model"]["hermes-4.3"] == 2
|
||||
assert stats["excluded"] == 1
|
||||
print("✓ test_get_provenance_stats")
|
||||
def test_uses_defaults_for_missing(self):
|
||||
entry = {}
|
||||
prov = extract_trajectory_provenance(entry)
|
||||
self.assertEqual(prov["source_session_id"], "unknown")
|
||||
self.assertEqual(prov["model"], "unknown")
|
||||
self.assertIn("T", prov["timestamp"])
|
||||
|
||||
|
||||
def test_extract_provenance_from_trajectory():
|
||||
"""Test extracting provenance from trajectory data."""
|
||||
trajectory = {
|
||||
"id": "traj_001",
|
||||
"model": "nexus-consciousness",
|
||||
"started_at": "2026-04-15T10:00:00"
|
||||
}
|
||||
class TestPairFingerprint(unittest.TestCase):
|
||||
def test_same_content_same_hash(self):
|
||||
pair1 = {"conversations": [{"from": "human", "value": "hi"}, {"from": "gpt", "value": "hello"}]}
|
||||
pair2 = {"conversations": [{"from": "human", "value": "hi"}, {"from": "gpt", "value": "hello"}]}
|
||||
self.assertEqual(pair_fingerprint(pair1), pair_fingerprint(pair2))
|
||||
|
||||
result = extract_provenance_from_trajectory(trajectory)
|
||||
def test_different_content_different_hash(self):
|
||||
pair1 = {"conversations": [{"from": "human", "value": "hi"}]}
|
||||
pair2 = {"conversations": [{"from": "human", "value": "bye"}]}
|
||||
self.assertNotEqual(pair_fingerprint(pair1), pair_fingerprint(pair2))
|
||||
|
||||
assert result["source_session_id"] == "traj_001"
|
||||
assert result["source_type"] == "trajectory"
|
||||
assert result["model"] == "nexus-consciousness"
|
||||
assert result["timestamp"] == "2026-04-15T10:00:00"
|
||||
print("✓ test_extract_provenance_from_trajectory")
|
||||
def test_ignores_provenance_in_hash(self):
|
||||
pair1 = {
|
||||
"conversations": [{"from": "human", "value": "hi"}],
|
||||
"provenance": {"source": "a"},
|
||||
}
|
||||
pair2 = {
|
||||
"conversations": [{"from": "human", "value": "hi"}],
|
||||
"provenance": {"source": "b"},
|
||||
}
|
||||
self.assertEqual(pair_fingerprint(pair1), pair_fingerprint(pair2))
|
||||
|
||||
def test_ignores_system_prompt(self):
|
||||
pair1 = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "prompt A"},
|
||||
{"from": "human", "value": "hi"},
|
||||
]
|
||||
}
|
||||
pair2 = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "prompt B"},
|
||||
{"from": "human", "value": "hi"},
|
||||
]
|
||||
}
|
||||
self.assertEqual(pair_fingerprint(pair1), pair_fingerprint(pair2))
|
||||
|
||||
|
||||
def test_provenance_metadata_dataclass():
|
||||
"""Test ProvenanceMetadata dataclass."""
|
||||
meta = ProvenanceMetadata(
|
||||
source_session_id="session_1",
|
||||
source_type="curated",
|
||||
model="timmy-curated",
|
||||
timestamp="2026-04-15T12:00:00"
|
||||
)
|
||||
class TestValidateProvenance(unittest.TestCase):
|
||||
def _write_jsonl(self, entries):
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
|
||||
for entry in entries:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
f.close()
|
||||
return f.name
|
||||
|
||||
d = meta.to_dict()
|
||||
assert d["source_session_id"] == "session_1"
|
||||
assert d["source_type"] == "curated"
|
||||
assert "excluded" not in d # False is excluded
|
||||
print("✓ test_provenance_metadata_dataclass")
|
||||
def test_all_valid(self):
|
||||
path = self._write_jsonl([
|
||||
{"id": "1", "provenance": {"source": "curated", "source_session_id": "s1", "model": "m1", "timestamp": "2026-01-01T00:00:00Z"}},
|
||||
{"id": "2", "provenance": {"source": "trajectory", "source_session_id": "s2", "model": "m2", "timestamp": "2026-01-01T00:00:00Z"}},
|
||||
])
|
||||
report = validate_provenance(path)
|
||||
self.assertEqual(report["missing_provenance"], 0)
|
||||
self.assertEqual(report["missing_fields"], 0)
|
||||
self.assertEqual(report["coverage"], 100.0)
|
||||
|
||||
def test_missing_provenance(self):
|
||||
path = self._write_jsonl([
|
||||
{"id": "1", "conversations": []},
|
||||
])
|
||||
report = validate_provenance(path)
|
||||
self.assertEqual(report["missing_provenance"], 1)
|
||||
self.assertEqual(report["coverage"], 0.0)
|
||||
|
||||
def test_missing_fields(self):
|
||||
path = self._write_jsonl([
|
||||
{"id": "1", "provenance": {"source": "curated"}}, # missing session_id, model, timestamp
|
||||
])
|
||||
report = validate_provenance(path)
|
||||
self.assertEqual(report["missing_fields"], 1)
|
||||
|
||||
def test_invalid_source(self):
|
||||
path = self._write_jsonl([
|
||||
{"id": "1", "provenance": {"source": "bogus", "source_session_id": "s1", "model": "m1", "timestamp": "2026-01-01T00:00:00Z"}},
|
||||
])
|
||||
report = validate_provenance(path)
|
||||
self.assertEqual(report["invalid_source"], 1)
|
||||
|
||||
|
||||
class TestBackfillProvenance(unittest.TestCase):
|
||||
def _write_jsonl(self, entries):
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
|
||||
for entry in entries:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
f.close()
|
||||
return f.name
|
||||
|
||||
def test_backfills_missing(self):
|
||||
path = self._write_jsonl([
|
||||
{"id": "1", "conversations": []},
|
||||
{"id": "2", "conversations": [], "provenance": {"source": "existing", "source_session_id": "s", "model": "m", "timestamp": "t"}},
|
||||
])
|
||||
out = tempfile.NamedTemporaryFile(suffix=".jsonl", delete=False).name
|
||||
stats = backfill_provenance(path, source="backfill", model="test-model", output_path=out)
|
||||
self.assertEqual(stats["backfilled"], 1)
|
||||
self.assertEqual(stats["already_had"], 1)
|
||||
|
||||
entries = load_jsonl(out)
|
||||
self.assertEqual(entries[0]["provenance"]["source"], "backfill")
|
||||
self.assertEqual(entries[1]["provenance"]["source"], "existing")
|
||||
|
||||
|
||||
class TestDashboard(unittest.TestCase):
|
||||
def test_generates_output(self):
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
|
||||
for i in range(5):
|
||||
f.write(json.dumps({
|
||||
"id": str(i),
|
||||
"provenance": {
|
||||
"source": "curated" if i < 3 else "trajectory",
|
||||
"source_session_id": f"s{i}",
|
||||
"model": "timmy-curated" if i < 3 else "hermes3",
|
||||
"timestamp": f"2026-04-{14+i:02d}T00:00:00Z",
|
||||
},
|
||||
}) + "\n")
|
||||
f.close()
|
||||
output = provenance_dashboard(f.name)
|
||||
self.assertIn("PROVENANCE DASHBOARD", output)
|
||||
self.assertIn("timmy-curated", output)
|
||||
self.assertIn("hermes3", output)
|
||||
self.assertIn("curated", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Running provenance tests...")
|
||||
print()
|
||||
|
||||
test_add_provenance()
|
||||
test_validate_provenance_valid()
|
||||
test_validate_provenance_missing()
|
||||
test_validate_provenance_invalid_type()
|
||||
test_get_provenance_stats()
|
||||
test_extract_provenance_from_trajectory()
|
||||
test_provenance_metadata_dataclass()
|
||||
|
||||
print()
|
||||
print("All tests passed!")
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user