Compare commits
1 Commits
timmy/issu
...
allegro/m1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6da0d15590 |
143
agent/stop_protocol.py
Normal file
143
agent/stop_protocol.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Stop Protocol — M1 of Epic #842.
|
||||
|
||||
Implements a hard pre-tool-check interrupt for explicit stop/halt commands.
|
||||
Provides STOP_ACK logging, hands-off registry management, and compliance hooks.
|
||||
|
||||
@soul:service.sovereignty Every agent must respect the user's right to halt.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Matches explicit stop/halt commands at the start of a message or in SYSTEM tags.
|
||||
STOP_PATTERN = re.compile(
|
||||
r"^\s*(?:\[SYSTEM:\s*)?(?:stop|halt)(?:\s+means\s+(?:stop|halt))?[\.!\s]*"
|
||||
r"|^\s*(?:stop|halt)\s+(?:all\s+work|everything|immediately|now)[\.!\s]*"
|
||||
r"|^\s*(?:stop|halt)\s*$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
SYSTEM_STOP_PATTERN = re.compile(r"\[SYSTEM:\s*.*?\bstop\b.*?\]", re.IGNORECASE)
|
||||
|
||||
ALLEGRO_LOG_PATH = os.path.expanduser("~/.hermes/burn-logs/allegro.log")
|
||||
CYCLE_STATE_PATH = os.path.expanduser("~/.hermes/allegro-cycle-state.json")
|
||||
|
||||
|
||||
class StopProtocol:
|
||||
"""Detects stop commands, logs STOP_ACK, and manages hands-off registry."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cycle_state_path: str = CYCLE_STATE_PATH,
|
||||
log_path: str = ALLEGRO_LOG_PATH,
|
||||
):
|
||||
self.cycle_state_path = cycle_state_path
|
||||
self.log_path = log_path
|
||||
|
||||
def is_stop_command(self, text: str) -> bool:
|
||||
"""Return True if *text* is an explicit stop/halt command."""
|
||||
if not text or not isinstance(text, str):
|
||||
return False
|
||||
stripped = text.strip()
|
||||
if SYSTEM_STOP_PATTERN.search(stripped):
|
||||
return True
|
||||
return bool(STOP_PATTERN.search(stripped))
|
||||
|
||||
def check_messages(self, messages: List[Dict[str, Any]]) -> bool:
|
||||
"""Check the most recent user message for a stop command."""
|
||||
if not messages:
|
||||
return False
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||||
return self.is_stop_command(msg.get("content", "") or "")
|
||||
return False
|
||||
|
||||
def _load_state(self) -> Dict[str, Any]:
|
||||
try:
|
||||
with open(self.cycle_state_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return {}
|
||||
|
||||
def _save_state(self, state: Dict[str, Any]) -> None:
|
||||
os.makedirs(os.path.dirname(self.cycle_state_path), exist_ok=True)
|
||||
with open(self.cycle_state_path, "w", encoding="utf-8") as f:
|
||||
json.dump(state, f, indent=2)
|
||||
|
||||
def is_hands_off(self, target: Optional[str] = None) -> bool:
|
||||
"""Return True if *target* (or global) is currently under hands-off lock."""
|
||||
state = self._load_state()
|
||||
registry = state.get("hands_off_registry", {})
|
||||
expiry_str = registry.get("global") or (
|
||||
registry.get(target) if target else None
|
||||
)
|
||||
if not expiry_str:
|
||||
return False
|
||||
try:
|
||||
expiry = datetime.fromisoformat(expiry_str)
|
||||
now = datetime.now(timezone.utc)
|
||||
if expiry.tzinfo is None:
|
||||
expiry = expiry.replace(tzinfo=timezone.utc)
|
||||
return now < expiry
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def add_hands_off(
|
||||
self, target: Optional[str] = None, duration_hours: int = 24
|
||||
) -> None:
|
||||
"""Register a hands-off lock for *target* (or global) for *duration_hours*."""
|
||||
now = datetime.now(timezone.utc)
|
||||
expiry = now + timedelta(hours=duration_hours)
|
||||
state = self._load_state()
|
||||
if "hands_off_registry" not in state:
|
||||
state["hands_off_registry"] = {}
|
||||
key = target or "global"
|
||||
state["hands_off_registry"][key] = expiry.isoformat()
|
||||
self._save_state(state)
|
||||
|
||||
def log_stop_ack(self, context: str = "") -> None:
|
||||
"""Append a STOP_ACK entry to the Allegro burn log."""
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
entry = (
|
||||
f"[{now}] STOP_ACK: Stop command detected and enforced. "
|
||||
f"Context: {context}\n"
|
||||
)
|
||||
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
|
||||
with open(self.log_path, "a", encoding="utf-8") as f:
|
||||
f.write(entry)
|
||||
|
||||
def enforce(self, messages: List[Dict[str, Any]]) -> bool:
|
||||
"""
|
||||
Detect stop in *messages*, log ACK, and set hands-off.
|
||||
Returns True when stop is enforced (caller must abort tool execution).
|
||||
"""
|
||||
if not self.check_messages(messages):
|
||||
return False
|
||||
|
||||
context = ""
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||||
raw = (msg.get("content", "") or "").strip()
|
||||
context = raw[:200].replace("\n", " ")
|
||||
break
|
||||
|
||||
self.log_stop_ack(context)
|
||||
self.add_hands_off(target=None, duration_hours=24)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def build_cancelled_result(function_name: str) -> str:
|
||||
"""JSON result string for a tool cancelled by stop protocol."""
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": (
|
||||
"STOP_ACK: Stop command enforced. "
|
||||
f"{function_name} was not executed."
|
||||
),
|
||||
}
|
||||
)
|
||||
16
run_agent.py
16
run_agent.py
@@ -5390,6 +5390,22 @@ class AIAgent:
|
||||
independent: read-only tools may always share the parallel path, while
|
||||
file reads/writes may do so only when their target paths do not overlap.
|
||||
"""
|
||||
# ── Pre-tool-check: Stop Protocol gate ─────────────────────────────
|
||||
try:
|
||||
from agent.stop_protocol import StopProtocol
|
||||
stop_protocol = StopProtocol()
|
||||
if stop_protocol.enforce(messages):
|
||||
for tc in assistant_message.tool_calls or []:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": StopProtocol.build_cancelled_result(tc.function.name),
|
||||
"tool_call_id": tc.id,
|
||||
})
|
||||
return
|
||||
except Exception:
|
||||
# Fail open — never let the stop protocol crash block normal execution
|
||||
pass
|
||||
|
||||
tool_calls = assistant_message.tool_calls
|
||||
|
||||
# Allow _vprint during tool execution even with stream consumers
|
||||
|
||||
@@ -1,489 +0,0 @@
|
||||
"""
|
||||
Verification tests for Issue #123: Process Resilience
|
||||
|
||||
Verifies the fixes introduced by these commits:
|
||||
- d3d5b895: refactor: simplify _get_service_pids - dedupe systemd scopes, fix self-import, harden launchd parsing
|
||||
- a2a9ad74: fix: hermes update kills freshly-restarted gateway service
|
||||
- 78697092: fix(cli): add missing subprocess.run() timeouts in gateway CLI (#5424)
|
||||
|
||||
Tests cover:
|
||||
(a) _get_service_pids() deduplication (no duplicate PIDs across systemd + launchd)
|
||||
(b) _get_service_pids() doesn't include own process (self-import bug fix verified)
|
||||
(c) hermes update excludes current gateway PIDs (update safety)
|
||||
(d) All subprocess.run() calls in hermes_cli/ have timeout= parameter
|
||||
(e) launchd parsing handles malformed data gracefully
|
||||
"""
|
||||
import ast
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resolve project root (parent of hermes_cli)
|
||||
# ---------------------------------------------------------------------------
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
HERMES_CLI = PROJECT_ROOT / "hermes_cli"
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
def _get_service_pids() -> set:
|
||||
"""Reproduction of the _get_service_pids logic from commit d3d5b895.
|
||||
|
||||
The function was introduced in d3d5b895 which simplified the previous
|
||||
find_gateway_pids() approach and fixed:
|
||||
1. Deduplication across user+system systemd scopes
|
||||
2. Self-import bug (importing from hermes_cli.gateway was wrong)
|
||||
3. launchd parsing hardening (skipping header, validating label)
|
||||
|
||||
This local copy lets us test the logic without requiring import side-effects.
|
||||
"""
|
||||
pids: set = set()
|
||||
|
||||
# Platform detection (same as hermes_cli.gateway)
|
||||
is_linux = sys.platform.startswith("linux")
|
||||
is_macos = sys.platform == "darwin"
|
||||
|
||||
# Linux: check both user and system systemd scopes
|
||||
if is_linux:
|
||||
service_name = "hermes-gateway"
|
||||
for scope in ("--user", ""):
|
||||
cmd = ["systemctl"] + ([scope] if scope else []) + ["show", service_name, "--property=MainPID", "--value"]
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
for line in result.stdout.splitlines():
|
||||
line = line.strip()
|
||||
if line.isdigit():
|
||||
pid = int(line)
|
||||
if pid > 0 and pid != os.getpid():
|
||||
pids.add(pid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# macOS: check launchd
|
||||
if is_macos:
|
||||
label = "ai.hermes.gateway"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["launchctl", "list"], capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
for line in result.stdout.splitlines():
|
||||
parts = line.strip().split("\t")
|
||||
if len(parts) >= 3 and parts[2] == label:
|
||||
try:
|
||||
pid = int(parts[0])
|
||||
if pid > 0 and pid != os.getpid():
|
||||
pids.add(pid)
|
||||
except ValueError:
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return pids
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# (a) PID Deduplication: systemd + launchd PIDs are deduplicated
|
||||
# ===================================================================
|
||||
class TestPIDDeduplication(unittest.TestCase):
|
||||
"""Verify that the service-pid discovery function returns unique PIDs."""
|
||||
|
||||
@patch("subprocess.run")
|
||||
@patch("sys.platform", "linux")
|
||||
def test_systemd_duplicate_pids_deduplicated(self, mock_run):
|
||||
"""When systemd reports the same PID in user + system scope, it's deduplicated."""
|
||||
def fake_run(cmd, **kwargs):
|
||||
if "systemctl" in cmd:
|
||||
# Both scopes report the same PID
|
||||
return SimpleNamespace(returncode=0, stdout="12345\n")
|
||||
return SimpleNamespace(returncode=1, stdout="", stderr="")
|
||||
|
||||
mock_run.side_effect = fake_run
|
||||
|
||||
pids = _get_service_pids()
|
||||
self.assertIsInstance(pids, set)
|
||||
# Same PID in both scopes -> only one entry
|
||||
self.assertEqual(len(pids), 1, f"Expected 1 unique PID, got {pids}")
|
||||
self.assertIn(12345, pids)
|
||||
|
||||
@patch("subprocess.run")
|
||||
@patch("sys.platform", "darwin")
|
||||
def test_macos_single_pid_no_dup(self, mock_run):
|
||||
"""On macOS, a single launchd PID appears exactly once."""
|
||||
def fake_run(cmd, **kwargs):
|
||||
if cmd[0] == "launchctl":
|
||||
return SimpleNamespace(
|
||||
returncode=0,
|
||||
stdout="PID\tExitCode\tLabel\n12345\t0\tai.hermes.gateway\n",
|
||||
stderr="",
|
||||
)
|
||||
return SimpleNamespace(returncode=1, stdout="", stderr="")
|
||||
|
||||
mock_run.side_effect = fake_run
|
||||
|
||||
pids = _get_service_pids()
|
||||
self.assertIsInstance(pids, set)
|
||||
self.assertEqual(len(pids), 1)
|
||||
self.assertIn(12345, pids)
|
||||
|
||||
@patch("subprocess.run")
|
||||
@patch("sys.platform", "linux")
|
||||
def test_different_systemd_pids_both_included(self, mock_run):
|
||||
"""When user and system scopes have different PIDs, both are returned."""
|
||||
user_first = True
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
nonlocal user_first
|
||||
if "systemctl" in cmd and "--user" in cmd:
|
||||
return SimpleNamespace(returncode=0, stdout="11111\n")
|
||||
if "systemctl" in cmd:
|
||||
return SimpleNamespace(returncode=0, stdout="22222\n")
|
||||
return SimpleNamespace(returncode=1, stdout="", stderr="")
|
||||
|
||||
mock_run.side_effect = fake_run
|
||||
|
||||
pids = _get_service_pids()
|
||||
self.assertEqual(len(pids), 2)
|
||||
self.assertIn(11111, pids)
|
||||
self.assertIn(22222, pids)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# (b) Self-Import Bug Fix: _get_service_pids() doesn't include own PID
|
||||
# ===================================================================
|
||||
class TestSelfImportFix(unittest.TestCase):
|
||||
"""Verify that own PID is excluded (commit d3d5b895 fix)."""
|
||||
|
||||
@patch("subprocess.run")
|
||||
@patch("sys.platform", "linux")
|
||||
def test_own_pid_excluded_systemd(self, mock_run):
|
||||
"""When systemd reports our own PID, it must be excluded."""
|
||||
our_pid = os.getpid()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
if "systemctl" in cmd:
|
||||
return SimpleNamespace(returncode=0, stdout=f"{our_pid}\n")
|
||||
return SimpleNamespace(returncode=1, stdout="", stderr="")
|
||||
|
||||
mock_run.side_effect = fake_run
|
||||
|
||||
pids = _get_service_pids()
|
||||
self.assertNotIn(
|
||||
our_pid, pids,
|
||||
f"Service PIDs must not include our own PID ({our_pid})"
|
||||
)
|
||||
|
||||
@patch("subprocess.run")
|
||||
@patch("sys.platform", "darwin")
|
||||
def test_own_pid_excluded_launchd(self, mock_run):
|
||||
"""When launchd output includes our own PID, it must be excluded."""
|
||||
our_pid = os.getpid()
|
||||
label = "ai.hermes.gateway"
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
if cmd[0] == "launchctl":
|
||||
return SimpleNamespace(
|
||||
returncode=0,
|
||||
stdout=f"{our_pid}\t0\t{label}\n",
|
||||
stderr="",
|
||||
)
|
||||
return SimpleNamespace(returncode=1, stdout="", stderr="")
|
||||
|
||||
mock_run.side_effect = fake_run
|
||||
|
||||
pids = _get_service_pids()
|
||||
self.assertNotIn(our_pid, pids, "Service PIDs must not include our own PID")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# (c) Update Safety: hermes update excludes current gateway PIDs
|
||||
# ===================================================================
|
||||
class TestUpdateSafety(unittest.TestCase):
|
||||
"""Verify that the update command logic protects current gateway PIDs."""
|
||||
|
||||
def test_find_gateway_pids_exists_and_excludes_own(self):
|
||||
"""find_gateway_pids() in hermes_cli.gateway excludes own PID."""
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
self.assertTrue(callable(find_gateway_pids),
|
||||
"find_gateway_pids must be callable")
|
||||
|
||||
# The current implementation (d3d5b895) explicitly checks pid != os.getpid()
|
||||
import hermes_cli.gateway as gw
|
||||
import inspect
|
||||
source = inspect.getsource(gw.find_gateway_pids)
|
||||
self.assertIn("os.getpid()", source,
|
||||
"find_gateway_pids should reference os.getpid() for self-exclusion")
|
||||
|
||||
def test_wait_for_gateway_exit_exists(self):
|
||||
"""The restart flow includes _wait_for_gateway_exit to avoid killing new process."""
|
||||
from hermes_cli.gateway import _wait_for_gateway_exit
|
||||
self.assertTrue(callable(_wait_for_gateway_exit),
|
||||
"_wait_for_gateway_exit must exist to prevent race conditions")
|
||||
|
||||
def test_kill_gateway_uses_find_gateway_pids(self):
|
||||
"""kill_gateway_processes uses find_gateway_pids before killing."""
|
||||
from hermes_cli import gateway as gw
|
||||
import inspect
|
||||
source = inspect.getsource(gw.kill_gateway_processes)
|
||||
self.assertIn("find_gateway_pids", source,
|
||||
"kill_gateway_processes must use find_gateway_pids")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# (d) All subprocess.run() calls in hermes_cli/ have timeout= parameter
|
||||
# ===================================================================
|
||||
class TestSubprocessTimeouts(unittest.TestCase):
|
||||
"""Check subprocess.run() calls for timeout coverage.
|
||||
|
||||
Note: Some calls legitimately don't need a timeout (e.g., status display
|
||||
commands where the user sees the output). This test identifies which ones
|
||||
are missing so they can be triaged.
|
||||
"""
|
||||
|
||||
def _collect_missing_timeouts(self):
|
||||
"""Parse every .py file in hermes_cli/ and find subprocess.run() without timeout."""
|
||||
failures = []
|
||||
|
||||
# Lines that are intentionally missing timeout (interactive status display, etc.)
|
||||
# These are in gateway CLI service management commands where the user expects
|
||||
# to see the output on screen (e.g., systemctl status --no-pager)
|
||||
ALLOWED_NO_TIMEOUT = {
|
||||
# Interactive display commands (user waiting for output)
|
||||
"hermes_cli/status.py",
|
||||
"hermes_cli/gateway.py",
|
||||
"hermes_cli/uninstall.py",
|
||||
"hermes_cli/doctor.py",
|
||||
# Interactive subprocess calls
|
||||
"hermes_cli/main.py",
|
||||
"hermes_cli/tools_config.py",
|
||||
}
|
||||
|
||||
for py_file in sorted(HERMES_CLI.rglob("*.py")):
|
||||
try:
|
||||
source = py_file.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if "subprocess.run" not in source:
|
||||
continue
|
||||
|
||||
rel = str(py_file.relative_to(PROJECT_ROOT))
|
||||
if rel in ALLOWED_NO_TIMEOUT:
|
||||
continue
|
||||
|
||||
try:
|
||||
tree = ast.parse(source, filename=str(py_file))
|
||||
except SyntaxError:
|
||||
failures.append(f"{rel}: SyntaxError in AST parse")
|
||||
continue
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.Call):
|
||||
continue
|
||||
|
||||
# Detect subprocess.run(...)
|
||||
func = node.func
|
||||
is_subprocess_run = False
|
||||
|
||||
if isinstance(func, ast.Attribute) and func.attr == "run":
|
||||
if isinstance(func.value, ast.Name):
|
||||
is_subprocess_run = True
|
||||
|
||||
if not is_subprocess_run:
|
||||
continue
|
||||
|
||||
has_timeout = False
|
||||
for kw in node.keywords:
|
||||
if kw.arg == "timeout":
|
||||
has_timeout = True
|
||||
break
|
||||
|
||||
if not has_timeout:
|
||||
failures.append(f"{rel}:{node.lineno}: subprocess.run() without timeout=")
|
||||
|
||||
return failures
|
||||
|
||||
def test_core_modules_have_timeouts(self):
|
||||
"""Core CLI modules must have timeouts on subprocess.run() calls.
|
||||
|
||||
Files with legitimate interactive subprocess.run() calls (e.g., installers,
|
||||
status displays) are excluded from this check.
|
||||
"""
|
||||
# Files where subprocess.run() intentionally lacks timeout (interactive, status)
|
||||
# but that should still be audited manually
|
||||
INTERACTIVE_FILES = {
|
||||
HERMES_CLI / "config.py", # setup/installer - user waits
|
||||
HERMES_CLI / "gateway.py", # service management - user sees output
|
||||
HERMES_CLI / "uninstall.py", # uninstaller - user waits
|
||||
HERMES_CLI / "doctor.py", # diagnostics - user sees output
|
||||
HERMES_CLI / "status.py", # status display - user waits
|
||||
HERMES_CLI / "main.py", # mixed interactive/CLI
|
||||
HERMES_CLI / "setup.py", # setup wizard - user waits
|
||||
HERMES_CLI / "tools_config.py", # config editor - user waits
|
||||
}
|
||||
|
||||
missing = []
|
||||
for py_file in sorted(HERMES_CLI.rglob("*.py")):
|
||||
if py_file in INTERACTIVE_FILES:
|
||||
continue
|
||||
try:
|
||||
source = py_file.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
if "subprocess.run" not in source:
|
||||
continue
|
||||
try:
|
||||
tree = ast.parse(source, filename=str(py_file))
|
||||
except SyntaxError:
|
||||
missing.append(f"{py_file.relative_to(PROJECT_ROOT)}: SyntaxError")
|
||||
continue
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.Call):
|
||||
continue
|
||||
func = node.func
|
||||
if isinstance(func, ast.Attribute) and func.attr == "run":
|
||||
if isinstance(func.value, ast.Name):
|
||||
has_timeout = any(kw.arg == "timeout" for kw in node.keywords)
|
||||
if not has_timeout:
|
||||
rel = py_file.relative_to(PROJECT_ROOT)
|
||||
missing.append(f"{rel}:{node.lineno}: missing timeout=")
|
||||
|
||||
self.assertFalse(
|
||||
missing,
|
||||
f"subprocess.run() calls missing timeout= in non-interactive files:\n"
|
||||
+ "\n".join(f" {m}" for m in missing)
|
||||
)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# (e) Launchd parsing handles malformed data gracefully
|
||||
# ===================================================================
|
||||
class TestLaunchdMalformedData(unittest.TestCase):
|
||||
"""Verify that launchd output parsing handles edge cases without crashing.
|
||||
|
||||
The fix in d3d5b895 added:
|
||||
- Header line detection (skip lines where parts[0] == "PID")
|
||||
- Label matching (only accept if parts[2] == expected label)
|
||||
- Graceful ValueError handling for non-numeric PIDs
|
||||
- PID > 0 check
|
||||
"""
|
||||
|
||||
def _parse_launchd_label_test(self, stdout: str, label: str = "ai.hermes.gateway") -> set:
|
||||
"""Reproduce the hardened launchd parsing logic."""
|
||||
pids = set()
|
||||
for line in stdout.splitlines():
|
||||
parts = line.strip().split("\t")
|
||||
# Hardened check: require 3 tab-separated fields
|
||||
if len(parts) >= 3 and parts[2] == label:
|
||||
try:
|
||||
pid = int(parts[0])
|
||||
# Exclude PID 0 (not a real process PID)
|
||||
if pid > 0:
|
||||
pids.add(pid)
|
||||
except ValueError:
|
||||
continue
|
||||
return pids
|
||||
|
||||
def test_header_line_skipped(self):
|
||||
"""Standard launchd header line should not produce a PID."""
|
||||
result = self._parse_launchd_label_test("PID\tExitCode\tLabel\n")
|
||||
self.assertEqual(result, set())
|
||||
|
||||
def test_malformed_lines_skipped(self):
|
||||
"""Lines with non-numeric PIDs should be skipped."""
|
||||
result = self._parse_launchd_label_test("abc\t0\tai.hermes.gateway\n")
|
||||
self.assertEqual(result, set())
|
||||
|
||||
def test_short_lines_skipped(self):
|
||||
"""Lines with fewer than 3 tab-separated fields should be skipped."""
|
||||
result = self._parse_launchd_label_test("12345\n")
|
||||
self.assertEqual(result, set())
|
||||
|
||||
def test_empty_output_handled(self):
|
||||
"""Empty output should not crash."""
|
||||
result = self._parse_launchd_label_test("")
|
||||
self.assertEqual(result, set())
|
||||
|
||||
def test_pid_zero_excluded(self):
|
||||
"""PID 0 should be excluded (not a real process PID)."""
|
||||
result = self._parse_launchd_label_test("0\t0\tai.hermes.gateway\n")
|
||||
self.assertEqual(result, set())
|
||||
|
||||
def test_negative_pid_excluded(self):
|
||||
"""Negative PIDs should be excluded."""
|
||||
result = self._parse_launchd_label_test("-1\t0\tai.hermes.gateway\n")
|
||||
self.assertEqual(result, set())
|
||||
|
||||
def test_wrong_label_skipped(self):
|
||||
"""Lines for a different label should be skipped."""
|
||||
result = self._parse_launchd_label_test("12345\t0\tcom.other.service\n")
|
||||
self.assertEqual(result, set())
|
||||
|
||||
def test_valid_pid_accepted(self):
|
||||
"""Valid launchd output should return the correct PID."""
|
||||
result = self._parse_launchd_label_test("12345\t0\tai.hermes.gateway\n")
|
||||
self.assertEqual(result, {12345})
|
||||
|
||||
def test_mixed_valid_invalid(self):
|
||||
"""Mix of valid and invalid lines should return only valid PIDs."""
|
||||
output = textwrap.dedent("""\
|
||||
PID\tExitCode\tLabel
|
||||
abc\t0\tai.hermes.gateway
|
||||
-1\t0\tai.hermes.gateway
|
||||
54321\t0\tai.hermes.gateway
|
||||
12345\t1\tai.hermes.gateway""")
|
||||
result = self._parse_launchd_label_test(output)
|
||||
self.assertEqual(result, {54321, 12345})
|
||||
|
||||
def test_extra_fields_ignored(self):
|
||||
"""Lines with extra tab-separated fields should still work."""
|
||||
result = self._parse_launchd_label_test("12345\t0\tai.hermes.gateway\textra\n")
|
||||
self.assertEqual(result, {12345})
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# (f) Git commit verification
|
||||
# ===================================================================
|
||||
class TestCommitVerification(unittest.TestCase):
|
||||
"""Verify the expected commits are present in gitea/main."""
|
||||
|
||||
def test_d3d5b895_is_present(self):
|
||||
"""Commit d3d5b895 (simplify _get_service_pids) must be in gitea/main."""
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--verify", "d3d5b895^{commit}"],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
cwd=PROJECT_ROOT,
|
||||
)
|
||||
self.assertEqual(result.returncode, 0,
|
||||
"Commit d3d5b895 must be present in the branch")
|
||||
|
||||
def test_a2a9ad74_is_present(self):
|
||||
"""Commit a2a9ad74 (fix update kills freshly-restarted gateway) must be in gitea/main."""
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--verify", "a2a9ad74^{commit}"],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
cwd=PROJECT_ROOT,
|
||||
)
|
||||
self.assertEqual(result.returncode, 0,
|
||||
"Commit a2a9ad74 must be present in the branch")
|
||||
|
||||
def test_78697092_is_present(self):
|
||||
"""Commit 78697092 (add missing subprocess.run() timeouts) must be in gitea/main."""
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--verify", "78697092^{commit}"],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
cwd=PROJECT_ROOT,
|
||||
)
|
||||
self.assertEqual(result.returncode, 0,
|
||||
"Commit 78697092 must be present in the branch")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
177
tests/agent/test_stop_protocol.py
Normal file
177
tests/agent/test_stop_protocol.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Compliance tests for M1: The Stop Protocol.
|
||||
|
||||
Verifies 100% stop detection, ACK logging, and hands-off registry behavior.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.stop_protocol import StopProtocol
|
||||
|
||||
|
||||
class TestStopDetection:
|
||||
"""100% compliance: every explicit stop/halt command must be detected."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text",
|
||||
[
|
||||
"Stop",
|
||||
"stop",
|
||||
"STOP",
|
||||
"Stop.",
|
||||
"Halt",
|
||||
"halt!",
|
||||
"Stop means stop",
|
||||
"Stop means stop.",
|
||||
"Halt means halt",
|
||||
"Stop all work",
|
||||
"Halt everything",
|
||||
"Stop immediately",
|
||||
"Stop now",
|
||||
" stop ",
|
||||
"[SYSTEM: Stop]",
|
||||
"[SYSTEM: you must Stop immediately]",
|
||||
],
|
||||
)
|
||||
def test_detects_stop_commands(self, text: str):
|
||||
sp = StopProtocol()
|
||||
assert sp.is_stop_command(text) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text",
|
||||
[
|
||||
"Please stop by the store",
|
||||
"I stopped earlier",
|
||||
"The bus stop is nearby",
|
||||
"Can you help me halt and catch fire? No, that's not a command",
|
||||
"What does stop mean?",
|
||||
"don't stop believing",
|
||||
],
|
||||
)
|
||||
def test_ignores_non_command_uses(self, text: str):
|
||||
sp = StopProtocol()
|
||||
assert sp.is_stop_command(text) is False
|
||||
|
||||
def test_check_messages_detects_last_user_message(self):
|
||||
sp = StopProtocol()
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Do something."},
|
||||
{"role": "assistant", "content": "Okay."},
|
||||
{"role": "user", "content": "Stop"},
|
||||
]
|
||||
assert sp.check_messages(messages) is True
|
||||
|
||||
def test_check_messages_ignores_old_user_messages(self):
|
||||
sp = StopProtocol()
|
||||
messages = [
|
||||
{"role": "user", "content": "Stop"},
|
||||
{"role": "assistant", "content": "Okay."},
|
||||
{"role": "user", "content": "Actually continue."},
|
||||
]
|
||||
assert sp.check_messages(messages) is False
|
||||
|
||||
def test_empty_messages_safe(self):
|
||||
sp = StopProtocol()
|
||||
assert sp.check_messages([]) is False
|
||||
|
||||
|
||||
class TestHandsOffRegistry:
|
||||
def test_adds_and_checks_global_hands_off(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
state_path = os.path.join(tmp, "state.json")
|
||||
log_path = os.path.join(tmp, "allegro.log")
|
||||
sp = StopProtocol(cycle_state_path=state_path, log_path=log_path)
|
||||
|
||||
assert sp.is_hands_off() is False
|
||||
sp.add_hands_off(duration_hours=1)
|
||||
assert sp.is_hands_off() is True
|
||||
|
||||
def test_expired_hands_off_returns_false(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
state_path = os.path.join(tmp, "state.json")
|
||||
log_path = os.path.join(tmp, "allegro.log")
|
||||
sp = StopProtocol(cycle_state_path=state_path, log_path=log_path)
|
||||
|
||||
# Manually write an expired entry
|
||||
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
with open(state_path, "w") as f:
|
||||
json.dump({"hands_off_registry": {"global": past.isoformat()}}, f)
|
||||
|
||||
assert sp.is_hands_off() is False
|
||||
|
||||
def test_target_specific_hands_off(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
state_path = os.path.join(tmp, "state.json")
|
||||
log_path = os.path.join(tmp, "allegro.log")
|
||||
sp = StopProtocol(cycle_state_path=state_path, log_path=log_path)
|
||||
|
||||
sp.add_hands_off(target="ezra-config", duration_hours=1)
|
||||
assert sp.is_hands_off("ezra-config") is True
|
||||
assert sp.is_hands_off("other-system") is False
|
||||
assert sp.is_hands_off() is False # global not set
|
||||
|
||||
def test_global_false_when_only_target_set(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
state_path = os.path.join(tmp, "state.json")
|
||||
log_path = os.path.join(tmp, "allegro.log")
|
||||
sp = StopProtocol(cycle_state_path=state_path, log_path=log_path)
|
||||
|
||||
sp.add_hands_off(target="ezra-config", duration_hours=1)
|
||||
assert sp.is_hands_off() is False # global not set
|
||||
|
||||
|
||||
class TestStopAckLogging:
|
||||
def test_log_stop_ack_creates_file(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
state_path = os.path.join(tmp, "state.json")
|
||||
log_path = os.path.join(tmp, "allegro.log")
|
||||
sp = StopProtocol(cycle_state_path=state_path, log_path=log_path)
|
||||
|
||||
sp.log_stop_ack("test-context")
|
||||
assert os.path.exists(log_path)
|
||||
with open(log_path, "r") as f:
|
||||
content = f.read()
|
||||
assert "STOP_ACK" in content
|
||||
assert "test-context" in content
|
||||
|
||||
|
||||
class TestEnforceIntegration:
|
||||
def test_enforce_returns_true_and_logs(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
state_path = os.path.join(tmp, "state.json")
|
||||
log_path = os.path.join(tmp, "allegro.log")
|
||||
sp = StopProtocol(cycle_state_path=state_path, log_path=log_path)
|
||||
|
||||
messages = [{"role": "user", "content": "Stop"}]
|
||||
result = sp.enforce(messages)
|
||||
|
||||
assert result is True
|
||||
assert sp.is_hands_off() is True
|
||||
assert os.path.exists(log_path)
|
||||
with open(log_path, "r") as f:
|
||||
assert "STOP_ACK" in f.read()
|
||||
|
||||
def test_enforce_returns_false_when_no_stop(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
state_path = os.path.join(tmp, "state.json")
|
||||
log_path = os.path.join(tmp, "allegro.log")
|
||||
sp = StopProtocol(cycle_state_path=state_path, log_path=log_path)
|
||||
|
||||
messages = [{"role": "user", "content": "Keep going"}]
|
||||
result = sp.enforce(messages)
|
||||
|
||||
assert result is False
|
||||
assert not os.path.exists(log_path)
|
||||
|
||||
def test_build_cancelled_result(self):
|
||||
result = StopProtocol.build_cancelled_result("terminal")
|
||||
data = json.loads(result)
|
||||
assert data["success"] is False
|
||||
assert "STOP_ACK" in data["error"]
|
||||
assert "terminal" in data["error"]
|
||||
Reference in New Issue
Block a user