92 lines
2.9 KiB
Python
92 lines
2.9 KiB
Python
|
|
"""Tests for session compaction with fact extraction."""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||
|
|
|
||
|
|
from agent.session_compactor import (
|
||
|
|
ExtractedFact,
|
||
|
|
extract_facts_from_messages,
|
||
|
|
save_facts_to_store,
|
||
|
|
extract_and_save_facts,
|
||
|
|
format_facts_summary,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestFactExtraction:
|
||
|
|
def test_extract_preference(self):
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": "I prefer Python over JavaScript for backend work."},
|
||
|
|
]
|
||
|
|
facts = extract_facts_from_messages(messages)
|
||
|
|
assert len(facts) >= 1
|
||
|
|
assert any("Python" in f.content for f in facts)
|
||
|
|
|
||
|
|
def test_extract_correction(self):
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": "Actually the port is 8081 not 8080."},
|
||
|
|
]
|
||
|
|
facts = extract_facts_from_messages(messages)
|
||
|
|
assert len(facts) >= 1
|
||
|
|
assert any("8081" in f.content for f in facts)
|
||
|
|
|
||
|
|
def test_extract_project_fact(self):
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": "The project uses Gitea for source control."},
|
||
|
|
]
|
||
|
|
facts = extract_facts_from_messages(messages)
|
||
|
|
assert len(facts) >= 1
|
||
|
|
|
||
|
|
def test_skip_tool_results(self):
|
||
|
|
messages = [
|
||
|
|
{"role": "assistant", "content": "Running command...", "tool_calls": [{"id": "1"}]},
|
||
|
|
{"role": "tool", "content": "output here"},
|
||
|
|
]
|
||
|
|
facts = extract_facts_from_messages(messages)
|
||
|
|
assert len(facts) == 0
|
||
|
|
|
||
|
|
def test_skip_short_messages(self):
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": "ok"},
|
||
|
|
]
|
||
|
|
facts = extract_facts_from_messages(messages)
|
||
|
|
assert len(facts) == 0
|
||
|
|
|
||
|
|
def test_deduplication(self):
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": "I prefer Python."},
|
||
|
|
{"role": "user", "content": "I prefer Python."},
|
||
|
|
]
|
||
|
|
facts = extract_facts_from_messages(messages)
|
||
|
|
# Should deduplicate
|
||
|
|
python_facts = [f for f in facts if "Python" in f.content]
|
||
|
|
assert len(python_facts) == 1
|
||
|
|
|
||
|
|
|
||
|
|
class TestSaveFacts:
|
||
|
|
def test_save_with_callback(self):
|
||
|
|
saved = []
|
||
|
|
def mock_save(category, entity, content, trust):
|
||
|
|
saved.append({"category": category, "content": content})
|
||
|
|
|
||
|
|
facts = [ExtractedFact("user_pref", "user", "likes dark mode", 0.8, 0)]
|
||
|
|
count = save_facts_to_store(facts, fact_store_fn=mock_save)
|
||
|
|
assert count == 1
|
||
|
|
assert len(saved) == 1
|
||
|
|
|
||
|
|
|
||
|
|
class TestFormatSummary:
|
||
|
|
def test_empty(self):
|
||
|
|
assert "No facts" in format_facts_summary([])
|
||
|
|
|
||
|
|
def test_with_facts(self):
|
||
|
|
facts = [
|
||
|
|
ExtractedFact("user_pref", "user", "likes dark mode", 0.8, 0),
|
||
|
|
ExtractedFact("correction", "user", "port is 8081", 0.9, 1),
|
||
|
|
]
|
||
|
|
summary = format_facts_summary(facts)
|
||
|
|
assert "2 facts" in summary
|
||
|
|
assert "user_pref" in summary
|