159 lines
4.5 KiB
Python
159 lines
4.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Tests for training_pair_provenance module.
|
|
"""
|
|
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "training"))
|
|
from training_pair_provenance import (
|
|
add_provenance,
|
|
validate_provenance,
|
|
get_provenance_stats,
|
|
extract_provenance_from_trajectory,
|
|
ProvenanceMetadata
|
|
)
|
|
|
|
|
|
def test_add_provenance():
|
|
"""Test adding provenance to a pair."""
|
|
pair = {
|
|
"id": "test_001",
|
|
"conversations": [
|
|
{"from": "human", "value": "Hello"},
|
|
{"from": "gpt", "value": "Hi"}
|
|
]
|
|
}
|
|
|
|
result = add_provenance(
|
|
pair,
|
|
source_session_id="session_123",
|
|
source_type="trajectory",
|
|
model="hermes-4.3-70b"
|
|
)
|
|
|
|
assert "provenance" in result
|
|
assert result["provenance"]["source_session_id"] == "session_123"
|
|
assert result["provenance"]["source_type"] == "trajectory"
|
|
assert result["provenance"]["model"] == "hermes-4.3-70b"
|
|
assert "timestamp" in result["provenance"]
|
|
print("✓ test_add_provenance")
|
|
|
|
|
|
def test_validate_provenance_valid():
|
|
"""Test validation of valid provenance."""
|
|
pair = {
|
|
"id": "test",
|
|
"provenance": {
|
|
"source_session_id": "session_1",
|
|
"source_type": "trajectory",
|
|
"model": "hermes-4.3",
|
|
"timestamp": "2026-04-15T12:00:00"
|
|
}
|
|
}
|
|
|
|
valid, errors = validate_provenance(pair)
|
|
assert valid, f"Expected valid, got errors: {errors}"
|
|
assert len(errors) == 0
|
|
print("✓ test_validate_provenance_valid")
|
|
|
|
|
|
def test_validate_provenance_missing():
|
|
"""Test validation fails on missing provenance."""
|
|
pair = {"id": "test"}
|
|
|
|
valid, errors = validate_provenance(pair)
|
|
assert not valid
|
|
assert "Missing provenance metadata" in errors
|
|
print("✓ test_validate_provenance_missing")
|
|
|
|
|
|
def test_validate_provenance_invalid_type():
|
|
"""Test validation fails on invalid source_type."""
|
|
pair = {
|
|
"id": "test",
|
|
"provenance": {
|
|
"source_session_id": "session_1",
|
|
"source_type": "invalid",
|
|
"model": "hermes-4.3",
|
|
"timestamp": "2026-04-15T12:00:00"
|
|
}
|
|
}
|
|
|
|
valid, errors = validate_provenance(pair)
|
|
assert not valid
|
|
assert any("Invalid source_type" in e for e in errors)
|
|
print("✓ test_validate_provenance_invalid_type")
|
|
|
|
|
|
def test_get_provenance_stats():
|
|
"""Test statistics computation."""
|
|
pairs = [
|
|
{"provenance": {"source_type": "trajectory", "model": "hermes-4.3"}},
|
|
{"provenance": {"source_type": "curated", "model": "timmy-curated"}},
|
|
{"provenance": {"source_type": "trajectory", "model": "hermes-4.3", "excluded": True}},
|
|
{}, # No provenance
|
|
]
|
|
|
|
stats = get_provenance_stats(pairs)
|
|
|
|
assert stats["total_pairs"] == 4
|
|
assert stats["with_provenance"] == 3
|
|
assert stats["coverage_pct"] == 75.0
|
|
assert stats["by_source_type"]["trajectory"] == 2
|
|
assert stats["by_source_type"]["curated"] == 1
|
|
assert stats["by_model"]["hermes-4.3"] == 2
|
|
assert stats["excluded"] == 1
|
|
print("✓ test_get_provenance_stats")
|
|
|
|
|
|
def test_extract_provenance_from_trajectory():
|
|
"""Test extracting provenance from trajectory data."""
|
|
trajectory = {
|
|
"id": "traj_001",
|
|
"model": "nexus-consciousness",
|
|
"started_at": "2026-04-15T10:00:00"
|
|
}
|
|
|
|
result = extract_provenance_from_trajectory(trajectory)
|
|
|
|
assert result["source_session_id"] == "traj_001"
|
|
assert result["source_type"] == "trajectory"
|
|
assert result["model"] == "nexus-consciousness"
|
|
assert result["timestamp"] == "2026-04-15T10:00:00"
|
|
print("✓ test_extract_provenance_from_trajectory")
|
|
|
|
|
|
def test_provenance_metadata_dataclass():
|
|
"""Test ProvenanceMetadata dataclass."""
|
|
meta = ProvenanceMetadata(
|
|
source_session_id="session_1",
|
|
source_type="curated",
|
|
model="timmy-curated",
|
|
timestamp="2026-04-15T12:00:00"
|
|
)
|
|
|
|
d = meta.to_dict()
|
|
assert d["source_session_id"] == "session_1"
|
|
assert d["source_type"] == "curated"
|
|
assert "excluded" not in d # False is excluded
|
|
print("✓ test_provenance_metadata_dataclass")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("Running provenance tests...")
|
|
print()
|
|
|
|
test_add_provenance()
|
|
test_validate_provenance_valid()
|
|
test_validate_provenance_missing()
|
|
test_validate_provenance_invalid_type()
|
|
test_get_provenance_stats()
|
|
test_extract_provenance_from_trajectory()
|
|
test_provenance_metadata_dataclass()
|
|
|
|
print()
|
|
print("All tests passed!")
|