91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Tests for session_pair_harvester."""
|
|
|
|
import json
|
|
import sys
|
|
import os
|
|
import tempfile
|
|
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
from session_pair_harvester import extract_pairs_from_session, deduplicate_pairs, compute_hash
|
|
|
|
|
|
def test_basic_extraction():
|
|
session = {
|
|
"id": "test_001",
|
|
"model": "test-model",
|
|
"conversations": [
|
|
{"from": "system", "value": "You are helpful."},
|
|
{"from": "human", "value": "What is Python?"},
|
|
{"from": "gpt", "value": "Python is a high-level programming language known for its readability and versatility. It supports multiple paradigms including procedural, object-oriented, and functional programming. Python is widely used in web development, data science, machine learning, and automation."},
|
|
]
|
|
}
|
|
pairs = extract_pairs_from_session(session, min_ratio=1.5, min_response_words=10)
|
|
assert len(pairs) == 1
|
|
assert pairs[0]["terse"] == "What is Python?"
|
|
assert "programming language" in pairs[0]["rich"]
|
|
assert pairs[0]["source"] == "test_001"
|
|
print("PASS: test_basic_extraction")
|
|
|
|
|
|
def test_filters_short_responses():
|
|
session = {
|
|
"id": "test_002",
|
|
"model": "test",
|
|
"conversations": [
|
|
{"from": "human", "value": "Hi"},
|
|
{"from": "gpt", "value": "Hello!"},
|
|
]
|
|
}
|
|
pairs = extract_pairs_from_session(session, min_ratio=1.5, min_response_words=20)
|
|
assert len(pairs) == 0
|
|
print("PASS: test_filters_short_responses")
|
|
|
|
|
|
def test_skips_tool_results():
|
|
session = {
|
|
"id": "test_003",
|
|
"model": "test",
|
|
"conversations": [
|
|
{"from": "human", "value": '{"output": "file content", "exit_code": 0}'},
|
|
{"from": "gpt", "value": "The file was read successfully. Now let me analyze the content and provide a detailed summary of what was found in the file system."},
|
|
]
|
|
}
|
|
pairs = extract_pairs_from_session(session, min_ratio=1.5, min_response_words=10)
|
|
assert len(pairs) == 0
|
|
print("PASS: test_skips_tool_results")
|
|
|
|
|
|
def test_deduplication():
|
|
pairs = [
|
|
{"terse": "What is X?", "rich": "X is Y.", "source": "s1", "model": "m"},
|
|
{"terse": "What is X?", "rich": "X is Y.", "source": "s2", "model": "m"},
|
|
{"terse": "What is Z?", "rich": "Z is W.", "source": "s1", "model": "m"},
|
|
]
|
|
unique = deduplicate_pairs(pairs)
|
|
assert len(unique) == 2
|
|
print("PASS: test_deduplication")
|
|
|
|
|
|
def test_ratio_filter():
|
|
session = {
|
|
"id": "test_005",
|
|
"model": "test",
|
|
"conversations": [
|
|
{"from": "human", "value": "Explain quantum computing in detail with examples and applications"},
|
|
{"from": "gpt", "value": "OK."},
|
|
]
|
|
}
|
|
pairs = extract_pairs_from_session(session, min_ratio=1.5, min_response_words=10)
|
|
assert len(pairs) == 0 # response too short relative to prompt
|
|
print("PASS: test_ratio_filter")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_basic_extraction()
|
|
test_filters_short_responses()
|
|
test_skips_tool_results()
|
|
test_deduplication()
|
|
test_ratio_filter()
|
|
print("\nAll tests passed.")
|