diff --git a/training/test_training_pair_provenance.py b/training/test_training_pair_provenance.py new file mode 100644 index 00000000..b7c50221 --- /dev/null +++ b/training/test_training_pair_provenance.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +""" +Tests for Training Pair Provenance Tracking +""" + +import json +import tempfile +from pathlib import Path +import pytest + +from training_pair_provenance import ProvenanceTracker, load_jsonl, save_jsonl + + +class TestProvenanceTracker: + """Test the ProvenanceTracker class.""" + + def test_init(self): + """Test tracker initialization.""" + tracker = ProvenanceTracker() + assert tracker.stats["total_pairs"] == 0 + assert tracker.stats["pairs_with_provenance"] == 0 + assert tracker.stats["pairs_without_provenance"] == 0 + + def test_generate_pair_id(self): + """Test pair ID generation.""" + tracker = ProvenanceTracker() + pair = {"prompt": "test", "chosen": "response", "rejected": "bad"} + + id1 = tracker.generate_pair_id(pair) + id2 = tracker.generate_pair_id(pair) + + # Same content should generate same ID + assert id1 == id2 + assert len(id1) == 16 + + def test_add_provenance(self): + """Test adding provenance to a pair.""" + tracker = ProvenanceTracker() + pair = {"prompt": "test", "chosen": "response", "rejected": "bad"} + + result = tracker.add_provenance(pair, source_session_id="session123", model="test-model") + + assert "provenance" in result + assert result["provenance"]["source_session_id"] == "session123" + assert result["provenance"]["model"] == "test-model" + assert "timestamp" in result["provenance"] + assert result["provenance"]["source"] == "curated" + assert "content_hash" in result["provenance"] + + def test_extract_provenance_from_existing(self): + """Test extracting provenance from existing fields.""" + tracker = ProvenanceTracker() + pair = { + "id": "session456", + "model": "claude-3-opus", + "started_at": "2024-01-01T00:00:00Z", + "conversations": [{"from": "human", "value": "test"}] + } + + provenance = tracker.extract_provenance_from_existing(pair) + + assert provenance["source_session_id"] == "session456" + assert provenance["model"] == "claude-3-opus" + assert provenance["timestamp"] == "2024-01-01T00:00:00Z" + assert provenance["source"] == "curated" + assert "content_hash" in provenance + + def test_process_pair(self): + """Test processing a pair.""" + tracker = ProvenanceTracker() + pair = {"id": "test123", "model": "test-model", "conversations": []} + + result = tracker.process_pair(pair) + + assert tracker.stats["total_pairs"] == 1 + assert tracker.stats["pairs_without_provenance"] == 1 + assert "provenance" in result + + def test_filter_by_provenance(self): + """Test filtering pairs by provenance.""" + tracker = ProvenanceTracker() + + pairs = [ + {"provenance": {"model": "anthropic/claude-3-opus"}}, + {"provenance": {"model": "gpt-4"}}, + {"provenance": {"model": "anthropic/claude-3-sonnet"}}, + ] + + filtered = tracker.filter_by_provenance(pairs, exclude_models=["anthropic/claude-3-opus", "anthropic/claude-3-sonnet"]) + + assert len(filtered) == 1 + assert filtered[0]["provenance"]["model"] == "gpt-4" + assert tracker.stats["excluded"] == 2 + + def test_generate_report(self): + """Test report generation.""" + tracker = ProvenanceTracker() + tracker.stats = { + "total_pairs": 10, + "pairs_with_provenance": 8, + "pairs_without_provenance": 2, + "by_model": {"gpt-4": 5, "claude-3": 3}, + "by_source": {"curated": 8}, + "excluded": 0 + } + + report = tracker.generate_report() + + assert "Total pairs: 10" in report + assert "Pairs with provenance: 8" in report + assert "gpt-4: 5" in report + + +class TestJsonlFunctions: + """Test JSONL load/save functions.""" + + def test_load_jsonl(self): + """Test loading JSONL file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f: + f.write('{"id": "1", "value": "test1"}\n') + f.write('{"id": "2", "value": "test2"}\n') + f.write('{"id": "3", "value": "test3"}\n') + temp_path = Path(f.name) + + try: + entries = load_jsonl(temp_path) + assert len(entries) == 3 + assert entries[0]["id"] == "1" + assert entries[2]["value"] == "test3" + finally: + temp_path.unlink() + + def test_save_jsonl(self): + """Test saving JSONL file.""" + entries = [ + {"id": "1", "value": "test1"}, + {"id": "2", "value": "test2"} + ] + + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f: + temp_path = Path(f.name) + + try: + save_jsonl(entries, temp_path) + + with open(temp_path) as f: + lines = f.readlines() + + assert len(lines) == 2 + assert json.loads(lines[0])["id"] == "1" + assert json.loads(lines[1])["value"] == "test2" + finally: + temp_path.unlink() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])