Fixes and refactors enabled by recent updates to main.
This commit is contained in:
@@ -6,6 +6,8 @@ import types
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
TOOLS_DIR = Path(__file__).resolve().parents[2] / "tools"
|
||||
|
||||
@@ -25,7 +27,7 @@ def _reset_modules(prefixes: tuple[str, ...]):
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
|
||||
def _install_fake_tools_package():
|
||||
def _install_fake_tools_package(*, credential_mounts=None):
|
||||
_reset_modules(("tools", "agent", "hermes_cli"))
|
||||
|
||||
hermes_cli = types.ModuleType("hermes_cli")
|
||||
@@ -68,6 +70,9 @@ def _install_fake_tools_package():
|
||||
managed_mode=True,
|
||||
)
|
||||
)
|
||||
sys.modules["tools.credential_files"] = types.SimpleNamespace(
|
||||
get_credential_file_mounts=lambda: list(credential_mounts or []),
|
||||
)
|
||||
|
||||
return interrupt_event
|
||||
|
||||
@@ -87,6 +92,7 @@ class _FakeResponse:
|
||||
def test_managed_modal_execute_polls_until_completed(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
|
||||
calls = []
|
||||
poll_count = {"value": 0}
|
||||
@@ -112,7 +118,7 @@ def test_managed_modal_execute_polls_until_completed(monkeypatch):
|
||||
raise AssertionError(f"Unexpected request: {method} {url}")
|
||||
|
||||
monkeypatch.setattr(managed_modal.requests, "request", fake_request)
|
||||
monkeypatch.setattr(managed_modal.time, "sleep", lambda _: None)
|
||||
monkeypatch.setattr(modal_common.time, "sleep", lambda _: None)
|
||||
|
||||
env = managed_modal.ManagedModalEnvironment(image="python:3.11")
|
||||
result = env.execute("echo hello")
|
||||
@@ -149,6 +155,7 @@ def test_managed_modal_create_sends_a_stable_idempotency_key(monkeypatch):
|
||||
def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
|
||||
interrupt_event = _install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
|
||||
calls = []
|
||||
|
||||
@@ -170,7 +177,7 @@ def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
|
||||
interrupt_event.set()
|
||||
|
||||
monkeypatch.setattr(managed_modal.requests, "request", fake_request)
|
||||
monkeypatch.setattr(managed_modal.time, "sleep", fake_sleep)
|
||||
monkeypatch.setattr(modal_common.time, "sleep", fake_sleep)
|
||||
|
||||
env = managed_modal.ManagedModalEnvironment(image="python:3.11")
|
||||
result = env.execute("sleep 30")
|
||||
@@ -190,6 +197,7 @@ def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
|
||||
def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||
@@ -203,7 +211,7 @@ def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeyp
|
||||
raise AssertionError(f"Unexpected request: {method} {url}")
|
||||
|
||||
monkeypatch.setattr(managed_modal.requests, "request", fake_request)
|
||||
monkeypatch.setattr(managed_modal.time, "sleep", lambda _: None)
|
||||
monkeypatch.setattr(modal_common.time, "sleep", lambda _: None)
|
||||
|
||||
env = managed_modal.ManagedModalEnvironment(image="python:3.11")
|
||||
result = env.execute("echo hello")
|
||||
@@ -211,3 +219,91 @@ def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeyp
|
||||
|
||||
assert result["returncode"] == 1
|
||||
assert "not found" in result["output"].lower()
|
||||
|
||||
|
||||
def test_managed_modal_create_and_cleanup_preserve_gateway_persistence_fields(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
|
||||
create_payloads = []
|
||||
terminate_payloads = []
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||
create_payloads.append(json)
|
||||
return _FakeResponse(200, {"id": "sandbox-1"})
|
||||
if method == "POST" and url.endswith("/terminate"):
|
||||
terminate_payloads.append(json)
|
||||
return _FakeResponse(200, {"status": "terminated"})
|
||||
raise AssertionError(f"Unexpected request: {method} {url}")
|
||||
|
||||
monkeypatch.setattr(managed_modal.requests, "request", fake_request)
|
||||
|
||||
env = managed_modal.ManagedModalEnvironment(
|
||||
image="python:3.11",
|
||||
task_id="task-managed-persist",
|
||||
persistent_filesystem=False,
|
||||
)
|
||||
env.cleanup()
|
||||
|
||||
assert create_payloads == [{
|
||||
"image": "python:3.11",
|
||||
"cwd": "/root",
|
||||
"cpu": 1.0,
|
||||
"memoryMiB": 5120.0,
|
||||
"timeoutMs": 3_600_000,
|
||||
"idleTimeoutMs": 300_000,
|
||||
"persistentFilesystem": False,
|
||||
"logicalKey": "task-managed-persist",
|
||||
}]
|
||||
assert terminate_payloads == [{"snapshotBeforeTerminate": False}]
|
||||
|
||||
|
||||
def test_managed_modal_rejects_host_credential_passthrough():
|
||||
_install_fake_tools_package(
|
||||
credential_mounts=[{
|
||||
"host_path": "/tmp/token.json",
|
||||
"container_path": "/root/.hermes/token.json",
|
||||
}]
|
||||
)
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
|
||||
with pytest.raises(ValueError, match="credential-file passthrough"):
|
||||
managed_modal.ManagedModalEnvironment(image="python:3.11")
|
||||
|
||||
|
||||
def test_managed_modal_execute_times_out_and_cancels(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
|
||||
calls = []
|
||||
monotonic_values = iter([0.0, 12.5])
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
calls.append((method, url, json, timeout))
|
||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||
return _FakeResponse(200, {"id": "sandbox-1"})
|
||||
if method == "POST" and url.endswith("/execs"):
|
||||
return _FakeResponse(202, {"execId": json["execId"], "status": "running"})
|
||||
if method == "GET" and "/execs/" in url:
|
||||
return _FakeResponse(200, {"execId": url.rsplit("/", 1)[-1], "status": "running"})
|
||||
if method == "POST" and url.endswith("/cancel"):
|
||||
return _FakeResponse(202, {"status": "cancelling"})
|
||||
if method == "POST" and url.endswith("/terminate"):
|
||||
return _FakeResponse(200, {"status": "terminated"})
|
||||
raise AssertionError(f"Unexpected request: {method} {url}")
|
||||
|
||||
monkeypatch.setattr(managed_modal.requests, "request", fake_request)
|
||||
monkeypatch.setattr(modal_common.time, "monotonic", lambda: next(monotonic_values))
|
||||
monkeypatch.setattr(modal_common.time, "sleep", lambda _: None)
|
||||
|
||||
env = managed_modal.ManagedModalEnvironment(image="python:3.11")
|
||||
result = env.execute("sleep 30", timeout=2)
|
||||
env.cleanup()
|
||||
|
||||
assert result == {
|
||||
"output": "Managed Modal exec timed out after 2s",
|
||||
"returncode": 124,
|
||||
}
|
||||
assert any(call[0] == "POST" and call[1].endswith("/cancel") for call in calls)
|
||||
|
||||
@@ -87,6 +87,10 @@ def _install_modal_test_modules(
|
||||
|
||||
sys.modules["tools.environments.base"] = types.SimpleNamespace(BaseEnvironment=_DummyBaseEnvironment)
|
||||
sys.modules["tools.interrupt"] = types.SimpleNamespace(is_interrupted=lambda: False)
|
||||
sys.modules["tools.credential_files"] = types.SimpleNamespace(
|
||||
get_credential_file_mounts=lambda: [],
|
||||
iter_skills_files=lambda: [],
|
||||
)
|
||||
|
||||
from_id_calls: list[str] = []
|
||||
registry_calls: list[tuple[str, list[str] | None]] = []
|
||||
|
||||
@@ -6,12 +6,15 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
from tools.environments.modal_common import (
|
||||
BaseModalExecutionEnvironment,
|
||||
ModalExecStart,
|
||||
PreparedModalExec,
|
||||
)
|
||||
from tools.managed_tool_gateway import resolve_managed_tool_gateway
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,12 +28,20 @@ def _request_timeout_env(name: str, default: float) -> float:
|
||||
return default
|
||||
|
||||
|
||||
class ManagedModalEnvironment(BaseEnvironment):
|
||||
@dataclass(frozen=True)
|
||||
class _ManagedModalExecHandle:
|
||||
exec_id: str
|
||||
|
||||
|
||||
class ManagedModalEnvironment(BaseModalExecutionEnvironment):
|
||||
"""Gateway-owned Modal sandbox with Hermes-compatible execute/cleanup."""
|
||||
|
||||
_CONNECT_TIMEOUT_SECONDS = _request_timeout_env("TERMINAL_MANAGED_MODAL_CONNECT_TIMEOUT_SECONDS", 1.0)
|
||||
_POLL_READ_TIMEOUT_SECONDS = _request_timeout_env("TERMINAL_MANAGED_MODAL_POLL_READ_TIMEOUT_SECONDS", 5.0)
|
||||
_CANCEL_READ_TIMEOUT_SECONDS = _request_timeout_env("TERMINAL_MANAGED_MODAL_CANCEL_READ_TIMEOUT_SECONDS", 5.0)
|
||||
_client_timeout_grace_seconds = 10.0
|
||||
_interrupt_output = "[Command interrupted - Modal sandbox exec cancelled]"
|
||||
_unexpected_error_prefix = "Managed Modal exec failed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -43,6 +54,8 @@ class ManagedModalEnvironment(BaseEnvironment):
|
||||
):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
self._guard_unsupported_credential_passthrough()
|
||||
|
||||
gateway = resolve_managed_tool_gateway("modal")
|
||||
if gateway is None:
|
||||
raise ValueError("Managed Modal requires a configured tool gateway and Nous user token")
|
||||
@@ -56,31 +69,16 @@ class ManagedModalEnvironment(BaseEnvironment):
|
||||
self._create_idempotency_key = str(uuid.uuid4())
|
||||
self._sandbox_id = self._create_sandbox()
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
# When a sudo password is present, inject it via a shell-level pipe
|
||||
# (same approach as the direct ModalEnvironment) since the gateway
|
||||
# cannot pipe subprocess stdin directly.
|
||||
if sudo_stdin is not None:
|
||||
import shlex
|
||||
exec_command = (
|
||||
f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}"
|
||||
)
|
||||
|
||||
exec_cwd = cwd or self.cwd
|
||||
effective_timeout = timeout or self.timeout
|
||||
def _start_modal_exec(self, prepared: PreparedModalExec) -> ModalExecStart:
|
||||
exec_id = str(uuid.uuid4())
|
||||
payload: Dict[str, Any] = {
|
||||
"execId": exec_id,
|
||||
"command": exec_command,
|
||||
"cwd": exec_cwd,
|
||||
"timeoutMs": int(effective_timeout * 1000),
|
||||
"command": prepared.command,
|
||||
"cwd": prepared.cwd,
|
||||
"timeoutMs": int(prepared.timeout * 1000),
|
||||
}
|
||||
if stdin_data is not None:
|
||||
payload["stdinData"] = stdin_data
|
||||
if prepared.stdin_data is not None:
|
||||
payload["stdinData"] = prepared.stdin_data
|
||||
|
||||
try:
|
||||
response = self._request(
|
||||
@@ -90,81 +88,68 @@ class ManagedModalEnvironment(BaseEnvironment):
|
||||
timeout=10,
|
||||
)
|
||||
except Exception as exc:
|
||||
return {
|
||||
"output": f"Managed Modal exec failed: {exc}",
|
||||
"returncode": 1,
|
||||
}
|
||||
return ModalExecStart(
|
||||
immediate_result=self._error_result(f"Managed Modal exec failed: {exc}")
|
||||
)
|
||||
|
||||
if response.status_code >= 400:
|
||||
return {
|
||||
"output": self._format_error("Managed Modal exec failed", response),
|
||||
"returncode": 1,
|
||||
}
|
||||
return ModalExecStart(
|
||||
immediate_result=self._error_result(
|
||||
self._format_error("Managed Modal exec failed", response)
|
||||
)
|
||||
)
|
||||
|
||||
body = response.json()
|
||||
status = body.get("status")
|
||||
if status in {"completed", "failed", "cancelled", "timeout"}:
|
||||
return {
|
||||
"output": body.get("output", ""),
|
||||
"returncode": body.get("returncode", 1),
|
||||
}
|
||||
return ModalExecStart(
|
||||
immediate_result=self._result(
|
||||
body.get("output", ""),
|
||||
body.get("returncode", 1),
|
||||
)
|
||||
)
|
||||
|
||||
if body.get("execId") != exec_id:
|
||||
return {
|
||||
"output": "Managed Modal exec start did not return the expected exec id",
|
||||
"returncode": 1,
|
||||
}
|
||||
|
||||
poll_interval = 0.25
|
||||
deadline = time.monotonic() + effective_timeout + 10
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
if is_interrupted():
|
||||
self._cancel_exec(exec_id)
|
||||
return {
|
||||
"output": "[Command interrupted - Modal sandbox exec cancelled]",
|
||||
"returncode": 130,
|
||||
}
|
||||
|
||||
try:
|
||||
status_response = self._request(
|
||||
"GET",
|
||||
f"/v1/sandboxes/{self._sandbox_id}/execs/{exec_id}",
|
||||
timeout=(self._CONNECT_TIMEOUT_SECONDS, self._POLL_READ_TIMEOUT_SECONDS),
|
||||
return ModalExecStart(
|
||||
immediate_result=self._error_result(
|
||||
"Managed Modal exec start did not return the expected exec id"
|
||||
)
|
||||
except Exception as exc:
|
||||
return {
|
||||
"output": f"Managed Modal exec poll failed: {exc}",
|
||||
"returncode": 1,
|
||||
}
|
||||
)
|
||||
|
||||
if status_response.status_code == 404:
|
||||
return {
|
||||
"output": "Managed Modal exec not found",
|
||||
"returncode": 1,
|
||||
}
|
||||
return ModalExecStart(handle=_ManagedModalExecHandle(exec_id=exec_id))
|
||||
|
||||
if status_response.status_code >= 400:
|
||||
return {
|
||||
"output": self._format_error("Managed Modal exec poll failed", status_response),
|
||||
"returncode": 1,
|
||||
}
|
||||
def _poll_modal_exec(self, handle: _ManagedModalExecHandle) -> dict | None:
|
||||
try:
|
||||
status_response = self._request(
|
||||
"GET",
|
||||
f"/v1/sandboxes/{self._sandbox_id}/execs/{handle.exec_id}",
|
||||
timeout=(self._CONNECT_TIMEOUT_SECONDS, self._POLL_READ_TIMEOUT_SECONDS),
|
||||
)
|
||||
except Exception as exc:
|
||||
return self._error_result(f"Managed Modal exec poll failed: {exc}")
|
||||
|
||||
status_body = status_response.json()
|
||||
status = status_body.get("status")
|
||||
if status in {"completed", "failed", "cancelled", "timeout"}:
|
||||
return {
|
||||
"output": status_body.get("output", ""),
|
||||
"returncode": status_body.get("returncode", 1),
|
||||
}
|
||||
if status_response.status_code == 404:
|
||||
return self._error_result("Managed Modal exec not found")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
if status_response.status_code >= 400:
|
||||
return self._error_result(
|
||||
self._format_error("Managed Modal exec poll failed", status_response)
|
||||
)
|
||||
|
||||
self._cancel_exec(exec_id)
|
||||
return {
|
||||
"output": f"Managed Modal exec timed out after {effective_timeout}s",
|
||||
"returncode": 124,
|
||||
}
|
||||
status_body = status_response.json()
|
||||
status = status_body.get("status")
|
||||
if status in {"completed", "failed", "cancelled", "timeout"}:
|
||||
return self._result(
|
||||
status_body.get("output", ""),
|
||||
status_body.get("returncode", 1),
|
||||
)
|
||||
return None
|
||||
|
||||
def _cancel_modal_exec(self, handle: _ManagedModalExecHandle) -> None:
|
||||
self._cancel_exec(handle.exec_id)
|
||||
|
||||
def _timeout_result_for_modal(self, timeout: int) -> dict:
|
||||
return self._result(f"Managed Modal exec timed out after {timeout}s", 124)
|
||||
|
||||
def cleanup(self):
|
||||
if not getattr(self, "_sandbox_id", None):
|
||||
@@ -226,6 +211,21 @@ class ManagedModalEnvironment(BaseEnvironment):
|
||||
raise RuntimeError("Managed Modal create did not return a sandbox id")
|
||||
return sandbox_id
|
||||
|
||||
def _guard_unsupported_credential_passthrough(self) -> None:
|
||||
"""Managed Modal does not sync or mount host credential files."""
|
||||
try:
|
||||
from tools.credential_files import get_credential_file_mounts
|
||||
except Exception:
|
||||
return
|
||||
|
||||
mounts = get_credential_file_mounts()
|
||||
if mounts:
|
||||
raise ValueError(
|
||||
"Managed Modal does not support host credential-file passthrough. "
|
||||
"Use TERMINAL_MODAL_MODE=direct when skills or config require "
|
||||
"credential files inside the sandbox."
|
||||
)
|
||||
|
||||
def _request(self, method: str, path: str, *,
|
||||
json: Dict[str, Any] | None = None,
|
||||
timeout: int = 30,
|
||||
|
||||
@@ -9,13 +9,16 @@ import json
|
||||
import logging
|
||||
import shlex
|
||||
import threading
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
from tools.environments.modal_common import (
|
||||
BaseModalExecutionEnvironment,
|
||||
ModalExecStart,
|
||||
PreparedModalExec,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -135,9 +138,20 @@ class _AsyncWorker:
|
||||
self._thread.join(timeout=10)
|
||||
|
||||
|
||||
class ModalEnvironment(BaseEnvironment):
|
||||
@dataclass
|
||||
class _DirectModalExecHandle:
|
||||
thread: threading.Thread
|
||||
result_holder: Dict[str, Any]
|
||||
|
||||
|
||||
class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
"""Modal cloud execution via native Modal sandboxes."""
|
||||
|
||||
_stdin_mode = "heredoc"
|
||||
_poll_interval_seconds = 0.2
|
||||
_interrupt_output = "[Command interrupted - Modal sandbox terminated]"
|
||||
_unexpected_error_prefix = "Modal execution error"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str,
|
||||
@@ -312,36 +326,11 @@ class ModalEnvironment(BaseEnvironment):
|
||||
except Exception as e:
|
||||
logger.debug("Modal: file sync failed: %s", e)
|
||||
|
||||
def execute(
|
||||
self,
|
||||
command: str,
|
||||
cwd: str = "",
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None,
|
||||
) -> dict:
|
||||
def _before_execute(self) -> None:
|
||||
self._sync_files()
|
||||
|
||||
if stdin_data is not None:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
while marker in stdin_data:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
command = f"{command} << '{marker}'\n{stdin_data}\n{marker}"
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
# Modal sandboxes execute commands via exec() and cannot pipe
|
||||
# subprocess stdin directly. When a sudo password is present,
|
||||
# use a shell-level pipe from printf.
|
||||
if sudo_stdin is not None:
|
||||
exec_command = (
|
||||
f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}"
|
||||
)
|
||||
|
||||
effective_cwd = cwd or self.cwd
|
||||
effective_timeout = timeout or self.timeout
|
||||
full_command = f"cd {shlex.quote(effective_cwd)} && {exec_command}"
|
||||
|
||||
def _start_modal_exec(self, prepared: PreparedModalExec) -> ModalExecStart:
|
||||
full_command = f"cd {shlex.quote(prepared.cwd)} && {prepared.command}"
|
||||
result_holder = {"value": None, "error": None}
|
||||
|
||||
def _run():
|
||||
@@ -351,7 +340,7 @@ class ModalEnvironment(BaseEnvironment):
|
||||
"bash",
|
||||
"-c",
|
||||
full_command,
|
||||
timeout=effective_timeout,
|
||||
timeout=prepared.timeout,
|
||||
)
|
||||
stdout = await process.stdout.read.aio()
|
||||
stderr = await process.stderr.read.aio()
|
||||
@@ -363,42 +352,31 @@ class ModalEnvironment(BaseEnvironment):
|
||||
output = stdout
|
||||
if stderr:
|
||||
output = f"{stdout}\n{stderr}" if stdout else stderr
|
||||
return output, exit_code
|
||||
return self._result(output, exit_code)
|
||||
|
||||
output, exit_code = self._worker.run_coroutine(
|
||||
result_holder["value"] = self._worker.run_coroutine(
|
||||
_do_execute(),
|
||||
timeout=effective_timeout + 30,
|
||||
timeout=prepared.timeout + 30,
|
||||
)
|
||||
result_holder["value"] = {
|
||||
"output": output,
|
||||
"returncode": exit_code,
|
||||
}
|
||||
except Exception as e:
|
||||
result_holder["error"] = e
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True)
|
||||
t.start()
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.2)
|
||||
if is_interrupted():
|
||||
try:
|
||||
self._worker.run_coroutine(
|
||||
self._sandbox.terminate.aio(),
|
||||
timeout=15,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"output": "[Command interrupted - Modal sandbox terminated]",
|
||||
"returncode": 130,
|
||||
}
|
||||
return ModalExecStart(handle=_DirectModalExecHandle(thread=t, result_holder=result_holder))
|
||||
|
||||
if result_holder["error"]:
|
||||
return {
|
||||
"output": f"Modal execution error: {result_holder['error']}",
|
||||
"returncode": 1,
|
||||
}
|
||||
return result_holder["value"]
|
||||
def _poll_modal_exec(self, handle: _DirectModalExecHandle) -> dict | None:
|
||||
if handle.thread.is_alive():
|
||||
return None
|
||||
if handle.result_holder["error"]:
|
||||
return self._error_result(f"Modal execution error: {handle.result_holder['error']}")
|
||||
return handle.result_holder["value"]
|
||||
|
||||
def _cancel_modal_exec(self, handle: _DirectModalExecHandle) -> None:
|
||||
self._worker.run_coroutine(
|
||||
self._sandbox.terminate.aio(),
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
|
||||
|
||||
178
tools/environments/modal_common.py
Normal file
178
tools/environments/modal_common.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Shared Hermes-side execution flow for Modal transports.
|
||||
|
||||
This module deliberately stops at the Hermes boundary:
|
||||
- command preparation
|
||||
- cwd/timeout normalization
|
||||
- stdin/sudo shell wrapping
|
||||
- common result shape
|
||||
- interrupt/cancel polling
|
||||
|
||||
Direct Modal and managed Modal keep separate transport logic, persistence, and
|
||||
trust-boundary decisions in their own modules.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shlex
|
||||
import time
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PreparedModalExec:
|
||||
"""Normalized command data passed to a transport-specific exec runner."""
|
||||
|
||||
command: str
|
||||
cwd: str
|
||||
timeout: int
|
||||
stdin_data: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModalExecStart:
|
||||
"""Transport response after starting an exec."""
|
||||
|
||||
handle: Any | None = None
|
||||
immediate_result: dict | None = None
|
||||
|
||||
|
||||
def wrap_modal_stdin_heredoc(command: str, stdin_data: str) -> str:
|
||||
"""Append stdin as a shell heredoc for transports without stdin piping."""
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
while marker in stdin_data:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
return f"{command} << '{marker}'\n{stdin_data}\n{marker}"
|
||||
|
||||
|
||||
def wrap_modal_sudo_pipe(command: str, sudo_stdin: str) -> str:
|
||||
"""Feed sudo via a shell pipe for transports without direct stdin piping."""
|
||||
return f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {command}"
|
||||
|
||||
|
||||
class BaseModalExecutionEnvironment(BaseEnvironment):
|
||||
"""Common execute() flow for direct and managed Modal transports."""
|
||||
|
||||
_stdin_mode = "payload"
|
||||
_poll_interval_seconds = 0.25
|
||||
_client_timeout_grace_seconds: float | None = None
|
||||
_interrupt_output = "[Command interrupted]"
|
||||
_unexpected_error_prefix = "Modal execution error"
|
||||
|
||||
def execute(
|
||||
self,
|
||||
command: str,
|
||||
cwd: str = "",
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None,
|
||||
) -> dict:
|
||||
self._before_execute()
|
||||
prepared = self._prepare_modal_exec(
|
||||
command,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
stdin_data=stdin_data,
|
||||
)
|
||||
|
||||
try:
|
||||
start = self._start_modal_exec(prepared)
|
||||
except Exception as exc:
|
||||
return self._error_result(f"{self._unexpected_error_prefix}: {exc}")
|
||||
|
||||
if start.immediate_result is not None:
|
||||
return start.immediate_result
|
||||
|
||||
if start.handle is None:
|
||||
return self._error_result(
|
||||
f"{self._unexpected_error_prefix}: transport did not return an exec handle"
|
||||
)
|
||||
|
||||
deadline = None
|
||||
if self._client_timeout_grace_seconds is not None:
|
||||
deadline = time.monotonic() + prepared.timeout + self._client_timeout_grace_seconds
|
||||
|
||||
while True:
|
||||
if is_interrupted():
|
||||
try:
|
||||
self._cancel_modal_exec(start.handle)
|
||||
except Exception:
|
||||
pass
|
||||
return self._result(self._interrupt_output, 130)
|
||||
|
||||
try:
|
||||
result = self._poll_modal_exec(start.handle)
|
||||
except Exception as exc:
|
||||
return self._error_result(f"{self._unexpected_error_prefix}: {exc}")
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
try:
|
||||
self._cancel_modal_exec(start.handle)
|
||||
except Exception:
|
||||
pass
|
||||
return self._timeout_result_for_modal(prepared.timeout)
|
||||
|
||||
time.sleep(self._poll_interval_seconds)
|
||||
|
||||
def _before_execute(self) -> None:
|
||||
"""Hook for backends that need pre-exec sync or validation."""
|
||||
return None
|
||||
|
||||
def _prepare_modal_exec(
|
||||
self,
|
||||
command: str,
|
||||
*,
|
||||
cwd: str = "",
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None,
|
||||
) -> PreparedModalExec:
|
||||
effective_cwd = cwd or self.cwd
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
exec_command = command
|
||||
exec_stdin = stdin_data if self._stdin_mode == "payload" else None
|
||||
if stdin_data is not None and self._stdin_mode == "heredoc":
|
||||
exec_command = wrap_modal_stdin_heredoc(exec_command, stdin_data)
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(exec_command)
|
||||
if sudo_stdin is not None:
|
||||
exec_command = wrap_modal_sudo_pipe(exec_command, sudo_stdin)
|
||||
|
||||
return PreparedModalExec(
|
||||
command=exec_command,
|
||||
cwd=effective_cwd,
|
||||
timeout=effective_timeout,
|
||||
stdin_data=exec_stdin,
|
||||
)
|
||||
|
||||
def _result(self, output: str, returncode: int) -> dict:
|
||||
return {
|
||||
"output": output,
|
||||
"returncode": returncode,
|
||||
}
|
||||
|
||||
def _error_result(self, output: str) -> dict:
|
||||
return self._result(output, 1)
|
||||
|
||||
def _timeout_result_for_modal(self, timeout: int) -> dict:
|
||||
return self._result(f"Command timed out after {timeout}s", 124)
|
||||
|
||||
@abstractmethod
|
||||
def _start_modal_exec(self, prepared: PreparedModalExec) -> ModalExecStart:
|
||||
"""Begin a transport-specific exec."""
|
||||
|
||||
@abstractmethod
|
||||
def _poll_modal_exec(self, handle: Any) -> dict | None:
|
||||
"""Return a final result dict when complete, else ``None``."""
|
||||
|
||||
@abstractmethod
|
||||
def _cancel_modal_exec(self, handle: Any) -> None:
|
||||
"""Cancel or terminate the active transport exec."""
|
||||
Reference in New Issue
Block a user