feat: enhance interrupt handling and container resource configuration
- Introduced a shared interrupt signaling mechanism to allow tools to check for user interrupts during long-running operations. - Updated the AIAgent to handle interrupts more effectively, ensuring in-progress tool calls are canceled and multiple interrupt messages are combined into one prompt. - Enhanced the CLI configuration to include container resource limits (CPU, memory, disk) and persistence options for Docker, Singularity, and Modal environments. - Improved documentation to clarify interrupt behaviors and container resource settings, providing users with better guidance on configuration and usage.
This commit is contained in:
46
README.md
46
README.md
@@ -361,6 +361,20 @@ Type `/` to see an autocomplete dropdown of all commands.
|
||||
- `Ctrl+C` — interrupt agent (double-press to force exit)
|
||||
- `Ctrl+D` — exit
|
||||
|
||||
### Interrupting the Agent
|
||||
|
||||
**CLI:**
|
||||
- Type a message + Enter while the agent is working to interrupt and send new instructions
|
||||
- `Ctrl+C` to interrupt (press twice within 2s to force exit)
|
||||
- In-progress terminal commands are killed immediately (SIGTERM, then SIGKILL after 1s if the process resists)
|
||||
- Multiple messages typed during interrupt are combined into one prompt
|
||||
|
||||
**Messaging Platforms (Telegram, Discord, Slack):**
|
||||
- Send any message while the agent is working to interrupt
|
||||
- Use `/stop` to interrupt without queuing a follow-up message
|
||||
- Multiple messages sent during interrupt are combined into one prompt
|
||||
- Interrupt signals are processed with highest priority (before command parsing)
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
@@ -441,6 +455,30 @@ hermes config set terminal.backend modal
|
||||
|
||||
**Sudo Support:** If a command needs sudo, you'll be prompted for your password (cached for the session). Or set `SUDO_PASSWORD` in `~/.hermes/.env`.
|
||||
|
||||
**Container Security (Docker, Singularity, Modal):**
|
||||
All container backends run with security hardening by default:
|
||||
- Read-only root filesystem (Docker)
|
||||
- All Linux capabilities dropped
|
||||
- No privilege escalation (`--security-opt no-new-privileges`)
|
||||
- PID limits (256 processes)
|
||||
- Full namespace isolation (`--containall` for Singularity)
|
||||
- Persistent workspace via volumes, not writable root layer
|
||||
|
||||
**Container Resources:**
|
||||
Configure CPU, memory, disk, and persistence for all container backends:
|
||||
|
||||
```yaml
|
||||
# In ~/.hermes/config.yaml under terminal:
|
||||
terminal:
|
||||
backend: docker # or singularity, modal
|
||||
container_cpu: 1 # CPU cores (default: 1)
|
||||
container_memory: 5120 # Memory in MB (default: 5GB)
|
||||
container_disk: 51200 # Disk in MB (default: 50GB)
|
||||
container_persistent: true # Persist filesystem across sessions (default: true)
|
||||
```
|
||||
|
||||
When `container_persistent: true`, the sandbox state (installed packages, files, config) survives across sessions. Docker uses named volumes, Singularity uses persistent overlays, and Modal uses filesystem snapshots.
|
||||
|
||||
### 🧠 Persistent Memory
|
||||
|
||||
Bounded curated memory that persists across sessions:
|
||||
@@ -1348,6 +1386,14 @@ All variables go in `~/.hermes/.env`. Run `hermes config set VAR value` to set t
|
||||
| `MESSAGING_CWD` | Working directory for terminal in messaging (default: ~) |
|
||||
| `GATEWAY_ALLOW_ALL_USERS` | Allow all users without allowlist (`true`/`false`, default: `false`) |
|
||||
|
||||
**Container Resources (Docker, Singularity, Modal):**
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `TERMINAL_CONTAINER_CPU` | CPU cores for container backends (default: 1) |
|
||||
| `TERMINAL_CONTAINER_MEMORY` | Memory in MB for container backends (default: 5120) |
|
||||
| `TERMINAL_CONTAINER_DISK` | Disk in MB for container backends (default: 51200) |
|
||||
| `TERMINAL_CONTAINER_PERSISTENT` | Persist container filesystem across sessions (default: true) |
|
||||
|
||||
**Agent Behavior:**
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
|
||||
@@ -90,6 +90,14 @@ terminal:
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# modal_image: "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
#
|
||||
# --- Container resource limits (docker, singularity, modal -- ignored for local/ssh) ---
|
||||
# These settings apply to all container backends. They control the resources
|
||||
# allocated to the sandbox and whether its filesystem persists across sessions.
|
||||
# container_cpu: 1 # CPU cores (default: 1)
|
||||
# container_memory: 5120 # Memory in MB (default: 5120 = 5GB)
|
||||
# container_disk: 51200 # Disk in MB (default: 51200 = 50GB)
|
||||
# container_persistent: true # Persist filesystem across sessions (default: true)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SUDO SUPPORT (works with ALL backends above)
|
||||
|
||||
22
cli.py
22
cli.py
@@ -225,6 +225,11 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
"ssh_user": "TERMINAL_SSH_USER",
|
||||
"ssh_port": "TERMINAL_SSH_PORT",
|
||||
"ssh_key": "TERMINAL_SSH_KEY",
|
||||
# Container resource config (docker, singularity, modal -- ignored for local/ssh)
|
||||
"container_cpu": "TERMINAL_CONTAINER_CPU",
|
||||
"container_memory": "TERMINAL_CONTAINER_MEMORY",
|
||||
"container_disk": "TERMINAL_CONTAINER_DISK",
|
||||
"container_persistent": "TERMINAL_CONTAINER_PERSISTENT",
|
||||
# Sudo support (works with all backends)
|
||||
"sudo_password": "SUDO_PASSWORD",
|
||||
}
|
||||
@@ -1807,11 +1812,20 @@ class HermesCLI:
|
||||
# nothing can interleave between the box borders.
|
||||
_cprint(f"\n{top}\n{response}\n\n{bot}")
|
||||
|
||||
# If we have a pending message from interrupt, re-queue it for process_loop
|
||||
# instead of recursing (avoids unbounded recursion from rapid interrupts)
|
||||
# Combine all interrupt messages (user may have typed multiple while waiting)
|
||||
# and re-queue as one prompt for process_loop
|
||||
if pending_message and hasattr(self, '_pending_input'):
|
||||
print(f"\n📨 Queued: '{pending_message[:50]}{'...' if len(pending_message) > 50 else ''}'")
|
||||
self._pending_input.put(pending_message)
|
||||
all_parts = [pending_message]
|
||||
while not self._interrupt_queue.empty():
|
||||
try:
|
||||
extra = self._interrupt_queue.get_nowait()
|
||||
if extra:
|
||||
all_parts.append(extra)
|
||||
except queue.Empty:
|
||||
break
|
||||
combined = "\n".join(all_parts)
|
||||
print(f"\n📨 Queued: '{combined[:50]}{'...' if len(combined) > 50 else ''}'")
|
||||
self._pending_input.put(combined)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -538,6 +538,16 @@ tail -f ~/.hermes/logs/gateway.log
|
||||
python cli.py --gateway
|
||||
```
|
||||
|
||||
## Interrupting the Agent
|
||||
|
||||
Send any message while the agent is working to interrupt it. The message becomes the next prompt after the agent stops. Key behaviors:
|
||||
|
||||
- **In-progress terminal commands are killed immediately** -- SIGTERM first, SIGKILL after 1 second if the process resists. Works on local, Docker, SSH, Singularity, and Modal backends.
|
||||
- **Tool calls are cancelled** -- if the model generated multiple tool calls in one batch, only the currently-executing one runs. The rest are skipped.
|
||||
- **Multiple messages are combined** -- if you send "Stop!" then "Do X instead" while the agent is stopping, both messages are joined into one prompt (separated by newline).
|
||||
- **`/stop` command** -- interrupts without queuing a follow-up message.
|
||||
- **Priority processing** -- interrupt signals bypass command parsing and session creation for minimal latency.
|
||||
|
||||
## Storage Locations
|
||||
|
||||
| Path | Purpose |
|
||||
|
||||
@@ -375,6 +375,24 @@ class GatewayRunner:
|
||||
)
|
||||
return None
|
||||
|
||||
# PRIORITY: If an agent is already running for this session, interrupt it
|
||||
# immediately. This is before command parsing to minimize latency -- the
|
||||
# user's "stop" message reaches the agent as fast as possible.
|
||||
_quick_key = (
|
||||
f"agent:main:{source.platform.value}:{source.chat_type}:{source.chat_id}"
|
||||
if source.chat_type != "dm"
|
||||
else f"agent:main:{source.platform.value}:dm"
|
||||
)
|
||||
if _quick_key in self._running_agents:
|
||||
running_agent = self._running_agents[_quick_key]
|
||||
logger.debug("PRIORITY interrupt for session %s", _quick_key[:20])
|
||||
running_agent.interrupt(event.text)
|
||||
if _quick_key in self._pending_messages:
|
||||
self._pending_messages[_quick_key] += "\n" + event.text
|
||||
else:
|
||||
self._pending_messages[_quick_key] = event.text
|
||||
return None
|
||||
|
||||
# Check for commands
|
||||
command = event.get_command()
|
||||
if command in ["new", "reset"]:
|
||||
@@ -427,15 +445,6 @@ class GatewayRunner:
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_key = session_entry.session_key
|
||||
|
||||
# Check if there's already a running agent for this session
|
||||
if session_key in self._running_agents:
|
||||
running_agent = self._running_agents[session_key]
|
||||
logger.debug("Interrupting running agent for session %s...", session_key[:20])
|
||||
running_agent.interrupt(event.text)
|
||||
# Store the new message to be processed after current agent finishes
|
||||
self._pending_messages[session_key] = event.text
|
||||
return None # Don't respond yet - let the interrupt handle it
|
||||
|
||||
# Build session context
|
||||
context = build_session_context(source, self.config, session_entry)
|
||||
|
||||
|
||||
45
run_agent.py
45
run_agent.py
@@ -50,7 +50,8 @@ else:
|
||||
|
||||
# Import our tool system
|
||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||
from tools.terminal_tool import cleanup_vm, set_interrupt_event as _set_terminal_interrupt
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
from tools.interrupt import set_interrupt as _set_interrupt
|
||||
from tools.browser_tool import cleanup_browser
|
||||
|
||||
import requests
|
||||
@@ -266,6 +267,7 @@ class AIAgent:
|
||||
# Primary: OPENROUTER_API_KEY, fallback to direct provider keys
|
||||
client_kwargs["api_key"] = os.getenv("OPENROUTER_API_KEY", "")
|
||||
|
||||
self._client_kwargs = client_kwargs # stored for rebuilding after interrupt
|
||||
try:
|
||||
self.client = OpenAI(**client_kwargs)
|
||||
if not self.quiet_mode:
|
||||
@@ -1015,8 +1017,8 @@ class AIAgent:
|
||||
"""
|
||||
self._interrupt_requested = True
|
||||
self._interrupt_message = message
|
||||
# Signal the terminal tool to kill any running subprocess immediately
|
||||
_set_terminal_interrupt(True)
|
||||
# Signal all tools to abort any in-flight operations immediately
|
||||
_set_interrupt(True)
|
||||
# Propagate interrupt to any running child agents (subagent delegation)
|
||||
for child in self._active_children:
|
||||
try:
|
||||
@@ -1061,7 +1063,7 @@ class AIAgent:
|
||||
self._todo_store.write(last_todo_response, merge=False)
|
||||
if not self.quiet_mode:
|
||||
print(f"{self.log_prefix}📋 Restored {len(last_todo_response)} todo item(s) from history")
|
||||
_set_terminal_interrupt(False)
|
||||
_set_interrupt(False)
|
||||
|
||||
@property
|
||||
def is_interrupted(self) -> bool:
|
||||
@@ -1148,8 +1150,9 @@ class AIAgent:
|
||||
Run the API call in a background thread so the main conversation loop
|
||||
can detect interrupts without waiting for the full HTTP round-trip.
|
||||
|
||||
Returns the API response, or raises InterruptedError if the agent was
|
||||
interrupted while waiting.
|
||||
On interrupt, closes the HTTP client to cancel the in-flight request
|
||||
(stops token generation and avoids wasting money), then rebuilds the
|
||||
client for future calls.
|
||||
"""
|
||||
result = {"response": None, "error": None}
|
||||
|
||||
@@ -1161,12 +1164,19 @@ class AIAgent:
|
||||
|
||||
t = threading.Thread(target=_call, daemon=True)
|
||||
t.start()
|
||||
# Poll every 0.3s so interrupts are noticed quickly
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.3)
|
||||
if self._interrupt_requested:
|
||||
# Can't cancel the HTTP request cleanly, but we can stop
|
||||
# waiting and let the thread finish in the background.
|
||||
# Force-close the HTTP connection to stop token generation
|
||||
try:
|
||||
self.client.close()
|
||||
except Exception:
|
||||
pass
|
||||
# Rebuild the client for future calls (cheap, no network)
|
||||
try:
|
||||
self.client = OpenAI(**self._client_kwargs)
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during API call")
|
||||
if result["error"] is not None:
|
||||
raise result["error"]
|
||||
@@ -1392,6 +1402,23 @@ class AIAgent:
|
||||
def _execute_tool_calls(self, assistant_message, messages: list, effective_task_id: str) -> None:
|
||||
"""Execute tool calls from the assistant message and append results to messages."""
|
||||
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
|
||||
# SAFETY: check interrupt BEFORE starting each tool.
|
||||
# If the user sent "stop" during a previous tool's execution,
|
||||
# do NOT start any more tools -- skip them all immediately.
|
||||
if self._interrupt_requested:
|
||||
remaining_calls = assistant_message.tool_calls[i-1:]
|
||||
if remaining_calls:
|
||||
print(f"{self.log_prefix}⚡ Interrupt: skipping {len(remaining_calls)} tool call(s)")
|
||||
for skipped_tc in remaining_calls:
|
||||
skip_msg = {
|
||||
"role": "tool",
|
||||
"content": "[Tool execution cancelled - user interrupted]",
|
||||
"tool_call_id": skipped_tc.id,
|
||||
}
|
||||
messages.append(skip_msg)
|
||||
self._log_msg_to_db(skip_msg)
|
||||
break
|
||||
|
||||
function_name = tool_call.function.name
|
||||
|
||||
# Reset nudge counters when the relevant tool is actually used
|
||||
|
||||
221
tests/test_interrupt.py
Normal file
221
tests/test_interrupt.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Tests for the interrupt system.
|
||||
|
||||
Run with: python -m pytest tests/test_interrupt.py -v
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: shared interrupt module
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInterruptModule:
|
||||
"""Tests for tools/interrupt.py"""
|
||||
|
||||
def test_set_and_check(self):
|
||||
from tools.interrupt import set_interrupt, is_interrupted
|
||||
set_interrupt(False)
|
||||
assert not is_interrupted()
|
||||
|
||||
set_interrupt(True)
|
||||
assert is_interrupted()
|
||||
|
||||
set_interrupt(False)
|
||||
assert not is_interrupted()
|
||||
|
||||
def test_thread_safety(self):
|
||||
"""Set from one thread, check from another."""
|
||||
from tools.interrupt import set_interrupt, is_interrupted
|
||||
set_interrupt(False)
|
||||
|
||||
seen = {"value": False}
|
||||
|
||||
def _checker():
|
||||
while not is_interrupted():
|
||||
time.sleep(0.01)
|
||||
seen["value"] = True
|
||||
|
||||
t = threading.Thread(target=_checker, daemon=True)
|
||||
t.start()
|
||||
|
||||
time.sleep(0.05)
|
||||
assert not seen["value"]
|
||||
|
||||
set_interrupt(True)
|
||||
t.join(timeout=1)
|
||||
assert seen["value"]
|
||||
|
||||
set_interrupt(False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: pre-tool interrupt check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPreToolCheck:
|
||||
"""Verify that _execute_tool_calls skips all tools when interrupted."""
|
||||
|
||||
def test_all_tools_skipped_when_interrupted(self):
|
||||
"""Mock an interrupted agent and verify no tools execute."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Build a fake assistant_message with 3 tool calls
|
||||
tc1 = MagicMock()
|
||||
tc1.id = "tc_1"
|
||||
tc1.function.name = "terminal"
|
||||
tc1.function.arguments = '{"command": "rm -rf /"}'
|
||||
|
||||
tc2 = MagicMock()
|
||||
tc2.id = "tc_2"
|
||||
tc2.function.name = "terminal"
|
||||
tc2.function.arguments = '{"command": "echo hello"}'
|
||||
|
||||
tc3 = MagicMock()
|
||||
tc3.id = "tc_3"
|
||||
tc3.function.name = "web_search"
|
||||
tc3.function.arguments = '{"query": "test"}'
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.tool_calls = [tc1, tc2, tc3]
|
||||
|
||||
messages = []
|
||||
|
||||
# Create a minimal mock agent with _interrupt_requested = True
|
||||
agent = MagicMock()
|
||||
agent._interrupt_requested = True
|
||||
agent.log_prefix = ""
|
||||
agent._log_msg_to_db = MagicMock()
|
||||
|
||||
# Import and call the method
|
||||
from run_agent import AIAgent
|
||||
# Bind the real method to our mock
|
||||
AIAgent._execute_tool_calls(agent, assistant_msg, messages, "default")
|
||||
|
||||
# All 3 should be skipped
|
||||
assert len(messages) == 3
|
||||
for msg in messages:
|
||||
assert msg["role"] == "tool"
|
||||
assert "cancelled" in msg["content"].lower() or "interrupted" in msg["content"].lower()
|
||||
|
||||
# No actual tool handlers should have been called
|
||||
# (handle_function_call should NOT have been invoked)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: message combining
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMessageCombining:
|
||||
"""Verify multiple interrupt messages are joined."""
|
||||
|
||||
def test_cli_interrupt_queue_drain(self):
|
||||
"""Simulate draining multiple messages from the interrupt queue."""
|
||||
q = queue.Queue()
|
||||
q.put("Stop!")
|
||||
q.put("Don't delete anything")
|
||||
q.put("Show me what you were going to delete instead")
|
||||
|
||||
parts = []
|
||||
while not q.empty():
|
||||
try:
|
||||
msg = q.get_nowait()
|
||||
if msg:
|
||||
parts.append(msg)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
combined = "\n".join(parts)
|
||||
assert "Stop!" in combined
|
||||
assert "Don't delete anything" in combined
|
||||
assert "Show me what you were going to delete instead" in combined
|
||||
assert combined.count("\n") == 2
|
||||
|
||||
def test_gateway_pending_messages_append(self):
|
||||
"""Simulate gateway _pending_messages append logic."""
|
||||
pending = {}
|
||||
key = "agent:main:telegram:dm"
|
||||
|
||||
# First message
|
||||
if key in pending:
|
||||
pending[key] += "\n" + "Stop!"
|
||||
else:
|
||||
pending[key] = "Stop!"
|
||||
|
||||
# Second message
|
||||
if key in pending:
|
||||
pending[key] += "\n" + "Do something else instead"
|
||||
else:
|
||||
pending[key] = "Do something else instead"
|
||||
|
||||
assert pending[key] == "Stop!\nDo something else instead"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests (require local terminal)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSIGKILLEscalation:
|
||||
"""Test that SIGTERM-resistant processes get SIGKILL'd."""
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not __import__("shutil").which("bash"),
|
||||
reason="Requires bash"
|
||||
)
|
||||
def test_sigterm_trap_killed_within_2s(self):
|
||||
"""A process that traps SIGTERM should be SIGKILL'd after 1s grace."""
|
||||
from tools.interrupt import set_interrupt
|
||||
from tools.environments.local import LocalEnvironment
|
||||
|
||||
set_interrupt(False)
|
||||
env = LocalEnvironment(cwd="/tmp", timeout=30)
|
||||
|
||||
# Start execution in a thread, interrupt after 0.5s
|
||||
result_holder = {"value": None}
|
||||
|
||||
def _run():
|
||||
result_holder["value"] = env.execute(
|
||||
"trap '' TERM; sleep 60",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
t = threading.Thread(target=_run)
|
||||
t.start()
|
||||
|
||||
time.sleep(0.5)
|
||||
set_interrupt(True)
|
||||
|
||||
t.join(timeout=5)
|
||||
set_interrupt(False)
|
||||
|
||||
assert result_holder["value"] is not None
|
||||
assert result_holder["value"]["returncode"] == 130
|
||||
assert "interrupted" in result_holder["value"]["output"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Manual smoke test checklist (not automated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SMOKE_TESTS = """
|
||||
Manual Smoke Test Checklist:
|
||||
|
||||
1. CLI: Run `hermes`, ask it to `sleep 30` in terminal, type "stop" + Enter.
|
||||
Expected: command dies within 2s, agent responds to "stop".
|
||||
|
||||
2. CLI: Ask it to extract content from 5 URLs, type interrupt mid-way.
|
||||
Expected: remaining URLs are skipped, partial results returned.
|
||||
|
||||
3. Gateway (Telegram): Send a long task, then send "Stop".
|
||||
Expected: agent stops and responds acknowledging the stop.
|
||||
|
||||
4. Gateway (Telegram): Send "Stop" then "Do X instead" rapidly.
|
||||
Expected: both messages appear as the next prompt (joined by newline).
|
||||
|
||||
5. CLI: Start a task that generates 3+ tool calls in one batch.
|
||||
Type interrupt during the first tool call.
|
||||
Expected: only 1 tool executes, remaining are skipped.
|
||||
"""
|
||||
@@ -698,6 +698,10 @@ def _run_browser_command(
|
||||
except FileNotFoundError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
from tools.interrupt import is_interrupted
|
||||
if is_interrupted():
|
||||
return {"success": False, "error": "Interrupted"}
|
||||
|
||||
# Get session info (creates Browserbase session with proxies if needed)
|
||||
try:
|
||||
session_info = _get_session_info(task_id)
|
||||
|
||||
@@ -1,22 +1,108 @@
|
||||
"""Docker execution environment wrapping mini-swe-agent's DockerEnvironment."""
|
||||
"""Docker execution environment wrapping mini-swe-agent's DockerEnvironment.
|
||||
|
||||
Adds security hardening, configurable resource limits (CPU, memory, disk),
|
||||
and optional filesystem persistence via `docker commit`/`docker create --image`.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
# Security flags applied to every container
|
||||
_SECURITY_ARGS = [
|
||||
"--read-only",
|
||||
"--cap-drop", "ALL",
|
||||
"--security-opt", "no-new-privileges",
|
||||
"--pids-limit", "256",
|
||||
"--tmpfs", "/tmp:rw,noexec,nosuid,size=512m",
|
||||
"--tmpfs", "/var/tmp:rw,noexec,nosuid,size=256m",
|
||||
"--tmpfs", "/run:rw,noexec,nosuid,size=64m",
|
||||
]
|
||||
|
||||
|
||||
class DockerEnvironment(BaseEnvironment):
|
||||
"""Docker container execution via mini-swe-agent.
|
||||
"""Hardened Docker container execution with resource limits and persistence.
|
||||
|
||||
Wraps the upstream DockerEnvironment and adds non-blocking stdin
|
||||
and sudo -S support.
|
||||
Security: read-only root, all capabilities dropped, no privilege escalation,
|
||||
PID limits, tmpfs for writable scratch. Writable overlay for /home and cwd
|
||||
via tmpfs or bind mounts.
|
||||
|
||||
Persistence: when enabled, `docker commit` saves the container state on
|
||||
cleanup, and the next creation restores from that image.
|
||||
"""
|
||||
|
||||
def __init__(self, image: str, cwd: str = "/", timeout: int = 60):
|
||||
def __init__(
|
||||
self,
|
||||
image: str,
|
||||
cwd: str = "/",
|
||||
timeout: int = 60,
|
||||
cpu: float = 0,
|
||||
memory: int = 0,
|
||||
disk: int = 0,
|
||||
persistent_filesystem: bool = False,
|
||||
task_id: str = "default",
|
||||
network: bool = True,
|
||||
):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
self._base_image = image
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._container_id: Optional[str] = None
|
||||
|
||||
from minisweagent.environments.docker import DockerEnvironment as _Docker
|
||||
self._inner = _Docker(image=image, cwd=cwd, timeout=timeout)
|
||||
|
||||
# Build resource limit args
|
||||
resource_args = []
|
||||
if cpu > 0:
|
||||
resource_args.extend(["--cpus", str(cpu)])
|
||||
if memory > 0:
|
||||
resource_args.extend(["--memory", f"{memory}m"])
|
||||
if disk > 0:
|
||||
resource_args.extend(["--storage-opt", f"size={disk}m"])
|
||||
if not network:
|
||||
resource_args.append("--network=none")
|
||||
|
||||
# Persistent volume for writable workspace that survives container restarts.
|
||||
# Non-persistent mode uses tmpfs (ephemeral, fast, gone on cleanup).
|
||||
self._volume_name: Optional[str] = None
|
||||
if self._persistent:
|
||||
self._volume_name = f"hermes-workspace-{task_id}"
|
||||
# Create volume if it doesn't exist
|
||||
subprocess.run(
|
||||
["docker", "volume", "create", self._volume_name],
|
||||
capture_output=True, timeout=10,
|
||||
)
|
||||
writable_args = [
|
||||
"-v", f"{self._volume_name}:{cwd}",
|
||||
"-v", f"{self._volume_name}-home:/root",
|
||||
]
|
||||
else:
|
||||
writable_args = [
|
||||
"--tmpfs", f"{cwd}:rw,exec,size=10g",
|
||||
"--tmpfs", "/home:rw,exec,size=1g",
|
||||
"--tmpfs", "/root:rw,exec,size=1g",
|
||||
]
|
||||
|
||||
# All containers get full security hardening (read-only root + writable
|
||||
# mounts for the workspace). Persistence uses Docker volumes, not
|
||||
# filesystem layer commits, so --read-only is always safe.
|
||||
all_run_args = list(_SECURITY_ARGS) + writable_args + resource_args
|
||||
|
||||
self._inner = _Docker(
|
||||
image=effective_image, cwd=cwd, timeout=timeout,
|
||||
run_args=all_run_args,
|
||||
)
|
||||
self._container_id = self._inner.container_id
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
@@ -38,10 +124,65 @@ class DockerEnvironment(BaseEnvironment):
|
||||
cmd.extend([self._inner.container_id, "bash", "-lc", exec_command])
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, **self._build_run_kwargs(timeout, stdin_data))
|
||||
return {"output": result.stdout, "returncode": result.returncode}
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
if stdin_data:
|
||||
try:
|
||||
proc.stdin.write(stdin_data)
|
||||
proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
except Exception as e:
|
||||
return {"output": f"Docker execution error: {e}", "returncode": 1}
|
||||
|
||||
def cleanup(self):
|
||||
"""Stop and remove the container. Volumes persist if persistent=True."""
|
||||
self._inner.cleanup()
|
||||
|
||||
# If NOT persistent, remove the workspace volumes too
|
||||
if not self._persistent and self._volume_name:
|
||||
for vol in [self._volume_name, f"{self._volume_name}-home"]:
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "volume", "rm", "-f", vol],
|
||||
capture_output=True, timeout=10,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -76,7 +76,12 @@ class LocalEnvironment(BaseEnvironment):
|
||||
while proc.poll() is None:
|
||||
if _interrupt_event.is_set():
|
||||
try:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
pgid = os.getpgid(proc.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
|
||||
@@ -1,21 +1,61 @@
|
||||
"""Modal cloud execution environment wrapping mini-swe-agent's SwerexModalEnvironment."""
|
||||
"""Modal cloud execution environment wrapping mini-swe-agent's SwerexModalEnvironment.
|
||||
|
||||
Supports persistent filesystem snapshots: when enabled, the sandbox's filesystem
|
||||
is snapshotted on cleanup and restored on next creation, so installed packages,
|
||||
project files, and config changes survive across sessions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SNAPSHOT_STORE = Path.home() / ".hermes" / "modal_snapshots.json"
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
"""Load snapshot ID mapping from disk."""
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
"""Persist snapshot ID mapping to disk."""
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
class ModalEnvironment(BaseEnvironment):
|
||||
"""Modal cloud execution via mini-swe-agent.
|
||||
|
||||
Wraps SwerexModalEnvironment and adds sudo -S support.
|
||||
Async-safety patches are applied once before first use so Modal
|
||||
works inside any event loop (Atropos, gateway, etc.).
|
||||
Wraps SwerexModalEnvironment and adds sudo -S support, configurable
|
||||
resources (CPU, memory, disk), and optional filesystem persistence
|
||||
via Modal's snapshot_filesystem() API.
|
||||
"""
|
||||
|
||||
_patches_applied = False
|
||||
|
||||
def __init__(self, image: str, cwd: str = "/root", timeout: int = 60):
|
||||
def __init__(
|
||||
self,
|
||||
image: str,
|
||||
cwd: str = "/root",
|
||||
timeout: int = 60,
|
||||
modal_sandbox_kwargs: Optional[Dict[str, Any]] = None,
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
if not ModalEnvironment._patches_applied:
|
||||
@@ -26,10 +66,35 @@ class ModalEnvironment(BaseEnvironment):
|
||||
pass
|
||||
ModalEnvironment._patches_applied = True
|
||||
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._base_image = image
|
||||
|
||||
sandbox_kwargs = dict(modal_sandbox_kwargs or {})
|
||||
|
||||
# If persistent, try to restore from a previous snapshot
|
||||
restored_image = None
|
||||
if self._persistent:
|
||||
snapshot_id = _load_snapshots().get(self._task_id)
|
||||
if snapshot_id:
|
||||
try:
|
||||
import modal
|
||||
restored_image = modal.Image.from_id(snapshot_id)
|
||||
logger.info("Modal: restoring from snapshot %s", snapshot_id[:20])
|
||||
except Exception as e:
|
||||
logger.warning("Modal: failed to restore snapshot, using base image: %s", e)
|
||||
restored_image = None
|
||||
|
||||
effective_image = restored_image if restored_image else image
|
||||
|
||||
from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment
|
||||
self._inner = SwerexModalEnvironment(
|
||||
image=image, cwd=cwd, timeout=timeout,
|
||||
startup_timeout=180.0, runtime_timeout=3600.0,
|
||||
image=effective_image,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
startup_timeout=180.0,
|
||||
runtime_timeout=3600.0,
|
||||
modal_sandbox_kwargs=sandbox_kwargs,
|
||||
)
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
@@ -42,8 +107,61 @@ class ModalEnvironment(BaseEnvironment):
|
||||
command = f"{command} << '{marker}'\n{stdin_data}\n{marker}"
|
||||
|
||||
exec_command = self._prepare_command(command)
|
||||
return self._inner.execute(exec_command, cwd=cwd, timeout=timeout)
|
||||
|
||||
# Run in a background thread so we can poll for interrupts
|
||||
result_holder = {"value": None, "error": None}
|
||||
|
||||
def _run():
|
||||
try:
|
||||
result_holder["value"] = self._inner.execute(exec_command, cwd=cwd, timeout=timeout)
|
||||
except Exception as e:
|
||||
result_holder["error"] = e
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True)
|
||||
t.start()
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.2)
|
||||
if is_interrupted():
|
||||
try:
|
||||
self._inner.stop()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"output": "[Command interrupted - Modal sandbox terminated]",
|
||||
"returncode": 130,
|
||||
}
|
||||
|
||||
if result_holder["error"]:
|
||||
return {"output": f"Modal execution error: {result_holder['error']}", "returncode": 1}
|
||||
return result_holder["value"]
|
||||
|
||||
def cleanup(self):
|
||||
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
|
||||
if self._persistent:
|
||||
try:
|
||||
sandbox = getattr(self._inner, 'deployment', None)
|
||||
sandbox = getattr(sandbox, '_sandbox', None) if sandbox else None
|
||||
if sandbox:
|
||||
import asyncio
|
||||
async def _snapshot():
|
||||
img = await sandbox.snapshot_filesystem.aio()
|
||||
return img.object_id
|
||||
try:
|
||||
snapshot_id = asyncio.run(_snapshot())
|
||||
except RuntimeError:
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
snapshot_id = pool.submit(
|
||||
asyncio.run, _snapshot()
|
||||
).result(timeout=60)
|
||||
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[self._task_id] = snapshot_id
|
||||
_save_snapshots(snapshots)
|
||||
logger.info("Modal: saved filesystem snapshot %s for task %s",
|
||||
snapshot_id[:20], self._task_id)
|
||||
except Exception as e:
|
||||
logger.warning("Modal: filesystem snapshot failed: %s", e)
|
||||
|
||||
if hasattr(self._inner, 'stop'):
|
||||
self._inner.stop()
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Singularity/Apptainer persistent container environment.
|
||||
|
||||
Also contains the Singularity-specific helpers: scratch dir management,
|
||||
Apptainer cache, and SIF image building.
|
||||
Security-hardened with --containall, --no-home, capability dropping.
|
||||
Supports configurable resource limits and optional filesystem persistence
|
||||
via writable overlay directories that survive across sessions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
@@ -12,11 +14,29 @@ import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SNAPSHOT_STORE = Path.home() / ".hermes" / "singularity_snapshots.json"
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Singularity helpers (scratch dir, SIF cache, SIF building)
|
||||
@@ -116,32 +136,77 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
class SingularityEnvironment(BaseEnvironment):
|
||||
"""Persistent Singularity/Apptainer container environment.
|
||||
"""Hardened Singularity/Apptainer container with resource limits and persistence.
|
||||
|
||||
Uses ``apptainer instance`` to create a long-running container that persists
|
||||
state across all commands within a task.
|
||||
Security: --containall (isolated PID/IPC/mount namespaces, no host home mount),
|
||||
--no-home, writable-tmpfs for scratch space. The container cannot see or modify
|
||||
the host filesystem outside of explicitly bound paths.
|
||||
|
||||
Persistence: when enabled, the writable overlay directory is preserved across
|
||||
sessions so installed packages and files survive cleanup/restore.
|
||||
"""
|
||||
|
||||
def __init__(self, image: str, cwd: str = "/root", timeout: int = 60):
|
||||
def __init__(
|
||||
self,
|
||||
image: str,
|
||||
cwd: str = "/root",
|
||||
timeout: int = 60,
|
||||
cpu: float = 0,
|
||||
memory: int = 0,
|
||||
disk: int = 0,
|
||||
persistent_filesystem: bool = False,
|
||||
task_id: str = "default",
|
||||
):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
self.executable = "apptainer" if shutil.which("apptainer") else "singularity"
|
||||
self.image = _get_or_build_sif(image, self.executable)
|
||||
self.instance_id = f"hermes_{uuid.uuid4().hex[:12]}"
|
||||
self._instance_started = False
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._overlay_dir: Optional[Path] = None
|
||||
|
||||
# Resource limits
|
||||
self._cpu = cpu
|
||||
self._memory = memory
|
||||
|
||||
# Persistent overlay directory
|
||||
if self._persistent:
|
||||
overlay_base = _get_scratch_dir() / "hermes-overlays"
|
||||
overlay_base.mkdir(parents=True, exist_ok=True)
|
||||
self._overlay_dir = overlay_base / f"overlay-{task_id}"
|
||||
self._overlay_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._start_instance()
|
||||
|
||||
def _start_instance(self):
|
||||
cmd = [
|
||||
self.executable, "instance", "start",
|
||||
"--writable-tmpfs", "--containall",
|
||||
str(self.image), self.instance_id,
|
||||
]
|
||||
cmd = [self.executable, "instance", "start"]
|
||||
|
||||
# Security: full isolation from host
|
||||
cmd.extend(["--containall", "--no-home"])
|
||||
|
||||
# Writable layer
|
||||
if self._persistent and self._overlay_dir:
|
||||
# Persistent writable overlay -- survives across restarts
|
||||
cmd.extend(["--overlay", str(self._overlay_dir)])
|
||||
else:
|
||||
cmd.append("--writable-tmpfs")
|
||||
|
||||
# Resource limits (cgroup-based, may require root or appropriate config)
|
||||
if self._memory > 0:
|
||||
cmd.extend(["--memory", f"{self._memory}M"])
|
||||
if self._cpu > 0:
|
||||
cmd.extend(["--cpus", str(self._cpu)])
|
||||
|
||||
cmd.extend([str(self.image), self.instance_id])
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to start instance: {result.stderr}")
|
||||
self._instance_started = True
|
||||
logger.info("Singularity instance %s started", self.instance_id)
|
||||
logger.info("Singularity instance %s started (persistent=%s)",
|
||||
self.instance_id, self._persistent)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError("Instance start timed out")
|
||||
|
||||
@@ -151,17 +216,63 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
if not self._instance_started:
|
||||
return {"output": "Instance not started", "returncode": -1}
|
||||
|
||||
effective_timeout = timeout or self.timeout
|
||||
cmd = [self.executable, "exec", "--pwd", cwd or self.cwd,
|
||||
f"instance://{self.instance_id}",
|
||||
"bash", "-c", self._prepare_command(command)]
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, **self._build_run_kwargs(timeout, stdin_data))
|
||||
return {"output": result.stdout, "returncode": result.returncode}
|
||||
import time as _time
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
if stdin_data:
|
||||
try:
|
||||
proc.stdin.write(stdin_data)
|
||||
proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = _time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
return self._timeout_result(timeout)
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if _time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
_time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
except Exception as e:
|
||||
return {"output": f"Singularity execution error: {e}", "returncode": 1}
|
||||
|
||||
def cleanup(self):
|
||||
"""Stop the instance. If persistent, the overlay dir survives for next creation."""
|
||||
if self._instance_started:
|
||||
try:
|
||||
subprocess.run(
|
||||
@@ -172,3 +283,9 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
except Exception as e:
|
||||
logger.warning("Failed to stop Singularity instance %s: %s", self.instance_id, e)
|
||||
self._instance_started = False
|
||||
|
||||
# Record overlay path for persistence restoration
|
||||
if self._persistent and self._overlay_dir:
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[self._task_id] = str(self._overlay_dir)
|
||||
_save_snapshots(snapshots)
|
||||
|
||||
@@ -3,9 +3,12 @@
|
||||
import logging
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,6 +19,9 @@ class SSHEnvironment(BaseEnvironment):
|
||||
Uses SSH ControlMaster for connection persistence so subsequent
|
||||
commands are fast. Security benefit: the agent cannot modify its
|
||||
own code since execution happens on a separate machine.
|
||||
|
||||
Foreground commands are interruptible: the local ssh process is killed
|
||||
and a remote kill is attempted over the ControlMaster socket.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, user: str, cwd: str = "/tmp",
|
||||
@@ -65,15 +71,65 @@ class SSHEnvironment(BaseEnvironment):
|
||||
work_dir = cwd or self.cwd
|
||||
exec_command = self._prepare_command(command)
|
||||
wrapped = f'cd {work_dir} && {exec_command}'
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.extend(["bash", "-c", wrapped])
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, **self._build_run_kwargs(timeout, stdin_data))
|
||||
return {"output": result.stdout, "returncode": result.returncode}
|
||||
kwargs = self._build_run_kwargs(timeout, stdin_data)
|
||||
# Remove timeout from kwargs -- we handle it in the poll loop
|
||||
kwargs.pop("timeout", None)
|
||||
|
||||
_output_chunks = []
|
||||
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if stdin_data:
|
||||
try:
|
||||
proc.stdin.write(stdin_data)
|
||||
proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
return self._timeout_result(timeout)
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
|
||||
except Exception as e:
|
||||
return {"output": f"SSH execution error: {str(e)}", "returncode": 1}
|
||||
|
||||
|
||||
28
tools/interrupt.py
Normal file
28
tools/interrupt.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Shared interrupt signaling for all tools.
|
||||
|
||||
Provides a global threading.Event that any tool can check to determine
|
||||
if the user has requested an interrupt. The agent's interrupt() method
|
||||
sets this event, and tools poll it during long-running operations.
|
||||
|
||||
Usage in tools:
|
||||
from tools.interrupt import is_interrupted
|
||||
if is_interrupted():
|
||||
return {"output": "[interrupted]", "returncode": 130}
|
||||
"""
|
||||
|
||||
import threading
|
||||
|
||||
_interrupt_event = threading.Event()
|
||||
|
||||
|
||||
def set_interrupt(active: bool) -> None:
|
||||
"""Called by the agent to signal or clear the interrupt."""
|
||||
if active:
|
||||
_interrupt_event.set()
|
||||
else:
|
||||
_interrupt_event.clear()
|
||||
|
||||
|
||||
def is_interrupted() -> bool:
|
||||
"""Check if an interrupt has been requested. Safe to call from any thread."""
|
||||
return _interrupt_event.is_set()
|
||||
@@ -92,6 +92,10 @@ def _handle_send(args):
|
||||
f"Try using a numeric channel ID instead."
|
||||
})
|
||||
|
||||
from tools.interrupt import is_interrupted
|
||||
if is_interrupted():
|
||||
return json.dumps({"error": "Interrupted"})
|
||||
|
||||
try:
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
config = load_gateway_config()
|
||||
|
||||
@@ -49,20 +49,7 @@ logger = logging.getLogger(__name__)
|
||||
# The terminal tool polls this during command execution so it can kill
|
||||
# long-running subprocesses immediately instead of blocking until timeout.
|
||||
# ---------------------------------------------------------------------------
|
||||
_interrupt_event = threading.Event()
|
||||
|
||||
|
||||
def set_interrupt_event(active: bool) -> None:
|
||||
"""Called by the agent to signal or clear the interrupt."""
|
||||
if active:
|
||||
_interrupt_event.set()
|
||||
else:
|
||||
_interrupt_event.clear()
|
||||
|
||||
|
||||
def is_interrupted() -> bool:
|
||||
"""Check if an interrupt has been requested."""
|
||||
return _interrupt_event.is_set()
|
||||
from tools.interrupt import set_interrupt as set_interrupt_event, is_interrupted, _interrupt_event
|
||||
|
||||
|
||||
# Add mini-swe-agent to path if not installed
|
||||
@@ -459,11 +446,18 @@ def _get_env_config() -> Dict[str, Any]:
|
||||
"ssh_host": os.getenv("TERMINAL_SSH_HOST", ""),
|
||||
"ssh_user": os.getenv("TERMINAL_SSH_USER", ""),
|
||||
"ssh_port": int(os.getenv("TERMINAL_SSH_PORT", "22")),
|
||||
"ssh_key": os.getenv("TERMINAL_SSH_KEY", ""), # Path to private key (optional, uses ssh-agent if empty)
|
||||
"ssh_key": os.getenv("TERMINAL_SSH_KEY", ""),
|
||||
# Container resource config (applies to docker, singularity, modal -- ignored for local/ssh)
|
||||
"container_cpu": float(os.getenv("TERMINAL_CONTAINER_CPU", "1")),
|
||||
"container_memory": int(os.getenv("TERMINAL_CONTAINER_MEMORY", "5120")), # MB (default 5GB)
|
||||
"container_disk": int(os.getenv("TERMINAL_CONTAINER_DISK", "51200")), # MB (default 50GB)
|
||||
"container_persistent": os.getenv("TERMINAL_CONTAINER_PERSISTENT", "true").lower() in ("true", "1", "yes"),
|
||||
}
|
||||
|
||||
|
||||
def _create_environment(env_type: str, image: str, cwd: str, timeout: int, ssh_config: dict = None):
|
||||
def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
||||
ssh_config: dict = None, container_config: dict = None,
|
||||
task_id: str = "default"):
|
||||
"""
|
||||
Create an execution environment from mini-swe-agent.
|
||||
|
||||
@@ -473,25 +467,49 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, ssh_c
|
||||
cwd: Working directory
|
||||
timeout: Default command timeout
|
||||
ssh_config: SSH connection config (for env_type="ssh")
|
||||
container_config: Resource config for container backends (cpu, memory, disk, persistent)
|
||||
task_id: Task identifier for environment reuse and snapshot keying
|
||||
|
||||
Returns:
|
||||
Environment instance with execute() method
|
||||
"""
|
||||
cc = container_config or {}
|
||||
cpu = cc.get("container_cpu", 1)
|
||||
memory = cc.get("container_memory", 5120)
|
||||
disk = cc.get("container_disk", 51200)
|
||||
persistent = cc.get("container_persistent", True)
|
||||
|
||||
if env_type == "local":
|
||||
# Use our custom LocalEnvironment with sudo support and non-blocking stdin
|
||||
return _LocalEnvironment(cwd=cwd, timeout=timeout)
|
||||
|
||||
elif env_type == "docker":
|
||||
# Use custom Docker wrapper with sudo support and non-blocking stdin
|
||||
return _DockerEnvironment(image=image, cwd=cwd, timeout=timeout)
|
||||
return _DockerEnvironment(
|
||||
image=image, cwd=cwd, timeout=timeout,
|
||||
cpu=cpu, memory=memory, disk=disk,
|
||||
persistent_filesystem=persistent, task_id=task_id,
|
||||
)
|
||||
|
||||
elif env_type == "singularity":
|
||||
# Use custom Singularity environment with better space management
|
||||
return _SingularityEnvironment(image=image, cwd=cwd, timeout=timeout)
|
||||
return _SingularityEnvironment(
|
||||
image=image, cwd=cwd, timeout=timeout,
|
||||
cpu=cpu, memory=memory, disk=disk,
|
||||
persistent_filesystem=persistent, task_id=task_id,
|
||||
)
|
||||
|
||||
elif env_type == "modal":
|
||||
# Use custom Modal wrapper with sudo support
|
||||
return _ModalEnvironment(image=image, cwd=cwd, timeout=timeout)
|
||||
sandbox_kwargs = {}
|
||||
if cpu > 0:
|
||||
sandbox_kwargs["cpu"] = cpu
|
||||
if memory > 0:
|
||||
sandbox_kwargs["memory"] = memory
|
||||
if disk > 0:
|
||||
sandbox_kwargs["ephemeral_disk"] = disk
|
||||
|
||||
return _ModalEnvironment(
|
||||
image=image, cwd=cwd, timeout=timeout,
|
||||
modal_sandbox_kwargs=sandbox_kwargs,
|
||||
persistent_filesystem=persistent, task_id=task_id,
|
||||
)
|
||||
|
||||
elif env_type == "ssh":
|
||||
if not ssh_config or not ssh_config.get("host") or not ssh_config.get("user"):
|
||||
@@ -502,7 +520,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, ssh_c
|
||||
port=ssh_config.get("port", 22),
|
||||
key_path=ssh_config.get("key", ""),
|
||||
cwd=cwd,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -830,12 +848,23 @@ def terminal_tool(
|
||||
"key": config.get("ssh_key", ""),
|
||||
}
|
||||
|
||||
container_config = None
|
||||
if env_type in ("docker", "singularity", "modal"):
|
||||
container_config = {
|
||||
"container_cpu": config.get("container_cpu", 1),
|
||||
"container_memory": config.get("container_memory", 5120),
|
||||
"container_disk": config.get("container_disk", 51200),
|
||||
"container_persistent": config.get("container_persistent", True),
|
||||
}
|
||||
|
||||
new_env = _create_environment(
|
||||
env_type=env_type,
|
||||
image=image,
|
||||
cwd=cwd,
|
||||
timeout=effective_timeout,
|
||||
ssh_config=ssh_config
|
||||
ssh_config=ssh_config,
|
||||
container_config=container_config,
|
||||
task_id=effective_task_id,
|
||||
)
|
||||
except ImportError as e:
|
||||
return json.dumps({
|
||||
|
||||
@@ -234,6 +234,10 @@ async def vision_analyze_tool(
|
||||
should_cleanup = True
|
||||
|
||||
try:
|
||||
from tools.interrupt import is_interrupted
|
||||
if is_interrupted():
|
||||
return json.dumps({"success": False, "error": "Interrupted"})
|
||||
|
||||
logger.info("Analyzing image: %s", image_url[:60])
|
||||
logger.info("User prompt: %s", user_prompt[:100])
|
||||
|
||||
|
||||
@@ -465,11 +465,12 @@ def web_search_tool(query: str, limit: int = 5) -> str:
|
||||
}
|
||||
|
||||
try:
|
||||
from tools.interrupt import is_interrupted
|
||||
if is_interrupted():
|
||||
return json.dumps({"error": "Interrupted", "success": False})
|
||||
|
||||
logger.info("Searching the web for: '%s' (limit: %d)", query, limit)
|
||||
|
||||
# Use Firecrawl's v2 search functionality WITHOUT scraping
|
||||
# We only want search result metadata, not scraped content
|
||||
# Docs: https://docs.firecrawl.dev/features/search
|
||||
response = _get_firecrawl_client().search(
|
||||
query=query,
|
||||
limit=limit
|
||||
@@ -601,7 +602,12 @@ async def web_extract_tool(
|
||||
# Batch scraping adds complexity without much benefit for small numbers of URLs
|
||||
results: List[Dict[str, Any]] = []
|
||||
|
||||
from tools.interrupt import is_interrupted as _is_interrupted
|
||||
for url in urls:
|
||||
if _is_interrupted():
|
||||
results.append({"url": url, "error": "Interrupted", "title": ""})
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.info("Scraping: %s", url)
|
||||
scrape_result = _get_firecrawl_client().scrape(
|
||||
@@ -876,7 +882,10 @@ async def web_crawl_tool(
|
||||
if instructions:
|
||||
logger.info("Instructions parameter ignored (not supported in crawl API)")
|
||||
|
||||
# Use the crawl method which waits for completion automatically
|
||||
from tools.interrupt import is_interrupted as _is_int
|
||||
if _is_int():
|
||||
return json.dumps({"error": "Interrupted", "success": False})
|
||||
|
||||
try:
|
||||
crawl_result = _get_firecrawl_client().crawl(
|
||||
url=url,
|
||||
|
||||
Reference in New Issue
Block a user