Wrap json.loads() in load_transcript() with try/except JSONDecodeError so that partial JSONL lines (from mid-write crashes like OOM/SIGKILL) are skipped with a warning instead of crashing the entire transcript load. The rest of the history loads fine. Adds a logger.warning with the session ID and truncated corrupt line content for debugging visibility. Salvaged from PR #1193 by alireza78a. Closes #1193
989 lines
36 KiB
Python
989 lines
36 KiB
Python
"""
|
|
Session management for the gateway.
|
|
|
|
Handles:
|
|
- Session context tracking (where messages come from)
|
|
- Session storage (conversations persisted to disk)
|
|
- Reset policy evaluation (when to start fresh)
|
|
- Dynamic system prompt injection (agent knows its context)
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import json
|
|
import re
|
|
import uuid
|
|
from pathlib import Path
|
|
from datetime import datetime, timedelta
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PII redaction helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_PHONE_RE = re.compile(r"^\+?\d[\d\-\s]{6,}$")
|
|
|
|
|
|
def _hash_id(value: str) -> str:
|
|
"""Deterministic 12-char hex hash of an identifier."""
|
|
return hashlib.sha256(value.encode("utf-8")).hexdigest()[:12]
|
|
|
|
|
|
def _hash_sender_id(value: str) -> str:
|
|
"""Hash a sender ID to ``user_<12hex>``."""
|
|
return f"user_{_hash_id(value)}"
|
|
|
|
|
|
def _hash_chat_id(value: str) -> str:
|
|
"""Hash the numeric portion of a chat ID, preserving platform prefix.
|
|
|
|
``telegram:12345`` → ``telegram:<hash>``
|
|
``12345`` → ``<hash>``
|
|
"""
|
|
colon = value.find(":")
|
|
if colon > 0:
|
|
prefix = value[:colon]
|
|
return f"{prefix}:{_hash_id(value[colon + 1:])}"
|
|
return _hash_id(value)
|
|
|
|
|
|
def _looks_like_phone(value: str) -> bool:
|
|
"""Return True if *value* looks like a phone number (E.164 or similar)."""
|
|
return bool(_PHONE_RE.match(value.strip()))
|
|
|
|
from .config import (
|
|
Platform,
|
|
GatewayConfig,
|
|
SessionResetPolicy,
|
|
HomeChannel,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SessionSource:
|
|
"""
|
|
Describes where a message originated from.
|
|
|
|
This information is used to:
|
|
1. Route responses back to the right place
|
|
2. Inject context into the system prompt
|
|
3. Track origin for cron job delivery
|
|
"""
|
|
platform: Platform
|
|
chat_id: str
|
|
chat_name: Optional[str] = None
|
|
chat_type: str = "dm" # "dm", "group", "channel", "thread"
|
|
user_id: Optional[str] = None
|
|
user_name: Optional[str] = None
|
|
thread_id: Optional[str] = None # For forum topics, Discord threads, etc.
|
|
chat_topic: Optional[str] = None # Channel topic/description (Discord, Slack)
|
|
user_id_alt: Optional[str] = None # Signal UUID (alternative to phone number)
|
|
chat_id_alt: Optional[str] = None # Signal group internal ID
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
"""Human-readable description of the source."""
|
|
if self.platform == Platform.LOCAL:
|
|
return "CLI terminal"
|
|
|
|
parts = []
|
|
if self.chat_type == "dm":
|
|
parts.append(f"DM with {self.user_name or self.user_id or 'user'}")
|
|
elif self.chat_type == "group":
|
|
parts.append(f"group: {self.chat_name or self.chat_id}")
|
|
elif self.chat_type == "channel":
|
|
parts.append(f"channel: {self.chat_name or self.chat_id}")
|
|
else:
|
|
parts.append(self.chat_name or self.chat_id)
|
|
|
|
if self.thread_id:
|
|
parts.append(f"thread: {self.thread_id}")
|
|
|
|
return ", ".join(parts)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
d = {
|
|
"platform": self.platform.value,
|
|
"chat_id": self.chat_id,
|
|
"chat_name": self.chat_name,
|
|
"chat_type": self.chat_type,
|
|
"user_id": self.user_id,
|
|
"user_name": self.user_name,
|
|
"thread_id": self.thread_id,
|
|
"chat_topic": self.chat_topic,
|
|
}
|
|
if self.user_id_alt:
|
|
d["user_id_alt"] = self.user_id_alt
|
|
if self.chat_id_alt:
|
|
d["chat_id_alt"] = self.chat_id_alt
|
|
return d
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "SessionSource":
|
|
return cls(
|
|
platform=Platform(data["platform"]),
|
|
chat_id=str(data["chat_id"]),
|
|
chat_name=data.get("chat_name"),
|
|
chat_type=data.get("chat_type", "dm"),
|
|
user_id=data.get("user_id"),
|
|
user_name=data.get("user_name"),
|
|
thread_id=data.get("thread_id"),
|
|
chat_topic=data.get("chat_topic"),
|
|
user_id_alt=data.get("user_id_alt"),
|
|
chat_id_alt=data.get("chat_id_alt"),
|
|
)
|
|
|
|
@classmethod
|
|
def local_cli(cls) -> "SessionSource":
|
|
"""Create a source representing the local CLI."""
|
|
return cls(
|
|
platform=Platform.LOCAL,
|
|
chat_id="cli",
|
|
chat_name="CLI terminal",
|
|
chat_type="dm",
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SessionContext:
|
|
"""
|
|
Full context for a session, used for dynamic system prompt injection.
|
|
|
|
The agent receives this information to understand:
|
|
- Where messages are coming from
|
|
- What platforms are available
|
|
- Where it can deliver scheduled task outputs
|
|
"""
|
|
source: SessionSource
|
|
connected_platforms: List[Platform]
|
|
home_channels: Dict[Platform, HomeChannel]
|
|
|
|
# Session metadata
|
|
session_key: str = ""
|
|
session_id: str = ""
|
|
created_at: Optional[datetime] = None
|
|
updated_at: Optional[datetime] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"source": self.source.to_dict(),
|
|
"connected_platforms": [p.value for p in self.connected_platforms],
|
|
"home_channels": {
|
|
p.value: hc.to_dict() for p, hc in self.home_channels.items()
|
|
},
|
|
"session_key": self.session_key,
|
|
"session_id": self.session_id,
|
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
|
}
|
|
|
|
|
|
_PII_SAFE_PLATFORMS = frozenset({
|
|
Platform.WHATSAPP,
|
|
Platform.SIGNAL,
|
|
Platform.TELEGRAM,
|
|
})
|
|
"""Platforms where user IDs can be safely redacted (no in-message mention system
|
|
that requires raw IDs). Discord is excluded because mentions use ``<@user_id>``
|
|
and the LLM needs the real ID to tag users."""
|
|
|
|
|
|
def build_session_context_prompt(
|
|
context: SessionContext,
|
|
*,
|
|
redact_pii: bool = False,
|
|
) -> str:
|
|
"""
|
|
Build the dynamic system prompt section that tells the agent about its context.
|
|
|
|
This is injected into the system prompt so the agent knows:
|
|
- Where messages are coming from
|
|
- What platforms are connected
|
|
- Where it can deliver scheduled task outputs
|
|
|
|
When *redact_pii* is True **and** the source platform is in
|
|
``_PII_SAFE_PLATFORMS``, phone numbers are stripped and user/chat IDs
|
|
are replaced with deterministic hashes before being sent to the LLM.
|
|
Platforms like Discord are excluded because mentions need real IDs.
|
|
Routing still uses the original values (they stay in SessionSource).
|
|
"""
|
|
# Only apply redaction on platforms where IDs aren't needed for mentions
|
|
redact_pii = redact_pii and context.source.platform in _PII_SAFE_PLATFORMS
|
|
lines = [
|
|
"## Current Session Context",
|
|
"",
|
|
]
|
|
|
|
# Source info
|
|
platform_name = context.source.platform.value.title()
|
|
if context.source.platform == Platform.LOCAL:
|
|
lines.append(f"**Source:** {platform_name} (the machine running this agent)")
|
|
else:
|
|
# Build a description that respects PII redaction
|
|
src = context.source
|
|
if redact_pii:
|
|
# Build a safe description without raw IDs
|
|
_uname = src.user_name or (
|
|
_hash_sender_id(src.user_id) if src.user_id else "user"
|
|
)
|
|
_cname = src.chat_name or _hash_chat_id(src.chat_id)
|
|
if src.chat_type == "dm":
|
|
desc = f"DM with {_uname}"
|
|
elif src.chat_type == "group":
|
|
desc = f"group: {_cname}"
|
|
elif src.chat_type == "channel":
|
|
desc = f"channel: {_cname}"
|
|
else:
|
|
desc = _cname
|
|
else:
|
|
desc = src.description
|
|
lines.append(f"**Source:** {platform_name} ({desc})")
|
|
|
|
# Channel topic (if available - provides context about the channel's purpose)
|
|
if context.source.chat_topic:
|
|
lines.append(f"**Channel Topic:** {context.source.chat_topic}")
|
|
|
|
# User identity (especially useful for WhatsApp where multiple people DM)
|
|
if context.source.user_name:
|
|
lines.append(f"**User:** {context.source.user_name}")
|
|
elif context.source.user_id:
|
|
uid = context.source.user_id
|
|
if redact_pii:
|
|
uid = _hash_sender_id(uid)
|
|
lines.append(f"**User ID:** {uid}")
|
|
|
|
# Platform-specific behavioral notes
|
|
if context.source.platform == Platform.SLACK:
|
|
lines.append("")
|
|
lines.append(
|
|
"**Platform notes:** You are running inside Slack. "
|
|
"You do NOT have access to Slack-specific APIs — you cannot search "
|
|
"channel history, pin/unpin messages, manage channels, or list users. "
|
|
"Do not promise to perform these actions. If the user asks, explain "
|
|
"that you can only read messages sent directly to you and respond."
|
|
)
|
|
elif context.source.platform == Platform.DISCORD:
|
|
lines.append("")
|
|
lines.append(
|
|
"**Platform notes:** You are running inside Discord. "
|
|
"You do NOT have access to Discord-specific APIs — you cannot search "
|
|
"channel history, pin messages, manage roles, or list server members. "
|
|
"Do not promise to perform these actions. If the user asks, explain "
|
|
"that you can only read messages sent directly to you and respond."
|
|
)
|
|
|
|
# Connected platforms
|
|
platforms_list = ["local (files on this machine)"]
|
|
for p in context.connected_platforms:
|
|
if p != Platform.LOCAL:
|
|
platforms_list.append(f"{p.value}: Connected ✓")
|
|
|
|
lines.append(f"**Connected Platforms:** {', '.join(platforms_list)}")
|
|
|
|
# Home channels
|
|
if context.home_channels:
|
|
lines.append("")
|
|
lines.append("**Home Channels (default destinations):**")
|
|
for platform, home in context.home_channels.items():
|
|
hc_id = _hash_chat_id(home.chat_id) if redact_pii else home.chat_id
|
|
lines.append(f" - {platform.value}: {home.name} (ID: {hc_id})")
|
|
|
|
# Delivery options for scheduled tasks
|
|
lines.append("")
|
|
lines.append("**Delivery options for scheduled tasks:**")
|
|
|
|
# Origin delivery
|
|
if context.source.platform == Platform.LOCAL:
|
|
lines.append("- `\"origin\"` → Local output (saved to files)")
|
|
else:
|
|
_origin_label = context.source.chat_name or (
|
|
_hash_chat_id(context.source.chat_id) if redact_pii else context.source.chat_id
|
|
)
|
|
lines.append(f"- `\"origin\"` → Back to this chat ({_origin_label})")
|
|
|
|
# Local always available
|
|
lines.append("- `\"local\"` → Save to local files only (~/.hermes/cron/output/)")
|
|
|
|
# Platform home channels
|
|
for platform, home in context.home_channels.items():
|
|
lines.append(f"- `\"{platform.value}\"` → Home channel ({home.name})")
|
|
|
|
# Note about explicit targeting
|
|
lines.append("")
|
|
lines.append("*For explicit targeting, use `\"platform:chat_id\"` format if the user provides a specific chat ID.*")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
@dataclass
|
|
class SessionEntry:
|
|
"""
|
|
Entry in the session store.
|
|
|
|
Maps a session key to its current session ID and metadata.
|
|
"""
|
|
session_key: str
|
|
session_id: str
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
# Origin metadata for delivery routing
|
|
origin: Optional[SessionSource] = None
|
|
|
|
# Display metadata
|
|
display_name: Optional[str] = None
|
|
platform: Optional[Platform] = None
|
|
chat_type: str = "dm"
|
|
|
|
# Token tracking
|
|
input_tokens: int = 0
|
|
output_tokens: int = 0
|
|
cache_read_tokens: int = 0
|
|
cache_write_tokens: int = 0
|
|
total_tokens: int = 0
|
|
estimated_cost_usd: float = 0.0
|
|
cost_status: str = "unknown"
|
|
|
|
# Last API-reported prompt tokens (for accurate compression pre-check)
|
|
last_prompt_tokens: int = 0
|
|
|
|
# Set when a session was created because the previous one expired;
|
|
# consumed once by the message handler to inject a notice into context
|
|
was_auto_reset: bool = False
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
result = {
|
|
"session_key": self.session_key,
|
|
"session_id": self.session_id,
|
|
"created_at": self.created_at.isoformat(),
|
|
"updated_at": self.updated_at.isoformat(),
|
|
"display_name": self.display_name,
|
|
"platform": self.platform.value if self.platform else None,
|
|
"chat_type": self.chat_type,
|
|
"input_tokens": self.input_tokens,
|
|
"output_tokens": self.output_tokens,
|
|
"cache_read_tokens": self.cache_read_tokens,
|
|
"cache_write_tokens": self.cache_write_tokens,
|
|
"total_tokens": self.total_tokens,
|
|
"last_prompt_tokens": self.last_prompt_tokens,
|
|
"estimated_cost_usd": self.estimated_cost_usd,
|
|
"cost_status": self.cost_status,
|
|
}
|
|
if self.origin:
|
|
result["origin"] = self.origin.to_dict()
|
|
return result
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "SessionEntry":
|
|
origin = None
|
|
if "origin" in data and data["origin"]:
|
|
origin = SessionSource.from_dict(data["origin"])
|
|
|
|
platform = None
|
|
if data.get("platform"):
|
|
try:
|
|
platform = Platform(data["platform"])
|
|
except ValueError as e:
|
|
logger.debug("Unknown platform value %r: %s", data["platform"], e)
|
|
|
|
return cls(
|
|
session_key=data["session_key"],
|
|
session_id=data["session_id"],
|
|
created_at=datetime.fromisoformat(data["created_at"]),
|
|
updated_at=datetime.fromisoformat(data["updated_at"]),
|
|
origin=origin,
|
|
display_name=data.get("display_name"),
|
|
platform=platform,
|
|
chat_type=data.get("chat_type", "dm"),
|
|
input_tokens=data.get("input_tokens", 0),
|
|
output_tokens=data.get("output_tokens", 0),
|
|
cache_read_tokens=data.get("cache_read_tokens", 0),
|
|
cache_write_tokens=data.get("cache_write_tokens", 0),
|
|
total_tokens=data.get("total_tokens", 0),
|
|
last_prompt_tokens=data.get("last_prompt_tokens", 0),
|
|
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
|
|
cost_status=data.get("cost_status", "unknown"),
|
|
)
|
|
|
|
|
|
def build_session_key(source: SessionSource, group_sessions_per_user: bool = True) -> str:
|
|
"""Build a deterministic session key from a message source.
|
|
|
|
This is the single source of truth for session key construction.
|
|
|
|
DM rules:
|
|
- DMs include chat_id when present, so each private conversation is isolated.
|
|
- thread_id further differentiates threaded DMs within the same DM chat.
|
|
- Without chat_id, thread_id is used as a best-effort fallback.
|
|
- Without thread_id or chat_id, DMs share a single session.
|
|
|
|
Group/channel rules:
|
|
- chat_id identifies the parent group/channel.
|
|
- user_id/user_id_alt isolates participants within that parent chat when available when
|
|
``group_sessions_per_user`` is enabled.
|
|
- thread_id differentiates threads within that parent chat.
|
|
- Without participant identifiers, or when isolation is disabled, messages fall back to one
|
|
shared session per chat.
|
|
- Without identifiers, messages fall back to one session per platform/chat_type.
|
|
"""
|
|
platform = source.platform.value
|
|
if source.chat_type == "dm":
|
|
if source.chat_id:
|
|
if source.thread_id:
|
|
return f"agent:main:{platform}:dm:{source.chat_id}:{source.thread_id}"
|
|
return f"agent:main:{platform}:dm:{source.chat_id}"
|
|
if source.thread_id:
|
|
return f"agent:main:{platform}:dm:{source.thread_id}"
|
|
return f"agent:main:{platform}:dm"
|
|
|
|
participant_id = source.user_id_alt or source.user_id
|
|
key_parts = ["agent:main", platform, source.chat_type]
|
|
|
|
if source.chat_id:
|
|
key_parts.append(source.chat_id)
|
|
if source.thread_id:
|
|
key_parts.append(source.thread_id)
|
|
if group_sessions_per_user and participant_id:
|
|
key_parts.append(str(participant_id))
|
|
|
|
return ":".join(key_parts)
|
|
|
|
|
|
class SessionStore:
|
|
"""
|
|
Manages session storage and retrieval.
|
|
|
|
Uses SQLite (via SessionDB) for session metadata and message transcripts.
|
|
Falls back to legacy JSONL files if SQLite is unavailable.
|
|
"""
|
|
|
|
def __init__(self, sessions_dir: Path, config: GatewayConfig,
|
|
has_active_processes_fn=None,
|
|
on_auto_reset=None):
|
|
self.sessions_dir = sessions_dir
|
|
self.config = config
|
|
self._entries: Dict[str, SessionEntry] = {}
|
|
self._loaded = False
|
|
self._has_active_processes_fn = has_active_processes_fn
|
|
# on_auto_reset is deprecated — memory flush now runs proactively
|
|
# via the background session expiry watcher in GatewayRunner.
|
|
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
|
|
|
|
# Initialize SQLite session database
|
|
self._db = None
|
|
try:
|
|
from hermes_state import SessionDB
|
|
self._db = SessionDB()
|
|
except Exception as e:
|
|
print(f"[gateway] Warning: SQLite session store unavailable, falling back to JSONL: {e}")
|
|
|
|
def _ensure_loaded(self) -> None:
|
|
"""Load sessions index from disk if not already loaded."""
|
|
if self._loaded:
|
|
return
|
|
|
|
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
|
sessions_file = self.sessions_dir / "sessions.json"
|
|
|
|
if sessions_file.exists():
|
|
try:
|
|
with open(sessions_file, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
for key, entry_data in data.items():
|
|
try:
|
|
self._entries[key] = SessionEntry.from_dict(entry_data)
|
|
except (ValueError, KeyError):
|
|
# Skip entries with unknown/removed platform values
|
|
continue
|
|
except Exception as e:
|
|
print(f"[gateway] Warning: Failed to load sessions: {e}")
|
|
|
|
self._loaded = True
|
|
|
|
def _save(self) -> None:
|
|
"""Save sessions index to disk (kept for session key -> ID mapping)."""
|
|
import tempfile
|
|
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
|
sessions_file = self.sessions_dir / "sessions.json"
|
|
|
|
data = {key: entry.to_dict() for key, entry in self._entries.items()}
|
|
fd, tmp_path = tempfile.mkstemp(
|
|
dir=str(self.sessions_dir), suffix=".tmp", prefix=".sessions_"
|
|
)
|
|
try:
|
|
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
|
json.dump(data, f, indent=2)
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
os.replace(tmp_path, sessions_file)
|
|
except BaseException:
|
|
try:
|
|
os.unlink(tmp_path)
|
|
except OSError as e:
|
|
logger.debug("Could not remove temp file %s: %s", tmp_path, e)
|
|
raise
|
|
|
|
def _generate_session_key(self, source: SessionSource) -> str:
|
|
"""Generate a session key from a source."""
|
|
return build_session_key(
|
|
source,
|
|
group_sessions_per_user=getattr(self.config, "group_sessions_per_user", True),
|
|
)
|
|
|
|
def _is_session_expired(self, entry: SessionEntry) -> bool:
|
|
"""Check if a session has expired based on its reset policy.
|
|
|
|
Works from the entry alone — no SessionSource needed.
|
|
Used by the background expiry watcher to proactively flush memories.
|
|
Sessions with active background processes are never considered expired.
|
|
"""
|
|
if self._has_active_processes_fn:
|
|
if self._has_active_processes_fn(entry.session_key):
|
|
return False
|
|
|
|
policy = self.config.get_reset_policy(
|
|
platform=entry.platform,
|
|
session_type=entry.chat_type,
|
|
)
|
|
|
|
if policy.mode == "none":
|
|
return False
|
|
|
|
now = datetime.now()
|
|
|
|
if policy.mode in ("idle", "both"):
|
|
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
|
if now > idle_deadline:
|
|
return True
|
|
|
|
if policy.mode in ("daily", "both"):
|
|
today_reset = now.replace(
|
|
hour=policy.at_hour,
|
|
minute=0, second=0, microsecond=0,
|
|
)
|
|
if now.hour < policy.at_hour:
|
|
today_reset -= timedelta(days=1)
|
|
if entry.updated_at < today_reset:
|
|
return True
|
|
|
|
return False
|
|
|
|
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
|
"""
|
|
Check if a session should be reset based on policy.
|
|
|
|
Sessions with active background processes are never reset.
|
|
"""
|
|
if self._has_active_processes_fn:
|
|
session_key = self._generate_session_key(source)
|
|
if self._has_active_processes_fn(session_key):
|
|
return False
|
|
|
|
policy = self.config.get_reset_policy(
|
|
platform=source.platform,
|
|
session_type=source.chat_type
|
|
)
|
|
|
|
if policy.mode == "none":
|
|
return False
|
|
|
|
now = datetime.now()
|
|
|
|
if policy.mode in ("idle", "both"):
|
|
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
|
if now > idle_deadline:
|
|
return True
|
|
|
|
if policy.mode in ("daily", "both"):
|
|
today_reset = now.replace(
|
|
hour=policy.at_hour,
|
|
minute=0,
|
|
second=0,
|
|
microsecond=0
|
|
)
|
|
if now.hour < policy.at_hour:
|
|
today_reset -= timedelta(days=1)
|
|
|
|
if entry.updated_at < today_reset:
|
|
return True
|
|
|
|
return False
|
|
|
|
def has_any_sessions(self) -> bool:
|
|
"""Check if any sessions have ever been created (across all platforms).
|
|
|
|
Uses the SQLite database as the source of truth because it preserves
|
|
historical session records (ended sessions still count). The in-memory
|
|
``_entries`` dict replaces entries on reset, so ``len(_entries)`` would
|
|
stay at 1 for single-platform users — which is the bug this fixes.
|
|
|
|
The current session is already in the DB by the time this is called
|
|
(get_or_create_session runs first), so we check ``> 1``.
|
|
"""
|
|
if self._db:
|
|
try:
|
|
return self._db.session_count() > 1
|
|
except Exception:
|
|
pass # fall through to heuristic
|
|
# Fallback: check if sessions.json was loaded with existing data.
|
|
# This covers the rare case where the DB is unavailable.
|
|
self._ensure_loaded()
|
|
return len(self._entries) > 1
|
|
|
|
def get_or_create_session(
|
|
self,
|
|
source: SessionSource,
|
|
force_new: bool = False
|
|
) -> SessionEntry:
|
|
"""
|
|
Get an existing session or create a new one.
|
|
|
|
Evaluates reset policy to determine if the existing session is stale.
|
|
Creates a session record in SQLite when a new session starts.
|
|
"""
|
|
self._ensure_loaded()
|
|
|
|
session_key = self._generate_session_key(source)
|
|
now = datetime.now()
|
|
|
|
if session_key in self._entries and not force_new:
|
|
entry = self._entries[session_key]
|
|
|
|
if not self._should_reset(entry, source):
|
|
entry.updated_at = now
|
|
self._save()
|
|
return entry
|
|
else:
|
|
# Session is being auto-reset. The background expiry watcher
|
|
# should have already flushed memories proactively; discard
|
|
# the marker so it doesn't accumulate.
|
|
was_auto_reset = True
|
|
self._pre_flushed_sessions.discard(entry.session_id)
|
|
if self._db:
|
|
try:
|
|
self._db.end_session(entry.session_id, "session_reset")
|
|
except Exception as e:
|
|
logger.debug("Session DB operation failed: %s", e)
|
|
else:
|
|
was_auto_reset = False
|
|
|
|
# Create new session
|
|
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
|
|
|
entry = SessionEntry(
|
|
session_key=session_key,
|
|
session_id=session_id,
|
|
created_at=now,
|
|
updated_at=now,
|
|
origin=source,
|
|
display_name=source.chat_name,
|
|
platform=source.platform,
|
|
chat_type=source.chat_type,
|
|
was_auto_reset=was_auto_reset,
|
|
)
|
|
|
|
self._entries[session_key] = entry
|
|
self._save()
|
|
|
|
# Create session in SQLite
|
|
if self._db:
|
|
try:
|
|
self._db.create_session(
|
|
session_id=session_id,
|
|
source=source.platform.value,
|
|
user_id=source.user_id,
|
|
)
|
|
except Exception as e:
|
|
print(f"[gateway] Warning: Failed to create SQLite session: {e}")
|
|
|
|
return entry
|
|
|
|
def update_session(
|
|
self,
|
|
session_key: str,
|
|
input_tokens: int = 0,
|
|
output_tokens: int = 0,
|
|
cache_read_tokens: int = 0,
|
|
cache_write_tokens: int = 0,
|
|
last_prompt_tokens: int = None,
|
|
model: str = None,
|
|
estimated_cost_usd: Optional[float] = None,
|
|
cost_status: Optional[str] = None,
|
|
cost_source: Optional[str] = None,
|
|
provider: Optional[str] = None,
|
|
base_url: Optional[str] = None,
|
|
) -> None:
|
|
"""Update a session's metadata after an interaction."""
|
|
self._ensure_loaded()
|
|
|
|
if session_key in self._entries:
|
|
entry = self._entries[session_key]
|
|
entry.updated_at = datetime.now()
|
|
entry.input_tokens += input_tokens
|
|
entry.output_tokens += output_tokens
|
|
entry.cache_read_tokens += cache_read_tokens
|
|
entry.cache_write_tokens += cache_write_tokens
|
|
if last_prompt_tokens is not None:
|
|
entry.last_prompt_tokens = last_prompt_tokens
|
|
if estimated_cost_usd is not None:
|
|
entry.estimated_cost_usd += estimated_cost_usd
|
|
if cost_status:
|
|
entry.cost_status = cost_status
|
|
entry.total_tokens = (
|
|
entry.input_tokens
|
|
+ entry.output_tokens
|
|
+ entry.cache_read_tokens
|
|
+ entry.cache_write_tokens
|
|
)
|
|
self._save()
|
|
|
|
if self._db:
|
|
try:
|
|
self._db.update_token_counts(
|
|
entry.session_id,
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
cache_read_tokens=cache_read_tokens,
|
|
cache_write_tokens=cache_write_tokens,
|
|
estimated_cost_usd=estimated_cost_usd,
|
|
cost_status=cost_status,
|
|
cost_source=cost_source,
|
|
billing_provider=provider,
|
|
billing_base_url=base_url,
|
|
model=model,
|
|
)
|
|
except Exception as e:
|
|
logger.debug("Session DB operation failed: %s", e)
|
|
|
|
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
|
|
"""Force reset a session, creating a new session ID."""
|
|
self._ensure_loaded()
|
|
|
|
if session_key not in self._entries:
|
|
return None
|
|
|
|
old_entry = self._entries[session_key]
|
|
|
|
# End old session in SQLite
|
|
if self._db:
|
|
try:
|
|
self._db.end_session(old_entry.session_id, "session_reset")
|
|
except Exception as e:
|
|
logger.debug("Session DB operation failed: %s", e)
|
|
|
|
now = datetime.now()
|
|
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
|
|
|
new_entry = SessionEntry(
|
|
session_key=session_key,
|
|
session_id=session_id,
|
|
created_at=now,
|
|
updated_at=now,
|
|
origin=old_entry.origin,
|
|
display_name=old_entry.display_name,
|
|
platform=old_entry.platform,
|
|
chat_type=old_entry.chat_type,
|
|
)
|
|
|
|
self._entries[session_key] = new_entry
|
|
self._save()
|
|
|
|
# Create new session in SQLite
|
|
if self._db:
|
|
try:
|
|
self._db.create_session(
|
|
session_id=session_id,
|
|
source=old_entry.platform.value if old_entry.platform else "unknown",
|
|
user_id=old_entry.origin.user_id if old_entry.origin else None,
|
|
)
|
|
except Exception as e:
|
|
logger.debug("Session DB operation failed: %s", e)
|
|
|
|
return new_entry
|
|
|
|
def switch_session(self, session_key: str, target_session_id: str) -> Optional[SessionEntry]:
|
|
"""Switch a session key to point at an existing session ID.
|
|
|
|
Used by ``/resume`` to restore a previously-named session.
|
|
Ends the current session in SQLite (like reset), but instead of
|
|
generating a fresh session ID, re-uses ``target_session_id`` so the
|
|
old transcript is loaded on the next message.
|
|
"""
|
|
self._ensure_loaded()
|
|
|
|
if session_key not in self._entries:
|
|
return None
|
|
|
|
old_entry = self._entries[session_key]
|
|
|
|
# Don't switch if already on that session
|
|
if old_entry.session_id == target_session_id:
|
|
return old_entry
|
|
|
|
# End the current session in SQLite
|
|
if self._db:
|
|
try:
|
|
self._db.end_session(old_entry.session_id, "session_switch")
|
|
except Exception as e:
|
|
logger.debug("Session DB end_session failed: %s", e)
|
|
|
|
now = datetime.now()
|
|
new_entry = SessionEntry(
|
|
session_key=session_key,
|
|
session_id=target_session_id,
|
|
created_at=now,
|
|
updated_at=now,
|
|
origin=old_entry.origin,
|
|
display_name=old_entry.display_name,
|
|
platform=old_entry.platform,
|
|
chat_type=old_entry.chat_type,
|
|
)
|
|
|
|
self._entries[session_key] = new_entry
|
|
self._save()
|
|
return new_entry
|
|
|
|
def list_sessions(self, active_minutes: Optional[int] = None) -> List[SessionEntry]:
|
|
"""List all sessions, optionally filtered by activity."""
|
|
self._ensure_loaded()
|
|
|
|
entries = list(self._entries.values())
|
|
|
|
if active_minutes is not None:
|
|
cutoff = datetime.now() - timedelta(minutes=active_minutes)
|
|
entries = [e for e in entries if e.updated_at >= cutoff]
|
|
|
|
entries.sort(key=lambda e: e.updated_at, reverse=True)
|
|
|
|
return entries
|
|
|
|
def get_transcript_path(self, session_id: str) -> Path:
|
|
"""Get the path to a session's legacy transcript file."""
|
|
return self.sessions_dir / f"{session_id}.jsonl"
|
|
|
|
def append_to_transcript(self, session_id: str, message: Dict[str, Any], skip_db: bool = False) -> None:
|
|
"""Append a message to a session's transcript (SQLite + legacy JSONL).
|
|
|
|
Args:
|
|
skip_db: When True, only write to JSONL and skip the SQLite write.
|
|
Used when the agent already persisted messages to SQLite
|
|
via its own _flush_messages_to_session_db(), preventing
|
|
the duplicate-write bug (#860).
|
|
"""
|
|
# Write to SQLite (unless the agent already handled it)
|
|
if self._db and not skip_db:
|
|
try:
|
|
self._db.append_message(
|
|
session_id=session_id,
|
|
role=message.get("role", "unknown"),
|
|
content=message.get("content"),
|
|
tool_name=message.get("tool_name"),
|
|
tool_calls=message.get("tool_calls"),
|
|
tool_call_id=message.get("tool_call_id"),
|
|
)
|
|
except Exception as e:
|
|
logger.debug("Session DB operation failed: %s", e)
|
|
|
|
# Also write legacy JSONL (keeps existing tooling working during transition)
|
|
transcript_path = self.get_transcript_path(session_id)
|
|
with open(transcript_path, "a", encoding="utf-8") as f:
|
|
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
|
|
|
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
|
|
"""Replace the entire transcript for a session with new messages.
|
|
|
|
Used by /retry, /undo, and /compress to persist modified conversation history.
|
|
Rewrites both SQLite and legacy JSONL storage.
|
|
"""
|
|
# SQLite: clear old messages and re-insert
|
|
if self._db:
|
|
try:
|
|
self._db.clear_messages(session_id)
|
|
for msg in messages:
|
|
self._db.append_message(
|
|
session_id=session_id,
|
|
role=msg.get("role", "unknown"),
|
|
content=msg.get("content"),
|
|
tool_name=msg.get("tool_name"),
|
|
tool_calls=msg.get("tool_calls"),
|
|
tool_call_id=msg.get("tool_call_id"),
|
|
)
|
|
except Exception as e:
|
|
logger.debug("Failed to rewrite transcript in DB: %s", e)
|
|
|
|
# JSONL: overwrite the file
|
|
transcript_path = self.get_transcript_path(session_id)
|
|
with open(transcript_path, "w", encoding="utf-8") as f:
|
|
for msg in messages:
|
|
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
|
|
|
def load_transcript(self, session_id: str) -> List[Dict[str, Any]]:
|
|
"""Load all messages from a session's transcript."""
|
|
# Try SQLite first
|
|
if self._db:
|
|
try:
|
|
messages = self._db.get_messages_as_conversation(session_id)
|
|
if messages:
|
|
return messages
|
|
except Exception as e:
|
|
logger.debug("Could not load messages from DB: %s", e)
|
|
|
|
# Fall back to legacy JSONL
|
|
transcript_path = self.get_transcript_path(session_id)
|
|
|
|
if not transcript_path.exists():
|
|
return []
|
|
|
|
messages = []
|
|
with open(transcript_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
try:
|
|
messages.append(json.loads(line))
|
|
except json.JSONDecodeError:
|
|
logger.warning(
|
|
"Skipping corrupt line in transcript %s: %s",
|
|
session_id, line[:120],
|
|
)
|
|
|
|
return messages
|
|
|
|
|
|
def build_session_context(
|
|
source: SessionSource,
|
|
config: GatewayConfig,
|
|
session_entry: Optional[SessionEntry] = None
|
|
) -> SessionContext:
|
|
"""
|
|
Build a full session context from a source and config.
|
|
|
|
This is used to inject context into the agent's system prompt.
|
|
"""
|
|
connected = config.get_connected_platforms()
|
|
|
|
home_channels = {}
|
|
for platform in connected:
|
|
home = config.get_home_channel(platform)
|
|
if home:
|
|
home_channels[platform] = home
|
|
|
|
context = SessionContext(
|
|
source=source,
|
|
connected_platforms=connected,
|
|
home_channels=home_channels,
|
|
)
|
|
|
|
if session_entry:
|
|
context.session_key = session_entry.session_key
|
|
context.session_id = session_entry.session_id
|
|
context.created_at = session_entry.created_at
|
|
context.updated_at = session_entry.updated_at
|
|
|
|
return context
|