#!/usr/bin/env python3 """ Tests for Training Pair Provenance Tracking """ import json import tempfile from pathlib import Path import pytest from training_pair_provenance import ProvenanceTracker, load_jsonl, save_jsonl class TestProvenanceTracker: """Test the ProvenanceTracker class.""" def test_init(self): """Test tracker initialization.""" tracker = ProvenanceTracker() assert tracker.stats["total_pairs"] == 0 assert tracker.stats["pairs_with_provenance"] == 0 assert tracker.stats["pairs_without_provenance"] == 0 def test_generate_pair_id(self): """Test pair ID generation.""" tracker = ProvenanceTracker() pair = {"prompt": "test", "chosen": "response", "rejected": "bad"} id1 = tracker.generate_pair_id(pair) id2 = tracker.generate_pair_id(pair) # Same content should generate same ID assert id1 == id2 assert len(id1) == 16 def test_add_provenance(self): """Test adding provenance to a pair.""" tracker = ProvenanceTracker() pair = {"prompt": "test", "chosen": "response", "rejected": "bad"} result = tracker.add_provenance(pair, source_session_id="session123", model="test-model") assert "provenance" in result assert result["provenance"]["source_session_id"] == "session123" assert result["provenance"]["model"] == "test-model" assert "timestamp" in result["provenance"] assert result["provenance"]["source"] == "curated" assert "content_hash" in result["provenance"] def test_extract_provenance_from_existing(self): """Test extracting provenance from existing fields.""" tracker = ProvenanceTracker() pair = { "id": "session456", "model": "claude-3-opus", "started_at": "2024-01-01T00:00:00Z", "conversations": [{"from": "human", "value": "test"}] } provenance = tracker.extract_provenance_from_existing(pair) assert provenance["source_session_id"] == "session456" assert provenance["model"] == "claude-3-opus" assert provenance["timestamp"] == "2024-01-01T00:00:00Z" assert provenance["source"] == "curated" assert "content_hash" in provenance def test_process_pair(self): """Test processing a pair.""" tracker = ProvenanceTracker() pair = {"id": "test123", "model": "test-model", "conversations": []} result = tracker.process_pair(pair) assert tracker.stats["total_pairs"] == 1 assert tracker.stats["pairs_without_provenance"] == 1 assert "provenance" in result def test_filter_by_provenance(self): """Test filtering pairs by provenance.""" tracker = ProvenanceTracker() pairs = [ {"provenance": {"model": "anthropic/claude-3-opus"}}, {"provenance": {"model": "gpt-4"}}, {"provenance": {"model": "anthropic/claude-3-sonnet"}}, ] filtered = tracker.filter_by_provenance(pairs, exclude_models=["anthropic/claude-3-opus", "anthropic/claude-3-sonnet"]) assert len(filtered) == 1 assert filtered[0]["provenance"]["model"] == "gpt-4" assert tracker.stats["excluded"] == 2 def test_generate_report(self): """Test report generation.""" tracker = ProvenanceTracker() tracker.stats = { "total_pairs": 10, "pairs_with_provenance": 8, "pairs_without_provenance": 2, "by_model": {"gpt-4": 5, "claude-3": 3}, "by_source": {"curated": 8}, "excluded": 0 } report = tracker.generate_report() assert "Total pairs: 10" in report assert "Pairs with provenance: 8" in report assert "gpt-4: 5" in report class TestJsonlFunctions: """Test JSONL load/save functions.""" def test_load_jsonl(self): """Test loading JSONL file.""" with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f: f.write('{"id": "1", "value": "test1"}\n') f.write('{"id": "2", "value": "test2"}\n') f.write('{"id": "3", "value": "test3"}\n') temp_path = Path(f.name) try: entries = load_jsonl(temp_path) assert len(entries) == 3 assert entries[0]["id"] == "1" assert entries[2]["value"] == "test3" finally: temp_path.unlink() def test_save_jsonl(self): """Test saving JSONL file.""" entries = [ {"id": "1", "value": "test1"}, {"id": "2", "value": "test2"} ] with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f: temp_path = Path(f.name) try: save_jsonl(entries, temp_path) with open(temp_path) as f: lines = f.readlines() assert len(lines) == 2 assert json.loads(lines[0])["id"] == "1" assert json.loads(lines[1])["value"] == "test2" finally: temp_path.unlink() if __name__ == "__main__": pytest.main([__file__, "-v"])