fix(gateway): isolate approval session key per turn
This commit is contained in:
@@ -5849,7 +5849,12 @@ class GatewayRunner:
|
||||
# command approval blocks the agent thread (mirrors CLI input()).
|
||||
# The callback bridges sync→async to send the approval request
|
||||
# to the user immediately.
|
||||
from tools.approval import register_gateway_notify, unregister_gateway_notify
|
||||
from tools.approval import (
|
||||
register_gateway_notify,
|
||||
reset_current_session_key,
|
||||
set_current_session_key,
|
||||
unregister_gateway_notify,
|
||||
)
|
||||
|
||||
def _approval_notify_sync(approval_data: dict) -> None:
|
||||
"""Send the approval request to the user from the agent thread.
|
||||
@@ -5905,11 +5910,13 @@ class GatewayRunner:
|
||||
logger.error("Failed to send approval request: %s", _e)
|
||||
|
||||
_approval_session_key = session_key or ""
|
||||
_approval_session_token = set_current_session_key(_approval_session_key)
|
||||
register_gateway_notify(_approval_session_key, _approval_notify_sync)
|
||||
try:
|
||||
result = agent.run_conversation(message, conversation_history=agent_history, task_id=session_id)
|
||||
finally:
|
||||
unregister_gateway_notify(_approval_session_key)
|
||||
reset_current_session_key(_approval_session_token)
|
||||
result_holder[0] = result
|
||||
|
||||
# Signal the stream consumer that the agent is done
|
||||
|
||||
@@ -390,6 +390,9 @@ class TestBlockingApprovalE2E:
|
||||
result_holder = [None]
|
||||
|
||||
def agent_thread():
|
||||
from tools.approval import reset_current_session_key, set_current_session_key
|
||||
|
||||
token = set_current_session_key(session_key)
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
@@ -399,6 +402,7 @@ class TestBlockingApprovalE2E:
|
||||
finally:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
reset_current_session_key(token)
|
||||
|
||||
t = threading.Thread(target=agent_thread)
|
||||
t.start()
|
||||
@@ -432,6 +436,9 @@ class TestBlockingApprovalE2E:
|
||||
result_holder = [None]
|
||||
|
||||
def agent_thread():
|
||||
from tools.approval import reset_current_session_key, set_current_session_key
|
||||
|
||||
token = set_current_session_key(session_key)
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
@@ -441,6 +448,7 @@ class TestBlockingApprovalE2E:
|
||||
finally:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
reset_current_session_key(token)
|
||||
|
||||
t = threading.Thread(target=agent_thread)
|
||||
t.start()
|
||||
@@ -469,6 +477,9 @@ class TestBlockingApprovalE2E:
|
||||
result_holder = [None]
|
||||
|
||||
def agent_thread():
|
||||
from tools.approval import reset_current_session_key, set_current_session_key
|
||||
|
||||
token = set_current_session_key(session_key)
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
@@ -480,6 +491,7 @@ class TestBlockingApprovalE2E:
|
||||
finally:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
reset_current_session_key(token)
|
||||
|
||||
t = threading.Thread(target=agent_thread)
|
||||
t.start()
|
||||
@@ -505,6 +517,9 @@ class TestBlockingApprovalE2E:
|
||||
|
||||
def make_agent(idx, cmd):
|
||||
def run():
|
||||
from tools.approval import reset_current_session_key, set_current_session_key
|
||||
|
||||
token = set_current_session_key(session_key)
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
@@ -512,6 +527,7 @@ class TestBlockingApprovalE2E:
|
||||
finally:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
reset_current_session_key(token)
|
||||
return run
|
||||
|
||||
threads = [
|
||||
@@ -556,6 +572,9 @@ class TestBlockingApprovalE2E:
|
||||
|
||||
def make_agent(idx, cmd):
|
||||
def run():
|
||||
from tools.approval import reset_current_session_key, set_current_session_key
|
||||
|
||||
token = set_current_session_key(session_key)
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
@@ -563,6 +582,7 @@ class TestBlockingApprovalE2E:
|
||||
finally:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
reset_current_session_key(token)
|
||||
return run
|
||||
|
||||
threads = [
|
||||
@@ -580,8 +600,9 @@ class TestBlockingApprovalE2E:
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
assert results[0]["approved"] is True
|
||||
assert results[1]["approved"] is False
|
||||
assert all(r is not None for r in results)
|
||||
assert sorted(r["approved"] for r in results) == [False, True]
|
||||
assert sum("BLOCKED" in (r.get("message") or "") for r in results) == 1
|
||||
unregister_gateway_notify(session_key)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for the dangerous command approval module."""
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch as mock_patch
|
||||
|
||||
import tools.approval as approval_module
|
||||
@@ -148,6 +150,79 @@ class TestApproveAndCheckSession:
|
||||
assert has_pending(key) is False
|
||||
|
||||
|
||||
class TestSessionKeyContext:
|
||||
def test_context_session_key_overrides_process_env(self):
|
||||
token = approval_module.set_current_session_key("alice")
|
||||
try:
|
||||
with mock_patch.dict("os.environ", {"HERMES_SESSION_KEY": "bob"}, clear=False):
|
||||
assert approval_module.get_current_session_key() == "alice"
|
||||
finally:
|
||||
approval_module.reset_current_session_key(token)
|
||||
|
||||
def test_gateway_runner_binds_session_key_to_context_before_agent_run(self):
|
||||
run_py = Path(__file__).resolve().parents[2] / "gateway" / "run.py"
|
||||
module = ast.parse(run_py.read_text(encoding="utf-8"))
|
||||
|
||||
run_sync = None
|
||||
for node in ast.walk(module):
|
||||
if isinstance(node, ast.FunctionDef) and node.name == "run_sync":
|
||||
run_sync = node
|
||||
break
|
||||
|
||||
assert run_sync is not None, "gateway.run.run_sync not found"
|
||||
|
||||
called_names = set()
|
||||
for node in ast.walk(run_sync):
|
||||
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
||||
called_names.add(node.func.id)
|
||||
|
||||
assert "set_current_session_key" in called_names
|
||||
assert "reset_current_session_key" in called_names
|
||||
|
||||
def test_context_keeps_pending_approval_attached_to_originating_session(self):
|
||||
import os
|
||||
import threading
|
||||
|
||||
clear_session("alice")
|
||||
clear_session("bob")
|
||||
pop_pending("alice")
|
||||
pop_pending("bob")
|
||||
approval_module._permanent_approved.clear()
|
||||
|
||||
alice_ready = threading.Event()
|
||||
bob_ready = threading.Event()
|
||||
|
||||
def worker_alice():
|
||||
token = approval_module.set_current_session_key("alice")
|
||||
try:
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = "alice"
|
||||
alice_ready.set()
|
||||
bob_ready.wait()
|
||||
approval_module.check_all_command_guards("rm -rf /tmp/alice-secret", "local")
|
||||
finally:
|
||||
approval_module.reset_current_session_key(token)
|
||||
|
||||
def worker_bob():
|
||||
alice_ready.wait()
|
||||
token = approval_module.set_current_session_key("bob")
|
||||
try:
|
||||
os.environ["HERMES_SESSION_KEY"] = "bob"
|
||||
bob_ready.set()
|
||||
finally:
|
||||
approval_module.reset_current_session_key(token)
|
||||
|
||||
t1 = threading.Thread(target=worker_alice)
|
||||
t2 = threading.Thread(target=worker_bob)
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
assert pop_pending("alice") is not None
|
||||
assert pop_pending("bob") is None
|
||||
|
||||
|
||||
class TestRmFalsePositiveFix:
|
||||
"""Regression tests: filenames starting with 'r' must NOT trigger recursive delete."""
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ This module is the single source of truth for the dangerous command system:
|
||||
- Permanent allowlist persistence (config.yaml)
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -18,6 +19,33 @@ from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-thread/per-task gateway session identity.
|
||||
# Gateway runs agent turns concurrently in executor threads, so reading a
|
||||
# process-global env var for session identity is racy. Keep env fallback for
|
||||
# legacy single-threaded callers, but prefer the context-local value when set.
|
||||
_approval_session_key: contextvars.ContextVar[str] = contextvars.ContextVar(
|
||||
"approval_session_key",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
def set_current_session_key(session_key: str):
|
||||
"""Bind the active approval session key to the current context."""
|
||||
return _approval_session_key.set(session_key or "")
|
||||
|
||||
|
||||
def reset_current_session_key(token) -> None:
|
||||
"""Restore the prior approval session key context."""
|
||||
_approval_session_key.reset(token)
|
||||
|
||||
|
||||
def get_current_session_key(default: str = "default") -> str:
|
||||
"""Return the active session key, preferring context-local state."""
|
||||
session_key = _approval_session_key.get()
|
||||
if session_key:
|
||||
return session_key
|
||||
return os.getenv("HERMES_SESSION_KEY", default)
|
||||
|
||||
# Sensitive write targets that should trigger approval even when referenced
|
||||
# via shell expansions like $HOME or $HERMES_HOME.
|
||||
_SSH_SENSITIVE_PATH = r'(?:~|\$home|\$\{home\})/\.ssh(?:/|$)'
|
||||
@@ -534,7 +562,7 @@ def check_dangerous_command(command: str, env_type: str,
|
||||
if not is_dangerous:
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
session_key = os.getenv("HERMES_SESSION_KEY", "default")
|
||||
session_key = get_current_session_key()
|
||||
if is_approved(session_key, pattern_key):
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
@@ -660,7 +688,7 @@ def check_all_command_guards(command: str, env_type: str,
|
||||
# Collect warnings that need approval
|
||||
warnings = [] # list of (pattern_key, description, is_tirith)
|
||||
|
||||
session_key = os.getenv("HERMES_SESSION_KEY", "default")
|
||||
session_key = get_current_session_key()
|
||||
|
||||
# Tirith block/warn → approvable warning with rich findings.
|
||||
# Previously, tirith "block" was a hard block with no approval prompt.
|
||||
|
||||
Reference in New Issue
Block a user