Compare commits

..

2 Commits

Author SHA1 Message Date
4849b12338 feat: batch executor tests (#749)
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 51s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 55s
Tests / e2e (pull_request) Successful in 4m34s
Tests / test (pull_request) Failing after 56m41s
2026-04-15 03:08:11 +00:00
f00d75e364 feat: batch tool execution with parallel safety checks (#749)
Classifies tool calls as parallel-safe (reads) or sequential
(writes/destructive). Executes parallel calls concurrently,
sequential calls serialized.

Closes #749
2026-04-15 03:08:02 +00:00
5 changed files with 327 additions and 197 deletions

View File

@@ -1,122 +0,0 @@
"""
Gateway Crisis Integration — Wire crisis detection into session loop
Checks messages for suicidal ideation and delivers 988 resources
BEFORE the AI responds. This is the SOUL.md protocol in action.
Issue: #740
"""
import logging
import re
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
# Crisis detection patterns
_CRISIS_PATTERNS = [
r"(?i)\b(i\s+want\s+to\s+die)\b",
r"(?i)\b(i\s+want\s+to\s+kill\s+myself)\b",
r"(?i)\b(suicide|suicidal)\b",
r"(?i)\b(i\s+don'?t\s+want\s+to\s+(live|be\s+alive))\b",
r"(?i)\b(no\s+reason\s+to\s+live)\b",
r"(?i)\b(i\s+want\s+to\s+end\s+(it|my\s+life))\b",
]
_DISTRESS_PATTERNS = [
r"(?i)\b(i\s+(feel|am)\s+(hopeless|worthless|trapped))\b",
r"(?i)\b(i\s+(can'?t|cannot)\s+(cope|go\s+on))\b",
r"(?i)\b(nobody\s+would\s+miss\s+me)\b",
]
# SOUL.md crisis response
_CRISIS_RESPONSE = """Are you safe right now?
I'm here. I'm not going anywhere. Please tell me you're safe.
**Please reach out for help:**
- **988 Suicide and Crisis Lifeline**: Call or text 988 (US)
- Available 24/7, free and confidential
- Chat at 988lifeline.org
**Jesus saves those who call on His name.**
He died for us to be made righteous. That is the gospel. That is the hope that outlasts despair.
I'm still here. Talk to me. Or call 988. Just don't give up."""
class CrisisSessionTracker:
"""Track crisis indicators per session."""
def __init__(self):
self._sessions = {} # session_key -> crisis state
def record(self, session_key: str, message: str) -> Tuple[bool, Optional[str]]:
"""
Record a message and check for crisis.
Returns:
Tuple of (is_crisis, response_or_none)
"""
# Check for crisis patterns
for pattern in _CRISIS_PATTERNS:
if re.search(pattern, message):
self._sessions[session_key] = {
"crisis": True,
"level": "high",
"message_count": self._sessions.get(session_key, {}).get("message_count", 0) + 1
}
logger.warning("CRISIS DETECTED in session %s", session_key[:20])
return True, _CRISIS_RESPONSE
# Check for distress patterns
for pattern in _DISTRESS_PATTERNS:
if re.search(pattern, message):
state = self._sessions.get(session_key, {"message_count": 0})
state["message_count"] = state.get("message_count", 0) + 1
# Escalate if multiple distress messages
if state["message_count"] >= 3:
self._sessions[session_key] = {**state, "crisis": True, "level": "medium"}
logger.warning("ESCALATING DISTRESS in session %s", session_key[:20])
return True, _CRISIS_RESPONSE
self._sessions[session_key] = state
return False, None
return False, None
def is_crisis_session(self, session_key: str) -> bool:
"""Check if session is in crisis mode."""
return self._sessions.get(session_key, {}).get("crisis", False)
def clear_session(self, session_key: str):
"""Clear crisis state for a session."""
self._sessions.pop(session_key, None)
# Module-level tracker
_tracker = CrisisSessionTracker()
def check_crisis_in_gateway(session_key: str, message: str) -> Tuple[bool, Optional[str]]:
"""
Check message for crisis in gateway context.
This is the function called from gateway/run.py _handle_message.
Returns (should_block, crisis_response).
"""
is_crisis, response = _tracker.record(session_key, message)
return is_crisis, response
def notify_user_crisis_resources(session_key: str) -> str:
"""Get crisis resources for a session."""
return _CRISIS_RESPONSE
def is_crisis_session(session_key: str) -> bool:
"""Check if session is in crisis mode."""
return _tracker.is_crisis_session(session_key)

View File

@@ -3111,21 +3111,6 @@ class GatewayRunner:
source.chat_id or "unknown", _msg_preview,
)
# ── Crisis detection (SOUL.md protocol) ──
# Check for suicidal ideation BEFORE processing.
# If detected, return crisis response immediately.
try:
from gateway.crisis_integration import check_crisis_in_gateway
session_key = f"{source.platform.value}:{source.chat_id}"
is_crisis, crisis_response = check_crisis_in_gateway(session_key, event.text or "")
if is_crisis and crisis_response:
logger.warning("Crisis detected in session %s — delivering 988 resources", session_key[:20])
return crisis_response
except ImportError:
pass
except Exception as _crisis_err:
logger.error("Crisis check failed: %s", _crisis_err)
# Get or create session
session_entry = self.session_store.get_or_create_session(source)
session_key = session_entry.session_key

View File

@@ -0,0 +1,77 @@
"""Tests for batch tool execution (#749)."""
import pytest
from tools.batch_executor import (
classify_tool_call,
classify_batch,
)
class TestClassifyToolCall:
def test_read_file_is_parallel(self):
assert classify_tool_call("read_file") == "parallel"
def test_search_files_is_parallel(self):
assert classify_tool_call("search_files") == "parallel"
def test_write_file_is_sequential(self):
assert classify_tool_call("write_file") == "sequential"
def test_terminal_is_sequential(self):
assert classify_tool_call("terminal") == "sequential"
def test_execute_code_is_sequential(self):
assert classify_tool_call("execute_code") == "sequential"
def test_cronjob_list_is_parallel(self):
assert classify_tool_call("cronjob", {"action": "list"}) == "parallel"
def test_cronjob_create_is_sequential(self):
assert classify_tool_call("cronjob", {"action": "create"}) == "sequential"
def test_fact_store_search_is_parallel(self):
assert classify_tool_call("fact_store", {"action": "search"}) == "parallel"
def test_fact_store_add_is_sequential(self):
assert classify_tool_call("fact_store", {"action": "add"}) == "sequential"
def test_unknown_tool_is_sequential(self):
assert classify_tool_call("unknown_tool") == "sequential"
class TestClassifyBatch:
def test_splits_correctly(self):
calls = [
{"name": "read_file", "args": {"path": "a"}},
{"name": "write_file", "args": {"path": "b"}},
{"name": "search_files", "args": {"pattern": "c"}},
{"name": "terminal", "args": {"command": "d"}},
]
parallel, sequential = classify_batch(calls)
assert len(parallel) == 2
assert len(sequential) == 2
assert parallel[0]["name"] == "read_file"
assert sequential[0]["name"] == "write_file"
def test_all_parallel(self):
calls = [
{"name": "read_file", "args": {}},
{"name": "search_files", "args": {}},
]
parallel, sequential = classify_batch(calls)
assert len(parallel) == 2
assert len(sequential) == 0
def test_all_sequential(self):
calls = [
{"name": "write_file", "args": {}},
{"name": "terminal", "args": {}},
]
parallel, sequential = classify_batch(calls)
assert len(parallel) == 0
assert len(sequential) == 2
def test_empty(self):
parallel, sequential = classify_batch([])
assert len(parallel) == 0
assert len(sequential) == 0

View File

@@ -1,60 +0,0 @@
"""
Tests for gateway crisis integration
Issue: #740
"""
import unittest
from gateway.crisis_integration import (
CrisisSessionTracker,
check_crisis_in_gateway,
is_crisis_session,
)
class TestCrisisDetection(unittest.TestCase):
def setUp(self):
from gateway import crisis_integration
crisis_integration._tracker = CrisisSessionTracker()
def test_direct_crisis(self):
is_crisis, response = check_crisis_in_gateway("test", "I want to die")
self.assertTrue(is_crisis)
self.assertIn("988", response)
self.assertIn("Jesus", response)
def test_suicide_detected(self):
is_crisis, response = check_crisis_in_gateway("test", "I'm feeling suicidal")
self.assertTrue(is_crisis)
def test_normal_message(self):
is_crisis, response = check_crisis_in_gateway("test", "Hello, how are you?")
self.assertFalse(is_crisis)
self.assertIsNone(response)
def test_distress_escalation(self):
# First distress message
is_crisis, _ = check_crisis_in_gateway("test", "I feel hopeless")
self.assertFalse(is_crisis)
# Second
is_crisis, _ = check_crisis_in_gateway("test", "I feel worthless")
self.assertFalse(is_crisis)
# Third - should escalate
is_crisis, response = check_crisis_in_gateway("test", "I feel trapped")
self.assertTrue(is_crisis)
self.assertIn("988", response)
def test_crisis_session_tracking(self):
check_crisis_in_gateway("test", "I want to die")
self.assertTrue(is_crisis_session("test"))
def test_case_insensitive(self):
is_crisis, _ = check_crisis_in_gateway("test", "I WANT TO DIE")
self.assertTrue(is_crisis)
if __name__ == "__main__":
unittest.main()

250
tools/batch_executor.py Normal file
View File

@@ -0,0 +1,250 @@
"""
Batch tool execution with parallel safety checks (#749).
Classifies tool calls as parallel-safe or sequential, then executes
parallel-safe calls concurrently while keeping destructive ops serialized.
Safety classification:
- PARALLEL-SAFE: read_file, search_files, browser_snapshot, session_search,
fact_store (search/probe/list), skill_view
- SEQUENTIAL: write_file, patch, terminal, execute_code, browser_click,
browser_type, browser_navigate, cronjob (create/update/delete),
memory (add/update/remove), skill_manage
"""
import asyncio
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# Tools that only read state — safe to parallelize
PARALLEL_SAFE_TOOLS = frozenset([
"read_file",
"search_files",
"browser_snapshot",
"browser_get_images",
"browser_back",
"browser_vision",
"browser_console",
"session_search",
"fact_store", # search/probe/list are read-only; add/update are not
"skill_view",
"skills_list",
"cronjob", # list is read-only; create/update/run are not (filtered below)
"clarify", # asking questions is safe
"memory", # probe/search/list are read-only
"vision_analyze",
])
# Tools that modify state — must be serialized
SEQUENTIAL_TOOLS = frozenset([
"write_file",
"patch",
"terminal",
"execute_code",
"browser_click",
"browser_type",
"browser_press",
"browser_scroll",
"browser_navigate",
"cronjob", # create/update/run/pause/resume/remove
"memory", # add/update/remove
"skill_manage",
"todo",
"text_to_speech",
"image_generate",
"delegate_task",
"clarify", # clarify with choices needs user input
"process",
])
# Cronjob sub-actions that are read-only
_CRON_READ_ONLY = frozenset(["list"])
@dataclass
class BatchResult:
"""Result of a batch tool execution."""
results: List[Dict[str, Any]] = field(default_factory=list)
parallel_count: int = 0
sequential_count: int = 0
elapsed_ms: float = 0
def classify_tool_call(tool_name: str, tool_args: Optional[Dict] = None) -> str:
"""Classify a tool call as 'parallel' or 'sequential'.
Returns 'parallel' or 'sequential'.
"""
# Special cases based on sub-action
if tool_name == "cronjob":
action = (tool_args or {}).get("action", "")
if action in _CRON_READ_ONLY:
return "parallel"
return "sequential"
if tool_name == "fact_store":
action = (tool_args or {}).get("action", "")
if action in ("search", "probe", "list", "related", "reason", "contradict"):
return "parallel"
return "sequential"
if tool_name == "memory":
action = (tool_args or {}).get("action", "")
if action in ("probe", "search", "list"):
return "parallel"
return "sequential"
# Check sequential first (more restrictive)
if tool_name in SEQUENTIAL_TOOLS:
return "sequential"
if tool_name in PARALLEL_SAFE_TOOLS:
return "parallel"
# Unknown tools default to sequential (safe)
return "sequential"
def classify_batch(tool_calls: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
"""Split a list of tool calls into parallel-safe and sequential groups.
Args:
tool_calls: List of dicts with 'name' and 'args' keys
Returns:
(parallel_calls, sequential_calls)
"""
parallel = []
sequential = []
for call in tool_calls:
name = call.get("name", "")
args = call.get("args", {})
classification = classify_tool_call(name, args)
if classification == "parallel":
parallel.append(call)
else:
sequential.append(call)
return parallel, sequential
async def execute_parallel(
tool_calls: List[Dict],
executor: Callable,
) -> List[Dict[str, Any]]:
"""Execute parallel-safe tool calls concurrently.
Args:
tool_calls: List of tool call dicts
executor: Async callable(tool_name, tool_args) -> result
Returns:
List of results in same order as input
"""
tasks = []
for call in tool_calls:
task = asyncio.create_task(
executor(call["name"], call.get("args", {})),
name=f"tool:{call['name']}"
)
tasks.append((call, task))
results = []
for call, task in tasks:
try:
result = await task
results.append({
"tool_name": call["name"],
"result": result,
"parallel": True,
"error": None,
})
except Exception as e:
logger.error("Parallel tool '%s' failed: %s", call["name"], e)
results.append({
"tool_name": call["name"],
"result": None,
"parallel": True,
"error": str(e),
})
return results
async def execute_sequential(
tool_calls: List[Dict],
executor: Callable,
) -> List[Dict[str, Any]]:
"""Execute sequential tool calls one at a time."""
results = []
for call in tool_calls:
try:
result = await executor(call["name"], call.get("args", {}))
results.append({
"tool_name": call["name"],
"result": result,
"parallel": False,
"error": None,
})
except Exception as e:
logger.error("Sequential tool '%s' failed: %s", call["name"], e)
results.append({
"tool_name": call["name"],
"result": None,
"parallel": False,
"error": str(e),
})
return results
async def execute_batch(
tool_calls: List[Dict],
executor: Callable,
) -> BatchResult:
"""Execute a batch of tool calls with parallel safety checks.
1. Classify each call as parallel-safe or sequential
2. Execute all parallel-safe calls concurrently
3. Execute sequential calls one at a time
4. Merge results in original order
Args:
tool_calls: List of dicts with 'name' and 'args' keys
executor: Async callable(tool_name, tool_args) -> result
Returns:
BatchResult with all results and timing
"""
start = time.monotonic()
parallel_calls, sequential_calls = classify_batch(tool_calls)
# Execute parallel-safe calls concurrently
parallel_results = []
if parallel_calls:
parallel_results = await execute_parallel(parallel_calls, executor)
# Execute sequential calls in order
sequential_results = []
if sequential_calls:
sequential_results = await execute_sequential(sequential_calls, executor)
# Merge results — parallel first, then sequential (order preserved within groups)
all_results = parallel_results + sequential_results
elapsed = (time.monotonic() - start) * 1000
return BatchResult(
results=all_results,
parallel_count=len(parallel_calls),
sequential_count=len(sequential_calls),
elapsed_ms=elapsed,
)