Threads (Telegram forum topics, Discord threads, Slack threads) now default to shared sessions where all participants see the same conversation. This is the expected UX for threaded conversations where multiple users @mention the bot and interact collaboratively. Changes: - build_session_key(): when thread_id is present, user_id is no longer appended to the session key (threads are shared by default) - New config: thread_sessions_per_user (default: false) — opt-in to restore per-user isolation in threads if needed - Sender attribution: messages in shared threads are prefixed with [sender name] so the agent can tell participants apart - System prompt: shared threads show 'Multi-user thread' note instead of a per-turn User line (avoids busting prompt cache) - Wired through all callers: gateway/run.py, base.py, telegram.py, feishu.py - Regular group messages (no thread) remain per-user isolated (unchanged) - DM threads are unaffected (they have their own keying logic) Closes community request from demontut_ re: thread-based shared sessions.
1082 lines
41 KiB
Python
1082 lines
41 KiB
Python
"""
|
|
Session management for the gateway.
|
|
|
|
Handles:
|
|
- Session context tracking (where messages come from)
|
|
- Session storage (conversations persisted to disk)
|
|
- Reset policy evaluation (when to start fresh)
|
|
- Dynamic system prompt injection (agent knows its context)
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import json
|
|
import re
|
|
import threading
|
|
import uuid
|
|
from pathlib import Path
|
|
from datetime import datetime, timedelta
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _now() -> datetime:
|
|
"""Return the current local time."""
|
|
return datetime.now()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PII redaction helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_PHONE_RE = re.compile(r"^\+?\d[\d\-\s]{6,}$")
|
|
|
|
|
|
def _hash_id(value: str) -> str:
|
|
"""Deterministic 12-char hex hash of an identifier."""
|
|
return hashlib.sha256(value.encode("utf-8")).hexdigest()[:12]
|
|
|
|
|
|
def _hash_sender_id(value: str) -> str:
|
|
"""Hash a sender ID to ``user_<12hex>``."""
|
|
return f"user_{_hash_id(value)}"
|
|
|
|
|
|
def _hash_chat_id(value: str) -> str:
|
|
"""Hash the numeric portion of a chat ID, preserving platform prefix.
|
|
|
|
``telegram:12345`` → ``telegram:<hash>``
|
|
``12345`` → ``<hash>``
|
|
"""
|
|
colon = value.find(":")
|
|
if colon > 0:
|
|
prefix = value[:colon]
|
|
return f"{prefix}:{_hash_id(value[colon + 1:])}"
|
|
return _hash_id(value)
|
|
|
|
|
|
def _looks_like_phone(value: str) -> bool:
|
|
"""Return True if *value* looks like a phone number (E.164 or similar)."""
|
|
return bool(_PHONE_RE.match(value.strip()))
|
|
|
|
from .config import (
|
|
Platform,
|
|
GatewayConfig,
|
|
SessionResetPolicy, # noqa: F401 — re-exported via gateway/__init__.py
|
|
HomeChannel,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SessionSource:
|
|
"""
|
|
Describes where a message originated from.
|
|
|
|
This information is used to:
|
|
1. Route responses back to the right place
|
|
2. Inject context into the system prompt
|
|
3. Track origin for cron job delivery
|
|
"""
|
|
platform: Platform
|
|
chat_id: str
|
|
chat_name: Optional[str] = None
|
|
chat_type: str = "dm" # "dm", "group", "channel", "thread"
|
|
user_id: Optional[str] = None
|
|
user_name: Optional[str] = None
|
|
thread_id: Optional[str] = None # For forum topics, Discord threads, etc.
|
|
chat_topic: Optional[str] = None # Channel topic/description (Discord, Slack)
|
|
user_id_alt: Optional[str] = None # Signal UUID (alternative to phone number)
|
|
chat_id_alt: Optional[str] = None # Signal group internal ID
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
"""Human-readable description of the source."""
|
|
if self.platform == Platform.LOCAL:
|
|
return "CLI terminal"
|
|
|
|
parts = []
|
|
if self.chat_type == "dm":
|
|
parts.append(f"DM with {self.user_name or self.user_id or 'user'}")
|
|
elif self.chat_type == "group":
|
|
parts.append(f"group: {self.chat_name or self.chat_id}")
|
|
elif self.chat_type == "channel":
|
|
parts.append(f"channel: {self.chat_name or self.chat_id}")
|
|
else:
|
|
parts.append(self.chat_name or self.chat_id)
|
|
|
|
if self.thread_id:
|
|
parts.append(f"thread: {self.thread_id}")
|
|
|
|
return ", ".join(parts)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
d = {
|
|
"platform": self.platform.value,
|
|
"chat_id": self.chat_id,
|
|
"chat_name": self.chat_name,
|
|
"chat_type": self.chat_type,
|
|
"user_id": self.user_id,
|
|
"user_name": self.user_name,
|
|
"thread_id": self.thread_id,
|
|
"chat_topic": self.chat_topic,
|
|
}
|
|
if self.user_id_alt:
|
|
d["user_id_alt"] = self.user_id_alt
|
|
if self.chat_id_alt:
|
|
d["chat_id_alt"] = self.chat_id_alt
|
|
return d
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "SessionSource":
|
|
return cls(
|
|
platform=Platform(data["platform"]),
|
|
chat_id=str(data["chat_id"]),
|
|
chat_name=data.get("chat_name"),
|
|
chat_type=data.get("chat_type", "dm"),
|
|
user_id=data.get("user_id"),
|
|
user_name=data.get("user_name"),
|
|
thread_id=data.get("thread_id"),
|
|
chat_topic=data.get("chat_topic"),
|
|
user_id_alt=data.get("user_id_alt"),
|
|
chat_id_alt=data.get("chat_id_alt"),
|
|
)
|
|
|
|
@classmethod
|
|
def local_cli(cls) -> "SessionSource":
|
|
"""Create a source representing the local CLI."""
|
|
return cls(
|
|
platform=Platform.LOCAL,
|
|
chat_id="cli",
|
|
chat_name="CLI terminal",
|
|
chat_type="dm",
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SessionContext:
|
|
"""
|
|
Full context for a session, used for dynamic system prompt injection.
|
|
|
|
The agent receives this information to understand:
|
|
- Where messages are coming from
|
|
- What platforms are available
|
|
- Where it can deliver scheduled task outputs
|
|
"""
|
|
source: SessionSource
|
|
connected_platforms: List[Platform]
|
|
home_channels: Dict[Platform, HomeChannel]
|
|
|
|
# Session metadata
|
|
session_key: str = ""
|
|
session_id: str = ""
|
|
created_at: Optional[datetime] = None
|
|
updated_at: Optional[datetime] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"source": self.source.to_dict(),
|
|
"connected_platforms": [p.value for p in self.connected_platforms],
|
|
"home_channels": {
|
|
p.value: hc.to_dict() for p, hc in self.home_channels.items()
|
|
},
|
|
"session_key": self.session_key,
|
|
"session_id": self.session_id,
|
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
|
}
|
|
|
|
|
|
_PII_SAFE_PLATFORMS = frozenset({
|
|
Platform.WHATSAPP,
|
|
Platform.SIGNAL,
|
|
Platform.TELEGRAM,
|
|
})
|
|
"""Platforms where user IDs can be safely redacted (no in-message mention system
|
|
that requires raw IDs). Discord is excluded because mentions use ``<@user_id>``
|
|
and the LLM needs the real ID to tag users."""
|
|
|
|
|
|
def build_session_context_prompt(
|
|
context: SessionContext,
|
|
*,
|
|
redact_pii: bool = False,
|
|
) -> str:
|
|
"""
|
|
Build the dynamic system prompt section that tells the agent about its context.
|
|
|
|
This is injected into the system prompt so the agent knows:
|
|
- Where messages are coming from
|
|
- What platforms are connected
|
|
- Where it can deliver scheduled task outputs
|
|
|
|
When *redact_pii* is True **and** the source platform is in
|
|
``_PII_SAFE_PLATFORMS``, phone numbers are stripped and user/chat IDs
|
|
are replaced with deterministic hashes before being sent to the LLM.
|
|
Platforms like Discord are excluded because mentions need real IDs.
|
|
Routing still uses the original values (they stay in SessionSource).
|
|
"""
|
|
# Only apply redaction on platforms where IDs aren't needed for mentions
|
|
redact_pii = redact_pii and context.source.platform in _PII_SAFE_PLATFORMS
|
|
lines = [
|
|
"## Current Session Context",
|
|
"",
|
|
]
|
|
|
|
# Source info
|
|
platform_name = context.source.platform.value.title()
|
|
if context.source.platform == Platform.LOCAL:
|
|
lines.append(f"**Source:** {platform_name} (the machine running this agent)")
|
|
else:
|
|
# Build a description that respects PII redaction
|
|
src = context.source
|
|
if redact_pii:
|
|
# Build a safe description without raw IDs
|
|
_uname = src.user_name or (
|
|
_hash_sender_id(src.user_id) if src.user_id else "user"
|
|
)
|
|
_cname = src.chat_name or _hash_chat_id(src.chat_id)
|
|
if src.chat_type == "dm":
|
|
desc = f"DM with {_uname}"
|
|
elif src.chat_type == "group":
|
|
desc = f"group: {_cname}"
|
|
elif src.chat_type == "channel":
|
|
desc = f"channel: {_cname}"
|
|
else:
|
|
desc = _cname
|
|
else:
|
|
desc = src.description
|
|
lines.append(f"**Source:** {platform_name} ({desc})")
|
|
|
|
# Channel topic (if available - provides context about the channel's purpose)
|
|
if context.source.chat_topic:
|
|
lines.append(f"**Channel Topic:** {context.source.chat_topic}")
|
|
|
|
# User identity.
|
|
# In shared thread sessions (non-DM with thread_id), multiple users
|
|
# contribute to the same conversation. Don't pin a single user name
|
|
# in the system prompt — it changes per-turn and would bust the prompt
|
|
# cache. Instead, note that this is a multi-user thread; individual
|
|
# sender names are prefixed on each user message by the gateway.
|
|
_is_shared_thread = (
|
|
context.source.chat_type != "dm"
|
|
and context.source.thread_id
|
|
)
|
|
if _is_shared_thread:
|
|
lines.append(
|
|
"**Session type:** Multi-user thread — messages are prefixed "
|
|
"with [sender name]. Multiple users may participate."
|
|
)
|
|
elif context.source.user_name:
|
|
lines.append(f"**User:** {context.source.user_name}")
|
|
elif context.source.user_id:
|
|
uid = context.source.user_id
|
|
if redact_pii:
|
|
uid = _hash_sender_id(uid)
|
|
lines.append(f"**User ID:** {uid}")
|
|
|
|
# Platform-specific behavioral notes
|
|
if context.source.platform == Platform.SLACK:
|
|
lines.append("")
|
|
lines.append(
|
|
"**Platform notes:** You are running inside Slack. "
|
|
"You do NOT have access to Slack-specific APIs — you cannot search "
|
|
"channel history, pin/unpin messages, manage channels, or list users. "
|
|
"Do not promise to perform these actions. If the user asks, explain "
|
|
"that you can only read messages sent directly to you and respond."
|
|
)
|
|
elif context.source.platform == Platform.DISCORD:
|
|
lines.append("")
|
|
lines.append(
|
|
"**Platform notes:** You are running inside Discord. "
|
|
"You do NOT have access to Discord-specific APIs — you cannot search "
|
|
"channel history, pin messages, manage roles, or list server members. "
|
|
"Do not promise to perform these actions. If the user asks, explain "
|
|
"that you can only read messages sent directly to you and respond."
|
|
)
|
|
|
|
# Connected platforms
|
|
platforms_list = ["local (files on this machine)"]
|
|
for p in context.connected_platforms:
|
|
if p != Platform.LOCAL:
|
|
platforms_list.append(f"{p.value}: Connected ✓")
|
|
|
|
lines.append(f"**Connected Platforms:** {', '.join(platforms_list)}")
|
|
|
|
# Home channels
|
|
if context.home_channels:
|
|
lines.append("")
|
|
lines.append("**Home Channels (default destinations):**")
|
|
for platform, home in context.home_channels.items():
|
|
hc_id = _hash_chat_id(home.chat_id) if redact_pii else home.chat_id
|
|
lines.append(f" - {platform.value}: {home.name} (ID: {hc_id})")
|
|
|
|
# Delivery options for scheduled tasks
|
|
lines.append("")
|
|
lines.append("**Delivery options for scheduled tasks:**")
|
|
|
|
# Origin delivery
|
|
if context.source.platform == Platform.LOCAL:
|
|
lines.append("- `\"origin\"` → Local output (saved to files)")
|
|
else:
|
|
_origin_label = context.source.chat_name or (
|
|
_hash_chat_id(context.source.chat_id) if redact_pii else context.source.chat_id
|
|
)
|
|
lines.append(f"- `\"origin\"` → Back to this chat ({_origin_label})")
|
|
|
|
# Local always available
|
|
lines.append("- `\"local\"` → Save to local files only (~/.hermes/cron/output/)")
|
|
|
|
# Platform home channels
|
|
for platform, home in context.home_channels.items():
|
|
lines.append(f"- `\"{platform.value}\"` → Home channel ({home.name})")
|
|
|
|
# Note about explicit targeting
|
|
lines.append("")
|
|
lines.append("*For explicit targeting, use `\"platform:chat_id\"` format if the user provides a specific chat ID.*")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
@dataclass
|
|
class SessionEntry:
|
|
"""
|
|
Entry in the session store.
|
|
|
|
Maps a session key to its current session ID and metadata.
|
|
"""
|
|
session_key: str
|
|
session_id: str
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
# Origin metadata for delivery routing
|
|
origin: Optional[SessionSource] = None
|
|
|
|
# Display metadata
|
|
display_name: Optional[str] = None
|
|
platform: Optional[Platform] = None
|
|
chat_type: str = "dm"
|
|
|
|
# Token tracking
|
|
input_tokens: int = 0
|
|
output_tokens: int = 0
|
|
cache_read_tokens: int = 0
|
|
cache_write_tokens: int = 0
|
|
total_tokens: int = 0
|
|
estimated_cost_usd: float = 0.0
|
|
cost_status: str = "unknown"
|
|
|
|
# Last API-reported prompt tokens (for accurate compression pre-check)
|
|
last_prompt_tokens: int = 0
|
|
|
|
# Set when a session was created because the previous one expired;
|
|
# consumed once by the message handler to inject a notice into context
|
|
was_auto_reset: bool = False
|
|
auto_reset_reason: Optional[str] = None # "idle" or "daily"
|
|
reset_had_activity: bool = False # whether the expired session had any messages
|
|
|
|
# Set by the background expiry watcher after it successfully flushes
|
|
# memories for this session. Persisted to sessions.json so the flag
|
|
# survives gateway restarts (the old in-memory _pre_flushed_sessions
|
|
# set was lost on restart, causing redundant re-flushes).
|
|
memory_flushed: bool = False
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
result = {
|
|
"session_key": self.session_key,
|
|
"session_id": self.session_id,
|
|
"created_at": self.created_at.isoformat(),
|
|
"updated_at": self.updated_at.isoformat(),
|
|
"display_name": self.display_name,
|
|
"platform": self.platform.value if self.platform else None,
|
|
"chat_type": self.chat_type,
|
|
"input_tokens": self.input_tokens,
|
|
"output_tokens": self.output_tokens,
|
|
"cache_read_tokens": self.cache_read_tokens,
|
|
"cache_write_tokens": self.cache_write_tokens,
|
|
"total_tokens": self.total_tokens,
|
|
"last_prompt_tokens": self.last_prompt_tokens,
|
|
"estimated_cost_usd": self.estimated_cost_usd,
|
|
"cost_status": self.cost_status,
|
|
"memory_flushed": self.memory_flushed,
|
|
}
|
|
if self.origin:
|
|
result["origin"] = self.origin.to_dict()
|
|
return result
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "SessionEntry":
|
|
origin = None
|
|
if "origin" in data and data["origin"]:
|
|
origin = SessionSource.from_dict(data["origin"])
|
|
|
|
platform = None
|
|
if data.get("platform"):
|
|
try:
|
|
platform = Platform(data["platform"])
|
|
except ValueError as e:
|
|
logger.debug("Unknown platform value %r: %s", data["platform"], e)
|
|
|
|
return cls(
|
|
session_key=data["session_key"],
|
|
session_id=data["session_id"],
|
|
created_at=datetime.fromisoformat(data["created_at"]),
|
|
updated_at=datetime.fromisoformat(data["updated_at"]),
|
|
origin=origin,
|
|
display_name=data.get("display_name"),
|
|
platform=platform,
|
|
chat_type=data.get("chat_type", "dm"),
|
|
input_tokens=data.get("input_tokens", 0),
|
|
output_tokens=data.get("output_tokens", 0),
|
|
cache_read_tokens=data.get("cache_read_tokens", 0),
|
|
cache_write_tokens=data.get("cache_write_tokens", 0),
|
|
total_tokens=data.get("total_tokens", 0),
|
|
last_prompt_tokens=data.get("last_prompt_tokens", 0),
|
|
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
|
|
cost_status=data.get("cost_status", "unknown"),
|
|
memory_flushed=data.get("memory_flushed", False),
|
|
)
|
|
|
|
|
|
def build_session_key(
|
|
source: SessionSource,
|
|
group_sessions_per_user: bool = True,
|
|
thread_sessions_per_user: bool = False,
|
|
) -> str:
|
|
"""Build a deterministic session key from a message source.
|
|
|
|
This is the single source of truth for session key construction.
|
|
|
|
DM rules:
|
|
- DMs include chat_id when present, so each private conversation is isolated.
|
|
- thread_id further differentiates threaded DMs within the same DM chat.
|
|
- Without chat_id, thread_id is used as a best-effort fallback.
|
|
- Without thread_id or chat_id, DMs share a single session.
|
|
|
|
Group/channel rules:
|
|
- chat_id identifies the parent group/channel.
|
|
- user_id/user_id_alt isolates participants within that parent chat when available when
|
|
``group_sessions_per_user`` is enabled.
|
|
- thread_id differentiates threads within that parent chat. When
|
|
``thread_sessions_per_user`` is False (default), threads are *shared* across all
|
|
participants — user_id is NOT appended, so every user in the thread
|
|
shares a single session. This is the expected UX for threaded
|
|
conversations (Telegram forum topics, Discord threads, Slack threads).
|
|
- Without participant identifiers, or when isolation is disabled, messages fall back to one
|
|
shared session per chat.
|
|
- Without identifiers, messages fall back to one session per platform/chat_type.
|
|
"""
|
|
platform = source.platform.value
|
|
if source.chat_type == "dm":
|
|
if source.chat_id:
|
|
if source.thread_id:
|
|
return f"agent:main:{platform}:dm:{source.chat_id}:{source.thread_id}"
|
|
return f"agent:main:{platform}:dm:{source.chat_id}"
|
|
if source.thread_id:
|
|
return f"agent:main:{platform}:dm:{source.thread_id}"
|
|
return f"agent:main:{platform}:dm"
|
|
|
|
participant_id = source.user_id_alt or source.user_id
|
|
key_parts = ["agent:main", platform, source.chat_type]
|
|
|
|
if source.chat_id:
|
|
key_parts.append(source.chat_id)
|
|
if source.thread_id:
|
|
key_parts.append(source.thread_id)
|
|
|
|
# In threads, default to shared sessions (all participants see the same
|
|
# conversation). Per-user isolation only applies when explicitly enabled
|
|
# via thread_sessions_per_user, or when there is no thread (regular group).
|
|
isolate_user = group_sessions_per_user
|
|
if source.thread_id and not thread_sessions_per_user:
|
|
isolate_user = False
|
|
|
|
if isolate_user and participant_id:
|
|
key_parts.append(str(participant_id))
|
|
|
|
return ":".join(key_parts)
|
|
|
|
|
|
class SessionStore:
|
|
"""
|
|
Manages session storage and retrieval.
|
|
|
|
Uses SQLite (via SessionDB) for session metadata and message transcripts.
|
|
Falls back to legacy JSONL files if SQLite is unavailable.
|
|
"""
|
|
|
|
def __init__(self, sessions_dir: Path, config: GatewayConfig,
|
|
has_active_processes_fn=None,
|
|
on_auto_reset=None):
|
|
self.sessions_dir = sessions_dir
|
|
self.config = config
|
|
self._entries: Dict[str, SessionEntry] = {}
|
|
self._loaded = False
|
|
self._lock = threading.Lock()
|
|
self._has_active_processes_fn = has_active_processes_fn
|
|
|
|
# Initialize SQLite session database
|
|
self._db = None
|
|
try:
|
|
from hermes_state import SessionDB
|
|
self._db = SessionDB()
|
|
except Exception as e:
|
|
print(f"[gateway] Warning: SQLite session store unavailable, falling back to JSONL: {e}")
|
|
|
|
def _ensure_loaded(self) -> None:
|
|
"""Load sessions index from disk if not already loaded."""
|
|
with self._lock:
|
|
self._ensure_loaded_locked()
|
|
|
|
def _ensure_loaded_locked(self) -> None:
|
|
"""Load sessions index from disk. Must be called with self._lock held."""
|
|
if self._loaded:
|
|
return
|
|
|
|
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
|
sessions_file = self.sessions_dir / "sessions.json"
|
|
|
|
if sessions_file.exists():
|
|
try:
|
|
with open(sessions_file, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
for key, entry_data in data.items():
|
|
try:
|
|
self._entries[key] = SessionEntry.from_dict(entry_data)
|
|
except (ValueError, KeyError):
|
|
# Skip entries with unknown/removed platform values
|
|
continue
|
|
except Exception as e:
|
|
print(f"[gateway] Warning: Failed to load sessions: {e}")
|
|
|
|
self._loaded = True
|
|
|
|
def _save(self) -> None:
|
|
"""Save sessions index to disk (kept for session key -> ID mapping)."""
|
|
import tempfile
|
|
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
|
sessions_file = self.sessions_dir / "sessions.json"
|
|
|
|
data = {key: entry.to_dict() for key, entry in self._entries.items()}
|
|
fd, tmp_path = tempfile.mkstemp(
|
|
dir=str(self.sessions_dir), suffix=".tmp", prefix=".sessions_"
|
|
)
|
|
try:
|
|
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
|
json.dump(data, f, indent=2)
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
os.replace(tmp_path, sessions_file)
|
|
except BaseException:
|
|
try:
|
|
os.unlink(tmp_path)
|
|
except OSError as e:
|
|
logger.debug("Could not remove temp file %s: %s", tmp_path, e)
|
|
raise
|
|
|
|
def _generate_session_key(self, source: SessionSource) -> str:
|
|
"""Generate a session key from a source."""
|
|
return build_session_key(
|
|
source,
|
|
group_sessions_per_user=getattr(self.config, "group_sessions_per_user", True),
|
|
thread_sessions_per_user=getattr(self.config, "thread_sessions_per_user", False),
|
|
)
|
|
|
|
def _is_session_expired(self, entry: SessionEntry) -> bool:
|
|
"""Check if a session has expired based on its reset policy.
|
|
|
|
Works from the entry alone — no SessionSource needed.
|
|
Used by the background expiry watcher to proactively flush memories.
|
|
Sessions with active background processes are never considered expired.
|
|
"""
|
|
if self._has_active_processes_fn:
|
|
if self._has_active_processes_fn(entry.session_key):
|
|
return False
|
|
|
|
policy = self.config.get_reset_policy(
|
|
platform=entry.platform,
|
|
session_type=entry.chat_type,
|
|
)
|
|
|
|
if policy.mode == "none":
|
|
return False
|
|
|
|
now = _now()
|
|
|
|
if policy.mode in ("idle", "both"):
|
|
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
|
if now > idle_deadline:
|
|
return True
|
|
|
|
if policy.mode in ("daily", "both"):
|
|
today_reset = now.replace(
|
|
hour=policy.at_hour,
|
|
minute=0, second=0, microsecond=0,
|
|
)
|
|
if now.hour < policy.at_hour:
|
|
today_reset -= timedelta(days=1)
|
|
if entry.updated_at < today_reset:
|
|
return True
|
|
|
|
return False
|
|
|
|
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> Optional[str]:
|
|
"""
|
|
Check if a session should be reset based on policy.
|
|
|
|
Returns the reset reason ("idle" or "daily") if a reset is needed,
|
|
or None if the session is still valid.
|
|
|
|
Sessions with active background processes are never reset.
|
|
"""
|
|
if self._has_active_processes_fn:
|
|
session_key = self._generate_session_key(source)
|
|
if self._has_active_processes_fn(session_key):
|
|
return None
|
|
|
|
policy = self.config.get_reset_policy(
|
|
platform=source.platform,
|
|
session_type=source.chat_type
|
|
)
|
|
|
|
if policy.mode == "none":
|
|
return None
|
|
|
|
now = _now()
|
|
|
|
if policy.mode in ("idle", "both"):
|
|
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
|
if now > idle_deadline:
|
|
return "idle"
|
|
|
|
if policy.mode in ("daily", "both"):
|
|
today_reset = now.replace(
|
|
hour=policy.at_hour,
|
|
minute=0,
|
|
second=0,
|
|
microsecond=0
|
|
)
|
|
if now.hour < policy.at_hour:
|
|
today_reset -= timedelta(days=1)
|
|
|
|
if entry.updated_at < today_reset:
|
|
return "daily"
|
|
|
|
return None
|
|
|
|
def has_any_sessions(self) -> bool:
|
|
"""Check if any sessions have ever been created (across all platforms).
|
|
|
|
Uses the SQLite database as the source of truth because it preserves
|
|
historical session records (ended sessions still count). The in-memory
|
|
``_entries`` dict replaces entries on reset, so ``len(_entries)`` would
|
|
stay at 1 for single-platform users — which is the bug this fixes.
|
|
|
|
The current session is already in the DB by the time this is called
|
|
(get_or_create_session runs first), so we check ``> 1``.
|
|
"""
|
|
if self._db:
|
|
try:
|
|
return self._db.session_count() > 1
|
|
except Exception:
|
|
pass # fall through to heuristic
|
|
# Fallback: check if sessions.json was loaded with existing data.
|
|
# This covers the rare case where the DB is unavailable.
|
|
with self._lock:
|
|
self._ensure_loaded_locked()
|
|
return len(self._entries) > 1
|
|
|
|
def get_or_create_session(
|
|
self,
|
|
source: SessionSource,
|
|
force_new: bool = False
|
|
) -> SessionEntry:
|
|
"""
|
|
Get an existing session or create a new one.
|
|
|
|
Evaluates reset policy to determine if the existing session is stale.
|
|
Creates a session record in SQLite when a new session starts.
|
|
"""
|
|
session_key = self._generate_session_key(source)
|
|
now = _now()
|
|
|
|
# SQLite calls are made outside the lock to avoid holding it during I/O.
|
|
# All _entries / _loaded mutations are protected by self._lock.
|
|
db_end_session_id = None
|
|
db_create_kwargs = None
|
|
|
|
with self._lock:
|
|
self._ensure_loaded_locked()
|
|
|
|
if session_key in self._entries and not force_new:
|
|
entry = self._entries[session_key]
|
|
|
|
reset_reason = self._should_reset(entry, source)
|
|
if not reset_reason:
|
|
entry.updated_at = now
|
|
self._save()
|
|
return entry
|
|
else:
|
|
# Session is being auto-reset.
|
|
was_auto_reset = True
|
|
auto_reset_reason = reset_reason
|
|
# Track whether the expired session had any real conversation
|
|
reset_had_activity = entry.total_tokens > 0
|
|
db_end_session_id = entry.session_id
|
|
else:
|
|
was_auto_reset = False
|
|
auto_reset_reason = None
|
|
reset_had_activity = False
|
|
|
|
# Create new session
|
|
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
|
|
|
entry = SessionEntry(
|
|
session_key=session_key,
|
|
session_id=session_id,
|
|
created_at=now,
|
|
updated_at=now,
|
|
origin=source,
|
|
display_name=source.chat_name,
|
|
platform=source.platform,
|
|
chat_type=source.chat_type,
|
|
was_auto_reset=was_auto_reset,
|
|
auto_reset_reason=auto_reset_reason,
|
|
reset_had_activity=reset_had_activity,
|
|
)
|
|
|
|
self._entries[session_key] = entry
|
|
self._save()
|
|
db_create_kwargs = {
|
|
"session_id": session_id,
|
|
"source": source.platform.value,
|
|
"user_id": source.user_id,
|
|
}
|
|
|
|
# SQLite operations outside the lock
|
|
if self._db and db_end_session_id:
|
|
try:
|
|
self._db.end_session(db_end_session_id, "session_reset")
|
|
except Exception as e:
|
|
logger.debug("Session DB operation failed: %s", e)
|
|
|
|
if self._db and db_create_kwargs:
|
|
try:
|
|
self._db.create_session(**db_create_kwargs)
|
|
except Exception as e:
|
|
print(f"[gateway] Warning: Failed to create SQLite session: {e}")
|
|
|
|
# Seed new DM thread sessions with parent DM session history.
|
|
# When a bot reply creates a Slack thread and the user responds in it,
|
|
# the thread gets a new session (keyed by thread_ts). Without seeding,
|
|
# the thread session starts with zero context — the user's original
|
|
# question and the bot's answer are invisible. Fix: copy the parent
|
|
# DM session's transcript into the new thread session so context carries
|
|
# over while still keeping threads isolated from each other.
|
|
if (
|
|
source.chat_type == "dm"
|
|
and source.thread_id
|
|
and entry.created_at == entry.updated_at # brand-new session
|
|
and not was_auto_reset
|
|
):
|
|
parent_source = SessionSource(
|
|
platform=source.platform,
|
|
chat_id=source.chat_id,
|
|
chat_type="dm",
|
|
user_id=source.user_id,
|
|
# no thread_id — this is the parent DM session
|
|
)
|
|
parent_key = self._generate_session_key(parent_source)
|
|
with self._lock:
|
|
parent_entry = self._entries.get(parent_key)
|
|
if parent_entry and parent_entry.session_id != entry.session_id:
|
|
try:
|
|
parent_history = self.load_transcript(parent_entry.session_id)
|
|
if parent_history:
|
|
self.rewrite_transcript(entry.session_id, parent_history)
|
|
logger.info(
|
|
"[Session] Seeded DM thread session %s with %d messages from parent %s",
|
|
entry.session_id, len(parent_history), parent_entry.session_id,
|
|
)
|
|
except Exception as e:
|
|
logger.warning("[Session] Failed to seed thread session: %s", e)
|
|
|
|
return entry
|
|
|
|
def update_session(
|
|
self,
|
|
session_key: str,
|
|
last_prompt_tokens: int = None,
|
|
) -> None:
|
|
"""Update lightweight session metadata after an interaction."""
|
|
with self._lock:
|
|
self._ensure_loaded_locked()
|
|
|
|
if session_key in self._entries:
|
|
entry = self._entries[session_key]
|
|
entry.updated_at = _now()
|
|
if last_prompt_tokens is not None:
|
|
entry.last_prompt_tokens = last_prompt_tokens
|
|
self._save()
|
|
|
|
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
|
|
"""Force reset a session, creating a new session ID."""
|
|
db_end_session_id = None
|
|
db_create_kwargs = None
|
|
new_entry = None
|
|
|
|
with self._lock:
|
|
self._ensure_loaded_locked()
|
|
|
|
if session_key not in self._entries:
|
|
return None
|
|
|
|
old_entry = self._entries[session_key]
|
|
db_end_session_id = old_entry.session_id
|
|
|
|
now = _now()
|
|
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
|
|
|
new_entry = SessionEntry(
|
|
session_key=session_key,
|
|
session_id=session_id,
|
|
created_at=now,
|
|
updated_at=now,
|
|
origin=old_entry.origin,
|
|
display_name=old_entry.display_name,
|
|
platform=old_entry.platform,
|
|
chat_type=old_entry.chat_type,
|
|
)
|
|
|
|
self._entries[session_key] = new_entry
|
|
self._save()
|
|
db_create_kwargs = {
|
|
"session_id": session_id,
|
|
"source": old_entry.platform.value if old_entry.platform else "unknown",
|
|
"user_id": old_entry.origin.user_id if old_entry.origin else None,
|
|
}
|
|
|
|
if self._db and db_end_session_id:
|
|
try:
|
|
self._db.end_session(db_end_session_id, "session_reset")
|
|
except Exception as e:
|
|
logger.debug("Session DB operation failed: %s", e)
|
|
|
|
if self._db and db_create_kwargs:
|
|
try:
|
|
self._db.create_session(**db_create_kwargs)
|
|
except Exception as e:
|
|
logger.debug("Session DB operation failed: %s", e)
|
|
|
|
return new_entry
|
|
|
|
def switch_session(self, session_key: str, target_session_id: str) -> Optional[SessionEntry]:
|
|
"""Switch a session key to point at an existing session ID.
|
|
|
|
Used by ``/resume`` to restore a previously-named session.
|
|
Ends the current session in SQLite (like reset), but instead of
|
|
generating a fresh session ID, re-uses ``target_session_id`` so the
|
|
old transcript is loaded on the next message.
|
|
"""
|
|
db_end_session_id = None
|
|
new_entry = None
|
|
|
|
with self._lock:
|
|
self._ensure_loaded_locked()
|
|
|
|
if session_key not in self._entries:
|
|
return None
|
|
|
|
old_entry = self._entries[session_key]
|
|
|
|
# Don't switch if already on that session
|
|
if old_entry.session_id == target_session_id:
|
|
return old_entry
|
|
|
|
db_end_session_id = old_entry.session_id
|
|
|
|
now = _now()
|
|
new_entry = SessionEntry(
|
|
session_key=session_key,
|
|
session_id=target_session_id,
|
|
created_at=now,
|
|
updated_at=now,
|
|
origin=old_entry.origin,
|
|
display_name=old_entry.display_name,
|
|
platform=old_entry.platform,
|
|
chat_type=old_entry.chat_type,
|
|
)
|
|
|
|
self._entries[session_key] = new_entry
|
|
self._save()
|
|
|
|
if self._db and db_end_session_id:
|
|
try:
|
|
self._db.end_session(db_end_session_id, "session_switch")
|
|
except Exception as e:
|
|
logger.debug("Session DB end_session failed: %s", e)
|
|
|
|
return new_entry
|
|
|
|
def list_sessions(self, active_minutes: Optional[int] = None) -> List[SessionEntry]:
|
|
"""List all sessions, optionally filtered by activity."""
|
|
with self._lock:
|
|
self._ensure_loaded_locked()
|
|
entries = list(self._entries.values())
|
|
|
|
if active_minutes is not None:
|
|
cutoff = _now() - timedelta(minutes=active_minutes)
|
|
entries = [e for e in entries if e.updated_at >= cutoff]
|
|
|
|
entries.sort(key=lambda e: e.updated_at, reverse=True)
|
|
|
|
return entries
|
|
|
|
def get_transcript_path(self, session_id: str) -> Path:
|
|
"""Get the path to a session's legacy transcript file."""
|
|
return self.sessions_dir / f"{session_id}.jsonl"
|
|
|
|
def append_to_transcript(self, session_id: str, message: Dict[str, Any], skip_db: bool = False) -> None:
|
|
"""Append a message to a session's transcript (SQLite + legacy JSONL).
|
|
|
|
Args:
|
|
skip_db: When True, only write to JSONL and skip the SQLite write.
|
|
Used when the agent already persisted messages to SQLite
|
|
via its own _flush_messages_to_session_db(), preventing
|
|
the duplicate-write bug (#860).
|
|
"""
|
|
# Write to SQLite (unless the agent already handled it)
|
|
if self._db and not skip_db:
|
|
try:
|
|
self._db.append_message(
|
|
session_id=session_id,
|
|
role=message.get("role", "unknown"),
|
|
content=message.get("content"),
|
|
tool_name=message.get("tool_name"),
|
|
tool_calls=message.get("tool_calls"),
|
|
tool_call_id=message.get("tool_call_id"),
|
|
)
|
|
except Exception as e:
|
|
logger.debug("Session DB operation failed: %s", e)
|
|
|
|
# Also write legacy JSONL (keeps existing tooling working during transition)
|
|
transcript_path = self.get_transcript_path(session_id)
|
|
with open(transcript_path, "a", encoding="utf-8") as f:
|
|
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
|
|
|
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
|
|
"""Replace the entire transcript for a session with new messages.
|
|
|
|
Used by /retry, /undo, and /compress to persist modified conversation history.
|
|
Rewrites both SQLite and legacy JSONL storage.
|
|
"""
|
|
# SQLite: clear old messages and re-insert
|
|
if self._db:
|
|
try:
|
|
self._db.clear_messages(session_id)
|
|
for msg in messages:
|
|
role = msg.get("role", "unknown")
|
|
self._db.append_message(
|
|
session_id=session_id,
|
|
role=role,
|
|
content=msg.get("content"),
|
|
tool_name=msg.get("tool_name"),
|
|
tool_calls=msg.get("tool_calls"),
|
|
tool_call_id=msg.get("tool_call_id"),
|
|
reasoning=msg.get("reasoning") if role == "assistant" else None,
|
|
reasoning_details=msg.get("reasoning_details") if role == "assistant" else None,
|
|
codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None,
|
|
)
|
|
except Exception as e:
|
|
logger.debug("Failed to rewrite transcript in DB: %s", e)
|
|
|
|
# JSONL: overwrite the file
|
|
transcript_path = self.get_transcript_path(session_id)
|
|
with open(transcript_path, "w", encoding="utf-8") as f:
|
|
for msg in messages:
|
|
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
|
|
|
def load_transcript(self, session_id: str) -> List[Dict[str, Any]]:
|
|
"""Load all messages from a session's transcript."""
|
|
db_messages = []
|
|
# Try SQLite first
|
|
if self._db:
|
|
try:
|
|
db_messages = self._db.get_messages_as_conversation(session_id)
|
|
except Exception as e:
|
|
logger.debug("Could not load messages from DB: %s", e)
|
|
|
|
# Load legacy JSONL transcript (may contain more history than SQLite
|
|
# for sessions created before the DB layer was introduced).
|
|
transcript_path = self.get_transcript_path(session_id)
|
|
jsonl_messages = []
|
|
if transcript_path.exists():
|
|
with open(transcript_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
try:
|
|
jsonl_messages.append(json.loads(line))
|
|
except json.JSONDecodeError:
|
|
logger.warning(
|
|
"Skipping corrupt line in transcript %s: %s",
|
|
session_id, line[:120],
|
|
)
|
|
|
|
# Prefer whichever source has more messages.
|
|
#
|
|
# Background: when a session pre-dates SQLite storage (or when the DB
|
|
# layer was added while a long-lived session was already active), the
|
|
# first post-migration turn writes only the *new* messages to SQLite
|
|
# (because _flush_messages_to_session_db skips messages already in
|
|
# conversation_history, assuming they're persisted). On the *next*
|
|
# turn load_transcript returns those few SQLite rows and ignores the
|
|
# full JSONL history — the model sees a context of 1-4 messages instead
|
|
# of hundreds. Using the longer source prevents this silent truncation.
|
|
if len(jsonl_messages) > len(db_messages):
|
|
if db_messages:
|
|
logger.debug(
|
|
"Session %s: JSONL has %d messages vs SQLite %d — "
|
|
"using JSONL (legacy session not yet fully migrated)",
|
|
session_id, len(jsonl_messages), len(db_messages),
|
|
)
|
|
return jsonl_messages
|
|
|
|
return db_messages
|
|
|
|
|
|
def build_session_context(
|
|
source: SessionSource,
|
|
config: GatewayConfig,
|
|
session_entry: Optional[SessionEntry] = None
|
|
) -> SessionContext:
|
|
"""
|
|
Build a full session context from a source and config.
|
|
|
|
This is used to inject context into the agent's system prompt.
|
|
"""
|
|
connected = config.get_connected_platforms()
|
|
|
|
home_channels = {}
|
|
for platform in connected:
|
|
home = config.get_home_channel(platform)
|
|
if home:
|
|
home_channels[platform] = home
|
|
|
|
context = SessionContext(
|
|
source=source,
|
|
connected_platforms=connected,
|
|
home_channels=home_channels,
|
|
)
|
|
|
|
if session_entry:
|
|
context.session_key = session_entry.session_key
|
|
context.session_id = session_entry.session_id
|
|
context.created_at = session_entry.created_at
|
|
context.updated_at = session_entry.updated_at
|
|
|
|
return context
|