Compare commits
1 Commits
fix/799
...
claude/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75e31bee27 |
58
agent/atlas/__init__.py
Normal file
58
agent/atlas/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""ATLAS — Lossless Context + Memory Subsystem for Hermes.
|
||||
|
||||
ATLAS (Adaptive Turn-Lineage Archival System) replaces destructive context
|
||||
truncation with a three-tier store and summary DAG compaction.
|
||||
|
||||
Design goals (from hermes-lcm + gbrain):
|
||||
- **Lossless**: every turn is persisted with a stable lineage ID before it
|
||||
can ever be evicted from the active context window.
|
||||
- **DAG compaction**: when the context window fills, ATLAS produces a summary
|
||||
DAG node that references its source turns — history is never deleted.
|
||||
- **Three explicit stores**: writes are routed to the correct bucket:
|
||||
- ``world_knowledge``: static facts about the world / domain
|
||||
- ``durable_memory``: user preferences, corrections, long-lived facts
|
||||
- ``session_state``: current-session working notes (ephemeral)
|
||||
- **Typed links**: deterministic relation extraction on every write.
|
||||
- **Recall tools**: the agent can call ``atlas_search``, ``atlas_describe``,
|
||||
and ``atlas_expand`` to reach compacted history without re-injecting
|
||||
the full original transcript.
|
||||
|
||||
Submodules
|
||||
----------
|
||||
db — shared SQLite connection and schema bootstrap
|
||||
turns — RawTurnStore: immutable per-turn records
|
||||
dag — SummaryDAGStore: summary nodes with source references
|
||||
stores — WorldKnowledgeStore, DurableMemoryStore, SessionStateStore
|
||||
extractor — TypedLinkExtractor: deterministic typed-link extraction
|
||||
recall — RecallEngine: search / describe / expand operations
|
||||
"""
|
||||
|
||||
from agent.atlas.db import AtlasDB
|
||||
from agent.atlas.turns import RawTurnStore, TurnRecord
|
||||
from agent.atlas.dag import SummaryDAGStore, SummaryNode
|
||||
from agent.atlas.stores import (
|
||||
WorldKnowledgeStore,
|
||||
DurableMemoryStore,
|
||||
SessionStateStore,
|
||||
AtlasStores,
|
||||
StoreTarget,
|
||||
)
|
||||
from agent.atlas.extractor import TypedLinkExtractor, TypedLink, RelationType
|
||||
from agent.atlas.recall import RecallEngine
|
||||
|
||||
__all__ = [
|
||||
"AtlasDB",
|
||||
"RawTurnStore",
|
||||
"TurnRecord",
|
||||
"SummaryDAGStore",
|
||||
"SummaryNode",
|
||||
"WorldKnowledgeStore",
|
||||
"DurableMemoryStore",
|
||||
"SessionStateStore",
|
||||
"AtlasStores",
|
||||
"StoreTarget",
|
||||
"TypedLinkExtractor",
|
||||
"TypedLink",
|
||||
"RelationType",
|
||||
"RecallEngine",
|
||||
]
|
||||
134
agent/atlas/dag.py
Normal file
134
agent/atlas/dag.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""SummaryDAGStore — summary DAG nodes produced by ATLAS compaction.
|
||||
|
||||
When the active context window fills, a compaction event creates a
|
||||
``SummaryNode`` that references the source turns it summarised. Source turns
|
||||
are **never deleted**; they remain in ``raw_turns`` and are reachable via
|
||||
``atlas_expand``.
|
||||
|
||||
Node ID format: ``dag:<session_id>:<index:04d>``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from agent.atlas.db import AtlasDB
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummaryNode:
|
||||
"""A summary DAG node.
|
||||
|
||||
``source_turn_ids`` lists every turn that was compacted into this node.
|
||||
``parent_node_id`` chains summaries-of-summaries for multi-level compaction.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
session_id: str
|
||||
summary_text: str
|
||||
source_turn_ids: List[str]
|
||||
parent_node_id: Optional[str] = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
token_count: int = 0
|
||||
|
||||
def source_count(self) -> int:
|
||||
return len(self.source_turn_ids)
|
||||
|
||||
|
||||
class SummaryDAGStore:
|
||||
"""Append-only store for summary DAG nodes."""
|
||||
|
||||
def __init__(self, db: AtlasDB) -> None:
|
||||
self._db = db
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Write
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
session_id: str,
|
||||
summary_text: str,
|
||||
source_turn_ids: List[str],
|
||||
*,
|
||||
parent_node_id: Optional[str] = None,
|
||||
) -> SummaryNode:
|
||||
"""Persist a new summary node and return it."""
|
||||
con = self._db.conn()
|
||||
row = con.execute(
|
||||
"SELECT COALESCE(MAX(CAST(SUBSTR(node_id, LENGTH(?) + 2) AS INTEGER)), -1) "
|
||||
"FROM summary_dag WHERE session_id = ?",
|
||||
(f"dag:{session_id}", session_id),
|
||||
).fetchone()
|
||||
next_idx: int = (row[0] if row and row[0] is not None else -1) + 1
|
||||
node_id = f"dag:{session_id}:{next_idx:04d}"
|
||||
ts = time.time()
|
||||
token_count = max(1, len(summary_text) // 4)
|
||||
source_json = json.dumps(source_turn_ids)
|
||||
con.execute(
|
||||
"""
|
||||
INSERT INTO summary_dag
|
||||
(node_id, session_id, summary_text, source_turn_ids,
|
||||
parent_node_id, created_at, token_count)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(node_id, session_id, summary_text, source_json,
|
||||
parent_node_id, ts, token_count),
|
||||
)
|
||||
con.commit()
|
||||
return SummaryNode(
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
summary_text=summary_text,
|
||||
source_turn_ids=source_turn_ids,
|
||||
parent_node_id=parent_node_id,
|
||||
created_at=ts,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Read
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[SummaryNode]:
|
||||
"""Fetch a summary node by ID."""
|
||||
row = self._db.conn().execute(
|
||||
"SELECT * FROM summary_dag WHERE node_id = ?", (node_id,)
|
||||
).fetchone()
|
||||
return _row_to_node(row) if row else None
|
||||
|
||||
def get_session_nodes(self, session_id: str) -> List[SummaryNode]:
|
||||
"""Return all summary nodes for a session, oldest first."""
|
||||
rows = self._db.conn().execute(
|
||||
"SELECT * FROM summary_dag WHERE session_id = ? ORDER BY created_at ASC",
|
||||
(session_id,),
|
||||
).fetchall()
|
||||
return [_row_to_node(r) for r in rows]
|
||||
|
||||
def get_latest_node(self, session_id: str) -> Optional[SummaryNode]:
|
||||
"""Return the most-recent summary node for a session."""
|
||||
row = self._db.conn().execute(
|
||||
"SELECT * FROM summary_dag WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
return _row_to_node(row) if row else None
|
||||
|
||||
def count_session_nodes(self, session_id: str) -> int:
|
||||
"""Count summary nodes for a session."""
|
||||
row = self._db.conn().execute(
|
||||
"SELECT COUNT(*) FROM summary_dag WHERE session_id = ?", (session_id,)
|
||||
).fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
|
||||
def _row_to_node(row: object) -> SummaryNode:
|
||||
return SummaryNode(
|
||||
node_id=row["node_id"],
|
||||
session_id=row["session_id"],
|
||||
summary_text=row["summary_text"],
|
||||
source_turn_ids=json.loads(row["source_turn_ids"]),
|
||||
parent_node_id=row["parent_node_id"],
|
||||
created_at=row["created_at"],
|
||||
token_count=row["token_count"],
|
||||
)
|
||||
210
agent/atlas/db.py
Normal file
210
agent/atlas/db.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Shared SQLite connection and schema bootstrap for ATLAS.
|
||||
|
||||
Single SQLite database at ``HERMES_HOME/atlas.db`` (WAL mode, one writer).
|
||||
All ATLAS submodules obtain a connection via ``AtlasDB``.
|
||||
|
||||
Schema version is tracked in ``schema_version``. Migrations run on open.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SCHEMA_VERSION = 1
|
||||
|
||||
_SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
version INTEGER NOT NULL
|
||||
);
|
||||
|
||||
-- Immutable raw turn records (never deleted).
|
||||
CREATE TABLE IF NOT EXISTS raw_turns (
|
||||
turn_id TEXT PRIMARY KEY, -- "<session_id>:<index:04d>"
|
||||
session_id TEXT NOT NULL,
|
||||
turn_index INTEGER NOT NULL,
|
||||
role TEXT NOT NULL, -- user / assistant / tool
|
||||
content TEXT NOT NULL,
|
||||
tool_name TEXT,
|
||||
tool_call_id TEXT,
|
||||
timestamp REAL NOT NULL,
|
||||
token_count INTEGER NOT NULL DEFAULT 0,
|
||||
UNIQUE(session_id, turn_index)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_raw_turns_session ON raw_turns(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_raw_turns_ts ON raw_turns(timestamp);
|
||||
|
||||
-- Summary DAG nodes produced by compaction.
|
||||
-- A node references the source turns it summarised; turns are never deleted.
|
||||
CREATE TABLE IF NOT EXISTS summary_dag (
|
||||
node_id TEXT PRIMARY KEY, -- "dag:<session_id>:<index:04d>"
|
||||
session_id TEXT NOT NULL,
|
||||
summary_text TEXT NOT NULL,
|
||||
source_turn_ids TEXT NOT NULL, -- JSON array of turn_ids
|
||||
parent_node_id TEXT, -- chained compactions
|
||||
created_at REAL NOT NULL,
|
||||
token_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_dag_session ON summary_dag(session_id);
|
||||
|
||||
-- Typed relational links extracted from turns / nodes.
|
||||
CREATE TABLE IF NOT EXISTS typed_links (
|
||||
link_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source_id TEXT NOT NULL, -- turn_id or node_id
|
||||
source_type TEXT NOT NULL, -- "turn" or "dag"
|
||||
relation_type TEXT NOT NULL, -- see RelationType enum
|
||||
subject TEXT NOT NULL,
|
||||
object TEXT NOT NULL,
|
||||
confidence REAL NOT NULL DEFAULT 1.0,
|
||||
created_at REAL NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_links_source ON typed_links(source_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_links_rel ON typed_links(relation_type);
|
||||
|
||||
-- World knowledge store: static/universal facts.
|
||||
CREATE TABLE IF NOT EXISTS world_knowledge (
|
||||
wk_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
content TEXT NOT NULL,
|
||||
tags TEXT NOT NULL DEFAULT '',
|
||||
trust REAL NOT NULL DEFAULT 0.8,
|
||||
created_at REAL NOT NULL,
|
||||
updated_at REAL NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_wk_tags ON world_knowledge(tags);
|
||||
|
||||
-- Durable memory store: user-specific facts that outlive sessions.
|
||||
CREATE TABLE IF NOT EXISTS durable_memory (
|
||||
dm_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
category TEXT NOT NULL DEFAULT 'general',
|
||||
content TEXT NOT NULL,
|
||||
tags TEXT NOT NULL DEFAULT '',
|
||||
trust REAL NOT NULL DEFAULT 0.7,
|
||||
created_at REAL NOT NULL,
|
||||
updated_at REAL NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_dm_category ON durable_memory(category);
|
||||
|
||||
-- Session state store: ephemeral working notes for the current session.
|
||||
CREATE TABLE IF NOT EXISTS session_state (
|
||||
ss_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at REAL NOT NULL,
|
||||
updated_at REAL NOT NULL,
|
||||
UNIQUE(session_id, key)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ss_session ON session_state(session_id);
|
||||
|
||||
-- FTS5 search across turns, dag nodes, and durable memory.
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS atlas_fts USING fts5(
|
||||
doc_id,
|
||||
doc_type, -- raw_turn | dag | durable | world
|
||||
content,
|
||||
tokenize = 'porter unicode61'
|
||||
);
|
||||
"""
|
||||
|
||||
_TRIGGER_SQL = """
|
||||
CREATE TRIGGER IF NOT EXISTS atlas_fts_raw_turns_ai
|
||||
AFTER INSERT ON raw_turns BEGIN
|
||||
INSERT INTO atlas_fts(doc_id, doc_type, content)
|
||||
VALUES (new.turn_id, 'raw_turn', new.content);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS atlas_fts_dag_ai
|
||||
AFTER INSERT ON summary_dag BEGIN
|
||||
INSERT INTO atlas_fts(doc_id, doc_type, content)
|
||||
VALUES (new.node_id, 'dag', new.summary_text);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS atlas_fts_durable_ai
|
||||
AFTER INSERT ON durable_memory BEGIN
|
||||
INSERT INTO atlas_fts(doc_id, doc_type, content)
|
||||
VALUES (CAST(new.dm_id AS TEXT), 'durable', new.content);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS atlas_fts_world_ai
|
||||
AFTER INSERT ON world_knowledge BEGIN
|
||||
INSERT INTO atlas_fts(doc_id, doc_type, content)
|
||||
VALUES (CAST(new.wk_id AS TEXT), 'world', new.content);
|
||||
END;
|
||||
"""
|
||||
|
||||
|
||||
class AtlasDB:
|
||||
"""Thread-safe SQLite connection manager for the ATLAS subsystem.
|
||||
|
||||
Usage::
|
||||
|
||||
db = AtlasDB(Path("~/.hermes/atlas.db").expanduser())
|
||||
with db.conn() as con:
|
||||
con.execute("SELECT * FROM raw_turns LIMIT 10")
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path) -> None:
|
||||
self._path = path
|
||||
self._local = threading.local()
|
||||
self._init_lock = threading.Lock()
|
||||
self._initialised = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public
|
||||
|
||||
def open(self) -> None:
|
||||
"""Bootstrap schema (idempotent). Call once at startup."""
|
||||
with self._init_lock:
|
||||
if self._initialised:
|
||||
return
|
||||
con = self._get_con()
|
||||
con.executescript(_SCHEMA_SQL)
|
||||
con.executescript(_TRIGGER_SQL)
|
||||
self._migrate(con)
|
||||
con.commit()
|
||||
self._initialised = True
|
||||
logger.debug("ATLAS DB opened at %s", self._path)
|
||||
|
||||
def conn(self) -> sqlite3.Connection:
|
||||
"""Return a per-thread connection (autocommit context manager)."""
|
||||
if not self._initialised:
|
||||
self.open()
|
||||
return self._get_con()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close per-thread connection if open."""
|
||||
con: Optional[sqlite3.Connection] = getattr(self._local, "con", None)
|
||||
if con is not None:
|
||||
con.close()
|
||||
self._local.con = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal
|
||||
|
||||
def _get_con(self) -> sqlite3.Connection:
|
||||
con: Optional[sqlite3.Connection] = getattr(self._local, "con", None)
|
||||
if con is None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
con = sqlite3.connect(str(self._path), check_same_thread=False)
|
||||
con.row_factory = sqlite3.Row
|
||||
con.execute("PRAGMA journal_mode=WAL")
|
||||
con.execute("PRAGMA foreign_keys=ON")
|
||||
self._local.con = con
|
||||
return con
|
||||
|
||||
def _migrate(self, con: sqlite3.Connection) -> None:
|
||||
row = con.execute("SELECT MAX(version) FROM schema_version").fetchone()
|
||||
current = row[0] if row and row[0] is not None else 0
|
||||
if current < _SCHEMA_VERSION:
|
||||
con.execute("INSERT INTO schema_version VALUES (?)", (_SCHEMA_VERSION,))
|
||||
logger.debug("ATLAS schema migrated %d -> %d", current, _SCHEMA_VERSION)
|
||||
286
agent/atlas/extractor.py
Normal file
286
agent/atlas/extractor.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Deterministic typed-link extraction for ATLAS.
|
||||
|
||||
Extracts relational triples (subject, relation, object) from text using
|
||||
only regex and heuristics — no LLM required. Results are reproducible
|
||||
and testable with fixtures.
|
||||
|
||||
Supported relation types (≥5 required by acceptance criteria):
|
||||
|
||||
mentions Entity appears in text (e.g. proper nouns, quoted terms)
|
||||
defines "X is a Y" / "X means Y" / "X is defined as Y"
|
||||
corrects "actually / no / correction: X not Y"
|
||||
prefers "I prefer X" / "prefer X over Y" / "use X instead"
|
||||
uses "uses X" / "runs X" / "built with X"
|
||||
depends_on "requires X" / "depends on X" / "needs X"
|
||||
part_of "X is part of Y" / "X belongs to Y"
|
||||
|
||||
All extraction functions are pure: ``extract_links(text, source_id)``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from agent.atlas.db import AtlasDB
|
||||
|
||||
|
||||
class RelationType(str, Enum):
|
||||
MENTIONS = "mentions"
|
||||
DEFINES = "defines"
|
||||
CORRECTS = "corrects"
|
||||
PREFERS = "prefers"
|
||||
USES = "uses"
|
||||
DEPENDS_ON = "depends_on"
|
||||
PART_OF = "part_of"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypedLink:
|
||||
source_id: str # turn_id or node_id
|
||||
source_type: str # "turn" or "dag"
|
||||
relation_type: RelationType
|
||||
subject: str
|
||||
object: str
|
||||
confidence: float = 1.0
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regex patterns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Named entities: Capitalized words (2+ chars), quoted phrases
|
||||
_RE_QUOTED = re.compile(r'"([^"]{2,80})"')
|
||||
_RE_SINGLE_QUOTED = re.compile(r"'([^']{2,40})'")
|
||||
_RE_CAMEL = re.compile(r'\b([A-Z][a-zA-Z0-9]{1,}(?:\s+[A-Z][a-zA-Z0-9]{1,})*)\b')
|
||||
|
||||
# defines: "X is a Y", "X means Y", "X is defined as Y", "define X as Y"
|
||||
_RE_DEFINES = re.compile(
|
||||
r'\b(\w[\w\s]{1,40}?)\s+(?:is\s+(?:a|an|the)\s+|means?\s+|is\s+defined\s+as\s+)(\w[\w\s]{1,40})',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# corrects: "actually, X", "correction: X", "not Y but X"
|
||||
# Allow optional comma/colon between trigger and content
|
||||
_RE_CORRECTS = re.compile(
|
||||
r'(?:actually|no[,\s]|wait[,\s]|correction[:\s]|that.?s\s+wrong)[,\s]+(.{5,120})',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_RE_NOT_BUT = re.compile(
|
||||
r"\b([\w\s]{2,40})\s+(?:not|isn't|is\s+not)\s+([\w\s]{2,40})\s+but\s+([\w\s]{2,40})",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# prefers: "I prefer X over Y" (two patterns to avoid lazy-quantifier under-match)
|
||||
# Pattern 1: "prefer X over Y" — greedy up to " over"
|
||||
_RE_PREFERS_OVER = re.compile(
|
||||
r'(?:I\s+)?(?:prefer|would\s+rather\s+use|rather\s+use|like\s+to\s+use)\s+'
|
||||
r'([\w][\w\s\-\.]+?)\s+over\s+([\w][\w\s\-\.]{1,50})(?:[.,;!?]|$)',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# Pattern 2: "prefer X" (no "over") — stop at punctuation or end of sentence
|
||||
_RE_PREFERS_SIMPLE = re.compile(
|
||||
r'(?:I\s+)?(?:prefer|would\s+rather\s+use|rather\s+use|like\s+to\s+use)\s+'
|
||||
r'([\w][\w\s\-\.]{2,50})(?=[.,;!?]|$|\s+(?:for|in|when|as\b))',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_RE_USE_INSTEAD = re.compile(
|
||||
r'(?:use|using)\s+([\w][\w\s\-\.]{1,50}?)\s+instead\s+of\s+([\w][\w\s\-\.]{1,50})',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# uses: "uses X", "built with X", "runs X", "deployed on X"
|
||||
# Allow optional quotes around the object (e.g. use "Redis")
|
||||
_RE_USES = re.compile(
|
||||
r'(?:\buses\b|\buse\b|built\s+with|running\s+on|deployed\s+on|written\s+in|powered\s+by)\s+'
|
||||
r'"?([^"\s][\w\s\-\.\/]{0,49})"?',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# depends_on: "requires X", "depends on X", "needs X"
|
||||
_RE_DEPENDS = re.compile(
|
||||
r'(?:requires?|depends?\s+on|needs?)\s+([\w][\w\s\-\.]{1,50})',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# part_of: "X is part of Y", "X belongs to Y"
|
||||
_RE_PART_OF = re.compile(
|
||||
r'([\w][\w\s\-]{1,40}?)\s+(?:is\s+part\s+of|belongs?\s+to)\s+([\w][\w\s\-]{1,40})',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _clean(s: str) -> str:
|
||||
return s.strip().rstrip(".,;:")[:120]
|
||||
|
||||
|
||||
def extract_links(
|
||||
text: str,
|
||||
source_id: str,
|
||||
*,
|
||||
source_type: str = "turn",
|
||||
max_per_type: int = 10,
|
||||
) -> List[TypedLink]:
|
||||
"""Extract typed links from a text block.
|
||||
|
||||
Pure function — deterministic, no side effects.
|
||||
|
||||
Args:
|
||||
text: The conversation text to analyse.
|
||||
source_id: The turn_id or dag node_id this text came from.
|
||||
source_type: ``"turn"`` or ``"dag"``.
|
||||
max_per_type: Hard cap on links per relation type (prevents explosions).
|
||||
|
||||
Returns:
|
||||
List of TypedLink objects (may be empty).
|
||||
"""
|
||||
links: List[TypedLink] = []
|
||||
ts = time.time()
|
||||
|
||||
def _add(rel: RelationType, subj: str, obj_: str, conf: float = 1.0) -> None:
|
||||
subj = _clean(subj)
|
||||
obj_ = _clean(obj_)
|
||||
if len(subj) < 2 or len(obj_) < 2:
|
||||
return
|
||||
links.append(TypedLink(
|
||||
source_id=source_id,
|
||||
source_type=source_type,
|
||||
relation_type=rel,
|
||||
subject=subj,
|
||||
object=obj_,
|
||||
confidence=conf,
|
||||
created_at=ts,
|
||||
))
|
||||
|
||||
# -- mentions (quoted + capitalized entities) ---------------------------
|
||||
seen_mentions: set = set()
|
||||
for m in _RE_QUOTED.finditer(text):
|
||||
ent = _clean(m.group(1))
|
||||
if ent not in seen_mentions:
|
||||
_add(RelationType.MENTIONS, source_id, ent, 0.9)
|
||||
seen_mentions.add(ent)
|
||||
if len(seen_mentions) >= max_per_type:
|
||||
break
|
||||
for m in _RE_CAMEL.finditer(text):
|
||||
ent = _clean(m.group(1))
|
||||
if ent not in seen_mentions and len(ent.split()) <= 3:
|
||||
_add(RelationType.MENTIONS, source_id, ent, 0.7)
|
||||
seen_mentions.add(ent)
|
||||
if len(seen_mentions) >= max_per_type:
|
||||
break
|
||||
|
||||
# -- defines ------------------------------------------------------------
|
||||
for m in list(_RE_DEFINES.finditer(text))[:max_per_type]:
|
||||
_add(RelationType.DEFINES, _clean(m.group(1)), _clean(m.group(2)), 0.8)
|
||||
|
||||
# -- corrects -----------------------------------------------------------
|
||||
for m in list(_RE_CORRECTS.finditer(text))[:max_per_type]:
|
||||
_add(RelationType.CORRECTS, "user", _clean(m.group(1)), 0.85)
|
||||
for m in list(_RE_NOT_BUT.finditer(text))[:max_per_type]:
|
||||
_add(RelationType.CORRECTS, _clean(m.group(2)), _clean(m.group(3)), 0.9)
|
||||
|
||||
# -- prefers ------------------------------------------------------------
|
||||
seen_prefers: set = set()
|
||||
for m in list(_RE_PREFERS_OVER.finditer(text))[:max_per_type]:
|
||||
preferred = _clean(m.group(1))
|
||||
alt = _clean(m.group(2))
|
||||
if preferred not in seen_prefers:
|
||||
_add(RelationType.PREFERS, "user", preferred, 0.9)
|
||||
seen_prefers.add(preferred)
|
||||
_add(RelationType.PREFERS, preferred, alt, 0.7)
|
||||
for m in list(_RE_PREFERS_SIMPLE.finditer(text))[:max_per_type]:
|
||||
preferred = _clean(m.group(1))
|
||||
if preferred not in seen_prefers:
|
||||
_add(RelationType.PREFERS, "user", preferred, 0.9)
|
||||
seen_prefers.add(preferred)
|
||||
for m in list(_RE_USE_INSTEAD.finditer(text))[:max_per_type]:
|
||||
_add(RelationType.PREFERS, _clean(m.group(1)), _clean(m.group(2)), 0.85)
|
||||
|
||||
# -- uses ---------------------------------------------------------------
|
||||
for m in list(_RE_USES.finditer(text))[:max_per_type]:
|
||||
_add(RelationType.USES, source_id, _clean(m.group(1)), 0.8)
|
||||
|
||||
# -- depends_on ---------------------------------------------------------
|
||||
for m in list(_RE_DEPENDS.finditer(text))[:max_per_type]:
|
||||
_add(RelationType.DEPENDS_ON, source_id, _clean(m.group(1)), 0.8)
|
||||
|
||||
# -- part_of ------------------------------------------------------------
|
||||
for m in list(_RE_PART_OF.finditer(text))[:max_per_type]:
|
||||
_add(RelationType.PART_OF, _clean(m.group(1)), _clean(m.group(2)), 0.8)
|
||||
|
||||
return links
|
||||
|
||||
|
||||
class TypedLinkExtractor:
|
||||
"""Extracts typed links and persists them to the ATLAS database."""
|
||||
|
||||
def __init__(self, db: AtlasDB) -> None:
|
||||
self._db = db
|
||||
|
||||
def extract_and_store(
|
||||
self,
|
||||
text: str,
|
||||
source_id: str,
|
||||
*,
|
||||
source_type: str = "turn",
|
||||
max_per_type: int = 10,
|
||||
) -> List[TypedLink]:
|
||||
"""Extract links and persist to ``typed_links`` table.
|
||||
|
||||
Returns the extracted links (may be empty).
|
||||
"""
|
||||
links = extract_links(text, source_id, source_type=source_type, max_per_type=max_per_type)
|
||||
if links:
|
||||
con = self._db.conn()
|
||||
con.executemany(
|
||||
"""
|
||||
INSERT INTO typed_links
|
||||
(source_id, source_type, relation_type, subject, object,
|
||||
confidence, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
(lnk.source_id, lnk.source_type, lnk.relation_type.value,
|
||||
lnk.subject, lnk.object, lnk.confidence, lnk.created_at)
|
||||
for lnk in links
|
||||
],
|
||||
)
|
||||
con.commit()
|
||||
return links
|
||||
|
||||
def query_links(
|
||||
self,
|
||||
*,
|
||||
source_id: str | None = None,
|
||||
relation_type: RelationType | str | None = None,
|
||||
subject: str | None = None,
|
||||
object_: str | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""Query stored typed links with optional filters."""
|
||||
clauses = []
|
||||
params: list = []
|
||||
if source_id:
|
||||
clauses.append("source_id = ?")
|
||||
params.append(source_id)
|
||||
if relation_type:
|
||||
rt = relation_type.value if isinstance(relation_type, RelationType) else relation_type
|
||||
clauses.append("relation_type = ?")
|
||||
params.append(rt)
|
||||
if subject:
|
||||
clauses.append("subject LIKE ?")
|
||||
params.append(f"%{subject}%")
|
||||
if object_:
|
||||
clauses.append("object LIKE ?")
|
||||
params.append(f"%{object_}%")
|
||||
where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
|
||||
params.append(limit)
|
||||
rows = self._db.conn().execute(
|
||||
f"SELECT * FROM typed_links {where} ORDER BY created_at DESC LIMIT ?",
|
||||
params,
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
264
agent/atlas/recall.py
Normal file
264
agent/atlas/recall.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""RecallEngine — explicit recall operations over ATLAS stores.
|
||||
|
||||
The three core operations satisfy the acceptance criterion of ≥3 explicit
|
||||
recall operations reachable by the agent:
|
||||
|
||||
search(query) Full-text search over raw turns, DAG nodes, and
|
||||
durable memory. Returns ranked text excerpts.
|
||||
|
||||
describe(id) Describe a specific turn or DAG node by its lineage
|
||||
ID. Returns structured metadata + content preview.
|
||||
|
||||
expand(node_id) Expand a DAG summary node to its source turns.
|
||||
Proves the agent can recover compacted context
|
||||
without re-injecting the full original transcript.
|
||||
|
||||
All operations query the SQLite FTS5 index and/or primary tables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.atlas.dag import SummaryDAGStore
|
||||
from agent.atlas.db import AtlasDB
|
||||
from agent.atlas.turns import RawTurnStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PREVIEW_LEN = 300 # chars shown in search snippets
|
||||
_EXPAND_MAX_TURNS = 20 # max source turns returned by expand()
|
||||
|
||||
|
||||
class RecallEngine:
|
||||
"""Explicit recall operations for the ATLAS memory subsystem.
|
||||
|
||||
Designed to be called by agent tool handlers:
|
||||
|
||||
engine = RecallEngine(db)
|
||||
result = engine.search("capital of France")
|
||||
# → formatted string with ranked hits
|
||||
"""
|
||||
|
||||
def __init__(self, db: AtlasDB) -> None:
|
||||
self._db = db
|
||||
self._turns = RawTurnStore(db)
|
||||
self._dag = SummaryDAGStore(db)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
limit: int = 10,
|
||||
doc_types: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""Full-text search over the ATLAS corpus.
|
||||
|
||||
Searches raw turns, DAG summary nodes, and durable memory using
|
||||
FTS5 with porter stemming. Falls back to LIKE search if FTS5
|
||||
returns nothing.
|
||||
|
||||
Args:
|
||||
query: Free-text search query.
|
||||
limit: Maximum number of results to return.
|
||||
doc_types: Filter to specific doc types: ``raw_turn``, ``dag``,
|
||||
``durable``, ``world``. None = all.
|
||||
|
||||
Returns:
|
||||
Formatted string listing ranked hits. Each hit includes the
|
||||
document type, lineage ID, and a content preview.
|
||||
"""
|
||||
hits = self._fts_search(query, limit=limit, doc_types=doc_types)
|
||||
if not hits:
|
||||
hits = self._fallback_search(query, limit=limit, doc_types=doc_types)
|
||||
if not hits:
|
||||
return f"No results found for '{query}'."
|
||||
lines = [f"ATLAS search results for '{query}' ({len(hits)} hit(s)):\n"]
|
||||
for i, h in enumerate(hits, 1):
|
||||
preview = textwrap.shorten(h["content"], width=_PREVIEW_LEN, placeholder="…")
|
||||
lines.append(f" [{i}] ({h['doc_type']}) id={h['doc_id']}\n {preview}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def describe(self, doc_id: str) -> str:
|
||||
"""Describe a specific turn or DAG node by its lineage ID.
|
||||
|
||||
Args:
|
||||
doc_id: A turn_id (``session_id:NNNN``) or DAG node_id
|
||||
(``dag:session_id:NNNN``).
|
||||
|
||||
Returns:
|
||||
Formatted string with metadata and content preview.
|
||||
"""
|
||||
if doc_id.startswith("dag:"):
|
||||
return self._describe_dag_node(doc_id)
|
||||
else:
|
||||
return self._describe_turn(doc_id)
|
||||
|
||||
def expand(self, node_id: str, *, max_turns: int = _EXPAND_MAX_TURNS) -> str:
|
||||
"""Expand a DAG summary node to its source turns.
|
||||
|
||||
This is the key operation that proves the agent can recover facts
|
||||
from compacted context without re-injecting the full transcript.
|
||||
|
||||
Args:
|
||||
node_id: A DAG node ID (``dag:session_id:NNNN``).
|
||||
max_turns: Maximum source turns to include in the output.
|
||||
|
||||
Returns:
|
||||
Formatted string with the summary and its source turns.
|
||||
"""
|
||||
node = self._dag.get_node(node_id)
|
||||
if node is None:
|
||||
return f"No DAG node found with id '{node_id}'."
|
||||
|
||||
lines = [
|
||||
f"ATLAS expand: {node_id}",
|
||||
f" Created: {_ts(node.created_at)}",
|
||||
f" Source turns: {node.source_count()}",
|
||||
f" Summary:\n{textwrap.indent(node.summary_text, ' ')}",
|
||||
"",
|
||||
" Source turns (oldest first):",
|
||||
]
|
||||
|
||||
turn_ids = node.source_turn_ids[:max_turns]
|
||||
turns = self._turns.get_turns_by_ids(turn_ids)
|
||||
if not turns:
|
||||
lines.append(" (source turns not found in local store)")
|
||||
else:
|
||||
for t in turns:
|
||||
preview = textwrap.shorten(t.content, width=_PREVIEW_LEN, placeholder="…")
|
||||
lines.append(
|
||||
f" [{t.turn_id}] <{t.role}> {preview}"
|
||||
)
|
||||
if len(node.source_turn_ids) > max_turns:
|
||||
lines.append(
|
||||
f" … and {len(node.source_turn_ids) - max_turns} more turns "
|
||||
f"(use max_turns= to see more)"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Search internals
|
||||
|
||||
def _fts_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
doc_types: Optional[List[str]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""FTS5 full-text search."""
|
||||
try:
|
||||
if doc_types:
|
||||
placeholders = ",".join("?" * len(doc_types))
|
||||
rows = self._db.conn().execute(
|
||||
f"SELECT doc_id, doc_type, content FROM atlas_fts "
|
||||
f"WHERE atlas_fts MATCH ? AND doc_type IN ({placeholders}) "
|
||||
f"ORDER BY rank LIMIT ?",
|
||||
[query] + doc_types + [limit],
|
||||
).fetchall()
|
||||
else:
|
||||
rows = self._db.conn().execute(
|
||||
"SELECT doc_id, doc_type, content FROM atlas_fts "
|
||||
"WHERE atlas_fts MATCH ? ORDER BY rank LIMIT ?",
|
||||
(query, limit),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as exc:
|
||||
logger.debug("FTS5 search failed (will fallback): %s", exc)
|
||||
return []
|
||||
|
||||
def _fallback_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
doc_types: Optional[List[str]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""LIKE-based fallback when FTS5 returns nothing."""
|
||||
results: List[Dict[str, Any]] = []
|
||||
like = f"%{query}%"
|
||||
|
||||
if doc_types is None or "raw_turn" in doc_types:
|
||||
rows = self._db.conn().execute(
|
||||
"SELECT turn_id AS doc_id, 'raw_turn' AS doc_type, content "
|
||||
"FROM raw_turns WHERE content LIKE ? LIMIT ?",
|
||||
(like, limit),
|
||||
).fetchall()
|
||||
results.extend(dict(r) for r in rows)
|
||||
|
||||
if doc_types is None or "dag" in doc_types:
|
||||
rows = self._db.conn().execute(
|
||||
"SELECT node_id AS doc_id, 'dag' AS doc_type, summary_text AS content "
|
||||
"FROM summary_dag WHERE summary_text LIKE ? LIMIT ?",
|
||||
(like, limit),
|
||||
).fetchall()
|
||||
results.extend(dict(r) for r in rows)
|
||||
|
||||
if doc_types is None or "durable" in doc_types:
|
||||
rows = self._db.conn().execute(
|
||||
"SELECT CAST(dm_id AS TEXT) AS doc_id, 'durable' AS doc_type, content "
|
||||
"FROM durable_memory WHERE content LIKE ? LIMIT ?",
|
||||
(like, limit),
|
||||
).fetchall()
|
||||
results.extend(dict(r) for r in rows)
|
||||
|
||||
if doc_types is None or "world" in doc_types:
|
||||
rows = self._db.conn().execute(
|
||||
"SELECT CAST(wk_id AS TEXT) AS doc_id, 'world' AS doc_type, content "
|
||||
"FROM world_knowledge WHERE content LIKE ? LIMIT ?",
|
||||
(like, limit),
|
||||
).fetchall()
|
||||
results.extend(dict(r) for r in rows)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Describe internals
|
||||
|
||||
def _describe_turn(self, turn_id: str) -> str:
|
||||
turn = self._turns.get(turn_id)
|
||||
if turn is None:
|
||||
return f"No turn found with id '{turn_id}'."
|
||||
preview = textwrap.shorten(turn.content, width=_PREVIEW_LEN, placeholder="…")
|
||||
lines = [
|
||||
f"ATLAS describe: {turn_id}",
|
||||
f" Role: {turn.role}",
|
||||
f" Session: {turn.session_id}",
|
||||
f" Index: {turn.turn_index}",
|
||||
f" Timestamp: {_ts(turn.timestamp)}",
|
||||
f" Tokens: ~{turn.token_count}",
|
||||
]
|
||||
if turn.tool_name:
|
||||
lines.append(f" Tool: {turn.tool_name}")
|
||||
lines.append(f" Content:\n{textwrap.indent(preview, ' ')}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _describe_dag_node(self, node_id: str) -> str:
|
||||
node = self._dag.get_node(node_id)
|
||||
if node is None:
|
||||
return f"No DAG node found with id '{node_id}'."
|
||||
preview = textwrap.shorten(node.summary_text, width=_PREVIEW_LEN, placeholder="…")
|
||||
lines = [
|
||||
f"ATLAS describe: {node_id}",
|
||||
f" Session: {node.session_id}",
|
||||
f" Created: {_ts(node.created_at)}",
|
||||
f" Source turns: {node.source_count()}",
|
||||
f" Tokens: ~{node.token_count}",
|
||||
]
|
||||
if node.parent_node_id:
|
||||
lines.append(f" Parent node: {node.parent_node_id}")
|
||||
lines.append(f" Summary:\n{textwrap.indent(preview, ' ')}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _ts(timestamp: float) -> str:
|
||||
"""Human-readable UTC timestamp."""
|
||||
import datetime
|
||||
return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
|
||||
272
agent/atlas/stores.py
Normal file
272
agent/atlas/stores.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Three explicit stores: world knowledge, durable memory, session state.
|
||||
|
||||
Writes are *always* routed to exactly one store — never a single mixed bucket.
|
||||
|
||||
- ``WorldKnowledgeStore`` — static facts about the world / domain
|
||||
- ``DurableMemoryStore`` — user-specific facts that outlive sessions
|
||||
- ``SessionStateStore`` — ephemeral working notes for the current session
|
||||
|
||||
``AtlasStores`` is the unified facade used by the MemoryProvider.
|
||||
``StoreTarget`` is the routing enum.
|
||||
|
||||
Routing rules (applied deterministically, no LLM):
|
||||
"world:" prefix → WorldKnowledgeStore
|
||||
"session:" prefix → SessionStateStore
|
||||
anything else → DurableMemoryStore (safe default)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.atlas.db import AtlasDB
|
||||
|
||||
|
||||
class StoreTarget(str, Enum):
|
||||
WORLD = "world"
|
||||
DURABLE = "durable"
|
||||
SESSION = "session"
|
||||
|
||||
|
||||
def _classify_target(content: str, category: str = "") -> StoreTarget:
|
||||
"""Determine the correct store for a write.
|
||||
|
||||
Rules (checked in order):
|
||||
1. Explicit ``world:`` prefix → WORLD
|
||||
2. Explicit ``session:`` prefix → SESSION
|
||||
3. Category starts with ``world`` → WORLD
|
||||
4. Category starts with ``session`` → SESSION
|
||||
5. Default → DURABLE
|
||||
"""
|
||||
lc_content = content.strip().lower()
|
||||
lc_category = (category or "").strip().lower()
|
||||
|
||||
if lc_content.startswith("world:") or lc_category.startswith("world"):
|
||||
return StoreTarget.WORLD
|
||||
if lc_content.startswith("session:") or lc_category.startswith("session"):
|
||||
return StoreTarget.SESSION
|
||||
return StoreTarget.DURABLE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# World knowledge
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WorldKnowledgeStore:
|
||||
"""Stores static / universal facts.
|
||||
|
||||
Use for: domain knowledge, constants, API shapes, product descriptions.
|
||||
These facts do NOT expire with sessions.
|
||||
"""
|
||||
|
||||
def __init__(self, db: AtlasDB) -> None:
|
||||
self._db = db
|
||||
|
||||
def add(self, content: str, *, tags: str = "", trust: float = 0.8) -> int:
|
||||
ts = time.time()
|
||||
content = content.removeprefix("world:").strip()
|
||||
cur = self._db.conn().execute(
|
||||
"INSERT INTO world_knowledge (content, tags, trust, created_at, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(content, tags, trust, ts, ts),
|
||||
)
|
||||
self._db.conn().commit()
|
||||
return cur.lastrowid
|
||||
|
||||
def search(self, query: str, *, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
rows = self._db.conn().execute(
|
||||
"SELECT wk_id, content, tags, trust FROM world_knowledge "
|
||||
"WHERE content LIKE ? ORDER BY trust DESC LIMIT ?",
|
||||
(f"%{query}%", limit),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def list_all(self, *, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
rows = self._db.conn().execute(
|
||||
"SELECT wk_id, content, tags, trust FROM world_knowledge "
|
||||
"ORDER BY trust DESC LIMIT ?",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def count(self) -> int:
|
||||
return self._db.conn().execute(
|
||||
"SELECT COUNT(*) FROM world_knowledge"
|
||||
).fetchone()[0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Durable memory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DurableMemoryStore:
|
||||
"""Stores user-specific facts that outlive sessions.
|
||||
|
||||
Use for: user preferences, corrections, project facts, identity details.
|
||||
Facts here survive session boundaries.
|
||||
"""
|
||||
|
||||
def __init__(self, db: AtlasDB) -> None:
|
||||
self._db = db
|
||||
|
||||
def add(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
category: str = "general",
|
||||
tags: str = "",
|
||||
trust: float = 0.7,
|
||||
) -> int:
|
||||
ts = time.time()
|
||||
cur = self._db.conn().execute(
|
||||
"INSERT INTO durable_memory (category, content, tags, trust, created_at, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(category, content, tags, trust, ts, ts),
|
||||
)
|
||||
self._db.conn().commit()
|
||||
return cur.lastrowid
|
||||
|
||||
def search(self, query: str, *, category: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
if category:
|
||||
rows = self._db.conn().execute(
|
||||
"SELECT dm_id, category, content, tags, trust FROM durable_memory "
|
||||
"WHERE category = ? AND content LIKE ? ORDER BY trust DESC LIMIT ?",
|
||||
(category, f"%{query}%", limit),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = self._db.conn().execute(
|
||||
"SELECT dm_id, category, content, tags, trust FROM durable_memory "
|
||||
"WHERE content LIKE ? ORDER BY trust DESC LIMIT ?",
|
||||
(f"%{query}%", limit),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def list_by_category(self, category: str, *, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
rows = self._db.conn().execute(
|
||||
"SELECT dm_id, category, content, tags, trust FROM durable_memory "
|
||||
"WHERE category = ? ORDER BY trust DESC LIMIT ?",
|
||||
(category, limit),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def count(self) -> int:
|
||||
return self._db.conn().execute(
|
||||
"SELECT COUNT(*) FROM durable_memory"
|
||||
).fetchone()[0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SessionStateStore:
|
||||
"""Ephemeral working notes scoped to one session.
|
||||
|
||||
Useful for tracking in-progress plans, intermediate outputs, and
|
||||
scratch-pad facts that should not persist beyond the session.
|
||||
|
||||
Keys are unique per session; setting the same key overwrites the prior value.
|
||||
"""
|
||||
|
||||
def __init__(self, db: AtlasDB) -> None:
|
||||
self._db = db
|
||||
|
||||
def set(self, session_id: str, key: str, content: str) -> None:
|
||||
key = key.removeprefix("session:").strip()
|
||||
ts = time.time()
|
||||
self._db.conn().execute(
|
||||
"""
|
||||
INSERT INTO session_state (session_id, key, content, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(session_id, key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(session_id, key, content, ts, ts),
|
||||
)
|
||||
self._db.conn().commit()
|
||||
|
||||
def get(self, session_id: str, key: str) -> Optional[str]:
|
||||
row = self._db.conn().execute(
|
||||
"SELECT content FROM session_state WHERE session_id = ? AND key = ?",
|
||||
(session_id, key),
|
||||
).fetchone()
|
||||
return row["content"] if row else None
|
||||
|
||||
def list_session(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
rows = self._db.conn().execute(
|
||||
"SELECT key, content, updated_at FROM session_state "
|
||||
"WHERE session_id = ? ORDER BY updated_at DESC",
|
||||
(session_id,),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def count_session(self, session_id: str) -> int:
|
||||
return self._db.conn().execute(
|
||||
"SELECT COUNT(*) FROM session_state WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()[0]
|
||||
|
||||
def clear_session(self, session_id: str) -> int:
|
||||
cur = self._db.conn().execute(
|
||||
"DELETE FROM session_state WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
self._db.conn().commit()
|
||||
return cur.rowcount
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unified facade
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AtlasStores:
|
||||
"""Unified facade that routes writes to the correct store.
|
||||
|
||||
Usage::
|
||||
|
||||
stores = AtlasStores(db)
|
||||
target = stores.route(content="world: Paris is the capital of France")
|
||||
stores.write(content="world: Paris is the capital of France")
|
||||
# → routed to WorldKnowledgeStore
|
||||
"""
|
||||
|
||||
def __init__(self, db: AtlasDB) -> None:
|
||||
self.world = WorldKnowledgeStore(db)
|
||||
self.durable = DurableMemoryStore(db)
|
||||
self.session = SessionStateStore(db)
|
||||
|
||||
def route(self, content: str, category: str = "") -> StoreTarget:
|
||||
"""Classify which store a write should go to."""
|
||||
return _classify_target(content, category)
|
||||
|
||||
def write(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
category: str = "",
|
||||
tags: str = "",
|
||||
trust: float = 0.7,
|
||||
session_id: str = "",
|
||||
session_key: str = "",
|
||||
) -> tuple[StoreTarget, Any]:
|
||||
"""Route and write content; returns (target, id/key)."""
|
||||
target = self.route(content, category)
|
||||
if target == StoreTarget.WORLD:
|
||||
row_id = self.world.add(content, tags=tags, trust=max(trust, 0.8))
|
||||
return target, row_id
|
||||
elif target == StoreTarget.SESSION:
|
||||
key = session_key or category or "note"
|
||||
self.session.set(session_id, key, content)
|
||||
return target, key
|
||||
else:
|
||||
row_id = self.durable.add(content, category=category or "general", tags=tags, trust=trust)
|
||||
return target, row_id
|
||||
|
||||
def search_all(self, query: str, *, limit_each: int = 5) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Search all three stores and return results keyed by store name."""
|
||||
return {
|
||||
"world": self.world.search(query, limit=limit_each),
|
||||
"durable": self.durable.search(query, limit=limit_each),
|
||||
}
|
||||
155
agent/atlas/turns.py
Normal file
155
agent/atlas/turns.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""RawTurnStore — immutable per-turn records with stable lineage IDs.
|
||||
|
||||
Every user/assistant/tool turn is persisted here *before* it can be evicted
|
||||
from the active context window. Records are never deleted; only appended.
|
||||
|
||||
Lineage ID format: ``<session_id>:<index:04d>``
|
||||
|
||||
Example::
|
||||
|
||||
store = RawTurnStore(atlas_db)
|
||||
turn_id = store.append(
|
||||
session_id="sess_abc",
|
||||
role="user",
|
||||
content="What is the capital of France?",
|
||||
)
|
||||
# turn_id == "sess_abc:0000"
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from agent.atlas.db import AtlasDB
|
||||
|
||||
|
||||
@dataclass
|
||||
class TurnRecord:
|
||||
"""Immutable record of one conversation turn."""
|
||||
|
||||
turn_id: str # stable lineage ID
|
||||
session_id: str
|
||||
turn_index: int
|
||||
role: str # user / assistant / tool
|
||||
content: str
|
||||
tool_name: Optional[str] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
token_count: int = 0
|
||||
|
||||
def lineage_label(self) -> str:
|
||||
"""Short label for human display, e.g. ``user@sess_abc:0012``."""
|
||||
return f"{self.role}@{self.turn_id}"
|
||||
|
||||
|
||||
def _estimate_tokens(text: str) -> int:
|
||||
"""Rough token estimate — 4 chars per token."""
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
|
||||
class RawTurnStore:
|
||||
"""Append-only store for conversation turns.
|
||||
|
||||
Thread-safe (AtlasDB uses per-thread connections).
|
||||
"""
|
||||
|
||||
def __init__(self, db: AtlasDB) -> None:
|
||||
self._db = db
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Write
|
||||
|
||||
def append(
|
||||
self,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
*,
|
||||
tool_name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
) -> str:
|
||||
"""Persist a turn and return its stable turn_id."""
|
||||
ts = timestamp or time.time()
|
||||
con = self._db.conn()
|
||||
# Determine next index for this session
|
||||
row = con.execute(
|
||||
"SELECT COALESCE(MAX(turn_index), -1) FROM raw_turns WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
next_idx: int = (row[0] if row else -1) + 1
|
||||
turn_id = f"{session_id}:{next_idx:04d}"
|
||||
token_count = _estimate_tokens(content)
|
||||
con.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO raw_turns
|
||||
(turn_id, session_id, turn_index, role, content,
|
||||
tool_name, tool_call_id, timestamp, token_count)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(turn_id, session_id, next_idx, role, content,
|
||||
tool_name, tool_call_id, ts, token_count),
|
||||
)
|
||||
con.commit()
|
||||
return turn_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Read
|
||||
|
||||
def get(self, turn_id: str) -> Optional[TurnRecord]:
|
||||
"""Fetch a single turn by its lineage ID."""
|
||||
row = self._db.conn().execute(
|
||||
"SELECT * FROM raw_turns WHERE turn_id = ?", (turn_id,)
|
||||
).fetchone()
|
||||
return _row_to_turn(row) if row else None
|
||||
|
||||
def get_session_turns(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
) -> List[TurnRecord]:
|
||||
"""Return all turns for a session, oldest first."""
|
||||
sql = (
|
||||
"SELECT * FROM raw_turns WHERE session_id = ? "
|
||||
"ORDER BY turn_index ASC LIMIT ? OFFSET ?"
|
||||
)
|
||||
lim = limit if limit is not None else -1
|
||||
rows = self._db.conn().execute(sql, (session_id, lim, offset)).fetchall()
|
||||
return [_row_to_turn(r) for r in rows]
|
||||
|
||||
def count_session_turns(self, session_id: str) -> int:
|
||||
"""Count persisted turns for a session."""
|
||||
row = self._db.conn().execute(
|
||||
"SELECT COUNT(*) FROM raw_turns WHERE session_id = ?", (session_id,)
|
||||
).fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
def get_turns_by_ids(self, turn_ids: List[str]) -> List[TurnRecord]:
|
||||
"""Fetch multiple turns by ID list (preserves ID order)."""
|
||||
if not turn_ids:
|
||||
return []
|
||||
placeholders = ",".join("?" * len(turn_ids))
|
||||
rows = self._db.conn().execute(
|
||||
f"SELECT * FROM raw_turns WHERE turn_id IN ({placeholders})",
|
||||
turn_ids,
|
||||
).fetchall()
|
||||
by_id = {r["turn_id"]: _row_to_turn(r) for r in rows}
|
||||
return [by_id[tid] for tid in turn_ids if tid in by_id]
|
||||
|
||||
|
||||
def _row_to_turn(row: object) -> TurnRecord:
|
||||
return TurnRecord(
|
||||
turn_id=row["turn_id"],
|
||||
session_id=row["session_id"],
|
||||
turn_index=row["turn_index"],
|
||||
role=row["role"],
|
||||
content=row["content"],
|
||||
tool_name=row["tool_name"],
|
||||
tool_call_id=row["tool_call_id"],
|
||||
timestamp=row["timestamp"],
|
||||
token_count=row["token_count"],
|
||||
)
|
||||
697
agent/atlas_memory.py
Normal file
697
agent/atlas_memory.py
Normal file
@@ -0,0 +1,697 @@
|
||||
"""
|
||||
ATLAS Memory — Lossless Context + Memory Subsystem.
|
||||
|
||||
Provides immutable turn storage, a compaction DAG that preserves source
|
||||
references, typed link extraction, three-store routing, and explicit recall
|
||||
operations (search / describe / expand).
|
||||
|
||||
No LLM calls are made in this module — all operations are deterministic.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SCHEMA_SQL = """
|
||||
-- ATLAS Memory schema v1
|
||||
-- Turn Store: immutable raw turns with stable lineage IDs
|
||||
CREATE TABLE IF NOT EXISTS atlas_turns (
|
||||
turn_id TEXT PRIMARY KEY, -- "{session_id}:{seq_num:06d}"
|
||||
session_id TEXT NOT NULL,
|
||||
seq_num INTEGER NOT NULL,
|
||||
role TEXT NOT NULL, -- 'user', 'assistant', 'tool', 'system'
|
||||
content TEXT,
|
||||
tool_name TEXT,
|
||||
tool_call_id TEXT,
|
||||
created_at REAL NOT NULL,
|
||||
UNIQUE (session_id, seq_num)
|
||||
);
|
||||
|
||||
-- Summary DAG: compaction nodes with source references (lossless)
|
||||
CREATE TABLE IF NOT EXISTS atlas_summary_nodes (
|
||||
node_id TEXT PRIMARY KEY, -- "sum:{session_id}:{timestamp_hex}"
|
||||
session_id TEXT NOT NULL,
|
||||
summary_text TEXT NOT NULL,
|
||||
source_turn_ids TEXT NOT NULL, -- JSON array of turn_ids (immutable reference)
|
||||
depth INTEGER NOT NULL DEFAULT 0, -- 0=leaf, 1=summary-of-summaries
|
||||
parent_node_id TEXT,
|
||||
created_at REAL NOT NULL
|
||||
);
|
||||
|
||||
-- Three explicit stores (no mixed bucket)
|
||||
-- Note: session_id can be NULL (world_knowledge / durable_memory).
|
||||
-- SQLite UNIQUE constraints treat NULLs as distinct, so we use a unique
|
||||
-- expression index on COALESCE(session_id, '') to enforce uniqueness
|
||||
-- correctly for both NULL and non-NULL session_ids.
|
||||
CREATE TABLE IF NOT EXISTS atlas_stores (
|
||||
store_name TEXT NOT NULL, -- 'world_knowledge', 'durable_memory', 'session_state'
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
session_id TEXT, -- NULL for world/durable, set for session_state
|
||||
created_at REAL NOT NULL,
|
||||
updated_at REAL NOT NULL
|
||||
);
|
||||
|
||||
-- Typed links: deterministic extraction on write
|
||||
CREATE TABLE IF NOT EXISTS atlas_links (
|
||||
link_id TEXT PRIMARY KEY, -- "lnk:{source_id}:{link_type}:{target_hash}"
|
||||
source_id TEXT NOT NULL, -- turn_id or node_id
|
||||
source_type TEXT NOT NULL, -- 'turn' or 'node'
|
||||
target_text TEXT NOT NULL, -- what was referenced (denormalized for lookup)
|
||||
link_type TEXT NOT NULL, -- DEFINES, MODIFIES, REFERENCES, DEPENDS_ON, CONTRADICTS
|
||||
confidence REAL NOT NULL DEFAULT 1.0,
|
||||
session_id TEXT,
|
||||
created_at REAL NOT NULL
|
||||
);
|
||||
"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AtlasDB
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AtlasDB:
|
||||
"""Core database — opens SQLite with WAL mode and creates the schema."""
|
||||
|
||||
def __init__(self, db_path: Path = None):
|
||||
if db_path is None:
|
||||
db_path = get_hermes_home() / "atlas.db"
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._conn = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
check_same_thread=False,
|
||||
timeout=10.0,
|
||||
isolation_level=None, # autocommit; we manage transactions manually
|
||||
)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("PRAGMA foreign_keys=ON")
|
||||
self._init_schema()
|
||||
|
||||
def _init_schema(self) -> None:
|
||||
with self._lock:
|
||||
self._conn.executescript(_SCHEMA_SQL)
|
||||
# Expression index for atlas_stores uniqueness on (store_name, key, session_id).
|
||||
# UNIQUE constraints don't handle NULL correctly (NULL != NULL), so we use an
|
||||
# expression index with COALESCE to treat NULL session_id as empty string.
|
||||
self._conn.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_atlas_stores_unique "
|
||||
"ON atlas_stores (store_name, key, COALESCE(session_id, ''))"
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
@property
|
||||
def conn(self) -> sqlite3.Connection:
|
||||
return self._conn
|
||||
|
||||
def close(self) -> None:
|
||||
with self._lock:
|
||||
if self._conn:
|
||||
try:
|
||||
self._conn.execute("PRAGMA wal_checkpoint(PASSIVE)")
|
||||
except Exception:
|
||||
pass
|
||||
self._conn.close()
|
||||
self._conn = None # type: ignore[assignment]
|
||||
|
||||
# Simple helper used by sub-components
|
||||
def execute(self, sql: str, params=()):
|
||||
with self._lock:
|
||||
return self._conn.execute(sql, params)
|
||||
|
||||
def executemany(self, sql: str, params_seq):
|
||||
with self._lock:
|
||||
return self._conn.executemany(sql, params_seq)
|
||||
|
||||
def write(self, fn):
|
||||
"""Execute *fn(conn)* inside a BEGIN IMMEDIATE transaction."""
|
||||
with self._lock:
|
||||
self._conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
result = fn(self._conn)
|
||||
self._conn.commit()
|
||||
except BaseException:
|
||||
try:
|
||||
self._conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
return result
|
||||
|
||||
def fetchall(self, sql: str, params=()):
|
||||
with self._lock:
|
||||
return [dict(r) for r in self._conn.execute(sql, params).fetchall()]
|
||||
|
||||
def fetchone(self, sql: str, params=()):
|
||||
with self._lock:
|
||||
row = self._conn.execute(sql, params).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TurnStore
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TurnStore:
|
||||
"""Immutable turn storage with stable lineage IDs."""
|
||||
|
||||
def __init__(self, db: AtlasDB):
|
||||
self._db = db
|
||||
|
||||
def append_turn(
|
||||
self,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str = None,
|
||||
tool_name: str = None,
|
||||
tool_call_id: str = None,
|
||||
) -> str:
|
||||
"""Append a turn and return its turn_id."""
|
||||
|
||||
def _do(conn):
|
||||
row = conn.execute(
|
||||
"SELECT COALESCE(MAX(seq_num), 0) + 1 AS next_seq "
|
||||
"FROM atlas_turns WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
seq_num = row["next_seq"] if row else 1
|
||||
turn_id = f"{session_id}:{seq_num:06d}"
|
||||
conn.execute(
|
||||
"""INSERT INTO atlas_turns
|
||||
(turn_id, session_id, seq_num, role, content,
|
||||
tool_name, tool_call_id, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
turn_id,
|
||||
session_id,
|
||||
seq_num,
|
||||
role,
|
||||
content,
|
||||
tool_name,
|
||||
tool_call_id,
|
||||
time.time(),
|
||||
),
|
||||
)
|
||||
return turn_id
|
||||
|
||||
return self._db.write(_do)
|
||||
|
||||
def get_turn(self, turn_id: str) -> Optional[Dict]:
|
||||
return self._db.fetchone(
|
||||
"SELECT * FROM atlas_turns WHERE turn_id = ?", (turn_id,)
|
||||
)
|
||||
|
||||
def get_session_turns(self, session_id: str) -> List[Dict]:
|
||||
return self._db.fetchall(
|
||||
"SELECT * FROM atlas_turns WHERE session_id = ? ORDER BY seq_num",
|
||||
(session_id,),
|
||||
)
|
||||
|
||||
def turn_count(self, session_id: str) -> int:
|
||||
row = self._db.fetchone(
|
||||
"SELECT COUNT(*) AS cnt FROM atlas_turns WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
return row["cnt"] if row else 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SummaryDAG
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SummaryDAG:
|
||||
"""Lossless compaction: summary nodes retain immutable references to source turns."""
|
||||
|
||||
def __init__(self, db: AtlasDB):
|
||||
self._db = db
|
||||
|
||||
def create_summary_node(
|
||||
self,
|
||||
session_id: str,
|
||||
summary_text: str,
|
||||
source_turn_ids: List[str],
|
||||
parent_node_id: str = None,
|
||||
) -> str:
|
||||
"""Create a summary node and return its node_id."""
|
||||
ts = time.time()
|
||||
ts_hex = format(int(ts * 1000), "x")
|
||||
node_id = f"sum:{session_id}:{ts_hex}"
|
||||
depth = 0 if parent_node_id is None else 1
|
||||
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"""INSERT INTO atlas_summary_nodes
|
||||
(node_id, session_id, summary_text, source_turn_ids,
|
||||
depth, parent_node_id, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
node_id,
|
||||
session_id,
|
||||
summary_text,
|
||||
json.dumps(source_turn_ids),
|
||||
depth,
|
||||
parent_node_id,
|
||||
ts,
|
||||
),
|
||||
)
|
||||
|
||||
self._db.write(_do)
|
||||
return node_id
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[Dict]:
|
||||
row = self._db.fetchone(
|
||||
"SELECT * FROM atlas_summary_nodes WHERE node_id = ?", (node_id,)
|
||||
)
|
||||
if row:
|
||||
row["source_turn_ids"] = json.loads(row["source_turn_ids"])
|
||||
return row
|
||||
|
||||
def get_session_nodes(self, session_id: str) -> List[Dict]:
|
||||
rows = self._db.fetchall(
|
||||
"SELECT * FROM atlas_summary_nodes WHERE session_id = ? ORDER BY created_at",
|
||||
(session_id,),
|
||||
)
|
||||
for row in rows:
|
||||
row["source_turn_ids"] = json.loads(row["source_turn_ids"])
|
||||
return rows
|
||||
|
||||
def get_source_turns(self, node_id: str, turn_store: TurnStore) -> List[Dict]:
|
||||
"""Expand a node: retrieve the original source turns (lossless recall)."""
|
||||
node = self.get_node(node_id)
|
||||
if not node:
|
||||
return []
|
||||
turns = []
|
||||
for tid in node["source_turn_ids"]:
|
||||
t = turn_store.get_turn(tid)
|
||||
if t:
|
||||
turns.append(t)
|
||||
return turns
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TypedLinkExtractor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Five relation types with their compiled patterns
|
||||
_LINK_PATTERNS: Dict[str, List[re.Pattern]] = {
|
||||
"DEFINES": [
|
||||
re.compile(r"\b\w+\s+is\s+(?:a|an|the)\b", re.IGNORECASE),
|
||||
re.compile(r"\bdefined?\s+as\b", re.IGNORECASE),
|
||||
re.compile(r"\bmeans\b", re.IGNORECASE),
|
||||
re.compile(r"\bdef\s+\w+\s*\(", re.IGNORECASE),
|
||||
re.compile(r"\bclass\s+\w+", re.IGNORECASE),
|
||||
re.compile(r"\b\w+\s*=\s*\S", re.IGNORECASE),
|
||||
],
|
||||
"MODIFIES": [
|
||||
re.compile(r"\b(?:changed|updated|modified|fixed|replaced|patched)\b", re.IGNORECASE),
|
||||
],
|
||||
"REFERENCES": [
|
||||
re.compile(r"\bas\s+mentioned\b", re.IGNORECASE),
|
||||
re.compile(r"\bas\s+I\s+said\b", re.IGNORECASE),
|
||||
re.compile(r"\bsee\s+above\b", re.IGNORECASE),
|
||||
re.compile(r"\breferring\s+to\b", re.IGNORECASE),
|
||||
re.compile(r'"[^"]{2,}"', re.IGNORECASE), # quoted phrases
|
||||
],
|
||||
"DEPENDS_ON": [
|
||||
re.compile(r"\brequires?\b", re.IGNORECASE),
|
||||
re.compile(r"\bneeds?\b", re.IGNORECASE),
|
||||
re.compile(r"\bdepends?\s+on\b", re.IGNORECASE),
|
||||
re.compile(r"\bimport\s+\w+", re.IGNORECASE),
|
||||
re.compile(r"\bprerequisite\b", re.IGNORECASE),
|
||||
],
|
||||
"CONTRADICTS": [
|
||||
re.compile(r"\bnot\s+\w+", re.IGNORECASE),
|
||||
re.compile(r"\bactually\b", re.IGNORECASE),
|
||||
re.compile(r"\bwrong\b", re.IGNORECASE),
|
||||
re.compile(r"\bincorrect\b", re.IGNORECASE),
|
||||
re.compile(r"\bthat'?s?\s+not\s+right\b", re.IGNORECASE),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class TypedLinkExtractor:
|
||||
"""Deterministic typed-link extraction — no LLM calls."""
|
||||
|
||||
def __init__(self, db: AtlasDB):
|
||||
self._db = db
|
||||
|
||||
def extract_and_store(
|
||||
self, turn_id: str, content: str, session_id: str, role: str
|
||||
) -> List[Dict]:
|
||||
if not content:
|
||||
return []
|
||||
|
||||
links: List[Dict] = []
|
||||
now = time.time()
|
||||
|
||||
for link_type, patterns in _LINK_PATTERNS.items():
|
||||
for pattern in patterns:
|
||||
matches = pattern.findall(content)
|
||||
for match_text in matches:
|
||||
target_hash = hashlib.sha1(
|
||||
f"{match_text}".encode()
|
||||
).hexdigest()[:8]
|
||||
link_id = f"lnk:{turn_id}:{link_type}:{target_hash}"
|
||||
link = {
|
||||
"link_id": link_id,
|
||||
"source_id": turn_id,
|
||||
"source_type": "turn",
|
||||
"target_text": match_text,
|
||||
"link_type": link_type,
|
||||
"confidence": 1.0,
|
||||
"session_id": session_id,
|
||||
"created_at": now,
|
||||
}
|
||||
links.append(link)
|
||||
|
||||
if links:
|
||||
def _do(conn):
|
||||
for lnk in links:
|
||||
conn.execute(
|
||||
"""INSERT OR IGNORE INTO atlas_links
|
||||
(link_id, source_id, source_type, target_text,
|
||||
link_type, confidence, session_id, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
lnk["link_id"],
|
||||
lnk["source_id"],
|
||||
lnk["source_type"],
|
||||
lnk["target_text"],
|
||||
lnk["link_type"],
|
||||
lnk["confidence"],
|
||||
lnk["session_id"],
|
||||
lnk["created_at"],
|
||||
),
|
||||
)
|
||||
|
||||
self._db.write(_do)
|
||||
|
||||
return links
|
||||
|
||||
def get_links_for_turn(self, turn_id: str) -> List[Dict]:
|
||||
return self._db.fetchall(
|
||||
"SELECT * FROM atlas_links WHERE source_id = ?", (turn_id,)
|
||||
)
|
||||
|
||||
def get_links_by_type(
|
||||
self, link_type: str, session_id: str = None
|
||||
) -> List[Dict]:
|
||||
if session_id is not None:
|
||||
return self._db.fetchall(
|
||||
"SELECT * FROM atlas_links WHERE link_type = ? AND session_id = ?",
|
||||
(link_type, session_id),
|
||||
)
|
||||
return self._db.fetchall(
|
||||
"SELECT * FROM atlas_links WHERE link_type = ?", (link_type,)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RecallEngine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RecallEngine:
|
||||
"""Three explicit recall operations: search, describe, expand."""
|
||||
|
||||
def __init__(self, db: AtlasDB, turn_store: TurnStore, summary_dag: SummaryDAG):
|
||||
self._db = db
|
||||
self._turns = turn_store
|
||||
self._dag = summary_dag
|
||||
|
||||
def search(
|
||||
self, query: str, session_id: str = None, limit: int = 10
|
||||
) -> List[Dict]:
|
||||
"""LIKE-based search over turns and summary nodes.
|
||||
|
||||
Returns a list of result dicts with keys: id, type, content, session_id.
|
||||
"""
|
||||
pattern = f"%{query}%"
|
||||
results: List[Dict] = []
|
||||
|
||||
# Search turns
|
||||
if session_id is not None:
|
||||
turn_rows = self._db.fetchall(
|
||||
"SELECT turn_id, session_id, role, content, created_at "
|
||||
"FROM atlas_turns "
|
||||
"WHERE session_id = ? AND content LIKE ?",
|
||||
(session_id, pattern),
|
||||
)
|
||||
else:
|
||||
turn_rows = self._db.fetchall(
|
||||
"SELECT turn_id, session_id, role, content, created_at "
|
||||
"FROM atlas_turns WHERE content LIKE ?",
|
||||
(pattern,),
|
||||
)
|
||||
for row in turn_rows:
|
||||
results.append(
|
||||
{
|
||||
"id": row["turn_id"],
|
||||
"type": "turn",
|
||||
"content": row["content"],
|
||||
"session_id": row["session_id"],
|
||||
"role": row.get("role"),
|
||||
"created_at": row.get("created_at"),
|
||||
}
|
||||
)
|
||||
|
||||
# Search summary nodes
|
||||
if session_id is not None:
|
||||
node_rows = self._db.fetchall(
|
||||
"SELECT node_id, session_id, summary_text, created_at "
|
||||
"FROM atlas_summary_nodes "
|
||||
"WHERE session_id = ? AND summary_text LIKE ?",
|
||||
(session_id, pattern),
|
||||
)
|
||||
else:
|
||||
node_rows = self._db.fetchall(
|
||||
"SELECT node_id, session_id, summary_text, created_at "
|
||||
"FROM atlas_summary_nodes WHERE summary_text LIKE ?",
|
||||
(pattern,),
|
||||
)
|
||||
for row in node_rows:
|
||||
results.append(
|
||||
{
|
||||
"id": row["node_id"],
|
||||
"type": "summary_node",
|
||||
"content": row["summary_text"],
|
||||
"session_id": row["session_id"],
|
||||
"created_at": row.get("created_at"),
|
||||
}
|
||||
)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def describe(self, item_id: str) -> Optional[Dict]:
|
||||
"""Get the full content of a turn or summary node by ID."""
|
||||
# Try turn first
|
||||
turn = self._turns.get_turn(item_id)
|
||||
if turn:
|
||||
turn["type"] = "turn"
|
||||
return turn
|
||||
# Try summary node
|
||||
node = self._dag.get_node(item_id)
|
||||
if node:
|
||||
node["type"] = "summary_node"
|
||||
return node
|
||||
return None
|
||||
|
||||
def expand(self, node_id: str) -> Dict:
|
||||
"""Expand a summary node to its original source turns.
|
||||
|
||||
Returns {"node": ..., "source_turns": [...]}.
|
||||
"""
|
||||
node = self._dag.get_node(node_id)
|
||||
if not node:
|
||||
return {"node": None, "source_turns": []}
|
||||
source_turns = self._dag.get_source_turns(node_id, self._turns)
|
||||
return {"node": node, "source_turns": source_turns}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AtlasStore
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
WORLD_KNOWLEDGE = "world_knowledge"
|
||||
DURABLE_MEMORY = "durable_memory"
|
||||
SESSION_STATE = "session_state"
|
||||
|
||||
_VALID_STORES = {WORLD_KNOWLEDGE, DURABLE_MEMORY, SESSION_STATE}
|
||||
|
||||
|
||||
class AtlasStore:
|
||||
"""Three-store architecture — routes writes to world_knowledge, durable_memory, or session_state."""
|
||||
|
||||
def __init__(self, db: AtlasDB):
|
||||
self._db = db
|
||||
|
||||
def _validate_store(self, store_name: str) -> None:
|
||||
if store_name not in _VALID_STORES:
|
||||
raise ValueError(
|
||||
f"Invalid store '{store_name}'. Valid stores: {sorted(_VALID_STORES)}"
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
store_name: str,
|
||||
key: str,
|
||||
value: str,
|
||||
session_id: str = None,
|
||||
) -> None:
|
||||
self._validate_store(store_name)
|
||||
now = time.time()
|
||||
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"""INSERT INTO atlas_stores
|
||||
(store_name, key, value, session_id, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(store_name, key, COALESCE(session_id, ''))
|
||||
DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at""",
|
||||
(store_name, key, value, session_id, now, now),
|
||||
)
|
||||
|
||||
self._db.write(_do)
|
||||
|
||||
def read(
|
||||
self, store_name: str, key: str, session_id: str = None
|
||||
) -> Optional[str]:
|
||||
self._validate_store(store_name)
|
||||
if session_id is None:
|
||||
row = self._db.fetchone(
|
||||
"SELECT value FROM atlas_stores "
|
||||
"WHERE store_name = ? AND key = ? AND session_id IS NULL",
|
||||
(store_name, key),
|
||||
)
|
||||
else:
|
||||
row = self._db.fetchone(
|
||||
"SELECT value FROM atlas_stores "
|
||||
"WHERE store_name = ? AND key = ? AND session_id = ?",
|
||||
(store_name, key, session_id),
|
||||
)
|
||||
return row["value"] if row else None
|
||||
|
||||
def list_keys(self, store_name: str, session_id: str = None) -> List[str]:
|
||||
self._validate_store(store_name)
|
||||
if session_id is None:
|
||||
rows = self._db.fetchall(
|
||||
"SELECT key FROM atlas_stores "
|
||||
"WHERE store_name = ? AND session_id IS NULL",
|
||||
(store_name,),
|
||||
)
|
||||
else:
|
||||
rows = self._db.fetchall(
|
||||
"SELECT key FROM atlas_stores "
|
||||
"WHERE store_name = ? AND session_id = ?",
|
||||
(store_name, session_id),
|
||||
)
|
||||
return [r["key"] for r in rows]
|
||||
|
||||
def delete(
|
||||
self, store_name: str, key: str, session_id: str = None
|
||||
) -> None:
|
||||
self._validate_store(store_name)
|
||||
|
||||
def _do(conn):
|
||||
if session_id is None:
|
||||
conn.execute(
|
||||
"DELETE FROM atlas_stores "
|
||||
"WHERE store_name = ? AND key = ? AND session_id IS NULL",
|
||||
(store_name, key),
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
"DELETE FROM atlas_stores "
|
||||
"WHERE store_name = ? AND key = ? AND session_id = ?",
|
||||
(store_name, key, session_id),
|
||||
)
|
||||
|
||||
self._db.write(_do)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AtlasMemory — top-level facade
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AtlasMemory:
|
||||
"""Top-level facade combining all ATLAS subsystems."""
|
||||
|
||||
def __init__(self, db_path: Path = None):
|
||||
self.db = AtlasDB(db_path=db_path)
|
||||
self.turns = TurnStore(self.db)
|
||||
self.dag = SummaryDAG(self.db)
|
||||
self.links = TypedLinkExtractor(self.db)
|
||||
self.recall = RecallEngine(self.db, self.turns, self.dag)
|
||||
self.store = AtlasStore(self.db)
|
||||
|
||||
def record_turn(
|
||||
self,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str = None,
|
||||
tool_name: str = None,
|
||||
tool_call_id: str = None,
|
||||
) -> str:
|
||||
"""Append a turn and extract typed links from its content."""
|
||||
turn_id = self.turns.append_turn(
|
||||
session_id,
|
||||
role,
|
||||
content=content,
|
||||
tool_name=tool_name,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
if content:
|
||||
self.links.extract_and_store(turn_id, content, session_id, role)
|
||||
return turn_id
|
||||
|
||||
def compact_session(
|
||||
self,
|
||||
session_id: str,
|
||||
summary_text: str,
|
||||
turn_ids: List[str] = None,
|
||||
) -> str:
|
||||
"""Compact session turns into a DAG node and return the node_id.
|
||||
|
||||
If *turn_ids* is None, all turns for the session are referenced.
|
||||
"""
|
||||
if turn_ids is None:
|
||||
all_turns = self.turns.get_session_turns(session_id)
|
||||
turn_ids = [t["turn_id"] for t in all_turns]
|
||||
return self.dag.create_summary_node(session_id, summary_text, turn_ids)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_default_atlas: Optional[AtlasMemory] = None
|
||||
_default_atlas_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_default_atlas() -> AtlasMemory:
|
||||
"""Get or create the default AtlasMemory instance (profile-scoped)."""
|
||||
global _default_atlas
|
||||
with _default_atlas_lock:
|
||||
if _default_atlas is None:
|
||||
_default_atlas = AtlasMemory(
|
||||
db_path=get_hermes_home() / "atlas.db"
|
||||
)
|
||||
return _default_atlas
|
||||
903
agent/lossless_context.py
Normal file
903
agent/lossless_context.py
Normal file
@@ -0,0 +1,903 @@
|
||||
"""Lossless context + memory subsystem.
|
||||
|
||||
Inspired by hermes-lcm (immutable session storage, summary DAG compaction,
|
||||
lineage-aware externalization) and gbrain (strict store separation, deterministic
|
||||
typed-link extraction, disciplined retrieval boundaries).
|
||||
|
||||
Three-store architecture:
|
||||
- SESSION_STATE : per-session immutable turn log + summary DAG
|
||||
- DURABLE_MEMORY : user preferences, corrections, project facts (cross-session)
|
||||
- WORLD_KNOWLEDGE: stable world facts not specific to this user or session
|
||||
|
||||
Key types:
|
||||
TurnRecord — immutable raw turn with stable lineage ID (session:seq)
|
||||
SessionTurnStore — append-only JSONL store for raw turns
|
||||
SummaryNode — compacted summary with source references (turn IDs)
|
||||
SummaryDAG — append-only JSONL store for summary nodes
|
||||
RelationLink — a typed directed link between two entities
|
||||
LinkExtractor — deterministic regex-based typed-link extraction (≥5 types)
|
||||
StoreRouter — routes facts to the correct store tier
|
||||
RecallEngine — search / describe / expand recall operations
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Store tiers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class StoreTier(str, Enum):
|
||||
"""The three explicit memory tiers."""
|
||||
WORLD_KNOWLEDGE = "world_knowledge"
|
||||
DURABLE_MEMORY = "durable_memory"
|
||||
SESSION_STATE = "session_state"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Typed relation links
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RelationType(str, Enum):
|
||||
"""Supported typed relation types for link extraction."""
|
||||
PREFERS = "PREFERS" # entity A prefers B
|
||||
CORRECTS = "CORRECTS" # entity A corrects claim B
|
||||
USES = "USES" # entity A uses technology/tool B
|
||||
LOCATED_AT = "LOCATED_AT" # entity A is located at file/path B
|
||||
DEPENDS_ON = "DEPENDS_ON" # entity A depends on B
|
||||
CONFIGURES = "CONFIGURES" # entity A configures setting B
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelationLink:
|
||||
"""A typed directed link between two text entities extracted from a turn."""
|
||||
relation: str # RelationType value
|
||||
subject: str # what/who the relation is about
|
||||
object_: str # the value / target
|
||||
source_turn_id: str # lineage reference to the turn it came from
|
||||
confidence: float = 0.7
|
||||
store_tier: str = StoreTier.DURABLE_MEMORY
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = asdict(self)
|
||||
d["object"] = d.pop("object_")
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict[str, Any]) -> "RelationLink":
|
||||
d = dict(d)
|
||||
d["object_"] = d.pop("object", d.pop("object_", ""))
|
||||
return cls(**d)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Turn record
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class TurnRecord:
|
||||
"""An immutable raw session turn with a stable lineage identifier.
|
||||
|
||||
The lineage_id is ``{session_id}:{seq}`` and never changes once assigned.
|
||||
"""
|
||||
lineage_id: str # stable identifier: session_id:seq_number
|
||||
session_id: str
|
||||
seq: int # monotonically increasing within the session
|
||||
role: str # "user" | "assistant" | "tool" | "system"
|
||||
content: str # raw text content
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
tool_name: Optional[str] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
content_sha256: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.content_sha256:
|
||||
self.content_sha256 = hashlib.sha256(
|
||||
self.content.encode("utf-8", errors="replace")
|
||||
).hexdigest()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict[str, Any]) -> "TurnRecord":
|
||||
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
|
||||
|
||||
@classmethod
|
||||
def create(cls, session_id: str, seq: int, role: str, content: str, **kwargs) -> "TurnRecord":
|
||||
return cls(
|
||||
lineage_id=f"{session_id}:{seq}",
|
||||
session_id=session_id,
|
||||
seq=seq,
|
||||
role=role,
|
||||
content=content,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session turn store — immutable append-only JSONL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SessionTurnStore:
|
||||
"""Append-only immutable store for raw session turns.
|
||||
|
||||
Each turn is written as a single JSON line to:
|
||||
<hermes_home>/sessions/<session_id>/turns.jsonl
|
||||
|
||||
Turns are never modified or deleted (lossless). Compaction only creates
|
||||
SummaryNode entries in SummaryDAG — the raw turns always remain.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, hermes_home: Optional[Path] = None):
|
||||
self.session_id = session_id
|
||||
if hermes_home is None:
|
||||
from hermes_constants import get_hermes_home
|
||||
hermes_home = get_hermes_home()
|
||||
self._dir = Path(hermes_home) / "sessions" / session_id
|
||||
self._dir.mkdir(parents=True, exist_ok=True)
|
||||
self._turns_path = self._dir / "turns.jsonl"
|
||||
self._seq = self._load_next_seq()
|
||||
|
||||
def _load_next_seq(self) -> int:
|
||||
"""Determine next sequence number from existing file."""
|
||||
if not self._turns_path.exists():
|
||||
return 0
|
||||
seq = 0
|
||||
try:
|
||||
with open(self._turns_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
rec = json.loads(line)
|
||||
seq = max(seq, rec.get("seq", 0) + 1)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
return seq
|
||||
|
||||
def append(
|
||||
self,
|
||||
role: str,
|
||||
content: str,
|
||||
tool_name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
) -> TurnRecord:
|
||||
"""Append a new turn and return the immutable record."""
|
||||
record = TurnRecord.create(
|
||||
session_id=self.session_id,
|
||||
seq=self._seq,
|
||||
role=role,
|
||||
content=content,
|
||||
tool_name=tool_name,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
self._seq += 1
|
||||
self._write_record(record)
|
||||
return record
|
||||
|
||||
def _write_record(self, record: TurnRecord) -> None:
|
||||
try:
|
||||
line = json.dumps(record.to_dict(), ensure_ascii=False) + "\n"
|
||||
with open(self._turns_path, "a", encoding="utf-8") as f:
|
||||
f.write(line)
|
||||
f.flush()
|
||||
except (OSError, IOError) as e:
|
||||
logger.error("Failed to persist turn %s: %s", record.lineage_id, e)
|
||||
|
||||
def load_all(self) -> List[TurnRecord]:
|
||||
"""Load all turns from disk, in sequence order."""
|
||||
if not self._turns_path.exists():
|
||||
return []
|
||||
records = []
|
||||
try:
|
||||
with open(self._turns_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
records.append(TurnRecord.from_dict(json.loads(line)))
|
||||
except (json.JSONDecodeError, TypeError, KeyError) as e:
|
||||
logger.debug("Skipping malformed turn record: %s", e)
|
||||
except (OSError, IOError) as e:
|
||||
logger.error("Failed to load turns: %s", e)
|
||||
return sorted(records, key=lambda r: r.seq)
|
||||
|
||||
def get_by_id(self, lineage_id: str) -> Optional[TurnRecord]:
|
||||
"""Retrieve a single turn by its lineage ID."""
|
||||
for r in self.load_all():
|
||||
if r.lineage_id == lineage_id:
|
||||
return r
|
||||
return None
|
||||
|
||||
def search(self, query: str, limit: int = 20) -> List[TurnRecord]:
|
||||
"""Full-text search over turn content (case-insensitive substring)."""
|
||||
q = query.lower()
|
||||
results = []
|
||||
for r in self.load_all():
|
||||
if q in r.content.lower():
|
||||
results.append(r)
|
||||
if len(results) >= limit:
|
||||
break
|
||||
return results
|
||||
|
||||
@property
|
||||
def turn_count(self) -> int:
|
||||
return self._seq
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Summary DAG
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class SummaryNode:
|
||||
"""A compacted summary node with source lineage references.
|
||||
|
||||
Unlike destructive truncation, each SummaryNode keeps a list of
|
||||
``source_turn_ids`` pointing back to the raw turns it was built from.
|
||||
This makes compaction lossless-traceable: original content is always
|
||||
retrievable via the SessionTurnStore.
|
||||
"""
|
||||
node_id: str # unique ID within the session: summary:<n>
|
||||
session_id: str
|
||||
summary_text: str # the compacted summary text
|
||||
source_turn_ids: List[str] # lineage IDs of turns compressed into this node
|
||||
parent_node_id: Optional[str] = None # for DAG structure (iterative updates)
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
topic: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict[str, Any]) -> "SummaryNode":
|
||||
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
|
||||
|
||||
|
||||
class SummaryDAG:
|
||||
"""Append-only DAG of summary nodes for a session.
|
||||
|
||||
Stored at: <hermes_home>/sessions/<session_id>/summaries.jsonl
|
||||
|
||||
Each compaction appends a new SummaryNode that references the raw turns
|
||||
it summarised. Summary nodes can reference previous summary nodes as
|
||||
parents, building a DAG of progressive compactions.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, hermes_home: Optional[Path] = None):
|
||||
self.session_id = session_id
|
||||
if hermes_home is None:
|
||||
from hermes_constants import get_hermes_home
|
||||
hermes_home = get_hermes_home()
|
||||
self._dir = Path(hermes_home) / "sessions" / session_id
|
||||
self._dir.mkdir(parents=True, exist_ok=True)
|
||||
self._path = self._dir / "summaries.jsonl"
|
||||
self._counter = self._load_next_counter()
|
||||
|
||||
def _load_next_counter(self) -> int:
|
||||
if not self._path.exists():
|
||||
return 0
|
||||
n = 0
|
||||
try:
|
||||
with open(self._path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
n += 1
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
return n
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
summary_text: str,
|
||||
source_turn_ids: List[str],
|
||||
parent_node_id: Optional[str] = None,
|
||||
topic: Optional[str] = None,
|
||||
) -> SummaryNode:
|
||||
"""Append a new summary node to the DAG and return it."""
|
||||
node_id = f"{self.session_id}:summary:{self._counter}"
|
||||
node = SummaryNode(
|
||||
node_id=node_id,
|
||||
session_id=self.session_id,
|
||||
summary_text=summary_text,
|
||||
source_turn_ids=source_turn_ids,
|
||||
parent_node_id=parent_node_id,
|
||||
topic=topic,
|
||||
)
|
||||
self._counter += 1
|
||||
self._write_node(node)
|
||||
return node
|
||||
|
||||
def _write_node(self, node: SummaryNode) -> None:
|
||||
try:
|
||||
line = json.dumps(node.to_dict(), ensure_ascii=False) + "\n"
|
||||
with open(self._path, "a", encoding="utf-8") as f:
|
||||
f.write(line)
|
||||
f.flush()
|
||||
except (OSError, IOError) as e:
|
||||
logger.error("Failed to persist summary node %s: %s", node.node_id, e)
|
||||
|
||||
def load_all(self) -> List[SummaryNode]:
|
||||
"""Load all summary nodes, in creation order."""
|
||||
if not self._path.exists():
|
||||
return []
|
||||
nodes = []
|
||||
try:
|
||||
with open(self._path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
nodes.append(SummaryNode.from_dict(json.loads(line)))
|
||||
except (json.JSONDecodeError, TypeError, KeyError) as e:
|
||||
logger.debug("Skipping malformed summary node: %s", e)
|
||||
except (OSError, IOError) as e:
|
||||
logger.error("Failed to load summary nodes: %s", e)
|
||||
return nodes
|
||||
|
||||
def get_by_id(self, node_id: str) -> Optional[SummaryNode]:
|
||||
for n in self.load_all():
|
||||
if n.node_id == node_id:
|
||||
return n
|
||||
return None
|
||||
|
||||
def get_latest(self) -> Optional[SummaryNode]:
|
||||
"""Return the most recently added summary node, or None."""
|
||||
nodes = self.load_all()
|
||||
return nodes[-1] if nodes else None
|
||||
|
||||
def search(self, query: str, limit: int = 10) -> List[SummaryNode]:
|
||||
"""Full-text search over summary text."""
|
||||
q = query.lower()
|
||||
return [n for n in self.load_all() if q in n.summary_text.lower()][:limit]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Relation link store
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RelationLinkStore:
|
||||
"""Append-only store for typed relation links extracted from turns.
|
||||
|
||||
Stored at: <hermes_home>/sessions/<session_id>/links.jsonl
|
||||
Cross-session links go to: <hermes_home>/knowledge/links.jsonl
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
self._path = path
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def append(self, link: RelationLink) -> None:
|
||||
try:
|
||||
line = json.dumps(link.to_dict(), ensure_ascii=False) + "\n"
|
||||
with open(self._path, "a", encoding="utf-8") as f:
|
||||
f.write(line)
|
||||
f.flush()
|
||||
except (OSError, IOError) as e:
|
||||
logger.error("Failed to persist relation link: %s", e)
|
||||
|
||||
def load_all(self) -> List[RelationLink]:
|
||||
if not self._path.exists():
|
||||
return []
|
||||
links = []
|
||||
try:
|
||||
with open(self._path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
links.append(RelationLink.from_dict(json.loads(line)))
|
||||
except (json.JSONDecodeError, TypeError, KeyError) as e:
|
||||
logger.debug("Skipping malformed link: %s", e)
|
||||
except (OSError, IOError) as e:
|
||||
logger.error("Failed to load links: %s", e)
|
||||
return links
|
||||
|
||||
def search(self, query: str, relation: Optional[str] = None) -> List[RelationLink]:
|
||||
"""Search links by query substring and optional relation type filter."""
|
||||
q = query.lower()
|
||||
results = []
|
||||
for link in self.load_all():
|
||||
if relation and link.relation != relation:
|
||||
continue
|
||||
if q in link.subject.lower() or q in link.object_.lower():
|
||||
results.append(link)
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Typed-link extraction (deterministic, regex-based)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Each entry: (pattern, RelationType, subject_group, object_group, confidence, store_tier)
|
||||
_LINK_PATTERNS: List[Tuple] = [
|
||||
# PREFERS — "I prefer X over Y", "I prefer X", "prefer to use X"
|
||||
(
|
||||
r"(?:I|we)\s+prefer\s+(?:to\s+use\s+)?(.+?)(?:\s+over\s+(.+?))?(?:\.|$)",
|
||||
RelationType.PREFERS, "user", 1, 0.8, StoreTier.DURABLE_MEMORY,
|
||||
),
|
||||
(
|
||||
r"my\s+preferred\s+\S+\s+is\s+(.+?)(?:\.|$)",
|
||||
RelationType.PREFERS, "user", 1, 0.75, StoreTier.DURABLE_MEMORY,
|
||||
),
|
||||
# CORRECTS — "actually X", "no, it's X", "correction: X", "I meant X"
|
||||
(
|
||||
r"(?:actually[,:]?\s+|no[,:]?\s+|correction[,:]?\s+)(.{5,150})(?:\.|$)",
|
||||
RelationType.CORRECTS, "user", 1, 0.85, StoreTier.DURABLE_MEMORY,
|
||||
),
|
||||
(
|
||||
r"(?:I meant|what I meant was)\s+(.{5,150})(?:\.|$)",
|
||||
RelationType.CORRECTS, "user", 1, 0.85, StoreTier.DURABLE_MEMORY,
|
||||
),
|
||||
# USES — "the project uses X", "we use X", "using X"
|
||||
(
|
||||
r"(?:the\s+project|the\s+repo|the\s+codebase|we|this\s+app)\s+use[sd]?\s+(.{3,100})(?:\s+for\s+.+?)?(?:\.|$)",
|
||||
RelationType.USES, "project", 1, 0.7, StoreTier.WORLD_KNOWLEDGE,
|
||||
),
|
||||
(
|
||||
r"\busing\s+([A-Z][a-zA-Z0-9_-]{1,40})(?:\s+for\s+.+?)?(?:\.|,|$)",
|
||||
RelationType.USES, "project", 1, 0.65, StoreTier.WORLD_KNOWLEDGE,
|
||||
),
|
||||
# LOCATED_AT — file paths like "the config is at /path/to/file"
|
||||
(
|
||||
r"(?:the\s+)?(\w[\w\s]{0,30}?)\s+(?:is\s+(?:at|in|under)|lives\s+at|can\s+be\s+found\s+at)\s+((?:/|\.\.?/|~/)[^\s,;]+)",
|
||||
RelationType.LOCATED_AT, 1, 2, 0.8, StoreTier.DURABLE_MEMORY,
|
||||
),
|
||||
(
|
||||
r"(?:path|file|config|config\s+file|directory|dir)\s+(?:is\s+)?['\"]?((?:/|\.\.?/|~/)[^\s'\"]+)['\"]?",
|
||||
RelationType.LOCATED_AT, "file", 1, 0.7, StoreTier.DURABLE_MEMORY,
|
||||
),
|
||||
# DEPENDS_ON — "X requires Y", "X depends on Y"
|
||||
(
|
||||
r"(\S+)\s+(?:requires|depends\s+on|needs)\s+([A-Za-z0-9_.-]+(?:\s+version\s+[\d.]+)?)",
|
||||
RelationType.DEPENDS_ON, 1, 2, 0.7, StoreTier.WORLD_KNOWLEDGE,
|
||||
),
|
||||
# CONFIGURES — "X is set to Y", "X is configured as Y"
|
||||
(
|
||||
r"(\S+)\s+(?:is\s+set\s+to|is\s+configured\s+(?:as|to)|is\s+enabled|is\s+disabled)\s*(.{0,80})(?:\.|$)",
|
||||
RelationType.CONFIGURES, 1, 2, 0.75, StoreTier.DURABLE_MEMORY,
|
||||
),
|
||||
(
|
||||
r"(?:set|configure)\s+(\S+)\s+(?:to|as)\s+(.{3,80})(?:\.|$)",
|
||||
RelationType.CONFIGURES, 1, 2, 0.7, StoreTier.DURABLE_MEMORY,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _extract_group(match: re.Match, group_ref: Any, text: str) -> str:
|
||||
"""Extract a group from a regex match, or use a literal string."""
|
||||
if isinstance(group_ref, int):
|
||||
try:
|
||||
val = match.group(group_ref)
|
||||
return (val or "").strip()
|
||||
except IndexError:
|
||||
return ""
|
||||
return str(group_ref)
|
||||
|
||||
|
||||
class LinkExtractor:
|
||||
"""Deterministic typed-link extraction from text.
|
||||
|
||||
Supports at least 5 relation types:
|
||||
PREFERS, CORRECTS, USES, LOCATED_AT, DEPENDS_ON, CONFIGURES
|
||||
|
||||
All extraction is regex-based (no LLM calls) — fully deterministic and
|
||||
fast, suitable for running on every turn write.
|
||||
"""
|
||||
|
||||
def extract(self, text: str, source_turn_id: str) -> List[RelationLink]:
|
||||
"""Extract typed links from text.
|
||||
|
||||
Args:
|
||||
text: Raw text to extract from.
|
||||
source_turn_id: Lineage ID of the turn this text came from.
|
||||
|
||||
Returns:
|
||||
List of extracted RelationLink objects (may be empty).
|
||||
"""
|
||||
if not text or len(text) < 8:
|
||||
return []
|
||||
|
||||
links: List[RelationLink] = []
|
||||
seen: set = set() # Dedup by (relation, subject, object)
|
||||
|
||||
for (pattern, rel_type, subj_ref, obj_ref, confidence, store_tier) in _LINK_PATTERNS:
|
||||
for match in re.finditer(pattern, text, re.IGNORECASE):
|
||||
subject = _extract_group(match, subj_ref, text)
|
||||
obj = _extract_group(match, obj_ref, text)
|
||||
|
||||
if not subject or not obj:
|
||||
continue
|
||||
if len(subject) > 200 or len(obj) > 200:
|
||||
continue
|
||||
# Clean up whitespace
|
||||
subject = " ".join(subject.split())
|
||||
obj = " ".join(obj.split())
|
||||
if not subject or not obj:
|
||||
continue
|
||||
|
||||
key = (rel_type.value, subject[:50], obj[:50])
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
|
||||
links.append(RelationLink(
|
||||
relation=rel_type.value,
|
||||
subject=subject,
|
||||
object_=obj,
|
||||
source_turn_id=source_turn_id,
|
||||
confidence=confidence,
|
||||
store_tier=store_tier.value,
|
||||
))
|
||||
|
||||
return links
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Store router — explicit routing to world_knowledge / durable_memory / session_state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class StoreRouter:
|
||||
"""Routes fact writes to the correct memory tier.
|
||||
|
||||
Tiers:
|
||||
SESSION_STATE — raw turns and summaries (SessionTurnStore + SummaryDAG)
|
||||
DURABLE_MEMORY — user preferences, corrections, project facts (cross-session)
|
||||
WORLD_KNOWLEDGE — stable world facts, technology/tool usage (cross-session)
|
||||
|
||||
The router maintains a per-tier JSONL fact store and provides a unified
|
||||
``write`` interface. The existing MemoryStore (MEMORY.md / USER.md) handles
|
||||
the user-facing durable_memory surface; this router is the backend store.
|
||||
"""
|
||||
|
||||
def __init__(self, hermes_home: Optional[Path] = None):
|
||||
if hermes_home is None:
|
||||
from hermes_constants import get_hermes_home
|
||||
hermes_home = get_hermes_home()
|
||||
self._hermes_home = Path(hermes_home)
|
||||
self._stores: Dict[str, Path] = {
|
||||
StoreTier.WORLD_KNOWLEDGE: self._hermes_home / "knowledge" / "world_knowledge.jsonl",
|
||||
StoreTier.DURABLE_MEMORY: self._hermes_home / "memories" / "durable_memory.jsonl",
|
||||
}
|
||||
# Ensure directories exist
|
||||
for path in self._stores.values():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def write(
|
||||
self,
|
||||
tier: str,
|
||||
category: str,
|
||||
content: str,
|
||||
source_id: str = "",
|
||||
confidence: float = 0.7,
|
||||
) -> None:
|
||||
"""Write a fact to the specified tier store."""
|
||||
if tier == StoreTier.SESSION_STATE:
|
||||
return # session state is managed by SessionTurnStore, not here
|
||||
path = self._stores.get(tier)
|
||||
if path is None:
|
||||
logger.warning("StoreRouter: unknown tier '%s'", tier)
|
||||
return
|
||||
entry = {
|
||||
"tier": tier,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"source_id": source_id,
|
||||
"confidence": confidence,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
try:
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
f.flush()
|
||||
except (OSError, IOError) as e:
|
||||
logger.error("StoreRouter: failed to write to %s: %s", tier, e)
|
||||
|
||||
def load(self, tier: str) -> List[Dict[str, Any]]:
|
||||
"""Load all facts from a tier store."""
|
||||
path = self._stores.get(tier)
|
||||
if not path or not path.exists():
|
||||
return []
|
||||
facts = []
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
facts.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except (OSError, IOError) as e:
|
||||
logger.error("StoreRouter: failed to load %s: %s", tier, e)
|
||||
return facts
|
||||
|
||||
def search(self, query: str, tier: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Search facts across tiers (or a single tier if specified)."""
|
||||
q = query.lower()
|
||||
results = []
|
||||
tiers_to_search = [tier] if tier else list(self._stores.keys())
|
||||
for t in tiers_to_search:
|
||||
for fact in self.load(t):
|
||||
content = fact.get("content", "")
|
||||
if q in content.lower():
|
||||
results.append(fact)
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recall engine — search / describe / expand
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RecallEngine:
|
||||
"""Explicit recall operations over compacted context.
|
||||
|
||||
Three primary recall operations:
|
||||
search(query) — full-text search across turns, summaries, and links
|
||||
describe(lineage_id) — retrieve details about a specific turn or summary node
|
||||
expand(summary_node_id)— expand a summary node to its source turns
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
turn_store: SessionTurnStore,
|
||||
summary_dag: SummaryDAG,
|
||||
link_store: RelationLinkStore,
|
||||
store_router: StoreRouter,
|
||||
):
|
||||
self._turns = turn_store
|
||||
self._dag = summary_dag
|
||||
self._links = link_store
|
||||
self._router = store_router
|
||||
|
||||
def search(self, query: str, limit: int = 20) -> Dict[str, Any]:
|
||||
"""Search across session turns, summary nodes, relation links, and durable stores.
|
||||
|
||||
Returns a dict with matching results grouped by source type.
|
||||
"""
|
||||
matching_turns = self._turns.search(query, limit=limit)
|
||||
matching_summaries = self._dag.search(query, limit=limit // 2)
|
||||
matching_links = self._links.search(query)
|
||||
durable_facts = self._router.search(query)
|
||||
world_facts = self._router.search(query, tier=StoreTier.WORLD_KNOWLEDGE)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"turns": [
|
||||
{
|
||||
"lineage_id": t.lineage_id,
|
||||
"role": t.role,
|
||||
"content_preview": t.content[:200] + ("..." if len(t.content) > 200 else ""),
|
||||
"timestamp": t.timestamp,
|
||||
}
|
||||
for t in matching_turns
|
||||
],
|
||||
"summaries": [
|
||||
{
|
||||
"node_id": n.node_id,
|
||||
"summary_preview": n.summary_text[:300] + ("..." if len(n.summary_text) > 300 else ""),
|
||||
"source_count": len(n.source_turn_ids),
|
||||
"timestamp": n.timestamp,
|
||||
}
|
||||
for n in matching_summaries
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"relation": lk.relation,
|
||||
"subject": lk.subject,
|
||||
"object": lk.object_,
|
||||
"source_turn_id": lk.source_turn_id,
|
||||
"confidence": lk.confidence,
|
||||
}
|
||||
for lk in matching_links
|
||||
],
|
||||
"durable_facts": durable_facts[:10],
|
||||
"world_facts": world_facts[:10],
|
||||
"total_results": (
|
||||
len(matching_turns) + len(matching_summaries)
|
||||
+ len(matching_links) + len(durable_facts) + len(world_facts)
|
||||
),
|
||||
}
|
||||
|
||||
def describe(self, lineage_id: str) -> Dict[str, Any]:
|
||||
"""Describe a specific turn or summary node by its lineage ID.
|
||||
|
||||
Accepts both turn IDs (``session:seq``) and summary node IDs
|
||||
(``session:summary:n``).
|
||||
"""
|
||||
# Try summary nodes first (contain "summary" in ID)
|
||||
if ":summary:" in lineage_id:
|
||||
node = self._dag.get_by_id(lineage_id)
|
||||
if node:
|
||||
return {
|
||||
"type": "summary_node",
|
||||
"node_id": node.node_id,
|
||||
"summary_text": node.summary_text,
|
||||
"source_turn_ids": node.source_turn_ids,
|
||||
"parent_node_id": node.parent_node_id,
|
||||
"topic": node.topic,
|
||||
"timestamp": node.timestamp,
|
||||
"source_count": len(node.source_turn_ids),
|
||||
}
|
||||
return {"error": f"Summary node not found: {lineage_id}"}
|
||||
|
||||
# Try turn records
|
||||
turn = self._turns.get_by_id(lineage_id)
|
||||
if turn:
|
||||
links = self._links.search("", relation=None)
|
||||
turn_links = [lk for lk in links if lk.source_turn_id == lineage_id]
|
||||
return {
|
||||
"type": "turn",
|
||||
"lineage_id": turn.lineage_id,
|
||||
"session_id": turn.session_id,
|
||||
"seq": turn.seq,
|
||||
"role": turn.role,
|
||||
"content": turn.content,
|
||||
"tool_name": turn.tool_name,
|
||||
"tool_call_id": turn.tool_call_id,
|
||||
"timestamp": turn.timestamp,
|
||||
"content_sha256": turn.content_sha256,
|
||||
"extracted_links": [lk.to_dict() for lk in turn_links],
|
||||
}
|
||||
|
||||
return {"error": f"Turn record not found: {lineage_id}"}
|
||||
|
||||
def expand(self, summary_node_id: str) -> Dict[str, Any]:
|
||||
"""Expand a summary node to show the source turns it was built from.
|
||||
|
||||
This is the key lossless recall operation: even after compaction, the
|
||||
original turns are always retrievable via their lineage IDs.
|
||||
"""
|
||||
node = self._dag.get_by_id(summary_node_id)
|
||||
if not node:
|
||||
return {"error": f"Summary node not found: {summary_node_id}"}
|
||||
|
||||
source_turns = []
|
||||
for turn_id in node.source_turn_ids:
|
||||
turn = self._turns.get_by_id(turn_id)
|
||||
if turn:
|
||||
source_turns.append({
|
||||
"lineage_id": turn.lineage_id,
|
||||
"role": turn.role,
|
||||
"content": turn.content,
|
||||
"timestamp": turn.timestamp,
|
||||
})
|
||||
else:
|
||||
source_turns.append({
|
||||
"lineage_id": turn_id,
|
||||
"error": "Turn not found in local store",
|
||||
})
|
||||
|
||||
# Walk the DAG to find parent summaries
|
||||
parent_chain = []
|
||||
current = node
|
||||
seen_ids: set = {node.node_id}
|
||||
while current.parent_node_id and current.parent_node_id not in seen_ids:
|
||||
parent = self._dag.get_by_id(current.parent_node_id)
|
||||
if not parent:
|
||||
break
|
||||
parent_chain.append({
|
||||
"node_id": parent.node_id,
|
||||
"summary_preview": parent.summary_text[:150] + "...",
|
||||
})
|
||||
seen_ids.add(parent.node_id)
|
||||
current = parent
|
||||
|
||||
return {
|
||||
"node_id": node.node_id,
|
||||
"summary_text": node.summary_text,
|
||||
"topic": node.topic,
|
||||
"timestamp": node.timestamp,
|
||||
"source_turns": source_turns,
|
||||
"source_count": len(source_turns),
|
||||
"found_count": sum(1 for t in source_turns if "error" not in t),
|
||||
"parent_chain": parent_chain,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_lossless_context(
|
||||
session_id: str,
|
||||
hermes_home: Optional[Path] = None,
|
||||
) -> Tuple[SessionTurnStore, SummaryDAG, LinkExtractor, RelationLinkStore, StoreRouter, RecallEngine]:
|
||||
"""Create all components of the lossless context subsystem for a session.
|
||||
|
||||
Returns a tuple of (turn_store, summary_dag, link_extractor, link_store,
|
||||
store_router, recall_engine).
|
||||
"""
|
||||
if hermes_home is None:
|
||||
from hermes_constants import get_hermes_home
|
||||
hermes_home = get_hermes_home()
|
||||
hermes_home = Path(hermes_home)
|
||||
|
||||
turn_store = SessionTurnStore(session_id, hermes_home=hermes_home)
|
||||
summary_dag = SummaryDAG(session_id, hermes_home=hermes_home)
|
||||
link_extractor = LinkExtractor()
|
||||
link_store = RelationLinkStore(hermes_home / "sessions" / session_id / "links.jsonl")
|
||||
store_router = StoreRouter(hermes_home=hermes_home)
|
||||
recall_engine = RecallEngine(turn_store, summary_dag, link_store, store_router)
|
||||
|
||||
return turn_store, summary_dag, link_extractor, link_store, store_router, recall_engine
|
||||
|
||||
|
||||
def ingest_turn(
|
||||
text: str,
|
||||
role: str,
|
||||
turn_store: SessionTurnStore,
|
||||
link_extractor: LinkExtractor,
|
||||
link_store: RelationLinkStore,
|
||||
store_router: StoreRouter,
|
||||
tool_name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
) -> TurnRecord:
|
||||
"""Append a turn to the session store and extract + route typed links.
|
||||
|
||||
This is the recommended write entry point. It:
|
||||
1. Appends the raw turn to SessionTurnStore (SESSION_STATE)
|
||||
2. Extracts typed links deterministically (no LLM)
|
||||
3. Persists links to the session link store
|
||||
4. Routes high-confidence links to WORLD_KNOWLEDGE / DURABLE_MEMORY
|
||||
|
||||
Returns the TurnRecord with its stable lineage_id.
|
||||
"""
|
||||
record = turn_store.append(role=role, content=text, tool_name=tool_name, tool_call_id=tool_call_id)
|
||||
|
||||
# Only extract links from user/assistant text (not tool results — they're noisy)
|
||||
if role in ("user", "assistant") and text:
|
||||
links = link_extractor.extract(text, source_turn_id=record.lineage_id)
|
||||
for link in links:
|
||||
link_store.append(link)
|
||||
# Route to the appropriate durable store
|
||||
if link.confidence >= 0.7:
|
||||
store_router.write(
|
||||
tier=link.store_tier,
|
||||
category=link.relation,
|
||||
content=f"{link.subject} {link.relation} {link.object_}",
|
||||
source_id=record.lineage_id,
|
||||
confidence=link.confidence,
|
||||
)
|
||||
|
||||
return record
|
||||
|
||||
|
||||
def compact_turns_to_dag(
|
||||
summary_text: str,
|
||||
source_turn_ids: List[str],
|
||||
summary_dag: SummaryDAG,
|
||||
parent_node_id: Optional[str] = None,
|
||||
topic: Optional[str] = None,
|
||||
) -> SummaryNode:
|
||||
"""Create a summary DAG node from a set of source turn IDs.
|
||||
|
||||
This is called after the ContextCompressor generates a summary, to
|
||||
record the compaction in the DAG with proper lineage references.
|
||||
The raw turns are NOT deleted — they remain in SessionTurnStore.
|
||||
|
||||
Returns the new SummaryNode.
|
||||
"""
|
||||
return summary_dag.add_node(
|
||||
summary_text=summary_text,
|
||||
source_turn_ids=source_turn_ids,
|
||||
parent_node_id=parent_node_id,
|
||||
topic=topic,
|
||||
)
|
||||
13
plugins/memory/atlas/__init__.py
Normal file
13
plugins/memory/atlas/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""ATLAS memory plugin package.
|
||||
|
||||
Registers the AtlasMemoryProvider with Hermes memory plugin discovery.
|
||||
"""
|
||||
|
||||
from plugins.memory.atlas.provider import AtlasMemoryProvider
|
||||
|
||||
|
||||
def register(manager):
|
||||
"""Plugin entry point called by MemoryManager plugin loader."""
|
||||
provider = AtlasMemoryProvider()
|
||||
if provider.is_available():
|
||||
manager.add_provider(provider)
|
||||
9
plugins/memory/atlas/plugin.yaml
Normal file
9
plugins/memory/atlas/plugin.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
name: atlas
|
||||
description: >
|
||||
ATLAS lossless context + memory subsystem.
|
||||
Persists every turn with a stable lineage ID, builds summary DAG nodes
|
||||
during compaction, and routes writes to explicit stores (world knowledge,
|
||||
durable memory, session state). Exposes atlas_search, atlas_describe,
|
||||
and atlas_expand recall tools.
|
||||
version: "1.0.0"
|
||||
author: hermes-agent
|
||||
316
plugins/memory/atlas/provider.py
Normal file
316
plugins/memory/atlas/provider.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""AtlasMemoryProvider — MemoryProvider integration for the ATLAS subsystem.
|
||||
|
||||
Wires the ATLAS core (RawTurnStore, SummaryDAGStore, AtlasStores,
|
||||
TypedLinkExtractor, RecallEngine) into Hermes' MemoryProvider lifecycle.
|
||||
|
||||
Responsibilities
|
||||
----------------
|
||||
* On every turn: persist the raw turn to RawTurnStore (lossless).
|
||||
* On pre-compress: build a SummaryDAGStore node from the messages being
|
||||
discarded (DAG compaction — no destructive deletion).
|
||||
* Provide tool schemas for atlas_search, atlas_describe, atlas_expand.
|
||||
* Route explicit memory writes to the correct store via AtlasStores.
|
||||
* Extract typed links on every turn write.
|
||||
|
||||
Configuration (config.yaml)
|
||||
---------------------------
|
||||
memory:
|
||||
provider: atlas
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AtlasMemoryProvider(MemoryProvider):
|
||||
"""Lossless context + memory subsystem for Hermes."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# MemoryProvider identity
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "atlas"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Always available — pure-local SQLite, no external deps."""
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
hermes_home = Path(kwargs.get("hermes_home", Path.home() / ".hermes"))
|
||||
db_path = hermes_home / "atlas.db"
|
||||
|
||||
from agent.atlas.db import AtlasDB
|
||||
from agent.atlas.turns import RawTurnStore
|
||||
from agent.atlas.dag import SummaryDAGStore
|
||||
from agent.atlas.stores import AtlasStores
|
||||
from agent.atlas.extractor import TypedLinkExtractor
|
||||
from agent.atlas.recall import RecallEngine
|
||||
|
||||
self._db = AtlasDB(db_path)
|
||||
self._db.open()
|
||||
self._turns = RawTurnStore(self._db)
|
||||
self._dag = SummaryDAGStore(self._db)
|
||||
self._stores = AtlasStores(self._db)
|
||||
self._extractor = TypedLinkExtractor(self._db)
|
||||
self._recall = RecallEngine(self._db)
|
||||
self._session_id = session_id
|
||||
|
||||
logger.info("ATLAS provider initialised (session=%s, db=%s)", session_id, db_path)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if hasattr(self, "_db"):
|
||||
self._db.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# System prompt
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
if not hasattr(self, "_recall"):
|
||||
return ""
|
||||
return (
|
||||
"## ATLAS Memory\n"
|
||||
"Your conversation is persisted losslessly in the ATLAS memory subsystem.\n"
|
||||
"Use these tools to recall compacted context:\n"
|
||||
" atlas_search(query) — full-text search over turns, summaries, and memory\n"
|
||||
" atlas_describe(id) — describe a specific turn or summary node\n"
|
||||
" atlas_expand(node_id) — expand a summary node to its source turns\n"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Turn persistence
|
||||
|
||||
def sync_turn(
|
||||
self, user_content: str, assistant_content: str, *, session_id: str = ""
|
||||
) -> None:
|
||||
"""Persist user+assistant turns and extract typed links."""
|
||||
if not hasattr(self, "_turns"):
|
||||
return
|
||||
sid = session_id or getattr(self, "_session_id", "default")
|
||||
ts = time.time()
|
||||
|
||||
if user_content:
|
||||
turn_id = self._turns.append(
|
||||
sid, "user", user_content, timestamp=ts
|
||||
)
|
||||
self._extractor.extract_and_store(user_content, turn_id)
|
||||
|
||||
if assistant_content:
|
||||
turn_id = self._turns.append(
|
||||
sid, "assistant", assistant_content, timestamp=ts + 0.001
|
||||
)
|
||||
self._extractor.extract_and_store(assistant_content, turn_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Compaction hook — produces summary DAG nodes
|
||||
|
||||
def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str:
|
||||
"""Build a DAG summary node from messages about to be compressed.
|
||||
|
||||
The summary text is derived from assistant messages in the to-be-
|
||||
compressed window (no LLM call — uses first 500 chars of each
|
||||
assistant turn as an extractive summary). More sophisticated
|
||||
summarisation can be added later without changing the interface.
|
||||
"""
|
||||
if not hasattr(self, "_dag") or not messages:
|
||||
return ""
|
||||
sid = getattr(self, "_session_id", "default")
|
||||
|
||||
# Collect turn IDs for turns that are already persisted
|
||||
source_turn_ids = []
|
||||
summary_parts = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
if not content or not isinstance(content, str):
|
||||
continue
|
||||
# Best-effort match to persisted turn_ids via content search
|
||||
# (exact matching not required; lineage is tracked by session)
|
||||
if role in ("user", "assistant", "tool"):
|
||||
turns = self._turns.get_session_turns(sid)
|
||||
for t in turns:
|
||||
if t.content.strip()[:200] == content.strip()[:200]:
|
||||
source_turn_ids.append(t.turn_id)
|
||||
break
|
||||
if role == "assistant":
|
||||
summary_parts.append(content[:500])
|
||||
|
||||
if not summary_parts:
|
||||
return ""
|
||||
|
||||
summary_text = "\n\n".join(summary_parts)[:2000]
|
||||
node = self._dag.add_node(
|
||||
sid,
|
||||
summary_text=summary_text,
|
||||
source_turn_ids=source_turn_ids or [],
|
||||
)
|
||||
# Also extract links from summary
|
||||
self._extractor.extract_and_store(
|
||||
summary_text, node.node_id, source_type="dag"
|
||||
)
|
||||
logger.info(
|
||||
"ATLAS: compaction DAG node %s created (%d source turns)",
|
||||
node.node_id,
|
||||
node.source_count(),
|
||||
)
|
||||
# Return summary text so the compressor can embed it in its prompt
|
||||
return f"[ATLAS compaction summary — node {node.node_id}]\n{summary_text}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool schemas
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"name": "atlas_search",
|
||||
"description": (
|
||||
"Search the ATLAS memory corpus (raw turns, DAG summary nodes, "
|
||||
"durable memory, world knowledge) using full-text search. "
|
||||
"Use this to recall facts from earlier in the conversation even "
|
||||
"after context compaction."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Free-text search query.",
|
||||
},
|
||||
"doc_types": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "enum": ["raw_turn", "dag", "durable", "world"]},
|
||||
"description": "Limit results to these doc types (omit for all).",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results to return (default 10).",
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "atlas_describe",
|
||||
"description": (
|
||||
"Describe a specific turn or DAG summary node by its lineage ID. "
|
||||
"Turn IDs look like 'session_id:NNNN'. "
|
||||
"DAG node IDs look like 'dag:session_id:NNNN'."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "The turn_id or DAG node_id to describe.",
|
||||
},
|
||||
},
|
||||
"required": ["id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "atlas_expand",
|
||||
"description": (
|
||||
"Expand a DAG summary node to reveal the source turns it compacted. "
|
||||
"Lets you recover specific facts from compacted context without "
|
||||
"re-injecting the full original transcript."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"node_id": {
|
||||
"type": "string",
|
||||
"description": "The DAG node_id to expand (e.g. 'dag:session_id:0000').",
|
||||
},
|
||||
"max_turns": {
|
||||
"type": "integer",
|
||||
"description": "Maximum source turns to include (default 20).",
|
||||
"default": 20,
|
||||
},
|
||||
},
|
||||
"required": ["node_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "atlas_store",
|
||||
"description": (
|
||||
"Write a fact to the appropriate ATLAS store. "
|
||||
"Use 'world:' prefix for world knowledge, 'session:' for session state, "
|
||||
"or no prefix for durable memory."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The fact to store.",
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"description": "Category tag (e.g. 'user_pref', 'project', 'world:geography').",
|
||||
},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool dispatch
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: Dict[str, Any], **kwargs) -> str:
|
||||
if not hasattr(self, "_recall"):
|
||||
return json.dumps({"error": "ATLAS provider not initialised"})
|
||||
try:
|
||||
if tool_name == "atlas_search":
|
||||
result = self._recall.search(
|
||||
args["query"],
|
||||
limit=args.get("limit", 10),
|
||||
doc_types=args.get("doc_types"),
|
||||
)
|
||||
return json.dumps({"result": result})
|
||||
|
||||
elif tool_name == "atlas_describe":
|
||||
result = self._recall.describe(args["id"])
|
||||
return json.dumps({"result": result})
|
||||
|
||||
elif tool_name == "atlas_expand":
|
||||
result = self._recall.expand(
|
||||
args["node_id"],
|
||||
max_turns=args.get("max_turns", 20),
|
||||
)
|
||||
return json.dumps({"result": result})
|
||||
|
||||
elif tool_name == "atlas_store":
|
||||
sid = getattr(self, "_session_id", "default")
|
||||
target, row_id = self._stores.write(
|
||||
args["content"],
|
||||
category=args.get("category", ""),
|
||||
session_id=sid,
|
||||
)
|
||||
return json.dumps({"result": f"Stored to {target.value} store (id={row_id})"})
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
except Exception as exc:
|
||||
logger.error("ATLAS tool %s failed: %s", tool_name, exc)
|
||||
return json.dumps({"error": str(exc)})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Config schema (for `hermes memory setup`)
|
||||
|
||||
def get_config_schema(self) -> List[Dict[str, Any]]:
|
||||
return [] # fully local; no credentials needed
|
||||
408
tests/agent/test_atlas_memory.py
Normal file
408
tests/agent/test_atlas_memory.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Tests for agent/atlas_memory.py — ATLAS lossless memory subsystem.
|
||||
|
||||
All tests use tmp_path and monkeypatch HERMES_HOME so that no writes
|
||||
go to ~/.hermes/.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def atlas(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
from agent.atlas_memory import AtlasMemory
|
||||
|
||||
m = AtlasMemory(db_path=tmp_path / "atlas.db")
|
||||
yield m
|
||||
m.db.close()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TestTurnStore
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestTurnStore:
|
||||
def test_append_and_retrieve(self, atlas):
|
||||
"""Append a turn and retrieve all fields by turn_id."""
|
||||
turn_id = atlas.turns.append_turn(
|
||||
session_id="sess1",
|
||||
role="user",
|
||||
content="Hello world",
|
||||
)
|
||||
row = atlas.turns.get_turn(turn_id)
|
||||
assert row is not None
|
||||
assert row["turn_id"] == turn_id
|
||||
assert row["session_id"] == "sess1"
|
||||
assert row["role"] == "user"
|
||||
assert row["content"] == "Hello world"
|
||||
assert row["created_at"] > 0
|
||||
|
||||
def test_stable_lineage_id(self, atlas):
|
||||
"""turn_id format is {session_id}:000001 for the first turn."""
|
||||
turn_id = atlas.turns.append_turn("mysession", "user", content="hi")
|
||||
assert turn_id == "mysession:000001"
|
||||
|
||||
def test_immutable_sequence(self, atlas):
|
||||
"""seq_num monotonically increases within a session."""
|
||||
ids = [
|
||||
atlas.turns.append_turn("s", "user", content=f"msg {i}")
|
||||
for i in range(5)
|
||||
]
|
||||
expected = [f"s:{n:06d}" for n in range(1, 6)]
|
||||
assert ids == expected
|
||||
|
||||
turns = atlas.turns.get_session_turns("s")
|
||||
seq_nums = [t["seq_num"] for t in turns]
|
||||
assert seq_nums == list(range(1, 6))
|
||||
|
||||
def test_cross_session_isolation(self, atlas):
|
||||
"""Turns from different sessions do not mix."""
|
||||
atlas.turns.append_turn("alpha", "user", content="alpha msg")
|
||||
atlas.turns.append_turn("beta", "user", content="beta msg")
|
||||
|
||||
alpha_turns = atlas.turns.get_session_turns("alpha")
|
||||
beta_turns = atlas.turns.get_session_turns("beta")
|
||||
|
||||
assert len(alpha_turns) == 1
|
||||
assert len(beta_turns) == 1
|
||||
assert alpha_turns[0]["content"] == "alpha msg"
|
||||
assert beta_turns[0]["content"] == "beta msg"
|
||||
|
||||
# seq_num resets per session
|
||||
assert alpha_turns[0]["seq_num"] == 1
|
||||
assert beta_turns[0]["seq_num"] == 1
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TestSummaryDAG
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSummaryDAG:
|
||||
def test_create_summary_node(self, atlas):
|
||||
"""Create a summary node with source_turn_ids."""
|
||||
t1 = atlas.turns.append_turn("s1", "user", content="turn one")
|
||||
t2 = atlas.turns.append_turn("s1", "assistant", content="turn two")
|
||||
|
||||
node_id = atlas.dag.create_summary_node(
|
||||
session_id="s1",
|
||||
summary_text="Summary of turns",
|
||||
source_turn_ids=[t1, t2],
|
||||
)
|
||||
assert node_id.startswith("sum:s1:")
|
||||
|
||||
node = atlas.dag.get_node(node_id)
|
||||
assert node is not None
|
||||
assert node["session_id"] == "s1"
|
||||
assert node["summary_text"] == "Summary of turns"
|
||||
assert t1 in node["source_turn_ids"]
|
||||
assert t2 in node["source_turn_ids"]
|
||||
|
||||
def test_node_references_source_turns(self, atlas):
|
||||
"""node_id links back to the exact source turn IDs."""
|
||||
t1 = atlas.turns.append_turn("s2", "user", content="original turn")
|
||||
node_id = atlas.dag.create_summary_node("s2", "A summary", [t1])
|
||||
|
||||
node = atlas.dag.get_node(node_id)
|
||||
assert node["source_turn_ids"] == [t1]
|
||||
|
||||
def test_expand_recovers_original_turns(self, atlas):
|
||||
"""expand() returns the original turn content (lossless recall)."""
|
||||
t1 = atlas.turns.append_turn("s3", "user", content="the original content")
|
||||
node_id = atlas.dag.create_summary_node("s3", "Compact form", [t1])
|
||||
|
||||
source_turns = atlas.dag.get_source_turns(node_id, atlas.turns)
|
||||
assert len(source_turns) == 1
|
||||
assert source_turns[0]["content"] == "the original content"
|
||||
assert source_turns[0]["turn_id"] == t1
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TestRecallEngine
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestRecallEngine:
|
||||
def test_search_finds_turn(self, atlas):
|
||||
"""search('password') finds a turn whose content contains 'password'."""
|
||||
atlas.turns.append_turn("s", "user", content="my password is hunter2")
|
||||
atlas.turns.append_turn("s", "user", content="unrelated message")
|
||||
|
||||
results = atlas.recall.search("password", session_id="s")
|
||||
assert len(results) == 1
|
||||
assert "password" in results[0]["content"]
|
||||
|
||||
def test_search_finds_across_sessions_when_no_session_filter(self, atlas):
|
||||
"""Without session_id filter, search spans all sessions."""
|
||||
atlas.turns.append_turn("sess_a", "user", content="my password is here")
|
||||
atlas.turns.append_turn("sess_b", "user", content="no secrets here")
|
||||
|
||||
results = atlas.recall.search("password")
|
||||
assert any(r["session_id"] == "sess_a" for r in results)
|
||||
|
||||
def test_describe_turn(self, atlas):
|
||||
"""describe(turn_id) returns the full turn dict."""
|
||||
tid = atlas.turns.append_turn("s", "user", content="describe me")
|
||||
desc = atlas.recall.describe(tid)
|
||||
assert desc is not None
|
||||
assert desc["content"] == "describe me"
|
||||
assert desc["type"] == "turn"
|
||||
|
||||
def test_describe_node(self, atlas):
|
||||
"""describe(node_id) returns the summary node dict."""
|
||||
t1 = atlas.turns.append_turn("s", "user", content="raw turn")
|
||||
node_id = atlas.dag.create_summary_node("s", "a node summary", [t1])
|
||||
|
||||
desc = atlas.recall.describe(node_id)
|
||||
assert desc is not None
|
||||
assert desc["summary_text"] == "a node summary"
|
||||
assert desc["type"] == "summary_node"
|
||||
|
||||
def test_describe_nonexistent_returns_none(self, atlas):
|
||||
"""describe() returns None for an unknown ID."""
|
||||
assert atlas.recall.describe("nonexistent:999") is None
|
||||
|
||||
def test_expand_summary(self, atlas):
|
||||
"""expand(node_id) returns both node and source turns."""
|
||||
t1 = atlas.turns.append_turn("s", "user", content="original")
|
||||
node_id = atlas.dag.create_summary_node("s", "compact", [t1])
|
||||
|
||||
result = atlas.recall.expand(node_id)
|
||||
assert result["node"] is not None
|
||||
assert result["node"]["node_id"] == node_id
|
||||
assert len(result["source_turns"]) == 1
|
||||
assert result["source_turns"][0]["content"] == "original"
|
||||
|
||||
def test_search_finds_summary_node(self, atlas):
|
||||
"""search() also finds content in summary nodes."""
|
||||
t1 = atlas.turns.append_turn("s", "user", content="turn content")
|
||||
atlas.dag.create_summary_node("s", "important summary keyword here", [t1])
|
||||
|
||||
results = atlas.recall.search("important summary keyword")
|
||||
assert any(r["type"] == "summary_node" for r in results)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TestAtlasStore (three-store routing)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestAtlasStore:
|
||||
def test_write_world_knowledge(self, atlas):
|
||||
"""Write to world_knowledge and read back."""
|
||||
atlas.store.write("world_knowledge", "capital_of_france", "Paris")
|
||||
value = atlas.store.read("world_knowledge", "capital_of_france")
|
||||
assert value == "Paris"
|
||||
|
||||
def test_write_durable_memory(self, atlas):
|
||||
"""Write to durable_memory and read back."""
|
||||
atlas.store.write("durable_memory", "user_pref", "dark_mode")
|
||||
value = atlas.store.read("durable_memory", "user_pref")
|
||||
assert value == "dark_mode"
|
||||
|
||||
def test_write_session_state(self, atlas):
|
||||
"""Write to session_state with session_id and read back."""
|
||||
atlas.store.write("session_state", "current_task", "coding", session_id="sess1")
|
||||
value = atlas.store.read("session_state", "current_task", session_id="sess1")
|
||||
assert value == "coding"
|
||||
|
||||
def test_session_state_scoped(self, atlas):
|
||||
"""session_state key for session A is not visible from session B."""
|
||||
atlas.store.write("session_state", "task", "A's task", session_id="sessA")
|
||||
value_b = atlas.store.read("session_state", "task", session_id="sessB")
|
||||
assert value_b is None
|
||||
|
||||
def test_invalid_store_rejected(self, atlas):
|
||||
"""Writing to 'mixed_bucket' raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid store"):
|
||||
atlas.store.write("mixed_bucket", "key", "value")
|
||||
|
||||
def test_invalid_store_rejected_on_read(self, atlas):
|
||||
"""Reading from an invalid store also raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid store"):
|
||||
atlas.store.read("garbage_store", "key")
|
||||
|
||||
def test_three_stores_independent(self, atlas):
|
||||
"""Same key in different stores has different values."""
|
||||
atlas.store.write("world_knowledge", "x", "world_value")
|
||||
atlas.store.write("durable_memory", "x", "durable_value")
|
||||
|
||||
assert atlas.store.read("world_knowledge", "x") == "world_value"
|
||||
assert atlas.store.read("durable_memory", "x") == "durable_value"
|
||||
|
||||
def test_list_keys(self, atlas):
|
||||
"""list_keys returns all keys for a store."""
|
||||
atlas.store.write("world_knowledge", "a", "1")
|
||||
atlas.store.write("world_knowledge", "b", "2")
|
||||
keys = atlas.store.list_keys("world_knowledge")
|
||||
assert set(keys) == {"a", "b"}
|
||||
|
||||
def test_delete(self, atlas):
|
||||
"""delete removes the key from the store."""
|
||||
atlas.store.write("durable_memory", "to_delete", "bye")
|
||||
atlas.store.delete("durable_memory", "to_delete")
|
||||
assert atlas.store.read("durable_memory", "to_delete") is None
|
||||
|
||||
def test_overwrite(self, atlas):
|
||||
"""Writing to the same key updates the value."""
|
||||
atlas.store.write("world_knowledge", "k", "v1")
|
||||
atlas.store.write("world_knowledge", "k", "v2")
|
||||
assert atlas.store.read("world_knowledge", "k") == "v2"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TestTypedLinkExtractor — fixture-backed, all 5 relation types
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestTypedLinkExtractor:
|
||||
def test_extracts_defines(self, atlas):
|
||||
"""'Python is a programming language' → DEFINES link."""
|
||||
t = atlas.record_turn("s", "user", content="Python is a programming language")
|
||||
links = atlas.links.get_links_for_turn(t)
|
||||
types = [lnk["link_type"] for lnk in links]
|
||||
assert "DEFINES" in types
|
||||
|
||||
def test_extracts_modifies(self, atlas):
|
||||
"""'I changed the port to 8081' → MODIFIES link."""
|
||||
t = atlas.record_turn("s", "user", content="I changed the port to 8081")
|
||||
links = atlas.links.get_links_for_turn(t)
|
||||
types = [lnk["link_type"] for lnk in links]
|
||||
assert "MODIFIES" in types
|
||||
|
||||
def test_extracts_references(self, atlas):
|
||||
"""'as mentioned earlier' → REFERENCES link."""
|
||||
t = atlas.record_turn("s", "user", content="as mentioned earlier, this is important")
|
||||
links = atlas.links.get_links_for_turn(t)
|
||||
types = [lnk["link_type"] for lnk in links]
|
||||
assert "REFERENCES" in types
|
||||
|
||||
def test_extracts_depends_on(self, atlas):
|
||||
"""'This requires Redis to be running' → DEPENDS_ON link."""
|
||||
t = atlas.record_turn("s", "user", content="This requires Redis to be running")
|
||||
links = atlas.links.get_links_for_turn(t)
|
||||
types = [lnk["link_type"] for lnk in links]
|
||||
assert "DEPENDS_ON" in types
|
||||
|
||||
def test_extracts_contradicts(self, atlas):
|
||||
"""'Actually that is wrong, the port is 8081' → CONTRADICTS link."""
|
||||
t = atlas.record_turn(
|
||||
"s", "user", content="Actually that is wrong, the port is 8081"
|
||||
)
|
||||
links = atlas.links.get_links_for_turn(t)
|
||||
types = [lnk["link_type"] for lnk in links]
|
||||
assert "CONTRADICTS" in types
|
||||
|
||||
def test_all_five_types_present(self, atlas):
|
||||
"""A multi-statement paragraph contains all 5 relation types."""
|
||||
content = (
|
||||
"Python is a programming language. " # DEFINES
|
||||
"I changed the config file. " # MODIFIES
|
||||
"As mentioned earlier, we use Redis. " # REFERENCES
|
||||
"This requires Python to be installed. " # DEPENDS_ON
|
||||
"Actually that is wrong, not Java." # CONTRADICTS
|
||||
)
|
||||
t = atlas.record_turn("s", "user", content=content)
|
||||
links = atlas.links.get_links_for_turn(t)
|
||||
found_types = {lnk["link_type"] for lnk in links}
|
||||
assert found_types >= {"DEFINES", "MODIFIES", "REFERENCES", "DEPENDS_ON", "CONTRADICTS"}
|
||||
|
||||
def test_get_links_by_type(self, atlas):
|
||||
"""get_links_by_type returns only links of the requested type."""
|
||||
t = atlas.record_turn("s", "user", content="Python is a language")
|
||||
links = atlas.links.get_links_by_type("DEFINES", session_id="s")
|
||||
assert all(lnk["link_type"] == "DEFINES" for lnk in links)
|
||||
assert any(lnk["source_id"] == t for lnk in links)
|
||||
|
||||
def test_no_content_returns_empty(self, atlas):
|
||||
"""A turn with no content produces no links."""
|
||||
t = atlas.turns.append_turn("s", "tool", content=None)
|
||||
links = atlas.links.extract_and_store(t, None, "s", "tool")
|
||||
assert links == []
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TestFactRecovery — the "demo" test
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestFactRecovery:
|
||||
def test_recover_fact_from_compacted_context(self, atlas):
|
||||
"""
|
||||
Demonstrates lossless compaction:
|
||||
|
||||
1. Record two turns referencing API key XYZ123.
|
||||
2. Compact them into a summary node.
|
||||
3. Show the original turns are still accessible via expand().
|
||||
4. Show search('API key') finds the fact via the summary node.
|
||||
5. Assert the agent can recover 'XYZ123' without re-injecting the
|
||||
full original transcript.
|
||||
"""
|
||||
session_id = "demo_session"
|
||||
|
||||
# Step 1: Create the original turns
|
||||
t1 = atlas.record_turn(
|
||||
session_id, "user", content="The API key is XYZ123"
|
||||
)
|
||||
t2 = atlas.record_turn(
|
||||
session_id,
|
||||
"assistant",
|
||||
content="Got it, I'll remember the API key XYZ123",
|
||||
)
|
||||
|
||||
# Step 2: Compact those turns into a summary node
|
||||
node_id = atlas.compact_session(
|
||||
session_id,
|
||||
summary_text="User shared API key XYZ123; assistant acknowledged.",
|
||||
turn_ids=[t1, t2],
|
||||
)
|
||||
assert node_id.startswith("sum:")
|
||||
|
||||
# Step 3: Original turns are still accessible via expand()
|
||||
expansion = atlas.recall.expand(node_id)
|
||||
assert expansion["node"] is not None
|
||||
source_turns = expansion["source_turns"]
|
||||
assert len(source_turns) == 2
|
||||
contents = [turn["content"] for turn in source_turns]
|
||||
assert any("XYZ123" in c for c in contents), (
|
||||
"Original turn content must still be recoverable via expand()"
|
||||
)
|
||||
|
||||
# Step 4: search('API key') finds the fact via the summary node
|
||||
results = atlas.recall.search("API key", session_id=session_id)
|
||||
assert len(results) > 0, "search('API key') must return at least one result"
|
||||
|
||||
# Step 5: Recover XYZ123 without using the raw turn content directly.
|
||||
# We access it only through the summary node or expand() — this proves
|
||||
# lossless compaction. We deliberately search the compacted form.
|
||||
summary_results = [r for r in results if r["type"] == "summary_node"]
|
||||
turn_results = [r for r in results if r["type"] == "turn"]
|
||||
|
||||
# The summary node should also reference XYZ123 in its summary text,
|
||||
# OR the original turns remain findable — either path proves recovery.
|
||||
fact_found = False
|
||||
for r in summary_results:
|
||||
if "XYZ123" in r.get("content", ""):
|
||||
fact_found = True
|
||||
break
|
||||
if not fact_found:
|
||||
# Fall back to expanding via the DAG (the other valid recovery path)
|
||||
expanded = atlas.recall.expand(node_id)
|
||||
for turn in expanded["source_turns"]:
|
||||
if "XYZ123" in (turn.get("content") or ""):
|
||||
fact_found = True
|
||||
break
|
||||
|
||||
assert fact_found, (
|
||||
"XYZ123 must be recoverable from compacted context via summary node "
|
||||
"or expand() without re-injecting the full original transcript"
|
||||
)
|
||||
|
||||
# Bonus: verify the node correctly references both source turn IDs
|
||||
node = atlas.recall.describe(node_id)
|
||||
assert t1 in node["source_turn_ids"]
|
||||
assert t2 in node["source_turn_ids"]
|
||||
0
tests/atlas/__init__.py
Normal file
0
tests/atlas/__init__.py
Normal file
14
tests/atlas/conftest.py
Normal file
14
tests/atlas/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Shared fixtures for ATLAS tests."""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from agent.atlas.db import AtlasDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def atlas_db(tmp_path):
|
||||
"""Return an initialised in-memory-like AtlasDB backed by a tmp file."""
|
||||
db = AtlasDB(tmp_path / "test_atlas.db")
|
||||
db.open()
|
||||
yield db
|
||||
db.close()
|
||||
106
tests/atlas/test_dag.py
Normal file
106
tests/atlas/test_dag.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Tests for SummaryDAGStore — acceptance criterion:
|
||||
Compaction builds retrievable summary nodes with source references
|
||||
instead of destructive transcript loss.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from agent.atlas.dag import SummaryDAGStore
|
||||
from agent.atlas.turns import RawTurnStore
|
||||
|
||||
|
||||
class TestSummaryDAGStore:
|
||||
def test_add_node_returns_node(self, atlas_db):
|
||||
store = SummaryDAGStore(atlas_db)
|
||||
node = store.add_node(
|
||||
"sess1",
|
||||
"Summary of first compaction.",
|
||||
source_turn_ids=["sess1:0000", "sess1:0001"],
|
||||
)
|
||||
assert node.node_id == "dag:sess1:0000"
|
||||
assert node.session_id == "sess1"
|
||||
assert node.summary_text == "Summary of first compaction."
|
||||
assert node.source_turn_ids == ["sess1:0000", "sess1:0001"]
|
||||
assert node.parent_node_id is None
|
||||
|
||||
def test_node_ids_increment_per_session(self, atlas_db):
|
||||
store = SummaryDAGStore(atlas_db)
|
||||
n0 = store.add_node("sess1", "First summary", ["sess1:0000"])
|
||||
n1 = store.add_node("sess1", "Second summary", ["sess1:0001"])
|
||||
assert n0.node_id == "dag:sess1:0000"
|
||||
assert n1.node_id == "dag:sess1:0001"
|
||||
|
||||
def test_chained_compaction_preserves_parent(self, atlas_db):
|
||||
store = SummaryDAGStore(atlas_db)
|
||||
n0 = store.add_node("sess1", "Level 1 summary", ["sess1:0000"])
|
||||
n1 = store.add_node(
|
||||
"sess1", "Level 2 summary",
|
||||
["sess1:0001", "sess1:0002"],
|
||||
parent_node_id=n0.node_id,
|
||||
)
|
||||
assert n1.parent_node_id == n0.node_id
|
||||
|
||||
def test_get_node_by_id(self, atlas_db):
|
||||
store = SummaryDAGStore(atlas_db)
|
||||
store.add_node("sess1", "My summary", ["sess1:0000"])
|
||||
node = store.get_node("dag:sess1:0000")
|
||||
assert node is not None
|
||||
assert node.summary_text == "My summary"
|
||||
assert node.source_turn_ids == ["sess1:0000"]
|
||||
|
||||
def test_get_nonexistent_returns_none(self, atlas_db):
|
||||
store = SummaryDAGStore(atlas_db)
|
||||
assert store.get_node("dag:missing:0000") is None
|
||||
|
||||
def test_get_session_nodes(self, atlas_db):
|
||||
store = SummaryDAGStore(atlas_db)
|
||||
store.add_node("sess1", "A", ["sess1:0000"])
|
||||
store.add_node("sess1", "B", ["sess1:0001"])
|
||||
nodes = store.get_session_nodes("sess1")
|
||||
assert len(nodes) == 2
|
||||
assert nodes[0].summary_text == "A"
|
||||
assert nodes[1].summary_text == "B"
|
||||
|
||||
def test_count_session_nodes(self, atlas_db):
|
||||
store = SummaryDAGStore(atlas_db)
|
||||
store.add_node("sess1", "x", [])
|
||||
store.add_node("sess1", "y", [])
|
||||
assert store.count_session_nodes("sess1") == 2
|
||||
assert store.count_session_nodes("sessX") == 0
|
||||
|
||||
def test_get_latest_node(self, atlas_db):
|
||||
store = SummaryDAGStore(atlas_db)
|
||||
store.add_node("sess1", "Old summary", [])
|
||||
store.add_node("sess1", "New summary", [])
|
||||
latest = store.get_latest_node("sess1")
|
||||
assert latest.summary_text == "New summary"
|
||||
|
||||
def test_source_turns_not_deleted(self, atlas_db):
|
||||
"""Core lossless guarantee: source turns remain accessible after compaction."""
|
||||
turns = RawTurnStore(atlas_db)
|
||||
dag = SummaryDAGStore(atlas_db)
|
||||
|
||||
t0 = turns.append("sess1", "user", "What is the capital of France?")
|
||||
t1 = turns.append("sess1", "assistant", "The capital of France is Paris.")
|
||||
|
||||
node = dag.add_node(
|
||||
"sess1",
|
||||
"User asked about France. Answer: Paris is the capital.",
|
||||
source_turn_ids=[t0, t1],
|
||||
)
|
||||
|
||||
# Source turns still retrievable
|
||||
r0 = turns.get(t0)
|
||||
r1 = turns.get(t1)
|
||||
assert r0 is not None
|
||||
assert r1 is not None
|
||||
assert r0.content == "What is the capital of France?"
|
||||
assert r1.content == "The capital of France is Paris."
|
||||
|
||||
# DAG node references both turns
|
||||
assert t0 in node.source_turn_ids
|
||||
assert t1 in node.source_turn_ids
|
||||
|
||||
def test_node_source_count(self, atlas_db):
|
||||
store = SummaryDAGStore(atlas_db)
|
||||
node = store.add_node("sess1", "Summary", ["t1", "t2", "t3"])
|
||||
assert node.source_count() == 3
|
||||
193
tests/atlas/test_extractor.py
Normal file
193
tests/atlas/test_extractor.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Tests for TypedLinkExtractor — acceptance criterion:
|
||||
Deterministic typed-link extraction supports at least 5 initial relation
|
||||
types and has fixture-backed tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from agent.atlas.extractor import (
|
||||
TypedLinkExtractor,
|
||||
TypedLink,
|
||||
RelationType,
|
||||
extract_links,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure extraction tests (fixture-backed, no DB)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractLinks:
|
||||
"""Tests for the pure extract_links() function."""
|
||||
|
||||
# ---- mentions ----------------------------------------------------------
|
||||
|
||||
def test_mentions_quoted_entity(self):
|
||||
links = extract_links('We use "PostgreSQL" for storage.', "turn:1")
|
||||
mentions = [l for l in links if l.relation_type == RelationType.MENTIONS]
|
||||
assert any("PostgreSQL" in l.object for l in mentions)
|
||||
|
||||
def test_mentions_capitalized_entity(self):
|
||||
links = extract_links("The project uses Django Framework.", "turn:1")
|
||||
mentions = [l for l in links if l.relation_type == RelationType.MENTIONS]
|
||||
subjects = [l.object for l in mentions]
|
||||
assert any("Django" in s for s in subjects)
|
||||
|
||||
# ---- defines -----------------------------------------------------------
|
||||
|
||||
def test_defines_is_a(self):
|
||||
links = extract_links("A router is a network device.", "turn:1")
|
||||
defines = [l for l in links if l.relation_type == RelationType.DEFINES]
|
||||
assert len(defines) >= 1
|
||||
|
||||
def test_defines_means(self):
|
||||
links = extract_links("CRUD means Create, Read, Update, Delete.", "turn:1")
|
||||
defines = [l for l in links if l.relation_type == RelationType.DEFINES]
|
||||
assert len(defines) >= 1
|
||||
|
||||
# ---- corrects ----------------------------------------------------------
|
||||
|
||||
def test_corrects_actually(self):
|
||||
links = extract_links(
|
||||
"Actually, the server is on port 8081 not 8080.", "turn:1"
|
||||
)
|
||||
corrects = [l for l in links if l.relation_type == RelationType.CORRECTS]
|
||||
assert len(corrects) >= 1
|
||||
assert any("8081" in l.object for l in corrects)
|
||||
|
||||
def test_corrects_no_comma(self):
|
||||
links = extract_links("No, that is the wrong approach.", "turn:1")
|
||||
corrects = [l for l in links if l.relation_type == RelationType.CORRECTS]
|
||||
assert len(corrects) >= 1
|
||||
|
||||
# ---- prefers -----------------------------------------------------------
|
||||
|
||||
def test_prefers_i_prefer(self):
|
||||
links = extract_links("I prefer Python over JavaScript.", "turn:1")
|
||||
prefers = [l for l in links if l.relation_type == RelationType.PREFERS]
|
||||
assert len(prefers) >= 1
|
||||
assert any("Python" in l.object for l in prefers)
|
||||
|
||||
def test_prefers_use_instead(self):
|
||||
links = extract_links("Use pytest instead of unittest.", "turn:1")
|
||||
prefers = [l for l in links if l.relation_type == RelationType.PREFERS]
|
||||
assert len(prefers) >= 1
|
||||
|
||||
# ---- uses --------------------------------------------------------------
|
||||
|
||||
def test_uses_tool(self):
|
||||
links = extract_links("The project uses SQLAlchemy for ORM.", "turn:1")
|
||||
uses = [l for l in links if l.relation_type == RelationType.USES]
|
||||
assert len(uses) >= 1
|
||||
assert any("SQLAlchemy" in l.object for l in uses)
|
||||
|
||||
def test_uses_built_with(self):
|
||||
links = extract_links("Built with FastAPI and Pydantic.", "turn:1")
|
||||
uses = [l for l in links if l.relation_type == RelationType.USES]
|
||||
assert len(uses) >= 1
|
||||
|
||||
# ---- depends_on --------------------------------------------------------
|
||||
|
||||
def test_depends_on_requires(self):
|
||||
links = extract_links("The module requires numpy for math operations.", "turn:1")
|
||||
deps = [l for l in links if l.relation_type == RelationType.DEPENDS_ON]
|
||||
assert len(deps) >= 1
|
||||
assert any("numpy" in l.object.lower() for l in deps)
|
||||
|
||||
def test_depends_on_needs(self):
|
||||
links = extract_links("This feature needs a database connection.", "turn:1")
|
||||
deps = [l for l in links if l.relation_type == RelationType.DEPENDS_ON]
|
||||
assert len(deps) >= 1
|
||||
|
||||
# ---- part_of -----------------------------------------------------------
|
||||
|
||||
def test_part_of(self):
|
||||
links = extract_links("The auth module is part of the core library.", "turn:1")
|
||||
parts = [l for l in links if l.relation_type == RelationType.PART_OF]
|
||||
assert len(parts) >= 1
|
||||
|
||||
# ---- empty text --------------------------------------------------------
|
||||
|
||||
def test_empty_text_returns_no_links(self):
|
||||
links = extract_links("", "turn:1")
|
||||
assert links == []
|
||||
|
||||
def test_short_text_returns_few_links(self):
|
||||
links = extract_links("ok", "turn:1")
|
||||
# No meaningful relations from a 2-char message
|
||||
assert all(l.relation_type == RelationType.MENTIONS for l in links) or len(links) == 0
|
||||
|
||||
# ---- source metadata ---------------------------------------------------
|
||||
|
||||
def test_source_id_preserved(self):
|
||||
links = extract_links("I prefer vim.", "my-turn-42")
|
||||
assert all(l.source_id == "my-turn-42" for l in links)
|
||||
|
||||
def test_source_type_preserved(self):
|
||||
links = extract_links("I prefer vim.", "dag:sess:0000", source_type="dag")
|
||||
assert all(l.source_type == "dag" for l in links)
|
||||
|
||||
# ---- max_per_type cap --------------------------------------------------
|
||||
|
||||
def test_max_per_type_is_respected(self):
|
||||
# Generate text with many capitalized entities
|
||||
text = " ".join(f'The entity {chr(65+i)} is a component.' for i in range(20))
|
||||
links = extract_links(text, "turn:1", max_per_type=3)
|
||||
for rel in RelationType:
|
||||
count = sum(1 for l in links if l.relation_type == rel)
|
||||
assert count <= 3 * 2, f"{rel} exceeded cap: {count}" # some slack for variants
|
||||
|
||||
|
||||
class TestTypedLinkExtractor:
|
||||
"""Tests for the DB-backed TypedLinkExtractor."""
|
||||
|
||||
def test_extract_and_store_returns_links(self, atlas_db):
|
||||
ex = TypedLinkExtractor(atlas_db)
|
||||
links = ex.extract_and_store(
|
||||
"I prefer Python over JavaScript.", "sess1:0000"
|
||||
)
|
||||
assert len(links) >= 1
|
||||
assert any(l.relation_type == RelationType.PREFERS for l in links)
|
||||
|
||||
def test_stored_links_queryable(self, atlas_db):
|
||||
ex = TypedLinkExtractor(atlas_db)
|
||||
ex.extract_and_store("I prefer PostgreSQL over MySQL.", "sess1:0000")
|
||||
results = ex.query_links(relation_type=RelationType.PREFERS)
|
||||
assert len(results) >= 1
|
||||
|
||||
def test_query_by_source_id(self, atlas_db):
|
||||
ex = TypedLinkExtractor(atlas_db)
|
||||
ex.extract_and_store("Built with FastAPI.", "sess1:0000")
|
||||
ex.extract_and_store("Uses Django instead.", "sess1:0001")
|
||||
results = ex.query_links(source_id="sess1:0000")
|
||||
assert all(r["source_id"] == "sess1:0000" for r in results)
|
||||
|
||||
def test_query_by_relation_type_string(self, atlas_db):
|
||||
ex = TypedLinkExtractor(atlas_db)
|
||||
ex.extract_and_store("I prefer tabs.", "sess1:0000")
|
||||
results = ex.query_links(relation_type="prefers")
|
||||
assert len(results) >= 1
|
||||
|
||||
def test_empty_text_stores_nothing(self, atlas_db):
|
||||
ex = TypedLinkExtractor(atlas_db)
|
||||
links = ex.extract_and_store("", "sess1:0000")
|
||||
assert links == []
|
||||
results = ex.query_links(source_id="sess1:0000")
|
||||
assert len(results) == 0
|
||||
|
||||
def test_five_relation_types_extractable(self, atlas_db):
|
||||
"""Verify all 5 required relation types can be extracted from fixtures."""
|
||||
ex = TypedLinkExtractor(atlas_db)
|
||||
fixtures = [
|
||||
("I prefer Python over Perl.", RelationType.PREFERS),
|
||||
('We use "Redis" for caching.', RelationType.USES),
|
||||
("The module requires aiohttp.", RelationType.DEPENDS_ON),
|
||||
("Actually, the port is 9090.", RelationType.CORRECTS),
|
||||
('"FastAPI" is mentioned as the framework.', RelationType.MENTIONS),
|
||||
]
|
||||
for i, (text, expected_rel) in enumerate(fixtures):
|
||||
links = ex.extract_and_store(text, f"sess1:{i:04d}")
|
||||
found = [l for l in links if l.relation_type == expected_rel]
|
||||
assert found, (
|
||||
f"Expected relation {expected_rel.value!r} not found in:\n {text!r}\n"
|
||||
f" Got: {[l.relation_type.value for l in links]}"
|
||||
)
|
||||
153
tests/atlas/test_recall.py
Normal file
153
tests/atlas/test_recall.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Tests for RecallEngine — acceptance criterion:
|
||||
At least 3 explicit recall operations exist (search / describe / expand).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from agent.atlas.dag import SummaryDAGStore
|
||||
from agent.atlas.turns import RawTurnStore
|
||||
from agent.atlas.stores import AtlasStores
|
||||
from agent.atlas.recall import RecallEngine
|
||||
|
||||
|
||||
class TestRecallSearch:
|
||||
"""Tests for the atlas_search recall operation."""
|
||||
|
||||
def test_search_finds_turn_content(self, atlas_db):
|
||||
turns = RawTurnStore(atlas_db)
|
||||
turns.append("sess1", "user", "The capital of France is Paris.")
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.search("capital France")
|
||||
assert "Paris" in result or "capital" in result.lower()
|
||||
|
||||
def test_search_finds_durable_memory(self, atlas_db):
|
||||
stores = AtlasStores(atlas_db)
|
||||
stores.durable.add("User prefers functional programming", category="user_pref")
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.search("functional programming")
|
||||
assert "functional" in result.lower() or "programming" in result.lower()
|
||||
|
||||
def test_search_no_results_returns_message(self, atlas_db):
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.search("xyzzy_nonexistent_query_abc")
|
||||
assert "No results" in result
|
||||
|
||||
def test_search_with_doc_type_filter(self, atlas_db):
|
||||
turns = RawTurnStore(atlas_db)
|
||||
turns.append("sess1", "user", "Search target text for filtering.")
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.search("Search target", doc_types=["raw_turn"])
|
||||
assert "raw_turn" in result or "Search target" in result or "No results" in result
|
||||
|
||||
def test_search_returns_formatted_string(self, atlas_db):
|
||||
turns = RawTurnStore(atlas_db)
|
||||
turns.append("sess1", "user", "The Eiffel Tower is in Paris.")
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.search("Eiffel Tower")
|
||||
# Should be a human-readable string, not raw JSON
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
def test_search_respects_limit(self, atlas_db):
|
||||
turns = RawTurnStore(atlas_db)
|
||||
for i in range(20):
|
||||
turns.append(f"sess{i}", "user", f"Common keyword present in turn {i}.")
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.search("keyword present", limit=3)
|
||||
# Count result entries by counting the "[N]" markers
|
||||
import re
|
||||
hits = re.findall(r"\[\d+\]", result)
|
||||
assert len(hits) <= 3
|
||||
|
||||
|
||||
class TestRecallDescribe:
|
||||
"""Tests for the atlas_describe recall operation."""
|
||||
|
||||
def test_describe_existing_turn(self, atlas_db):
|
||||
turns = RawTurnStore(atlas_db)
|
||||
turn_id = turns.append("sess1", "user", "Tell me about recursion.")
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.describe(turn_id)
|
||||
assert "sess1:0000" in result
|
||||
assert "user" in result
|
||||
assert "recursion" in result
|
||||
|
||||
def test_describe_dag_node(self, atlas_db):
|
||||
dag = SummaryDAGStore(atlas_db)
|
||||
node = dag.add_node("sess1", "Summary text here.", ["sess1:0000"])
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.describe(node.node_id)
|
||||
assert node.node_id in result
|
||||
assert "Summary text" in result
|
||||
|
||||
def test_describe_missing_turn(self, atlas_db):
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.describe("missing_sess:9999")
|
||||
assert "No turn found" in result or "No" in result
|
||||
|
||||
def test_describe_missing_dag_node(self, atlas_db):
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.describe("dag:missing:9999")
|
||||
assert "No DAG node" in result or "No" in result
|
||||
|
||||
def test_describe_includes_role(self, atlas_db):
|
||||
turns = RawTurnStore(atlas_db)
|
||||
turn_id = turns.append("sess1", "assistant", "Here is the answer.")
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.describe(turn_id)
|
||||
assert "assistant" in result
|
||||
|
||||
def test_describe_tool_turn_shows_tool_name(self, atlas_db):
|
||||
turns = RawTurnStore(atlas_db)
|
||||
turn_id = turns.append(
|
||||
"sess1", "tool", '{"result": "success"}',
|
||||
tool_name="bash", tool_call_id="call_xyz"
|
||||
)
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.describe(turn_id)
|
||||
assert "bash" in result
|
||||
|
||||
|
||||
class TestRecallExpand:
|
||||
"""Tests for the atlas_expand recall operation."""
|
||||
|
||||
def test_expand_shows_summary_and_source_turns(self, atlas_db):
|
||||
turns = RawTurnStore(atlas_db)
|
||||
t0 = turns.append("sess1", "user", "What is the boiling point of water?")
|
||||
t1 = turns.append("sess1", "assistant", "Water boils at 100°C at sea level.")
|
||||
dag = SummaryDAGStore(atlas_db)
|
||||
node = dag.add_node(
|
||||
"sess1",
|
||||
"User asked about water boiling point. Answer: 100°C.",
|
||||
source_turn_ids=[t0, t1],
|
||||
)
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.expand(node.node_id)
|
||||
# Summary is present
|
||||
assert "100°C" in result or "boiling" in result.lower()
|
||||
# Source turns are shown
|
||||
assert t0 in result or "user" in result.lower()
|
||||
|
||||
def test_expand_missing_node(self, atlas_db):
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.expand("dag:missing:9999")
|
||||
assert "No DAG node" in result
|
||||
|
||||
def test_expand_includes_node_id(self, atlas_db):
|
||||
dag = SummaryDAGStore(atlas_db)
|
||||
node = dag.add_node("sess1", "Compacted summary.", [])
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.expand(node.node_id)
|
||||
assert node.node_id in result
|
||||
|
||||
def test_expand_respects_max_turns(self, atlas_db):
|
||||
turns = RawTurnStore(atlas_db)
|
||||
turn_ids = []
|
||||
for i in range(10):
|
||||
tid = turns.append("sess1", "user", f"Turn {i} content.")
|
||||
turn_ids.append(tid)
|
||||
dag = SummaryDAGStore(atlas_db)
|
||||
node = dag.add_node("sess1", "Big compaction.", turn_ids)
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.expand(node.node_id, max_turns=3)
|
||||
# Should mention it's truncated
|
||||
assert "more" in result.lower() or len(turn_ids) - 3 >= 0
|
||||
189
tests/atlas/test_recovery.py
Normal file
189
tests/atlas/test_recovery.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Integration test proving fact recovery from compacted context.
|
||||
|
||||
Acceptance criterion:
|
||||
A test or demo proves the agent can recover a fact from compacted context
|
||||
without re-injecting the full original transcript.
|
||||
|
||||
Scenario:
|
||||
1. A conversation establishes a fact ("the API key is 'abc-secret'").
|
||||
2. Context compaction fires — a DAG summary node is created from those turns.
|
||||
3. The original turns are not in the active window any more.
|
||||
4. The agent calls atlas_expand(node_id) and recovers the fact from the
|
||||
summary node and/or its source turns — without using the original
|
||||
transcript messages.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from agent.atlas.turns import RawTurnStore
|
||||
from agent.atlas.dag import SummaryDAGStore
|
||||
from agent.atlas.stores import AtlasStores
|
||||
from agent.atlas.extractor import TypedLinkExtractor, RelationType
|
||||
from agent.atlas.recall import RecallEngine
|
||||
|
||||
|
||||
def _simulate_compaction(db, session_id: str, messages: list[dict]) -> str:
|
||||
"""Simulate what AtlasMemoryProvider.on_pre_compress does.
|
||||
|
||||
Persists all messages as raw turns, then creates a summary DAG node
|
||||
from the assistant turns. Returns the new node_id.
|
||||
"""
|
||||
turns = RawTurnStore(db)
|
||||
dag = SummaryDAGStore(db)
|
||||
|
||||
source_turn_ids = []
|
||||
summary_parts = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
turn_id = turns.append(session_id, role, content)
|
||||
source_turn_ids.append(turn_id)
|
||||
if role == "assistant":
|
||||
summary_parts.append(content[:500])
|
||||
|
||||
summary_text = "\n".join(summary_parts) or "Compacted session."
|
||||
node = dag.add_node(session_id, summary_text, source_turn_ids)
|
||||
return node.node_id
|
||||
|
||||
|
||||
class TestFactRecovery:
|
||||
def test_recover_fact_via_search(self, atlas_db):
|
||||
"""Agent can search for a fact from compacted history."""
|
||||
session_id = "sess_recovery_search"
|
||||
messages = [
|
||||
{"role": "user", "content": "Set the staging API key to 'abc-secret-123'."},
|
||||
{"role": "assistant", "content": "Understood. Staging API key is now 'abc-secret-123'."},
|
||||
{"role": "user", "content": "Good. Now let's work on the deploy script."},
|
||||
{"role": "assistant", "content": "Sure, starting on the deploy script."},
|
||||
]
|
||||
|
||||
# Simulate compaction: turns persisted, DAG node created, active messages dropped
|
||||
_simulate_compaction(atlas_db, session_id, messages)
|
||||
|
||||
# Active context is now empty — agent must recall from ATLAS
|
||||
engine = RecallEngine(atlas_db)
|
||||
|
||||
# Agent searches for the API key
|
||||
result = engine.search("staging API key")
|
||||
assert "abc-secret-123" in result, (
|
||||
f"Fact not recovered via search. Got:\n{result}"
|
||||
)
|
||||
|
||||
def test_recover_fact_via_expand(self, atlas_db):
|
||||
"""Agent expands a DAG node to recover a specific fact."""
|
||||
session_id = "sess_recovery_expand"
|
||||
messages = [
|
||||
{"role": "user", "content": "The database host is 'db.internal.example.com'."},
|
||||
{"role": "assistant", "content": "Noted. Database host: db.internal.example.com."},
|
||||
]
|
||||
|
||||
node_id = _simulate_compaction(atlas_db, session_id, messages)
|
||||
|
||||
# Drop active context — only DAG node + raw turns remain in ATLAS
|
||||
engine = RecallEngine(atlas_db)
|
||||
|
||||
# Expand the DAG node to see source turns
|
||||
result = engine.expand(node_id)
|
||||
assert "db.internal.example.com" in result, (
|
||||
f"Fact not recovered via expand. Got:\n{result}"
|
||||
)
|
||||
|
||||
def test_recover_fact_via_describe(self, atlas_db):
|
||||
"""Agent describes a specific turn by lineage ID to recover a fact."""
|
||||
session_id = "sess_recovery_describe"
|
||||
turns = RawTurnStore(atlas_db)
|
||||
turn_id = turns.append(
|
||||
session_id, "user",
|
||||
"The deployment target is production cluster 'prod-us-east-1'."
|
||||
)
|
||||
|
||||
engine = RecallEngine(atlas_db)
|
||||
result = engine.describe(turn_id)
|
||||
assert "prod-us-east-1" in result, (
|
||||
f"Fact not recovered via describe. Got:\n{result}"
|
||||
)
|
||||
|
||||
def test_full_compaction_cycle(self, atlas_db):
|
||||
"""End-to-end: turns persisted → compaction → fact recovery.
|
||||
|
||||
This test mirrors the complete lifecycle:
|
||||
1. Conversation produces turns with important facts.
|
||||
2. Compaction fires and creates a DAG summary node.
|
||||
3. Active messages are dropped (simulated by not passing them to recall).
|
||||
4. Agent recovers the fact using only the ATLAS subsystem.
|
||||
"""
|
||||
session_id = "sess_full_cycle"
|
||||
|
||||
# Phase 1: conversation with important fact
|
||||
important_fact = "The encryption key rotation interval is 30 days."
|
||||
messages = [
|
||||
{"role": "user", "content": f"Configure the system: {important_fact}"},
|
||||
{"role": "assistant", "content": f"Configured. {important_fact}"},
|
||||
{"role": "user", "content": "What else should we set up?"},
|
||||
{"role": "assistant", "content": "Let's configure the backup schedule next."},
|
||||
]
|
||||
|
||||
# Phase 2: compaction
|
||||
node_id = _simulate_compaction(atlas_db, session_id, messages)
|
||||
|
||||
# Phase 3: verify DAG node was created with source references
|
||||
dag = SummaryDAGStore(atlas_db)
|
||||
node = dag.get_node(node_id)
|
||||
assert node is not None
|
||||
assert node.source_count() == len(messages)
|
||||
|
||||
# Phase 4: active context dropped — recover fact without re-injecting transcript
|
||||
engine = RecallEngine(atlas_db)
|
||||
|
||||
# Option A: search
|
||||
search_result = engine.search("encryption key rotation")
|
||||
# Option B: expand
|
||||
expand_result = engine.expand(node_id)
|
||||
|
||||
assert "30 days" in search_result or "30 days" in expand_result, (
|
||||
"Fact '30 days' not found via either search or expand.\n"
|
||||
f"Search result:\n{search_result}\n\nExpand result:\n{expand_result}"
|
||||
)
|
||||
|
||||
def test_typed_links_survive_compaction(self, atlas_db):
|
||||
"""Typed links extracted at write time remain queryable after compaction."""
|
||||
session_id = "sess_typed_links"
|
||||
turns = RawTurnStore(atlas_db)
|
||||
extractor = TypedLinkExtractor(atlas_db)
|
||||
|
||||
# Persist a turn with an extractable preference
|
||||
turn_id = turns.append(
|
||||
session_id, "user", "I prefer PostgreSQL over MySQL for this project."
|
||||
)
|
||||
links = extractor.extract_and_store(
|
||||
"I prefer PostgreSQL over MySQL for this project.", turn_id
|
||||
)
|
||||
|
||||
# Simulate compaction (DAG node created, active messages dropped)
|
||||
dag = SummaryDAGStore(atlas_db)
|
||||
dag.add_node(session_id, "User prefers PostgreSQL.", [turn_id])
|
||||
|
||||
# Query typed links — should still be present
|
||||
pref_links = extractor.query_links(relation_type=RelationType.PREFERS)
|
||||
assert any("PostgreSQL" in lnk["object"] for lnk in pref_links), (
|
||||
f"Typed link not found. Links: {pref_links}"
|
||||
)
|
||||
|
||||
def test_three_stores_populated_independently(self, atlas_db):
|
||||
"""Facts written to different stores are independently recoverable."""
|
||||
stores = AtlasStores(atlas_db)
|
||||
|
||||
stores.write("world: The speed of light is 299,792,458 m/s")
|
||||
stores.write("User prefers dark mode", category="ui_pref")
|
||||
stores.write("session: Current focus is auth module", session_id="sess1", session_key="focus")
|
||||
|
||||
# Each store is queryable independently
|
||||
world_results = stores.world.search("speed of light")
|
||||
assert len(world_results) >= 1
|
||||
|
||||
durable_results = stores.durable.search("dark mode")
|
||||
assert len(durable_results) >= 1
|
||||
|
||||
session_val = stores.session.get("sess1", "focus")
|
||||
assert session_val is not None
|
||||
assert "auth" in session_val.lower()
|
||||
159
tests/atlas/test_stores.py
Normal file
159
tests/atlas/test_stores.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Tests for three-store routing — acceptance criterion:
|
||||
Writes route to explicit stores (world knowledge vs durable memory vs
|
||||
session state), not a single mixed bucket.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from agent.atlas.stores import (
|
||||
AtlasStores,
|
||||
WorldKnowledgeStore,
|
||||
DurableMemoryStore,
|
||||
SessionStateStore,
|
||||
StoreTarget,
|
||||
_classify_target,
|
||||
)
|
||||
|
||||
|
||||
class TestStoreRouting:
|
||||
"""Test deterministic routing logic."""
|
||||
|
||||
def test_world_prefix_routes_to_world(self):
|
||||
assert _classify_target("world: Paris is the capital") == StoreTarget.WORLD
|
||||
|
||||
def test_world_category_routes_to_world(self):
|
||||
assert _classify_target("Paris is the capital", "world:geography") == StoreTarget.WORLD
|
||||
|
||||
def test_session_prefix_routes_to_session(self):
|
||||
assert _classify_target("session: current plan is X") == StoreTarget.SESSION
|
||||
|
||||
def test_session_category_routes_to_session(self):
|
||||
assert _classify_target("Plan step 1", "session:plan") == StoreTarget.SESSION
|
||||
|
||||
def test_default_routes_to_durable(self):
|
||||
assert _classify_target("I prefer Python over Java") == StoreTarget.DURABLE
|
||||
assert _classify_target("user likes dark mode") == StoreTarget.DURABLE
|
||||
|
||||
|
||||
class TestWorldKnowledgeStore:
|
||||
def test_add_and_search(self, atlas_db):
|
||||
store = WorldKnowledgeStore(atlas_db)
|
||||
store.add("Paris is the capital of France", tags="geography", trust=0.95)
|
||||
results = store.search("capital of France")
|
||||
assert len(results) >= 1
|
||||
assert "Paris" in results[0]["content"]
|
||||
|
||||
def test_world_prefix_stripped_on_add(self, atlas_db):
|
||||
store = WorldKnowledgeStore(atlas_db)
|
||||
store.add("world: Rome is in Italy", tags="geography")
|
||||
results = store.search("Rome")
|
||||
assert any("Rome" in r["content"] for r in results)
|
||||
# prefix should be stripped
|
||||
assert all(not r["content"].startswith("world:") for r in results)
|
||||
|
||||
def test_count(self, atlas_db):
|
||||
store = WorldKnowledgeStore(atlas_db)
|
||||
store.add("Fact A")
|
||||
store.add("Fact B")
|
||||
assert store.count() == 2
|
||||
|
||||
def test_list_all(self, atlas_db):
|
||||
store = WorldKnowledgeStore(atlas_db)
|
||||
store.add("Alpha")
|
||||
store.add("Beta")
|
||||
items = store.list_all()
|
||||
assert len(items) == 2
|
||||
|
||||
|
||||
class TestDurableMemoryStore:
|
||||
def test_add_and_search(self, atlas_db):
|
||||
store = DurableMemoryStore(atlas_db)
|
||||
store.add("User prefers Python over JavaScript", category="user_pref")
|
||||
results = store.search("Python")
|
||||
assert len(results) >= 1
|
||||
assert "Python" in results[0]["content"]
|
||||
|
||||
def test_category_filter(self, atlas_db):
|
||||
store = DurableMemoryStore(atlas_db)
|
||||
store.add("Likes dark mode", category="ui_pref")
|
||||
store.add("Uses VSCode", category="tools")
|
||||
ui_results = store.list_by_category("ui_pref")
|
||||
assert len(ui_results) == 1
|
||||
assert "dark mode" in ui_results[0]["content"]
|
||||
|
||||
def test_count(self, atlas_db):
|
||||
store = DurableMemoryStore(atlas_db)
|
||||
store.add("X")
|
||||
store.add("Y")
|
||||
assert store.count() == 2
|
||||
|
||||
def test_trust_ordering(self, atlas_db):
|
||||
store = DurableMemoryStore(atlas_db)
|
||||
store.add("Low trust fact", trust=0.1)
|
||||
store.add("High trust fact", trust=0.9)
|
||||
results = store.search("fact")
|
||||
assert results[0]["trust"] >= results[-1]["trust"]
|
||||
|
||||
|
||||
class TestSessionStateStore:
|
||||
def test_set_and_get(self, atlas_db):
|
||||
store = SessionStateStore(atlas_db)
|
||||
store.set("sess1", "current_task", "Implement auth module")
|
||||
val = store.get("sess1", "current_task")
|
||||
assert val == "Implement auth module"
|
||||
|
||||
def test_set_overwrites_existing_key(self, atlas_db):
|
||||
store = SessionStateStore(atlas_db)
|
||||
store.set("sess1", "plan", "Step 1")
|
||||
store.set("sess1", "plan", "Step 2")
|
||||
assert store.get("sess1", "plan") == "Step 2"
|
||||
|
||||
def test_sessions_are_isolated(self, atlas_db):
|
||||
store = SessionStateStore(atlas_db)
|
||||
store.set("sessA", "key", "A's value")
|
||||
store.set("sessB", "key", "B's value")
|
||||
assert store.get("sessA", "key") == "A's value"
|
||||
assert store.get("sessB", "key") == "B's value"
|
||||
|
||||
def test_get_missing_returns_none(self, atlas_db):
|
||||
store = SessionStateStore(atlas_db)
|
||||
assert store.get("sess1", "missing_key") is None
|
||||
|
||||
def test_list_session(self, atlas_db):
|
||||
store = SessionStateStore(atlas_db)
|
||||
store.set("sess1", "k1", "v1")
|
||||
store.set("sess1", "k2", "v2")
|
||||
items = store.list_session("sess1")
|
||||
assert len(items) == 2
|
||||
|
||||
def test_clear_session(self, atlas_db):
|
||||
store = SessionStateStore(atlas_db)
|
||||
store.set("sess1", "k1", "v1")
|
||||
store.set("sess1", "k2", "v2")
|
||||
deleted = store.clear_session("sess1")
|
||||
assert deleted == 2
|
||||
assert store.count_session("sess1") == 0
|
||||
|
||||
|
||||
class TestAtlasStores:
|
||||
def test_write_world_content(self, atlas_db):
|
||||
stores = AtlasStores(atlas_db)
|
||||
target, _ = stores.write("world: The Earth orbits the Sun")
|
||||
assert target == StoreTarget.WORLD
|
||||
|
||||
def test_write_durable_content(self, atlas_db):
|
||||
stores = AtlasStores(atlas_db)
|
||||
target, _ = stores.write("User prefers tabs over spaces")
|
||||
assert target == StoreTarget.DURABLE
|
||||
|
||||
def test_write_session_content(self, atlas_db):
|
||||
stores = AtlasStores(atlas_db)
|
||||
target, _ = stores.write("session: Currently debugging auth", session_id="sess1")
|
||||
assert target == StoreTarget.SESSION
|
||||
|
||||
def test_search_all_returns_dict(self, atlas_db):
|
||||
stores = AtlasStores(atlas_db)
|
||||
stores.write("world: Water is H2O")
|
||||
stores.write("I prefer functional programming")
|
||||
results = stores.search_all("programming")
|
||||
assert "world" in results
|
||||
assert "durable" in results
|
||||
96
tests/atlas/test_turns.py
Normal file
96
tests/atlas/test_turns.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Tests for RawTurnStore — acceptance criterion:
|
||||
Every user/assistant/tool turn is persisted with stable lineage identifiers.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from agent.atlas.turns import RawTurnStore
|
||||
|
||||
|
||||
class TestRawTurnStore:
|
||||
def test_append_returns_stable_lineage_id(self, atlas_db):
|
||||
store = RawTurnStore(atlas_db)
|
||||
turn_id = store.append("sess1", "user", "Hello, world!")
|
||||
assert turn_id == "sess1:0000"
|
||||
|
||||
def test_turn_index_increments_per_session(self, atlas_db):
|
||||
store = RawTurnStore(atlas_db)
|
||||
t0 = store.append("sess1", "user", "First turn")
|
||||
t1 = store.append("sess1", "assistant", "Second turn")
|
||||
t2 = store.append("sess1", "user", "Third turn")
|
||||
assert t0 == "sess1:0000"
|
||||
assert t1 == "sess1:0001"
|
||||
assert t2 == "sess1:0002"
|
||||
|
||||
def test_sessions_are_independent(self, atlas_db):
|
||||
store = RawTurnStore(atlas_db)
|
||||
a = store.append("sessA", "user", "A first")
|
||||
b = store.append("sessB", "user", "B first")
|
||||
assert a == "sessA:0000"
|
||||
assert b == "sessB:0000"
|
||||
|
||||
def test_get_returns_full_record(self, atlas_db):
|
||||
store = RawTurnStore(atlas_db)
|
||||
turn_id = store.append("sess1", "user", "What is the capital of France?")
|
||||
rec = store.get(turn_id)
|
||||
assert rec is not None
|
||||
assert rec.turn_id == "sess1:0000"
|
||||
assert rec.role == "user"
|
||||
assert rec.content == "What is the capital of France?"
|
||||
assert rec.session_id == "sess1"
|
||||
assert rec.turn_index == 0
|
||||
|
||||
def test_get_nonexistent_returns_none(self, atlas_db):
|
||||
store = RawTurnStore(atlas_db)
|
||||
assert store.get("missing:0000") is None
|
||||
|
||||
def test_get_session_turns_oldest_first(self, atlas_db):
|
||||
store = RawTurnStore(atlas_db)
|
||||
store.append("sess1", "user", "Turn A")
|
||||
store.append("sess1", "assistant", "Turn B")
|
||||
store.append("sess1", "user", "Turn C")
|
||||
turns = store.get_session_turns("sess1")
|
||||
assert len(turns) == 3
|
||||
assert turns[0].content == "Turn A"
|
||||
assert turns[2].content == "Turn C"
|
||||
|
||||
def test_count_session_turns(self, atlas_db):
|
||||
store = RawTurnStore(atlas_db)
|
||||
store.append("sess1", "user", "x")
|
||||
store.append("sess1", "assistant", "y")
|
||||
assert store.count_session_turns("sess1") == 2
|
||||
assert store.count_session_turns("sessX") == 0
|
||||
|
||||
def test_tool_turn_fields(self, atlas_db):
|
||||
store = RawTurnStore(atlas_db)
|
||||
turn_id = store.append(
|
||||
"sess1", "tool", '{"result": "ok"}',
|
||||
tool_name="bash", tool_call_id="call_abc"
|
||||
)
|
||||
rec = store.get(turn_id)
|
||||
assert rec.tool_name == "bash"
|
||||
assert rec.tool_call_id == "call_abc"
|
||||
|
||||
def test_get_turns_by_ids(self, atlas_db):
|
||||
store = RawTurnStore(atlas_db)
|
||||
t0 = store.append("sess1", "user", "First")
|
||||
t1 = store.append("sess1", "assistant", "Second")
|
||||
results = store.get_turns_by_ids([t1, t0])
|
||||
# Order preserved by input ID list
|
||||
assert results[0].turn_id == t1
|
||||
assert results[1].turn_id == t0
|
||||
|
||||
def test_append_is_idempotent_on_conflict(self, atlas_db):
|
||||
"""INSERT OR IGNORE should not raise on duplicate turn_id."""
|
||||
store = RawTurnStore(atlas_db)
|
||||
t0 = store.append("sess1", "user", "First")
|
||||
# Force same index by inserting manually
|
||||
count_before = store.count_session_turns("sess1")
|
||||
# append a second turn (gets index 1, not a conflict)
|
||||
store.append("sess1", "assistant", "Second")
|
||||
assert store.count_session_turns("sess1") == count_before + 1
|
||||
|
||||
def test_lineage_label(self, atlas_db):
|
||||
store = RawTurnStore(atlas_db)
|
||||
turn_id = store.append("sess1", "user", "hi")
|
||||
rec = store.get(turn_id)
|
||||
assert rec.lineage_label() == "user@sess1:0000"
|
||||
715
tests/test_lossless_context.py
Normal file
715
tests/test_lossless_context.py
Normal file
@@ -0,0 +1,715 @@
|
||||
"""Tests for the lossless context + memory subsystem.
|
||||
|
||||
Covers:
|
||||
- TurnRecord: lineage IDs, immutability, SHA-256
|
||||
- SessionTurnStore: append-only persistence, load, search
|
||||
- SummaryDAG: DAG node persistence, source references, expand
|
||||
- LinkExtractor: all 5+ relation types, fixture-backed
|
||||
- StoreRouter: three-store routing
|
||||
- RecallEngine: search / describe / expand (the key recovery proof)
|
||||
- Recovery test: fact survives compaction and is recoverable via expand()
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from agent.lossless_context import (
|
||||
TurnRecord,
|
||||
SessionTurnStore,
|
||||
SummaryDAG,
|
||||
SummaryNode,
|
||||
RelationLink,
|
||||
RelationLinkStore,
|
||||
RelationType,
|
||||
StoreTier,
|
||||
LinkExtractor,
|
||||
StoreRouter,
|
||||
RecallEngine,
|
||||
ingest_turn,
|
||||
compact_turns_to_dag,
|
||||
create_lossless_context,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_hermes_home(tmp_path):
|
||||
"""Return a temporary directory that acts as HERMES_HOME."""
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_id():
|
||||
return "test-session-abc123"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def turn_store(session_id, tmp_hermes_home):
|
||||
return SessionTurnStore(session_id, hermes_home=tmp_hermes_home)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def summary_dag(session_id, tmp_hermes_home):
|
||||
return SummaryDAG(session_id, hermes_home=tmp_hermes_home)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def link_store(session_id, tmp_hermes_home):
|
||||
return RelationLinkStore(tmp_hermes_home / "sessions" / session_id / "links.jsonl")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store_router(tmp_hermes_home):
|
||||
return StoreRouter(hermes_home=tmp_hermes_home)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recall_engine(turn_store, summary_dag, link_store, store_router):
|
||||
return RecallEngine(turn_store, summary_dag, link_store, store_router)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def link_extractor():
|
||||
return LinkExtractor()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TurnRecord tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTurnRecord:
|
||||
def test_lineage_id_format(self):
|
||||
rec = TurnRecord.create(session_id="sess1", seq=7, role="user", content="Hello")
|
||||
assert rec.lineage_id == "sess1:7"
|
||||
|
||||
def test_content_sha256_computed(self):
|
||||
rec = TurnRecord.create(session_id="s", seq=0, role="user", content="test content")
|
||||
import hashlib
|
||||
expected = hashlib.sha256(b"test content").hexdigest()
|
||||
assert rec.content_sha256 == expected
|
||||
|
||||
def test_roundtrip_dict(self):
|
||||
rec = TurnRecord.create(session_id="s", seq=3, role="assistant", content="reply")
|
||||
restored = TurnRecord.from_dict(rec.to_dict())
|
||||
assert restored.lineage_id == rec.lineage_id
|
||||
assert restored.content == rec.content
|
||||
assert restored.role == rec.role
|
||||
assert restored.seq == rec.seq
|
||||
|
||||
def test_stable_lineage_id_across_roundtrip(self):
|
||||
"""lineage_id must survive serialization unchanged."""
|
||||
rec = TurnRecord.create(session_id="stable", seq=42, role="tool", content="output")
|
||||
d = rec.to_dict()
|
||||
restored = TurnRecord.from_dict(d)
|
||||
assert restored.lineage_id == "stable:42"
|
||||
|
||||
def test_tool_call_fields(self):
|
||||
rec = TurnRecord.create(
|
||||
session_id="s", seq=0, role="tool", content="result",
|
||||
tool_name="read_file", tool_call_id="call-001"
|
||||
)
|
||||
assert rec.tool_name == "read_file"
|
||||
assert rec.tool_call_id == "call-001"
|
||||
d = rec.to_dict()
|
||||
restored = TurnRecord.from_dict(d)
|
||||
assert restored.tool_name == "read_file"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionTurnStore tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSessionTurnStore:
|
||||
def test_append_returns_record_with_lineage(self, turn_store, session_id):
|
||||
rec = turn_store.append(role="user", content="Hi there")
|
||||
assert rec.lineage_id == f"{session_id}:0"
|
||||
assert rec.role == "user"
|
||||
assert rec.content == "Hi there"
|
||||
|
||||
def test_sequence_increments(self, turn_store, session_id):
|
||||
r1 = turn_store.append(role="user", content="first")
|
||||
r2 = turn_store.append(role="assistant", content="second")
|
||||
r3 = turn_store.append(role="tool", content="result")
|
||||
assert r1.seq == 0
|
||||
assert r2.seq == 1
|
||||
assert r3.seq == 2
|
||||
|
||||
def test_persistence_across_reload(self, session_id, tmp_hermes_home):
|
||||
store1 = SessionTurnStore(session_id, hermes_home=tmp_hermes_home)
|
||||
store1.append(role="user", content="persisted message")
|
||||
store1.append(role="assistant", content="persisted reply")
|
||||
|
||||
# Reload from disk
|
||||
store2 = SessionTurnStore(session_id, hermes_home=tmp_hermes_home)
|
||||
records = store2.load_all()
|
||||
assert len(records) == 2
|
||||
assert records[0].content == "persisted message"
|
||||
assert records[1].content == "persisted reply"
|
||||
|
||||
def test_load_preserves_order(self, turn_store):
|
||||
for i in range(5):
|
||||
turn_store.append(role="user", content=f"message {i}")
|
||||
records = turn_store.load_all()
|
||||
assert [r.seq for r in records] == [0, 1, 2, 3, 4]
|
||||
|
||||
def test_search_case_insensitive(self, turn_store):
|
||||
turn_store.append(role="user", content="I prefer Python for backend work")
|
||||
turn_store.append(role="assistant", content="Understood")
|
||||
turn_store.append(role="user", content="Unrelated message")
|
||||
|
||||
results = turn_store.search("PYTHON")
|
||||
assert len(results) == 1
|
||||
assert "Python" in results[0].content
|
||||
|
||||
def test_get_by_id(self, turn_store, session_id):
|
||||
turn_store.append(role="user", content="first")
|
||||
r = turn_store.append(role="user", content="target")
|
||||
turn_store.append(role="user", content="after")
|
||||
|
||||
found = turn_store.get_by_id(r.lineage_id)
|
||||
assert found is not None
|
||||
assert found.content == "target"
|
||||
|
||||
def test_get_by_id_missing(self, turn_store):
|
||||
assert turn_store.get_by_id("nonexistent:999") is None
|
||||
|
||||
def test_turns_are_never_deleted(self, session_id, tmp_hermes_home):
|
||||
"""Appending more turns never removes old ones (lossless)."""
|
||||
store = SessionTurnStore(session_id, hermes_home=tmp_hermes_home)
|
||||
for i in range(10):
|
||||
store.append(role="user", content=f"turn {i}")
|
||||
records = store.load_all()
|
||||
assert len(records) == 10
|
||||
|
||||
def test_seq_resumes_after_reload(self, session_id, tmp_hermes_home):
|
||||
"""New store instance continues from the last sequence number."""
|
||||
store1 = SessionTurnStore(session_id, hermes_home=tmp_hermes_home)
|
||||
store1.append(role="user", content="a")
|
||||
store1.append(role="user", content="b")
|
||||
|
||||
store2 = SessionTurnStore(session_id, hermes_home=tmp_hermes_home)
|
||||
r = store2.append(role="user", content="c")
|
||||
assert r.seq == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SummaryDAG tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSummaryDAG:
|
||||
def test_add_node_returns_node_with_id(self, summary_dag, session_id):
|
||||
node = summary_dag.add_node(
|
||||
summary_text="Worked on issue #123",
|
||||
source_turn_ids=["sess:0", "sess:1", "sess:2"],
|
||||
)
|
||||
assert node.node_id == f"{session_id}:summary:0"
|
||||
assert node.session_id == session_id
|
||||
assert node.source_turn_ids == ["sess:0", "sess:1", "sess:2"]
|
||||
|
||||
def test_sequential_node_ids(self, summary_dag, session_id):
|
||||
n1 = summary_dag.add_node("first summary", ["s:0"])
|
||||
n2 = summary_dag.add_node("second summary", ["s:1"])
|
||||
assert n1.node_id.endswith(":summary:0")
|
||||
assert n2.node_id.endswith(":summary:1")
|
||||
|
||||
def test_persistence_across_reload(self, session_id, tmp_hermes_home):
|
||||
dag1 = SummaryDAG(session_id, hermes_home=tmp_hermes_home)
|
||||
dag1.add_node("first", ["t:0"])
|
||||
dag1.add_node("second", ["t:1"])
|
||||
|
||||
dag2 = SummaryDAG(session_id, hermes_home=tmp_hermes_home)
|
||||
nodes = dag2.load_all()
|
||||
assert len(nodes) == 2
|
||||
assert nodes[0].summary_text == "first"
|
||||
|
||||
def test_parent_node_id(self, summary_dag):
|
||||
n1 = summary_dag.add_node("initial", ["t:0"])
|
||||
n2 = summary_dag.add_node("update", ["t:1"], parent_node_id=n1.node_id)
|
||||
assert n2.parent_node_id == n1.node_id
|
||||
|
||||
def test_get_by_id(self, summary_dag):
|
||||
n = summary_dag.add_node("target summary", ["t:5"])
|
||||
found = summary_dag.get_by_id(n.node_id)
|
||||
assert found is not None
|
||||
assert found.summary_text == "target summary"
|
||||
|
||||
def test_get_latest(self, summary_dag):
|
||||
summary_dag.add_node("older", ["t:0"])
|
||||
latest = summary_dag.add_node("newest", ["t:1"])
|
||||
assert summary_dag.get_latest().node_id == latest.node_id
|
||||
|
||||
def test_search(self, summary_dag):
|
||||
summary_dag.add_node("Fixed a bug in auth module", ["t:0"])
|
||||
summary_dag.add_node("Refactored the database layer", ["t:1"])
|
||||
results = summary_dag.search("auth")
|
||||
assert len(results) == 1
|
||||
assert "auth" in results[0].summary_text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LinkExtractor tests — fixture-backed, deterministic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLinkExtractor:
|
||||
"""Fixture-backed tests for all 5+ relation types."""
|
||||
|
||||
FIXTURES = [
|
||||
# (input_text, expected_relation, partial_subject, partial_object)
|
||||
|
||||
# PREFERS
|
||||
("I prefer Python over JavaScript for backend work.",
|
||||
RelationType.PREFERS, "user", "Python"),
|
||||
|
||||
("My preferred editor is Neovim.",
|
||||
RelationType.PREFERS, "user", "Neovim"),
|
||||
|
||||
# CORRECTS
|
||||
("Actually, the port is 8081 not 8080.",
|
||||
RelationType.CORRECTS, "user", "the port is 8081"),
|
||||
|
||||
("I meant the config file is at /etc/app/config.yaml.",
|
||||
RelationType.CORRECTS, "user", "the config"),
|
||||
|
||||
# USES
|
||||
("The project uses PostgreSQL for data storage.",
|
||||
RelationType.USES, "project", "PostgreSQL"),
|
||||
|
||||
("Using Docker for containerization.",
|
||||
RelationType.USES, "project", "Docker"),
|
||||
|
||||
# LOCATED_AT
|
||||
("The config is at /etc/hermes/config.yaml.",
|
||||
RelationType.LOCATED_AT, "config", "/etc/hermes/config.yaml"),
|
||||
|
||||
# DEPENDS_ON
|
||||
("hermes-agent requires openai>=2.0.",
|
||||
RelationType.DEPENDS_ON, "hermes-agent", "openai"),
|
||||
|
||||
# CONFIGURES
|
||||
("debug is set to true.",
|
||||
RelationType.CONFIGURES, "debug", "true"),
|
||||
|
||||
("Set log_level to DEBUG.",
|
||||
RelationType.CONFIGURES, "log_level", "DEBUG"),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("text,expected_rel,partial_subj,partial_obj", FIXTURES)
|
||||
def test_extraction(self, link_extractor, text, expected_rel, partial_subj, partial_obj):
|
||||
links = link_extractor.extract(text, source_turn_id="test:0")
|
||||
matching = [
|
||||
lk for lk in links
|
||||
if lk.relation == expected_rel.value
|
||||
and (partial_subj.lower() in lk.subject.lower() or
|
||||
partial_subj.lower() in lk.object_.lower())
|
||||
]
|
||||
assert len(matching) >= 1, (
|
||||
f"Expected at least one {expected_rel.value} link "
|
||||
f"with subject/object containing '{partial_subj}' or '{partial_obj}'. "
|
||||
f"Got links: {[(lk.relation, lk.subject, lk.object_) for lk in links]}"
|
||||
)
|
||||
|
||||
def test_source_turn_id_preserved(self, link_extractor):
|
||||
links = link_extractor.extract("I prefer Go over Python.", source_turn_id="session:7")
|
||||
assert all(lk.source_turn_id == "session:7" for lk in links)
|
||||
|
||||
def test_empty_text_returns_empty(self, link_extractor):
|
||||
assert link_extractor.extract("", source_turn_id="s:0") == []
|
||||
|
||||
def test_short_text_returns_empty(self, link_extractor):
|
||||
assert link_extractor.extract("Hi", source_turn_id="s:0") == []
|
||||
|
||||
def test_no_duplicates_for_same_pattern(self, link_extractor):
|
||||
links = link_extractor.extract(
|
||||
"I prefer Python. I prefer Python.", source_turn_id="s:0"
|
||||
)
|
||||
# Should deduplicate
|
||||
prefers = [lk for lk in links if lk.relation == RelationType.PREFERS.value]
|
||||
assert len(prefers) <= 2 # May match twice on different patterns but not the same exact key
|
||||
|
||||
def test_confidence_in_range(self, link_extractor):
|
||||
links = link_extractor.extract("The project uses Redis.", source_turn_id="s:0")
|
||||
for lk in links:
|
||||
assert 0.0 <= lk.confidence <= 1.0
|
||||
|
||||
def test_store_tier_assigned(self, link_extractor):
|
||||
links = link_extractor.extract("I prefer Python.", source_turn_id="s:0")
|
||||
for lk in links:
|
||||
assert lk.store_tier in (
|
||||
StoreTier.WORLD_KNOWLEDGE.value,
|
||||
StoreTier.DURABLE_MEMORY.value,
|
||||
StoreTier.SESSION_STATE.value,
|
||||
)
|
||||
|
||||
def test_five_distinct_relation_types_covered(self, link_extractor):
|
||||
"""A single input set covers all 5+ relation types."""
|
||||
texts = [
|
||||
("I prefer TypeScript over JavaScript.", RelationType.PREFERS),
|
||||
("Actually the database is PostgreSQL.", RelationType.CORRECTS),
|
||||
("The project uses FastAPI.", RelationType.USES),
|
||||
("The config is at /app/settings.py.", RelationType.LOCATED_AT),
|
||||
("myapp requires redis>=4.0.", RelationType.DEPENDS_ON),
|
||||
("cache_timeout is set to 3600.", RelationType.CONFIGURES),
|
||||
]
|
||||
found_types = set()
|
||||
for text, expected_rel in texts:
|
||||
links = link_extractor.extract(text, source_turn_id="s:0")
|
||||
for lk in links:
|
||||
found_types.add(lk.relation)
|
||||
|
||||
assert RelationType.PREFERS.value in found_types
|
||||
assert RelationType.CORRECTS.value in found_types
|
||||
assert RelationType.USES.value in found_types
|
||||
assert RelationType.DEPENDS_ON.value in found_types or RelationType.CONFIGURES.value in found_types
|
||||
# At least 5 distinct types found
|
||||
assert len(found_types) >= 5, f"Expected ≥5 relation types, found: {found_types}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StoreRouter tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStoreRouter:
|
||||
def test_write_and_load_durable(self, store_router):
|
||||
store_router.write(
|
||||
tier=StoreTier.DURABLE_MEMORY,
|
||||
category="user_pref",
|
||||
content="user prefers dark mode",
|
||||
source_id="s:0",
|
||||
)
|
||||
facts = store_router.load(StoreTier.DURABLE_MEMORY)
|
||||
assert len(facts) == 1
|
||||
assert "dark mode" in facts[0]["content"]
|
||||
|
||||
def test_write_and_load_world_knowledge(self, store_router):
|
||||
store_router.write(
|
||||
tier=StoreTier.WORLD_KNOWLEDGE,
|
||||
category="USES",
|
||||
content="project USES PostgreSQL",
|
||||
source_id="s:1",
|
||||
)
|
||||
facts = store_router.load(StoreTier.WORLD_KNOWLEDGE)
|
||||
assert len(facts) >= 1
|
||||
assert any("PostgreSQL" in f["content"] for f in facts)
|
||||
|
||||
def test_session_state_tier_is_noop(self, store_router):
|
||||
# Writing to SESSION_STATE should not create any file
|
||||
store_router.write(tier=StoreTier.SESSION_STATE, category="turn", content="stuff")
|
||||
# No file should be created for session_state (it's handled by SessionTurnStore)
|
||||
session_path = store_router._stores.get(StoreTier.SESSION_STATE)
|
||||
assert session_path is None
|
||||
|
||||
def test_search_across_tiers(self, store_router):
|
||||
store_router.write(StoreTier.DURABLE_MEMORY, "pref", "user prefers vim editor")
|
||||
store_router.write(StoreTier.WORLD_KNOWLEDGE, "USES", "project uses React")
|
||||
|
||||
# Search both
|
||||
vim_results = store_router.search("vim")
|
||||
assert len(vim_results) == 1
|
||||
assert "vim" in vim_results[0]["content"]
|
||||
|
||||
react_results = store_router.search("React")
|
||||
assert len(react_results) == 1
|
||||
|
||||
def test_search_tier_filter(self, store_router):
|
||||
store_router.write(StoreTier.DURABLE_MEMORY, "pref", "prefers dark mode")
|
||||
store_router.write(StoreTier.WORLD_KNOWLEDGE, "fact", "dark web reference")
|
||||
|
||||
# Search only DURABLE_MEMORY
|
||||
results = store_router.search("dark", tier=StoreTier.DURABLE_MEMORY)
|
||||
assert all(r["tier"] == StoreTier.DURABLE_MEMORY for r in results)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RecallEngine tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRecallEngine:
|
||||
def test_search_finds_turn_content(self, recall_engine, turn_store):
|
||||
turn_store.append(role="user", content="I prefer Go for systems programming")
|
||||
result = recall_engine.search("Go for systems")
|
||||
assert result["total_results"] > 0
|
||||
assert len(result["turns"]) >= 1
|
||||
assert "Go for systems" in result["turns"][0]["content_preview"]
|
||||
|
||||
def test_search_finds_summary_content(self, recall_engine, summary_dag):
|
||||
summary_dag.add_node(
|
||||
"User prefers TypeScript. Fixed auth bug.", source_turn_ids=["s:0"]
|
||||
)
|
||||
result = recall_engine.search("TypeScript")
|
||||
assert len(result["summaries"]) >= 1
|
||||
|
||||
def test_search_returns_structured_result(self, recall_engine, turn_store):
|
||||
turn_store.append(role="user", content="test query")
|
||||
result = recall_engine.search("test query")
|
||||
assert "query" in result
|
||||
assert "turns" in result
|
||||
assert "summaries" in result
|
||||
assert "links" in result
|
||||
assert "total_results" in result
|
||||
|
||||
def test_describe_turn(self, recall_engine, turn_store, session_id):
|
||||
r = turn_store.append(role="user", content="The config is at /home/user/app.yaml")
|
||||
result = recall_engine.describe(r.lineage_id)
|
||||
assert result["type"] == "turn"
|
||||
assert result["lineage_id"] == r.lineage_id
|
||||
assert result["content"] == "The config is at /home/user/app.yaml"
|
||||
assert result["role"] == "user"
|
||||
|
||||
def test_describe_summary_node(self, recall_engine, summary_dag, session_id):
|
||||
node = summary_dag.add_node("summary content", source_turn_ids=["x:0", "x:1"])
|
||||
result = recall_engine.describe(node.node_id)
|
||||
assert result["type"] == "summary_node"
|
||||
assert result["node_id"] == node.node_id
|
||||
assert result["source_turn_ids"] == ["x:0", "x:1"]
|
||||
|
||||
def test_describe_missing_returns_error(self, recall_engine):
|
||||
result = recall_engine.describe("nonexistent:999")
|
||||
assert "error" in result
|
||||
|
||||
def test_expand_shows_source_turns(self, recall_engine, turn_store, summary_dag, session_id):
|
||||
"""expand() must retrieve the original turns that produced a summary."""
|
||||
# Ingest some turns
|
||||
r0 = turn_store.append(role="user", content="Fact: the DB is Postgres")
|
||||
r1 = turn_store.append(role="assistant", content="Noted — using PostgreSQL")
|
||||
r2 = turn_store.append(role="user", content="Also use Redis for caching")
|
||||
|
||||
# Create a summary DAG node referencing those turns
|
||||
node = summary_dag.add_node(
|
||||
summary_text="Established DB stack: Postgres + Redis cache",
|
||||
source_turn_ids=[r0.lineage_id, r1.lineage_id, r2.lineage_id],
|
||||
)
|
||||
|
||||
# Expand the summary — should retrieve original turn content
|
||||
result = recall_engine.expand(node.node_id)
|
||||
assert "error" not in result
|
||||
assert result["node_id"] == node.node_id
|
||||
assert result["source_count"] == 3
|
||||
assert result["found_count"] == 3
|
||||
|
||||
# Verify we can read the original content
|
||||
contents = {t["content"] for t in result["source_turns"]}
|
||||
assert "Fact: the DB is Postgres" in contents
|
||||
assert "Noted — using PostgreSQL" in contents
|
||||
assert "Also use Redis for caching" in contents
|
||||
|
||||
def test_expand_missing_node(self, recall_engine):
|
||||
result = recall_engine.expand("nonexistent:summary:0")
|
||||
assert "error" in result
|
||||
|
||||
def test_expand_follows_parent_chain(self, recall_engine, turn_store, summary_dag):
|
||||
"""expand() returns the parent chain for DAG traversal."""
|
||||
r = turn_store.append(role="user", content="initial context")
|
||||
n1 = summary_dag.add_node("first compaction", source_turn_ids=[r.lineage_id])
|
||||
r2 = turn_store.append(role="user", content="more context")
|
||||
n2 = summary_dag.add_node(
|
||||
"second compaction (includes first)",
|
||||
source_turn_ids=[r2.lineage_id],
|
||||
parent_node_id=n1.node_id,
|
||||
)
|
||||
|
||||
result = recall_engine.expand(n2.node_id)
|
||||
assert result["parent_chain"]
|
||||
assert result["parent_chain"][0]["node_id"] == n1.node_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration test: fact recovery from compacted context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFactRecoveryFromCompactedContext:
|
||||
"""Prove that the agent can recover a fact from compacted context
|
||||
without re-injecting the full original transcript.
|
||||
|
||||
Scenario:
|
||||
1. Ingest multiple turns (including a key fact)
|
||||
2. Create a summary DAG node (simulating context compaction)
|
||||
3. The original turns are "gone" from the live context window
|
||||
4. Use RecallEngine.expand() to recover the original fact
|
||||
"""
|
||||
|
||||
def test_recover_fact_via_expand(self, session_id, tmp_hermes_home):
|
||||
"""Core acceptance criterion: recover a fact from compacted context."""
|
||||
turn_store, summary_dag, link_extractor, link_store, store_router, recall_engine = \
|
||||
create_lossless_context(session_id, hermes_home=tmp_hermes_home)
|
||||
|
||||
# Step 1: Ingest turns that contain a key fact
|
||||
key_fact = "The production database password is stored in /etc/secrets/db.env"
|
||||
r0 = ingest_turn("What's the database setup?", "user", turn_store, link_extractor, link_store, store_router)
|
||||
r1 = ingest_turn(key_fact, "assistant", turn_store, link_extractor, link_store, store_router)
|
||||
r2 = ingest_turn("Thanks, understood.", "user", turn_store, link_extractor, link_store, store_router)
|
||||
|
||||
# Step 2: Simulate context compaction — create a summary node
|
||||
summary_text = "User asked about database setup. Assistant explained that credentials are stored in /etc/secrets/db.env."
|
||||
node = compact_turns_to_dag(
|
||||
summary_text=summary_text,
|
||||
source_turn_ids=[r0.lineage_id, r1.lineage_id, r2.lineage_id],
|
||||
summary_dag=summary_dag,
|
||||
)
|
||||
|
||||
# Step 3: Simulate that live context is now ONLY the summary
|
||||
# (original turns r0, r1, r2 are "gone" from the active window)
|
||||
# The agent only sees the summary node ID from the compacted message.
|
||||
|
||||
# Step 4: Use expand() to recover the original content
|
||||
expanded = recall_engine.expand(node.node_id)
|
||||
|
||||
assert "error" not in expanded, f"expand() failed: {expanded}"
|
||||
assert expanded["found_count"] == 3, f"Expected 3 source turns, got: {expanded}"
|
||||
|
||||
# The key fact must be in the expanded source turns
|
||||
all_content = " ".join(t["content"] for t in expanded["source_turns"])
|
||||
assert key_fact in all_content, (
|
||||
f"Key fact not found in expanded turns. Content: {all_content}"
|
||||
)
|
||||
|
||||
def test_search_recovers_fact_from_compacted_summary(self, session_id, tmp_hermes_home):
|
||||
"""search() can find a fact that's only in a summary node."""
|
||||
turn_store, summary_dag, link_extractor, link_store, store_router, recall_engine = \
|
||||
create_lossless_context(session_id, hermes_home=tmp_hermes_home)
|
||||
|
||||
# Ingest minimal turns then compact them
|
||||
r = ingest_turn("The API key is ABCD-1234.", "assistant", turn_store, link_extractor, link_store, store_router)
|
||||
compact_turns_to_dag(
|
||||
summary_text="User verified API key ABCD-1234 for production.",
|
||||
source_turn_ids=[r.lineage_id],
|
||||
summary_dag=summary_dag,
|
||||
)
|
||||
|
||||
# Search for the API key value — should find it in summaries
|
||||
result = recall_engine.search("ABCD-1234")
|
||||
assert result["total_results"] > 0
|
||||
|
||||
# Found in either turns or summaries
|
||||
found_in_turns = any("ABCD-1234" in t["content_preview"] for t in result["turns"])
|
||||
found_in_summaries = any("ABCD-1234" in s["summary_preview"] for s in result["summaries"])
|
||||
assert found_in_turns or found_in_summaries, (
|
||||
f"Fact not found. turns={result['turns']}, summaries={result['summaries']}"
|
||||
)
|
||||
|
||||
def test_typed_links_survive_compaction(self, session_id, tmp_hermes_home):
|
||||
"""Typed links extracted at write time survive even if turns are compacted."""
|
||||
turn_store, summary_dag, link_extractor, link_store, store_router, recall_engine = \
|
||||
create_lossless_context(session_id, hermes_home=tmp_hermes_home)
|
||||
|
||||
# Ingest a turn with a clear preference
|
||||
r = ingest_turn(
|
||||
"I prefer PostgreSQL over MySQL for this project.",
|
||||
"user",
|
||||
turn_store,
|
||||
link_extractor,
|
||||
link_store,
|
||||
store_router,
|
||||
)
|
||||
|
||||
# Compact the turn
|
||||
compact_turns_to_dag(
|
||||
summary_text="Database preference established.",
|
||||
source_turn_ids=[r.lineage_id],
|
||||
summary_dag=summary_dag,
|
||||
)
|
||||
|
||||
# The original typed link should still be searchable via RecallEngine
|
||||
result = recall_engine.search("PostgreSQL")
|
||||
# Should find the preference in typed links or turns
|
||||
assert result["total_results"] > 0
|
||||
|
||||
# Verify typed links were extracted
|
||||
link_results = result["links"]
|
||||
prefers_links = [lk for lk in link_results if lk["relation"] == RelationType.PREFERS.value]
|
||||
assert len(prefers_links) >= 1, f"Expected PREFERS link, got: {link_results}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ingest_turn integration test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIngestTurn:
|
||||
def test_ingest_appends_and_extracts(self, session_id, tmp_hermes_home):
|
||||
turn_store, _, link_extractor, link_store, store_router, _ = \
|
||||
create_lossless_context(session_id, hermes_home=tmp_hermes_home)
|
||||
|
||||
record = ingest_turn(
|
||||
"The project uses FastAPI for the REST endpoints.",
|
||||
"user",
|
||||
turn_store, link_extractor, link_store, store_router,
|
||||
)
|
||||
|
||||
assert record.role == "user"
|
||||
assert "FastAPI" in record.content
|
||||
|
||||
# Check that a USES link was extracted and stored
|
||||
links = link_store.load_all()
|
||||
uses_links = [lk for lk in links if lk.relation == RelationType.USES.value]
|
||||
assert len(uses_links) >= 1
|
||||
|
||||
def test_tool_turns_not_link_extracted(self, session_id, tmp_hermes_home):
|
||||
turn_store, _, link_extractor, link_store, store_router, _ = \
|
||||
create_lossless_context(session_id, hermes_home=tmp_hermes_home)
|
||||
|
||||
ingest_turn(
|
||||
"The project uses FastAPI", # Tool results are noisy — don't extract
|
||||
"tool",
|
||||
turn_store, link_extractor, link_store, store_router,
|
||||
tool_name="read_file", tool_call_id="call-abc",
|
||||
)
|
||||
|
||||
links = link_store.load_all()
|
||||
assert len(links) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Memory tool world_knowledge target test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMemoryStoreWorldKnowledge:
|
||||
"""Verify that the world_knowledge store target works in MemoryStore."""
|
||||
|
||||
def test_world_knowledge_add_and_load(self, tmp_hermes_home, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_hermes_home))
|
||||
from hermes_constants import get_hermes_home
|
||||
# Patch get_memory_dir to use tmp_hermes_home
|
||||
import tools.memory_tool as mt
|
||||
monkeypatch.setattr(mt, "get_memory_dir", lambda: tmp_hermes_home / "memories")
|
||||
|
||||
store = mt.MemoryStore()
|
||||
store.load_from_disk()
|
||||
|
||||
# Add to world_knowledge
|
||||
result = store.add("world_knowledge", "Python was created by Guido van Rossum")
|
||||
assert result["success"] is True
|
||||
|
||||
# Verify entry exists
|
||||
entries = store._entries_for("world_knowledge")
|
||||
assert any("Guido" in e for e in entries)
|
||||
|
||||
def test_world_knowledge_char_limit_separate(self, tmp_hermes_home, monkeypatch):
|
||||
import tools.memory_tool as mt
|
||||
monkeypatch.setattr(mt, "get_memory_dir", lambda: tmp_hermes_home / "memories")
|
||||
|
||||
store = mt.MemoryStore(world_knowledge_char_limit=100)
|
||||
store.load_from_disk()
|
||||
|
||||
# Add something long that exceeds the limit
|
||||
long_entry = "A" * 150
|
||||
result = store.add("world_knowledge", long_entry)
|
||||
assert result["success"] is False
|
||||
assert "limit" in result["error"].lower() or "exceed" in result["error"].lower()
|
||||
|
||||
def test_invalid_target_rejected(self, tmp_hermes_home, monkeypatch):
|
||||
import tools.memory_tool as mt
|
||||
monkeypatch.setattr(mt, "get_memory_dir", lambda: tmp_hermes_home / "memories")
|
||||
|
||||
store = mt.MemoryStore()
|
||||
store.load_from_disk()
|
||||
result = mt.memory_tool("add", target="unknown_store", content="test", store=store)
|
||||
data = json.loads(result)
|
||||
assert data.get("success") is False
|
||||
assert "unknown_store" in data.get("error", "")
|
||||
178
tools/lossless_recall_tool.py
Normal file
178
tools/lossless_recall_tool.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Lossless Recall Tool — Explicit recall over immutable session context.
|
||||
|
||||
Provides three recall operations over the lossless context subsystem
|
||||
(agent/lossless_context.py):
|
||||
|
||||
search — full-text search across session turns, summaries, relation links,
|
||||
and durable/world-knowledge stores. Returns matched results grouped
|
||||
by source type (turns, summary nodes, typed links, durable facts).
|
||||
|
||||
describe — retrieve full details about a specific turn or summary node by its
|
||||
stable lineage ID (e.g. 'session:42' or 'session:summary:3').
|
||||
Returns original content plus any typed links extracted from it.
|
||||
|
||||
expand — expand a summary node to show the original source turns it was
|
||||
compacted from. Proves the agent can recover facts from compacted
|
||||
context without re-injecting the full original transcript.
|
||||
|
||||
This tool complements session_search (which searches across past sessions via
|
||||
FTS5). lossless_recall operates over the current session's raw turn store and
|
||||
summary DAG — the immutable, lineage-aware substrate underneath compaction.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def lossless_recall(
|
||||
action: str,
|
||||
query: str = "",
|
||||
lineage_id: str = "",
|
||||
limit: int = 20,
|
||||
recall_engine=None,
|
||||
) -> str:
|
||||
"""Dispatch a recall operation to the RecallEngine.
|
||||
|
||||
Args:
|
||||
action: 'search', 'describe', or 'expand'
|
||||
query: Search query text (required for 'search')
|
||||
lineage_id: Lineage ID of a turn or summary node (for 'describe'/'expand')
|
||||
limit: Max results to return (for 'search', default 20)
|
||||
recall_engine: RecallEngine instance (injected by the tool registry)
|
||||
|
||||
Returns:
|
||||
JSON string with results.
|
||||
"""
|
||||
if recall_engine is None:
|
||||
return json.dumps({
|
||||
"error": (
|
||||
"lossless_recall is not available — the lossless context subsystem "
|
||||
"has not been initialized for this session. This tool requires "
|
||||
"HERMES_LOSSLESS_CONTEXT=1 in the environment or config."
|
||||
)
|
||||
})
|
||||
|
||||
if action == "search":
|
||||
if not query:
|
||||
return json.dumps({"error": "query is required for 'search' action."})
|
||||
try:
|
||||
result = recall_engine.search(query, limit=limit)
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error("lossless_recall search failed: %s", e)
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
elif action == "describe":
|
||||
if not lineage_id:
|
||||
return json.dumps({"error": "lineage_id is required for 'describe' action."})
|
||||
try:
|
||||
result = recall_engine.describe(lineage_id)
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error("lossless_recall describe failed: %s", e)
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
elif action == "expand":
|
||||
if not lineage_id:
|
||||
return json.dumps({"error": "lineage_id is required for 'expand' action."})
|
||||
try:
|
||||
result = recall_engine.expand(lineage_id)
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error("lossless_recall expand failed: %s", e)
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
else:
|
||||
return json.dumps({
|
||||
"error": f"Unknown action '{action}'. Valid actions: search, describe, expand"
|
||||
})
|
||||
|
||||
|
||||
def check_lossless_recall_requirements() -> bool:
|
||||
"""lossless_recall requires the lossless_context module."""
|
||||
try:
|
||||
from agent.lossless_context import RecallEngine # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenAI Function-Calling Schema
|
||||
# =============================================================================
|
||||
|
||||
LOSSLESS_RECALL_SCHEMA = {
|
||||
"name": "lossless_recall",
|
||||
"description": (
|
||||
"Recall information from the current session's lossless context store. "
|
||||
"Unlike raw message history, this recall operates over immutable turn records "
|
||||
"and summary DAG nodes — so facts from compacted context remain retrievable "
|
||||
"even after the original turns are summarised.\n\n"
|
||||
"THREE ACTIONS:\n"
|
||||
"- 'search': full-text search across raw turns, summary nodes, relation links, "
|
||||
" and durable/world-knowledge stores. Use this to find a fact from earlier in the "
|
||||
" session that may have been compacted out of the live context window.\n"
|
||||
"- 'describe': retrieve full details about a specific turn or summary node by its "
|
||||
" stable lineage ID (format: 'session_id:seq' for turns, 'session_id:summary:n' for nodes). "
|
||||
" Returns the original content plus typed links extracted from it.\n"
|
||||
"- 'expand': expand a summary node to show the original source turns it was built from. "
|
||||
" Use this to recover original turn content from a compacted summary without re-reading "
|
||||
" the full transcript — proves lossless recovery from compaction.\n\n"
|
||||
"WHEN TO USE:\n"
|
||||
"- A fact from earlier in the session has been compacted out of the live context\n"
|
||||
"- You remember something was said but can't see it in the current context window\n"
|
||||
"- You want to trace back a summary node to the original turns that produced it\n"
|
||||
"- You need to verify a specific typed relation (PREFERS, CORRECTS, USES, etc.) from the session\n\n"
|
||||
"NOTE: For searching across past sessions, use session_search instead."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["search", "describe", "expand"],
|
||||
"description": "The recall operation: 'search' for full-text search, 'describe' for a specific node, 'expand' to unfold a summary node.",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query text (case-insensitive substring match). Required for 'search' action.",
|
||||
},
|
||||
"lineage_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Stable lineage ID of a turn or summary node. Required for 'describe' and 'expand'. "
|
||||
"Turn format: 'session_id:seq_number' (e.g. 'abc123:42'). "
|
||||
"Summary format: 'session_id:summary:n' (e.g. 'abc123:summary:3')."
|
||||
),
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum results to return for 'search' action (default: 20).",
|
||||
"default": 20,
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry, tool_error # noqa: E402
|
||||
|
||||
registry.register(
|
||||
name="lossless_recall",
|
||||
toolset="memory",
|
||||
schema=LOSSLESS_RECALL_SCHEMA,
|
||||
handler=lambda args, **kw: lossless_recall(
|
||||
action=args.get("action", ""),
|
||||
query=args.get("query", ""),
|
||||
lineage_id=args.get("lineage_id", ""),
|
||||
limit=args.get("limit", 20),
|
||||
recall_engine=kw.get("recall_engine"),
|
||||
),
|
||||
check_fn=check_lossless_recall_requirements,
|
||||
emoji="📖",
|
||||
)
|
||||
@@ -106,37 +106,49 @@ class MemoryStore:
|
||||
"""
|
||||
Bounded curated memory with file persistence. One instance per AIAgent.
|
||||
|
||||
Maintains two parallel states:
|
||||
Maintains three parallel states:
|
||||
- _system_prompt_snapshot: frozen at load time, used for system prompt injection.
|
||||
Never mutated mid-session. Keeps prefix cache stable.
|
||||
- memory_entries / user_entries: live state, mutated by tool calls, persisted to disk.
|
||||
- memory_entries / user_entries / world_knowledge_entries: live state, mutated by
|
||||
tool calls, persisted to disk.
|
||||
Tool responses always reflect this live state.
|
||||
|
||||
Three targets:
|
||||
- 'memory' : agent notes, environment facts, project conventions (durable_memory)
|
||||
- 'user' : user profile, preferences, communication style (durable_memory)
|
||||
- 'world_knowledge' : stable world facts not tied to this user/session (world_knowledge)
|
||||
"""
|
||||
|
||||
def __init__(self, memory_char_limit: int = 2200, user_char_limit: int = 1375):
|
||||
def __init__(self, memory_char_limit: int = 2200, user_char_limit: int = 1375,
|
||||
world_knowledge_char_limit: int = 2000):
|
||||
self.memory_entries: List[str] = []
|
||||
self.user_entries: List[str] = []
|
||||
self.world_knowledge_entries: List[str] = []
|
||||
self.memory_char_limit = memory_char_limit
|
||||
self.user_char_limit = user_char_limit
|
||||
self.world_knowledge_char_limit = world_knowledge_char_limit
|
||||
# Frozen snapshot for system prompt -- set once at load_from_disk()
|
||||
self._system_prompt_snapshot: Dict[str, str] = {"memory": "", "user": ""}
|
||||
self._system_prompt_snapshot: Dict[str, str] = {"memory": "", "user": "", "world_knowledge": ""}
|
||||
|
||||
def load_from_disk(self):
|
||||
"""Load entries from MEMORY.md and USER.md, capture system prompt snapshot."""
|
||||
"""Load entries from MEMORY.md, USER.md, and WORLD_KNOWLEDGE.md, capture system prompt snapshot."""
|
||||
mem_dir = get_memory_dir()
|
||||
mem_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.memory_entries = self._read_file(mem_dir / "MEMORY.md")
|
||||
self.user_entries = self._read_file(mem_dir / "USER.md")
|
||||
self.world_knowledge_entries = self._read_file(mem_dir / "WORLD_KNOWLEDGE.md")
|
||||
|
||||
# Deduplicate entries (preserves order, keeps first occurrence)
|
||||
self.memory_entries = list(dict.fromkeys(self.memory_entries))
|
||||
self.user_entries = list(dict.fromkeys(self.user_entries))
|
||||
self.world_knowledge_entries = list(dict.fromkeys(self.world_knowledge_entries))
|
||||
|
||||
# Capture frozen snapshot for system prompt injection
|
||||
self._system_prompt_snapshot = {
|
||||
"memory": self._render_block("memory", self.memory_entries),
|
||||
"user": self._render_block("user", self.user_entries),
|
||||
"world_knowledge": self._render_block("world_knowledge", self.world_knowledge_entries),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -181,6 +193,8 @@ class MemoryStore:
|
||||
mem_dir = get_memory_dir()
|
||||
if target == "user":
|
||||
return mem_dir / "USER.md"
|
||||
if target == "world_knowledge":
|
||||
return mem_dir / "WORLD_KNOWLEDGE.md"
|
||||
return mem_dir / "MEMORY.md"
|
||||
|
||||
def _reload_target(self, target: str):
|
||||
@@ -200,11 +214,15 @@ class MemoryStore:
|
||||
def _entries_for(self, target: str) -> List[str]:
|
||||
if target == "user":
|
||||
return self.user_entries
|
||||
if target == "world_knowledge":
|
||||
return self.world_knowledge_entries
|
||||
return self.memory_entries
|
||||
|
||||
def _set_entries(self, target: str, entries: List[str]):
|
||||
if target == "user":
|
||||
self.user_entries = entries
|
||||
elif target == "world_knowledge":
|
||||
self.world_knowledge_entries = entries
|
||||
else:
|
||||
self.memory_entries = entries
|
||||
|
||||
@@ -217,6 +235,8 @@ class MemoryStore:
|
||||
def _char_limit(self, target: str) -> int:
|
||||
if target == "user":
|
||||
return self.user_char_limit
|
||||
if target == "world_knowledge":
|
||||
return self.world_knowledge_char_limit
|
||||
return self.memory_char_limit
|
||||
|
||||
def add(self, target: str, content: str) -> Dict[str, Any]:
|
||||
@@ -400,6 +420,8 @@ class MemoryStore:
|
||||
|
||||
if target == "user":
|
||||
header = f"USER PROFILE (who the user is) [{pct}% — {current:,}/{limit:,} chars]"
|
||||
elif target == "world_knowledge":
|
||||
header = f"WORLD KNOWLEDGE (stable facts about the world) [{pct}% — {current:,}/{limit:,} chars]"
|
||||
else:
|
||||
header = f"MEMORY (your personal notes) [{pct}% — {current:,}/{limit:,} chars]"
|
||||
|
||||
@@ -475,8 +497,8 @@ def memory_tool(
|
||||
if store is None:
|
||||
return tool_error("Memory is not available. It may be disabled in config or this environment.", success=False)
|
||||
|
||||
if target not in ("memory", "user"):
|
||||
return tool_error(f"Invalid target '{target}'. Use 'memory' or 'user'.", success=False)
|
||||
if target not in ("memory", "user", "world_knowledge"):
|
||||
return tool_error(f"Invalid target '{target}'. Use 'memory', 'user', or 'world_knowledge'.", success=False)
|
||||
|
||||
if action == "add":
|
||||
if not content:
|
||||
@@ -528,9 +550,11 @@ MEMORY_SCHEMA = {
|
||||
"state to memory; use session_search to recall those from past transcripts.\n"
|
||||
"If you've discovered a new way to do something, solved a problem that could be "
|
||||
"necessary later, save it as a skill with the skill tool.\n\n"
|
||||
"TWO TARGETS:\n"
|
||||
"- 'user': who the user is -- name, role, preferences, communication style, pet peeves\n"
|
||||
"- 'memory': your notes -- environment facts, project conventions, tool quirks, lessons learned\n\n"
|
||||
"THREE TARGETS (explicit store routing):\n"
|
||||
"- 'user': who the user is -- name, role, preferences, communication style, pet peeves (durable_memory)\n"
|
||||
"- 'memory': your notes -- environment facts, project conventions, tool quirks, lessons learned (durable_memory)\n"
|
||||
"- 'world_knowledge': stable world facts not specific to this user/session -- technology facts,\n"
|
||||
" API behavior, language features, known libraries (world_knowledge store)\n\n"
|
||||
"ACTIONS: add (new entry), replace (update existing -- old_text identifies it), "
|
||||
"remove (delete -- old_text identifies it).\n\n"
|
||||
"SKIP: trivial/obvious info, things easily re-discovered, raw data dumps, and temporary task state."
|
||||
@@ -545,8 +569,12 @@ MEMORY_SCHEMA = {
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"enum": ["memory", "user"],
|
||||
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile."
|
||||
"enum": ["memory", "user", "world_knowledge"],
|
||||
"description": (
|
||||
"Which memory store: 'memory' for personal notes (durable_memory), "
|
||||
"'user' for user profile (durable_memory), or 'world_knowledge' for "
|
||||
"stable world facts not specific to this user/session."
|
||||
),
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
|
||||
Reference in New Issue
Block a user