diff --git a/plugins/memory/retaindb/__init__.py b/plugins/memory/retaindb/__init__.py index d1cbec54a..94dba6153 100644 --- a/plugins/memory/retaindb/__init__.py +++ b/plugins/memory/retaindb/__init__.py @@ -1,29 +1,45 @@ """RetainDB memory plugin — MemoryProvider interface. -Cross-session memory via RetainDB cloud API. Durable write-behind queue, -semantic search with deduplication, and user profile retrieval. +Cross-session memory via RetainDB cloud API. -Original PR #2732 by Alinxus, adapted to MemoryProvider ABC. +Features: +- Correct API routes for all operations +- Durable SQLite write-behind queue (crash-safe, async ingest) +- Semantic search + user profile retrieval +- Context query with deduplication overlay +- Dialectic synthesis (LLM-powered user understanding, prefetched each turn) +- Agent self-model (persona + instructions from SOUL.md, prefetched each turn) +- Shared file store tools (upload, list, read, ingest, delete) +- Explicit memory tools (profile, search, context, remember, forget) -Config via environment variables: - RETAINDB_API_KEY — API key (required) - RETAINDB_BASE_URL — API endpoint (default: https://api.retaindb.com) - RETAINDB_PROJECT — Project identifier (default: hermes) +Config (env vars or hermes config.yaml under retaindb:): + RETAINDB_API_KEY — API key (required) + RETAINDB_BASE_URL — API endpoint (default: https://api.retaindb.com) + RETAINDB_PROJECT — Project identifier """ from __future__ import annotations +import hashlib import json import logging import os +import queue +import re +import sqlite3 import threading +import time +from datetime import datetime, timezone +from pathlib import Path from typing import Any, Dict, List +from urllib.parse import quote from agent.memory_provider import MemoryProvider logger = logging.getLogger(__name__) _DEFAULT_BASE_URL = "https://api.retaindb.com" +_ASYNC_SHUTDOWN = object() # --------------------------------------------------------------------------- @@ -32,16 +48,13 @@ _DEFAULT_BASE_URL = "https://api.retaindb.com" PROFILE_SCHEMA = { "name": "retaindb_profile", - "description": "Get the user's stable profile — preferences, facts, and patterns.", + "description": "Get the user's stable profile — preferences, facts, and patterns recalled from long-term memory.", "parameters": {"type": "object", "properties": {}, "required": []}, } SEARCH_SCHEMA = { "name": "retaindb_search", - "description": ( - "Semantic search across stored memories. Returns ranked results " - "with relevance scores." - ), + "description": "Semantic search across stored memories. Returns ranked results with relevance scores.", "parameters": { "type": "object", "properties": { @@ -54,7 +67,7 @@ SEARCH_SCHEMA = { CONTEXT_SCHEMA = { "name": "retaindb_context", - "description": "Synthesized 'what matters now' context block for the current task.", + "description": "Synthesized context block — what matters most for the current task, pulled from long-term memory.", "parameters": { "type": "object", "properties": { @@ -66,20 +79,17 @@ CONTEXT_SCHEMA = { REMEMBER_SCHEMA = { "name": "retaindb_remember", - "description": "Persist an explicit fact or preference to long-term memory.", + "description": "Persist an explicit fact, preference, or decision to long-term memory.", "parameters": { "type": "object", "properties": { "content": {"type": "string", "description": "The fact to remember."}, "memory_type": { "type": "string", - "enum": ["preference", "fact", "decision", "context"], - "description": "Category (default: fact).", - }, - "importance": { - "type": "number", - "description": "Importance 0-1 (default: 0.5).", + "enum": ["factual", "preference", "goal", "instruction", "event", "opinion"], + "description": "Category (default: factual).", }, + "importance": {"type": "number", "description": "Importance 0-1 (default: 0.7)."}, }, "required": ["content"], }, @@ -97,23 +107,359 @@ FORGET_SCHEMA = { }, } +FILE_UPLOAD_SCHEMA = { + "name": "retaindb_upload_file", + "description": "Upload a file to the shared RetainDB file store. Returns an rdb:// URI any agent can reference.", + "parameters": { + "type": "object", + "properties": { + "local_path": {"type": "string", "description": "Local file path to upload."}, + "remote_path": {"type": "string", "description": "Destination path, e.g. /reports/q1.pdf"}, + "scope": {"type": "string", "enum": ["USER", "PROJECT", "ORG"], "description": "Access scope (default: PROJECT)."}, + "ingest": {"type": "boolean", "description": "Also extract memories from file after upload (default: false)."}, + }, + "required": ["local_path"], + }, +} + +FILE_LIST_SCHEMA = { + "name": "retaindb_list_files", + "description": "List files in the shared file store.", + "parameters": { + "type": "object", + "properties": { + "prefix": {"type": "string", "description": "Path prefix to filter by, e.g. /reports/"}, + "limit": {"type": "integer", "description": "Max results (default: 50)."}, + }, + "required": [], + }, +} + +FILE_READ_SCHEMA = { + "name": "retaindb_read_file", + "description": "Read the text content of a stored file by its file ID.", + "parameters": { + "type": "object", + "properties": { + "file_id": {"type": "string", "description": "File ID returned from upload or list."}, + }, + "required": ["file_id"], + }, +} + +FILE_INGEST_SCHEMA = { + "name": "retaindb_ingest_file", + "description": "Chunk, embed, and extract memories from a stored file. Makes its contents searchable.", + "parameters": { + "type": "object", + "properties": { + "file_id": {"type": "string", "description": "File ID to ingest."}, + }, + "required": ["file_id"], + }, +} + +FILE_DELETE_SCHEMA = { + "name": "retaindb_delete_file", + "description": "Delete a stored file.", + "parameters": { + "type": "object", + "properties": { + "file_id": {"type": "string", "description": "File ID to delete."}, + }, + "required": ["file_id"], + }, +} + # --------------------------------------------------------------------------- -# MemoryProvider implementation +# HTTP client +# --------------------------------------------------------------------------- + +class _Client: + def __init__(self, api_key: str, base_url: str, project: str): + self.api_key = api_key + self.base_url = re.sub(r"/+$", "", base_url) + self.project = project + + def _headers(self, path: str) -> dict: + token = self.api_key.replace("Bearer ", "").strip() + h = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + "x-sdk-runtime": "hermes-plugin", + } + if path.startswith("/v1/memory") or path.startswith("/v1/context"): + h["X-API-Key"] = token + return h + + def request(self, method: str, path: str, *, params=None, json_body=None, timeout: float = 8.0) -> Any: + import requests + url = f"{self.base_url}{path}" + resp = requests.request( + method.upper(), url, + params=params, + json=json_body if method.upper() not in {"GET", "DELETE"} else None, + headers=self._headers(path), + timeout=timeout, + ) + try: + payload = resp.json() + except Exception: + payload = resp.text + if not resp.ok: + msg = "" + if isinstance(payload, dict): + msg = str(payload.get("message") or payload.get("error") or "") + raise RuntimeError(f"RetainDB {method} {path} failed ({resp.status_code}): {msg or payload}") + return payload + + # ── Memory ──────────────────────────────────────────────────────────────── + + def query_context(self, user_id: str, session_id: str, query: str, max_tokens: int = 1200) -> dict: + return self.request("POST", "/v1/context/query", json_body={ + "project": self.project, + "query": query, + "user_id": user_id, + "session_id": session_id, + "include_memories": True, + "max_tokens": max_tokens, + }) + + def search(self, user_id: str, session_id: str, query: str, top_k: int = 8) -> dict: + return self.request("POST", "/v1/memory/search", json_body={ + "project": self.project, + "query": query, + "user_id": user_id, + "session_id": session_id, + "top_k": top_k, + "include_pending": True, + }) + + def get_profile(self, user_id: str) -> dict: + try: + return self.request("GET", f"/v1/memory/profile/{quote(user_id, safe='')}", params={"project": self.project, "include_pending": "true"}) + except Exception: + return self.request("GET", "/v1/memories", params={"project": self.project, "user_id": user_id, "limit": "200"}) + + def add_memory(self, user_id: str, session_id: str, content: str, memory_type: str = "factual", importance: float = 0.7) -> dict: + try: + return self.request("POST", "/v1/memory", json_body={ + "project": self.project, "content": content, "memory_type": memory_type, + "user_id": user_id, "session_id": session_id, "importance": importance, "write_mode": "sync", + }, timeout=5.0) + except Exception: + return self.request("POST", "/v1/memories", json_body={ + "project": self.project, "content": content, "memory_type": memory_type, + "user_id": user_id, "session_id": session_id, "importance": importance, + }, timeout=5.0) + + def delete_memory(self, memory_id: str) -> dict: + try: + return self.request("DELETE", f"/v1/memory/{quote(memory_id, safe='')}", timeout=5.0) + except Exception: + return self.request("DELETE", f"/v1/memories/{quote(memory_id, safe='')}", timeout=5.0) + + def ingest_session(self, user_id: str, session_id: str, messages: list, timeout: float = 15.0) -> dict: + return self.request("POST", "/v1/memory/ingest/session", json_body={ + "project": self.project, "session_id": session_id, "user_id": user_id, + "messages": messages, "write_mode": "sync", + }, timeout=timeout) + + def ask_user(self, user_id: str, query: str, reasoning_level: str = "low") -> dict: + return self.request("POST", f"/v1/memory/profile/{quote(user_id, safe='')}/ask", json_body={ + "project": self.project, "query": query, "reasoning_level": reasoning_level, + }, timeout=8.0) + + def get_agent_model(self, agent_id: str) -> dict: + return self.request("GET", f"/v1/memory/agent/{quote(agent_id, safe='')}/model", params={"project": self.project}, timeout=4.0) + + def seed_agent_identity(self, agent_id: str, content: str, source: str = "soul_md") -> dict: + return self.request("POST", f"/v1/memory/agent/{quote(agent_id, safe='')}/seed", json_body={ + "project": self.project, "content": content, "source": source, + }, timeout=20.0) + + # ── Files ───────────────────────────────────────────────────────────────── + + def upload_file(self, data: bytes, filename: str, remote_path: str, mime_type: str, scope: str, project_id: str | None) -> dict: + import io + import requests + url = f"{self.base_url}/v1/files" + token = self.api_key.replace("Bearer ", "").strip() + headers = {"Authorization": f"Bearer {token}", "x-sdk-runtime": "hermes-plugin"} + fields = {"path": remote_path, "scope": scope.upper()} + if project_id: + fields["project_id"] = project_id + resp = requests.post(url, files={"file": (filename, io.BytesIO(data), mime_type)}, data=fields, headers=headers, timeout=30) + resp.raise_for_status() + return resp.json() + + def list_files(self, prefix: str | None = None, limit: int = 50) -> dict: + params: dict = {"limit": limit} + if prefix: + params["prefix"] = prefix + return self.request("GET", "/v1/files", params=params) + + def get_file(self, file_id: str) -> dict: + return self.request("GET", f"/v1/files/{quote(file_id, safe='')}") + + def read_file_content(self, file_id: str) -> bytes: + import requests + token = self.api_key.replace("Bearer ", "").strip() + url = f"{self.base_url}/v1/files/{quote(file_id, safe='')}/content" + resp = requests.get(url, headers={"Authorization": f"Bearer {token}", "x-sdk-runtime": "hermes-plugin"}, timeout=30, allow_redirects=True) + resp.raise_for_status() + return resp.content + + def ingest_file(self, file_id: str, user_id: str | None = None, agent_id: str | None = None) -> dict: + body: dict = {} + if user_id: + body["user_id"] = user_id + if agent_id: + body["agent_id"] = agent_id + return self.request("POST", f"/v1/files/{quote(file_id, safe='')}/ingest", json_body=body, timeout=60.0) + + def delete_file(self, file_id: str) -> dict: + return self.request("DELETE", f"/v1/files/{quote(file_id, safe='')}", timeout=5.0) + + +# --------------------------------------------------------------------------- +# Durable write-behind queue +# --------------------------------------------------------------------------- + +class _WriteQueue: + """SQLite-backed async write queue. Survives crashes — pending rows replay on startup.""" + + def __init__(self, client: _Client, db_path: Path): + self._client = client + self._db_path = db_path + self._q: queue.Queue = queue.Queue() + self._thread = threading.Thread(target=self._loop, name="retaindb-writer", daemon=True) + self._db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_db() + self._thread.start() + # Replay any rows left from a previous crash + for row_id, user_id, session_id, msgs_json in self._pending_rows(): + self._q.put((row_id, user_id, session_id, json.loads(msgs_json))) + + def _connect(self) -> sqlite3.Connection: + conn = sqlite3.connect(str(self._db_path), timeout=30) + conn.row_factory = sqlite3.Row + return conn + + def _init_db(self) -> None: + with self._connect() as conn: + conn.execute("""CREATE TABLE IF NOT EXISTS pending ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT, session_id TEXT, messages_json TEXT, + created_at TEXT, last_error TEXT + )""") + conn.commit() + + def _pending_rows(self) -> list: + with self._connect() as conn: + return conn.execute("SELECT id, user_id, session_id, messages_json FROM pending ORDER BY id ASC LIMIT 200").fetchall() + + def enqueue(self, user_id: str, session_id: str, messages: list) -> None: + now = datetime.now(timezone.utc).isoformat() + with self._connect() as conn: + cur = conn.execute( + "INSERT INTO pending (user_id, session_id, messages_json, created_at) VALUES (?,?,?,?)", + (user_id, session_id, json.dumps(messages, ensure_ascii=False), now), + ) + row_id = cur.lastrowid + conn.commit() + self._q.put((row_id, user_id, session_id, messages)) + + def _flush_row(self, row_id: int, user_id: str, session_id: str, messages: list) -> None: + try: + self._client.ingest_session(user_id, session_id, messages) + with self._connect() as conn: + conn.execute("DELETE FROM pending WHERE id = ?", (row_id,)) + conn.commit() + except Exception as exc: + logger.warning("RetainDB ingest failed (will retry): %s", exc) + with self._connect() as conn: + conn.execute("UPDATE pending SET last_error = ? WHERE id = ?", (str(exc), row_id)) + conn.commit() + time.sleep(2) + + def _loop(self) -> None: + while True: + try: + item = self._q.get(timeout=5) + if item is _ASYNC_SHUTDOWN: + break + self._flush_row(*item) + except queue.Empty: + continue + except Exception as exc: + logger.error("RetainDB writer error: %s", exc) + + def shutdown(self) -> None: + self._q.put(_ASYNC_SHUTDOWN) + self._thread.join(timeout=10) + + +# --------------------------------------------------------------------------- +# Overlay formatter +# --------------------------------------------------------------------------- + +def _build_overlay(profile: dict, query_result: dict, local_entries: list[str] | None = None) -> str: + def _compact(s: str) -> str: + return re.sub(r"\s+", " ", str(s or "")).strip()[:320] + + def _norm(s: str) -> str: + return re.sub(r"[^a-z0-9 ]", "", _compact(s).lower()) + + seen: list[str] = [_norm(e) for e in (local_entries or []) if _norm(e)] + profile_items: list[str] = [] + for m in list((profile or {}).get("memories") or [])[:5]: + c = _compact((m or {}).get("content") or "") + n = _norm(c) + if c and n not in seen: + seen.append(n) + profile_items.append(c) + + query_items: list[str] = [] + for r in list((query_result or {}).get("results") or [])[:5]: + c = _compact((r or {}).get("content") or "") + n = _norm(c) + if c and n not in seen: + seen.append(n) + query_items.append(c) + + if not profile_items and not query_items: + return "" + + lines = ["[RetainDB Context]", "Profile:"] + lines += [f"- {i}" for i in profile_items] or ["- None"] + lines.append("Relevant memories:") + lines += [f"- {i}" for i in query_items] or ["- None"] + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Main plugin class # --------------------------------------------------------------------------- class RetainDBMemoryProvider(MemoryProvider): - """RetainDB cloud memory with write-behind queue and semantic search.""" + """RetainDB cloud memory — durable queue, semantic search, dialectic synthesis, shared files.""" def __init__(self): - self._api_key = "" - self._base_url = _DEFAULT_BASE_URL - self._project = "hermes" - self._user_id = "" - self._prefetch_result = "" - self._prefetch_lock = threading.Lock() - self._prefetch_thread = None - self._sync_thread = None + self._client: _Client | None = None + self._queue: _WriteQueue | None = None + self._user_id = "default" + self._session_id = "" + self._agent_id = "hermes" + self._lock = threading.Lock() + + # Prefetch caches + self._context_result = "" + self._dialectic_result = "" + self._agent_model: dict = {} + + # ── Core identity ────────────────────────────────────────────────────── @property def name(self) -> str: @@ -122,179 +468,275 @@ class RetainDBMemoryProvider(MemoryProvider): def is_available(self) -> bool: return bool(os.environ.get("RETAINDB_API_KEY")) - def get_config_schema(self): + def get_config_schema(self) -> List[Dict[str, Any]]: return [ {"key": "api_key", "description": "RetainDB API key", "secret": True, "required": True, "env_var": "RETAINDB_API_KEY", "url": "https://retaindb.com"}, - {"key": "base_url", "description": "API endpoint", "default": "https://api.retaindb.com"}, + {"key": "base_url", "description": "API endpoint", "default": _DEFAULT_BASE_URL}, {"key": "project", "description": "Project identifier", "default": "hermes"}, ] - def _headers(self) -> dict: - return { - "Authorization": f"Bearer {self._api_key}", - "Content-Type": "application/json", - } - - def _api(self, method: str, path: str, **kwargs): - """Make an API call to RetainDB.""" - import requests - url = f"{self._base_url}{path}" - resp = requests.request(method, url, headers=self._headers(), timeout=30, **kwargs) - resp.raise_for_status() - return resp.json() + # ── Lifecycle ────────────────────────────────────────────────────────── def initialize(self, session_id: str, **kwargs) -> None: - self._api_key = os.environ.get("RETAINDB_API_KEY", "") - self._base_url = os.environ.get("RETAINDB_BASE_URL", _DEFAULT_BASE_URL) - self._user_id = kwargs.get("user_id", "default") - self._session_id = session_id + api_key = os.environ.get("RETAINDB_API_KEY", "") + base_url = re.sub(r"/+$", "", os.environ.get("RETAINDB_BASE_URL", _DEFAULT_BASE_URL)) - # Derive profile-scoped project name so different profiles don't - # share server-side memory. Explicit RETAINDB_PROJECT always wins. - explicit_project = os.environ.get("RETAINDB_PROJECT") - if explicit_project: - self._project = explicit_project + # Profile-isolated project: RETAINDB_PROJECT > hermes- > hermes + explicit = os.environ.get("RETAINDB_PROJECT") + if explicit: + project = explicit else: - hermes_home = kwargs.get("hermes_home", "") + hermes_home = str(kwargs.get("hermes_home", "")) profile_name = os.path.basename(hermes_home) if hermes_home else "" - # Default profile (~/.hermes) → "hermes"; named profiles → "hermes-" - if profile_name and profile_name != ".hermes": - self._project = f"hermes-{profile_name}" - else: - self._project = "hermes" + project = f"hermes-{profile_name}" if (profile_name and profile_name != ".hermes") else "hermes" + + self._client = _Client(api_key, base_url, project) + self._session_id = session_id + self._user_id = kwargs.get("user_id", "default") or "default" + self._agent_id = kwargs.get("agent_id", "hermes") or "hermes" + + hermes_home_path = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) + db_path = hermes_home_path / "retaindb_queue.db" + self._queue = _WriteQueue(self._client, db_path) + + # Seed agent identity from SOUL.md in background + soul_path = hermes_home_path / "SOUL.md" + if soul_path.exists(): + soul_content = soul_path.read_text(encoding="utf-8", errors="replace").strip() + if soul_content: + threading.Thread( + target=self._seed_soul, + args=(soul_content,), + name="retaindb-soul-seed", + daemon=True, + ).start() + + def _seed_soul(self, content: str) -> None: + try: + self._client.seed_agent_identity(self._agent_id, content, source="soul_md") + except Exception as exc: + logger.debug("RetainDB soul seed failed: %s", exc) def system_prompt_block(self) -> str: + project = self._client.project if self._client else "retaindb" return ( "# RetainDB Memory\n" - f"Active. Project: {self._project}.\n" + f"Active. Project: {project}.\n" "Use retaindb_search to find memories, retaindb_remember to store facts, " - "retaindb_profile for a user overview, retaindb_context for task-relevant context." + "retaindb_profile for a user overview, retaindb_context for current-task context." ) - def prefetch(self, query: str, *, session_id: str = "") -> str: - if self._prefetch_thread and self._prefetch_thread.is_alive(): - self._prefetch_thread.join(timeout=3.0) - with self._prefetch_lock: - result = self._prefetch_result - self._prefetch_result = "" - if not result: - return "" - return f"## RetainDB Memory\n{result}" + # ── Background prefetch (fires at turn-end, consumed next turn-start) ── def queue_prefetch(self, query: str, *, session_id: str = "") -> None: - def _run(): - try: - data = self._api("POST", "/v1/recall", json={ - "project": self._project, - "query": query, - "user_id": self._user_id, - "top_k": 5, - }) - results = data.get("results", []) - if results: - lines = [r.get("content", "") for r in results if r.get("content")] - with self._prefetch_lock: - self._prefetch_result = "\n".join(f"- {l}" for l in lines) - except Exception as e: - logger.debug("RetainDB prefetch failed: %s", e) + """Fire context + dialectic + agent model prefetches in background.""" + if not self._client: + return + threading.Thread(target=self._prefetch_context, args=(query,), name="retaindb-ctx", daemon=True).start() + threading.Thread(target=self._prefetch_dialectic, args=(query,), name="retaindb-dialectic", daemon=True).start() + threading.Thread(target=self._prefetch_agent_model, name="retaindb-agent-model", daemon=True).start() - self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="retaindb-prefetch") - self._prefetch_thread.start() + def _prefetch_context(self, query: str) -> None: + try: + query_result = self._client.query_context(self._user_id, self._session_id, query) + profile = self._client.get_profile(self._user_id) + overlay = _build_overlay(profile, query_result) + with self._lock: + self._context_result = overlay + except Exception as exc: + logger.debug("RetainDB context prefetch failed: %s", exc) + + def _prefetch_dialectic(self, query: str) -> None: + try: + result = self._client.ask_user(self._user_id, query, reasoning_level=self._reasoning_level(query)) + answer = str(result.get("answer") or "") + if answer: + with self._lock: + self._dialectic_result = answer + except Exception as exc: + logger.debug("RetainDB dialectic prefetch failed: %s", exc) + + def _prefetch_agent_model(self) -> None: + try: + model = self._client.get_agent_model(self._agent_id) + if model.get("memory_count", 0) > 0: + with self._lock: + self._agent_model = model + except Exception as exc: + logger.debug("RetainDB agent model prefetch failed: %s", exc) + + @staticmethod + def _reasoning_level(query: str) -> str: + n = len(query) + if n < 120: + return "low" + if n < 400: + return "medium" + return "high" + + def prefetch(self, query: str, *, session_id: str = "") -> str: + """Consume prefetched results and return them as a context block.""" + with self._lock: + context = self._context_result + dialectic = self._dialectic_result + agent_model = self._agent_model + self._context_result = "" + self._dialectic_result = "" + self._agent_model = {} + + parts: list[str] = [] + if context: + parts.append(context) + if dialectic: + parts.append(f"[RetainDB User Synthesis]\n{dialectic}") + if agent_model and agent_model.get("memory_count", 0) > 0: + model_lines: list[str] = [] + if agent_model.get("persona"): + model_lines.append(f"Persona: {agent_model['persona']}") + if agent_model.get("persistent_instructions"): + model_lines.append("Instructions:\n" + "\n".join(f"- {i}" for i in agent_model["persistent_instructions"])) + if agent_model.get("working_style"): + model_lines.append(f"Working style: {agent_model['working_style']}") + if model_lines: + parts.append("[RetainDB Agent Self-Model]\n" + "\n".join(model_lines)) + + return "\n\n".join(parts) + + # ── Turn sync ────────────────────────────────────────────────────────── def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: - """Ingest conversation turn in background (non-blocking).""" - def _sync(): - try: - self._api("POST", "/v1/ingest", json={ - "project": self._project, - "user_id": self._user_id, - "session_id": self._session_id, - "messages": [ - {"role": "user", "content": user_content}, - {"role": "assistant", "content": assistant_content}, - ], - }) - except Exception as e: - logger.warning("RetainDB sync failed: %s", e) + """Queue turn for async ingest. Returns immediately.""" + if not self._queue or not user_content: + return + now = datetime.now(timezone.utc).isoformat() + self._queue.enqueue( + self._user_id, + session_id or self._session_id, + [ + {"role": "user", "content": user_content, "timestamp": now}, + {"role": "assistant", "content": assistant_content, "timestamp": now}, + ], + ) - if self._sync_thread and self._sync_thread.is_alive(): - self._sync_thread.join(timeout=5.0) - self._sync_thread = threading.Thread(target=_sync, daemon=True, name="retaindb-sync") - self._sync_thread.start() + # ── Tools ────────────────────────────────────────────────────────────── def get_tool_schemas(self) -> List[Dict[str, Any]]: - return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA, REMEMBER_SCHEMA, FORGET_SCHEMA] + return [ + PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA, + REMEMBER_SCHEMA, FORGET_SCHEMA, + FILE_UPLOAD_SCHEMA, FILE_LIST_SCHEMA, FILE_READ_SCHEMA, + FILE_INGEST_SCHEMA, FILE_DELETE_SCHEMA, + ] def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str: + if not self._client: + return json.dumps({"error": "RetainDB not initialized"}) try: - if tool_name == "retaindb_profile": - data = self._api("GET", f"/v1/profile/{self._project}/{self._user_id}") - return json.dumps(data) + return json.dumps(self._dispatch(tool_name, args)) + except Exception as exc: + return json.dumps({"error": str(exc)}) - elif tool_name == "retaindb_search": - query = args.get("query", "") - if not query: - return json.dumps({"error": "query is required"}) - data = self._api("POST", "/v1/search", json={ - "project": self._project, - "user_id": self._user_id, - "query": query, - "top_k": min(int(args.get("top_k", 8)), 20), - }) - return json.dumps(data) + def _dispatch(self, tool_name: str, args: dict) -> Any: + c = self._client - elif tool_name == "retaindb_context": - query = args.get("query", "") - if not query: - return json.dumps({"error": "query is required"}) - data = self._api("POST", "/v1/recall", json={ - "project": self._project, - "user_id": self._user_id, - "query": query, - "top_k": 5, - }) - return json.dumps(data) + if tool_name == "retaindb_profile": + return c.get_profile(self._user_id) - elif tool_name == "retaindb_remember": - content = args.get("content", "") - if not content: - return json.dumps({"error": "content is required"}) - data = self._api("POST", "/v1/remember", json={ - "project": self._project, - "user_id": self._user_id, - "content": content, - "memory_type": args.get("memory_type", "fact"), - "importance": float(args.get("importance", 0.5)), - }) - return json.dumps(data) + if tool_name == "retaindb_search": + query = args.get("query", "") + if not query: + return {"error": "query is required"} + return c.search(self._user_id, self._session_id, query, top_k=min(int(args.get("top_k", 8)), 20)) - elif tool_name == "retaindb_forget": - memory_id = args.get("memory_id", "") - if not memory_id: - return json.dumps({"error": "memory_id is required"}) - data = self._api("DELETE", f"/v1/memory/{memory_id}") - return json.dumps(data) + if tool_name == "retaindb_context": + query = args.get("query", "") + if not query: + return {"error": "query is required"} + query_result = c.query_context(self._user_id, self._session_id, query) + profile = c.get_profile(self._user_id) + overlay = _build_overlay(profile, query_result) + return {"context": overlay, "raw": query_result} - return json.dumps({"error": f"Unknown tool: {tool_name}"}) - except Exception as e: - return json.dumps({"error": str(e)}) + if tool_name == "retaindb_remember": + content = args.get("content", "") + if not content: + return {"error": "content is required"} + return c.add_memory( + self._user_id, self._session_id, content, + memory_type=args.get("memory_type", "factual"), + importance=float(args.get("importance", 0.7)), + ) + + if tool_name == "retaindb_forget": + memory_id = args.get("memory_id", "") + if not memory_id: + return {"error": "memory_id is required"} + return c.delete_memory(memory_id) + + # ── File tools ────────────────────────────────────────────────────── + + if tool_name == "retaindb_upload_file": + local_path = args.get("local_path", "") + if not local_path: + return {"error": "local_path is required"} + path_obj = Path(local_path) + if not path_obj.exists(): + return {"error": f"File not found: {local_path}"} + data = path_obj.read_bytes() + import mimetypes + mime = mimetypes.guess_type(path_obj.name)[0] or "application/octet-stream" + remote_path = args.get("remote_path") or f"/{path_obj.name}" + result = c.upload_file(data, path_obj.name, remote_path, mime, args.get("scope", "PROJECT"), None) + if args.get("ingest") and result.get("file", {}).get("id"): + ingest = c.ingest_file(result["file"]["id"], user_id=self._user_id, agent_id=self._agent_id) + result["ingest"] = ingest + return result + + if tool_name == "retaindb_list_files": + return c.list_files(prefix=args.get("prefix"), limit=int(args.get("limit", 50))) + + if tool_name == "retaindb_read_file": + file_id = args.get("file_id", "") + if not file_id: + return {"error": "file_id is required"} + meta = c.get_file(file_id) + file_info = meta.get("file") or {} + mime = (file_info.get("mime_type") or "").lower() + raw = c.read_file_content(file_id) + if not (mime.startswith("text/") or any(file_info.get("name", "").endswith(e) for e in (".txt", ".md", ".json", ".csv", ".yaml", ".yml", ".xml", ".html"))): + return {"file_id": file_id, "rdb_uri": file_info.get("rdb_uri"), "name": file_info.get("name"), "content": None, "note": "Binary file — use retaindb_ingest_file to extract text into memory."} + text = raw.decode("utf-8", errors="replace") + return {"file_id": file_id, "rdb_uri": file_info.get("rdb_uri"), "name": file_info.get("name"), "content": text[:32000], "truncated": len(text) > 32000} + + if tool_name == "retaindb_ingest_file": + file_id = args.get("file_id", "") + if not file_id: + return {"error": "file_id is required"} + return c.ingest_file(file_id, user_id=self._user_id, agent_id=self._agent_id) + + if tool_name == "retaindb_delete_file": + file_id = args.get("file_id", "") + if not file_id: + return {"error": "file_id is required"} + return c.delete_file(file_id) + + return {"error": f"Unknown tool: {tool_name}"} + + # ── Optional hooks ───────────────────────────────────────────────────── def on_memory_write(self, action: str, target: str, content: str) -> None: - if action == "add": - try: - self._api("POST", "/v1/remember", json={ - "project": self._project, - "user_id": self._user_id, - "content": content, - "memory_type": "preference" if target == "user" else "fact", - }) - except Exception as e: - logger.debug("RetainDB memory bridge failed: %s", e) + """Mirror built-in memory writes to RetainDB.""" + if action != "add" or not content or not self._client: + return + try: + memory_type = "preference" if target == "user" else "factual" + self._client.add_memory(self._user_id, self._session_id, content, memory_type=memory_type) + except Exception as exc: + logger.debug("RetainDB memory mirror failed: %s", exc) def shutdown(self) -> None: - for t in (self._prefetch_thread, self._sync_thread): - if t and t.is_alive(): - t.join(timeout=5.0) + if self._queue: + self._queue.shutdown() def register(ctx) -> None: