feat(environments): unified spawn-per-call execution layer
Replace dual execution model (PersistentShellMixin + per-backend oneshot) with spawn-per-call + session snapshot for all backends except ManagedModal. Core changes: - Every command spawns a fresh bash process; session snapshot (env vars, functions, aliases) captured at init and re-sourced before each command - CWD persists via file-based read (local) or in-band stdout markers (remote) - ProcessHandle protocol + _ThreadedProcessHandle adapter for SDK backends - cancel_fn wired for Modal (sandbox.terminate) and Daytona (sandbox.stop) - Shared utilities extracted: _pipe_stdin, _popen_bash, _load_json_store, _save_json_store, _file_mtime_key, _SYNC_INTERVAL_SECONDS - Rate-limited file sync unified in base _before_execute() with _sync_files() hook - execute_oneshot() removed; all 11 call sites in code_execution_tool.py migrated to execute() - Daytona timeout wrapper replaced with SDK-native timeout parameter - persistent_shell.py deleted (291 lines) Backend-specific: - Local: process-group kill via os.killpg, file-based CWD read - Docker: -e env flags only on init_session, not per-command - SSH: shlex.quote transport, ControlMaster connection reuse - Singularity: apptainer exec with instance://, no forced --pwd - Modal: _AsyncWorker + _ThreadedProcessHandle, cancel_fn -> sandbox.terminate - Daytona: SDK-level timeout (not shell wrapper), cancel_fn -> sandbox.stop - ManagedModal: unchanged (gateway owns execution); docstring added explaining why
This commit is contained in:
174
tests/tools/test_base_environment.py
Normal file
174
tests/tools/test_base_environment.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Tests for BaseEnvironment unified execution model.
|
||||
|
||||
Tests _wrap_command(), _extract_cwd_from_output(), _embed_stdin_heredoc(),
|
||||
init_session() failure handling, and the CWD marker contract.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from tools.environments.base import BaseEnvironment, _cwd_marker
|
||||
|
||||
|
||||
class _TestableEnv(BaseEnvironment):
|
||||
"""Concrete subclass for testing base class methods."""
|
||||
|
||||
def __init__(self, cwd="/tmp", timeout=10):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
def _run_bash(self, cmd_string, *, login=False, timeout=120, stdin_data=None):
|
||||
raise NotImplementedError("Use mock")
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestWrapCommand:
|
||||
def test_basic_shape(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("echo hello", "/tmp")
|
||||
|
||||
assert "source" in wrapped
|
||||
assert "cd /tmp" in wrapped or "cd '/tmp'" in wrapped
|
||||
assert "eval 'echo hello'" in wrapped
|
||||
assert "__hermes_ec=$?" in wrapped
|
||||
assert "export -p >" in wrapped
|
||||
assert "pwd -P >" in wrapped
|
||||
assert env._cwd_marker in wrapped
|
||||
assert "exit $__hermes_ec" in wrapped
|
||||
|
||||
def test_no_snapshot_skips_source(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = False
|
||||
wrapped = env._wrap_command("echo hello", "/tmp")
|
||||
|
||||
assert "source" not in wrapped
|
||||
|
||||
def test_single_quote_escaping(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("echo 'hello world'", "/tmp")
|
||||
|
||||
assert "eval 'echo '\\''hello world'\\'''" in wrapped
|
||||
|
||||
def test_tilde_not_quoted(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("ls", "~")
|
||||
|
||||
assert "cd ~" in wrapped
|
||||
assert "cd '~'" not in wrapped
|
||||
|
||||
def test_cd_failure_exit_126(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("ls", "/nonexistent")
|
||||
|
||||
assert "exit 126" in wrapped
|
||||
|
||||
|
||||
class TestExtractCwdFromOutput:
|
||||
def test_happy_path(self):
|
||||
env = _TestableEnv()
|
||||
marker = env._cwd_marker
|
||||
result = {
|
||||
"output": f"hello\n{marker}/home/user{marker}\n",
|
||||
}
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert env.cwd == "/home/user"
|
||||
assert marker not in result["output"]
|
||||
|
||||
def test_missing_marker(self):
|
||||
env = _TestableEnv()
|
||||
result = {"output": "hello world\n"}
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert env.cwd == "/tmp" # unchanged
|
||||
|
||||
def test_marker_in_command_output(self):
|
||||
"""If the marker appears in command output AND as the real marker,
|
||||
rfind grabs the last (real) one."""
|
||||
env = _TestableEnv()
|
||||
marker = env._cwd_marker
|
||||
result = {
|
||||
"output": f"user typed {marker} in their output\nreal output\n{marker}/correct/path{marker}\n",
|
||||
}
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert env.cwd == "/correct/path"
|
||||
|
||||
def test_output_cleaned(self):
|
||||
env = _TestableEnv()
|
||||
marker = env._cwd_marker
|
||||
result = {
|
||||
"output": f"hello\n{marker}/tmp{marker}\n",
|
||||
}
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert "hello" in result["output"]
|
||||
assert marker not in result["output"]
|
||||
|
||||
|
||||
class TestEmbedStdinHeredoc:
|
||||
def test_heredoc_format(self):
|
||||
result = BaseEnvironment._embed_stdin_heredoc("cat", "hello world")
|
||||
|
||||
assert result.startswith("cat << '")
|
||||
assert "hello world" in result
|
||||
assert "HERMES_STDIN_" in result
|
||||
|
||||
def test_unique_delimiter_each_call(self):
|
||||
r1 = BaseEnvironment._embed_stdin_heredoc("cat", "data")
|
||||
r2 = BaseEnvironment._embed_stdin_heredoc("cat", "data")
|
||||
|
||||
# Extract delimiters
|
||||
d1 = r1.split("'")[1]
|
||||
d2 = r2.split("'")[1]
|
||||
assert d1 != d2 # UUID-based, should be unique
|
||||
|
||||
|
||||
class TestInitSessionFailure:
|
||||
def test_snapshot_ready_false_on_failure(self):
|
||||
env = _TestableEnv()
|
||||
|
||||
def failing_run_bash(*args, **kwargs):
|
||||
raise RuntimeError("bash not found")
|
||||
|
||||
env._run_bash = failing_run_bash
|
||||
env.init_session()
|
||||
|
||||
assert env._snapshot_ready is False
|
||||
|
||||
def test_login_flag_when_snapshot_not_ready(self):
|
||||
"""When _snapshot_ready=False, execute() should pass login=True to _run_bash."""
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = False
|
||||
|
||||
calls = []
|
||||
def mock_run_bash(cmd, *, login=False, timeout=120, stdin_data=None):
|
||||
calls.append({"login": login})
|
||||
# Return a mock process handle
|
||||
mock = MagicMock()
|
||||
mock.poll.return_value = 0
|
||||
mock.returncode = 0
|
||||
mock.stdout = iter([])
|
||||
return mock
|
||||
|
||||
env._run_bash = mock_run_bash
|
||||
env.execute("echo test")
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["login"] is True
|
||||
|
||||
|
||||
class TestCwdMarker:
|
||||
def test_marker_contains_session_id(self):
|
||||
env = _TestableEnv()
|
||||
assert env._session_id in env._cwd_marker
|
||||
|
||||
def test_unique_per_instance(self):
|
||||
env1 = _TestableEnv()
|
||||
env2 = _TestableEnv()
|
||||
assert env1._cwd_marker != env2._cwd_marker
|
||||
@@ -22,21 +22,19 @@ import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from tools.environments.local import (
|
||||
LocalEnvironment,
|
||||
_clean_shell_noise,
|
||||
_extract_fenced_output,
|
||||
_OUTPUT_FENCE,
|
||||
_SHELL_NOISE_SUBSTRINGS,
|
||||
)
|
||||
from tools.environments.local import LocalEnvironment
|
||||
from tools.file_operations import ShellFileOperations
|
||||
|
||||
|
||||
# ── Shared noise detection ───────────────────────────────────────────────
|
||||
# Every known shell noise pattern. If ANY of these appear in output that
|
||||
# isn't explicitly expected, the test fails with a clear message.
|
||||
# Known shell noise patterns that should never appear in command output.
|
||||
|
||||
_ALL_NOISE_PATTERNS = list(_SHELL_NOISE_SUBSTRINGS) + [
|
||||
_ALL_NOISE_PATTERNS = [
|
||||
"bash: cannot set terminal process group",
|
||||
"bash: no job control in this shell",
|
||||
"no job control in this shell",
|
||||
"cannot set terminal process group",
|
||||
"tcsetattr: Inappropriate ioctl for device",
|
||||
"bash: ",
|
||||
"Inappropriate ioctl",
|
||||
"Auto-suggestions:",
|
||||
@@ -88,134 +86,6 @@ def populated_dir(tmp_path):
|
||||
return tmp_path
|
||||
|
||||
|
||||
# ── _clean_shell_noise unit tests ────────────────────────────────────────
|
||||
|
||||
class TestCleanShellNoise:
|
||||
def test_single_noise_line(self):
|
||||
output = "bash: no job control in this shell\nhello world\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello world\n"
|
||||
|
||||
def test_double_noise_lines(self):
|
||||
output = (
|
||||
"bash: cannot set terminal process group (-1): Inappropriate ioctl for device\n"
|
||||
"bash: no job control in this shell\n"
|
||||
"actual output here\n"
|
||||
)
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "actual output here\n"
|
||||
_assert_clean(result)
|
||||
|
||||
def test_tcsetattr_noise(self):
|
||||
output = (
|
||||
"bash: [12345: 2 (255)] tcsetattr: Inappropriate ioctl for device\n"
|
||||
"real content\n"
|
||||
)
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "real content\n"
|
||||
_assert_clean(result)
|
||||
|
||||
def test_triple_noise_lines(self):
|
||||
output = (
|
||||
"bash: cannot set terminal process group (-1): Inappropriate ioctl for device\n"
|
||||
"bash: no job control in this shell\n"
|
||||
"bash: [999: 2 (255)] tcsetattr: Inappropriate ioctl for device\n"
|
||||
"clean\n"
|
||||
)
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "clean\n"
|
||||
|
||||
def test_no_noise_untouched(self):
|
||||
assert _clean_shell_noise("hello\nworld\n") == "hello\nworld\n"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _clean_shell_noise("") == ""
|
||||
|
||||
def test_only_noise_produces_empty(self):
|
||||
output = "bash: no job control in this shell\n"
|
||||
result = _clean_shell_noise(output)
|
||||
_assert_clean(result)
|
||||
|
||||
def test_noise_in_middle_not_stripped(self):
|
||||
"""Noise in the middle is real output and should be preserved."""
|
||||
output = "real\nbash: no job control in this shell\nmore real\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == output
|
||||
|
||||
def test_zsh_restored_session(self):
|
||||
output = "Restored session: Mon Mar 2 22:16:54 +03 2026\nhello\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_zsh_saving_session_trailing(self):
|
||||
output = "hello\nSaving session...completed.\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_zsh_oh_my_zsh_banner(self):
|
||||
output = "Oh My Zsh on! | Auto-suggestions: press right\nhello\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_zsh_full_noise_sandwich(self):
|
||||
"""Both leading and trailing zsh noise stripped."""
|
||||
output = (
|
||||
"Restored session: Mon Mar 2\n"
|
||||
"command not found: docker\n"
|
||||
"Oh My Zsh on!\n"
|
||||
"actual output\n"
|
||||
"Saving session...completed.\n"
|
||||
)
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "actual output\n"
|
||||
|
||||
def test_last_login_stripped(self):
|
||||
output = "Last login: Mon Mar 2 22:00:00 on ttys001\nhello\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello\n"
|
||||
|
||||
|
||||
# ── _extract_fenced_output unit tests ────────────────────────────────────
|
||||
|
||||
class TestExtractFencedOutput:
|
||||
def test_normal_fenced_output(self):
|
||||
raw = f"noise\n{_OUTPUT_FENCE}hello world\n{_OUTPUT_FENCE}more noise\n"
|
||||
assert _extract_fenced_output(raw) == "hello world\n"
|
||||
|
||||
def test_no_trailing_newline(self):
|
||||
"""printf output with no trailing newline is preserved."""
|
||||
raw = f"noise{_OUTPUT_FENCE}exact{_OUTPUT_FENCE}noise"
|
||||
assert _extract_fenced_output(raw) == "exact"
|
||||
|
||||
def test_no_fences_falls_back(self):
|
||||
"""Without fences, falls back to pattern-based cleaning."""
|
||||
raw = "bash: no job control in this shell\nhello\n"
|
||||
result = _extract_fenced_output(raw)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_only_start_fence(self):
|
||||
"""Only start fence (e.g. user command called exit)."""
|
||||
raw = f"noise{_OUTPUT_FENCE}hello\nSaving session...\n"
|
||||
result = _extract_fenced_output(raw)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_user_outputs_fence_string(self):
|
||||
"""If user command outputs the fence marker, it is preserved."""
|
||||
raw = f"noise{_OUTPUT_FENCE}{_OUTPUT_FENCE}real\n{_OUTPUT_FENCE}noise"
|
||||
result = _extract_fenced_output(raw)
|
||||
# first fence -> last fence captures the middle including user's fence
|
||||
assert _OUTPUT_FENCE in result
|
||||
assert "real\n" in result
|
||||
|
||||
def test_empty_command_output(self):
|
||||
raw = f"noise{_OUTPUT_FENCE}{_OUTPUT_FENCE}noise"
|
||||
assert _extract_fenced_output(raw) == ""
|
||||
|
||||
def test_multiline_output(self):
|
||||
raw = f"noise\n{_OUTPUT_FENCE}line1\nline2\nline3\n{_OUTPUT_FENCE}noise\n"
|
||||
assert _extract_fenced_output(raw) == "line1\nline2\nline3\n"
|
||||
|
||||
|
||||
# ── LocalEnvironment.execute() ───────────────────────────────────────────
|
||||
|
||||
class TestLocalEnvironmentExecute:
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
"""Tests for the local persistent shell backend."""
|
||||
|
||||
import glob as glob_mod
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments.local import LocalEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
|
||||
|
||||
class TestLocalConfig:
|
||||
def test_local_persistent_default_false(self, monkeypatch):
|
||||
monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False)
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is False
|
||||
|
||||
def test_local_persistent_true(self, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "true")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is True
|
||||
|
||||
def test_local_persistent_yes(self, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "yes")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is True
|
||||
|
||||
|
||||
class TestMergeOutput:
|
||||
def test_stdout_only(self):
|
||||
assert PersistentShellMixin._merge_output("out", "") == "out"
|
||||
|
||||
def test_stderr_only(self):
|
||||
assert PersistentShellMixin._merge_output("", "err") == "err"
|
||||
|
||||
def test_both(self):
|
||||
assert PersistentShellMixin._merge_output("out", "err") == "out\nerr"
|
||||
|
||||
def test_empty(self):
|
||||
assert PersistentShellMixin._merge_output("", "") == ""
|
||||
|
||||
def test_strips_trailing_newlines(self):
|
||||
assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr"
|
||||
|
||||
|
||||
class TestLocalOneShotRegression:
|
||||
def test_echo(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
r = env.execute("echo hello")
|
||||
assert r["returncode"] == 0
|
||||
assert "hello" in r["output"]
|
||||
env.cleanup()
|
||||
|
||||
def test_exit_code(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
r = env.execute("exit 42")
|
||||
assert r["returncode"] == 42
|
||||
env.cleanup()
|
||||
|
||||
def test_state_does_not_persist(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
env.execute("export HERMES_ONESHOT_LOCAL=yes")
|
||||
r = env.execute("echo $HERMES_ONESHOT_LOCAL")
|
||||
assert r["output"].strip() == ""
|
||||
env.cleanup()
|
||||
|
||||
def test_oneshot_heredoc_does_not_leak_fence_wrapper(self):
|
||||
"""Heredoc closing line must not be merged with the fence wrapper tail."""
|
||||
env = LocalEnvironment(persistent=False)
|
||||
cmd = "cat <<'H_EOF'\nheredoc body line\nH_EOF"
|
||||
r = env.execute(cmd)
|
||||
env.cleanup()
|
||||
assert r["returncode"] == 0
|
||||
assert "heredoc body line" in r["output"]
|
||||
assert "__hermes_rc" not in r["output"]
|
||||
assert "printf '" not in r["output"]
|
||||
assert "exit $" not in r["output"]
|
||||
|
||||
|
||||
class TestLocalPersistent:
|
||||
@pytest.fixture
|
||||
def env(self):
|
||||
e = LocalEnvironment(persistent=True)
|
||||
yield e
|
||||
e.cleanup()
|
||||
|
||||
def test_echo(self, env):
|
||||
r = env.execute("echo hello-persistent")
|
||||
assert r["returncode"] == 0
|
||||
assert "hello-persistent" in r["output"]
|
||||
|
||||
def test_env_var_persists(self, env):
|
||||
env.execute("export HERMES_LOCAL_PERSIST_TEST=works")
|
||||
r = env.execute("echo $HERMES_LOCAL_PERSIST_TEST")
|
||||
assert r["output"].strip() == "works"
|
||||
|
||||
def test_cwd_persists(self, env):
|
||||
env.execute("cd /tmp")
|
||||
r = env.execute("pwd")
|
||||
assert r["output"].strip() == "/tmp"
|
||||
|
||||
def test_exit_code(self, env):
|
||||
r = env.execute("(exit 42)")
|
||||
assert r["returncode"] == 42
|
||||
|
||||
def test_stderr(self, env):
|
||||
r = env.execute("echo oops >&2")
|
||||
assert r["returncode"] == 0
|
||||
assert "oops" in r["output"]
|
||||
|
||||
def test_multiline_output(self, env):
|
||||
r = env.execute("echo a; echo b; echo c")
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert lines == ["a", "b", "c"]
|
||||
|
||||
def test_timeout_then_recovery(self, env):
|
||||
r = env.execute("sleep 999", timeout=2)
|
||||
assert r["returncode"] in (124, 130)
|
||||
r = env.execute("echo alive")
|
||||
assert r["returncode"] == 0
|
||||
assert "alive" in r["output"]
|
||||
|
||||
def test_large_output(self, env):
|
||||
r = env.execute("seq 1 1000")
|
||||
assert r["returncode"] == 0
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert len(lines) == 1000
|
||||
assert lines[0] == "1"
|
||||
assert lines[-1] == "1000"
|
||||
|
||||
def test_shell_variable_persists(self, env):
|
||||
env.execute("MY_LOCAL_VAR=hello123")
|
||||
r = env.execute("echo $MY_LOCAL_VAR")
|
||||
assert r["output"].strip() == "hello123"
|
||||
|
||||
def test_cleanup_removes_temp_files(self, env):
|
||||
env.execute("echo warmup")
|
||||
prefix = env._temp_prefix
|
||||
assert len(glob_mod.glob(f"{prefix}-*")) > 0
|
||||
env.cleanup()
|
||||
remaining = glob_mod.glob(f"{prefix}-*")
|
||||
assert remaining == []
|
||||
|
||||
def test_state_does_not_leak_between_instances(self):
|
||||
env1 = LocalEnvironment(persistent=True)
|
||||
env2 = LocalEnvironment(persistent=True)
|
||||
try:
|
||||
env1.execute("export LEAK_TEST=from_env1")
|
||||
r = env2.execute("echo $LEAK_TEST")
|
||||
assert r["output"].strip() == ""
|
||||
finally:
|
||||
env1.cleanup()
|
||||
env2.cleanup()
|
||||
|
||||
def test_special_characters_in_command(self, env):
|
||||
r = env.execute("echo 'hello world'")
|
||||
assert r["output"].strip() == "hello world"
|
||||
|
||||
def test_pipe_command(self, env):
|
||||
r = env.execute("echo hello | tr 'h' 'H'")
|
||||
assert r["output"].strip() == "Hello"
|
||||
|
||||
def test_multiple_commands_semicolon(self, env):
|
||||
r = env.execute("X=42; echo $X")
|
||||
assert r["output"].strip() == "42"
|
||||
@@ -110,7 +110,7 @@ class _FakeResponse:
|
||||
def test_managed_modal_execute_polls_until_completed(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
calls = []
|
||||
poll_count = {"value": 0}
|
||||
@@ -173,7 +173,7 @@ def test_managed_modal_create_sends_a_stable_idempotency_key(monkeypatch):
|
||||
def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
|
||||
interrupt_event = _install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
calls = []
|
||||
|
||||
@@ -215,7 +215,7 @@ def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
|
||||
def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||
@@ -293,7 +293,7 @@ def test_managed_modal_rejects_host_credential_passthrough():
|
||||
def test_managed_modal_execute_times_out_and_cancels(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
calls = []
|
||||
monotonic_values = iter([0.0, 12.5])
|
||||
|
||||
144
tests/tools/test_threaded_process_handle.py
Normal file
144
tests/tools/test_threaded_process_handle.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Tests for _ThreadedProcessHandle — the adapter for SDK backends."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from tools.environments.base import _ThreadedProcessHandle
|
||||
|
||||
|
||||
class TestBasicExecution:
|
||||
def test_successful_execution(self):
|
||||
def exec_fn():
|
||||
return ("hello world", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
assert handle.returncode == 0
|
||||
output = handle.stdout.read()
|
||||
assert "hello world" in output
|
||||
|
||||
def test_nonzero_exit_code(self):
|
||||
def exec_fn():
|
||||
return ("error occurred", 42)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
assert handle.returncode == 42
|
||||
output = handle.stdout.read()
|
||||
assert "error occurred" in output
|
||||
|
||||
def test_exception_in_exec_fn(self):
|
||||
def exec_fn():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
assert handle.returncode == 1
|
||||
|
||||
def test_empty_output(self):
|
||||
def exec_fn():
|
||||
return ("", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
assert handle.returncode == 0
|
||||
output = handle.stdout.read()
|
||||
assert output == ""
|
||||
|
||||
|
||||
class TestPolling:
|
||||
def test_poll_returns_none_while_running(self):
|
||||
event = threading.Event()
|
||||
|
||||
def exec_fn():
|
||||
event.wait(timeout=5)
|
||||
return ("done", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
assert handle.poll() is None
|
||||
|
||||
event.set()
|
||||
handle.wait(timeout=5)
|
||||
assert handle.poll() == 0
|
||||
|
||||
def test_poll_returns_returncode_when_done(self):
|
||||
def exec_fn():
|
||||
return ("ok", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
assert handle.poll() == 0
|
||||
|
||||
|
||||
class TestCancelFn:
|
||||
def test_cancel_fn_called_on_kill(self):
|
||||
called = threading.Event()
|
||||
|
||||
def cancel():
|
||||
called.set()
|
||||
|
||||
def exec_fn():
|
||||
time.sleep(10)
|
||||
return ("", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
handle.kill()
|
||||
assert called.is_set()
|
||||
|
||||
def test_cancel_fn_none_is_safe(self):
|
||||
def exec_fn():
|
||||
return ("ok", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn, cancel_fn=None)
|
||||
handle.kill() # should not raise
|
||||
handle.wait(timeout=5)
|
||||
assert handle.returncode == 0
|
||||
|
||||
def test_cancel_fn_exception_swallowed(self):
|
||||
def cancel():
|
||||
raise RuntimeError("cancel failed")
|
||||
|
||||
def exec_fn():
|
||||
return ("ok", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
handle.kill() # should not raise despite cancel raising
|
||||
handle.wait(timeout=5)
|
||||
|
||||
|
||||
class TestStdoutPipe:
|
||||
def test_stdout_is_readable(self):
|
||||
def exec_fn():
|
||||
return ("line1\nline2\nline3\n", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
lines = handle.stdout.readlines()
|
||||
assert len(lines) == 3
|
||||
assert lines[0] == "line1\n"
|
||||
|
||||
def test_stdout_iterable(self):
|
||||
def exec_fn():
|
||||
return ("a\nb\nc\n", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
collected = list(handle.stdout)
|
||||
assert len(collected) == 3
|
||||
|
||||
def test_unicode_output(self):
|
||||
def exec_fn():
|
||||
return ("hello 世界 🌍\n", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
output = handle.stdout.read()
|
||||
assert "世界" in output
|
||||
assert "🌍" in output
|
||||
@@ -18,7 +18,7 @@ Architecture (two transports):
|
||||
2. Parent ships both files to the remote environment
|
||||
3. Script runs inside the terminal backend (Docker/SSH/Modal/Daytona/etc.)
|
||||
4. Tool calls are written as request files; a polling thread on the parent
|
||||
reads them via execute_oneshot(), dispatches, and writes response files
|
||||
reads them via env.execute(), dispatches, and writes response files
|
||||
5. The script polls for response files and continues
|
||||
|
||||
In both cases, only the script's stdout is returned to the LLM; intermediate
|
||||
@@ -536,7 +536,7 @@ def _ship_file_to_remote(env, remote_path: str, content: str) -> None:
|
||||
quotes are fine.
|
||||
"""
|
||||
encoded = base64.b64encode(content.encode("utf-8")).decode("ascii")
|
||||
env.execute_oneshot(
|
||||
env.execute(
|
||||
f"echo '{encoded}' | base64 -d > {remote_path}",
|
||||
cwd="/",
|
||||
timeout=30,
|
||||
@@ -555,9 +555,9 @@ def _rpc_poll_loop(
|
||||
):
|
||||
"""Poll the remote filesystem for tool call requests and dispatch them.
|
||||
|
||||
Runs in a background thread. Uses ``env.execute_oneshot()`` so it can
|
||||
operate concurrently with the script-execution thread that holds
|
||||
``env.execute()`` (important for persistent-shell backends like SSH).
|
||||
Runs in a background thread. Each ``env.execute()`` spawns an
|
||||
independent process, so these calls run safely concurrent with the
|
||||
script-execution thread.
|
||||
"""
|
||||
from model_tools import handle_function_call
|
||||
|
||||
@@ -566,7 +566,7 @@ def _rpc_poll_loop(
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
# List pending request files (skip .tmp partials)
|
||||
ls_result = env.execute_oneshot(
|
||||
ls_result = env.execute(
|
||||
f"ls -1 {rpc_dir}/req_* 2>/dev/null || true",
|
||||
cwd="/",
|
||||
timeout=10,
|
||||
@@ -590,7 +590,7 @@ def _rpc_poll_loop(
|
||||
call_start = time.monotonic()
|
||||
|
||||
# Read request
|
||||
read_result = env.execute_oneshot(
|
||||
read_result = env.execute(
|
||||
f"cat {req_file}",
|
||||
cwd="/",
|
||||
timeout=10,
|
||||
@@ -600,7 +600,7 @@ def _rpc_poll_loop(
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.debug("Malformed RPC request in %s", req_file)
|
||||
# Remove bad request to avoid infinite retry
|
||||
env.execute_oneshot(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||
env.execute(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||
continue
|
||||
|
||||
tool_name = request.get("tool", "")
|
||||
@@ -664,7 +664,7 @@ def _rpc_poll_loop(
|
||||
encoded_result = base64.b64encode(
|
||||
tool_result.encode("utf-8")
|
||||
).decode("ascii")
|
||||
env.execute_oneshot(
|
||||
env.execute(
|
||||
f"echo '{encoded_result}' | base64 -d > {res_file}.tmp"
|
||||
f" && mv {res_file}.tmp {res_file}",
|
||||
cwd="/",
|
||||
@@ -672,7 +672,7 @@ def _rpc_poll_loop(
|
||||
)
|
||||
|
||||
# Remove the request file
|
||||
env.execute_oneshot(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||
env.execute(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||
|
||||
except Exception as e:
|
||||
if not stop_event.is_set():
|
||||
@@ -717,7 +717,7 @@ def _execute_remote(
|
||||
|
||||
try:
|
||||
# Verify Python is available on the remote
|
||||
py_check = env.execute_oneshot(
|
||||
py_check = env.execute(
|
||||
"command -v python3 >/dev/null 2>&1 && echo OK",
|
||||
cwd="/", timeout=15,
|
||||
)
|
||||
@@ -734,7 +734,7 @@ def _execute_remote(
|
||||
})
|
||||
|
||||
# Create sandbox directory on remote
|
||||
env.execute_oneshot(
|
||||
env.execute(
|
||||
f"mkdir -p {sandbox_dir}/rpc", cwd="/", timeout=10,
|
||||
)
|
||||
|
||||
@@ -806,7 +806,7 @@ def _execute_remote(
|
||||
|
||||
# Clean up remote sandbox dir
|
||||
try:
|
||||
env.execute_oneshot(
|
||||
env.execute(
|
||||
f"rm -rf {sandbox_dir}", cwd="/", timeout=15,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -1,11 +1,27 @@
|
||||
"""Base class for all Hermes execution environment backends."""
|
||||
"""Base class for all Hermes execution environment backends.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
Unified spawn-per-call model: every command spawns a fresh ``bash -c`` process.
|
||||
A session snapshot (env vars, functions, aliases) is captured once at init and
|
||||
re-sourced before each command. CWD persists via in-band stdout markers (remote)
|
||||
or a temp file (local).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import IO, Callable, Protocol
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_sandbox_dir() -> Path:
|
||||
@@ -23,30 +39,501 @@ def get_sandbox_dir() -> Path:
|
||||
return p
|
||||
|
||||
|
||||
class BaseEnvironment(ABC):
|
||||
"""Common interface for all Hermes execution backends.
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared constants and utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Subclasses implement execute() and cleanup(). Shared helpers eliminate
|
||||
duplicated subprocess boilerplate across backends.
|
||||
_SYNC_INTERVAL_SECONDS = 5.0
|
||||
|
||||
|
||||
def _pipe_stdin(proc: subprocess.Popen, data: str) -> None:
|
||||
"""Write *data* to proc.stdin on a daemon thread to avoid pipe-buffer deadlocks."""
|
||||
|
||||
def _write():
|
||||
try:
|
||||
proc.stdin.write(data)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
|
||||
threading.Thread(target=_write, daemon=True).start()
|
||||
|
||||
|
||||
def _popen_bash(
|
||||
cmd: list[str], stdin_data: str | None = None, **kwargs
|
||||
) -> subprocess.Popen:
|
||||
"""Spawn a subprocess with standard stdout/stderr/stdin setup.
|
||||
|
||||
If *stdin_data* is provided, writes it asynchronously via :func:`_pipe_stdin`.
|
||||
Backends with special Popen needs (e.g. local's ``preexec_fn``) can bypass
|
||||
this and call :func:`_pipe_stdin` directly.
|
||||
"""
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL,
|
||||
text=True,
|
||||
**kwargs,
|
||||
)
|
||||
if stdin_data is not None:
|
||||
_pipe_stdin(proc, stdin_data)
|
||||
return proc
|
||||
|
||||
|
||||
def _load_json_store(path: Path) -> dict:
|
||||
"""Load a JSON file as a dict, returning ``{}`` on any error."""
|
||||
if path.exists():
|
||||
try:
|
||||
return json.loads(path.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save_json_store(path: Path, data: dict) -> None:
|
||||
"""Write *data* as pretty-printed JSON to *path*."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
def _file_mtime_key(host_path: str) -> tuple[float, int] | None:
|
||||
"""Return ``(mtime, size)`` for cache comparison, or ``None`` if unreadable."""
|
||||
try:
|
||||
st = Path(host_path).stat()
|
||||
return (st.st_mtime, st.st_size)
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ProcessHandle protocol
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ProcessHandle(Protocol):
|
||||
"""Duck type that every backend's _run_bash() must return.
|
||||
|
||||
subprocess.Popen satisfies this natively. SDK backends (Modal, Daytona)
|
||||
return _ThreadedProcessHandle which adapts their blocking calls.
|
||||
"""
|
||||
|
||||
def poll(self) -> int | None: ...
|
||||
def kill(self) -> None: ...
|
||||
def wait(self, timeout: float | None = None) -> int: ...
|
||||
|
||||
@property
|
||||
def stdout(self) -> IO[str] | None: ...
|
||||
|
||||
@property
|
||||
def returncode(self) -> int | None: ...
|
||||
|
||||
|
||||
class _ThreadedProcessHandle:
|
||||
"""Adapter for SDK backends (Modal, Daytona) that have no real subprocess.
|
||||
|
||||
Wraps a blocking ``exec_fn() -> (output_str, exit_code)`` in a background
|
||||
thread and exposes a ProcessHandle-compatible interface. An optional
|
||||
``cancel_fn`` is invoked on ``kill()`` for backend-specific cancellation
|
||||
(e.g. Modal sandbox.terminate, Daytona sandbox.stop).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exec_fn: Callable[[], tuple[str, int]],
|
||||
cancel_fn: Callable[[], None] | None = None,
|
||||
):
|
||||
self._cancel_fn = cancel_fn
|
||||
self._done = threading.Event()
|
||||
self._returncode: int | None = None
|
||||
self._error: Exception | None = None
|
||||
|
||||
# Pipe for stdout — drain thread in _wait_for_process reads the read end.
|
||||
read_fd, write_fd = os.pipe()
|
||||
self._stdout = os.fdopen(read_fd, "r", encoding="utf-8", errors="replace")
|
||||
self._write_fd = write_fd
|
||||
|
||||
def _worker():
|
||||
try:
|
||||
output, exit_code = exec_fn()
|
||||
self._returncode = exit_code
|
||||
# Write output into the pipe so drain thread picks it up.
|
||||
try:
|
||||
os.write(self._write_fd, output.encode("utf-8", errors="replace"))
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
self._error = exc
|
||||
self._returncode = 1
|
||||
finally:
|
||||
try:
|
||||
os.close(self._write_fd)
|
||||
except OSError:
|
||||
pass
|
||||
self._done.set()
|
||||
|
||||
t = threading.Thread(target=_worker, daemon=True)
|
||||
t.start()
|
||||
|
||||
@property
|
||||
def stdout(self):
|
||||
return self._stdout
|
||||
|
||||
@property
|
||||
def returncode(self) -> int | None:
|
||||
return self._returncode
|
||||
|
||||
def poll(self) -> int | None:
|
||||
return self._returncode if self._done.is_set() else None
|
||||
|
||||
def kill(self):
|
||||
if self._cancel_fn:
|
||||
try:
|
||||
self._cancel_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def wait(self, timeout: float | None = None) -> int:
|
||||
self._done.wait(timeout=timeout)
|
||||
return self._returncode
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CWD marker for remote backends
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _cwd_marker(session_id: str) -> str:
|
||||
return f"__HERMES_CWD_{session_id}__"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BaseEnvironment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BaseEnvironment(ABC):
|
||||
"""Common interface and unified execution flow for all Hermes backends.
|
||||
|
||||
Subclasses implement ``_run_bash()`` and ``cleanup()``. The base class
|
||||
provides ``execute()`` with session snapshot sourcing, CWD tracking,
|
||||
interrupt handling, and timeout enforcement.
|
||||
"""
|
||||
|
||||
# Subclasses that embed stdin as a heredoc (Modal, Daytona) set this.
|
||||
_stdin_mode: str = "pipe" # "pipe" or "heredoc"
|
||||
|
||||
# Snapshot creation timeout (override for slow cold-starts).
|
||||
_snapshot_timeout: int = 30
|
||||
|
||||
def __init__(self, cwd: str, timeout: int, env: dict = None):
|
||||
self.cwd = cwd
|
||||
self.timeout = timeout
|
||||
self.env = env or {}
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Execute a command, return {"output": str, "returncode": int}."""
|
||||
...
|
||||
self._session_id = uuid.uuid4().hex[:12]
|
||||
self._snapshot_path = f"/tmp/hermes-snap-{self._session_id}.sh"
|
||||
self._cwd_file = f"/tmp/hermes-cwd-{self._session_id}.txt"
|
||||
self._cwd_marker = _cwd_marker(self._session_id)
|
||||
self._snapshot_ready = False
|
||||
self._last_sync_time: float | None = (
|
||||
None # set to 0 by backends that need file sync
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Abstract methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_bash(
|
||||
self,
|
||||
cmd_string: str,
|
||||
*,
|
||||
login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None,
|
||||
) -> ProcessHandle:
|
||||
"""Spawn a bash process to run *cmd_string*.
|
||||
|
||||
Returns a ProcessHandle (subprocess.Popen or _ThreadedProcessHandle).
|
||||
Must be overridden by every backend.
|
||||
"""
|
||||
raise NotImplementedError(f"{type(self).__name__} must implement _run_bash()")
|
||||
|
||||
@abstractmethod
|
||||
def cleanup(self):
|
||||
"""Release backend resources (container, instance, connection)."""
|
||||
...
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session snapshot (init_session)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def init_session(self):
|
||||
"""Capture login shell environment into a snapshot file.
|
||||
|
||||
Called once after backend construction. On success, sets
|
||||
``_snapshot_ready = True`` so subsequent commands source the snapshot
|
||||
instead of running with ``bash -l``.
|
||||
"""
|
||||
# Full capture: env vars, functions (filtered), aliases, shell options.
|
||||
bootstrap = (
|
||||
f"export -p > {self._snapshot_path}\n"
|
||||
f"declare -f | grep -vE '^_[^_]' >> {self._snapshot_path}\n"
|
||||
f"alias -p >> {self._snapshot_path}\n"
|
||||
f"echo 'shopt -s expand_aliases' >> {self._snapshot_path}\n"
|
||||
f"echo 'set +e' >> {self._snapshot_path}\n"
|
||||
f"echo 'set +u' >> {self._snapshot_path}\n"
|
||||
f"pwd -P > {self._cwd_file} 2>/dev/null || true\n"
|
||||
f"printf '\\n{self._cwd_marker}%s{self._cwd_marker}\\n' \"$(pwd -P)\"\n"
|
||||
)
|
||||
try:
|
||||
proc = self._run_bash(bootstrap, login=True, timeout=self._snapshot_timeout)
|
||||
result = self._wait_for_process(proc, timeout=self._snapshot_timeout)
|
||||
self._snapshot_ready = True
|
||||
self._update_cwd(result)
|
||||
logger.info(
|
||||
"Session snapshot created (session=%s, cwd=%s)",
|
||||
self._session_id,
|
||||
self.cwd,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"init_session failed (session=%s): %s — "
|
||||
"falling back to bash -l per command",
|
||||
self._session_id,
|
||||
exc,
|
||||
)
|
||||
self._snapshot_ready = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Command wrapping
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _wrap_command(self, command: str, cwd: str) -> str:
|
||||
"""Build the full bash script that sources snapshot, cd's, runs command,
|
||||
re-dumps env vars, and emits CWD markers."""
|
||||
escaped = command.replace("'", "'\\''")
|
||||
|
||||
parts = []
|
||||
|
||||
# Source snapshot (env vars from previous commands)
|
||||
if self._snapshot_ready:
|
||||
parts.append(f"source {self._snapshot_path} 2>/dev/null || true")
|
||||
|
||||
# cd to working directory — let bash expand ~ natively
|
||||
quoted_cwd = (
|
||||
shlex.quote(cwd) if cwd != "~" and not cwd.startswith("~/") else cwd
|
||||
)
|
||||
parts.append(f"cd {quoted_cwd} || exit 126")
|
||||
|
||||
# Run the actual command
|
||||
parts.append(f"eval '{escaped}'")
|
||||
parts.append("__hermes_ec=$?")
|
||||
|
||||
# Re-dump env vars to snapshot (last-writer-wins for concurrent calls)
|
||||
if self._snapshot_ready:
|
||||
parts.append(f"export -p > {self._snapshot_path} 2>/dev/null || true")
|
||||
|
||||
# Write CWD to file (local reads this) and stdout marker (remote parses this)
|
||||
parts.append(f"pwd -P > {self._cwd_file} 2>/dev/null || true")
|
||||
# Use a distinct line for the marker. The leading \n ensures
|
||||
# the marker starts on its own line even if the command doesn't
|
||||
# end with a newline (e.g. printf 'exact'). We'll strip this
|
||||
# injected newline in _extract_cwd_from_output.
|
||||
parts.append(
|
||||
f"printf '\\n{self._cwd_marker}%s{self._cwd_marker}\\n' \"$(pwd -P)\""
|
||||
)
|
||||
parts.append("exit $__hermes_ec")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stdin heredoc embedding (for SDK backends)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _embed_stdin_heredoc(command: str, stdin_data: str) -> str:
|
||||
"""Append stdin_data as a shell heredoc to the command string."""
|
||||
delimiter = f"HERMES_STDIN_{uuid.uuid4().hex[:12]}"
|
||||
return f"{command} << '{delimiter}'\n{stdin_data}\n{delimiter}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Process lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _wait_for_process(self, proc: ProcessHandle, timeout: int = 120) -> dict:
|
||||
"""Poll-based wait with interrupt checking and stdout draining.
|
||||
|
||||
Shared across all backends — not overridden.
|
||||
"""
|
||||
output_chunks: list[str] = []
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
output_chunks.append(line)
|
||||
except UnicodeDecodeError:
|
||||
output_chunks.clear()
|
||||
output_chunks.append(
|
||||
"[binary output detected — raw bytes not displayable]"
|
||||
)
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
|
||||
drain_thread = threading.Thread(target=_drain, daemon=True)
|
||||
drain_thread.start()
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
self._kill_process(proc)
|
||||
drain_thread.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
self._kill_process(proc)
|
||||
drain_thread.join(timeout=2)
|
||||
partial = "".join(output_chunks)
|
||||
timeout_msg = f"\n[Command timed out after {timeout}s]"
|
||||
return {
|
||||
"output": partial + timeout_msg
|
||||
if partial
|
||||
else timeout_msg.lstrip(),
|
||||
"returncode": 124,
|
||||
}
|
||||
time.sleep(0.2)
|
||||
|
||||
drain_thread.join(timeout=5)
|
||||
|
||||
try:
|
||||
proc.stdout.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"output": "".join(output_chunks), "returncode": proc.returncode}
|
||||
|
||||
def _kill_process(self, proc: ProcessHandle):
|
||||
"""Terminate a process. Subclasses may override for process-group kill."""
|
||||
try:
|
||||
proc.kill()
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# CWD extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _update_cwd(self, result: dict):
|
||||
"""Extract CWD from command output. Override for local file-based read."""
|
||||
self._extract_cwd_from_output(result)
|
||||
|
||||
def _extract_cwd_from_output(self, result: dict):
|
||||
"""Parse the __HERMES_CWD_{session}__ marker from stdout output.
|
||||
|
||||
Updates self.cwd and strips the marker from result["output"].
|
||||
Used by remote backends (Docker, SSH, Modal, Daytona, Singularity).
|
||||
"""
|
||||
output = result.get("output", "")
|
||||
marker = self._cwd_marker
|
||||
last = output.rfind(marker)
|
||||
if last == -1:
|
||||
return
|
||||
|
||||
# Find the opening marker before this closing one
|
||||
search_start = max(0, last - 4096) # CWD path won't be >4KB
|
||||
first = output.rfind(marker, search_start, last)
|
||||
if first == -1 or first == last:
|
||||
return
|
||||
|
||||
cwd_path = output[first + len(marker) : last].strip()
|
||||
if cwd_path:
|
||||
self.cwd = cwd_path
|
||||
|
||||
# Strip the marker line AND the \n we injected before it.
|
||||
# The wrapper emits: printf '\n__MARKER__%s__MARKER__\n'
|
||||
# So the output looks like: <cmd output>\n__MARKER__path__MARKER__\n
|
||||
# We want to remove everything from the injected \n onwards.
|
||||
line_start = output.rfind("\n", 0, first)
|
||||
if line_start == -1:
|
||||
line_start = first
|
||||
line_end = output.find("\n", last + len(marker))
|
||||
line_end = line_end + 1 if line_end != -1 else len(output)
|
||||
|
||||
result["output"] = output[:line_start] + output[line_end:]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Hooks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _before_execute(self):
|
||||
"""Rate-limited file sync before each command.
|
||||
|
||||
Backends that need pre-command sync set ``self._last_sync_time = 0``
|
||||
in ``__init__`` and override :meth:`_sync_files`. Backends needing
|
||||
extra pre-exec logic (e.g. Daytona sandbox restart check) override
|
||||
this method and call ``super()._before_execute()``.
|
||||
"""
|
||||
if self._last_sync_time is not None:
|
||||
now = time.monotonic()
|
||||
if now - self._last_sync_time >= _SYNC_INTERVAL_SECONDS:
|
||||
self._sync_files()
|
||||
self._last_sync_time = now
|
||||
|
||||
def _sync_files(self):
|
||||
"""Push files to remote environment. Called rate-limited by _before_execute."""
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unified execute()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(
|
||||
self,
|
||||
command: str,
|
||||
cwd: str = "",
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None,
|
||||
) -> dict:
|
||||
"""Execute a command, return {"output": str, "returncode": int}."""
|
||||
self._before_execute()
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
effective_timeout = timeout or self.timeout
|
||||
effective_cwd = cwd or self.cwd
|
||||
|
||||
# Merge sudo stdin with caller stdin
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
# Embed stdin as heredoc for backends that need it
|
||||
if effective_stdin and self._stdin_mode == "heredoc":
|
||||
exec_command = self._embed_stdin_heredoc(exec_command, effective_stdin)
|
||||
effective_stdin = None
|
||||
|
||||
wrapped = self._wrap_command(exec_command, effective_cwd)
|
||||
|
||||
# Use login shell if snapshot failed (so user's profile still loads)
|
||||
login = not self._snapshot_ready
|
||||
|
||||
proc = self._run_bash(
|
||||
wrapped, login=login, timeout=effective_timeout, stdin_data=effective_stdin
|
||||
)
|
||||
result = self._wait_for_process(proc, timeout=effective_timeout)
|
||||
self._update_cwd(result)
|
||||
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def stop(self):
|
||||
"""Alias for cleanup (compat with older callers)."""
|
||||
self.cleanup()
|
||||
@@ -57,53 +544,12 @@ class BaseEnvironment(ABC):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared helpers (eliminate duplication across backends)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _prepare_command(self, command: str) -> tuple[str, str | None]:
|
||||
"""Transform sudo commands if SUDO_PASSWORD is available.
|
||||
|
||||
Returns:
|
||||
(transformed_command, sudo_stdin) — see _transform_sudo_command
|
||||
for the full contract. Callers that drive a subprocess directly
|
||||
should prepend sudo_stdin (when not None) to any stdin_data they
|
||||
pass to Popen. Callers that embed stdin via heredoc (modal,
|
||||
daytona) handle sudo_stdin in their own execute() method.
|
||||
"""
|
||||
"""Transform sudo commands if SUDO_PASSWORD is available."""
|
||||
from tools.terminal_tool import _transform_sudo_command
|
||||
|
||||
return _transform_sudo_command(command)
|
||||
|
||||
def _build_run_kwargs(self, timeout: int | None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Build common subprocess.run kwargs for non-interactive execution."""
|
||||
kw = {
|
||||
"text": True,
|
||||
"timeout": timeout or self.timeout,
|
||||
"encoding": "utf-8",
|
||||
"errors": "replace",
|
||||
"stdout": subprocess.PIPE,
|
||||
"stderr": subprocess.STDOUT,
|
||||
}
|
||||
if stdin_data is not None:
|
||||
kw["input"] = stdin_data
|
||||
else:
|
||||
kw["stdin"] = subprocess.DEVNULL
|
||||
return kw
|
||||
|
||||
def execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Execute a command bypassing any persistent shell.
|
||||
|
||||
Safe for concurrent use alongside a long-running execute() call.
|
||||
Backends that maintain a persistent shell (SSH, Local) override this
|
||||
to route through their oneshot path, avoiding the shell lock.
|
||||
Non-persistent backends delegate to execute().
|
||||
"""
|
||||
return self.execute(command, cwd=cwd, timeout=timeout,
|
||||
stdin_data=stdin_data)
|
||||
|
||||
def _timeout_result(self, timeout: int | None) -> dict:
|
||||
"""Standard return dict when a command times out."""
|
||||
return {
|
||||
|
||||
@@ -6,17 +6,18 @@ and resumed on next creation, preserving the filesystem across sessions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import math
|
||||
import shlex
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
from tools.environments.base import (
|
||||
BaseEnvironment,
|
||||
_ThreadedProcessHandle,
|
||||
_file_mtime_key,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,22 +25,25 @@ logger = logging.getLogger(__name__)
|
||||
class DaytonaEnvironment(BaseEnvironment):
|
||||
"""Daytona cloud sandbox execution backend.
|
||||
|
||||
Uses stopped/started sandbox lifecycle for filesystem persistence
|
||||
instead of snapshots, making it faster and stateless on the host.
|
||||
Spawn-per-call via _ThreadedProcessHandle wrapping blocking SDK calls.
|
||||
cancel_fn wired to sandbox.stop() for interrupt support.
|
||||
Shell timeout wrapper preserved (SDK timeout unreliable).
|
||||
"""
|
||||
|
||||
_stdin_mode = "heredoc"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str,
|
||||
cwd: str = "/home/daytona",
|
||||
timeout: int = 60,
|
||||
cpu: int = 1,
|
||||
memory: int = 5120, # MB (hermes convention)
|
||||
disk: int = 10240, # MB (Daytona platform max is 10GB)
|
||||
memory: int = 5120,
|
||||
disk: int = 10240,
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
):
|
||||
self._requested_cwd = cwd
|
||||
requested_cwd = cwd
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
from daytona import (
|
||||
@@ -53,16 +57,18 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._SandboxState = SandboxState
|
||||
self._DaytonaError = DaytonaError
|
||||
self._daytona = Daytona()
|
||||
self._sandbox = None
|
||||
self._lock = threading.Lock()
|
||||
self._last_sync_time: float = 0
|
||||
|
||||
memory_gib = max(1, math.ceil(memory / 1024))
|
||||
disk_gib = max(1, math.ceil(disk / 1024))
|
||||
if disk_gib > 10:
|
||||
warnings.warn(
|
||||
f"Daytona: requested disk ({disk_gib}GB) exceeds platform limit (10GB). "
|
||||
f"Capping to 10GB. Set container_disk: 10240 in config to silence this.",
|
||||
f"Capping to 10GB.",
|
||||
stacklevel=2,
|
||||
)
|
||||
disk_gib = 10
|
||||
@@ -71,9 +77,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
labels = {"hermes_task_id": task_id}
|
||||
sandbox_name = f"hermes-{task_id}"
|
||||
|
||||
# Try to resume an existing sandbox for this task
|
||||
if self._persistent:
|
||||
# 1. Try name-based lookup (new path)
|
||||
try:
|
||||
self._sandbox = self._daytona.get(sandbox_name)
|
||||
self._sandbox.start()
|
||||
@@ -86,7 +90,6 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# 2. Legacy fallback: find sandbox created before the naming migration
|
||||
if self._sandbox is None:
|
||||
try:
|
||||
page = self._daytona.list(labels=labels, page=1, limit=1)
|
||||
@@ -100,7 +103,6 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# Create a fresh sandbox if we don't have one
|
||||
if self._sandbox is None:
|
||||
self._sandbox = self._daytona.create(
|
||||
CreateSandboxFromImageParams(
|
||||
@@ -114,32 +116,25 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
logger.info("Daytona: created sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
|
||||
# Detect remote home dir first so mounts go to the right place.
|
||||
# Detect remote home dir
|
||||
self._remote_home = "/root"
|
||||
try:
|
||||
home = self._sandbox.process.exec("echo $HOME").result.strip()
|
||||
if home:
|
||||
self._remote_home = home
|
||||
if self._requested_cwd in ("~", "/home/daytona"):
|
||||
if requested_cwd in ("~", "/home/daytona"):
|
||||
self.cwd = home
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("Daytona: resolved home to %s, cwd to %s", self._remote_home, self.cwd)
|
||||
|
||||
# Track synced files to avoid redundant uploads.
|
||||
# Key: remote_path, Value: (mtime, size)
|
||||
self._synced_files: Dict[str, tuple] = {}
|
||||
|
||||
# Upload credential files and skills directory into the sandbox.
|
||||
self._sync_skills_and_credentials()
|
||||
self._sync_files()
|
||||
self.init_session()
|
||||
|
||||
def _upload_if_changed(self, host_path: str, remote_path: str) -> bool:
|
||||
"""Upload a file if its mtime/size changed since last sync."""
|
||||
hp = Path(host_path)
|
||||
try:
|
||||
stat = hp.stat()
|
||||
file_key = (stat.st_mtime, stat.st_size)
|
||||
except OSError:
|
||||
file_key = _file_mtime_key(host_path)
|
||||
if file_key is None:
|
||||
return False
|
||||
if self._synced_files.get(remote_path) == file_key:
|
||||
return False
|
||||
@@ -153,20 +148,15 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
logger.debug("Daytona: upload failed %s: %s", host_path, e)
|
||||
return False
|
||||
|
||||
def _sync_skills_and_credentials(self) -> None:
|
||||
"""Upload changed credential files and skill files into the sandbox."""
|
||||
def _sync_files(self) -> None:
|
||||
container_base = f"{self._remote_home}/.hermes"
|
||||
try:
|
||||
from tools.credential_files import get_credential_file_mounts, iter_skills_files
|
||||
|
||||
for mount_entry in get_credential_file_mounts():
|
||||
remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1)
|
||||
if self._upload_if_changed(mount_entry["host_path"], remote_path):
|
||||
logger.debug("Daytona: synced credential %s", remote_path)
|
||||
|
||||
self._upload_if_changed(mount_entry["host_path"], remote_path)
|
||||
for entry in iter_skills_files(container_base=container_base):
|
||||
if self._upload_if_changed(entry["host_path"], entry["container_path"]):
|
||||
logger.debug("Daytona: synced skill %s", entry["container_path"])
|
||||
self._upload_if_changed(entry["host_path"], entry["container_path"])
|
||||
except Exception as e:
|
||||
logger.debug("Daytona: could not sync skills/credentials: %s", e)
|
||||
|
||||
@@ -177,111 +167,36 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
self._sandbox.start()
|
||||
logger.info("Daytona: restarted sandbox %s", self._sandbox.id)
|
||||
|
||||
def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict:
|
||||
"""Run exec in a background thread with interrupt polling.
|
||||
|
||||
The Daytona SDK's exec(timeout=...) parameter is unreliable (the
|
||||
server-side timeout is not enforced and the SDK has no client-side
|
||||
fallback), so we wrap the command with the shell ``timeout`` utility
|
||||
which reliably kills the process and returns exit code 124.
|
||||
"""
|
||||
# Wrap with shell `timeout` to enforce the deadline reliably.
|
||||
# Add a small buffer so the shell timeout fires before any SDK-level
|
||||
# timeout would, giving us a clean exit code 124.
|
||||
timed_command = f"timeout {timeout} sh -c {shlex.quote(exec_command)}"
|
||||
|
||||
result_holder: dict = {"value": None, "error": None}
|
||||
|
||||
def _run():
|
||||
try:
|
||||
response = self._sandbox.process.exec(
|
||||
timed_command, cwd=cwd,
|
||||
)
|
||||
result_holder["value"] = {
|
||||
"output": response.result or "",
|
||||
"returncode": response.exit_code,
|
||||
}
|
||||
except Exception as e:
|
||||
result_holder["error"] = e
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True)
|
||||
t.start()
|
||||
# Wait for timeout + generous buffer for network/SDK overhead
|
||||
deadline = time.monotonic() + timeout + 10
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.2)
|
||||
if is_interrupted():
|
||||
with self._lock:
|
||||
try:
|
||||
self._sandbox.stop()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"output": "[Command interrupted - Daytona sandbox stopped]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
# Shell timeout didn't fire and SDK is hung — force stop
|
||||
with self._lock:
|
||||
try:
|
||||
self._sandbox.stop()
|
||||
except Exception:
|
||||
pass
|
||||
return self._timeout_result(timeout)
|
||||
|
||||
if result_holder["error"]:
|
||||
return {"error": result_holder["error"]}
|
||||
return result_holder["value"]
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: Optional[int] = None,
|
||||
stdin_data: Optional[str] = None) -> dict:
|
||||
def _before_execute(self):
|
||||
"""Ensure sandbox is ready, then rate-limited file sync via base class."""
|
||||
with self._lock:
|
||||
self._ensure_sandbox_ready()
|
||||
# Incremental sync before each command so mid-session credential
|
||||
# refreshes and skill updates are picked up.
|
||||
self._sync_skills_and_credentials()
|
||||
super()._before_execute()
|
||||
|
||||
if stdin_data is not None:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
while marker in stdin_data:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
command = f"{command} << '{marker}'\n{stdin_data}\n{marker}"
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None):
|
||||
"""Return a _ThreadedProcessHandle wrapping a blocking Daytona SDK call."""
|
||||
sandbox = self._sandbox
|
||||
lock = self._lock
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
def cancel():
|
||||
with lock:
|
||||
try:
|
||||
sandbox.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Daytona sandboxes execute commands via the Daytona SDK and cannot
|
||||
# pipe subprocess stdin directly the way a local Popen can. When a
|
||||
# sudo password is present, use a shell-level pipe from printf so that
|
||||
# the password feeds sudo -S without appearing as an echo argument
|
||||
# embedded in the shell string. The password is still visible in the
|
||||
# remote sandbox's command line, but it is not exposed on the user's
|
||||
# local machine — which is the primary threat being mitigated.
|
||||
if sudo_stdin is not None:
|
||||
import shlex
|
||||
exec_command = (
|
||||
f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}"
|
||||
)
|
||||
effective_cwd = cwd or self.cwd or None
|
||||
effective_timeout = timeout or self.timeout
|
||||
if login:
|
||||
shell_cmd = f"bash -l -c {shlex.quote(cmd_string)}"
|
||||
else:
|
||||
shell_cmd = f"bash -c {shlex.quote(cmd_string)}"
|
||||
|
||||
result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout)
|
||||
def exec_fn() -> tuple[str, int]:
|
||||
response = sandbox.process.exec(shell_cmd, timeout=timeout)
|
||||
return (response.result or "", response.exit_code)
|
||||
|
||||
if "error" in result:
|
||||
from daytona import DaytonaError
|
||||
err = result["error"]
|
||||
if isinstance(err, DaytonaError):
|
||||
with self._lock:
|
||||
try:
|
||||
self._ensure_sandbox_ready()
|
||||
except Exception:
|
||||
return {"output": f"Daytona execution error: {err}", "returncode": 1}
|
||||
result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout)
|
||||
if "error" not in result:
|
||||
return result
|
||||
return {"output": f"Daytona execution error: {err}", "returncode": 1}
|
||||
|
||||
return result
|
||||
return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
|
||||
def cleanup(self):
|
||||
with self._lock:
|
||||
|
||||
@@ -8,18 +8,14 @@ persistence via bind mounts.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.base import BaseEnvironment, _popen_bash
|
||||
from tools.environments.local import _HERMES_PROVIDER_ENV_BLOCKLIST
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -431,6 +427,69 @@ class DockerEnvironment(BaseEnvironment):
|
||||
self._container_id = result.stdout.strip()
|
||||
logger.info(f"Started container {container_name} ({self._container_id[:12]})")
|
||||
|
||||
# Build the init-time env forwarding args (used only by init_session
|
||||
# to inject host env vars into the snapshot; subsequent commands get
|
||||
# them from the snapshot file).
|
||||
self._init_env_args = self._build_init_env_args()
|
||||
|
||||
# Initialize session snapshot inside the container
|
||||
self.init_session()
|
||||
|
||||
def _build_init_env_args(self) -> list[str]:
|
||||
"""Build -e KEY=VALUE args for injecting host env vars into init_session.
|
||||
|
||||
These are used once during init_session() so that export -p captures
|
||||
them into the snapshot. Subsequent execute() calls don't need -e flags.
|
||||
"""
|
||||
exec_env: dict[str, str] = dict(self._env)
|
||||
|
||||
explicit_forward_keys = set(self._forward_env)
|
||||
passthrough_keys: set[str] = set()
|
||||
try:
|
||||
from tools.env_passthrough import get_all_passthrough
|
||||
passthrough_keys = set(get_all_passthrough())
|
||||
except Exception:
|
||||
pass
|
||||
# Explicit docker_forward_env entries are an intentional opt-in and must
|
||||
# win over the generic Hermes secret blocklist. Only implicit passthrough
|
||||
# keys are filtered.
|
||||
forward_keys = explicit_forward_keys | (passthrough_keys - _HERMES_PROVIDER_ENV_BLOCKLIST)
|
||||
hermes_env = _load_hermes_env_vars() if forward_keys else {}
|
||||
for key in sorted(forward_keys):
|
||||
value = os.getenv(key)
|
||||
if value is None:
|
||||
value = hermes_env.get(key)
|
||||
if value is not None:
|
||||
exec_env[key] = value
|
||||
|
||||
args = []
|
||||
for key in sorted(exec_env):
|
||||
args.extend(["-e", f"{key}={exec_env[key]}"])
|
||||
return args
|
||||
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None) -> subprocess.Popen:
|
||||
"""Spawn a bash process inside the Docker container."""
|
||||
assert self._container_id, "Container not started"
|
||||
cmd = [self._docker_exe, "exec"]
|
||||
if stdin_data is not None:
|
||||
cmd.append("-i")
|
||||
|
||||
# Only inject -e env args during init_session (login=True).
|
||||
# Subsequent commands get env vars from the snapshot.
|
||||
if login:
|
||||
cmd.extend(self._init_env_args)
|
||||
|
||||
cmd.extend([self._container_id])
|
||||
|
||||
if login:
|
||||
cmd.extend(["bash", "-l", "-c", cmd_string])
|
||||
else:
|
||||
cmd.extend(["bash", "-c", cmd_string])
|
||||
|
||||
return _popen_bash(cmd, stdin_data)
|
||||
|
||||
@staticmethod
|
||||
def _storage_opt_supported() -> bool:
|
||||
"""Check if Docker's storage driver supports --storage-opt size=.
|
||||
@@ -471,112 +530,6 @@ class DockerEnvironment(BaseEnvironment):
|
||||
logger.debug("Docker --storage-opt support: %s", _storage_opt_ok)
|
||||
return _storage_opt_ok
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
work_dir = cwd or self.cwd
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
# Merge sudo password (if any) with caller-supplied stdin_data.
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
# docker exec -w doesn't expand ~, so prepend a cd into the command.
|
||||
# Keep ~ unquoted (for shell expansion) and quote only the subpath.
|
||||
if work_dir == "~":
|
||||
exec_command = f"cd ~ && {exec_command}"
|
||||
work_dir = "/"
|
||||
elif work_dir.startswith("~/"):
|
||||
exec_command = f"cd ~/{shlex.quote(work_dir[2:])} && {exec_command}"
|
||||
work_dir = "/"
|
||||
|
||||
assert self._container_id, "Container not started"
|
||||
cmd = [self._docker_exe, "exec"]
|
||||
if effective_stdin is not None:
|
||||
cmd.append("-i")
|
||||
cmd.extend(["-w", work_dir])
|
||||
# Build the per-exec environment: start with explicit docker_env values
|
||||
# (static config), then overlay docker_forward_env / skill env_passthrough
|
||||
# (dynamic from host process). Forward values take precedence.
|
||||
exec_env: dict[str, str] = dict(self._env)
|
||||
|
||||
explicit_forward_keys = set(self._forward_env)
|
||||
passthrough_keys: set[str] = set()
|
||||
try:
|
||||
from tools.env_passthrough import get_all_passthrough
|
||||
passthrough_keys = set(get_all_passthrough())
|
||||
except Exception:
|
||||
pass
|
||||
# Explicit docker_forward_env entries are an intentional opt-in and must
|
||||
# win over the generic Hermes secret blocklist. Only implicit passthrough
|
||||
# keys are filtered.
|
||||
forward_keys = explicit_forward_keys | (passthrough_keys - _HERMES_PROVIDER_ENV_BLOCKLIST)
|
||||
hermes_env = _load_hermes_env_vars() if forward_keys else {}
|
||||
for key in sorted(forward_keys):
|
||||
value = os.getenv(key)
|
||||
if value is None:
|
||||
value = hermes_env.get(key)
|
||||
if value is not None:
|
||||
exec_env[key] = value
|
||||
|
||||
for key in sorted(exec_env):
|
||||
cmd.extend(["-e", f"{key}={exec_env[key]}"])
|
||||
cmd.extend([self._container_id, "bash", "-lc", exec_command])
|
||||
|
||||
try:
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
if effective_stdin:
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
except Exception as e:
|
||||
return {"output": f"Docker execution error: {e}", "returncode": 1}
|
||||
|
||||
def cleanup(self):
|
||||
"""Stop and remove the container. Bind-mount dirs persist if persistent=True."""
|
||||
if self._container_id:
|
||||
|
||||
@@ -1,42 +1,22 @@
|
||||
"""Local execution environment with interrupt support and non-blocking I/O."""
|
||||
"""Local execution environment — spawn-per-call with session snapshot."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
from tools.environments.base import BaseEnvironment, _pipe_stdin
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
# Unique marker to isolate real command output from shell init/exit noise.
|
||||
# printf (no trailing newline) keeps the boundaries clean for splitting.
|
||||
_OUTPUT_FENCE = "__HERMES_FENCE_a9f7b3__"
|
||||
|
||||
# Hermes-internal env vars that should NOT leak into terminal subprocesses.
|
||||
# These are loaded from ~/.hermes/.env for Hermes' own LLM/provider calls
|
||||
# but can break external CLIs (e.g. codex) that also honor them.
|
||||
# See: https://github.com/NousResearch/hermes-agent/issues/1002
|
||||
#
|
||||
# Built dynamically from the provider registry so new providers are
|
||||
# automatically covered without manual blocklist maintenance.
|
||||
_HERMES_PROVIDER_ENV_FORCE_PREFIX = "_HERMES_FORCE_"
|
||||
|
||||
|
||||
def _build_provider_env_blocklist() -> frozenset:
|
||||
"""Derive the blocklist from provider, tool, and gateway config.
|
||||
|
||||
Automatically picks up api_key_env_vars and base_url_env_var from
|
||||
every registered provider, plus tool/messaging env vars from the
|
||||
optional config registry, so new Hermes-managed secrets are blocked
|
||||
in subprocesses without having to maintain multiple static lists.
|
||||
"""
|
||||
"""Derive the blocklist from provider, tool, and gateway config."""
|
||||
blocked: set[str] = set()
|
||||
|
||||
try:
|
||||
@@ -59,33 +39,30 @@ def _build_provider_env_blocklist() -> frozenset:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Vars not covered above but still Hermes-internal / conflict-prone.
|
||||
blocked.update({
|
||||
"OPENAI_BASE_URL",
|
||||
"OPENAI_API_KEY",
|
||||
"OPENAI_API_BASE", # legacy alias
|
||||
"OPENAI_API_BASE",
|
||||
"OPENAI_ORG_ID",
|
||||
"OPENAI_ORGANIZATION",
|
||||
"OPENROUTER_API_KEY",
|
||||
"ANTHROPIC_BASE_URL",
|
||||
"ANTHROPIC_TOKEN", # OAuth token (not in registry as env var)
|
||||
"ANTHROPIC_TOKEN",
|
||||
"CLAUDE_CODE_OAUTH_TOKEN",
|
||||
"LLM_MODEL",
|
||||
# Expanded isolation for other major providers (Issue #1002)
|
||||
"GOOGLE_API_KEY", # Gemini / Google AI Studio
|
||||
"DEEPSEEK_API_KEY", # DeepSeek
|
||||
"MISTRAL_API_KEY", # Mistral AI
|
||||
"GROQ_API_KEY", # Groq
|
||||
"TOGETHER_API_KEY", # Together AI
|
||||
"PERPLEXITY_API_KEY", # Perplexity
|
||||
"COHERE_API_KEY", # Cohere
|
||||
"FIREWORKS_API_KEY", # Fireworks AI
|
||||
"XAI_API_KEY", # xAI (Grok)
|
||||
"HELICONE_API_KEY", # LLM Observability proxy
|
||||
"GOOGLE_API_KEY",
|
||||
"DEEPSEEK_API_KEY",
|
||||
"MISTRAL_API_KEY",
|
||||
"GROQ_API_KEY",
|
||||
"TOGETHER_API_KEY",
|
||||
"PERPLEXITY_API_KEY",
|
||||
"COHERE_API_KEY",
|
||||
"FIREWORKS_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"HELICONE_API_KEY",
|
||||
"PARALLEL_API_KEY",
|
||||
"FIRECRAWL_API_KEY",
|
||||
"FIRECRAWL_API_URL",
|
||||
# Gateway/runtime config not represented in OPTIONAL_ENV_VARS.
|
||||
"TELEGRAM_HOME_CHANNEL",
|
||||
"TELEGRAM_HOME_CHANNEL_NAME",
|
||||
"DISCORD_HOME_CHANNEL",
|
||||
@@ -115,12 +92,10 @@ def _build_provider_env_blocklist() -> frozenset:
|
||||
"EMAIL_HOME_ADDRESS",
|
||||
"EMAIL_HOME_ADDRESS_NAME",
|
||||
"GATEWAY_ALLOWED_USERS",
|
||||
# Skills Hub / GitHub app auth paths and aliases.
|
||||
"GH_TOKEN",
|
||||
"GITHUB_APP_ID",
|
||||
"GITHUB_APP_PRIVATE_KEY_PATH",
|
||||
"GITHUB_APP_INSTALLATION_ID",
|
||||
# Remote sandbox backend credentials.
|
||||
"MODAL_TOKEN_ID",
|
||||
"MODAL_TOKEN_SECRET",
|
||||
"DAYTONA_API_KEY",
|
||||
@@ -132,13 +107,7 @@ _HERMES_PROVIDER_ENV_BLOCKLIST = _build_provider_env_blocklist()
|
||||
|
||||
|
||||
def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = None) -> dict:
|
||||
"""Filter Hermes-managed secrets from a subprocess environment.
|
||||
|
||||
`_HERMES_FORCE_<VAR>` entries in ``extra_env`` opt a blocked variable back in
|
||||
intentionally for callers that truly need it. Vars registered via
|
||||
:mod:`tools.env_passthrough` (skill-declared or user-configured) also
|
||||
bypass the blocklist.
|
||||
"""
|
||||
"""Filter Hermes-managed secrets from a subprocess environment."""
|
||||
try:
|
||||
from tools.env_passthrough import is_env_passthrough as _is_passthrough
|
||||
except Exception:
|
||||
@@ -163,33 +132,24 @@ def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = Non
|
||||
|
||||
|
||||
def _find_bash() -> str:
|
||||
"""Find bash for command execution.
|
||||
|
||||
The fence wrapper uses bash syntax (semicolons, $?, printf), so we
|
||||
must use bash — not the user's $SHELL which could be fish/zsh/etc.
|
||||
On Windows: uses Git Bash (bundled with Git for Windows).
|
||||
"""
|
||||
"""Find bash for command execution."""
|
||||
if not _IS_WINDOWS:
|
||||
return (
|
||||
shutil.which("bash")
|
||||
or ("/usr/bin/bash" if os.path.isfile("/usr/bin/bash") else None)
|
||||
or ("/bin/bash" if os.path.isfile("/bin/bash") else None)
|
||||
or os.environ.get("SHELL") # last resort: whatever they have
|
||||
or os.environ.get("SHELL")
|
||||
or "/bin/sh"
|
||||
)
|
||||
|
||||
# Windows: look for Git Bash (installed with Git for Windows).
|
||||
# Allow override via env var (same pattern as Claude Code).
|
||||
custom = os.environ.get("HERMES_GIT_BASH_PATH")
|
||||
if custom and os.path.isfile(custom):
|
||||
return custom
|
||||
|
||||
# shutil.which finds bash.exe if Git\bin is on PATH
|
||||
found = shutil.which("bash")
|
||||
if found:
|
||||
return found
|
||||
|
||||
# Check common Git for Windows install locations
|
||||
for candidate in (
|
||||
os.path.join(os.environ.get("ProgramFiles", r"C:\Program Files"), "Git", "bin", "bash.exe"),
|
||||
os.path.join(os.environ.get("ProgramFiles(x86)", r"C:\Program Files (x86)"), "Git", "bin", "bash.exe"),
|
||||
@@ -209,60 +169,7 @@ def _find_bash() -> str:
|
||||
_find_shell = _find_bash
|
||||
|
||||
|
||||
# Noise lines emitted by interactive shells when stdin is not a terminal.
|
||||
# Used as a fallback when output fence markers are missing.
|
||||
_SHELL_NOISE_SUBSTRINGS = (
|
||||
# bash
|
||||
"bash: cannot set terminal process group",
|
||||
"bash: no job control in this shell",
|
||||
"no job control in this shell",
|
||||
"cannot set terminal process group",
|
||||
"tcsetattr: Inappropriate ioctl for device",
|
||||
# zsh / oh-my-zsh / macOS terminal session
|
||||
"Restored session:",
|
||||
"Saving session...",
|
||||
"Last login:",
|
||||
"command not found:",
|
||||
"Oh My Zsh",
|
||||
"compinit:",
|
||||
)
|
||||
|
||||
|
||||
def _clean_shell_noise(output: str) -> str:
|
||||
"""Strip shell startup/exit warnings that leak when using -i without a TTY.
|
||||
|
||||
Removes lines matching known noise patterns from both the beginning
|
||||
and end of the output. Lines in the middle are left untouched.
|
||||
"""
|
||||
|
||||
def _is_noise(line: str) -> bool:
|
||||
return any(noise in line for noise in _SHELL_NOISE_SUBSTRINGS)
|
||||
|
||||
lines = output.split("\n")
|
||||
|
||||
# Strip leading noise
|
||||
while lines and _is_noise(lines[0]):
|
||||
lines.pop(0)
|
||||
|
||||
# Strip trailing noise (walk backwards, skip empty lines from split)
|
||||
end = len(lines) - 1
|
||||
while end >= 0 and (not lines[end] or _is_noise(lines[end])):
|
||||
end -= 1
|
||||
|
||||
if end < 0:
|
||||
return ""
|
||||
|
||||
cleaned = lines[: end + 1]
|
||||
result = "\n".join(cleaned)
|
||||
|
||||
# Preserve trailing newline if original had one
|
||||
if output.endswith("\n") and result and not result.endswith("\n"):
|
||||
result += "\n"
|
||||
return result
|
||||
|
||||
|
||||
# Standard PATH entries for environments with minimal PATH (e.g. systemd services).
|
||||
# Includes macOS Homebrew paths (/opt/homebrew/* for Apple Silicon).
|
||||
# Standard PATH entries for environments with minimal PATH.
|
||||
_SANE_PATH = (
|
||||
"/opt/homebrew/bin:/opt/homebrew/sbin:"
|
||||
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
@@ -290,197 +197,76 @@ def _make_run_env(env: dict) -> dict:
|
||||
return run_env
|
||||
|
||||
|
||||
def _extract_fenced_output(raw: str) -> str:
|
||||
"""Extract real command output from between fence markers.
|
||||
|
||||
The execute() method wraps each command with printf(FENCE) markers.
|
||||
This function finds the first and last fence and returns only the
|
||||
content between them, which is the actual command output free of
|
||||
any shell init/exit noise.
|
||||
|
||||
Falls back to pattern-based _clean_shell_noise if fences are missing.
|
||||
"""
|
||||
first = raw.find(_OUTPUT_FENCE)
|
||||
if first == -1:
|
||||
return _clean_shell_noise(raw)
|
||||
|
||||
start = first + len(_OUTPUT_FENCE)
|
||||
last = raw.rfind(_OUTPUT_FENCE)
|
||||
|
||||
if last <= first:
|
||||
# Only start fence found (e.g. user command called `exit`)
|
||||
return _clean_shell_noise(raw[start:])
|
||||
|
||||
return raw[start:last]
|
||||
|
||||
|
||||
class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
class LocalEnvironment(BaseEnvironment):
|
||||
"""Run commands directly on the host machine.
|
||||
|
||||
Features:
|
||||
- Popen + polling for interrupt support (user can cancel mid-command)
|
||||
- Background stdout drain thread to prevent pipe buffer deadlocks
|
||||
- stdin_data support for piping content (bypasses ARG_MAX limits)
|
||||
- sudo -S transform via SUDO_PASSWORD env var
|
||||
- Uses interactive login shell so full user env is available
|
||||
- Optional persistent shell mode (cwd/env vars survive across calls)
|
||||
Spawn-per-call: every execute() spawns a fresh bash process.
|
||||
Session snapshot preserves env vars across calls.
|
||||
CWD persists via file-based read after each command.
|
||||
"""
|
||||
|
||||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None,
|
||||
persistent: bool = False):
|
||||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None):
|
||||
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
|
||||
self.persistent = persistent
|
||||
if self.persistent:
|
||||
self._init_persistent_shell()
|
||||
self.init_session()
|
||||
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-local-{self._session_id}"
|
||||
|
||||
def _spawn_shell_process(self) -> subprocess.Popen:
|
||||
user_shell = _find_bash()
|
||||
run_env = _make_run_env(self.env)
|
||||
return subprocess.Popen(
|
||||
[user_shell, "-l"],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
env=run_env,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
def _read_temp_files(self, *paths: str) -> list[str]:
|
||||
results = []
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
with open(path) as f:
|
||||
results.append(f.read())
|
||||
else:
|
||||
results.append("")
|
||||
return results
|
||||
|
||||
def _kill_shell_children(self):
|
||||
if self._shell_pid is None:
|
||||
return
|
||||
try:
|
||||
subprocess.run(
|
||||
["pkill", "-P", str(self._shell_pid)],
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
def _cleanup_temp_files(self):
|
||||
for f in glob.glob(f"{self._temp_prefix}-*"):
|
||||
if os.path.exists(f):
|
||||
os.remove(f)
|
||||
|
||||
def _execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
work_dir = cwd or self.cwd or os.getcwd()
|
||||
effective_timeout = timeout or self.timeout
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
user_shell = _find_bash()
|
||||
# Newline-separated wrapper (not `cmd; __hermes_rc=...` on one line).
|
||||
# A trailing `; __hermes_rc` glued to `<<EOF` / a closing `EOF` line breaks
|
||||
# heredoc parsing: the delimiter must be alone on its line, otherwise the
|
||||
# rest of this script becomes heredoc body and leaks into stdout (e.g. gh
|
||||
# issue/PR flows that use here-documents for bodies).
|
||||
fenced_cmd = (
|
||||
f"printf '{_OUTPUT_FENCE}'\n"
|
||||
f"{exec_command}\n"
|
||||
f"__hermes_rc=$?\n"
|
||||
f"printf '{_OUTPUT_FENCE}'\n"
|
||||
f"exit $__hermes_rc\n"
|
||||
)
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None) -> subprocess.Popen:
|
||||
bash = _find_bash()
|
||||
args = [bash, "-l", "-c", cmd_string] if login else [bash, "-c", cmd_string]
|
||||
run_env = _make_run_env(self.env)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[user_shell, "-lic", fenced_cmd],
|
||||
args,
|
||||
text=True,
|
||||
cwd=work_dir,
|
||||
env=run_env,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL,
|
||||
stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
if effective_stdin is not None:
|
||||
def _write_stdin():
|
||||
if stdin_data is not None:
|
||||
_pipe_stdin(proc, stdin_data)
|
||||
|
||||
return proc
|
||||
|
||||
def _kill_process(self, proc):
|
||||
"""Kill the entire process group (all children)."""
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
pgid = os.getpgid(proc.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
threading.Thread(target=_write_stdin, daemon=True).start()
|
||||
|
||||
_output_chunks: list[str] = []
|
||||
|
||||
def _drain_stdout():
|
||||
proc.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except ValueError:
|
||||
proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
proc.stdout.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain_stdout, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
def _update_cwd(self, result: dict):
|
||||
"""Read CWD from temp file (local-only, no round-trip needed)."""
|
||||
try:
|
||||
cwd_path = open(self._cwd_file).read().strip()
|
||||
if cwd_path:
|
||||
self.cwd = cwd_path
|
||||
except (OSError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
pgid = os.getpgid(proc.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
partial = "".join(_output_chunks)
|
||||
timeout_msg = f"\n[Command timed out after {effective_timeout}s]"
|
||||
return {
|
||||
"output": partial + timeout_msg if partial else timeout_msg.lstrip(),
|
||||
"returncode": 124,
|
||||
}
|
||||
time.sleep(0.2)
|
||||
# Still strip the marker from output so it's not visible
|
||||
self._extract_cwd_from_output(result)
|
||||
|
||||
reader.join(timeout=5)
|
||||
output = _extract_fenced_output("".join(_output_chunks))
|
||||
return {"output": output, "returncode": proc.returncode}
|
||||
def cleanup(self):
|
||||
"""Clean up temp files."""
|
||||
for f in (self._snapshot_path, self._cwd_file):
|
||||
try:
|
||||
os.unlink(f)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@@ -10,7 +10,7 @@ import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tools.environments.modal_common import (
|
||||
from tools.environments.modal_utils import (
|
||||
BaseModalExecutionEnvironment,
|
||||
ModalExecStart,
|
||||
PreparedModalExec,
|
||||
|
||||
@@ -5,19 +5,19 @@ wrapper, while preserving Hermes' persistent snapshot behavior across sessions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import shlex
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.environments.modal_common import (
|
||||
BaseModalExecutionEnvironment,
|
||||
ModalExecStart,
|
||||
PreparedModalExec,
|
||||
from tools.environments.base import (
|
||||
BaseEnvironment,
|
||||
_ThreadedProcessHandle,
|
||||
_file_mtime_key,
|
||||
_load_json_store,
|
||||
_save_json_store,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -26,20 +26,12 @@ _SNAPSHOT_STORE = get_hermes_home() / "modal_snapshots.json"
|
||||
_DIRECT_SNAPSHOT_NAMESPACE = "direct"
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
"""Load snapshot ID mapping from disk."""
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
def _load_snapshots() -> dict:
|
||||
return _load_json_store(_SNAPSHOT_STORE)
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
"""Persist snapshot ID mapping to disk."""
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
def _save_snapshots(data: dict) -> None:
|
||||
_save_json_store(_SNAPSHOT_STORE, data)
|
||||
|
||||
|
||||
def _direct_snapshot_key(task_id: str) -> str:
|
||||
@@ -47,23 +39,18 @@ def _direct_snapshot_key(task_id: str) -> str:
|
||||
|
||||
|
||||
def _get_snapshot_restore_candidate(task_id: str) -> tuple[str | None, bool]:
|
||||
"""Return a snapshot id and whether it came from the legacy key format."""
|
||||
snapshots = _load_snapshots()
|
||||
|
||||
namespaced_key = _direct_snapshot_key(task_id)
|
||||
snapshot_id = snapshots.get(namespaced_key)
|
||||
if isinstance(snapshot_id, str) and snapshot_id:
|
||||
return snapshot_id, False
|
||||
|
||||
legacy_snapshot_id = snapshots.get(task_id)
|
||||
if isinstance(legacy_snapshot_id, str) and legacy_snapshot_id:
|
||||
return legacy_snapshot_id, True
|
||||
|
||||
return None, False
|
||||
|
||||
|
||||
def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None:
|
||||
"""Persist the direct Modal snapshot id under the direct namespace."""
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[_direct_snapshot_key(task_id)] = snapshot_id
|
||||
snapshots.pop(task_id, None)
|
||||
@@ -71,10 +58,8 @@ def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None:
|
||||
|
||||
|
||||
def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> None:
|
||||
"""Remove direct Modal snapshot entries for a task, including legacy keys."""
|
||||
snapshots = _load_snapshots()
|
||||
updated = False
|
||||
|
||||
for key in (_direct_snapshot_key(task_id), task_id):
|
||||
value = snapshots.get(key)
|
||||
if value is None:
|
||||
@@ -82,13 +67,15 @@ def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> Non
|
||||
if snapshot_id is None or value == snapshot_id:
|
||||
snapshots.pop(key, None)
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
_save_snapshots(snapshots)
|
||||
|
||||
|
||||
def _resolve_modal_image(image_spec: Any) -> Any:
|
||||
"""Convert registry references or snapshot ids into Modal image objects."""
|
||||
"""Convert registry references or snapshot ids into Modal image objects.
|
||||
|
||||
Includes add_python support for ubuntu/debian images (absorbed from PR 4511).
|
||||
"""
|
||||
import modal as _modal
|
||||
|
||||
if not isinstance(image_spec, str):
|
||||
@@ -97,12 +84,22 @@ def _resolve_modal_image(image_spec: Any) -> Any:
|
||||
if image_spec.startswith("im-"):
|
||||
return _modal.Image.from_id(image_spec)
|
||||
|
||||
# PR 4511: add python to ubuntu/debian images that don't have it
|
||||
lower = image_spec.lower()
|
||||
add_python = any(base in lower for base in ("ubuntu", "debian"))
|
||||
|
||||
setup_commands = [
|
||||
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
|
||||
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
|
||||
]
|
||||
if add_python:
|
||||
setup_commands.insert(0,
|
||||
"RUN apt-get update -qq && apt-get install -y -qq python3 python3-venv > /dev/null 2>&1 || true"
|
||||
)
|
||||
|
||||
return _modal.Image.from_registry(
|
||||
image_spec,
|
||||
setup_dockerfile_commands=[
|
||||
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
|
||||
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
|
||||
],
|
||||
setup_dockerfile_commands=setup_commands,
|
||||
)
|
||||
|
||||
|
||||
@@ -138,19 +135,15 @@ class _AsyncWorker:
|
||||
self._thread.join(timeout=10)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DirectModalExecHandle:
|
||||
thread: threading.Thread
|
||||
result_holder: Dict[str, Any]
|
||||
class ModalEnvironment(BaseEnvironment):
|
||||
"""Modal cloud execution via native Modal sandboxes.
|
||||
|
||||
|
||||
class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
"""Modal cloud execution via native Modal sandboxes."""
|
||||
Spawn-per-call via _ThreadedProcessHandle wrapping async SDK calls.
|
||||
cancel_fn wired to sandbox.terminate for interrupt support.
|
||||
"""
|
||||
|
||||
_stdin_mode = "heredoc"
|
||||
_poll_interval_seconds = 0.2
|
||||
_interrupt_output = "[Command interrupted - Modal sandbox terminated]"
|
||||
_unexpected_error_prefix = "Modal execution error"
|
||||
_snapshot_timeout = 60 # Modal cold starts can be slow
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -170,6 +163,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
self._app = None
|
||||
self._worker = _AsyncWorker()
|
||||
self._synced_files: Dict[str, tuple] = {}
|
||||
self._last_sync_time: float = 0
|
||||
|
||||
sandbox_kwargs = dict(modal_sandbox_kwargs or {})
|
||||
|
||||
@@ -199,27 +193,13 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
remote_path=mount_entry["container_path"],
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Modal: mounting credential %s -> %s",
|
||||
mount_entry["host_path"],
|
||||
mount_entry["container_path"],
|
||||
)
|
||||
|
||||
# Mount individual skill files (symlinks filtered out).
|
||||
skills_files = iter_skills_files()
|
||||
for entry in skills_files:
|
||||
for entry in iter_skills_files():
|
||||
cred_mounts.append(
|
||||
_modal.Mount.from_local_file(
|
||||
entry["host_path"],
|
||||
remote_path=entry["container_path"],
|
||||
)
|
||||
)
|
||||
if skills_files:
|
||||
logger.info("Modal: mounting %d skill files", len(skills_files))
|
||||
|
||||
# Mount host-side cache files (documents, images, audio,
|
||||
# screenshots). New files arriving mid-session are picked up
|
||||
# by _sync_files() before each command execution.
|
||||
cache_files = iter_cache_files()
|
||||
for entry in cache_files:
|
||||
cred_mounts.append(
|
||||
@@ -228,8 +208,6 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
remote_path=entry["container_path"],
|
||||
)
|
||||
)
|
||||
if cache_files:
|
||||
logger.info("Modal: mounting %d cache files", len(cache_files))
|
||||
except Exception as e:
|
||||
logger.debug("Modal: could not load credential file mounts: %s", e)
|
||||
|
||||
@@ -243,8 +221,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
existing_mounts.extend(cred_mounts)
|
||||
create_kwargs["mounts"] = existing_mounts
|
||||
sandbox = await _modal.Sandbox.create.aio(
|
||||
"sleep",
|
||||
"infinity",
|
||||
"sleep", "infinity",
|
||||
image=image_spec,
|
||||
app=app,
|
||||
timeout=int(create_kwargs.pop("timeout", 3600)),
|
||||
@@ -255,57 +232,41 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
try:
|
||||
target_image_spec = restored_snapshot_id or image
|
||||
try:
|
||||
# _resolve_modal_image keeps the Modal bootstrap fix together:
|
||||
# it applies setup_dockerfile_commands with ensurepip before
|
||||
# Modal builds registry images, while snapshot ids restore via
|
||||
# modal.Image.from_id() without rebuilding.
|
||||
effective_image = _resolve_modal_image(target_image_spec)
|
||||
self._app, self._sandbox = self._worker.run_coroutine(
|
||||
_create_sandbox(effective_image),
|
||||
timeout=300,
|
||||
_create_sandbox(effective_image), timeout=300,
|
||||
)
|
||||
except Exception as exc:
|
||||
if not restored_snapshot_id:
|
||||
raise
|
||||
|
||||
logger.warning(
|
||||
"Modal: failed to restore snapshot %s, retrying with base image: %s",
|
||||
restored_snapshot_id[:20],
|
||||
exc,
|
||||
restored_snapshot_id[:20], exc,
|
||||
)
|
||||
_delete_direct_snapshot(self._task_id, restored_snapshot_id)
|
||||
base_image = _resolve_modal_image(image)
|
||||
self._app, self._sandbox = self._worker.run_coroutine(
|
||||
_create_sandbox(base_image),
|
||||
timeout=300,
|
||||
_create_sandbox(base_image), timeout=300,
|
||||
)
|
||||
else:
|
||||
if restored_snapshot_id and restored_from_legacy_key:
|
||||
_store_direct_snapshot(self._task_id, restored_snapshot_id)
|
||||
logger.info(
|
||||
"Modal: migrated legacy snapshot entry for task %s",
|
||||
self._task_id,
|
||||
)
|
||||
except Exception:
|
||||
self._worker.stop()
|
||||
raise
|
||||
|
||||
logger.info("Modal: sandbox created (task=%s)", self._task_id)
|
||||
self.init_session()
|
||||
|
||||
def _push_file_to_sandbox(self, host_path: str, container_path: str) -> bool:
|
||||
"""Push a single file into the sandbox if changed. Returns True if synced."""
|
||||
hp = Path(host_path)
|
||||
try:
|
||||
stat = hp.stat()
|
||||
file_key = (stat.st_mtime, stat.st_size)
|
||||
except OSError:
|
||||
"""Push a single file into the sandbox if changed."""
|
||||
file_key = _file_mtime_key(host_path)
|
||||
if file_key is None:
|
||||
return False
|
||||
|
||||
if self._synced_files.get(container_path) == file_key:
|
||||
return False
|
||||
|
||||
try:
|
||||
content = hp.read_bytes()
|
||||
content = Path(host_path).read_bytes()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -326,85 +287,55 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
return True
|
||||
|
||||
def _sync_files(self) -> None:
|
||||
"""Push credential, skill, and cache files into the running sandbox.
|
||||
|
||||
Runs before each command. Uses mtime+size caching so only changed
|
||||
files are pushed (~13μs overhead in the no-op case). Cache files
|
||||
are especially important here — new uploads/screenshots may appear
|
||||
mid-session after sandbox creation.
|
||||
"""
|
||||
"""Push credential, skill, and cache files into the running sandbox."""
|
||||
try:
|
||||
from tools.credential_files import (
|
||||
get_credential_file_mounts,
|
||||
iter_skills_files,
|
||||
iter_cache_files,
|
||||
)
|
||||
|
||||
for entry in get_credential_file_mounts():
|
||||
if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]):
|
||||
logger.debug("Modal: synced credential %s", entry["container_path"])
|
||||
|
||||
self._push_file_to_sandbox(entry["host_path"], entry["container_path"])
|
||||
for entry in iter_skills_files():
|
||||
if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]):
|
||||
logger.debug("Modal: synced skill file %s", entry["container_path"])
|
||||
|
||||
self._push_file_to_sandbox(entry["host_path"], entry["container_path"])
|
||||
for entry in iter_cache_files():
|
||||
if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]):
|
||||
logger.debug("Modal: synced cache file %s", entry["container_path"])
|
||||
self._push_file_to_sandbox(entry["host_path"], entry["container_path"])
|
||||
except Exception as e:
|
||||
logger.debug("Modal: file sync failed: %s", e)
|
||||
|
||||
def _before_execute(self) -> None:
|
||||
self._sync_files()
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None):
|
||||
"""Return a _ThreadedProcessHandle wrapping an async Modal sandbox exec."""
|
||||
sandbox = self._sandbox
|
||||
worker = self._worker
|
||||
|
||||
def _start_modal_exec(self, prepared: PreparedModalExec) -> ModalExecStart:
|
||||
full_command = f"cd {shlex.quote(prepared.cwd)} && {prepared.command}"
|
||||
result_holder = {"value": None, "error": None}
|
||||
def cancel():
|
||||
worker.run_coroutine(sandbox.terminate.aio(), timeout=15)
|
||||
|
||||
def _run():
|
||||
try:
|
||||
async def _do_execute():
|
||||
process = await self._sandbox.exec.aio(
|
||||
"bash",
|
||||
"-c",
|
||||
full_command,
|
||||
timeout=prepared.timeout,
|
||||
)
|
||||
stdout = await process.stdout.read.aio()
|
||||
stderr = await process.stderr.read.aio()
|
||||
exit_code = await process.wait.aio()
|
||||
if isinstance(stdout, bytes):
|
||||
stdout = stdout.decode("utf-8", errors="replace")
|
||||
if isinstance(stderr, bytes):
|
||||
stderr = stderr.decode("utf-8", errors="replace")
|
||||
output = stdout
|
||||
if stderr:
|
||||
output = f"{stdout}\n{stderr}" if stdout else stderr
|
||||
return self._result(output, exit_code)
|
||||
def exec_fn() -> tuple[str, int]:
|
||||
async def _do():
|
||||
args = ["bash"]
|
||||
if login:
|
||||
args.extend(["-l", "-c", cmd_string])
|
||||
else:
|
||||
args.extend(["-c", cmd_string])
|
||||
process = await sandbox.exec.aio(*args, timeout=timeout)
|
||||
stdout = await process.stdout.read.aio()
|
||||
stderr = await process.stderr.read.aio()
|
||||
exit_code = await process.wait.aio()
|
||||
if isinstance(stdout, bytes):
|
||||
stdout = stdout.decode("utf-8", errors="replace")
|
||||
if isinstance(stderr, bytes):
|
||||
stderr = stderr.decode("utf-8", errors="replace")
|
||||
output = stdout
|
||||
if stderr:
|
||||
output = f"{stdout}\n{stderr}" if stdout else stderr
|
||||
return output, exit_code
|
||||
|
||||
result_holder["value"] = self._worker.run_coroutine(
|
||||
_do_execute(),
|
||||
timeout=prepared.timeout + 30,
|
||||
)
|
||||
except Exception as e:
|
||||
result_holder["error"] = e
|
||||
return worker.run_coroutine(_do(), timeout=timeout + 30)
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True)
|
||||
t.start()
|
||||
return ModalExecStart(handle=_DirectModalExecHandle(thread=t, result_holder=result_holder))
|
||||
|
||||
def _poll_modal_exec(self, handle: _DirectModalExecHandle) -> dict | None:
|
||||
if handle.thread.is_alive():
|
||||
return None
|
||||
if handle.result_holder["error"]:
|
||||
return self._error_result(f"Modal execution error: {handle.result_holder['error']}")
|
||||
return handle.result_holder["value"]
|
||||
|
||||
def _cancel_modal_exec(self, handle: _DirectModalExecHandle) -> None:
|
||||
self._worker.run_coroutine(
|
||||
self._sandbox.terminate.aio(),
|
||||
timeout=15,
|
||||
)
|
||||
return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
|
||||
def cleanup(self):
|
||||
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
|
||||
@@ -426,17 +357,13 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
_store_direct_snapshot(self._task_id, snapshot_id)
|
||||
logger.info(
|
||||
"Modal: saved filesystem snapshot %s for task %s",
|
||||
snapshot_id[:20],
|
||||
self._task_id,
|
||||
snapshot_id[:20], self._task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Modal: filesystem snapshot failed: %s", e)
|
||||
|
||||
try:
|
||||
self._worker.run_coroutine(
|
||||
self._sandbox.terminate.aio(),
|
||||
timeout=15,
|
||||
)
|
||||
self._worker.run_coroutine(self._sandbox.terminate.aio(), timeout=15)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
|
||||
@@ -56,7 +56,15 @@ def wrap_modal_sudo_pipe(command: str, sudo_stdin: str) -> str:
|
||||
|
||||
|
||||
class BaseModalExecutionEnvironment(BaseEnvironment):
|
||||
"""Common execute() flow for direct and managed Modal transports."""
|
||||
"""Execution flow for the *managed* Modal transport (gateway-owned sandbox).
|
||||
|
||||
This deliberately overrides :meth:`BaseEnvironment.execute` because the
|
||||
tool-gateway handles command preparation, CWD tracking, and env-snapshot
|
||||
management on the server side. The base class's ``_wrap_command`` /
|
||||
``_wait_for_process`` / snapshot machinery does not apply here — the
|
||||
gateway owns that responsibility. See ``ManagedModalEnvironment`` for the
|
||||
concrete subclass.
|
||||
"""
|
||||
|
||||
_stdin_mode = "payload"
|
||||
_poll_interval_seconds = 0.25
|
||||
@@ -124,7 +132,7 @@ class BaseModalExecutionEnvironment(BaseEnvironment):
|
||||
|
||||
def _before_execute(self) -> None:
|
||||
"""Hook for backends that need pre-exec sync or validation."""
|
||||
return None
|
||||
pass
|
||||
|
||||
def _prepare_modal_exec(
|
||||
self,
|
||||
@@ -1,290 +0,0 @@
|
||||
"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells."""
|
||||
|
||||
import logging
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PersistentShellMixin:
|
||||
"""Mixin that adds persistent shell capability to any BaseEnvironment.
|
||||
|
||||
Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``,
|
||||
``_kill_shell_children()``, ``_execute_oneshot()``, and ``_cleanup_temp_files()``.
|
||||
"""
|
||||
|
||||
persistent: bool
|
||||
|
||||
@abstractmethod
|
||||
def _spawn_shell_process(self) -> subprocess.Popen: ...
|
||||
|
||||
@abstractmethod
|
||||
def _read_temp_files(self, *paths: str) -> list[str]: ...
|
||||
|
||||
@abstractmethod
|
||||
def _kill_shell_children(self): ...
|
||||
|
||||
@abstractmethod
|
||||
def _execute_oneshot(self, command: str, cwd: str, *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict: ...
|
||||
|
||||
@abstractmethod
|
||||
def _cleanup_temp_files(self): ...
|
||||
|
||||
_session_id: str = ""
|
||||
_poll_interval_start: float = 0.01 # initial poll interval (10ms)
|
||||
_poll_interval_max: float = 0.25 # max poll interval (250ms) — reduces I/O for long commands
|
||||
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-persistent-{self._session_id}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_persistent_shell(self):
|
||||
self._shell_lock = threading.Lock()
|
||||
self._shell_proc: subprocess.Popen | None = None
|
||||
self._shell_alive: bool = False
|
||||
self._shell_pid: int | None = None
|
||||
|
||||
self._session_id = uuid.uuid4().hex[:12]
|
||||
p = self._temp_prefix
|
||||
self._pshell_stdout = f"{p}-stdout"
|
||||
self._pshell_stderr = f"{p}-stderr"
|
||||
self._pshell_status = f"{p}-status"
|
||||
self._pshell_cwd = f"{p}-cwd"
|
||||
self._pshell_pid_file = f"{p}-pid"
|
||||
|
||||
self._shell_proc = self._spawn_shell_process()
|
||||
self._shell_alive = True
|
||||
|
||||
self._drain_thread = threading.Thread(
|
||||
target=self._drain_shell_output, daemon=True,
|
||||
)
|
||||
self._drain_thread.start()
|
||||
|
||||
init_script = (
|
||||
f"export TERM=${{TERM:-dumb}}\n"
|
||||
f"touch {self._pshell_stdout} {self._pshell_stderr} "
|
||||
f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n"
|
||||
f"echo $$ > {self._pshell_pid_file}\n"
|
||||
f"pwd > {self._pshell_cwd}\n"
|
||||
)
|
||||
self._send_to_shell(init_script)
|
||||
|
||||
deadline = time.monotonic() + 3.0
|
||||
while time.monotonic() < deadline:
|
||||
pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip()
|
||||
if pid_str.isdigit():
|
||||
self._shell_pid = int(pid_str)
|
||||
break
|
||||
time.sleep(0.05)
|
||||
else:
|
||||
logger.warning("Could not read persistent shell PID")
|
||||
self._shell_pid = None
|
||||
|
||||
if self._shell_pid:
|
||||
logger.info(
|
||||
"Persistent shell started (session=%s, pid=%d)",
|
||||
self._session_id, self._shell_pid,
|
||||
)
|
||||
|
||||
reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip()
|
||||
if reported_cwd:
|
||||
self.cwd = reported_cwd
|
||||
|
||||
def _cleanup_persistent_shell(self):
|
||||
if self._shell_proc is None:
|
||||
return
|
||||
|
||||
if self._session_id:
|
||||
self._cleanup_temp_files()
|
||||
|
||||
try:
|
||||
self._shell_proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._shell_proc.terminate()
|
||||
self._shell_proc.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
self._shell_proc.kill()
|
||||
|
||||
self._shell_alive = False
|
||||
self._shell_proc = None
|
||||
|
||||
if hasattr(self, "_drain_thread") and self._drain_thread.is_alive():
|
||||
self._drain_thread.join(timeout=1.0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# execute() / cleanup() — shared dispatcher, subclasses inherit
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
if self.persistent:
|
||||
return self._execute_persistent(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
return self._execute_oneshot(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
|
||||
def execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Always use the oneshot (non-persistent) execution path.
|
||||
|
||||
This bypasses _shell_lock so it can run concurrently with a
|
||||
long-running command in the persistent shell — used by
|
||||
execute_code's file-based RPC polling thread.
|
||||
"""
|
||||
return self._execute_oneshot(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
if self.persistent:
|
||||
self._cleanup_persistent_shell()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shell I/O
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _drain_shell_output(self):
|
||||
try:
|
||||
for _ in self._shell_proc.stdout:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
self._shell_alive = False
|
||||
|
||||
def _send_to_shell(self, text: str):
|
||||
if not self._shell_alive or self._shell_proc is None:
|
||||
return
|
||||
try:
|
||||
self._shell_proc.stdin.write(text)
|
||||
self._shell_proc.stdin.flush()
|
||||
except (BrokenPipeError, OSError):
|
||||
self._shell_alive = False
|
||||
|
||||
def _read_persistent_output(self) -> tuple[str, int, str]:
|
||||
stdout, stderr, status_raw, cwd = self._read_temp_files(
|
||||
self._pshell_stdout, self._pshell_stderr,
|
||||
self._pshell_status, self._pshell_cwd,
|
||||
)
|
||||
output = self._merge_output(stdout, stderr)
|
||||
status = status_raw.strip()
|
||||
if ":" in status:
|
||||
status = status.split(":", 1)[1]
|
||||
try:
|
||||
exit_code = int(status.strip())
|
||||
except ValueError:
|
||||
exit_code = 1
|
||||
return output, exit_code, cwd.strip()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _execute_persistent(self, command: str, cwd: str, *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
if not self._shell_alive:
|
||||
logger.info("Persistent shell died, restarting...")
|
||||
self._init_persistent_shell()
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
effective_timeout = timeout or self.timeout
|
||||
if stdin_data or sudo_stdin:
|
||||
return self._execute_oneshot(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
|
||||
with self._shell_lock:
|
||||
return self._execute_persistent_locked(
|
||||
exec_command, cwd, effective_timeout,
|
||||
)
|
||||
|
||||
def _execute_persistent_locked(self, command: str, cwd: str,
|
||||
timeout: int) -> dict:
|
||||
work_dir = cwd or self.cwd
|
||||
cmd_id = uuid.uuid4().hex[:8]
|
||||
truncate = (
|
||||
f": > {self._pshell_stdout}\n"
|
||||
f": > {self._pshell_stderr}\n"
|
||||
f": > {self._pshell_status}\n"
|
||||
)
|
||||
self._send_to_shell(truncate)
|
||||
escaped = command.replace("'", "'\\''")
|
||||
|
||||
ipc_script = (
|
||||
f"cd {shlex.quote(work_dir)}\n"
|
||||
f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n"
|
||||
f"__EC=$?\n"
|
||||
f"pwd > {self._pshell_cwd}\n"
|
||||
f"echo {cmd_id}:$__EC > {self._pshell_status}\n"
|
||||
)
|
||||
self._send_to_shell(ipc_script)
|
||||
deadline = time.monotonic() + timeout
|
||||
poll_interval = self._poll_interval_start # starts at 10ms, backs off to 250ms
|
||||
|
||||
while True:
|
||||
if is_interrupted():
|
||||
self._kill_shell_children()
|
||||
output, _, _ = self._read_persistent_output()
|
||||
return {
|
||||
"output": output + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
|
||||
if time.monotonic() > deadline:
|
||||
self._kill_shell_children()
|
||||
output, _, _ = self._read_persistent_output()
|
||||
if output:
|
||||
return {
|
||||
"output": output + f"\n[Command timed out after {timeout}s]",
|
||||
"returncode": 124,
|
||||
}
|
||||
return self._timeout_result(timeout)
|
||||
|
||||
if not self._shell_alive:
|
||||
return {
|
||||
"output": "Persistent shell died during execution",
|
||||
"returncode": 1,
|
||||
}
|
||||
|
||||
status_content = self._read_temp_files(self._pshell_status)[0].strip()
|
||||
if status_content.startswith(cmd_id + ":"):
|
||||
break
|
||||
|
||||
time.sleep(poll_interval)
|
||||
# Exponential backoff: fast start (10ms) for quick commands,
|
||||
# ramps up to 250ms for long-running commands — reduces I/O by 10-25x
|
||||
# on WSL2 where polling keeps the VM hot and memory pressure high.
|
||||
poll_interval = min(poll_interval * 1.5, self._poll_interval_max)
|
||||
|
||||
output, exit_code, new_cwd = self._read_persistent_output()
|
||||
if new_cwd:
|
||||
self.cwd = new_cwd
|
||||
return {"output": output, "returncode": exit_code}
|
||||
|
||||
@staticmethod
|
||||
def _merge_output(stdout: str, stderr: str) -> str:
|
||||
parts = []
|
||||
if stdout.strip():
|
||||
parts.append(stdout.rstrip("\n"))
|
||||
if stderr.strip():
|
||||
parts.append(stderr.rstrip("\n"))
|
||||
return "\n".join(parts)
|
||||
@@ -5,20 +5,22 @@ Supports configurable resource limits and optional filesystem persistence
|
||||
via writable overlay directories that survive across sessions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
from tools.environments.base import (
|
||||
BaseEnvironment,
|
||||
_load_json_store,
|
||||
_popen_bash,
|
||||
_save_json_store,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,11 +28,7 @@ _SNAPSHOT_STORE = get_hermes_home() / "singularity_snapshots.json"
|
||||
|
||||
|
||||
def _find_singularity_executable() -> str:
|
||||
"""Locate the apptainer or singularity CLI binary.
|
||||
|
||||
Returns the executable name (``"apptainer"`` or ``"singularity"``).
|
||||
Raises ``RuntimeError`` with install instructions if neither is found.
|
||||
"""
|
||||
"""Locate the apptainer or singularity CLI binary."""
|
||||
if shutil.which("apptainer"):
|
||||
return "apptainer"
|
||||
if shutil.which("singularity"):
|
||||
@@ -43,66 +41,34 @@ def _find_singularity_executable() -> str:
|
||||
|
||||
|
||||
def _ensure_singularity_available() -> str:
|
||||
"""Preflight check: resolve the executable and verify it responds.
|
||||
|
||||
Returns the executable name on success.
|
||||
Raises ``RuntimeError`` with an actionable message on failure.
|
||||
"""
|
||||
"""Preflight check: resolve the executable and verify it responds."""
|
||||
exe = _find_singularity_executable()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[exe, "version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
[exe, "version"], capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(
|
||||
f"Singularity backend selected but the resolved executable '{exe}' "
|
||||
"could not be executed. Check your installation."
|
||||
f"Singularity backend selected but '{exe}' could not be executed."
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(
|
||||
f"'{exe} version' timed out. The runtime may be misconfigured."
|
||||
)
|
||||
raise RuntimeError(f"'{exe} version' timed out.")
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr = result.stderr.strip()[:200]
|
||||
raise RuntimeError(
|
||||
f"'{exe} version' failed (exit code {result.returncode}): {stderr}"
|
||||
)
|
||||
|
||||
raise RuntimeError(f"'{exe} version' failed (exit code {result.returncode}): {stderr}")
|
||||
return exe
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
def _load_snapshots() -> dict:
|
||||
return _load_json_store(_SNAPSHOT_STORE)
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
def _save_snapshots(data: dict) -> None:
|
||||
_save_json_store(_SNAPSHOT_STORE, data)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Singularity helpers (scratch dir, SIF cache, SIF building)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _get_scratch_dir() -> Path:
|
||||
"""Get the best directory for Singularity sandboxes.
|
||||
|
||||
Resolution order:
|
||||
1. TERMINAL_SCRATCH_DIR (explicit override)
|
||||
2. TERMINAL_SANDBOX_DIR / singularity (shared sandbox root)
|
||||
3. /scratch (common on HPC clusters)
|
||||
4. ~/.hermes/sandboxes/singularity (fallback)
|
||||
"""
|
||||
custom_scratch = os.getenv("TERMINAL_SCRATCH_DIR")
|
||||
if custom_scratch:
|
||||
scratch_path = Path(custom_scratch)
|
||||
@@ -124,7 +90,6 @@ def _get_scratch_dir() -> Path:
|
||||
|
||||
|
||||
def _get_apptainer_cache_dir() -> Path:
|
||||
"""Get the Apptainer cache directory for SIF images."""
|
||||
cache_dir = os.getenv("APPTAINER_CACHEDIR")
|
||||
if cache_dir:
|
||||
cache_path = Path(cache_dir)
|
||||
@@ -140,11 +105,6 @@ _sif_build_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
||||
"""Get or build a SIF image from a docker:// URL.
|
||||
|
||||
Returns the path unchanged if it's already a .sif file.
|
||||
For docker:// URLs, checks the cache and builds if needed.
|
||||
"""
|
||||
if image.endswith('.sif') and Path(image).exists():
|
||||
return image
|
||||
if not image.startswith('docker://'):
|
||||
@@ -193,19 +153,12 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
||||
return image
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# SingularityEnvironment
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
class SingularityEnvironment(BaseEnvironment):
|
||||
"""Hardened Singularity/Apptainer container with resource limits and persistence.
|
||||
|
||||
Security: --containall (isolated PID/IPC/mount namespaces, no host home mount),
|
||||
--no-home, writable-tmpfs for scratch space. The container cannot see or modify
|
||||
the host filesystem outside of explicitly bound paths.
|
||||
|
||||
Persistence: when enabled, the writable overlay directory is preserved across
|
||||
sessions so installed packages and files survive cleanup/restore.
|
||||
Spawn-per-call: every execute() spawns a fresh ``apptainer exec ... bash -c`` process.
|
||||
Session snapshot preserves env vars across calls.
|
||||
CWD persists via in-band stdout markers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -227,12 +180,9 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._overlay_dir: Optional[Path] = None
|
||||
|
||||
# Resource limits
|
||||
self._cpu = cpu
|
||||
self._memory = memory
|
||||
|
||||
# Persistent overlay directory
|
||||
if self._persistent:
|
||||
overlay_base = _get_scratch_dir() / "hermes-overlays"
|
||||
overlay_base.mkdir(parents=True, exist_ok=True)
|
||||
@@ -240,42 +190,26 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
self._overlay_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._start_instance()
|
||||
self.init_session()
|
||||
|
||||
def _start_instance(self):
|
||||
cmd = [self.executable, "instance", "start"]
|
||||
|
||||
# Security: full isolation from host
|
||||
cmd.extend(["--containall", "--no-home"])
|
||||
|
||||
# Writable layer
|
||||
if self._persistent and self._overlay_dir:
|
||||
# Persistent writable overlay -- survives across restarts
|
||||
cmd.extend(["--overlay", str(self._overlay_dir)])
|
||||
else:
|
||||
cmd.append("--writable-tmpfs")
|
||||
|
||||
# Mount credential files and skills directory (read-only).
|
||||
try:
|
||||
from tools.credential_files import get_credential_file_mounts, get_skills_directory_mount
|
||||
|
||||
for mount_entry in get_credential_file_mounts():
|
||||
cmd.extend(["--bind", f"{mount_entry['host_path']}:{mount_entry['container_path']}:ro"])
|
||||
logger.info(
|
||||
"Singularity: binding credential %s -> %s",
|
||||
mount_entry["host_path"],
|
||||
mount_entry["container_path"],
|
||||
)
|
||||
for skills_mount in get_skills_directory_mount():
|
||||
cmd.extend(["--bind", f"{skills_mount['host_path']}:{skills_mount['container_path']}:ro"])
|
||||
logger.info(
|
||||
"Singularity: binding skills dir %s -> %s",
|
||||
skills_mount["host_path"],
|
||||
skills_mount["container_path"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Singularity: could not load credential/skills mounts: %s", e)
|
||||
|
||||
# Resource limits (cgroup-based, may require root or appropriate config)
|
||||
if self._memory > 0:
|
||||
cmd.extend(["--memory", f"{self._memory}M"])
|
||||
if self._cpu > 0:
|
||||
@@ -288,94 +222,29 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to start instance: {result.stderr}")
|
||||
self._instance_started = True
|
||||
logger.info("Singularity instance %s started (persistent=%s)",
|
||||
logger.info("Singularity instance %s started (persistent=%s)",
|
||||
self.instance_id, self._persistent)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError("Instance start timed out")
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None) -> subprocess.Popen:
|
||||
"""Spawn a bash process inside the Singularity instance."""
|
||||
if not self._instance_started:
|
||||
return {"output": "Instance not started", "returncode": -1}
|
||||
raise RuntimeError("Singularity instance not started")
|
||||
|
||||
effective_timeout = timeout or self.timeout
|
||||
work_dir = cwd or self.cwd
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
# Merge sudo password (if any) with caller-supplied stdin_data.
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
cmd = [self.executable, "exec",
|
||||
f"instance://{self.instance_id}"]
|
||||
if login:
|
||||
cmd.extend(["bash", "-l", "-c", cmd_string])
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
cmd.extend(["bash", "-c", cmd_string])
|
||||
|
||||
# apptainer exec --pwd doesn't expand ~, so prepend a cd into the command.
|
||||
# Keep ~ unquoted (for shell expansion) and quote only the subpath.
|
||||
if work_dir == "~":
|
||||
exec_command = f"cd ~ && {exec_command}"
|
||||
work_dir = "/tmp"
|
||||
elif work_dir.startswith("~/"):
|
||||
exec_command = f"cd ~/{shlex.quote(work_dir[2:])} && {exec_command}"
|
||||
work_dir = "/tmp"
|
||||
|
||||
cmd = [self.executable, "exec", "--pwd", work_dir,
|
||||
f"instance://{self.instance_id}",
|
||||
"bash", "-c", exec_command]
|
||||
|
||||
try:
|
||||
import time as _time
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
if effective_stdin:
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = _time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if _time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
_time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
except Exception as e:
|
||||
return {"output": f"Singularity execution error: {e}", "returncode": 1}
|
||||
return _popen_bash(cmd, stdin_data)
|
||||
|
||||
def cleanup(self):
|
||||
"""Stop the instance. If persistent, the overlay dir survives for next creation."""
|
||||
"""Stop the instance. If persistent, the overlay dir survives."""
|
||||
if self._instance_started:
|
||||
try:
|
||||
subprocess.run(
|
||||
@@ -387,7 +256,6 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
logger.warning("Failed to stop Singularity instance %s: %s", self.instance_id, e)
|
||||
self._instance_started = False
|
||||
|
||||
# Record overlay path for persistence restoration
|
||||
if self._persistent and self._overlay_dir:
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[self._task_id] = str(self._overlay_dir)
|
||||
|
||||
@@ -5,13 +5,9 @@ import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
from tools.interrupt import is_interrupted
|
||||
from tools.environments.base import BaseEnvironment, _popen_bash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,32 +20,22 @@ def _ensure_ssh_available() -> None:
|
||||
)
|
||||
|
||||
|
||||
class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
class SSHEnvironment(BaseEnvironment):
|
||||
"""Run commands on a remote machine over SSH.
|
||||
|
||||
Uses SSH ControlMaster for connection persistence so subsequent
|
||||
commands are fast. Security benefit: the agent cannot modify its
|
||||
own code since execution happens on a separate machine.
|
||||
|
||||
Foreground commands are interruptible: the local ssh process is killed
|
||||
and a remote kill is attempted over the ControlMaster socket.
|
||||
|
||||
When ``persistent=True``, a single long-lived bash shell is kept alive
|
||||
over SSH and state (cwd, env vars, shell variables) persists across
|
||||
``execute()`` calls. Output capture uses file-based IPC on the remote
|
||||
host (stdout/stderr/exit-code written to temp files, polled via fast
|
||||
ControlMaster one-shot reads).
|
||||
Spawn-per-call: every execute() spawns a fresh ``ssh ... bash -c`` process.
|
||||
Session snapshot preserves env vars across calls.
|
||||
CWD persists via in-band stdout markers.
|
||||
Uses SSH ControlMaster for connection reuse.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, user: str, cwd: str = "~",
|
||||
timeout: int = 60, port: int = 22, key_path: str = "",
|
||||
persistent: bool = False):
|
||||
timeout: int = 60, port: int = 22, key_path: str = ""):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
self.host = host
|
||||
self.user = user
|
||||
self.port = port
|
||||
self.key_path = key_path
|
||||
self.persistent = persistent
|
||||
|
||||
self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh"
|
||||
self.control_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -57,10 +43,10 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
_ensure_ssh_available()
|
||||
self._establish_connection()
|
||||
self._remote_home = self._detect_remote_home()
|
||||
self._sync_skills_and_credentials()
|
||||
self._last_sync_time: float = 0 # guarantees first _before_execute syncs
|
||||
self._sync_files()
|
||||
|
||||
if self.persistent:
|
||||
self._init_persistent_shell()
|
||||
self.init_session()
|
||||
|
||||
def _build_ssh_command(self, extra_args: list | None = None) -> list:
|
||||
cmd = ["ssh"]
|
||||
@@ -102,12 +88,11 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
return home
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback: guess from username
|
||||
if self.user == "root":
|
||||
return "/root"
|
||||
return f"/home/{self.user}"
|
||||
|
||||
def _sync_skills_and_credentials(self) -> None:
|
||||
def _sync_files(self) -> None:
|
||||
"""Rsync skills directory and credential files to the remote host."""
|
||||
try:
|
||||
container_base = f"{self._remote_home}/.hermes"
|
||||
@@ -122,7 +107,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
rsync_base.extend(["-e", ssh_opts])
|
||||
dest_prefix = f"{self.user}@{self.host}"
|
||||
|
||||
# Sync individual credential files (remap /root/.hermes to detected home)
|
||||
for mount_entry in get_credential_file_mounts():
|
||||
remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1)
|
||||
parent_dir = str(Path(remote_path).parent)
|
||||
@@ -136,7 +120,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
else:
|
||||
logger.debug("SSH: rsync credential failed: %s", result.stderr.strip())
|
||||
|
||||
# Sync skill directories (local + external, remap to detected home)
|
||||
for skills_mount in get_skills_directory_mount(container_base=container_base):
|
||||
remote_path = skills_mount["container_path"]
|
||||
mkdir_cmd = self._build_ssh_command()
|
||||
@@ -154,152 +137,19 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
except Exception as e:
|
||||
logger.debug("SSH: could not sync skills/credentials: %s", e)
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
# Incremental sync before each command so mid-session credential
|
||||
# refreshes and skill updates are picked up.
|
||||
self._sync_skills_and_credentials()
|
||||
return super().execute(command, cwd, timeout=timeout, stdin_data=stdin_data)
|
||||
|
||||
_poll_interval_start: float = 0.15 # SSH: higher initial interval (150ms) for network latency
|
||||
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-ssh-{self._session_id}"
|
||||
|
||||
def _spawn_shell_process(self) -> subprocess.Popen:
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None) -> subprocess.Popen:
|
||||
"""Spawn an SSH process that runs bash on the remote host."""
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append("bash -l")
|
||||
return subprocess.Popen(
|
||||
cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
|
||||
def _read_temp_files(self, *paths: str) -> list[str]:
|
||||
if len(paths) == 1:
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(f"cat {paths[0]} 2>/dev/null")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
return [result.stdout]
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return [""]
|
||||
|
||||
delim = f"__HERMES_SEP_{self._session_id}__"
|
||||
script = "; ".join(
|
||||
f"cat {p} 2>/dev/null; echo '{delim}'" for p in paths
|
||||
)
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(script)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
parts = result.stdout.split(delim + "\n")
|
||||
return [parts[i] if i < len(parts) else "" for i in range(len(paths))]
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return [""] * len(paths)
|
||||
|
||||
def _kill_shell_children(self):
|
||||
if self._shell_pid is None:
|
||||
return
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(f"pkill -P {self._shell_pid} 2>/dev/null; true")
|
||||
try:
|
||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
def _cleanup_temp_files(self):
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(f"rm -f {self._temp_prefix}-*")
|
||||
try:
|
||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
def _execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
work_dir = cwd or self.cwd
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
# Keep ~ unquoted (for shell expansion) and quote only the subpath.
|
||||
if work_dir == "~":
|
||||
wrapped = f'cd ~ && {exec_command}'
|
||||
elif work_dir.startswith("~/"):
|
||||
wrapped = f'cd ~/{shlex.quote(work_dir[2:])} && {exec_command}'
|
||||
if login:
|
||||
cmd.extend(["bash", "-l", "-c", shlex.quote(cmd_string)])
|
||||
else:
|
||||
wrapped = f'cd {shlex.quote(work_dir)} && {exec_command}'
|
||||
effective_timeout = timeout or self.timeout
|
||||
cmd.extend(["bash", "-c", shlex.quote(cmd_string)])
|
||||
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(wrapped)
|
||||
|
||||
kwargs = self._build_run_kwargs(timeout, effective_stdin)
|
||||
kwargs.pop("timeout", None)
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if effective_stdin:
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
return _popen_bash(cmd, stdin_data)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
if self.control_socket.exists():
|
||||
try:
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||
|
||||
@@ -611,9 +611,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
||||
docker_env = cc.get("docker_env", {})
|
||||
|
||||
if env_type == "local":
|
||||
lc = local_config or {}
|
||||
return _LocalEnvironment(cwd=cwd, timeout=timeout,
|
||||
persistent=lc.get("persistent", False))
|
||||
return _LocalEnvironment(cwd=cwd, timeout=timeout)
|
||||
|
||||
elif env_type == "docker":
|
||||
return _DockerEnvironment(
|
||||
@@ -705,7 +703,6 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
||||
key_path=ssh_config.get("key", ""),
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
persistent=ssh_config.get("persistent", False),
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user