diff --git a/tests/tools/test_local_persistent.py b/tests/tools/test_local_persistent.py new file mode 100644 index 000000000..9c1642a2c --- /dev/null +++ b/tests/tools/test_local_persistent.py @@ -0,0 +1,183 @@ +"""Tests for the local persistent shell backend. + +Unit tests cover config plumbing (no real shell needed). +Integration tests run real commands — no external dependencies required. + + pytest tests/tools/test_local_persistent.py -v +""" + +import glob as glob_mod + +import pytest + +from tools.environments.local import LocalEnvironment +from tools.environments.persistent_shell import PersistentShellMixin + + +# --------------------------------------------------------------------------- +# Unit tests — config plumbing +# --------------------------------------------------------------------------- + +class TestLocalConfig: + def test_local_persistent_default_false(self, monkeypatch): + monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False) + from tools.terminal_tool import _get_env_config + assert _get_env_config()["local_persistent"] is False + + def test_local_persistent_true(self, monkeypatch): + monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "true") + from tools.terminal_tool import _get_env_config + assert _get_env_config()["local_persistent"] is True + + def test_local_persistent_yes(self, monkeypatch): + monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "yes") + from tools.terminal_tool import _get_env_config + assert _get_env_config()["local_persistent"] is True + + +class TestMergeOutput: + """Test the shared _merge_output static method.""" + + def test_stdout_only(self): + assert PersistentShellMixin._merge_output("out", "") == "out" + + def test_stderr_only(self): + assert PersistentShellMixin._merge_output("", "err") == "err" + + def test_both(self): + assert PersistentShellMixin._merge_output("out", "err") == "out\nerr" + + def test_empty(self): + assert PersistentShellMixin._merge_output("", "") == "" + + def test_strips_trailing_newlines(self): + assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr" + + +# --------------------------------------------------------------------------- +# One-shot regression tests — ensure refactor didn't break anything +# --------------------------------------------------------------------------- + +class TestLocalOneShotRegression: + """Verify one-shot mode still works after adding the mixin.""" + + def test_echo(self): + env = LocalEnvironment(persistent=False) + r = env.execute("echo hello") + assert r["returncode"] == 0 + assert "hello" in r["output"] + env.cleanup() + + def test_exit_code(self): + env = LocalEnvironment(persistent=False) + r = env.execute("exit 42") + assert r["returncode"] == 42 + env.cleanup() + + def test_state_does_not_persist(self): + """Env vars set in one command should NOT survive in one-shot mode.""" + env = LocalEnvironment(persistent=False) + env.execute("export HERMES_ONESHOT_LOCAL=yes") + r = env.execute("echo $HERMES_ONESHOT_LOCAL") + # In one-shot mode, env var should not persist + assert r["output"].strip() == "" + env.cleanup() + + +# --------------------------------------------------------------------------- +# Persistent shell integration tests +# --------------------------------------------------------------------------- + +class TestLocalPersistent: + """Persistent mode: state persists across execute() calls.""" + + @pytest.fixture + def env(self): + e = LocalEnvironment(persistent=True) + yield e + e.cleanup() + + def test_echo(self, env): + r = env.execute("echo hello-persistent") + assert r["returncode"] == 0 + assert "hello-persistent" in r["output"] + + def test_env_var_persists(self, env): + env.execute("export HERMES_LOCAL_PERSIST_TEST=works") + r = env.execute("echo $HERMES_LOCAL_PERSIST_TEST") + assert r["output"].strip() == "works" + + def test_cwd_persists(self, env): + env.execute("cd /tmp") + r = env.execute("pwd") + assert r["output"].strip() == "/tmp" + + def test_exit_code(self, env): + r = env.execute("(exit 42)") + assert r["returncode"] == 42 + + def test_stderr(self, env): + r = env.execute("echo oops >&2") + assert r["returncode"] == 0 + assert "oops" in r["output"] + + def test_multiline_output(self, env): + r = env.execute("echo a; echo b; echo c") + lines = r["output"].strip().splitlines() + assert lines == ["a", "b", "c"] + + def test_timeout_then_recovery(self, env): + r = env.execute("sleep 999", timeout=2) + assert r["returncode"] in (124, 130) # timeout or interrupted + # Shell should survive — next command works + r = env.execute("echo alive") + assert r["returncode"] == 0 + assert "alive" in r["output"] + + def test_large_output(self, env): + r = env.execute("seq 1 1000") + assert r["returncode"] == 0 + lines = r["output"].strip().splitlines() + assert len(lines) == 1000 + assert lines[0] == "1" + assert lines[-1] == "1000" + + def test_shell_variable_persists(self, env): + """Shell variables (not exported) should also persist.""" + env.execute("MY_LOCAL_VAR=hello123") + r = env.execute("echo $MY_LOCAL_VAR") + assert r["output"].strip() == "hello123" + + def test_cleanup_removes_temp_files(self, env): + env.execute("echo warmup") + prefix = env._temp_prefix + # Temp files should exist + assert len(glob_mod.glob(f"{prefix}-*")) > 0 + env.cleanup() + remaining = glob_mod.glob(f"{prefix}-*") + assert remaining == [] + + def test_state_does_not_leak_between_instances(self): + """Two separate persistent instances don't share state.""" + env1 = LocalEnvironment(persistent=True) + env2 = LocalEnvironment(persistent=True) + try: + env1.execute("export LEAK_TEST=from_env1") + r = env2.execute("echo $LEAK_TEST") + assert r["output"].strip() == "" + finally: + env1.cleanup() + env2.cleanup() + + def test_special_characters_in_command(self, env): + """Commands with quotes and special chars should work.""" + r = env.execute("echo 'hello world'") + assert r["output"].strip() == "hello world" + + def test_pipe_command(self, env): + r = env.execute("echo hello | tr 'h' 'H'") + assert r["output"].strip() == "Hello" + + def test_multiple_commands_semicolon(self, env): + r = env.execute("X=42; echo $X") + assert r["output"].strip() == "42" diff --git a/tests/tools/test_ssh_environment.py b/tests/tools/test_ssh_environment.py new file mode 100644 index 000000000..d10108c9b --- /dev/null +++ b/tests/tools/test_ssh_environment.py @@ -0,0 +1,198 @@ +"""Tests for the SSH remote execution environment backend. + +Unit tests (no SSH required) cover pure logic: command building, output merging, +config plumbing. + +Integration tests require a real SSH target. Set TERMINAL_SSH_HOST and +TERMINAL_SSH_USER to enable them. In CI, start an sshd container or enable +the localhost SSH service. + + TERMINAL_SSH_HOST=localhost TERMINAL_SSH_USER=$(whoami) \ + pytest tests/tools/test_ssh_environment.py -v +""" + +import json +import os +import subprocess +from unittest.mock import MagicMock + +import pytest + +from tools.environments.ssh import SSHEnvironment + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_SSH_HOST = os.getenv("TERMINAL_SSH_HOST", "") +_SSH_USER = os.getenv("TERMINAL_SSH_USER", "") +_SSH_PORT = int(os.getenv("TERMINAL_SSH_PORT", "22")) +_SSH_KEY = os.getenv("TERMINAL_SSH_KEY", "") + +_has_ssh = bool(_SSH_HOST and _SSH_USER) + +requires_ssh = pytest.mark.skipif( + not _has_ssh, + reason="TERMINAL_SSH_HOST / TERMINAL_SSH_USER not set", +) + + +def _run(command, task_id="ssh_test", **kwargs): + """Call terminal_tool like an LLM would, return parsed JSON.""" + from tools.terminal_tool import terminal_tool + return json.loads(terminal_tool(command, task_id=task_id, **kwargs)) + + +def _cleanup(task_id="ssh_test"): + from tools.terminal_tool import cleanup_vm + cleanup_vm(task_id) + + +# --------------------------------------------------------------------------- +# Unit tests — no SSH connection needed +# --------------------------------------------------------------------------- + +class TestBuildSSHCommand: + """Pure logic: verify the ssh command list is assembled correctly.""" + + @pytest.fixture(autouse=True) + def _mock_connection(self, monkeypatch): + monkeypatch.setattr("tools.environments.ssh.subprocess.run", + lambda *a, **k: subprocess.CompletedProcess([], 0)) + monkeypatch.setattr("tools.environments.ssh.subprocess.Popen", + lambda *a, **k: MagicMock(stdout=iter([]), + stderr=iter([]), + stdin=MagicMock())) + monkeypatch.setattr("tools.environments.ssh.time.sleep", lambda _: None) + + def test_base_flags(self): + env = SSHEnvironment(host="h", user="u") + cmd = " ".join(env._build_ssh_command()) + for flag in ("ControlMaster=auto", "ControlPersist=300", + "BatchMode=yes", "StrictHostKeyChecking=accept-new"): + assert flag in cmd + + def test_custom_port(self): + env = SSHEnvironment(host="h", user="u", port=2222) + cmd = env._build_ssh_command() + assert "-p" in cmd and "2222" in cmd + + def test_key_path(self): + env = SSHEnvironment(host="h", user="u", key_path="/k") + cmd = env._build_ssh_command() + assert "-i" in cmd and "/k" in cmd + + def test_user_host_suffix(self): + env = SSHEnvironment(host="h", user="u") + assert env._build_ssh_command()[-1] == "u@h" + + +class TestTerminalToolConfig: + def test_ssh_persistent_default_false(self, monkeypatch): + monkeypatch.delenv("TERMINAL_SSH_PERSISTENT", raising=False) + from tools.terminal_tool import _get_env_config + assert _get_env_config()["ssh_persistent"] is False + + def test_ssh_persistent_true(self, monkeypatch): + monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "true") + from tools.terminal_tool import _get_env_config + assert _get_env_config()["ssh_persistent"] is True + + +# --------------------------------------------------------------------------- +# Integration tests — real SSH, through terminal_tool() interface +# --------------------------------------------------------------------------- + +def _setup_ssh_env(monkeypatch, persistent: bool): + """Configure env vars for SSH integration tests.""" + monkeypatch.setenv("TERMINAL_ENV", "ssh") + monkeypatch.setenv("TERMINAL_SSH_HOST", _SSH_HOST) + monkeypatch.setenv("TERMINAL_SSH_USER", _SSH_USER) + monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "true" if persistent else "false") + if _SSH_PORT != 22: + monkeypatch.setenv("TERMINAL_SSH_PORT", str(_SSH_PORT)) + if _SSH_KEY: + monkeypatch.setenv("TERMINAL_SSH_KEY", _SSH_KEY) + + +@requires_ssh +class TestOneShotSSH: + """One-shot mode: each command is a fresh ssh invocation.""" + + @pytest.fixture(autouse=True) + def _setup(self, monkeypatch): + _setup_ssh_env(monkeypatch, persistent=False) + yield + _cleanup() + + def test_echo(self): + r = _run("echo hello") + assert r["exit_code"] == 0 + assert "hello" in r["output"] + + def test_exit_code(self): + r = _run("exit 42") + assert r["exit_code"] == 42 + + def test_state_does_not_persist(self): + """Env vars set in one command should NOT survive to the next.""" + _run("export HERMES_ONESHOT_TEST=yes") + r = _run("echo $HERMES_ONESHOT_TEST") + assert r["output"].strip() == "" + + +@requires_ssh +class TestPersistentSSH: + """Persistent mode: single long-lived shell, state persists.""" + + @pytest.fixture(autouse=True) + def _setup(self, monkeypatch): + _setup_ssh_env(monkeypatch, persistent=True) + yield + _cleanup() + + def test_echo(self): + r = _run("echo hello-persistent") + assert r["exit_code"] == 0 + assert "hello-persistent" in r["output"] + + def test_env_var_persists(self): + _run("export HERMES_PERSIST_TEST=works") + r = _run("echo $HERMES_PERSIST_TEST") + assert r["output"].strip() == "works" + + def test_cwd_persists(self): + _run("cd /tmp") + r = _run("pwd") + assert r["output"].strip() == "/tmp" + + def test_exit_code(self): + r = _run("(exit 42)") + assert r["exit_code"] == 42 + + def test_stderr(self): + r = _run("echo oops >&2") + assert r["exit_code"] == 0 + assert "oops" in r["output"] + + def test_multiline_output(self): + r = _run("echo a; echo b; echo c") + lines = r["output"].strip().splitlines() + assert lines == ["a", "b", "c"] + + def test_timeout_then_recovery(self): + r = _run("sleep 999", timeout=2) + assert r["exit_code"] == 124 + # Shell should survive — next command works + r = _run("echo alive") + assert r["exit_code"] == 0 + assert "alive" in r["output"] + + def test_large_output(self): + r = _run("seq 1 1000") + assert r["exit_code"] == 0 + lines = r["output"].strip().splitlines() + assert len(lines) == 1000 + assert lines[0] == "1" + assert lines[-1] == "1000" diff --git a/tools/environments/local.py b/tools/environments/local.py index 276ff9aca..a1d4686ec 100644 --- a/tools/environments/local.py +++ b/tools/environments/local.py @@ -11,6 +11,8 @@ import time _IS_WINDOWS = platform.system() == "Windows" from tools.environments.base import BaseEnvironment +from tools.environments.persistent_shell import PersistentShellMixin +from tools.interrupt import is_interrupted # Unique marker to isolate real command output from shell init/exit noise. # printf (no trailing newline) keeps the boundaries clean for splitting. @@ -162,6 +164,25 @@ def _clean_shell_noise(output: str) -> str: return result +_SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" + + +def _make_run_env(env: dict) -> dict: + """Build a run environment with a sane PATH and provider-var stripping.""" + merged = dict(os.environ | env) + run_env = {} + for k, v in merged.items(): + if k.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX): + real_key = k[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):] + run_env[real_key] = v + elif k not in _HERMES_PROVIDER_ENV_BLOCKLIST: + run_env[k] = v + existing_path = run_env.get("PATH", "") + if "/usr/bin" not in existing_path.split(":"): + run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH + return run_env + + def _extract_fenced_output(raw: str) -> str: """Extract real command output from between fence markers. @@ -186,7 +207,7 @@ def _extract_fenced_output(raw: str) -> str: return raw[start:last] -class LocalEnvironment(BaseEnvironment): +class LocalEnvironment(PersistentShellMixin, BaseEnvironment): """Run commands directly on the host machine. Features: @@ -195,24 +216,72 @@ class LocalEnvironment(BaseEnvironment): - stdin_data support for piping content (bypasses ARG_MAX limits) - sudo -S transform via SUDO_PASSWORD env var - Uses interactive login shell so full user env is available + - Optional persistent shell mode (cwd/env vars survive across calls) """ - def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None): + def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None, + persistent: bool = False): super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env) + self.persistent = persistent + if self.persistent: + self._init_persistent_shell() - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - from tools.terminal_tool import _interrupt_event + # ------------------------------------------------------------------ + # PersistentShellMixin: backend-specific implementations + # ------------------------------------------------------------------ + @property + def _temp_prefix(self) -> str: + return f"/tmp/hermes-local-{self._session_id}" + + def _spawn_shell_process(self) -> subprocess.Popen: + user_shell = _find_bash() + run_env = _make_run_env(self.env) + return subprocess.Popen( + [user_shell, "-l"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env=run_env, + preexec_fn=None if _IS_WINDOWS else os.setsid, + ) + + def _read_temp_files(self, *paths: str) -> list[str]: + """Read local files directly.""" + results = [] + for path in paths: + try: + with open(path) as f: + results.append(f.read()) + except OSError: + results.append("") + return results + + def _kill_shell_children(self): + """Kill children of the persistent shell via pkill -P.""" + if self._shell_pid is None: + return + try: + subprocess.run( + ["pkill", "-P", str(self._shell_pid)], + capture_output=True, timeout=5, + ) + except (subprocess.TimeoutExpired, OSError, FileNotFoundError): + pass + + # ------------------------------------------------------------------ + # One-shot execution (original behavior) + # ------------------------------------------------------------------ + + def _execute_oneshot(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: work_dir = cwd or self.cwd or os.getcwd() effective_timeout = timeout or self.timeout exec_command, sudo_stdin = self._prepare_command(command) # Merge the sudo password (if any) with caller-supplied stdin_data. - # sudo -S reads exactly one line (the password) then passes the rest - # of stdin to the child, so prepending is safe even when stdin_data - # is also present. if sudo_stdin is not None and stdin_data is not None: effective_stdin = sudo_stdin + stdin_data elif sudo_stdin is not None: @@ -221,13 +290,7 @@ class LocalEnvironment(BaseEnvironment): effective_stdin = stdin_data try: - # The fence wrapper uses bash syntax (semicolons, $?, printf). - # Always use bash for the wrapper — NOT $SHELL which could be - # fish, zsh, or another shell with incompatible syntax. - # The -lic flags source rc files so tools like nvm/pyenv work. user_shell = _find_bash() - # Wrap with output fences so we can later extract the real - # command output and discard shell init/exit noise. fenced_cmd = ( f"printf '{_OUTPUT_FENCE}';" f" {exec_command};" @@ -235,24 +298,7 @@ class LocalEnvironment(BaseEnvironment): f" printf '{_OUTPUT_FENCE}';" f" exit $__hermes_rc" ) - # Ensure PATH always includes standard dirs — systemd services - # and some terminal multiplexers inherit a minimal PATH. - _SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" - # Strip Hermes-internal provider vars so external CLIs - # (e.g. codex) are not silently misrouted. Callers that - # truly need a blocked var can opt in by prefixing the key - # with _HERMES_FORCE_ in self.env (e.g. _HERMES_FORCE_OPENAI_API_KEY). - merged = dict(os.environ | self.env) - run_env = {} - for k, v in merged.items(): - if k.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX): - real_key = k[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):] - run_env[real_key] = v - elif k not in _HERMES_PROVIDER_ENV_BLOCKLIST: - run_env[k] = v - existing_path = run_env.get("PATH", "") - if "/usr/bin" not in existing_path.split(":"): - run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH + run_env = _make_run_env(self.env) proc = subprocess.Popen( [user_shell, "-lic", fenced_cmd], @@ -295,7 +341,7 @@ class LocalEnvironment(BaseEnvironment): deadline = time.monotonic() + effective_timeout while proc.poll() is None: - if _interrupt_event.is_set(): + if is_interrupted(): try: if _IS_WINDOWS: proc.terminate() @@ -332,5 +378,21 @@ class LocalEnvironment(BaseEnvironment): except Exception as e: return {"output": f"Execution error: {str(e)}", "returncode": 1} + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + def execute(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + if self.persistent: + return self._execute_persistent( + command, cwd, timeout=timeout, stdin_data=stdin_data, + ) + return self._execute_oneshot( + command, cwd, timeout=timeout, stdin_data=stdin_data, + ) + def cleanup(self): - pass + if self.persistent: + self._cleanup_persistent_shell() diff --git a/tools/environments/persistent_shell.py b/tools/environments/persistent_shell.py new file mode 100644 index 000000000..f0bd438f0 --- /dev/null +++ b/tools/environments/persistent_shell.py @@ -0,0 +1,308 @@ +"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells. + +Provides the shared logic for maintaining a persistent bash shell across +execute() calls. Backend-specific operations (spawning the shell, reading +temp files, killing child processes) are implemented by subclasses via +abstract methods. + +The IPC protocol writes each command's stdout/stderr/exit-code/cwd to temp +files, then polls the status file for completion. A daemon thread drains +the shell's stdout to prevent pipe deadlock and detect shell death. +""" + +import glob as glob_mod +import logging +import os +import shlex +import subprocess +import threading +import time +import uuid +from abc import abstractmethod + +from tools.interrupt import is_interrupted + +logger = logging.getLogger(__name__) + + +class PersistentShellMixin: + """Mixin that adds persistent shell capability to any BaseEnvironment. + + Subclasses MUST implement: + _spawn_shell_process() -> subprocess.Popen + _read_temp_files(*paths) -> list[str] + _kill_shell_children() + + Subclasses MUST also provide ``_execute_oneshot()`` for the stdin_data + fallback path (commands with piped stdin cannot use the persistent shell). + """ + + # -- State (initialized by _init_persistent_shell) --------------------- + _shell_proc: subprocess.Popen | None = None + _shell_alive: bool = False + _shell_pid: int | None = None + _session_id: str = "" + + # -- Abstract methods (backend-specific) ------------------------------- + + @abstractmethod + def _spawn_shell_process(self) -> subprocess.Popen: + """Spawn a long-lived bash shell and return the Popen handle. + + Must use ``stdin=PIPE, stdout=PIPE, stderr=PIPE, text=True``. + """ + ... + + @abstractmethod + def _read_temp_files(self, *paths: str) -> list[str]: + """Read temp files from the execution context. + + Returns contents in the same order as *paths*. Falls back to + empty strings on failure. + """ + ... + + @abstractmethod + def _kill_shell_children(self): + """Kill the running command's processes but keep the shell alive.""" + ... + + # -- Overridable properties -------------------------------------------- + + @property + def _temp_prefix(self) -> str: + """Base path for temp files. Override per backend.""" + return f"/tmp/hermes-persistent-{self._session_id}" + + # -- Shared implementation --------------------------------------------- + + def _init_persistent_shell(self): + """Call from ``__init__`` when ``persistent=True``.""" + self._shell_lock = threading.Lock() + self._session_id = "" + self._shell_proc = None + self._shell_alive = False + self._shell_pid = None + self._start_persistent_shell() + + def _start_persistent_shell(self): + """Spawn the shell, create temp files, capture PID.""" + self._session_id = uuid.uuid4().hex[:12] + p = self._temp_prefix + self._pshell_stdout = f"{p}-stdout" + self._pshell_stderr = f"{p}-stderr" + self._pshell_status = f"{p}-status" + self._pshell_cwd = f"{p}-cwd" + self._pshell_pid_file = f"{p}-pid" + + self._shell_proc = self._spawn_shell_process() + self._shell_alive = True + + self._drain_thread = threading.Thread( + target=self._drain_shell_output, daemon=True, + ) + self._drain_thread.start() + + # Initialize temp files and capture shell PID + init_script = ( + f"touch {self._pshell_stdout} {self._pshell_stderr} " + f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n" + f"echo $$ > {self._pshell_pid_file}\n" + f"pwd > {self._pshell_cwd}\n" + ) + self._send_to_shell(init_script) + + # Poll for PID file + deadline = time.monotonic() + 3.0 + while time.monotonic() < deadline: + pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip() + if pid_str.isdigit(): + self._shell_pid = int(pid_str) + break + time.sleep(0.05) + else: + logger.warning("Could not read persistent shell PID") + self._shell_pid = None + + if self._shell_pid: + logger.info( + "Persistent shell started (session=%s, pid=%d)", + self._session_id, self._shell_pid, + ) + + # Update cwd from what the shell reports + reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip() + if reported_cwd: + self.cwd = reported_cwd + + def _drain_shell_output(self): + """Drain stdout to prevent pipe deadlock; detect shell death.""" + try: + for _ in self._shell_proc.stdout: + pass # Real output goes to temp files + except Exception: + pass + self._shell_alive = False + + def _send_to_shell(self, text: str): + """Write text to the persistent shell's stdin.""" + if not self._shell_alive or self._shell_proc is None: + return + try: + self._shell_proc.stdin.write(text) + self._shell_proc.stdin.flush() + except (BrokenPipeError, OSError): + self._shell_alive = False + + def _read_persistent_output(self) -> tuple[str, int, str]: + """Read stdout, stderr, status, cwd. Returns (output, exit_code, cwd).""" + stdout, stderr, status_raw, cwd = self._read_temp_files( + self._pshell_stdout, self._pshell_stderr, + self._pshell_status, self._pshell_cwd, + ) + output = self._merge_output(stdout, stderr) + # Status format: "cmd_id:exit_code" — strip the ID prefix + status = status_raw.strip() + if ":" in status: + status = status.split(":", 1)[1] + try: + exit_code = int(status.strip()) + except ValueError: + exit_code = 1 + return output, exit_code, cwd.strip() + + def _execute_persistent(self, command: str, cwd: str, *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + """Execute a command in the persistent shell.""" + if not self._shell_alive: + logger.info("Persistent shell died, restarting...") + self._start_persistent_shell() + + exec_command, sudo_stdin = self._prepare_command(command) + effective_timeout = timeout or self.timeout + + # Fall back to one-shot for commands needing piped stdin + if stdin_data or sudo_stdin: + return self._execute_oneshot( + command, cwd, timeout=timeout, stdin_data=stdin_data, + ) + + with self._shell_lock: + return self._execute_persistent_locked( + exec_command, cwd, effective_timeout, + ) + + def _execute_persistent_locked(self, command: str, cwd: str, + timeout: int) -> dict: + """Inner persistent execution — caller must hold ``_shell_lock``.""" + work_dir = cwd or self.cwd + + # Each command gets a unique ID written into the status file so the + # poll loop can distinguish the *current* command's result from a + # stale value left over from the previous command. This eliminates + # the race where a fast local file read sees the old status before + # the shell has processed the truncation. + cmd_id = uuid.uuid4().hex[:8] + + # Truncate temp files + truncate = ( + f": > {self._pshell_stdout}\n" + f": > {self._pshell_stderr}\n" + f": > {self._pshell_status}\n" + ) + self._send_to_shell(truncate) + + # Escape command for eval + escaped = command.replace("'", "'\\''") + + ipc_script = ( + f"cd {shlex.quote(work_dir)}\n" + f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n" + f"__EC=$?\n" + f"pwd > {self._pshell_cwd}\n" + f"echo {cmd_id}:$__EC > {self._pshell_status}\n" + ) + self._send_to_shell(ipc_script) + + # Poll the status file for current command's ID + deadline = time.monotonic() + timeout + poll_interval = 0.15 + + while True: + if is_interrupted(): + self._kill_shell_children() + output, _, _ = self._read_persistent_output() + return { + "output": output + "\n[Command interrupted]", + "returncode": 130, + } + + if time.monotonic() > deadline: + self._kill_shell_children() + output, _, _ = self._read_persistent_output() + if output: + return { + "output": output + f"\n[Command timed out after {timeout}s]", + "returncode": 124, + } + return self._timeout_result(timeout) + + if not self._shell_alive: + return { + "output": "Persistent shell died during execution", + "returncode": 1, + } + + status_content = self._read_temp_files(self._pshell_status)[0].strip() + if status_content.startswith(cmd_id + ":"): + break + + time.sleep(poll_interval) + + output, exit_code, new_cwd = self._read_persistent_output() + if new_cwd: + self.cwd = new_cwd + return {"output": output, "returncode": exit_code} + + @staticmethod + def _merge_output(stdout: str, stderr: str) -> str: + """Combine stdout and stderr into a single output string.""" + parts = [] + if stdout.strip(): + parts.append(stdout.rstrip("\n")) + if stderr.strip(): + parts.append(stderr.rstrip("\n")) + return "\n".join(parts) + + def _cleanup_persistent_shell(self): + """Clean up persistent shell resources. Call from ``cleanup()``.""" + if self._shell_proc is None: + return + + if self._session_id: + self._cleanup_temp_files() + + try: + self._shell_proc.stdin.close() + except Exception: + pass + try: + self._shell_proc.terminate() + self._shell_proc.wait(timeout=3) + except subprocess.TimeoutExpired: + self._shell_proc.kill() + + self._shell_alive = False + self._shell_proc = None + + if hasattr(self, "_drain_thread") and self._drain_thread.is_alive(): + self._drain_thread.join(timeout=1.0) + + def _cleanup_temp_files(self): + """Remove local temp files. Override for remote backends (SSH, Docker).""" + for f in glob_mod.glob(f"{self._temp_prefix}-*"): + try: + os.remove(f) + except OSError: + pass diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index 7a31006db..13893dedd 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -1,21 +1,20 @@ """SSH remote execution environment with ControlMaster connection persistence.""" import logging -import shlex import subprocess import tempfile import threading import time -import uuid from pathlib import Path from tools.environments.base import BaseEnvironment +from tools.environments.persistent_shell import PersistentShellMixin from tools.interrupt import is_interrupted logger = logging.getLogger(__name__) -class SSHEnvironment(BaseEnvironment): +class SSHEnvironment(PersistentShellMixin, BaseEnvironment): """Run commands on a remote machine over SSH. Uses SSH ControlMaster for connection persistence so subsequent @@ -47,22 +46,10 @@ class SSHEnvironment(BaseEnvironment): self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock" self._establish_connection() - # Persistent shell state - self._shell_proc: subprocess.Popen | None = None - self._shell_lock = threading.Lock() - self._shell_alive = False - self._session_id: str = "" - self._remote_stdout: str = "" - self._remote_stderr: str = "" - self._remote_status: str = "" - self._remote_cwd: str = "" - self._remote_pid: str = "" - self._remote_shell_pid: int | None = None - if self.persistent: - self._start_persistent_shell() + self._init_persistent_shell() - def _build_ssh_command(self, extra_args: list = None) -> list: + def _build_ssh_command(self, extra_args: list | None = None) -> list: cmd = ["ssh"] cmd.extend(["-o", f"ControlPath={self.control_socket}"]) cmd.extend(["-o", "ControlMaster=auto"]) @@ -91,230 +78,70 @@ class SSHEnvironment(BaseEnvironment): raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out") # ------------------------------------------------------------------ - # Persistent shell management + # PersistentShellMixin: backend-specific implementations # ------------------------------------------------------------------ - def _start_persistent_shell(self): - """Spawn a long-lived bash shell over SSH.""" - self._session_id = uuid.uuid4().hex[:12] - prefix = f"/tmp/hermes-ssh-{self._session_id}" - self._remote_stdout = f"{prefix}-stdout" - self._remote_stderr = f"{prefix}-stderr" - self._remote_status = f"{prefix}-status" - self._remote_cwd = f"{prefix}-cwd" - self._remote_pid = f"{prefix}-pid" + @property + def _temp_prefix(self) -> str: + return f"/tmp/hermes-ssh-{self._session_id}" + def _spawn_shell_process(self) -> subprocess.Popen: cmd = self._build_ssh_command() cmd.append("bash -l") - - self._shell_proc = subprocess.Popen( + return subprocess.Popen( cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) - self._shell_alive = True - # Start daemon thread to drain stdout/stderr and detect shell death - self._drain_thread = threading.Thread( - target=self._drain_shell_output, daemon=True + def _read_temp_files(self, *paths: str) -> list[str]: + """Read remote files via ControlMaster one-shot SSH calls.""" + if len(paths) == 1: + cmd = self._build_ssh_command() + cmd.append(f"cat {paths[0]} 2>/dev/null") + try: + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=10, + ) + return [result.stdout] + except (subprocess.TimeoutExpired, OSError): + return [""] + + delim = f"__HERMES_SEP_{self._session_id}__" + script = "; ".join( + f"cat {p} 2>/dev/null; echo '{delim}'" for p in paths ) - self._drain_thread.start() - - # Initialize remote temp files and capture shell PID - init_script = ( - f"touch {self._remote_stdout} {self._remote_stderr} " - f"{self._remote_status} {self._remote_cwd} {self._remote_pid}\n" - f"echo $$ > {self._remote_pid}\n" - f"pwd > {self._remote_cwd}\n" - ) - self._send_to_shell(init_script) - - # Give shell time to initialize and write PID file - time.sleep(0.3) - - # Read the remote shell PID - pid_str = self._read_remote_file(self._remote_pid).strip() - if pid_str.isdigit(): - self._remote_shell_pid = int(pid_str) - logger.info("Persistent shell started (session=%s, pid=%d)", - self._session_id, self._remote_shell_pid) - else: - logger.warning("Could not read persistent shell PID (got %r)", pid_str) - self._remote_shell_pid = None - - # Update cwd from what the shell reports - remote_cwd = self._read_remote_file(self._remote_cwd).strip() - if remote_cwd: - self.cwd = remote_cwd - - def _drain_shell_output(self): - """Drain the shell's stdout/stderr to prevent pipe deadlock. - - Also detects when the shell process dies. - """ - try: - for _ in self._shell_proc.stdout: - pass # Discard — real output goes to temp files - except Exception: - pass - self._shell_alive = False - - def _send_to_shell(self, text: str): - """Write text to the persistent shell's stdin.""" - if not self._shell_alive or self._shell_proc is None: - return - try: - self._shell_proc.stdin.write(text) - self._shell_proc.stdin.flush() - except (BrokenPipeError, OSError): - self._shell_alive = False - - def _read_remote_file(self, path: str) -> str: - """Read a file on the remote host via a one-shot SSH command. - - Uses ControlMaster so this is very fast (~5ms on LAN). - """ cmd = self._build_ssh_command() - cmd.append(f"cat {path} 2>/dev/null") + cmd.append(script) try: result = subprocess.run( - cmd, capture_output=True, text=True, timeout=10 + cmd, capture_output=True, text=True, timeout=10, ) - return result.stdout + parts = result.stdout.split(delim + "\n") + return [parts[i] if i < len(parts) else "" for i in range(len(paths))] except (subprocess.TimeoutExpired, OSError): - return "" + return [""] * len(paths) def _kill_shell_children(self): - """Kill children of the persistent shell (the running command), - but not the shell itself.""" - if self._remote_shell_pid is None: + if self._shell_pid is None: return cmd = self._build_ssh_command() - cmd.append(f"pkill -P {self._remote_shell_pid} 2>/dev/null; true") + cmd.append(f"pkill -P {self._shell_pid} 2>/dev/null; true") try: subprocess.run(cmd, capture_output=True, timeout=5) except (subprocess.TimeoutExpired, OSError): pass - def _execute_persistent(self, command: str, cwd: str, *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - """Execute a command in the persistent shell.""" - # If shell is dead, restart it - if not self._shell_alive: - logger.info("Persistent shell died, restarting...") - self._start_persistent_shell() - - exec_command, sudo_stdin = self._prepare_command(command) - effective_timeout = timeout or self.timeout - - # Fall back to one-shot for commands needing piped stdin - if stdin_data or sudo_stdin: - return self._execute_oneshot( - command, cwd, timeout=timeout, stdin_data=stdin_data - ) - - with self._shell_lock: - return self._execute_persistent_locked( - exec_command, cwd, effective_timeout - ) - - def _execute_persistent_locked(self, command: str, cwd: str, - timeout: int) -> dict: - """Inner persistent execution — caller must hold _shell_lock.""" - work_dir = cwd or self.cwd - - # Truncate temp files - truncate = ( - f": > {self._remote_stdout}\n" - f": > {self._remote_stderr}\n" - f": > {self._remote_status}\n" - ) - self._send_to_shell(truncate) - - # Escape command for eval — use single quotes with proper escaping - escaped = command.replace("'", "'\\''") - - # Send the IPC script - ipc_script = ( - f"cd {shlex.quote(work_dir)}\n" - f"eval '{escaped}' < /dev/null > {self._remote_stdout} 2> {self._remote_stderr}\n" - f"__EC=$?\n" - f"pwd > {self._remote_cwd}\n" - f"echo $__EC > {self._remote_status}\n" - ) - self._send_to_shell(ipc_script) - - # Poll the status file - deadline = time.monotonic() + timeout - poll_interval = 0.05 # 50ms - - while True: - if is_interrupted(): - self._kill_shell_children() - stdout = self._read_remote_file(self._remote_stdout) - stderr = self._read_remote_file(self._remote_stderr) - output = self._merge_output(stdout, stderr) - return { - "output": output + "\n[Command interrupted]", - "returncode": 130, - } - - if time.monotonic() > deadline: - self._kill_shell_children() - stdout = self._read_remote_file(self._remote_stdout) - stderr = self._read_remote_file(self._remote_stderr) - output = self._merge_output(stdout, stderr) - if output: - return { - "output": output + f"\n[Command timed out after {timeout}s]", - "returncode": 124, - } - return self._timeout_result(timeout) - - if not self._shell_alive: - return { - "output": "Persistent shell died during execution", - "returncode": 1, - } - - # Check if status file has content (command is done) - status_content = self._read_remote_file(self._remote_status).strip() - if status_content: - break - - time.sleep(poll_interval) - - # Read results - stdout = self._read_remote_file(self._remote_stdout) - stderr = self._read_remote_file(self._remote_stderr) - exit_code_str = status_content - new_cwd = self._read_remote_file(self._remote_cwd).strip() - - # Parse exit code + def _cleanup_temp_files(self): + """Remove remote temp files via SSH.""" try: - exit_code = int(exit_code_str) - except ValueError: - exit_code = 1 - - # Update cwd - if new_cwd: - self.cwd = new_cwd - - output = self._merge_output(stdout, stderr) - return {"output": output, "returncode": exit_code} - - @staticmethod - def _merge_output(stdout: str, stderr: str) -> str: - """Combine stdout and stderr into a single output string.""" - parts = [] - if stdout.strip(): - parts.append(stdout.rstrip("\n")) - if stderr.strip(): - parts.append(stderr.rstrip("\n")) - return "\n".join(parts) + cmd = self._build_ssh_command() + cmd.append(f"rm -f {self._temp_prefix}-*") + subprocess.run(cmd, capture_output=True, timeout=5) + except (OSError, subprocess.SubprocessError): + pass # ------------------------------------------------------------------ # One-shot execution (original behavior) @@ -413,34 +240,9 @@ class SSHEnvironment(BaseEnvironment): ) def cleanup(self): - # Persistent shell teardown - if self.persistent and self._shell_proc is not None: - # Remove remote temp files - if self._session_id: - try: - cmd = self._build_ssh_command() - cmd.append( - f"rm -f /tmp/hermes-ssh-{self._session_id}-*" - ) - subprocess.run(cmd, capture_output=True, timeout=5) - except (OSError, subprocess.SubprocessError): - pass - - # Close the shell - try: - self._shell_proc.stdin.close() - except Exception: - pass - try: - self._shell_proc.terminate() - self._shell_proc.wait(timeout=3) - except Exception: - try: - self._shell_proc.kill() - except Exception: - pass - self._shell_alive = False - self._shell_proc = None + # Persistent shell teardown (via mixin) + if self.persistent: + self._cleanup_persistent_shell() # ControlMaster cleanup if self.control_socket.exists(): diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index c7f72040a..b273ec028 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -504,6 +504,8 @@ def _get_env_config() -> Dict[str, Any]: "ssh_port": _parse_env_var("TERMINAL_SSH_PORT", "22"), "ssh_key": os.getenv("TERMINAL_SSH_KEY", ""), "ssh_persistent": os.getenv("TERMINAL_SSH_PERSISTENT", "false").lower() in ("true", "1", "yes"), + # Local persistent shell (cwd/env vars survive across calls) + "local_persistent": os.getenv("TERMINAL_LOCAL_PERSISTENT", "false").lower() in ("true", "1", "yes"), # Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh) "container_cpu": _parse_env_var("TERMINAL_CONTAINER_CPU", "1", float, "number"), "container_memory": _parse_env_var("TERMINAL_CONTAINER_MEMORY", "5120"), # MB (default 5GB) @@ -515,6 +517,7 @@ def _get_env_config() -> Dict[str, Any]: def _create_environment(env_type: str, image: str, cwd: str, timeout: int, ssh_config: dict = None, container_config: dict = None, + local_config: dict = None, task_id: str = "default"): """ Create an execution environment from mini-swe-agent. @@ -539,7 +542,9 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, volumes = cc.get("docker_volumes", []) if env_type == "local": - return _LocalEnvironment(cwd=cwd, timeout=timeout) + lc = local_config or {} + return _LocalEnvironment(cwd=cwd, timeout=timeout, + persistent=lc.get("persistent", False)) elif env_type == "docker": return _DockerEnvironment( @@ -938,6 +943,12 @@ def terminal_tool( "docker_volumes": config.get("docker_volumes", []), } + local_config = None + if env_type == "local": + local_config = { + "persistent": config.get("local_persistent", False), + } + new_env = _create_environment( env_type=env_type, image=image, @@ -945,6 +956,7 @@ def terminal_tool( timeout=effective_timeout, ssh_config=ssh_config, container_config=container_config, + local_config=local_config, task_id=effective_task_id, ) except ImportError as e: