Merge PR #751: training/test_training_pair_provenance.py (added)
This commit is contained in:
157
training/test_training_pair_provenance.py
Normal file
157
training/test_training_pair_provenance.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for Training Pair Provenance Tracking
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from training_pair_provenance import ProvenanceTracker, load_jsonl, save_jsonl
|
||||
|
||||
|
||||
class TestProvenanceTracker:
|
||||
"""Test the ProvenanceTracker class."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test tracker initialization."""
|
||||
tracker = ProvenanceTracker()
|
||||
assert tracker.stats["total_pairs"] == 0
|
||||
assert tracker.stats["pairs_with_provenance"] == 0
|
||||
assert tracker.stats["pairs_without_provenance"] == 0
|
||||
|
||||
def test_generate_pair_id(self):
|
||||
"""Test pair ID generation."""
|
||||
tracker = ProvenanceTracker()
|
||||
pair = {"prompt": "test", "chosen": "response", "rejected": "bad"}
|
||||
|
||||
id1 = tracker.generate_pair_id(pair)
|
||||
id2 = tracker.generate_pair_id(pair)
|
||||
|
||||
# Same content should generate same ID
|
||||
assert id1 == id2
|
||||
assert len(id1) == 16
|
||||
|
||||
def test_add_provenance(self):
|
||||
"""Test adding provenance to a pair."""
|
||||
tracker = ProvenanceTracker()
|
||||
pair = {"prompt": "test", "chosen": "response", "rejected": "bad"}
|
||||
|
||||
result = tracker.add_provenance(pair, source_session_id="session123", model="test-model")
|
||||
|
||||
assert "provenance" in result
|
||||
assert result["provenance"]["source_session_id"] == "session123"
|
||||
assert result["provenance"]["model"] == "test-model"
|
||||
assert "timestamp" in result["provenance"]
|
||||
assert result["provenance"]["source"] == "curated"
|
||||
assert "content_hash" in result["provenance"]
|
||||
|
||||
def test_extract_provenance_from_existing(self):
|
||||
"""Test extracting provenance from existing fields."""
|
||||
tracker = ProvenanceTracker()
|
||||
pair = {
|
||||
"id": "session456",
|
||||
"model": "claude-3-opus",
|
||||
"started_at": "2024-01-01T00:00:00Z",
|
||||
"conversations": [{"from": "human", "value": "test"}]
|
||||
}
|
||||
|
||||
provenance = tracker.extract_provenance_from_existing(pair)
|
||||
|
||||
assert provenance["source_session_id"] == "session456"
|
||||
assert provenance["model"] == "claude-3-opus"
|
||||
assert provenance["timestamp"] == "2024-01-01T00:00:00Z"
|
||||
assert provenance["source"] == "curated"
|
||||
assert "content_hash" in provenance
|
||||
|
||||
def test_process_pair(self):
|
||||
"""Test processing a pair."""
|
||||
tracker = ProvenanceTracker()
|
||||
pair = {"id": "test123", "model": "test-model", "conversations": []}
|
||||
|
||||
result = tracker.process_pair(pair)
|
||||
|
||||
assert tracker.stats["total_pairs"] == 1
|
||||
assert tracker.stats["pairs_without_provenance"] == 1
|
||||
assert "provenance" in result
|
||||
|
||||
def test_filter_by_provenance(self):
|
||||
"""Test filtering pairs by provenance."""
|
||||
tracker = ProvenanceTracker()
|
||||
|
||||
pairs = [
|
||||
{"provenance": {"model": "anthropic/claude-3-opus"}},
|
||||
{"provenance": {"model": "gpt-4"}},
|
||||
{"provenance": {"model": "anthropic/claude-3-sonnet"}},
|
||||
]
|
||||
|
||||
filtered = tracker.filter_by_provenance(pairs, exclude_models=["anthropic/claude-3-opus", "anthropic/claude-3-sonnet"])
|
||||
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["provenance"]["model"] == "gpt-4"
|
||||
assert tracker.stats["excluded"] == 2
|
||||
|
||||
def test_generate_report(self):
|
||||
"""Test report generation."""
|
||||
tracker = ProvenanceTracker()
|
||||
tracker.stats = {
|
||||
"total_pairs": 10,
|
||||
"pairs_with_provenance": 8,
|
||||
"pairs_without_provenance": 2,
|
||||
"by_model": {"gpt-4": 5, "claude-3": 3},
|
||||
"by_source": {"curated": 8},
|
||||
"excluded": 0
|
||||
}
|
||||
|
||||
report = tracker.generate_report()
|
||||
|
||||
assert "Total pairs: 10" in report
|
||||
assert "Pairs with provenance: 8" in report
|
||||
assert "gpt-4: 5" in report
|
||||
|
||||
|
||||
class TestJsonlFunctions:
|
||||
"""Test JSONL load/save functions."""
|
||||
|
||||
def test_load_jsonl(self):
|
||||
"""Test loading JSONL file."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
f.write('{"id": "1", "value": "test1"}\n')
|
||||
f.write('{"id": "2", "value": "test2"}\n')
|
||||
f.write('{"id": "3", "value": "test3"}\n')
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
entries = load_jsonl(temp_path)
|
||||
assert len(entries) == 3
|
||||
assert entries[0]["id"] == "1"
|
||||
assert entries[2]["value"] == "test3"
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
def test_save_jsonl(self):
|
||||
"""Test saving JSONL file."""
|
||||
entries = [
|
||||
{"id": "1", "value": "test1"},
|
||||
{"id": "2", "value": "test2"}
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
save_jsonl(entries, temp_path)
|
||||
|
||||
with open(temp_path) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
assert len(lines) == 2
|
||||
assert json.loads(lines[0])["id"] == "1"
|
||||
assert json.loads(lines[1])["value"] == "test2"
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user