fix(gateway): thread-safe SessionStore — protect _entries with threading.Lock (#3052)

SessionStore._entries was read and mutated without synchronisation,
causing race conditions when multiple platforms (Telegram + Discord)
received messages concurrently on the same gateway process. Two threads
could simultaneously pass the session_key check and create duplicate
sessions for the same user, splitting conversation history.

- Added threading.Lock to protect all _entries / _loaded mutations
- Split _ensure_loaded() into public wrapper + internal _ensure_loaded_locked()
- SQLite I/O is performed outside the lock to avoid blocking during
  slow disk operations
- _save() stays inside the lock since it reads _entries for serialization

Cherry-picked from PR #3012 by Kewe63. Removed unrelated changes
(delivery.py case-sensitivity, hermes_state.py schema tracking) and
stripped the UTC timezone switch to keep the change focused on threading.

Co-authored-by: Kewe63 <Kewe63@users.noreply.github.com>
This commit is contained in:
Teknium
2026-03-25 15:15:37 -07:00
committed by GitHub
parent 14cf2d85ca
commit 73e66eb3c0

View File

@@ -13,6 +13,7 @@ import logging
import os
import json
import re
import threading
import uuid
from pathlib import Path
from datetime import datetime, timedelta
@@ -22,6 +23,11 @@ from typing import Dict, List, Optional, Any
logger = logging.getLogger(__name__)
def _now() -> datetime:
"""Return the current local time."""
return datetime.now()
# ---------------------------------------------------------------------------
# PII redaction helpers
# ---------------------------------------------------------------------------
@@ -471,6 +477,7 @@ class SessionStore:
self.config = config
self._entries: Dict[str, SessionEntry] = {}
self._loaded = False
self._lock = threading.Lock()
self._has_active_processes_fn = has_active_processes_fn
# on_auto_reset is deprecated — memory flush now runs proactively
# via the background session expiry watcher in GatewayRunner.
@@ -486,12 +493,17 @@ class SessionStore:
def _ensure_loaded(self) -> None:
"""Load sessions index from disk if not already loaded."""
with self._lock:
self._ensure_loaded_locked()
def _ensure_loaded_locked(self) -> None:
"""Load sessions index from disk. Must be called with self._lock held."""
if self._loaded:
return
self.sessions_dir.mkdir(parents=True, exist_ok=True)
sessions_file = self.sessions_dir / "sessions.json"
if sessions_file.exists():
try:
with open(sessions_file, "r", encoding="utf-8") as f:
@@ -504,7 +516,7 @@ class SessionStore:
continue
except Exception as e:
print(f"[gateway] Warning: Failed to load sessions: {e}")
self._loaded = True
def _save(self) -> None:
@@ -556,7 +568,7 @@ class SessionStore:
if policy.mode == "none":
return False
now = datetime.now()
now = _now()
if policy.mode in ("idle", "both"):
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
@@ -597,7 +609,7 @@ class SessionStore:
if policy.mode == "none":
return None
now = datetime.now()
now = _now()
if policy.mode in ("idle", "both"):
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
@@ -637,87 +649,97 @@ class SessionStore:
pass # fall through to heuristic
# Fallback: check if sessions.json was loaded with existing data.
# This covers the rare case where the DB is unavailable.
self._ensure_loaded()
return len(self._entries) > 1
with self._lock:
self._ensure_loaded_locked()
return len(self._entries) > 1
def get_or_create_session(
self,
self,
source: SessionSource,
force_new: bool = False
) -> SessionEntry:
"""
Get an existing session or create a new one.
Evaluates reset policy to determine if the existing session is stale.
Creates a session record in SQLite when a new session starts.
"""
self._ensure_loaded()
session_key = self._generate_session_key(source)
now = datetime.now()
if session_key in self._entries and not force_new:
entry = self._entries[session_key]
reset_reason = self._should_reset(entry, source)
if not reset_reason:
entry.updated_at = now
self._save()
return entry
now = _now()
# SQLite calls are made outside the lock to avoid holding it during I/O.
# All _entries / _loaded mutations are protected by self._lock.
db_end_session_id = None
db_create_kwargs = None
with self._lock:
self._ensure_loaded_locked()
if session_key in self._entries and not force_new:
entry = self._entries[session_key]
reset_reason = self._should_reset(entry, source)
if not reset_reason:
entry.updated_at = now
self._save()
return entry
else:
# Session is being auto-reset. The background expiry watcher
# should have already flushed memories proactively; discard
# the marker so it doesn't accumulate.
was_auto_reset = True
auto_reset_reason = reset_reason
# Track whether the expired session had any real conversation
reset_had_activity = entry.total_tokens > 0
db_end_session_id = entry.session_id
self._pre_flushed_sessions.discard(entry.session_id)
else:
# Session is being auto-reset. The background expiry watcher
# should have already flushed memories proactively; discard
# the marker so it doesn't accumulate.
was_auto_reset = True
auto_reset_reason = reset_reason
# Track whether the expired session had any real conversation
reset_had_activity = entry.total_tokens > 0
self._pre_flushed_sessions.discard(entry.session_id)
if self._db:
try:
self._db.end_session(entry.session_id, "session_reset")
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
else:
was_auto_reset = False
auto_reset_reason = None
reset_had_activity = False
# Create new session
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
entry = SessionEntry(
session_key=session_key,
session_id=session_id,
created_at=now,
updated_at=now,
origin=source,
display_name=source.chat_name,
platform=source.platform,
chat_type=source.chat_type,
was_auto_reset=was_auto_reset,
auto_reset_reason=auto_reset_reason,
reset_had_activity=reset_had_activity,
)
self._entries[session_key] = entry
self._save()
# Create session in SQLite
if self._db:
was_auto_reset = False
auto_reset_reason = None
reset_had_activity = False
# Create new session
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
entry = SessionEntry(
session_key=session_key,
session_id=session_id,
created_at=now,
updated_at=now,
origin=source,
display_name=source.chat_name,
platform=source.platform,
chat_type=source.chat_type,
was_auto_reset=was_auto_reset,
auto_reset_reason=auto_reset_reason,
reset_had_activity=reset_had_activity,
)
self._entries[session_key] = entry
self._save()
db_create_kwargs = {
"session_id": session_id,
"source": source.platform.value,
"user_id": source.user_id,
}
# SQLite operations outside the lock
if self._db and db_end_session_id:
try:
self._db.create_session(
session_id=session_id,
source=source.platform.value,
user_id=source.user_id,
)
self._db.end_session(db_end_session_id, "session_reset")
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
if self._db and db_create_kwargs:
try:
self._db.create_session(**db_create_kwargs)
except Exception as e:
print(f"[gateway] Warning: Failed to create SQLite session: {e}")
return entry
def update_session(
self,
self,
session_key: str,
input_tokens: int = 0,
output_tokens: int = 0,
@@ -732,91 +754,100 @@ class SessionStore:
base_url: Optional[str] = None,
) -> None:
"""Update a session's metadata after an interaction."""
self._ensure_loaded()
if session_key in self._entries:
entry = self._entries[session_key]
entry.updated_at = datetime.now()
entry.input_tokens += input_tokens
entry.output_tokens += output_tokens
entry.cache_read_tokens += cache_read_tokens
entry.cache_write_tokens += cache_write_tokens
if last_prompt_tokens is not None:
entry.last_prompt_tokens = last_prompt_tokens
if estimated_cost_usd is not None:
entry.estimated_cost_usd += estimated_cost_usd
if cost_status:
entry.cost_status = cost_status
entry.total_tokens = (
entry.input_tokens
+ entry.output_tokens
+ entry.cache_read_tokens
+ entry.cache_write_tokens
)
self._save()
if self._db:
try:
self._db.update_token_counts(
entry.session_id,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
estimated_cost_usd=estimated_cost_usd,
cost_status=cost_status,
cost_source=cost_source,
billing_provider=provider,
billing_base_url=base_url,
model=model,
)
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
"""Force reset a session, creating a new session ID."""
self._ensure_loaded()
if session_key not in self._entries:
return None
old_entry = self._entries[session_key]
# End old session in SQLite
if self._db:
db_session_id = None
with self._lock:
self._ensure_loaded_locked()
if session_key in self._entries:
entry = self._entries[session_key]
entry.updated_at = _now()
entry.input_tokens += input_tokens
entry.output_tokens += output_tokens
entry.cache_read_tokens += cache_read_tokens
entry.cache_write_tokens += cache_write_tokens
if last_prompt_tokens is not None:
entry.last_prompt_tokens = last_prompt_tokens
if estimated_cost_usd is not None:
entry.estimated_cost_usd += estimated_cost_usd
if cost_status:
entry.cost_status = cost_status
entry.total_tokens = (
entry.input_tokens
+ entry.output_tokens
+ entry.cache_read_tokens
+ entry.cache_write_tokens
)
self._save()
db_session_id = entry.session_id
if self._db and db_session_id:
try:
self._db.end_session(old_entry.session_id, "session_reset")
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
now = datetime.now()
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
new_entry = SessionEntry(
session_key=session_key,
session_id=session_id,
created_at=now,
updated_at=now,
origin=old_entry.origin,
display_name=old_entry.display_name,
platform=old_entry.platform,
chat_type=old_entry.chat_type,
)
self._entries[session_key] = new_entry
self._save()
# Create new session in SQLite
if self._db:
try:
self._db.create_session(
session_id=session_id,
source=old_entry.platform.value if old_entry.platform else "unknown",
user_id=old_entry.origin.user_id if old_entry.origin else None,
self._db.update_token_counts(
db_session_id,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
estimated_cost_usd=estimated_cost_usd,
cost_status=cost_status,
cost_source=cost_source,
billing_provider=provider,
billing_base_url=base_url,
model=model,
)
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
"""Force reset a session, creating a new session ID."""
db_end_session_id = None
db_create_kwargs = None
new_entry = None
with self._lock:
self._ensure_loaded_locked()
if session_key not in self._entries:
return None
old_entry = self._entries[session_key]
db_end_session_id = old_entry.session_id
now = _now()
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
new_entry = SessionEntry(
session_key=session_key,
session_id=session_id,
created_at=now,
updated_at=now,
origin=old_entry.origin,
display_name=old_entry.display_name,
platform=old_entry.platform,
chat_type=old_entry.chat_type,
)
self._entries[session_key] = new_entry
self._save()
db_create_kwargs = {
"session_id": session_id,
"source": old_entry.platform.value if old_entry.platform else "unknown",
"user_id": old_entry.origin.user_id if old_entry.origin else None,
}
if self._db and db_end_session_id:
try:
self._db.end_session(db_end_session_id, "session_reset")
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
if self._db and db_create_kwargs:
try:
self._db.create_session(**db_create_kwargs)
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
return new_entry
def switch_session(self, session_key: str, target_session_id: str) -> Optional[SessionEntry]:
@@ -827,52 +858,58 @@ class SessionStore:
generating a fresh session ID, re-uses ``target_session_id`` so the
old transcript is loaded on the next message.
"""
self._ensure_loaded()
db_end_session_id = None
new_entry = None
if session_key not in self._entries:
return None
with self._lock:
self._ensure_loaded_locked()
old_entry = self._entries[session_key]
if session_key not in self._entries:
return None
# Don't switch if already on that session
if old_entry.session_id == target_session_id:
return old_entry
old_entry = self._entries[session_key]
# End the current session in SQLite
if self._db:
# Don't switch if already on that session
if old_entry.session_id == target_session_id:
return old_entry
db_end_session_id = old_entry.session_id
now = _now()
new_entry = SessionEntry(
session_key=session_key,
session_id=target_session_id,
created_at=now,
updated_at=now,
origin=old_entry.origin,
display_name=old_entry.display_name,
platform=old_entry.platform,
chat_type=old_entry.chat_type,
)
self._entries[session_key] = new_entry
self._save()
if self._db and db_end_session_id:
try:
self._db.end_session(old_entry.session_id, "session_switch")
self._db.end_session(db_end_session_id, "session_switch")
except Exception as e:
logger.debug("Session DB end_session failed: %s", e)
now = datetime.now()
new_entry = SessionEntry(
session_key=session_key,
session_id=target_session_id,
created_at=now,
updated_at=now,
origin=old_entry.origin,
display_name=old_entry.display_name,
platform=old_entry.platform,
chat_type=old_entry.chat_type,
)
self._entries[session_key] = new_entry
self._save()
return new_entry
def list_sessions(self, active_minutes: Optional[int] = None) -> List[SessionEntry]:
"""List all sessions, optionally filtered by activity."""
self._ensure_loaded()
entries = list(self._entries.values())
with self._lock:
self._ensure_loaded_locked()
entries = list(self._entries.values())
if active_minutes is not None:
cutoff = datetime.now() - timedelta(minutes=active_minutes)
cutoff = _now() - timedelta(minutes=active_minutes)
entries = [e for e in entries if e.updated_at >= cutoff]
entries.sort(key=lambda e: e.updated_at, reverse=True)
return entries
def get_transcript_path(self, session_id: str) -> Path:
@@ -891,17 +928,13 @@ class SessionStore:
# Write to SQLite (unless the agent already handled it)
if self._db and not skip_db:
try:
_role = message.get("role", "unknown")
self._db.append_message(
session_id=session_id,
role=_role,
role=message.get("role", "unknown"),
content=message.get("content"),
tool_name=message.get("tool_name"),
tool_calls=message.get("tool_calls"),
tool_call_id=message.get("tool_call_id"),
reasoning=message.get("reasoning") if _role == "assistant" else None,
reasoning_details=message.get("reasoning_details") if _role == "assistant" else None,
codex_reasoning_items=message.get("codex_reasoning_items") if _role == "assistant" else None,
)
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
@@ -922,17 +955,13 @@ class SessionStore:
try:
self._db.clear_messages(session_id)
for msg in messages:
_role = msg.get("role", "unknown")
self._db.append_message(
session_id=session_id,
role=_role,
role=msg.get("role", "unknown"),
content=msg.get("content"),
tool_name=msg.get("tool_name"),
tool_calls=msg.get("tool_calls"),
tool_call_id=msg.get("tool_call_id"),
reasoning=msg.get("reasoning") if _role == "assistant" else None,
reasoning_details=msg.get("reasoning_details") if _role == "assistant" else None,
codex_reasoning_items=msg.get("codex_reasoning_items") if _role == "assistant" else None,
)
except Exception as e:
logger.debug("Failed to rewrite transcript in DB: %s", e)