Merge pull request #1495 from NousResearch/fix/814-group-session-isolation

fix(gateway): default group sessions to per-user isolation
This commit is contained in:
Teknium
2026-03-16 00:25:43 -07:00
committed by GitHub
11 changed files with 307 additions and 29 deletions

View File

@@ -175,7 +175,10 @@ class GatewayConfig:
# STT settings
stt_enabled: bool = True # Whether to auto-transcribe inbound voice messages
# Session isolation in shared chats
group_sessions_per_user: bool = True # Isolate group/channel sessions per participant when user IDs are available
def get_connected_platforms(self) -> List[Platform]:
"""Return list of platforms that are enabled and configured."""
connected = []
@@ -240,6 +243,7 @@ class GatewayConfig:
"sessions_dir": str(self.sessions_dir),
"always_log_local": self.always_log_local,
"stt_enabled": self.stt_enabled,
"group_sessions_per_user": self.group_sessions_per_user,
}
@classmethod
@@ -280,6 +284,8 @@ class GatewayConfig:
if stt_enabled is None:
stt_enabled = data.get("stt", {}).get("enabled") if isinstance(data.get("stt"), dict) else None
group_sessions_per_user = data.get("group_sessions_per_user")
return cls(
platforms=platforms,
default_reset_policy=default_policy,
@@ -290,6 +296,7 @@ class GatewayConfig:
sessions_dir=sessions_dir,
always_log_local=data.get("always_log_local", True),
stt_enabled=_coerce_bool(stt_enabled, True),
group_sessions_per_user=_coerce_bool(group_sessions_per_user, True),
)
@@ -345,6 +352,14 @@ def load_gateway_config() -> GatewayConfig:
if isinstance(stt_cfg, dict) and "enabled" in stt_cfg:
config.stt_enabled = _coerce_bool(stt_cfg.get("enabled"), True)
# Bridge group session isolation from config.yaml into gateway runtime.
# Secure default is per-user isolation in shared chats.
if "group_sessions_per_user" in yaml_cfg:
config.group_sessions_per_user = _coerce_bool(
yaml_cfg.get("group_sessions_per_user"),
True,
)
# Bridge discord settings from config.yaml to env vars
# (env vars take precedence — only set if not already defined)
discord_cfg = yaml_cfg.get("discord", {})

View File

@@ -752,7 +752,10 @@ class BasePlatformAdapter(ABC):
if not self._message_handler:
return
session_key = build_session_key(event.source)
session_key = build_session_key(
event.source,
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
)
# Check if there's already an active handler for this session
if session_key in self._active_sessions:

View File

@@ -829,7 +829,10 @@ class TelegramAdapter(BasePlatformAdapter):
def _photo_batch_key(self, event: MessageEvent, msg: Message) -> str:
"""Return a batching key for Telegram photos/albums."""
from gateway.session import build_session_key
session_key = build_session_key(event.source)
session_key = build_session_key(
event.source,
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
)
media_group_id = getattr(msg, "media_group_id", None)
if media_group_id:
return f"{session_key}:album:{media_group_id}"

View File

@@ -572,6 +572,21 @@ class GatewayRunner:
def exit_reason(self) -> Optional[str]:
return self._exit_reason
def _session_key_for_source(self, source: SessionSource) -> str:
"""Resolve the current session key for a source, honoring gateway config when available."""
if hasattr(self, "session_store") and self.session_store is not None:
try:
session_key = self.session_store._generate_session_key(source)
if isinstance(session_key, str) and session_key:
return session_key
except Exception:
pass
config = getattr(self, "config", None)
return build_session_key(
source,
group_sessions_per_user=getattr(config, "group_sessions_per_user", True),
)
async def _handle_adapter_fatal_error(self, adapter: BasePlatformAdapter) -> None:
"""React to a non-retryable adapter failure after startup."""
logger.error(
@@ -1000,6 +1015,12 @@ class GatewayRunner:
config: Any
) -> Optional[BasePlatformAdapter]:
"""Create the appropriate adapter for a platform."""
if hasattr(config, "extra") and isinstance(config.extra, dict):
config.extra.setdefault(
"group_sessions_per_user",
self.config.group_sessions_per_user,
)
if platform == Platform.TELEGRAM:
from gateway.platforms.telegram import TelegramAdapter, check_telegram_requirements
if not check_telegram_requirements():
@@ -1171,7 +1192,7 @@ class GatewayRunner:
# Special case: Telegram/photo bursts often arrive as multiple near-
# simultaneous updates. Do NOT interrupt for photo-only follow-ups here;
# let the adapter-level batching/queueing logic absorb them.
_quick_key = build_session_key(source)
_quick_key = self._session_key_for_source(source)
if _quick_key in self._running_agents:
if event.get_command() == "status":
return await self._handle_status_command(event)
@@ -1360,7 +1381,7 @@ class GatewayRunner:
logger.debug("Skill command check failed (non-fatal): %s", e)
# Check for pending exec approval responses
session_key_preview = build_session_key(source)
session_key_preview = self._session_key_for_source(source)
if session_key_preview in self._pending_approvals:
user_text = event.text.strip().lower()
if user_text in ("yes", "y", "approve", "ok", "go", "do it"):
@@ -1912,7 +1933,7 @@ class GatewayRunner:
source = event.source
# Get existing session key
session_key = self.session_store._generate_session_key(source)
session_key = self._session_key_for_source(source)
# Flush memories in the background (fire-and-forget) so the user
# gets the "Session reset!" response immediately.
@@ -3144,7 +3165,7 @@ class GatewayRunner:
return "Session database not available."
source = event.source
session_key = build_session_key(source)
session_key = self._session_key_for_source(source)
name = event.get_command_args().strip()
if not name:
@@ -3218,7 +3239,7 @@ class GatewayRunner:
async def _handle_usage_command(self, event: MessageEvent) -> str:
"""Handle /usage command -- show token usage for the session's last agent run."""
source = event.source
session_key = build_session_key(source)
session_key = self._session_key_for_source(source)
agent = self._running_agents.get(session_key)
if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0:

View File

@@ -315,7 +315,7 @@ class SessionEntry:
)
def build_session_key(source: SessionSource) -> str:
def build_session_key(source: SessionSource, group_sessions_per_user: bool = True) -> str:
"""Build a deterministic session key from a message source.
This is the single source of truth for session key construction.
@@ -328,7 +328,11 @@ def build_session_key(source: SessionSource) -> str:
Group/channel rules:
- chat_id identifies the parent group/channel.
- user_id/user_id_alt isolates participants within that parent chat when available when
``group_sessions_per_user`` is enabled.
- thread_id differentiates threads within that parent chat.
- Without participant identifiers, or when isolation is disabled, messages fall back to one
shared session per chat.
- Without identifiers, messages fall back to one session per platform/chat_type.
"""
platform = source.platform.value
@@ -340,13 +344,18 @@ def build_session_key(source: SessionSource) -> str:
if source.thread_id:
return f"agent:main:{platform}:dm:{source.thread_id}"
return f"agent:main:{platform}:dm"
participant_id = source.user_id_alt or source.user_id
key_parts = ["agent:main", platform, source.chat_type]
if source.chat_id:
if source.thread_id:
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}:{source.thread_id}"
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}"
key_parts.append(source.chat_id)
if source.thread_id:
return f"agent:main:{platform}:{source.chat_type}:{source.thread_id}"
return f"agent:main:{platform}:{source.chat_type}"
key_parts.append(source.thread_id)
if group_sessions_per_user and participant_id:
key_parts.append(str(participant_id))
return ":".join(key_parts)
class SessionStore:
@@ -425,7 +434,10 @@ class SessionStore:
def _generate_session_key(self, source: SessionSource) -> str:
"""Generate a session key from a source."""
return build_session_key(source)
return build_session_key(
source,
group_sessions_per_user=getattr(self.config, "group_sessions_per_user", True),
)
def _is_session_expired(self, entry: SessionEntry) -> bool:
"""Check if a session has expired based on its reset policy.