fix: follow-up improvements for salvaged PR #5456
- SQLite write queue: thread-local connection pooling instead of creating+closing a new connection per operation - Prefetch threads: join previous batch before spawning new ones to prevent thread accumulation on rapid queue_prefetch() calls - Shutdown: join prefetch threads before stopping write queue - Add 73 tests covering _Client HTTP payloads, _WriteQueue crash recovery & connection reuse, _build_overlay deduplication, RetainDBMemoryProvider lifecycle/tools/prefetch/hooks, thread accumulation guard, and reasoning_level heuristic
This commit is contained in:
@@ -336,52 +336,58 @@ class _WriteQueue:
|
||||
self._q: queue.Queue = queue.Queue()
|
||||
self._thread = threading.Thread(target=self._loop, name="retaindb-writer", daemon=True)
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Thread-local connection cache — one connection per thread, reused.
|
||||
self._local = threading.local()
|
||||
self._init_db()
|
||||
self._thread.start()
|
||||
# Replay any rows left from a previous crash
|
||||
for row_id, user_id, session_id, msgs_json in self._pending_rows():
|
||||
self._q.put((row_id, user_id, session_id, json.loads(msgs_json)))
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(str(self._db_path), timeout=30)
|
||||
conn.row_factory = sqlite3.Row
|
||||
def _get_conn(self) -> sqlite3.Connection:
|
||||
"""Return a cached connection for the current thread."""
|
||||
conn = getattr(self._local, "conn", None)
|
||||
if conn is None:
|
||||
conn = sqlite3.connect(str(self._db_path), timeout=30)
|
||||
conn.row_factory = sqlite3.Row
|
||||
self._local.conn = conn
|
||||
return conn
|
||||
|
||||
def _init_db(self) -> None:
|
||||
with self._connect() as conn:
|
||||
conn.execute("""CREATE TABLE IF NOT EXISTS pending (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT, session_id TEXT, messages_json TEXT,
|
||||
created_at TEXT, last_error TEXT
|
||||
)""")
|
||||
conn.commit()
|
||||
conn = self._get_conn()
|
||||
conn.execute("""CREATE TABLE IF NOT EXISTS pending (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT, session_id TEXT, messages_json TEXT,
|
||||
created_at TEXT, last_error TEXT
|
||||
)""")
|
||||
conn.commit()
|
||||
|
||||
def _pending_rows(self) -> list:
|
||||
with self._connect() as conn:
|
||||
return conn.execute("SELECT id, user_id, session_id, messages_json FROM pending ORDER BY id ASC LIMIT 200").fetchall()
|
||||
conn = self._get_conn()
|
||||
return conn.execute("SELECT id, user_id, session_id, messages_json FROM pending ORDER BY id ASC LIMIT 200").fetchall()
|
||||
|
||||
def enqueue(self, user_id: str, session_id: str, messages: list) -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with self._connect() as conn:
|
||||
cur = conn.execute(
|
||||
"INSERT INTO pending (user_id, session_id, messages_json, created_at) VALUES (?,?,?,?)",
|
||||
(user_id, session_id, json.dumps(messages, ensure_ascii=False), now),
|
||||
)
|
||||
row_id = cur.lastrowid
|
||||
conn.commit()
|
||||
conn = self._get_conn()
|
||||
cur = conn.execute(
|
||||
"INSERT INTO pending (user_id, session_id, messages_json, created_at) VALUES (?,?,?,?)",
|
||||
(user_id, session_id, json.dumps(messages, ensure_ascii=False), now),
|
||||
)
|
||||
row_id = cur.lastrowid
|
||||
conn.commit()
|
||||
self._q.put((row_id, user_id, session_id, messages))
|
||||
|
||||
def _flush_row(self, row_id: int, user_id: str, session_id: str, messages: list) -> None:
|
||||
try:
|
||||
self._client.ingest_session(user_id, session_id, messages)
|
||||
with self._connect() as conn:
|
||||
conn.execute("DELETE FROM pending WHERE id = ?", (row_id,))
|
||||
conn.commit()
|
||||
conn = self._get_conn()
|
||||
conn.execute("DELETE FROM pending WHERE id = ?", (row_id,))
|
||||
conn.commit()
|
||||
except Exception as exc:
|
||||
logger.warning("RetainDB ingest failed (will retry): %s", exc)
|
||||
with self._connect() as conn:
|
||||
conn.execute("UPDATE pending SET last_error = ? WHERE id = ?", (str(exc), row_id))
|
||||
conn.commit()
|
||||
conn = self._get_conn()
|
||||
conn.execute("UPDATE pending SET last_error = ? WHERE id = ?", (str(exc), row_id))
|
||||
conn.commit()
|
||||
time.sleep(2)
|
||||
|
||||
def _loop(self) -> None:
|
||||
@@ -459,6 +465,9 @@ class RetainDBMemoryProvider(MemoryProvider):
|
||||
self._dialectic_result = ""
|
||||
self._agent_model: dict = {}
|
||||
|
||||
# Prefetch thread tracking — prevents accumulation on rapid calls
|
||||
self._prefetch_threads: list[threading.Thread] = []
|
||||
|
||||
# ── Core identity ──────────────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
@@ -533,9 +542,18 @@ class RetainDBMemoryProvider(MemoryProvider):
|
||||
"""Fire context + dialectic + agent model prefetches in background."""
|
||||
if not self._client:
|
||||
return
|
||||
threading.Thread(target=self._prefetch_context, args=(query,), name="retaindb-ctx", daemon=True).start()
|
||||
threading.Thread(target=self._prefetch_dialectic, args=(query,), name="retaindb-dialectic", daemon=True).start()
|
||||
threading.Thread(target=self._prefetch_agent_model, name="retaindb-agent-model", daemon=True).start()
|
||||
# Wait for any still-running prefetch threads before spawning new ones.
|
||||
# Prevents thread accumulation if turns fire faster than prefetches complete.
|
||||
for t in self._prefetch_threads:
|
||||
t.join(timeout=2.0)
|
||||
threads = [
|
||||
threading.Thread(target=self._prefetch_context, args=(query,), name="retaindb-ctx", daemon=True),
|
||||
threading.Thread(target=self._prefetch_dialectic, args=(query,), name="retaindb-dialectic", daemon=True),
|
||||
threading.Thread(target=self._prefetch_agent_model, name="retaindb-agent-model", daemon=True),
|
||||
]
|
||||
self._prefetch_threads = threads
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
def _prefetch_context(self, query: str) -> None:
|
||||
try:
|
||||
@@ -736,6 +754,8 @@ class RetainDBMemoryProvider(MemoryProvider):
|
||||
logger.debug("RetainDB memory mirror failed: %s", exc)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
for t in self._prefetch_threads:
|
||||
t.join(timeout=3.0)
|
||||
if self._queue:
|
||||
self._queue.shutdown()
|
||||
|
||||
|
||||
776
tests/plugins/test_retaindb_plugin.py
Normal file
776
tests/plugins/test_retaindb_plugin.py
Normal file
@@ -0,0 +1,776 @@
|
||||
"""Tests for the RetainDB memory plugin.
|
||||
|
||||
Covers: _Client HTTP client, _WriteQueue SQLite queue, _build_overlay formatter,
|
||||
RetainDBMemoryProvider lifecycle/tools/prefetch, thread management, connection pooling.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Imports — guarded since plugins/memory lives outside the standard test path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_env(tmp_path, monkeypatch):
|
||||
"""Ensure HERMES_HOME and RETAINDB vars are isolated."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("RETAINDB_API_KEY", raising=False)
|
||||
monkeypatch.delenv("RETAINDB_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("RETAINDB_PROJECT", raising=False)
|
||||
|
||||
|
||||
# We need the repo root on sys.path so the plugin can import agent.memory_provider
|
||||
import sys
|
||||
_repo_root = str(Path(__file__).resolve().parents[2])
|
||||
if _repo_root not in sys.path:
|
||||
sys.path.insert(0, _repo_root)
|
||||
|
||||
from plugins.memory.retaindb import (
|
||||
_Client,
|
||||
_WriteQueue,
|
||||
_build_overlay,
|
||||
RetainDBMemoryProvider,
|
||||
_ASYNC_SHUTDOWN,
|
||||
_DEFAULT_BASE_URL,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _Client tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestClient:
|
||||
"""Test the HTTP client with mocked requests."""
|
||||
|
||||
def _make_client(self, api_key="rdb-test-key", base_url="https://api.retaindb.com", project="test"):
|
||||
return _Client(api_key, base_url, project)
|
||||
|
||||
def test_base_url_trailing_slash_stripped(self):
|
||||
c = self._make_client(base_url="https://api.retaindb.com///")
|
||||
assert c.base_url == "https://api.retaindb.com"
|
||||
|
||||
def test_headers_include_auth(self):
|
||||
c = self._make_client()
|
||||
h = c._headers("/v1/files")
|
||||
assert h["Authorization"] == "Bearer rdb-test-key"
|
||||
assert "X-API-Key" not in h
|
||||
|
||||
def test_headers_include_api_key_for_memory_path(self):
|
||||
c = self._make_client()
|
||||
h = c._headers("/v1/memory/search")
|
||||
assert h["X-API-Key"] == "rdb-test-key"
|
||||
|
||||
def test_headers_include_api_key_for_context_path(self):
|
||||
c = self._make_client()
|
||||
h = c._headers("/v1/context/query")
|
||||
assert h["X-API-Key"] == "rdb-test-key"
|
||||
|
||||
def test_headers_strip_bearer_prefix(self):
|
||||
c = self._make_client(api_key="Bearer rdb-test-key")
|
||||
h = c._headers("/v1/memory/search")
|
||||
assert h["Authorization"] == "Bearer rdb-test-key"
|
||||
assert h["X-API-Key"] == "rdb-test-key"
|
||||
|
||||
def test_query_context_builds_correct_payload(self):
|
||||
c = self._make_client()
|
||||
with patch.object(c, "request") as mock_req:
|
||||
mock_req.return_value = {"results": []}
|
||||
c.query_context("user1", "sess1", "test query", max_tokens=500)
|
||||
mock_req.assert_called_once_with("POST", "/v1/context/query", json_body={
|
||||
"project": "test",
|
||||
"query": "test query",
|
||||
"user_id": "user1",
|
||||
"session_id": "sess1",
|
||||
"include_memories": True,
|
||||
"max_tokens": 500,
|
||||
})
|
||||
|
||||
def test_search_builds_correct_payload(self):
|
||||
c = self._make_client()
|
||||
with patch.object(c, "request") as mock_req:
|
||||
mock_req.return_value = {"results": []}
|
||||
c.search("user1", "sess1", "find this", top_k=5)
|
||||
mock_req.assert_called_once_with("POST", "/v1/memory/search", json_body={
|
||||
"project": "test",
|
||||
"query": "find this",
|
||||
"user_id": "user1",
|
||||
"session_id": "sess1",
|
||||
"top_k": 5,
|
||||
"include_pending": True,
|
||||
})
|
||||
|
||||
def test_add_memory_tries_fallback(self):
|
||||
c = self._make_client()
|
||||
call_count = 0
|
||||
def fake_request(method, path, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise RuntimeError("404")
|
||||
return {"id": "mem-1"}
|
||||
|
||||
with patch.object(c, "request", side_effect=fake_request):
|
||||
result = c.add_memory("u1", "s1", "test fact")
|
||||
assert result == {"id": "mem-1"}
|
||||
assert call_count == 2
|
||||
|
||||
def test_delete_memory_tries_fallback(self):
|
||||
c = self._make_client()
|
||||
call_count = 0
|
||||
def fake_request(method, path, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise RuntimeError("404")
|
||||
return {"deleted": True}
|
||||
|
||||
with patch.object(c, "request", side_effect=fake_request):
|
||||
result = c.delete_memory("mem-123")
|
||||
assert result == {"deleted": True}
|
||||
assert call_count == 2
|
||||
|
||||
def test_ingest_session_payload(self):
|
||||
c = self._make_client()
|
||||
with patch.object(c, "request") as mock_req:
|
||||
mock_req.return_value = {"status": "ok"}
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
c.ingest_session("u1", "s1", msgs, timeout=10.0)
|
||||
mock_req.assert_called_once_with("POST", "/v1/memory/ingest/session", json_body={
|
||||
"project": "test",
|
||||
"session_id": "s1",
|
||||
"user_id": "u1",
|
||||
"messages": msgs,
|
||||
"write_mode": "sync",
|
||||
}, timeout=10.0)
|
||||
|
||||
def test_ask_user_payload(self):
|
||||
c = self._make_client()
|
||||
with patch.object(c, "request") as mock_req:
|
||||
mock_req.return_value = {"answer": "test answer"}
|
||||
c.ask_user("u1", "who am i?", reasoning_level="medium")
|
||||
mock_req.assert_called_once()
|
||||
call_kwargs = mock_req.call_args
|
||||
assert call_kwargs[1]["json_body"]["reasoning_level"] == "medium"
|
||||
|
||||
def test_get_agent_model_path(self):
|
||||
c = self._make_client()
|
||||
with patch.object(c, "request") as mock_req:
|
||||
mock_req.return_value = {"memory_count": 3}
|
||||
c.get_agent_model("hermes")
|
||||
mock_req.assert_called_once_with(
|
||||
"GET", "/v1/memory/agent/hermes/model",
|
||||
params={"project": "test"}, timeout=4.0
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _WriteQueue tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestWriteQueue:
|
||||
"""Test the SQLite-backed write queue with real SQLite."""
|
||||
|
||||
def _make_queue(self, tmp_path, client=None):
|
||||
if client is None:
|
||||
client = MagicMock()
|
||||
client.ingest_session = MagicMock(return_value={"status": "ok"})
|
||||
db_path = tmp_path / "test_queue.db"
|
||||
return _WriteQueue(client, db_path), client, db_path
|
||||
|
||||
def test_enqueue_creates_row(self, tmp_path):
|
||||
q, client, db_path = self._make_queue(tmp_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
|
||||
# Give the writer thread a moment to process
|
||||
time.sleep(1)
|
||||
q.shutdown()
|
||||
# If ingest succeeded, the row should be deleted
|
||||
client.ingest_session.assert_called_once()
|
||||
|
||||
def test_enqueue_persists_to_sqlite(self, tmp_path):
|
||||
client = MagicMock()
|
||||
# Make ingest hang so the row stays in SQLite
|
||||
client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(5))
|
||||
db_path = tmp_path / "test_queue.db"
|
||||
q = _WriteQueue(client, db_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "test"}])
|
||||
# Check SQLite directly — row should exist since flush is slow
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
rows = conn.execute("SELECT user_id, session_id FROM pending").fetchall()
|
||||
conn.close()
|
||||
assert len(rows) >= 1
|
||||
assert rows[0][0] == "user1"
|
||||
q.shutdown()
|
||||
|
||||
def test_flush_deletes_row_on_success(self, tmp_path):
|
||||
q, client, db_path = self._make_queue(tmp_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
|
||||
time.sleep(1)
|
||||
q.shutdown()
|
||||
# Row should be gone
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
rows = conn.execute("SELECT COUNT(*) FROM pending").fetchone()[0]
|
||||
conn.close()
|
||||
assert rows == 0
|
||||
|
||||
def test_flush_records_error_on_failure(self, tmp_path):
|
||||
client = MagicMock()
|
||||
client.ingest_session = MagicMock(side_effect=RuntimeError("API down"))
|
||||
db_path = tmp_path / "test_queue.db"
|
||||
q = _WriteQueue(client, db_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
|
||||
time.sleep(3) # Allow retry + sleep(2) in _flush_row
|
||||
q.shutdown()
|
||||
# Row should still exist with error recorded
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
row = conn.execute("SELECT last_error FROM pending").fetchone()
|
||||
conn.close()
|
||||
assert row is not None
|
||||
assert "API down" in row[0]
|
||||
|
||||
def test_thread_local_connection_reuse(self, tmp_path):
|
||||
q, _, _ = self._make_queue(tmp_path)
|
||||
# Same thread should get same connection
|
||||
conn1 = q._get_conn()
|
||||
conn2 = q._get_conn()
|
||||
assert conn1 is conn2
|
||||
q.shutdown()
|
||||
|
||||
def test_crash_recovery_replays_pending(self, tmp_path):
|
||||
"""Simulate crash: create rows, then new queue should replay them."""
|
||||
db_path = tmp_path / "recovery_test.db"
|
||||
# First: create a queue and insert rows, but don't let them flush
|
||||
client1 = MagicMock()
|
||||
client1.ingest_session = MagicMock(side_effect=RuntimeError("fail"))
|
||||
q1 = _WriteQueue(client1, db_path)
|
||||
q1.enqueue("user1", "sess1", [{"role": "user", "content": "lost turn"}])
|
||||
time.sleep(3)
|
||||
q1.shutdown()
|
||||
|
||||
# Now create a new queue — it should replay the pending rows
|
||||
client2 = MagicMock()
|
||||
client2.ingest_session = MagicMock(return_value={"status": "ok"})
|
||||
q2 = _WriteQueue(client2, db_path)
|
||||
time.sleep(2)
|
||||
q2.shutdown()
|
||||
|
||||
# The replayed row should have been ingested via client2
|
||||
client2.ingest_session.assert_called_once()
|
||||
call_args = client2.ingest_session.call_args
|
||||
assert call_args[0][0] == "user1" # user_id
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _build_overlay tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestBuildOverlay:
|
||||
"""Test the overlay formatter (pure function)."""
|
||||
|
||||
def test_empty_inputs_returns_empty(self):
|
||||
assert _build_overlay({}, {}) == ""
|
||||
|
||||
def test_empty_memories_returns_empty(self):
|
||||
assert _build_overlay({"memories": []}, {"results": []}) == ""
|
||||
|
||||
def test_profile_items_included(self):
|
||||
profile = {"memories": [{"content": "User likes Python"}]}
|
||||
result = _build_overlay(profile, {})
|
||||
assert "User likes Python" in result
|
||||
assert "[RetainDB Context]" in result
|
||||
|
||||
def test_query_results_included(self):
|
||||
query_result = {"results": [{"content": "Previous discussion about Rust"}]}
|
||||
result = _build_overlay({}, query_result)
|
||||
assert "Previous discussion about Rust" in result
|
||||
|
||||
def test_deduplication_removes_duplicates(self):
|
||||
profile = {"memories": [{"content": "User likes Python"}]}
|
||||
query_result = {"results": [{"content": "User likes Python"}]}
|
||||
result = _build_overlay(profile, query_result)
|
||||
assert result.count("User likes Python") == 1
|
||||
|
||||
def test_local_entries_filter(self):
|
||||
profile = {"memories": [{"content": "Already known fact"}]}
|
||||
result = _build_overlay(profile, {}, local_entries=["Already known fact"])
|
||||
# The profile item matches a local entry, should be filtered
|
||||
assert result == ""
|
||||
|
||||
def test_max_five_items_per_section(self):
|
||||
profile = {"memories": [{"content": f"Fact {i}"} for i in range(10)]}
|
||||
result = _build_overlay(profile, {})
|
||||
# Should only include first 5
|
||||
assert "Fact 0" in result
|
||||
assert "Fact 4" in result
|
||||
assert "Fact 5" not in result
|
||||
|
||||
def test_none_content_handled(self):
|
||||
profile = {"memories": [{"content": None}, {"content": "Real fact"}]}
|
||||
result = _build_overlay(profile, {})
|
||||
assert "Real fact" in result
|
||||
|
||||
def test_truncation_at_320_chars(self):
|
||||
long_content = "x" * 500
|
||||
profile = {"memories": [{"content": long_content}]}
|
||||
result = _build_overlay(profile, {})
|
||||
# Each item is compacted to 320 chars max
|
||||
for line in result.split("\n"):
|
||||
if line.startswith("- "):
|
||||
assert len(line) <= 322 # "- " + 320
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# RetainDBMemoryProvider tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestRetainDBMemoryProvider:
|
||||
"""Test the main plugin class."""
|
||||
|
||||
def _make_provider(self, tmp_path, monkeypatch, api_key="rdb-test-key"):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", api_key)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
(tmp_path / ".hermes").mkdir(exist_ok=True)
|
||||
provider = RetainDBMemoryProvider()
|
||||
return provider
|
||||
|
||||
def test_name(self):
|
||||
p = RetainDBMemoryProvider()
|
||||
assert p.name == "retaindb"
|
||||
|
||||
def test_is_available_without_key(self):
|
||||
p = RetainDBMemoryProvider()
|
||||
assert p.is_available() is False
|
||||
|
||||
def test_is_available_with_key(self, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test")
|
||||
p = RetainDBMemoryProvider()
|
||||
assert p.is_available() is True
|
||||
|
||||
def test_config_schema(self):
|
||||
p = RetainDBMemoryProvider()
|
||||
schema = p.get_config_schema()
|
||||
assert len(schema) == 3
|
||||
keys = [s["key"] for s in schema]
|
||||
assert "api_key" in keys
|
||||
assert "base_url" in keys
|
||||
assert "project" in keys
|
||||
|
||||
def test_initialize_creates_client_and_queue(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
assert p._client is not None
|
||||
assert p._queue is not None
|
||||
assert p._session_id == "test-session"
|
||||
p.shutdown()
|
||||
|
||||
def test_initialize_default_project(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
assert p._client.project == "default"
|
||||
p.shutdown()
|
||||
|
||||
def test_initialize_explicit_project(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_PROJECT", "my-project")
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
assert p._client.project == "my-project"
|
||||
p.shutdown()
|
||||
|
||||
def test_initialize_profile_project(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
profile_home = str(tmp_path / "profiles" / "coder")
|
||||
p.initialize("test-session", hermes_home=profile_home)
|
||||
assert p._client.project == "hermes-coder"
|
||||
p.shutdown()
|
||||
|
||||
def test_initialize_seeds_soul_md(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
soul_path = tmp_path / ".hermes" / "SOUL.md"
|
||||
soul_path.write_text("I am a helpful agent.")
|
||||
with patch.object(RetainDBMemoryProvider, "_seed_soul") as mock_seed:
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
# Give thread time to start
|
||||
time.sleep(0.5)
|
||||
mock_seed.assert_called_once_with("I am a helpful agent.")
|
||||
p.shutdown()
|
||||
|
||||
def test_system_prompt_block(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
block = p.system_prompt_block()
|
||||
assert "RetainDB Memory" in block
|
||||
assert "Active" in block
|
||||
p.shutdown()
|
||||
|
||||
def test_tool_schemas_count(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
schemas = p.get_tool_schemas()
|
||||
assert len(schemas) == 10 # 5 memory + 5 file tools
|
||||
names = [s["name"] for s in schemas]
|
||||
assert "retaindb_profile" in names
|
||||
assert "retaindb_search" in names
|
||||
assert "retaindb_context" in names
|
||||
assert "retaindb_remember" in names
|
||||
assert "retaindb_forget" in names
|
||||
assert "retaindb_upload_file" in names
|
||||
assert "retaindb_list_files" in names
|
||||
assert "retaindb_read_file" in names
|
||||
assert "retaindb_ingest_file" in names
|
||||
assert "retaindb_delete_file" in names
|
||||
|
||||
def test_handle_tool_call_not_initialized(self):
|
||||
p = RetainDBMemoryProvider()
|
||||
result = json.loads(p.handle_tool_call("retaindb_profile", {}))
|
||||
assert "error" in result
|
||||
assert "not initialized" in result["error"]
|
||||
|
||||
def test_handle_tool_call_unknown_tool(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_nonexistent", {}))
|
||||
assert result == {"error": "Unknown tool: retaindb_nonexistent"}
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_profile(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "get_profile", return_value={"memories": []}):
|
||||
result = json.loads(p.handle_tool_call("retaindb_profile", {}))
|
||||
assert "memories" in result
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_search_requires_query(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_search", {}))
|
||||
assert result == {"error": "query is required"}
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_search(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "search", return_value={"results": [{"content": "found"}]}):
|
||||
result = json.loads(p.handle_tool_call("retaindb_search", {"query": "test"}))
|
||||
assert "results" in result
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_search_top_k_capped(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "search") as mock_search:
|
||||
mock_search.return_value = {"results": []}
|
||||
p.handle_tool_call("retaindb_search", {"query": "test", "top_k": 100})
|
||||
# top_k should be capped at 20
|
||||
assert mock_search.call_args[1]["top_k"] == 20
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_remember(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}):
|
||||
result = json.loads(p.handle_tool_call("retaindb_remember", {"content": "test fact"}))
|
||||
assert result["id"] == "mem-1"
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_remember_requires_content(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_remember", {}))
|
||||
assert result == {"error": "content is required"}
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_forget(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "delete_memory", return_value={"deleted": True}):
|
||||
result = json.loads(p.handle_tool_call("retaindb_forget", {"memory_id": "mem-1"}))
|
||||
assert result["deleted"] is True
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_forget_requires_id(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_forget", {}))
|
||||
assert result == {"error": "memory_id is required"}
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_context(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "query_context", return_value={"results": [{"content": "relevant"}]}), \
|
||||
patch.object(p._client, "get_profile", return_value={"memories": []}):
|
||||
result = json.loads(p.handle_tool_call("retaindb_context", {"query": "current task"}))
|
||||
assert "context" in result
|
||||
assert "raw" in result
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_file_list(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "list_files", return_value={"files": []}):
|
||||
result = json.loads(p.handle_tool_call("retaindb_list_files", {}))
|
||||
assert "files" in result
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_file_upload_missing_path(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_upload_file", {}))
|
||||
assert "error" in result
|
||||
|
||||
def test_dispatch_file_upload_not_found(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_upload_file", {"local_path": "/nonexistent/file.txt"}))
|
||||
assert "File not found" in result["error"]
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_file_read_requires_id(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_read_file", {}))
|
||||
assert result == {"error": "file_id is required"}
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_file_ingest_requires_id(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_ingest_file", {}))
|
||||
assert result == {"error": "file_id is required"}
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_file_delete_requires_id(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_delete_file", {}))
|
||||
assert result == {"error": "file_id is required"}
|
||||
p.shutdown()
|
||||
|
||||
def test_handle_tool_call_wraps_exception(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "get_profile", side_effect=RuntimeError("API exploded")):
|
||||
result = json.loads(p.handle_tool_call("retaindb_profile", {}))
|
||||
assert "API exploded" in result["error"]
|
||||
p.shutdown()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Prefetch and thread management tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestPrefetch:
|
||||
"""Test background prefetch and thread accumulation prevention."""
|
||||
|
||||
def _make_initialized_provider(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
return p
|
||||
|
||||
def test_queue_prefetch_skips_without_client(self):
|
||||
p = RetainDBMemoryProvider()
|
||||
p.queue_prefetch("test") # Should not raise
|
||||
|
||||
def test_prefetch_returns_empty_when_nothing_cached(self, tmp_path, monkeypatch):
|
||||
p = self._make_initialized_provider(tmp_path, monkeypatch)
|
||||
result = p.prefetch("test")
|
||||
assert result == ""
|
||||
p.shutdown()
|
||||
|
||||
def test_prefetch_consumes_context_result(self, tmp_path, monkeypatch):
|
||||
p = self._make_initialized_provider(tmp_path, monkeypatch)
|
||||
# Manually set the cached result
|
||||
with p._lock:
|
||||
p._context_result = "[RetainDB Context]\nProfile:\n- User likes tests"
|
||||
result = p.prefetch("test")
|
||||
assert "User likes tests" in result
|
||||
# Should be consumed
|
||||
assert p.prefetch("test") == ""
|
||||
p.shutdown()
|
||||
|
||||
def test_prefetch_consumes_dialectic_result(self, tmp_path, monkeypatch):
|
||||
p = self._make_initialized_provider(tmp_path, monkeypatch)
|
||||
with p._lock:
|
||||
p._dialectic_result = "User is a software engineer who prefers Python."
|
||||
result = p.prefetch("test")
|
||||
assert "[RetainDB User Synthesis]" in result
|
||||
assert "software engineer" in result
|
||||
p.shutdown()
|
||||
|
||||
def test_prefetch_consumes_agent_model(self, tmp_path, monkeypatch):
|
||||
p = self._make_initialized_provider(tmp_path, monkeypatch)
|
||||
with p._lock:
|
||||
p._agent_model = {
|
||||
"memory_count": 5,
|
||||
"persona": "Helpful coding assistant",
|
||||
"persistent_instructions": ["Be concise", "Use Python"],
|
||||
"working_style": "Direct and efficient",
|
||||
}
|
||||
result = p.prefetch("test")
|
||||
assert "[RetainDB Agent Self-Model]" in result
|
||||
assert "Helpful coding assistant" in result
|
||||
assert "Be concise" in result
|
||||
assert "Direct and efficient" in result
|
||||
p.shutdown()
|
||||
|
||||
def test_prefetch_skips_empty_agent_model(self, tmp_path, monkeypatch):
|
||||
p = self._make_initialized_provider(tmp_path, monkeypatch)
|
||||
with p._lock:
|
||||
p._agent_model = {"memory_count": 0}
|
||||
result = p.prefetch("test")
|
||||
assert "Agent Self-Model" not in result
|
||||
p.shutdown()
|
||||
|
||||
def test_thread_accumulation_guard(self, tmp_path, monkeypatch):
|
||||
"""Verify old prefetch threads are joined before new ones spawn."""
|
||||
p = self._make_initialized_provider(tmp_path, monkeypatch)
|
||||
# Mock the prefetch methods to be slow
|
||||
with patch.object(p, "_prefetch_context", side_effect=lambda q: time.sleep(0.5)), \
|
||||
patch.object(p, "_prefetch_dialectic", side_effect=lambda q: time.sleep(0.5)), \
|
||||
patch.object(p, "_prefetch_agent_model", side_effect=lambda: time.sleep(0.5)):
|
||||
p.queue_prefetch("query 1")
|
||||
first_threads = list(p._prefetch_threads)
|
||||
assert len(first_threads) == 3
|
||||
|
||||
# Call again — should join first batch before spawning new
|
||||
p.queue_prefetch("query 2")
|
||||
second_threads = list(p._prefetch_threads)
|
||||
assert len(second_threads) == 3
|
||||
# Should be different thread objects
|
||||
for t in second_threads:
|
||||
assert t not in first_threads
|
||||
p.shutdown()
|
||||
|
||||
def test_reasoning_level_short(self):
|
||||
assert RetainDBMemoryProvider._reasoning_level("hi") == "low"
|
||||
|
||||
def test_reasoning_level_medium(self):
|
||||
assert RetainDBMemoryProvider._reasoning_level("x" * 200) == "medium"
|
||||
|
||||
def test_reasoning_level_long(self):
|
||||
assert RetainDBMemoryProvider._reasoning_level("x" * 500) == "high"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# sync_turn tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestSyncTurn:
|
||||
"""Test turn synchronization via the write queue."""
|
||||
|
||||
def test_sync_turn_enqueues(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
with patch.object(p._queue, "enqueue") as mock_enqueue:
|
||||
p.sync_turn("user msg", "assistant msg")
|
||||
mock_enqueue.assert_called_once()
|
||||
args = mock_enqueue.call_args[0]
|
||||
assert args[0] == "default" # user_id
|
||||
assert args[1] == "test-session" # session_id
|
||||
msgs = args[2]
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["role"] == "user"
|
||||
assert msgs[1]["role"] == "assistant"
|
||||
p.shutdown()
|
||||
|
||||
def test_sync_turn_skips_empty_user_content(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
with patch.object(p._queue, "enqueue") as mock_enqueue:
|
||||
p.sync_turn("", "assistant msg")
|
||||
mock_enqueue.assert_not_called()
|
||||
p.shutdown()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# on_memory_write hook tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestOnMemoryWrite:
|
||||
"""Test the built-in memory mirror hook."""
|
||||
|
||||
def test_mirrors_add_action(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add:
|
||||
p.on_memory_write("add", "user", "User prefers dark mode")
|
||||
mock_add.assert_called_once()
|
||||
assert mock_add.call_args[1]["memory_type"] == "preference"
|
||||
p.shutdown()
|
||||
|
||||
def test_skips_non_add_action(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
with patch.object(p._client, "add_memory") as mock_add:
|
||||
p.on_memory_write("remove", "user", "something")
|
||||
mock_add.assert_not_called()
|
||||
p.shutdown()
|
||||
|
||||
def test_skips_empty_content(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
with patch.object(p._client, "add_memory") as mock_add:
|
||||
p.on_memory_write("add", "user", "")
|
||||
mock_add.assert_not_called()
|
||||
p.shutdown()
|
||||
|
||||
def test_memory_target_maps_to_type(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add:
|
||||
p.on_memory_write("add", "memory", "Some env fact")
|
||||
assert mock_add.call_args[1]["memory_type"] == "factual"
|
||||
p.shutdown()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# register() test
|
||||
# ===========================================================================
|
||||
|
||||
class TestRegister:
|
||||
def test_register_calls_register_memory_provider(self):
|
||||
from plugins.memory.retaindb import register
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
ctx.register_memory_provider.assert_called_once()
|
||||
arg = ctx.register_memory_provider.call_args[0][0]
|
||||
assert isinstance(arg, RetainDBMemoryProvider)
|
||||
Reference in New Issue
Block a user