feat(gateway): proactive async memory flush on session expiry

Previously, when a session expired (idle/daily reset), the memory flush
ran synchronously inside get_or_create_session — blocking the user's
message for 10-60s while an LLM call saved memories.

Now a background watcher task (_session_expiry_watcher) runs every 5 min,
detects expired sessions, and flushes memories proactively in a thread
pool.  By the time the user sends their next message, memories are
already saved and the response is immediate.

Changes:
- Add _is_session_expired(entry) to SessionStore — works from entry
  alone without needing a SessionSource
- Add _pre_flushed_sessions set to track already-flushed sessions
- Remove sync _on_auto_reset callback from get_or_create_session
- Refactor flush into _flush_memories_for_session (sync worker) +
  _async_flush_memories (thread pool wrapper)
- Add _session_expiry_watcher background task, started in start()
- Simplify /reset command to use shared fire-and-forget flush
- Add 10 tests for expiry detection, callback removal, tracking
This commit is contained in:
teknium1
2026-03-07 11:27:50 -08:00
parent e64d646bad
commit d80c30cc92
3 changed files with 282 additions and 42 deletions

View File

@@ -178,7 +178,6 @@ class GatewayRunner:
self.session_store = SessionStore(
self.config.sessions_dir, self.config,
has_active_processes_fn=lambda key: process_registry.has_active_for_session(key),
on_auto_reset=self._flush_memories_before_reset,
)
self.delivery_router = DeliveryRouter(self.config)
self._running = False
@@ -209,15 +208,14 @@ class GatewayRunner:
from gateway.hooks import HookRegistry
self.hooks = HookRegistry()
def _flush_memories_before_reset(self, old_entry):
"""Prompt the agent to save memories/skills before an auto-reset.
Called synchronously by SessionStore before destroying an expired session.
Loads the transcript, gives the agent a real turn with memory + skills
tools, and explicitly asks it to preserve anything worth keeping.
def _flush_memories_for_session(self, old_session_id: str):
"""Prompt the agent to save memories/skills before context is lost.
Synchronous worker — meant to be called via run_in_executor from
an async context so it doesn't block the event loop.
"""
try:
history = self.session_store.load_transcript(old_entry.session_id)
history = self.session_store.load_transcript(old_session_id)
if not history or len(history) < 4:
return
@@ -231,7 +229,7 @@ class GatewayRunner:
max_iterations=8,
quiet_mode=True,
enabled_toolsets=["memory", "skills"],
session_id=old_entry.session_id,
session_id=old_session_id,
)
# Build conversation history from transcript
@@ -260,9 +258,14 @@ class GatewayRunner:
user_message=flush_prompt,
conversation_history=msgs,
)
logger.info("Pre-reset save completed for session %s", old_entry.session_id)
logger.info("Pre-reset memory flush completed for session %s", old_session_id)
except Exception as e:
logger.debug("Pre-reset save failed for session %s: %s", old_entry.session_id, e)
logger.debug("Pre-reset memory flush failed for session %s: %s", old_session_id, e)
async def _async_flush_memories(self, old_session_id: str):
"""Run the sync memory flush in a thread pool so it won't block the event loop."""
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._flush_memories_for_session, old_session_id)
@staticmethod
def _load_prefill_messages() -> List[Dict[str, Any]]:
@@ -464,10 +467,50 @@ class GatewayRunner:
# Check if we're restarting after a /update command
await self._send_update_notification()
# Start background session expiry watcher for proactive memory flushing
asyncio.create_task(self._session_expiry_watcher())
logger.info("Press Ctrl+C to stop")
return True
async def _session_expiry_watcher(self, interval: int = 300):
"""Background task that proactively flushes memories for expired sessions.
Runs every `interval` seconds (default 5 min). For each session that
has expired according to its reset policy, flushes memories in a thread
pool and marks the session so it won't be flushed again.
This means memories are already saved by the time the user sends their
next message, so there's no blocking delay.
"""
await asyncio.sleep(60) # initial delay — let the gateway fully start
while self._running:
try:
self.session_store._ensure_loaded()
for key, entry in list(self.session_store._entries.items()):
if entry.session_id in self.session_store._pre_flushed_sessions:
continue # already flushed this session
if not self.session_store._is_session_expired(entry):
continue # session still active
# Session has expired — flush memories in the background
logger.info(
"Session %s expired (key=%s), flushing memories proactively",
entry.session_id, key,
)
try:
await self._async_flush_memories(entry.session_id)
self.session_store._pre_flushed_sessions.add(entry.session_id)
except Exception as e:
logger.debug("Proactive memory flush failed for %s: %s", entry.session_id, e)
except Exception as e:
logger.debug("Session expiry watcher error: %s", e)
# Sleep in small increments so we can stop quickly
for _ in range(interval):
if not self._running:
break
await asyncio.sleep(1)
async def stop(self) -> None:
"""Stop the gateway and disconnect all adapters."""
logger.info("Stopping gateway...")
@@ -1012,33 +1055,12 @@ class GatewayRunner:
# Get existing session key
session_key = self.session_store._generate_session_key(source)
# Memory flush before reset: load the old transcript and let a
# temporary agent save memories before the session is wiped.
# Flush memories in the background (fire-and-forget) so the user
# gets the "Session reset!" response immediately.
try:
old_entry = self.session_store._entries.get(session_key)
if old_entry:
old_history = self.session_store.load_transcript(old_entry.session_id)
if old_history:
from run_agent import AIAgent
loop = asyncio.get_event_loop()
_flush_kwargs = _resolve_runtime_agent_kwargs()
def _do_flush():
tmp_agent = AIAgent(
**_flush_kwargs,
max_iterations=5,
quiet_mode=True,
enabled_toolsets=["memory"],
session_id=old_entry.session_id,
)
# Build simple message list from transcript
msgs = []
for m in old_history:
role = m.get("role")
content = m.get("content")
if role in ("user", "assistant") and content:
msgs.append({"role": role, "content": content})
tmp_agent.flush_memories(msgs)
await loop.run_in_executor(None, _do_flush)
asyncio.create_task(self._async_flush_memories(old_entry.session_id))
except Exception as e:
logger.debug("Gateway memory flush on reset failed: %s", e)

View File

@@ -311,7 +311,9 @@ class SessionStore:
self._entries: Dict[str, SessionEntry] = {}
self._loaded = False
self._has_active_processes_fn = has_active_processes_fn
self._on_auto_reset = on_auto_reset # callback(old_entry) before auto-reset
# on_auto_reset is deprecated — memory flush now runs proactively
# via the background session expiry watcher in GatewayRunner.
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
# Initialize SQLite session database
self._db = None
@@ -353,6 +355,44 @@ class SessionStore:
"""Generate a session key from a source."""
return build_session_key(source)
def _is_session_expired(self, entry: SessionEntry) -> bool:
"""Check if a session has expired based on its reset policy.
Works from the entry alone — no SessionSource needed.
Used by the background expiry watcher to proactively flush memories.
Sessions with active background processes are never considered expired.
"""
if self._has_active_processes_fn:
if self._has_active_processes_fn(entry.session_key):
return False
policy = self.config.get_reset_policy(
platform=entry.platform,
session_type=entry.chat_type,
)
if policy.mode == "none":
return False
now = datetime.now()
if policy.mode in ("idle", "both"):
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
if now > idle_deadline:
return True
if policy.mode in ("daily", "both"):
today_reset = now.replace(
hour=policy.at_hour,
minute=0, second=0, microsecond=0,
)
if now.hour < policy.at_hour:
today_reset -= timedelta(days=1)
if entry.updated_at < today_reset:
return True
return False
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
"""
Check if a session should be reset based on policy.
@@ -439,13 +479,11 @@ class SessionStore:
self._save()
return entry
else:
# Session is being auto-reset — flush memories before destroying
# Session is being auto-reset. The background expiry watcher
# should have already flushed memories proactively; discard
# the marker so it doesn't accumulate.
was_auto_reset = True
if self._on_auto_reset:
try:
self._on_auto_reset(entry)
except Exception as e:
logger.debug("Auto-reset callback failed: %s", e)
self._pre_flushed_sessions.discard(entry.session_id)
if self._db:
try:
self._db.end_session(entry.session_id, "session_reset")