117 lines
4.5 KiB
Python
117 lines
4.5 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Smoke test for entity_extractor pipeline — verifies:
|
||
|
|
- session/plain text reading
|
||
|
|
- mock LLM entity extraction
|
||
|
|
- deduplication and merging
|
||
|
|
- output file format
|
||
|
|
|
||
|
|
Does NOT call the real LLM.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import tempfile
|
||
|
|
from unittest.mock import patch
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
SCRIPT_DIR = Path(__file__).parent.absolute()
|
||
|
|
sys.path.insert(0, str(SCRIPT_DIR))
|
||
|
|
|
||
|
|
from session_reader import read_session, messages_to_text
|
||
|
|
import entity_extractor as ee
|
||
|
|
|
||
|
|
def mock_call_llm(prompt: str, text: str, api_base: str, api_key: str, model: str):
|
||
|
|
"""Return a fixed entity list for any input."""
|
||
|
|
return [
|
||
|
|
{"name": "Hermes", "type": "tool", "context": "Hermes agent uses the tools tool."},
|
||
|
|
{"name": "Gitea", "type": "tool", "context": "Gitea is a forge."},
|
||
|
|
{"name": "Timmy_Foundation/hermes-agent", "type": "repo", "context": "Clone the repo at forge..."},
|
||
|
|
]
|
||
|
|
|
||
|
|
def test_read_session_text():
|
||
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||
|
|
f.write('{"role": "user", "content": "Clone repo", "timestamp": "2026-04-13T10:00:00Z"}\n')
|
||
|
|
f.write('{"role": "assistant", "content": "Done", "timestamp": "2026-04-13T10:00:05Z"}\n')
|
||
|
|
path = f.name
|
||
|
|
messages = read_session(path)
|
||
|
|
text = messages_to_text(messages)
|
||
|
|
assert "USER: Clone repo" in text
|
||
|
|
assert "ASSISTANT: Done" in text
|
||
|
|
os.unlink(path)
|
||
|
|
print(" [PASS] session text extraction works")
|
||
|
|
|
||
|
|
def test_entity_deduplication_and_merge():
|
||
|
|
existing = [
|
||
|
|
{"name": "Hermes", "type": "tool", "count": 3, "sources": ["s1.jsonl"]}
|
||
|
|
]
|
||
|
|
new = [
|
||
|
|
{"name": "Hermes", "type": "tool", "sources": ["s2.jsonl"]},
|
||
|
|
{"name": "Gitea", "type": "tool", "sources": ["s2.jsonl"]},
|
||
|
|
]
|
||
|
|
merged = ee.merge_entities(new, existing.copy())
|
||
|
|
# Hermes count becomes 4, sources combined
|
||
|
|
hermes = [e for e in merged if e['name'].lower() == 'hermes'][0]
|
||
|
|
assert hermes['count'] == 4
|
||
|
|
assert set(hermes['sources']) == {'s1.jsonl', 's2.jsonl'}
|
||
|
|
# Gitea new entry
|
||
|
|
gitea = [e for e in merged if e['name'].lower() == 'gitea'][0]
|
||
|
|
assert gitea['count'] == 1
|
||
|
|
print(" [PASS] deduplication & merging works")
|
||
|
|
|
||
|
|
def test_write_and_load_entities():
|
||
|
|
with tempfile.TemporaryDirectory() as tmp:
|
||
|
|
kdir = Path(tmp) / "knowledge"
|
||
|
|
kdir.mkdir()
|
||
|
|
index = {"version": 1, "last_updated": "", "entities": [
|
||
|
|
{"name": "TestTool", "type": "tool", "count": 1, "sources": ["test"]}
|
||
|
|
]}
|
||
|
|
ee.write_entities(index, str(kdir))
|
||
|
|
# load back
|
||
|
|
loaded = ee.load_existing_entities(str(kdir))
|
||
|
|
assert loaded['entities'][0]['name'] == 'TestTool'
|
||
|
|
print(" [PASS] entities persistence works")
|
||
|
|
|
||
|
|
def test_full_pipeline_mocked():
|
||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||
|
|
# Create two fake session files
|
||
|
|
sess1 = Path(tmpdir) / "s1.jsonl"
|
||
|
|
sess1.write_text('{"role":"user","content":"Use Hermes to clone","timestamp":"..."}\n')
|
||
|
|
sess2 = Path(tmpdir) / "s2.jsonl"
|
||
|
|
sess2.write_text('{"role":"user","content":"Deploy with Gitea","timestamp":"..."}\n')
|
||
|
|
|
||
|
|
knowledge_dir = Path(tmpdir) / "knowledge"
|
||
|
|
knowledge_dir.mkdir()
|
||
|
|
|
||
|
|
# Patch call_llm
|
||
|
|
with patch('entity_extractor.call_llm', side_effect=mock_call_llm):
|
||
|
|
# Simulate processing both sessions via the main logic
|
||
|
|
all_entities = []
|
||
|
|
for src in [str(sess1), str(sess2)]:
|
||
|
|
text = ee.read_text_from_source(src)
|
||
|
|
ents = ee.extract_from_text(text, "http://api", "fake-key", "model", source_name=Path(src).name)
|
||
|
|
all_entities.extend(ents)
|
||
|
|
|
||
|
|
# Merge into empty index
|
||
|
|
merged = ee.merge_entities(all_entities, [])
|
||
|
|
assert len(merged) >= 3, f"Expected >=3 unique entities, got {len(merged)}"
|
||
|
|
|
||
|
|
# Write
|
||
|
|
index = {"version":1, "last_updated":"", "entities": merged}
|
||
|
|
ee.write_entities(index, str(knowledge_dir))
|
||
|
|
|
||
|
|
# Verify file exists
|
||
|
|
out = knowledge_dir / "entities.json"
|
||
|
|
assert out.exists()
|
||
|
|
data = json.loads(out.read_text())
|
||
|
|
assert len(data['entities']) >= 3
|
||
|
|
print(f" [PASS] full pipeline (mocked) produced {len(data['entities'])} entities")
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
test_read_session_text()
|
||
|
|
test_entity_deduplication_and_merge()
|
||
|
|
test_write_and_load_entities()
|
||
|
|
test_full_pipeline_mocked()
|
||
|
|
print("\nAll smoke tests passed.")
|