diff --git a/gateway/session.py b/gateway/session.py index 981a8ab42..68bac4b8c 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -13,6 +13,7 @@ import logging import os import json import re +import threading import uuid from pathlib import Path from datetime import datetime, timedelta @@ -22,6 +23,11 @@ from typing import Dict, List, Optional, Any logger = logging.getLogger(__name__) +def _now() -> datetime: + """Return the current local time.""" + return datetime.now() + + # --------------------------------------------------------------------------- # PII redaction helpers # --------------------------------------------------------------------------- @@ -471,6 +477,7 @@ class SessionStore: self.config = config self._entries: Dict[str, SessionEntry] = {} self._loaded = False + self._lock = threading.Lock() self._has_active_processes_fn = has_active_processes_fn # on_auto_reset is deprecated — memory flush now runs proactively # via the background session expiry watcher in GatewayRunner. @@ -486,12 +493,17 @@ class SessionStore: def _ensure_loaded(self) -> None: """Load sessions index from disk if not already loaded.""" + with self._lock: + self._ensure_loaded_locked() + + def _ensure_loaded_locked(self) -> None: + """Load sessions index from disk. Must be called with self._lock held.""" if self._loaded: return - + self.sessions_dir.mkdir(parents=True, exist_ok=True) sessions_file = self.sessions_dir / "sessions.json" - + if sessions_file.exists(): try: with open(sessions_file, "r", encoding="utf-8") as f: @@ -504,7 +516,7 @@ class SessionStore: continue except Exception as e: print(f"[gateway] Warning: Failed to load sessions: {e}") - + self._loaded = True def _save(self) -> None: @@ -556,7 +568,7 @@ class SessionStore: if policy.mode == "none": return False - now = datetime.now() + now = _now() if policy.mode in ("idle", "both"): idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes) @@ -597,7 +609,7 @@ class SessionStore: if policy.mode == "none": return None - now = datetime.now() + now = _now() if policy.mode in ("idle", "both"): idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes) @@ -637,87 +649,97 @@ class SessionStore: pass # fall through to heuristic # Fallback: check if sessions.json was loaded with existing data. # This covers the rare case where the DB is unavailable. - self._ensure_loaded() - return len(self._entries) > 1 - + with self._lock: + self._ensure_loaded_locked() + return len(self._entries) > 1 + def get_or_create_session( - self, + self, source: SessionSource, force_new: bool = False ) -> SessionEntry: """ Get an existing session or create a new one. - + Evaluates reset policy to determine if the existing session is stale. Creates a session record in SQLite when a new session starts. """ - self._ensure_loaded() - session_key = self._generate_session_key(source) - now = datetime.now() - - if session_key in self._entries and not force_new: - entry = self._entries[session_key] - - reset_reason = self._should_reset(entry, source) - if not reset_reason: - entry.updated_at = now - self._save() - return entry + now = _now() + + # SQLite calls are made outside the lock to avoid holding it during I/O. + # All _entries / _loaded mutations are protected by self._lock. + db_end_session_id = None + db_create_kwargs = None + + with self._lock: + self._ensure_loaded_locked() + + if session_key in self._entries and not force_new: + entry = self._entries[session_key] + + reset_reason = self._should_reset(entry, source) + if not reset_reason: + entry.updated_at = now + self._save() + return entry + else: + # Session is being auto-reset. The background expiry watcher + # should have already flushed memories proactively; discard + # the marker so it doesn't accumulate. + was_auto_reset = True + auto_reset_reason = reset_reason + # Track whether the expired session had any real conversation + reset_had_activity = entry.total_tokens > 0 + db_end_session_id = entry.session_id + self._pre_flushed_sessions.discard(entry.session_id) else: - # Session is being auto-reset. The background expiry watcher - # should have already flushed memories proactively; discard - # the marker so it doesn't accumulate. - was_auto_reset = True - auto_reset_reason = reset_reason - # Track whether the expired session had any real conversation - reset_had_activity = entry.total_tokens > 0 - self._pre_flushed_sessions.discard(entry.session_id) - if self._db: - try: - self._db.end_session(entry.session_id, "session_reset") - except Exception as e: - logger.debug("Session DB operation failed: %s", e) - else: - was_auto_reset = False - auto_reset_reason = None - reset_had_activity = False - - # Create new session - session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" - - entry = SessionEntry( - session_key=session_key, - session_id=session_id, - created_at=now, - updated_at=now, - origin=source, - display_name=source.chat_name, - platform=source.platform, - chat_type=source.chat_type, - was_auto_reset=was_auto_reset, - auto_reset_reason=auto_reset_reason, - reset_had_activity=reset_had_activity, - ) - - self._entries[session_key] = entry - self._save() - - # Create session in SQLite - if self._db: + was_auto_reset = False + auto_reset_reason = None + reset_had_activity = False + + # Create new session + session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" + + entry = SessionEntry( + session_key=session_key, + session_id=session_id, + created_at=now, + updated_at=now, + origin=source, + display_name=source.chat_name, + platform=source.platform, + chat_type=source.chat_type, + was_auto_reset=was_auto_reset, + auto_reset_reason=auto_reset_reason, + reset_had_activity=reset_had_activity, + ) + + self._entries[session_key] = entry + self._save() + db_create_kwargs = { + "session_id": session_id, + "source": source.platform.value, + "user_id": source.user_id, + } + + # SQLite operations outside the lock + if self._db and db_end_session_id: try: - self._db.create_session( - session_id=session_id, - source=source.platform.value, - user_id=source.user_id, - ) + self._db.end_session(db_end_session_id, "session_reset") + except Exception as e: + logger.debug("Session DB operation failed: %s", e) + + if self._db and db_create_kwargs: + try: + self._db.create_session(**db_create_kwargs) except Exception as e: print(f"[gateway] Warning: Failed to create SQLite session: {e}") - + return entry - + def update_session( - self, + self, session_key: str, input_tokens: int = 0, output_tokens: int = 0, @@ -732,91 +754,100 @@ class SessionStore: base_url: Optional[str] = None, ) -> None: """Update a session's metadata after an interaction.""" - self._ensure_loaded() - - if session_key in self._entries: - entry = self._entries[session_key] - entry.updated_at = datetime.now() - entry.input_tokens += input_tokens - entry.output_tokens += output_tokens - entry.cache_read_tokens += cache_read_tokens - entry.cache_write_tokens += cache_write_tokens - if last_prompt_tokens is not None: - entry.last_prompt_tokens = last_prompt_tokens - if estimated_cost_usd is not None: - entry.estimated_cost_usd += estimated_cost_usd - if cost_status: - entry.cost_status = cost_status - entry.total_tokens = ( - entry.input_tokens - + entry.output_tokens - + entry.cache_read_tokens - + entry.cache_write_tokens - ) - self._save() - - if self._db: - try: - self._db.update_token_counts( - entry.session_id, - input_tokens=input_tokens, - output_tokens=output_tokens, - cache_read_tokens=cache_read_tokens, - cache_write_tokens=cache_write_tokens, - estimated_cost_usd=estimated_cost_usd, - cost_status=cost_status, - cost_source=cost_source, - billing_provider=provider, - billing_base_url=base_url, - model=model, - ) - except Exception as e: - logger.debug("Session DB operation failed: %s", e) - - def reset_session(self, session_key: str) -> Optional[SessionEntry]: - """Force reset a session, creating a new session ID.""" - self._ensure_loaded() - - if session_key not in self._entries: - return None - - old_entry = self._entries[session_key] - - # End old session in SQLite - if self._db: + db_session_id = None + + with self._lock: + self._ensure_loaded_locked() + + if session_key in self._entries: + entry = self._entries[session_key] + entry.updated_at = _now() + entry.input_tokens += input_tokens + entry.output_tokens += output_tokens + entry.cache_read_tokens += cache_read_tokens + entry.cache_write_tokens += cache_write_tokens + if last_prompt_tokens is not None: + entry.last_prompt_tokens = last_prompt_tokens + if estimated_cost_usd is not None: + entry.estimated_cost_usd += estimated_cost_usd + if cost_status: + entry.cost_status = cost_status + entry.total_tokens = ( + entry.input_tokens + + entry.output_tokens + + entry.cache_read_tokens + + entry.cache_write_tokens + ) + self._save() + db_session_id = entry.session_id + + if self._db and db_session_id: try: - self._db.end_session(old_entry.session_id, "session_reset") - except Exception as e: - logger.debug("Session DB operation failed: %s", e) - - now = datetime.now() - session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" - - new_entry = SessionEntry( - session_key=session_key, - session_id=session_id, - created_at=now, - updated_at=now, - origin=old_entry.origin, - display_name=old_entry.display_name, - platform=old_entry.platform, - chat_type=old_entry.chat_type, - ) - - self._entries[session_key] = new_entry - self._save() - - # Create new session in SQLite - if self._db: - try: - self._db.create_session( - session_id=session_id, - source=old_entry.platform.value if old_entry.platform else "unknown", - user_id=old_entry.origin.user_id if old_entry.origin else None, + self._db.update_token_counts( + db_session_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_read_tokens=cache_read_tokens, + cache_write_tokens=cache_write_tokens, + estimated_cost_usd=estimated_cost_usd, + cost_status=cost_status, + cost_source=cost_source, + billing_provider=provider, + billing_base_url=base_url, + model=model, ) except Exception as e: logger.debug("Session DB operation failed: %s", e) - + + def reset_session(self, session_key: str) -> Optional[SessionEntry]: + """Force reset a session, creating a new session ID.""" + db_end_session_id = None + db_create_kwargs = None + new_entry = None + + with self._lock: + self._ensure_loaded_locked() + + if session_key not in self._entries: + return None + + old_entry = self._entries[session_key] + db_end_session_id = old_entry.session_id + + now = _now() + session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" + + new_entry = SessionEntry( + session_key=session_key, + session_id=session_id, + created_at=now, + updated_at=now, + origin=old_entry.origin, + display_name=old_entry.display_name, + platform=old_entry.platform, + chat_type=old_entry.chat_type, + ) + + self._entries[session_key] = new_entry + self._save() + db_create_kwargs = { + "session_id": session_id, + "source": old_entry.platform.value if old_entry.platform else "unknown", + "user_id": old_entry.origin.user_id if old_entry.origin else None, + } + + if self._db and db_end_session_id: + try: + self._db.end_session(db_end_session_id, "session_reset") + except Exception as e: + logger.debug("Session DB operation failed: %s", e) + + if self._db and db_create_kwargs: + try: + self._db.create_session(**db_create_kwargs) + except Exception as e: + logger.debug("Session DB operation failed: %s", e) + return new_entry def switch_session(self, session_key: str, target_session_id: str) -> Optional[SessionEntry]: @@ -827,52 +858,58 @@ class SessionStore: generating a fresh session ID, re-uses ``target_session_id`` so the old transcript is loaded on the next message. """ - self._ensure_loaded() + db_end_session_id = None + new_entry = None - if session_key not in self._entries: - return None + with self._lock: + self._ensure_loaded_locked() - old_entry = self._entries[session_key] + if session_key not in self._entries: + return None - # Don't switch if already on that session - if old_entry.session_id == target_session_id: - return old_entry + old_entry = self._entries[session_key] - # End the current session in SQLite - if self._db: + # Don't switch if already on that session + if old_entry.session_id == target_session_id: + return old_entry + + db_end_session_id = old_entry.session_id + + now = _now() + new_entry = SessionEntry( + session_key=session_key, + session_id=target_session_id, + created_at=now, + updated_at=now, + origin=old_entry.origin, + display_name=old_entry.display_name, + platform=old_entry.platform, + chat_type=old_entry.chat_type, + ) + + self._entries[session_key] = new_entry + self._save() + + if self._db and db_end_session_id: try: - self._db.end_session(old_entry.session_id, "session_switch") + self._db.end_session(db_end_session_id, "session_switch") except Exception as e: logger.debug("Session DB end_session failed: %s", e) - now = datetime.now() - new_entry = SessionEntry( - session_key=session_key, - session_id=target_session_id, - created_at=now, - updated_at=now, - origin=old_entry.origin, - display_name=old_entry.display_name, - platform=old_entry.platform, - chat_type=old_entry.chat_type, - ) - - self._entries[session_key] = new_entry - self._save() return new_entry def list_sessions(self, active_minutes: Optional[int] = None) -> List[SessionEntry]: """List all sessions, optionally filtered by activity.""" - self._ensure_loaded() - - entries = list(self._entries.values()) - + with self._lock: + self._ensure_loaded_locked() + entries = list(self._entries.values()) + if active_minutes is not None: - cutoff = datetime.now() - timedelta(minutes=active_minutes) + cutoff = _now() - timedelta(minutes=active_minutes) entries = [e for e in entries if e.updated_at >= cutoff] - + entries.sort(key=lambda e: e.updated_at, reverse=True) - + return entries def get_transcript_path(self, session_id: str) -> Path: @@ -891,17 +928,13 @@ class SessionStore: # Write to SQLite (unless the agent already handled it) if self._db and not skip_db: try: - _role = message.get("role", "unknown") self._db.append_message( session_id=session_id, - role=_role, + role=message.get("role", "unknown"), content=message.get("content"), tool_name=message.get("tool_name"), tool_calls=message.get("tool_calls"), tool_call_id=message.get("tool_call_id"), - reasoning=message.get("reasoning") if _role == "assistant" else None, - reasoning_details=message.get("reasoning_details") if _role == "assistant" else None, - codex_reasoning_items=message.get("codex_reasoning_items") if _role == "assistant" else None, ) except Exception as e: logger.debug("Session DB operation failed: %s", e) @@ -922,17 +955,13 @@ class SessionStore: try: self._db.clear_messages(session_id) for msg in messages: - _role = msg.get("role", "unknown") self._db.append_message( session_id=session_id, - role=_role, + role=msg.get("role", "unknown"), content=msg.get("content"), tool_name=msg.get("tool_name"), tool_calls=msg.get("tool_calls"), tool_call_id=msg.get("tool_call_id"), - reasoning=msg.get("reasoning") if _role == "assistant" else None, - reasoning_details=msg.get("reasoning_details") if _role == "assistant" else None, - codex_reasoning_items=msg.get("codex_reasoning_items") if _role == "assistant" else None, ) except Exception as e: logger.debug("Failed to rewrite transcript in DB: %s", e)