diff --git a/training/tests/test_provenance.py b/training/tests/test_provenance.py new file mode 100644 index 00000000..b7c62840 --- /dev/null +++ b/training/tests/test_provenance.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +""" +Tests for training_pair_provenance module. +""" + +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent / "training")) +from training_pair_provenance import ( + add_provenance, + validate_provenance, + get_provenance_stats, + extract_provenance_from_trajectory, + ProvenanceMetadata +) + + +def test_add_provenance(): + """Test adding provenance to a pair.""" + pair = { + "id": "test_001", + "conversations": [ + {"from": "human", "value": "Hello"}, + {"from": "gpt", "value": "Hi"} + ] + } + + result = add_provenance( + pair, + source_session_id="session_123", + source_type="trajectory", + model="hermes-4.3-70b" + ) + + assert "provenance" in result + assert result["provenance"]["source_session_id"] == "session_123" + assert result["provenance"]["source_type"] == "trajectory" + assert result["provenance"]["model"] == "hermes-4.3-70b" + assert "timestamp" in result["provenance"] + print("✓ test_add_provenance") + + +def test_validate_provenance_valid(): + """Test validation of valid provenance.""" + pair = { + "id": "test", + "provenance": { + "source_session_id": "session_1", + "source_type": "trajectory", + "model": "hermes-4.3", + "timestamp": "2026-04-15T12:00:00" + } + } + + valid, errors = validate_provenance(pair) + assert valid, f"Expected valid, got errors: {errors}" + assert len(errors) == 0 + print("✓ test_validate_provenance_valid") + + +def test_validate_provenance_missing(): + """Test validation fails on missing provenance.""" + pair = {"id": "test"} + + valid, errors = validate_provenance(pair) + assert not valid + assert "Missing provenance metadata" in errors + print("✓ test_validate_provenance_missing") + + +def test_validate_provenance_invalid_type(): + """Test validation fails on invalid source_type.""" + pair = { + "id": "test", + "provenance": { + "source_session_id": "session_1", + "source_type": "invalid", + "model": "hermes-4.3", + "timestamp": "2026-04-15T12:00:00" + } + } + + valid, errors = validate_provenance(pair) + assert not valid + assert any("Invalid source_type" in e for e in errors) + print("✓ test_validate_provenance_invalid_type") + + +def test_get_provenance_stats(): + """Test statistics computation.""" + pairs = [ + {"provenance": {"source_type": "trajectory", "model": "hermes-4.3"}}, + {"provenance": {"source_type": "curated", "model": "timmy-curated"}}, + {"provenance": {"source_type": "trajectory", "model": "hermes-4.3", "excluded": True}}, + {}, # No provenance + ] + + stats = get_provenance_stats(pairs) + + assert stats["total_pairs"] == 4 + assert stats["with_provenance"] == 3 + assert stats["coverage_pct"] == 75.0 + assert stats["by_source_type"]["trajectory"] == 2 + assert stats["by_source_type"]["curated"] == 1 + assert stats["by_model"]["hermes-4.3"] == 2 + assert stats["excluded"] == 1 + print("✓ test_get_provenance_stats") + + +def test_extract_provenance_from_trajectory(): + """Test extracting provenance from trajectory data.""" + trajectory = { + "id": "traj_001", + "model": "nexus-consciousness", + "started_at": "2026-04-15T10:00:00" + } + + result = extract_provenance_from_trajectory(trajectory) + + assert result["source_session_id"] == "traj_001" + assert result["source_type"] == "trajectory" + assert result["model"] == "nexus-consciousness" + assert result["timestamp"] == "2026-04-15T10:00:00" + print("✓ test_extract_provenance_from_trajectory") + + +def test_provenance_metadata_dataclass(): + """Test ProvenanceMetadata dataclass.""" + meta = ProvenanceMetadata( + source_session_id="session_1", + source_type="curated", + model="timmy-curated", + timestamp="2026-04-15T12:00:00" + ) + + d = meta.to_dict() + assert d["source_session_id"] == "session_1" + assert d["source_type"] == "curated" + assert "excluded" not in d # False is excluded + print("✓ test_provenance_metadata_dataclass") + + +if __name__ == "__main__": + print("Running provenance tests...") + print() + + test_add_provenance() + test_validate_provenance_valid() + test_validate_provenance_missing() + test_validate_provenance_invalid_type() + test_get_provenance_stats() + test_extract_provenance_from_trajectory() + test_provenance_metadata_dataclass() + + print() + print("All tests passed!")