fix(gateway): isolate approval session key per turn
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""Tests for the dangerous command approval module."""
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch as mock_patch
|
||||
|
||||
import tools.approval as approval_module
|
||||
@@ -148,6 +150,79 @@ class TestApproveAndCheckSession:
|
||||
assert has_pending(key) is False
|
||||
|
||||
|
||||
class TestSessionKeyContext:
|
||||
def test_context_session_key_overrides_process_env(self):
|
||||
token = approval_module.set_current_session_key("alice")
|
||||
try:
|
||||
with mock_patch.dict("os.environ", {"HERMES_SESSION_KEY": "bob"}, clear=False):
|
||||
assert approval_module.get_current_session_key() == "alice"
|
||||
finally:
|
||||
approval_module.reset_current_session_key(token)
|
||||
|
||||
def test_gateway_runner_binds_session_key_to_context_before_agent_run(self):
|
||||
run_py = Path(__file__).resolve().parents[2] / "gateway" / "run.py"
|
||||
module = ast.parse(run_py.read_text(encoding="utf-8"))
|
||||
|
||||
run_sync = None
|
||||
for node in ast.walk(module):
|
||||
if isinstance(node, ast.FunctionDef) and node.name == "run_sync":
|
||||
run_sync = node
|
||||
break
|
||||
|
||||
assert run_sync is not None, "gateway.run.run_sync not found"
|
||||
|
||||
called_names = set()
|
||||
for node in ast.walk(run_sync):
|
||||
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
||||
called_names.add(node.func.id)
|
||||
|
||||
assert "set_current_session_key" in called_names
|
||||
assert "reset_current_session_key" in called_names
|
||||
|
||||
def test_context_keeps_pending_approval_attached_to_originating_session(self):
|
||||
import os
|
||||
import threading
|
||||
|
||||
clear_session("alice")
|
||||
clear_session("bob")
|
||||
pop_pending("alice")
|
||||
pop_pending("bob")
|
||||
approval_module._permanent_approved.clear()
|
||||
|
||||
alice_ready = threading.Event()
|
||||
bob_ready = threading.Event()
|
||||
|
||||
def worker_alice():
|
||||
token = approval_module.set_current_session_key("alice")
|
||||
try:
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = "alice"
|
||||
alice_ready.set()
|
||||
bob_ready.wait()
|
||||
approval_module.check_all_command_guards("rm -rf /tmp/alice-secret", "local")
|
||||
finally:
|
||||
approval_module.reset_current_session_key(token)
|
||||
|
||||
def worker_bob():
|
||||
alice_ready.wait()
|
||||
token = approval_module.set_current_session_key("bob")
|
||||
try:
|
||||
os.environ["HERMES_SESSION_KEY"] = "bob"
|
||||
bob_ready.set()
|
||||
finally:
|
||||
approval_module.reset_current_session_key(token)
|
||||
|
||||
t1 = threading.Thread(target=worker_alice)
|
||||
t2 = threading.Thread(target=worker_bob)
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
assert pop_pending("alice") is not None
|
||||
assert pop_pending("bob") is None
|
||||
|
||||
|
||||
class TestRmFalsePositiveFix:
|
||||
"""Regression tests: filenames starting with 'r' must NOT trigger recursive delete."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user