diff --git a/tests/tools/test_local_env_blocklist.py b/tests/tools/test_local_env_blocklist.py index f2346c3ef..3325a088a 100644 --- a/tests/tools/test_local_env_blocklist.py +++ b/tests/tools/test_local_env_blocklist.py @@ -26,8 +26,7 @@ def _make_fake_popen(captured: dict): proc = MagicMock() proc.poll.return_value = 0 proc.returncode = 0 - proc.stdout = iter([]) - proc.stdout.close = lambda: None + proc.stdout = MagicMock(__iter__=lambda s: iter([]), __next__=lambda s: (_ for _ in ()).throw(StopIteration)) proc.stdin = MagicMock() return proc return fake_popen diff --git a/tests/tools/test_local_persistent.py b/tests/tools/test_local_persistent.py new file mode 100644 index 000000000..b20cca5be --- /dev/null +++ b/tests/tools/test_local_persistent.py @@ -0,0 +1,152 @@ +"""Tests for the local persistent shell backend.""" + +import glob as glob_mod + +import pytest + +from tools.environments.local import LocalEnvironment +from tools.environments.persistent_shell import PersistentShellMixin + + +class TestLocalConfig: + def test_local_persistent_default_false(self, monkeypatch): + monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False) + from tools.terminal_tool import _get_env_config + assert _get_env_config()["local_persistent"] is False + + def test_local_persistent_true(self, monkeypatch): + monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "true") + from tools.terminal_tool import _get_env_config + assert _get_env_config()["local_persistent"] is True + + def test_local_persistent_yes(self, monkeypatch): + monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "yes") + from tools.terminal_tool import _get_env_config + assert _get_env_config()["local_persistent"] is True + + +class TestMergeOutput: + def test_stdout_only(self): + assert PersistentShellMixin._merge_output("out", "") == "out" + + def test_stderr_only(self): + assert PersistentShellMixin._merge_output("", "err") == "err" + + def test_both(self): + assert PersistentShellMixin._merge_output("out", "err") == "out\nerr" + + def test_empty(self): + assert PersistentShellMixin._merge_output("", "") == "" + + def test_strips_trailing_newlines(self): + assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr" + + +class TestLocalOneShotRegression: + def test_echo(self): + env = LocalEnvironment(persistent=False) + r = env.execute("echo hello") + assert r["returncode"] == 0 + assert "hello" in r["output"] + env.cleanup() + + def test_exit_code(self): + env = LocalEnvironment(persistent=False) + r = env.execute("exit 42") + assert r["returncode"] == 42 + env.cleanup() + + def test_state_does_not_persist(self): + env = LocalEnvironment(persistent=False) + env.execute("export HERMES_ONESHOT_LOCAL=yes") + r = env.execute("echo $HERMES_ONESHOT_LOCAL") + assert r["output"].strip() == "" + env.cleanup() + + +class TestLocalPersistent: + @pytest.fixture + def env(self): + e = LocalEnvironment(persistent=True) + yield e + e.cleanup() + + def test_echo(self, env): + r = env.execute("echo hello-persistent") + assert r["returncode"] == 0 + assert "hello-persistent" in r["output"] + + def test_env_var_persists(self, env): + env.execute("export HERMES_LOCAL_PERSIST_TEST=works") + r = env.execute("echo $HERMES_LOCAL_PERSIST_TEST") + assert r["output"].strip() == "works" + + def test_cwd_persists(self, env): + env.execute("cd /tmp") + r = env.execute("pwd") + assert r["output"].strip() == "/tmp" + + def test_exit_code(self, env): + r = env.execute("(exit 42)") + assert r["returncode"] == 42 + + def test_stderr(self, env): + r = env.execute("echo oops >&2") + assert r["returncode"] == 0 + assert "oops" in r["output"] + + def test_multiline_output(self, env): + r = env.execute("echo a; echo b; echo c") + lines = r["output"].strip().splitlines() + assert lines == ["a", "b", "c"] + + def test_timeout_then_recovery(self, env): + r = env.execute("sleep 999", timeout=2) + assert r["returncode"] in (124, 130) + r = env.execute("echo alive") + assert r["returncode"] == 0 + assert "alive" in r["output"] + + def test_large_output(self, env): + r = env.execute("seq 1 1000") + assert r["returncode"] == 0 + lines = r["output"].strip().splitlines() + assert len(lines) == 1000 + assert lines[0] == "1" + assert lines[-1] == "1000" + + def test_shell_variable_persists(self, env): + env.execute("MY_LOCAL_VAR=hello123") + r = env.execute("echo $MY_LOCAL_VAR") + assert r["output"].strip() == "hello123" + + def test_cleanup_removes_temp_files(self, env): + env.execute("echo warmup") + prefix = env._temp_prefix + assert len(glob_mod.glob(f"{prefix}-*")) > 0 + env.cleanup() + remaining = glob_mod.glob(f"{prefix}-*") + assert remaining == [] + + def test_state_does_not_leak_between_instances(self): + env1 = LocalEnvironment(persistent=True) + env2 = LocalEnvironment(persistent=True) + try: + env1.execute("export LEAK_TEST=from_env1") + r = env2.execute("echo $LEAK_TEST") + assert r["output"].strip() == "" + finally: + env1.cleanup() + env2.cleanup() + + def test_special_characters_in_command(self, env): + r = env.execute("echo 'hello world'") + assert r["output"].strip() == "hello world" + + def test_pipe_command(self, env): + r = env.execute("echo hello | tr 'h' 'H'") + assert r["output"].strip() == "Hello" + + def test_multiple_commands_semicolon(self, env): + r = env.execute("X=42; echo $X") + assert r["output"].strip() == "42" diff --git a/tests/tools/test_ssh_environment.py b/tests/tools/test_ssh_environment.py new file mode 100644 index 000000000..65469e5f5 --- /dev/null +++ b/tests/tools/test_ssh_environment.py @@ -0,0 +1,167 @@ +"""Tests for the SSH remote execution environment backend.""" + +import json +import os +import subprocess +from unittest.mock import MagicMock + +import pytest + +from tools.environments.ssh import SSHEnvironment + +_SSH_HOST = os.getenv("TERMINAL_SSH_HOST", "") +_SSH_USER = os.getenv("TERMINAL_SSH_USER", "") +_SSH_PORT = int(os.getenv("TERMINAL_SSH_PORT", "22")) +_SSH_KEY = os.getenv("TERMINAL_SSH_KEY", "") + +_has_ssh = bool(_SSH_HOST and _SSH_USER) + +requires_ssh = pytest.mark.skipif( + not _has_ssh, + reason="TERMINAL_SSH_HOST / TERMINAL_SSH_USER not set", +) + + +def _run(command, task_id="ssh_test", **kwargs): + from tools.terminal_tool import terminal_tool + return json.loads(terminal_tool(command, task_id=task_id, **kwargs)) + + +def _cleanup(task_id="ssh_test"): + from tools.terminal_tool import cleanup_vm + cleanup_vm(task_id) + + +class TestBuildSSHCommand: + + @pytest.fixture(autouse=True) + def _mock_connection(self, monkeypatch): + monkeypatch.setattr("tools.environments.ssh.subprocess.run", + lambda *a, **k: subprocess.CompletedProcess([], 0)) + monkeypatch.setattr("tools.environments.ssh.subprocess.Popen", + lambda *a, **k: MagicMock(stdout=iter([]), + stderr=iter([]), + stdin=MagicMock())) + monkeypatch.setattr("tools.environments.ssh.time.sleep", lambda _: None) + + def test_base_flags(self): + env = SSHEnvironment(host="h", user="u") + cmd = " ".join(env._build_ssh_command()) + for flag in ("ControlMaster=auto", "ControlPersist=300", + "BatchMode=yes", "StrictHostKeyChecking=accept-new"): + assert flag in cmd + + def test_custom_port(self): + env = SSHEnvironment(host="h", user="u", port=2222) + cmd = env._build_ssh_command() + assert "-p" in cmd and "2222" in cmd + + def test_key_path(self): + env = SSHEnvironment(host="h", user="u", key_path="/k") + cmd = env._build_ssh_command() + assert "-i" in cmd and "/k" in cmd + + def test_user_host_suffix(self): + env = SSHEnvironment(host="h", user="u") + assert env._build_ssh_command()[-1] == "u@h" + + +class TestTerminalToolConfig: + def test_ssh_persistent_default_false(self, monkeypatch): + monkeypatch.delenv("TERMINAL_SSH_PERSISTENT", raising=False) + from tools.terminal_tool import _get_env_config + assert _get_env_config()["ssh_persistent"] is False + + def test_ssh_persistent_true(self, monkeypatch): + monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "true") + from tools.terminal_tool import _get_env_config + assert _get_env_config()["ssh_persistent"] is True + + +def _setup_ssh_env(monkeypatch, persistent: bool): + monkeypatch.setenv("TERMINAL_ENV", "ssh") + monkeypatch.setenv("TERMINAL_SSH_HOST", _SSH_HOST) + monkeypatch.setenv("TERMINAL_SSH_USER", _SSH_USER) + monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "true" if persistent else "false") + if _SSH_PORT != 22: + monkeypatch.setenv("TERMINAL_SSH_PORT", str(_SSH_PORT)) + if _SSH_KEY: + monkeypatch.setenv("TERMINAL_SSH_KEY", _SSH_KEY) + + +@requires_ssh +class TestOneShotSSH: + + @pytest.fixture(autouse=True) + def _setup(self, monkeypatch): + _setup_ssh_env(monkeypatch, persistent=False) + yield + _cleanup() + + def test_echo(self): + r = _run("echo hello") + assert r["exit_code"] == 0 + assert "hello" in r["output"] + + def test_exit_code(self): + r = _run("exit 42") + assert r["exit_code"] == 42 + + def test_state_does_not_persist(self): + _run("export HERMES_ONESHOT_TEST=yes") + r = _run("echo $HERMES_ONESHOT_TEST") + assert r["output"].strip() == "" + + +@requires_ssh +class TestPersistentSSH: + + @pytest.fixture(autouse=True) + def _setup(self, monkeypatch): + _setup_ssh_env(monkeypatch, persistent=True) + yield + _cleanup() + + def test_echo(self): + r = _run("echo hello-persistent") + assert r["exit_code"] == 0 + assert "hello-persistent" in r["output"] + + def test_env_var_persists(self): + _run("export HERMES_PERSIST_TEST=works") + r = _run("echo $HERMES_PERSIST_TEST") + assert r["output"].strip() == "works" + + def test_cwd_persists(self): + _run("cd /tmp") + r = _run("pwd") + assert r["output"].strip() == "/tmp" + + def test_exit_code(self): + r = _run("(exit 42)") + assert r["exit_code"] == 42 + + def test_stderr(self): + r = _run("echo oops >&2") + assert r["exit_code"] == 0 + assert "oops" in r["output"] + + def test_multiline_output(self): + r = _run("echo a; echo b; echo c") + lines = r["output"].strip().splitlines() + assert lines == ["a", "b", "c"] + + def test_timeout_then_recovery(self): + r = _run("sleep 999", timeout=2) + assert r["exit_code"] == 124 + r = _run("echo alive") + assert r["exit_code"] == 0 + assert "alive" in r["output"] + + def test_large_output(self): + r = _run("seq 1 1000") + assert r["exit_code"] == 0 + lines = r["output"].strip().splitlines() + assert len(lines) == 1000 + assert lines[0] == "1" + assert lines[-1] == "1000" diff --git a/tools/environments/local.py b/tools/environments/local.py index bd82ded10..ed46dc7a1 100644 --- a/tools/environments/local.py +++ b/tools/environments/local.py @@ -1,5 +1,6 @@ """Local execution environment with interrupt support and non-blocking I/O.""" +import glob import os import platform import shutil @@ -11,6 +12,8 @@ import time _IS_WINDOWS = platform.system() == "Windows" from tools.environments.base import BaseEnvironment +from tools.environments.persistent_shell import PersistentShellMixin +from tools.interrupt import is_interrupted # Unique marker to isolate real command output from shell init/exit noise. # printf (no trailing newline) keeps the boundaries clean for splitting. @@ -244,6 +247,25 @@ def _clean_shell_noise(output: str) -> str: return result +_SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" + + +def _make_run_env(env: dict) -> dict: + """Build a run environment with a sane PATH and provider-var stripping.""" + merged = dict(os.environ | env) + run_env = {} + for k, v in merged.items(): + if k.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX): + real_key = k[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):] + run_env[real_key] = v + elif k not in _HERMES_PROVIDER_ENV_BLOCKLIST: + run_env[k] = v + existing_path = run_env.get("PATH", "") + if "/usr/bin" not in existing_path.split(":"): + run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH + return run_env + + def _extract_fenced_output(raw: str) -> str: """Extract real command output from between fence markers. @@ -268,7 +290,7 @@ def _extract_fenced_output(raw: str) -> str: return raw[start:last] -class LocalEnvironment(BaseEnvironment): +class LocalEnvironment(PersistentShellMixin, BaseEnvironment): """Run commands directly on the host machine. Features: @@ -277,24 +299,66 @@ class LocalEnvironment(BaseEnvironment): - stdin_data support for piping content (bypasses ARG_MAX limits) - sudo -S transform via SUDO_PASSWORD env var - Uses interactive login shell so full user env is available + - Optional persistent shell mode (cwd/env vars survive across calls) """ - def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None): + def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None, + persistent: bool = False): super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env) + self.persistent = persistent + if self.persistent: + self._init_persistent_shell() - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - from tools.terminal_tool import _interrupt_event + @property + def _temp_prefix(self) -> str: + return f"/tmp/hermes-local-{self._session_id}" + def _spawn_shell_process(self) -> subprocess.Popen: + user_shell = _find_bash() + run_env = _make_run_env(self.env) + return subprocess.Popen( + [user_shell, "-l"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + text=True, + env=run_env, + preexec_fn=None if _IS_WINDOWS else os.setsid, + ) + + def _read_temp_files(self, *paths: str) -> list[str]: + results = [] + for path in paths: + if os.path.exists(path): + with open(path) as f: + results.append(f.read()) + else: + results.append("") + return results + + def _kill_shell_children(self): + if self._shell_pid is None: + return + try: + subprocess.run( + ["pkill", "-P", str(self._shell_pid)], + capture_output=True, timeout=5, + ) + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + + def _cleanup_temp_files(self): + for f in glob.glob(f"{self._temp_prefix}-*"): + if os.path.exists(f): + os.remove(f) + + def _execute_oneshot(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: work_dir = cwd or self.cwd or os.getcwd() effective_timeout = timeout or self.timeout exec_command, sudo_stdin = self._prepare_command(command) - # Merge the sudo password (if any) with caller-supplied stdin_data. - # sudo -S reads exactly one line (the password) then passes the rest - # of stdin to the child, so prepending is safe even when stdin_data - # is also present. if sudo_stdin is not None and stdin_data is not None: effective_stdin = sudo_stdin + stdin_data elif sudo_stdin is not None: @@ -302,110 +366,87 @@ class LocalEnvironment(BaseEnvironment): else: effective_stdin = stdin_data - try: - # The fence wrapper uses bash syntax (semicolons, $?, printf). - # Always use bash for the wrapper — NOT $SHELL which could be - # fish, zsh, or another shell with incompatible syntax. - # The -lic flags source rc files so tools like nvm/pyenv work. - user_shell = _find_bash() - # Wrap with output fences so we can later extract the real - # command output and discard shell init/exit noise. - fenced_cmd = ( - f"printf '{_OUTPUT_FENCE}';" - f" {exec_command};" - f" __hermes_rc=$?;" - f" printf '{_OUTPUT_FENCE}';" - f" exit $__hermes_rc" - ) - # Ensure PATH always includes standard dirs — systemd services - # and some terminal multiplexers inherit a minimal PATH. - _SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" - # Strip Hermes-managed provider/tool/gateway vars so external CLIs - # are not silently misrouted or handed Hermes secrets. Callers that - # truly need a blocked var can opt in by prefixing the key with - # _HERMES_FORCE_ in self.env (e.g. _HERMES_FORCE_OPENAI_API_KEY). - run_env = _sanitize_subprocess_env(os.environ, self.env) - existing_path = run_env.get("PATH", "") - if "/usr/bin" not in existing_path.split(":"): - run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH + user_shell = _find_bash() + fenced_cmd = ( + f"printf '{_OUTPUT_FENCE}';" + f" {exec_command};" + f" __hermes_rc=$?;" + f" printf '{_OUTPUT_FENCE}';" + f" exit $__hermes_rc" + ) + run_env = _make_run_env(self.env) - proc = subprocess.Popen( - [user_shell, "-lic", fenced_cmd], - text=True, - cwd=work_dir, - env=run_env, - encoding="utf-8", - errors="replace", - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL, - preexec_fn=None if _IS_WINDOWS else os.setsid, - ) + proc = subprocess.Popen( + [user_shell, "-lic", fenced_cmd], + text=True, + cwd=work_dir, + env=run_env, + encoding="utf-8", + errors="replace", + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL, + preexec_fn=None if _IS_WINDOWS else os.setsid, + ) - if effective_stdin is not None: - def _write_stdin(): - try: - proc.stdin.write(effective_stdin) - proc.stdin.close() - except (BrokenPipeError, OSError): - pass - threading.Thread(target=_write_stdin, daemon=True).start() - - _output_chunks: list[str] = [] - - def _drain_stdout(): + if effective_stdin is not None: + def _write_stdin(): try: - for line in proc.stdout: - _output_chunks.append(line) - except ValueError: + proc.stdin.write(effective_stdin) + proc.stdin.close() + except (BrokenPipeError, OSError): pass - finally: - try: - proc.stdout.close() - except Exception: - pass + threading.Thread(target=_write_stdin, daemon=True).start() - reader = threading.Thread(target=_drain_stdout, daemon=True) - reader.start() - deadline = time.monotonic() + effective_timeout + _output_chunks: list[str] = [] - while proc.poll() is None: - if _interrupt_event.is_set(): - try: - if _IS_WINDOWS: - proc.terminate() - else: - pgid = os.getpgid(proc.pid) - os.killpg(pgid, signal.SIGTERM) - try: - proc.wait(timeout=1.0) - except subprocess.TimeoutExpired: - os.killpg(pgid, signal.SIGKILL) - except (ProcessLookupError, PermissionError): - proc.kill() - reader.join(timeout=2) - return { - "output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]", - "returncode": 130, - } - if time.monotonic() > deadline: - try: - if _IS_WINDOWS: - proc.terminate() - else: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) - except (ProcessLookupError, PermissionError): - proc.kill() - reader.join(timeout=2) - return self._timeout_result(effective_timeout) - time.sleep(0.2) + def _drain_stdout(): + try: + for line in proc.stdout: + _output_chunks.append(line) + except ValueError: + pass + finally: + try: + proc.stdout.close() + except Exception: + pass - reader.join(timeout=5) - output = _extract_fenced_output("".join(_output_chunks)) - return {"output": output, "returncode": proc.returncode} + reader = threading.Thread(target=_drain_stdout, daemon=True) + reader.start() + deadline = time.monotonic() + effective_timeout - except Exception as e: - return {"output": f"Execution error: {str(e)}", "returncode": 1} + while proc.poll() is None: + if is_interrupted(): + try: + if _IS_WINDOWS: + proc.terminate() + else: + pgid = os.getpgid(proc.pid) + os.killpg(pgid, signal.SIGTERM) + try: + proc.wait(timeout=1.0) + except subprocess.TimeoutExpired: + os.killpg(pgid, signal.SIGKILL) + except (ProcessLookupError, PermissionError): + proc.kill() + reader.join(timeout=2) + return { + "output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]", + "returncode": 130, + } + if time.monotonic() > deadline: + try: + if _IS_WINDOWS: + proc.terminate() + else: + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + except (ProcessLookupError, PermissionError): + proc.kill() + reader.join(timeout=2) + return self._timeout_result(effective_timeout) + time.sleep(0.2) - def cleanup(self): - pass + reader.join(timeout=5) + output = _extract_fenced_output("".join(_output_chunks)) + return {"output": output, "returncode": proc.returncode} diff --git a/tools/environments/persistent_shell.py b/tools/environments/persistent_shell.py new file mode 100644 index 000000000..4b89db471 --- /dev/null +++ b/tools/environments/persistent_shell.py @@ -0,0 +1,272 @@ +"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells.""" + +import logging +import shlex +import subprocess +import threading +import time +import uuid +from abc import abstractmethod + +from tools.interrupt import is_interrupted + +logger = logging.getLogger(__name__) + + +class PersistentShellMixin: + """Mixin that adds persistent shell capability to any BaseEnvironment. + + Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``, + ``_kill_shell_children()``, ``_execute_oneshot()``, and ``_cleanup_temp_files()``. + """ + + persistent: bool + + @abstractmethod + def _spawn_shell_process(self) -> subprocess.Popen: ... + + @abstractmethod + def _read_temp_files(self, *paths: str) -> list[str]: ... + + @abstractmethod + def _kill_shell_children(self): ... + + @abstractmethod + def _execute_oneshot(self, command: str, cwd: str, *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: ... + + @abstractmethod + def _cleanup_temp_files(self): ... + + _session_id: str = "" + _poll_interval: float = 0.01 + + @property + def _temp_prefix(self) -> str: + return f"/tmp/hermes-persistent-{self._session_id}" + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def _init_persistent_shell(self): + self._shell_lock = threading.Lock() + self._shell_proc: subprocess.Popen | None = None + self._shell_alive: bool = False + self._shell_pid: int | None = None + + self._session_id = uuid.uuid4().hex[:12] + p = self._temp_prefix + self._pshell_stdout = f"{p}-stdout" + self._pshell_stderr = f"{p}-stderr" + self._pshell_status = f"{p}-status" + self._pshell_cwd = f"{p}-cwd" + self._pshell_pid_file = f"{p}-pid" + + self._shell_proc = self._spawn_shell_process() + self._shell_alive = True + + self._drain_thread = threading.Thread( + target=self._drain_shell_output, daemon=True, + ) + self._drain_thread.start() + + init_script = ( + f"export TERM=${{TERM:-dumb}}\n" + f"touch {self._pshell_stdout} {self._pshell_stderr} " + f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n" + f"echo $$ > {self._pshell_pid_file}\n" + f"pwd > {self._pshell_cwd}\n" + ) + self._send_to_shell(init_script) + + deadline = time.monotonic() + 3.0 + while time.monotonic() < deadline: + pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip() + if pid_str.isdigit(): + self._shell_pid = int(pid_str) + break + time.sleep(0.05) + else: + logger.warning("Could not read persistent shell PID") + self._shell_pid = None + + if self._shell_pid: + logger.info( + "Persistent shell started (session=%s, pid=%d)", + self._session_id, self._shell_pid, + ) + + reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip() + if reported_cwd: + self.cwd = reported_cwd + + def _cleanup_persistent_shell(self): + if self._shell_proc is None: + return + + if self._session_id: + self._cleanup_temp_files() + + try: + self._shell_proc.stdin.close() + except Exception: + pass + try: + self._shell_proc.terminate() + self._shell_proc.wait(timeout=3) + except subprocess.TimeoutExpired: + self._shell_proc.kill() + + self._shell_alive = False + self._shell_proc = None + + if hasattr(self, "_drain_thread") and self._drain_thread.is_alive(): + self._drain_thread.join(timeout=1.0) + + # ------------------------------------------------------------------ + # execute() / cleanup() — shared dispatcher, subclasses inherit + # ------------------------------------------------------------------ + + def execute(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + if self.persistent: + return self._execute_persistent( + command, cwd, timeout=timeout, stdin_data=stdin_data, + ) + return self._execute_oneshot( + command, cwd, timeout=timeout, stdin_data=stdin_data, + ) + + def cleanup(self): + if self.persistent: + self._cleanup_persistent_shell() + + # ------------------------------------------------------------------ + # Shell I/O + # ------------------------------------------------------------------ + + def _drain_shell_output(self): + try: + for _ in self._shell_proc.stdout: + pass + except Exception: + pass + self._shell_alive = False + + def _send_to_shell(self, text: str): + if not self._shell_alive or self._shell_proc is None: + return + try: + self._shell_proc.stdin.write(text) + self._shell_proc.stdin.flush() + except (BrokenPipeError, OSError): + self._shell_alive = False + + def _read_persistent_output(self) -> tuple[str, int, str]: + stdout, stderr, status_raw, cwd = self._read_temp_files( + self._pshell_stdout, self._pshell_stderr, + self._pshell_status, self._pshell_cwd, + ) + output = self._merge_output(stdout, stderr) + status = status_raw.strip() + if ":" in status: + status = status.split(":", 1)[1] + try: + exit_code = int(status.strip()) + except ValueError: + exit_code = 1 + return output, exit_code, cwd.strip() + + # ------------------------------------------------------------------ + # Execution + # ------------------------------------------------------------------ + + def _execute_persistent(self, command: str, cwd: str, *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + if not self._shell_alive: + logger.info("Persistent shell died, restarting...") + self._init_persistent_shell() + + exec_command, sudo_stdin = self._prepare_command(command) + effective_timeout = timeout or self.timeout + if stdin_data or sudo_stdin: + return self._execute_oneshot( + command, cwd, timeout=timeout, stdin_data=stdin_data, + ) + + with self._shell_lock: + return self._execute_persistent_locked( + exec_command, cwd, effective_timeout, + ) + + def _execute_persistent_locked(self, command: str, cwd: str, + timeout: int) -> dict: + work_dir = cwd or self.cwd + cmd_id = uuid.uuid4().hex[:8] + truncate = ( + f": > {self._pshell_stdout}\n" + f": > {self._pshell_stderr}\n" + f": > {self._pshell_status}\n" + ) + self._send_to_shell(truncate) + escaped = command.replace("'", "'\\''") + + ipc_script = ( + f"cd {shlex.quote(work_dir)}\n" + f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n" + f"__EC=$?\n" + f"pwd > {self._pshell_cwd}\n" + f"echo {cmd_id}:$__EC > {self._pshell_status}\n" + ) + self._send_to_shell(ipc_script) + deadline = time.monotonic() + timeout + poll_interval = self._poll_interval + + while True: + if is_interrupted(): + self._kill_shell_children() + output, _, _ = self._read_persistent_output() + return { + "output": output + "\n[Command interrupted]", + "returncode": 130, + } + + if time.monotonic() > deadline: + self._kill_shell_children() + output, _, _ = self._read_persistent_output() + if output: + return { + "output": output + f"\n[Command timed out after {timeout}s]", + "returncode": 124, + } + return self._timeout_result(timeout) + + if not self._shell_alive: + return { + "output": "Persistent shell died during execution", + "returncode": 1, + } + + status_content = self._read_temp_files(self._pshell_status)[0].strip() + if status_content.startswith(cmd_id + ":"): + break + + time.sleep(poll_interval) + + output, exit_code, new_cwd = self._read_persistent_output() + if new_cwd: + self.cwd = new_cwd + return {"output": output, "returncode": exit_code} + + @staticmethod + def _merge_output(stdout: str, stderr: str) -> str: + parts = [] + if stdout.strip(): + parts.append(stdout.rstrip("\n")) + if stderr.strip(): + parts.append(stderr.rstrip("\n")) + return "\n".join(parts) diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index 83cc335b1..90532dda0 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -8,12 +8,13 @@ import time from pathlib import Path from tools.environments.base import BaseEnvironment +from tools.environments.persistent_shell import PersistentShellMixin from tools.interrupt import is_interrupted logger = logging.getLogger(__name__) -class SSHEnvironment(BaseEnvironment): +class SSHEnvironment(PersistentShellMixin, BaseEnvironment): """Run commands on a remote machine over SSH. Uses SSH ControlMaster for connection persistence so subsequent @@ -22,22 +23,33 @@ class SSHEnvironment(BaseEnvironment): Foreground commands are interruptible: the local ssh process is killed and a remote kill is attempted over the ControlMaster socket. + + When ``persistent=True``, a single long-lived bash shell is kept alive + over SSH and state (cwd, env vars, shell variables) persists across + ``execute()`` calls. Output capture uses file-based IPC on the remote + host (stdout/stderr/exit-code written to temp files, polled via fast + ControlMaster one-shot reads). """ def __init__(self, host: str, user: str, cwd: str = "~", - timeout: int = 60, port: int = 22, key_path: str = ""): + timeout: int = 60, port: int = 22, key_path: str = "", + persistent: bool = False): super().__init__(cwd=cwd, timeout=timeout) self.host = host self.user = user self.port = port self.key_path = key_path + self.persistent = persistent self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh" self.control_dir.mkdir(parents=True, exist_ok=True) self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock" self._establish_connection() - def _build_ssh_command(self, extra_args: list = None) -> list: + if self.persistent: + self._init_persistent_shell() + + def _build_ssh_command(self, extra_args: list | None = None) -> list: cmd = ["ssh"] cmd.extend(["-o", f"ControlPath={self.control_socket}"]) cmd.extend(["-o", "ControlMaster=auto"]) @@ -65,15 +77,76 @@ class SSHEnvironment(BaseEnvironment): except subprocess.TimeoutExpired: raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out") - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: + _poll_interval: float = 0.15 + + @property + def _temp_prefix(self) -> str: + return f"/tmp/hermes-ssh-{self._session_id}" + + def _spawn_shell_process(self) -> subprocess.Popen: + cmd = self._build_ssh_command() + cmd.append("bash -l") + return subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + text=True, + ) + + def _read_temp_files(self, *paths: str) -> list[str]: + if len(paths) == 1: + cmd = self._build_ssh_command() + cmd.append(f"cat {paths[0]} 2>/dev/null") + try: + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=10, + ) + return [result.stdout] + except (subprocess.TimeoutExpired, OSError): + return [""] + + delim = f"__HERMES_SEP_{self._session_id}__" + script = "; ".join( + f"cat {p} 2>/dev/null; echo '{delim}'" for p in paths + ) + cmd = self._build_ssh_command() + cmd.append(script) + try: + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=10, + ) + parts = result.stdout.split(delim + "\n") + return [parts[i] if i < len(parts) else "" for i in range(len(paths))] + except (subprocess.TimeoutExpired, OSError): + return [""] * len(paths) + + def _kill_shell_children(self): + if self._shell_pid is None: + return + cmd = self._build_ssh_command() + cmd.append(f"pkill -P {self._shell_pid} 2>/dev/null; true") + try: + subprocess.run(cmd, capture_output=True, timeout=5) + except (subprocess.TimeoutExpired, OSError): + pass + + def _cleanup_temp_files(self): + cmd = self._build_ssh_command() + cmd.append(f"rm -f {self._temp_prefix}-*") + try: + subprocess.run(cmd, capture_output=True, timeout=5) + except (subprocess.TimeoutExpired, OSError): + pass + + def _execute_oneshot(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: work_dir = cwd or self.cwd exec_command, sudo_stdin = self._prepare_command(command) wrapped = f'cd {work_dir} && {exec_command}' effective_timeout = timeout or self.timeout - # Merge sudo password (if any) with caller-supplied stdin_data. if sudo_stdin is not None and stdin_data is not None: effective_stdin = sudo_stdin + stdin_data elif sudo_stdin is not None: @@ -82,66 +155,60 @@ class SSHEnvironment(BaseEnvironment): effective_stdin = stdin_data cmd = self._build_ssh_command() - cmd.extend(["bash", "-c", wrapped]) + cmd.append(wrapped) - try: - kwargs = self._build_run_kwargs(timeout, effective_stdin) - # Remove timeout from kwargs -- we handle it in the poll loop - kwargs.pop("timeout", None) + kwargs = self._build_run_kwargs(timeout, effective_stdin) + kwargs.pop("timeout", None) + _output_chunks = [] + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL, + text=True, + ) - _output_chunks = [] + if effective_stdin: + try: + proc.stdin.write(effective_stdin) + proc.stdin.close() + except (BrokenPipeError, OSError): + pass - proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL, - text=True, - ) + def _drain(): + try: + for line in proc.stdout: + _output_chunks.append(line) + except Exception: + pass - if effective_stdin: + reader = threading.Thread(target=_drain, daemon=True) + reader.start() + deadline = time.monotonic() + effective_timeout + + while proc.poll() is None: + if is_interrupted(): + proc.terminate() try: - proc.stdin.write(effective_stdin) - proc.stdin.close() - except Exception: - pass - - def _drain(): - try: - for line in proc.stdout: - _output_chunks.append(line) - except Exception: - pass - - reader = threading.Thread(target=_drain, daemon=True) - reader.start() - deadline = time.monotonic() + effective_timeout - - while proc.poll() is None: - if is_interrupted(): - proc.terminate() - try: - proc.wait(timeout=1) - except subprocess.TimeoutExpired: - proc.kill() - reader.join(timeout=2) - return { - "output": "".join(_output_chunks) + "\n[Command interrupted]", - "returncode": 130, - } - if time.monotonic() > deadline: + proc.wait(timeout=1) + except subprocess.TimeoutExpired: proc.kill() - reader.join(timeout=2) - return self._timeout_result(effective_timeout) - time.sleep(0.2) + reader.join(timeout=2) + return { + "output": "".join(_output_chunks) + "\n[Command interrupted]", + "returncode": 130, + } + if time.monotonic() > deadline: + proc.kill() + reader.join(timeout=2) + return self._timeout_result(effective_timeout) + time.sleep(0.2) - reader.join(timeout=5) - return {"output": "".join(_output_chunks), "returncode": proc.returncode} - - except Exception as e: - return {"output": f"SSH execution error: {str(e)}", "returncode": 1} + reader.join(timeout=5) + return {"output": "".join(_output_chunks), "returncode": proc.returncode} def cleanup(self): + super().cleanup() if self.control_socket.exists(): try: cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", diff --git a/tools/file_tools.py b/tools/file_tools.py index e2535b06a..b41fc893e 100644 --- a/tools/file_tools.py +++ b/tools/file_tools.py @@ -114,12 +114,31 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations: "container_persistent": config.get("container_persistent", True), "docker_volumes": config.get("docker_volumes", []), } + + ssh_config = None + if env_type == "ssh": + ssh_config = { + "host": config.get("ssh_host", ""), + "user": config.get("ssh_user", ""), + "port": config.get("ssh_port", 22), + "key": config.get("ssh_key", ""), + "persistent": config.get("ssh_persistent", False), + } + + local_config = None + if env_type == "local": + local_config = { + "persistent": config.get("local_persistent", False), + } + terminal_env = _create_environment( env_type=env_type, image=image, cwd=cwd, timeout=config["timeout"], + ssh_config=ssh_config, container_config=container_config, + local_config=local_config, task_id=task_id, ) diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index bf1d2b6b3..327e12210 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -471,6 +471,8 @@ def _get_env_config() -> Dict[str, Any]: # is running inside the container/remote). if env_type == "local": default_cwd = os.getcwd() + elif env_type == "ssh": + default_cwd = "~" else: default_cwd = "/root" @@ -503,6 +505,8 @@ def _get_env_config() -> Dict[str, Any]: "ssh_user": os.getenv("TERMINAL_SSH_USER", ""), "ssh_port": _parse_env_var("TERMINAL_SSH_PORT", "22"), "ssh_key": os.getenv("TERMINAL_SSH_KEY", ""), + "ssh_persistent": os.getenv("TERMINAL_SSH_PERSISTENT", "false").lower() in ("true", "1", "yes"), + "local_persistent": os.getenv("TERMINAL_LOCAL_PERSISTENT", "false").lower() in ("true", "1", "yes"), # Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh) "container_cpu": _parse_env_var("TERMINAL_CONTAINER_CPU", "1", float, "number"), "container_memory": _parse_env_var("TERMINAL_CONTAINER_MEMORY", "5120"), # MB (default 5GB) @@ -514,6 +518,7 @@ def _get_env_config() -> Dict[str, Any]: def _create_environment(env_type: str, image: str, cwd: str, timeout: int, ssh_config: dict = None, container_config: dict = None, + local_config: dict = None, task_id: str = "default"): """ Create an execution environment from mini-swe-agent. @@ -538,7 +543,9 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, volumes = cc.get("docker_volumes", []) if env_type == "local": - return _LocalEnvironment(cwd=cwd, timeout=timeout) + lc = local_config or {} + return _LocalEnvironment(cwd=cwd, timeout=timeout, + persistent=lc.get("persistent", False)) elif env_type == "docker": return _DockerEnvironment( @@ -594,6 +601,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, key_path=ssh_config.get("key", ""), cwd=cwd, timeout=timeout, + persistent=ssh_config.get("persistent", False), ) else: @@ -923,6 +931,7 @@ def terminal_tool( "user": config.get("ssh_user", ""), "port": config.get("ssh_port", 22), "key": config.get("ssh_key", ""), + "persistent": config.get("ssh_persistent", False), } container_config = None @@ -935,6 +944,12 @@ def terminal_tool( "docker_volumes": config.get("docker_volumes", []), } + local_config = None + if env_type == "local": + local_config = { + "persistent": config.get("local_persistent", False), + } + new_env = _create_environment( env_type=env_type, image=image, @@ -942,6 +957,7 @@ def terminal_tool( timeout=effective_timeout, ssh_config=ssh_config, container_config=container_config, + local_config=local_config, task_id=effective_task_id, ) except ImportError as e: