174 lines
6.1 KiB
Python
174 lines
6.1 KiB
Python
|
|
"""Tests for Mem0 Local memory provider - ChromaDB-backed, no API key."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import tempfile
|
||
|
|
from pathlib import Path
|
||
|
|
from unittest.mock import MagicMock, patch
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
|
||
|
|
# Fact extraction tests
|
||
|
|
|
||
|
|
class TestFactExtraction:
|
||
|
|
"""Test the regex-based fact extraction."""
|
||
|
|
|
||
|
|
def _extract(self, text):
|
||
|
|
from plugins.memory.mem0_local import _extract_facts
|
||
|
|
return _extract_facts(text)
|
||
|
|
|
||
|
|
def test_name_extraction(self):
|
||
|
|
facts = self._extract("My name is Alexander Whitestone.")
|
||
|
|
assert any("alexander whitestone" in f["content"].lower() for f in facts)
|
||
|
|
|
||
|
|
def test_preference_extraction(self):
|
||
|
|
facts = self._extract("I prefer using vim for editing.")
|
||
|
|
assert any("vim" in f["content"].lower() for f in facts)
|
||
|
|
|
||
|
|
def test_timezone_extraction(self):
|
||
|
|
facts = self._extract("My timezone is America/New_York.")
|
||
|
|
assert any("america/new_york" in f["content"].lower() for f in facts)
|
||
|
|
|
||
|
|
def test_explicit_remember(self):
|
||
|
|
facts = self._extract("Remember: always use f-strings in Python.")
|
||
|
|
assert len(facts) > 0
|
||
|
|
|
||
|
|
def test_correction_extraction(self):
|
||
|
|
facts = self._extract("Actually: the port is 8080, not 3000.")
|
||
|
|
assert len(facts) > 0
|
||
|
|
|
||
|
|
def test_empty_input(self):
|
||
|
|
facts = self._extract("")
|
||
|
|
assert facts == []
|
||
|
|
|
||
|
|
def test_short_input_ignored(self):
|
||
|
|
facts = self._extract("Hi")
|
||
|
|
assert facts == []
|
||
|
|
|
||
|
|
def test_no_crash_on_random_text(self):
|
||
|
|
facts = self._extract("The quick brown fox jumps over the lazy dog. " * 10)
|
||
|
|
assert isinstance(facts, list)
|
||
|
|
|
||
|
|
|
||
|
|
# Config tests
|
||
|
|
|
||
|
|
class TestConfig:
|
||
|
|
"""Test configuration loading."""
|
||
|
|
|
||
|
|
def test_default_storage_path(self, tmp_path, monkeypatch):
|
||
|
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||
|
|
from plugins.memory.mem0_local import _load_config
|
||
|
|
config = _load_config()
|
||
|
|
assert "mem0-local" in config["storage_path"]
|
||
|
|
|
||
|
|
def test_env_override(self, tmp_path, monkeypatch):
|
||
|
|
custom_path = str(tmp_path / "custom-mem0")
|
||
|
|
monkeypatch.setenv("MEM0_LOCAL_PATH", custom_path)
|
||
|
|
from plugins.memory.mem0_local import _load_config
|
||
|
|
config = _load_config()
|
||
|
|
assert config["storage_path"] == custom_path
|
||
|
|
|
||
|
|
|
||
|
|
# Provider interface tests
|
||
|
|
|
||
|
|
class TestProviderInterface:
|
||
|
|
"""Test provider interface methods."""
|
||
|
|
|
||
|
|
def test_name(self):
|
||
|
|
from plugins.memory.mem0_local import Mem0LocalProvider
|
||
|
|
provider = Mem0LocalProvider()
|
||
|
|
assert provider.name == "mem0-local"
|
||
|
|
|
||
|
|
def test_tool_schemas(self):
|
||
|
|
from plugins.memory.mem0_local import Mem0LocalProvider
|
||
|
|
provider = Mem0LocalProvider()
|
||
|
|
schemas = provider.get_tool_schemas()
|
||
|
|
names = {s["name"] for s in schemas}
|
||
|
|
assert names == {"mem0_profile", "mem0_search", "mem0_conclude"}
|
||
|
|
|
||
|
|
def test_schema_required_params(self):
|
||
|
|
from plugins.memory.mem0_local import Mem0LocalProvider
|
||
|
|
provider = Mem0LocalProvider()
|
||
|
|
schemas = {s["name"]: s for s in provider.get_tool_schemas()}
|
||
|
|
assert "query" in schemas["mem0_search"]["parameters"]["required"]
|
||
|
|
assert "conclusion" in schemas["mem0_conclude"]["parameters"]["required"]
|
||
|
|
|
||
|
|
|
||
|
|
# ChromaDB integration tests
|
||
|
|
|
||
|
|
chromadb = None
|
||
|
|
try:
|
||
|
|
import chromadb
|
||
|
|
except ImportError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.skipif(chromadb is None, reason="chromadb not installed")
|
||
|
|
class TestChromaDBIntegration:
|
||
|
|
"""Integration tests with real ChromaDB."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def provider(self, tmp_path, monkeypatch):
|
||
|
|
from plugins.memory.mem0_local import Mem0LocalProvider
|
||
|
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||
|
|
provider = Mem0LocalProvider()
|
||
|
|
provider.initialize("test-session")
|
||
|
|
provider._storage_path = str(tmp_path / "mem0-test")
|
||
|
|
return provider
|
||
|
|
|
||
|
|
def test_store_and_search(self, provider):
|
||
|
|
result = provider.handle_tool_call("mem0_conclude", {"conclusion": "User prefers Python over JavaScript"})
|
||
|
|
data = json.loads(result)
|
||
|
|
assert data.get("result") == "Fact stored locally."
|
||
|
|
|
||
|
|
result = provider.handle_tool_call("mem0_search", {"query": "programming language preference"})
|
||
|
|
data = json.loads(result)
|
||
|
|
assert data["count"] > 0
|
||
|
|
assert any("python" in item["memory"].lower() for item in data["results"])
|
||
|
|
|
||
|
|
def test_profile_empty(self, provider):
|
||
|
|
result = provider.handle_tool_call("mem0_profile", {})
|
||
|
|
data = json.loads(result)
|
||
|
|
assert "No memories" in data.get("result", "") or data.get("count", 0) == 0
|
||
|
|
|
||
|
|
def test_profile_after_store(self, provider):
|
||
|
|
provider.handle_tool_call("mem0_conclude", {"conclusion": "User name is Alexander"})
|
||
|
|
provider.handle_tool_call("mem0_conclude", {"conclusion": "User timezone is UTC"})
|
||
|
|
|
||
|
|
result = provider.handle_tool_call("mem0_profile", {})
|
||
|
|
data = json.loads(result)
|
||
|
|
assert data["count"] >= 2
|
||
|
|
|
||
|
|
def test_dedup(self, provider):
|
||
|
|
provider.handle_tool_call("mem0_conclude", {"conclusion": "Project uses SQLite"})
|
||
|
|
provider.handle_tool_call("mem0_conclude", {"conclusion": "Project uses SQLite"})
|
||
|
|
|
||
|
|
result = provider.handle_tool_call("mem0_profile", {})
|
||
|
|
data = json.loads(result)
|
||
|
|
assert data["count"] == 1
|
||
|
|
|
||
|
|
def test_search_no_results(self, provider):
|
||
|
|
result = provider.handle_tool_call("mem0_search", {"query": "nonexistent topic xyz123"})
|
||
|
|
data = json.loads(result)
|
||
|
|
assert data.get("result") == "No relevant memories found." or data.get("count", 0) == 0
|
||
|
|
|
||
|
|
def test_sync_turn_extraction(self, provider):
|
||
|
|
provider.sync_turn(
|
||
|
|
"My name is TestUser and I prefer dark mode.",
|
||
|
|
"Hello TestUser! I'll remember your preference.",
|
||
|
|
)
|
||
|
|
result = provider.handle_tool_call("mem0_profile", {})
|
||
|
|
data = json.loads(result)
|
||
|
|
assert "count" in data
|
||
|
|
|
||
|
|
def test_conclude_missing_param(self, provider):
|
||
|
|
result = provider.handle_tool_call("mem0_conclude", {})
|
||
|
|
data = json.loads(result)
|
||
|
|
assert "error" in data
|
||
|
|
|
||
|
|
def test_search_missing_query(self, provider):
|
||
|
|
result = provider.handle_tool_call("mem0_search", {})
|
||
|
|
data = json.loads(result)
|
||
|
|
assert "error" in data
|