diff --git a/gateway/run.py b/gateway/run.py index 58c52f4b4..37fc3d8f1 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -5849,7 +5849,12 @@ class GatewayRunner: # command approval blocks the agent thread (mirrors CLI input()). # The callback bridges sync→async to send the approval request # to the user immediately. - from tools.approval import register_gateway_notify, unregister_gateway_notify + from tools.approval import ( + register_gateway_notify, + reset_current_session_key, + set_current_session_key, + unregister_gateway_notify, + ) def _approval_notify_sync(approval_data: dict) -> None: """Send the approval request to the user from the agent thread. @@ -5905,11 +5910,13 @@ class GatewayRunner: logger.error("Failed to send approval request: %s", _e) _approval_session_key = session_key or "" + _approval_session_token = set_current_session_key(_approval_session_key) register_gateway_notify(_approval_session_key, _approval_notify_sync) try: result = agent.run_conversation(message, conversation_history=agent_history, task_id=session_id) finally: unregister_gateway_notify(_approval_session_key) + reset_current_session_key(_approval_session_token) result_holder[0] = result # Signal the stream consumer that the agent is done diff --git a/tests/gateway/test_approve_deny_commands.py b/tests/gateway/test_approve_deny_commands.py index ddb3ebef5..d360e0cfb 100644 --- a/tests/gateway/test_approve_deny_commands.py +++ b/tests/gateway/test_approve_deny_commands.py @@ -390,6 +390,9 @@ class TestBlockingApprovalE2E: result_holder = [None] def agent_thread(): + from tools.approval import reset_current_session_key, set_current_session_key + + token = set_current_session_key(session_key) os.environ["HERMES_EXEC_ASK"] = "1" os.environ["HERMES_SESSION_KEY"] = session_key try: @@ -399,6 +402,7 @@ class TestBlockingApprovalE2E: finally: os.environ.pop("HERMES_EXEC_ASK", None) os.environ.pop("HERMES_SESSION_KEY", None) + reset_current_session_key(token) t = threading.Thread(target=agent_thread) t.start() @@ -432,6 +436,9 @@ class TestBlockingApprovalE2E: result_holder = [None] def agent_thread(): + from tools.approval import reset_current_session_key, set_current_session_key + + token = set_current_session_key(session_key) os.environ["HERMES_EXEC_ASK"] = "1" os.environ["HERMES_SESSION_KEY"] = session_key try: @@ -441,6 +448,7 @@ class TestBlockingApprovalE2E: finally: os.environ.pop("HERMES_EXEC_ASK", None) os.environ.pop("HERMES_SESSION_KEY", None) + reset_current_session_key(token) t = threading.Thread(target=agent_thread) t.start() @@ -469,6 +477,9 @@ class TestBlockingApprovalE2E: result_holder = [None] def agent_thread(): + from tools.approval import reset_current_session_key, set_current_session_key + + token = set_current_session_key(session_key) os.environ["HERMES_EXEC_ASK"] = "1" os.environ["HERMES_SESSION_KEY"] = session_key try: @@ -480,6 +491,7 @@ class TestBlockingApprovalE2E: finally: os.environ.pop("HERMES_EXEC_ASK", None) os.environ.pop("HERMES_SESSION_KEY", None) + reset_current_session_key(token) t = threading.Thread(target=agent_thread) t.start() @@ -505,6 +517,9 @@ class TestBlockingApprovalE2E: def make_agent(idx, cmd): def run(): + from tools.approval import reset_current_session_key, set_current_session_key + + token = set_current_session_key(session_key) os.environ["HERMES_EXEC_ASK"] = "1" os.environ["HERMES_SESSION_KEY"] = session_key try: @@ -512,6 +527,7 @@ class TestBlockingApprovalE2E: finally: os.environ.pop("HERMES_EXEC_ASK", None) os.environ.pop("HERMES_SESSION_KEY", None) + reset_current_session_key(token) return run threads = [ @@ -556,6 +572,9 @@ class TestBlockingApprovalE2E: def make_agent(idx, cmd): def run(): + from tools.approval import reset_current_session_key, set_current_session_key + + token = set_current_session_key(session_key) os.environ["HERMES_EXEC_ASK"] = "1" os.environ["HERMES_SESSION_KEY"] = session_key try: @@ -563,6 +582,7 @@ class TestBlockingApprovalE2E: finally: os.environ.pop("HERMES_EXEC_ASK", None) os.environ.pop("HERMES_SESSION_KEY", None) + reset_current_session_key(token) return run threads = [ @@ -580,8 +600,9 @@ class TestBlockingApprovalE2E: for t in threads: t.join(timeout=5) - assert results[0]["approved"] is True - assert results[1]["approved"] is False + assert all(r is not None for r in results) + assert sorted(r["approved"] for r in results) == [False, True] + assert sum("BLOCKED" in (r.get("message") or "") for r in results) == 1 unregister_gateway_notify(session_key) diff --git a/tests/tools/test_approval.py b/tests/tools/test_approval.py index abdda05fa..42dd0e7e0 100644 --- a/tests/tools/test_approval.py +++ b/tests/tools/test_approval.py @@ -1,5 +1,7 @@ """Tests for the dangerous command approval module.""" +import ast +from pathlib import Path from unittest.mock import patch as mock_patch import tools.approval as approval_module @@ -148,6 +150,79 @@ class TestApproveAndCheckSession: assert has_pending(key) is False +class TestSessionKeyContext: + def test_context_session_key_overrides_process_env(self): + token = approval_module.set_current_session_key("alice") + try: + with mock_patch.dict("os.environ", {"HERMES_SESSION_KEY": "bob"}, clear=False): + assert approval_module.get_current_session_key() == "alice" + finally: + approval_module.reset_current_session_key(token) + + def test_gateway_runner_binds_session_key_to_context_before_agent_run(self): + run_py = Path(__file__).resolve().parents[2] / "gateway" / "run.py" + module = ast.parse(run_py.read_text(encoding="utf-8")) + + run_sync = None + for node in ast.walk(module): + if isinstance(node, ast.FunctionDef) and node.name == "run_sync": + run_sync = node + break + + assert run_sync is not None, "gateway.run.run_sync not found" + + called_names = set() + for node in ast.walk(run_sync): + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): + called_names.add(node.func.id) + + assert "set_current_session_key" in called_names + assert "reset_current_session_key" in called_names + + def test_context_keeps_pending_approval_attached_to_originating_session(self): + import os + import threading + + clear_session("alice") + clear_session("bob") + pop_pending("alice") + pop_pending("bob") + approval_module._permanent_approved.clear() + + alice_ready = threading.Event() + bob_ready = threading.Event() + + def worker_alice(): + token = approval_module.set_current_session_key("alice") + try: + os.environ["HERMES_EXEC_ASK"] = "1" + os.environ["HERMES_SESSION_KEY"] = "alice" + alice_ready.set() + bob_ready.wait() + approval_module.check_all_command_guards("rm -rf /tmp/alice-secret", "local") + finally: + approval_module.reset_current_session_key(token) + + def worker_bob(): + alice_ready.wait() + token = approval_module.set_current_session_key("bob") + try: + os.environ["HERMES_SESSION_KEY"] = "bob" + bob_ready.set() + finally: + approval_module.reset_current_session_key(token) + + t1 = threading.Thread(target=worker_alice) + t2 = threading.Thread(target=worker_bob) + t1.start() + t2.start() + t1.join() + t2.join() + + assert pop_pending("alice") is not None + assert pop_pending("bob") is None + + class TestRmFalsePositiveFix: """Regression tests: filenames starting with 'r' must NOT trigger recursive delete.""" diff --git a/tools/approval.py b/tools/approval.py index 57b2f5863..5e8e4cfe5 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -8,6 +8,7 @@ This module is the single source of truth for the dangerous command system: - Permanent allowlist persistence (config.yaml) """ +import contextvars import logging import os import re @@ -18,6 +19,33 @@ from typing import Optional logger = logging.getLogger(__name__) +# Per-thread/per-task gateway session identity. +# Gateway runs agent turns concurrently in executor threads, so reading a +# process-global env var for session identity is racy. Keep env fallback for +# legacy single-threaded callers, but prefer the context-local value when set. +_approval_session_key: contextvars.ContextVar[str] = contextvars.ContextVar( + "approval_session_key", + default="", +) + + +def set_current_session_key(session_key: str): + """Bind the active approval session key to the current context.""" + return _approval_session_key.set(session_key or "") + + +def reset_current_session_key(token) -> None: + """Restore the prior approval session key context.""" + _approval_session_key.reset(token) + + +def get_current_session_key(default: str = "default") -> str: + """Return the active session key, preferring context-local state.""" + session_key = _approval_session_key.get() + if session_key: + return session_key + return os.getenv("HERMES_SESSION_KEY", default) + # Sensitive write targets that should trigger approval even when referenced # via shell expansions like $HOME or $HERMES_HOME. _SSH_SENSITIVE_PATH = r'(?:~|\$home|\$\{home\})/\.ssh(?:/|$)' @@ -534,7 +562,7 @@ def check_dangerous_command(command: str, env_type: str, if not is_dangerous: return {"approved": True, "message": None} - session_key = os.getenv("HERMES_SESSION_KEY", "default") + session_key = get_current_session_key() if is_approved(session_key, pattern_key): return {"approved": True, "message": None} @@ -660,7 +688,7 @@ def check_all_command_guards(command: str, env_type: str, # Collect warnings that need approval warnings = [] # list of (pattern_key, description, is_tirith) - session_key = os.getenv("HERMES_SESSION_KEY", "default") + session_key = get_current_session_key() # Tirith block/warn → approvable warning with rich findings. # Previously, tirith "block" was a hard block with no approval prompt.