wip: add persistent shell to ssh and local terminal backends

This commit is contained in:
balyan.sid@gmail.com
2026-03-13 16:54:11 +05:30
parent 9d63dcc3f9
commit 861202b56c
6 changed files with 842 additions and 277 deletions

View File

@@ -0,0 +1,183 @@
"""Tests for the local persistent shell backend.
Unit tests cover config plumbing (no real shell needed).
Integration tests run real commands — no external dependencies required.
pytest tests/tools/test_local_persistent.py -v
"""
import glob as glob_mod
import pytest
from tools.environments.local import LocalEnvironment
from tools.environments.persistent_shell import PersistentShellMixin
# ---------------------------------------------------------------------------
# Unit tests — config plumbing
# ---------------------------------------------------------------------------
class TestLocalConfig:
def test_local_persistent_default_false(self, monkeypatch):
monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False)
from tools.terminal_tool import _get_env_config
assert _get_env_config()["local_persistent"] is False
def test_local_persistent_true(self, monkeypatch):
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "true")
from tools.terminal_tool import _get_env_config
assert _get_env_config()["local_persistent"] is True
def test_local_persistent_yes(self, monkeypatch):
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "yes")
from tools.terminal_tool import _get_env_config
assert _get_env_config()["local_persistent"] is True
class TestMergeOutput:
"""Test the shared _merge_output static method."""
def test_stdout_only(self):
assert PersistentShellMixin._merge_output("out", "") == "out"
def test_stderr_only(self):
assert PersistentShellMixin._merge_output("", "err") == "err"
def test_both(self):
assert PersistentShellMixin._merge_output("out", "err") == "out\nerr"
def test_empty(self):
assert PersistentShellMixin._merge_output("", "") == ""
def test_strips_trailing_newlines(self):
assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr"
# ---------------------------------------------------------------------------
# One-shot regression tests — ensure refactor didn't break anything
# ---------------------------------------------------------------------------
class TestLocalOneShotRegression:
"""Verify one-shot mode still works after adding the mixin."""
def test_echo(self):
env = LocalEnvironment(persistent=False)
r = env.execute("echo hello")
assert r["returncode"] == 0
assert "hello" in r["output"]
env.cleanup()
def test_exit_code(self):
env = LocalEnvironment(persistent=False)
r = env.execute("exit 42")
assert r["returncode"] == 42
env.cleanup()
def test_state_does_not_persist(self):
"""Env vars set in one command should NOT survive in one-shot mode."""
env = LocalEnvironment(persistent=False)
env.execute("export HERMES_ONESHOT_LOCAL=yes")
r = env.execute("echo $HERMES_ONESHOT_LOCAL")
# In one-shot mode, env var should not persist
assert r["output"].strip() == ""
env.cleanup()
# ---------------------------------------------------------------------------
# Persistent shell integration tests
# ---------------------------------------------------------------------------
class TestLocalPersistent:
"""Persistent mode: state persists across execute() calls."""
@pytest.fixture
def env(self):
e = LocalEnvironment(persistent=True)
yield e
e.cleanup()
def test_echo(self, env):
r = env.execute("echo hello-persistent")
assert r["returncode"] == 0
assert "hello-persistent" in r["output"]
def test_env_var_persists(self, env):
env.execute("export HERMES_LOCAL_PERSIST_TEST=works")
r = env.execute("echo $HERMES_LOCAL_PERSIST_TEST")
assert r["output"].strip() == "works"
def test_cwd_persists(self, env):
env.execute("cd /tmp")
r = env.execute("pwd")
assert r["output"].strip() == "/tmp"
def test_exit_code(self, env):
r = env.execute("(exit 42)")
assert r["returncode"] == 42
def test_stderr(self, env):
r = env.execute("echo oops >&2")
assert r["returncode"] == 0
assert "oops" in r["output"]
def test_multiline_output(self, env):
r = env.execute("echo a; echo b; echo c")
lines = r["output"].strip().splitlines()
assert lines == ["a", "b", "c"]
def test_timeout_then_recovery(self, env):
r = env.execute("sleep 999", timeout=2)
assert r["returncode"] in (124, 130) # timeout or interrupted
# Shell should survive — next command works
r = env.execute("echo alive")
assert r["returncode"] == 0
assert "alive" in r["output"]
def test_large_output(self, env):
r = env.execute("seq 1 1000")
assert r["returncode"] == 0
lines = r["output"].strip().splitlines()
assert len(lines) == 1000
assert lines[0] == "1"
assert lines[-1] == "1000"
def test_shell_variable_persists(self, env):
"""Shell variables (not exported) should also persist."""
env.execute("MY_LOCAL_VAR=hello123")
r = env.execute("echo $MY_LOCAL_VAR")
assert r["output"].strip() == "hello123"
def test_cleanup_removes_temp_files(self, env):
env.execute("echo warmup")
prefix = env._temp_prefix
# Temp files should exist
assert len(glob_mod.glob(f"{prefix}-*")) > 0
env.cleanup()
remaining = glob_mod.glob(f"{prefix}-*")
assert remaining == []
def test_state_does_not_leak_between_instances(self):
"""Two separate persistent instances don't share state."""
env1 = LocalEnvironment(persistent=True)
env2 = LocalEnvironment(persistent=True)
try:
env1.execute("export LEAK_TEST=from_env1")
r = env2.execute("echo $LEAK_TEST")
assert r["output"].strip() == ""
finally:
env1.cleanup()
env2.cleanup()
def test_special_characters_in_command(self, env):
"""Commands with quotes and special chars should work."""
r = env.execute("echo 'hello world'")
assert r["output"].strip() == "hello world"
def test_pipe_command(self, env):
r = env.execute("echo hello | tr 'h' 'H'")
assert r["output"].strip() == "Hello"
def test_multiple_commands_semicolon(self, env):
r = env.execute("X=42; echo $X")
assert r["output"].strip() == "42"

View File

@@ -0,0 +1,198 @@
"""Tests for the SSH remote execution environment backend.
Unit tests (no SSH required) cover pure logic: command building, output merging,
config plumbing.
Integration tests require a real SSH target. Set TERMINAL_SSH_HOST and
TERMINAL_SSH_USER to enable them. In CI, start an sshd container or enable
the localhost SSH service.
TERMINAL_SSH_HOST=localhost TERMINAL_SSH_USER=$(whoami) \
pytest tests/tools/test_ssh_environment.py -v
"""
import json
import os
import subprocess
from unittest.mock import MagicMock
import pytest
from tools.environments.ssh import SSHEnvironment
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_SSH_HOST = os.getenv("TERMINAL_SSH_HOST", "")
_SSH_USER = os.getenv("TERMINAL_SSH_USER", "")
_SSH_PORT = int(os.getenv("TERMINAL_SSH_PORT", "22"))
_SSH_KEY = os.getenv("TERMINAL_SSH_KEY", "")
_has_ssh = bool(_SSH_HOST and _SSH_USER)
requires_ssh = pytest.mark.skipif(
not _has_ssh,
reason="TERMINAL_SSH_HOST / TERMINAL_SSH_USER not set",
)
def _run(command, task_id="ssh_test", **kwargs):
"""Call terminal_tool like an LLM would, return parsed JSON."""
from tools.terminal_tool import terminal_tool
return json.loads(terminal_tool(command, task_id=task_id, **kwargs))
def _cleanup(task_id="ssh_test"):
from tools.terminal_tool import cleanup_vm
cleanup_vm(task_id)
# ---------------------------------------------------------------------------
# Unit tests — no SSH connection needed
# ---------------------------------------------------------------------------
class TestBuildSSHCommand:
"""Pure logic: verify the ssh command list is assembled correctly."""
@pytest.fixture(autouse=True)
def _mock_connection(self, monkeypatch):
monkeypatch.setattr("tools.environments.ssh.subprocess.run",
lambda *a, **k: subprocess.CompletedProcess([], 0))
monkeypatch.setattr("tools.environments.ssh.subprocess.Popen",
lambda *a, **k: MagicMock(stdout=iter([]),
stderr=iter([]),
stdin=MagicMock()))
monkeypatch.setattr("tools.environments.ssh.time.sleep", lambda _: None)
def test_base_flags(self):
env = SSHEnvironment(host="h", user="u")
cmd = " ".join(env._build_ssh_command())
for flag in ("ControlMaster=auto", "ControlPersist=300",
"BatchMode=yes", "StrictHostKeyChecking=accept-new"):
assert flag in cmd
def test_custom_port(self):
env = SSHEnvironment(host="h", user="u", port=2222)
cmd = env._build_ssh_command()
assert "-p" in cmd and "2222" in cmd
def test_key_path(self):
env = SSHEnvironment(host="h", user="u", key_path="/k")
cmd = env._build_ssh_command()
assert "-i" in cmd and "/k" in cmd
def test_user_host_suffix(self):
env = SSHEnvironment(host="h", user="u")
assert env._build_ssh_command()[-1] == "u@h"
class TestTerminalToolConfig:
def test_ssh_persistent_default_false(self, monkeypatch):
monkeypatch.delenv("TERMINAL_SSH_PERSISTENT", raising=False)
from tools.terminal_tool import _get_env_config
assert _get_env_config()["ssh_persistent"] is False
def test_ssh_persistent_true(self, monkeypatch):
monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "true")
from tools.terminal_tool import _get_env_config
assert _get_env_config()["ssh_persistent"] is True
# ---------------------------------------------------------------------------
# Integration tests — real SSH, through terminal_tool() interface
# ---------------------------------------------------------------------------
def _setup_ssh_env(monkeypatch, persistent: bool):
"""Configure env vars for SSH integration tests."""
monkeypatch.setenv("TERMINAL_ENV", "ssh")
monkeypatch.setenv("TERMINAL_SSH_HOST", _SSH_HOST)
monkeypatch.setenv("TERMINAL_SSH_USER", _SSH_USER)
monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "true" if persistent else "false")
if _SSH_PORT != 22:
monkeypatch.setenv("TERMINAL_SSH_PORT", str(_SSH_PORT))
if _SSH_KEY:
monkeypatch.setenv("TERMINAL_SSH_KEY", _SSH_KEY)
@requires_ssh
class TestOneShotSSH:
"""One-shot mode: each command is a fresh ssh invocation."""
@pytest.fixture(autouse=True)
def _setup(self, monkeypatch):
_setup_ssh_env(monkeypatch, persistent=False)
yield
_cleanup()
def test_echo(self):
r = _run("echo hello")
assert r["exit_code"] == 0
assert "hello" in r["output"]
def test_exit_code(self):
r = _run("exit 42")
assert r["exit_code"] == 42
def test_state_does_not_persist(self):
"""Env vars set in one command should NOT survive to the next."""
_run("export HERMES_ONESHOT_TEST=yes")
r = _run("echo $HERMES_ONESHOT_TEST")
assert r["output"].strip() == ""
@requires_ssh
class TestPersistentSSH:
"""Persistent mode: single long-lived shell, state persists."""
@pytest.fixture(autouse=True)
def _setup(self, monkeypatch):
_setup_ssh_env(monkeypatch, persistent=True)
yield
_cleanup()
def test_echo(self):
r = _run("echo hello-persistent")
assert r["exit_code"] == 0
assert "hello-persistent" in r["output"]
def test_env_var_persists(self):
_run("export HERMES_PERSIST_TEST=works")
r = _run("echo $HERMES_PERSIST_TEST")
assert r["output"].strip() == "works"
def test_cwd_persists(self):
_run("cd /tmp")
r = _run("pwd")
assert r["output"].strip() == "/tmp"
def test_exit_code(self):
r = _run("(exit 42)")
assert r["exit_code"] == 42
def test_stderr(self):
r = _run("echo oops >&2")
assert r["exit_code"] == 0
assert "oops" in r["output"]
def test_multiline_output(self):
r = _run("echo a; echo b; echo c")
lines = r["output"].strip().splitlines()
assert lines == ["a", "b", "c"]
def test_timeout_then_recovery(self):
r = _run("sleep 999", timeout=2)
assert r["exit_code"] == 124
# Shell should survive — next command works
r = _run("echo alive")
assert r["exit_code"] == 0
assert "alive" in r["output"]
def test_large_output(self):
r = _run("seq 1 1000")
assert r["exit_code"] == 0
lines = r["output"].strip().splitlines()
assert len(lines) == 1000
assert lines[0] == "1"
assert lines[-1] == "1000"

View File

@@ -11,6 +11,8 @@ import time
_IS_WINDOWS = platform.system() == "Windows"
from tools.environments.base import BaseEnvironment
from tools.environments.persistent_shell import PersistentShellMixin
from tools.interrupt import is_interrupted
# Unique marker to isolate real command output from shell init/exit noise.
# printf (no trailing newline) keeps the boundaries clean for splitting.
@@ -162,6 +164,25 @@ def _clean_shell_noise(output: str) -> str:
return result
_SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
def _make_run_env(env: dict) -> dict:
"""Build a run environment with a sane PATH and provider-var stripping."""
merged = dict(os.environ | env)
run_env = {}
for k, v in merged.items():
if k.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
real_key = k[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):]
run_env[real_key] = v
elif k not in _HERMES_PROVIDER_ENV_BLOCKLIST:
run_env[k] = v
existing_path = run_env.get("PATH", "")
if "/usr/bin" not in existing_path.split(":"):
run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH
return run_env
def _extract_fenced_output(raw: str) -> str:
"""Extract real command output from between fence markers.
@@ -186,7 +207,7 @@ def _extract_fenced_output(raw: str) -> str:
return raw[start:last]
class LocalEnvironment(BaseEnvironment):
class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
"""Run commands directly on the host machine.
Features:
@@ -195,24 +216,72 @@ class LocalEnvironment(BaseEnvironment):
- stdin_data support for piping content (bypasses ARG_MAX limits)
- sudo -S transform via SUDO_PASSWORD env var
- Uses interactive login shell so full user env is available
- Optional persistent shell mode (cwd/env vars survive across calls)
"""
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None):
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None,
persistent: bool = False):
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
self.persistent = persistent
if self.persistent:
self._init_persistent_shell()
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
from tools.terminal_tool import _interrupt_event
# ------------------------------------------------------------------
# PersistentShellMixin: backend-specific implementations
# ------------------------------------------------------------------
@property
def _temp_prefix(self) -> str:
return f"/tmp/hermes-local-{self._session_id}"
def _spawn_shell_process(self) -> subprocess.Popen:
user_shell = _find_bash()
run_env = _make_run_env(self.env)
return subprocess.Popen(
[user_shell, "-l"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
env=run_env,
preexec_fn=None if _IS_WINDOWS else os.setsid,
)
def _read_temp_files(self, *paths: str) -> list[str]:
"""Read local files directly."""
results = []
for path in paths:
try:
with open(path) as f:
results.append(f.read())
except OSError:
results.append("")
return results
def _kill_shell_children(self):
"""Kill children of the persistent shell via pkill -P."""
if self._shell_pid is None:
return
try:
subprocess.run(
["pkill", "-P", str(self._shell_pid)],
capture_output=True, timeout=5,
)
except (subprocess.TimeoutExpired, OSError, FileNotFoundError):
pass
# ------------------------------------------------------------------
# One-shot execution (original behavior)
# ------------------------------------------------------------------
def _execute_oneshot(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
work_dir = cwd or self.cwd or os.getcwd()
effective_timeout = timeout or self.timeout
exec_command, sudo_stdin = self._prepare_command(command)
# Merge the sudo password (if any) with caller-supplied stdin_data.
# sudo -S reads exactly one line (the password) then passes the rest
# of stdin to the child, so prepending is safe even when stdin_data
# is also present.
if sudo_stdin is not None and stdin_data is not None:
effective_stdin = sudo_stdin + stdin_data
elif sudo_stdin is not None:
@@ -221,13 +290,7 @@ class LocalEnvironment(BaseEnvironment):
effective_stdin = stdin_data
try:
# The fence wrapper uses bash syntax (semicolons, $?, printf).
# Always use bash for the wrapper — NOT $SHELL which could be
# fish, zsh, or another shell with incompatible syntax.
# The -lic flags source rc files so tools like nvm/pyenv work.
user_shell = _find_bash()
# Wrap with output fences so we can later extract the real
# command output and discard shell init/exit noise.
fenced_cmd = (
f"printf '{_OUTPUT_FENCE}';"
f" {exec_command};"
@@ -235,24 +298,7 @@ class LocalEnvironment(BaseEnvironment):
f" printf '{_OUTPUT_FENCE}';"
f" exit $__hermes_rc"
)
# Ensure PATH always includes standard dirs — systemd services
# and some terminal multiplexers inherit a minimal PATH.
_SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
# Strip Hermes-internal provider vars so external CLIs
# (e.g. codex) are not silently misrouted. Callers that
# truly need a blocked var can opt in by prefixing the key
# with _HERMES_FORCE_ in self.env (e.g. _HERMES_FORCE_OPENAI_API_KEY).
merged = dict(os.environ | self.env)
run_env = {}
for k, v in merged.items():
if k.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
real_key = k[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):]
run_env[real_key] = v
elif k not in _HERMES_PROVIDER_ENV_BLOCKLIST:
run_env[k] = v
existing_path = run_env.get("PATH", "")
if "/usr/bin" not in existing_path.split(":"):
run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH
run_env = _make_run_env(self.env)
proc = subprocess.Popen(
[user_shell, "-lic", fenced_cmd],
@@ -295,7 +341,7 @@ class LocalEnvironment(BaseEnvironment):
deadline = time.monotonic() + effective_timeout
while proc.poll() is None:
if _interrupt_event.is_set():
if is_interrupted():
try:
if _IS_WINDOWS:
proc.terminate()
@@ -332,5 +378,21 @@ class LocalEnvironment(BaseEnvironment):
except Exception as e:
return {"output": f"Execution error: {str(e)}", "returncode": 1}
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
if self.persistent:
return self._execute_persistent(
command, cwd, timeout=timeout, stdin_data=stdin_data,
)
return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data,
)
def cleanup(self):
pass
if self.persistent:
self._cleanup_persistent_shell()

View File

@@ -0,0 +1,308 @@
"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells.
Provides the shared logic for maintaining a persistent bash shell across
execute() calls. Backend-specific operations (spawning the shell, reading
temp files, killing child processes) are implemented by subclasses via
abstract methods.
The IPC protocol writes each command's stdout/stderr/exit-code/cwd to temp
files, then polls the status file for completion. A daemon thread drains
the shell's stdout to prevent pipe deadlock and detect shell death.
"""
import glob as glob_mod
import logging
import os
import shlex
import subprocess
import threading
import time
import uuid
from abc import abstractmethod
from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
class PersistentShellMixin:
"""Mixin that adds persistent shell capability to any BaseEnvironment.
Subclasses MUST implement:
_spawn_shell_process() -> subprocess.Popen
_read_temp_files(*paths) -> list[str]
_kill_shell_children()
Subclasses MUST also provide ``_execute_oneshot()`` for the stdin_data
fallback path (commands with piped stdin cannot use the persistent shell).
"""
# -- State (initialized by _init_persistent_shell) ---------------------
_shell_proc: subprocess.Popen | None = None
_shell_alive: bool = False
_shell_pid: int | None = None
_session_id: str = ""
# -- Abstract methods (backend-specific) -------------------------------
@abstractmethod
def _spawn_shell_process(self) -> subprocess.Popen:
"""Spawn a long-lived bash shell and return the Popen handle.
Must use ``stdin=PIPE, stdout=PIPE, stderr=PIPE, text=True``.
"""
...
@abstractmethod
def _read_temp_files(self, *paths: str) -> list[str]:
"""Read temp files from the execution context.
Returns contents in the same order as *paths*. Falls back to
empty strings on failure.
"""
...
@abstractmethod
def _kill_shell_children(self):
"""Kill the running command's processes but keep the shell alive."""
...
# -- Overridable properties --------------------------------------------
@property
def _temp_prefix(self) -> str:
"""Base path for temp files. Override per backend."""
return f"/tmp/hermes-persistent-{self._session_id}"
# -- Shared implementation ---------------------------------------------
def _init_persistent_shell(self):
"""Call from ``__init__`` when ``persistent=True``."""
self._shell_lock = threading.Lock()
self._session_id = ""
self._shell_proc = None
self._shell_alive = False
self._shell_pid = None
self._start_persistent_shell()
def _start_persistent_shell(self):
"""Spawn the shell, create temp files, capture PID."""
self._session_id = uuid.uuid4().hex[:12]
p = self._temp_prefix
self._pshell_stdout = f"{p}-stdout"
self._pshell_stderr = f"{p}-stderr"
self._pshell_status = f"{p}-status"
self._pshell_cwd = f"{p}-cwd"
self._pshell_pid_file = f"{p}-pid"
self._shell_proc = self._spawn_shell_process()
self._shell_alive = True
self._drain_thread = threading.Thread(
target=self._drain_shell_output, daemon=True,
)
self._drain_thread.start()
# Initialize temp files and capture shell PID
init_script = (
f"touch {self._pshell_stdout} {self._pshell_stderr} "
f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n"
f"echo $$ > {self._pshell_pid_file}\n"
f"pwd > {self._pshell_cwd}\n"
)
self._send_to_shell(init_script)
# Poll for PID file
deadline = time.monotonic() + 3.0
while time.monotonic() < deadline:
pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip()
if pid_str.isdigit():
self._shell_pid = int(pid_str)
break
time.sleep(0.05)
else:
logger.warning("Could not read persistent shell PID")
self._shell_pid = None
if self._shell_pid:
logger.info(
"Persistent shell started (session=%s, pid=%d)",
self._session_id, self._shell_pid,
)
# Update cwd from what the shell reports
reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip()
if reported_cwd:
self.cwd = reported_cwd
def _drain_shell_output(self):
"""Drain stdout to prevent pipe deadlock; detect shell death."""
try:
for _ in self._shell_proc.stdout:
pass # Real output goes to temp files
except Exception:
pass
self._shell_alive = False
def _send_to_shell(self, text: str):
"""Write text to the persistent shell's stdin."""
if not self._shell_alive or self._shell_proc is None:
return
try:
self._shell_proc.stdin.write(text)
self._shell_proc.stdin.flush()
except (BrokenPipeError, OSError):
self._shell_alive = False
def _read_persistent_output(self) -> tuple[str, int, str]:
"""Read stdout, stderr, status, cwd. Returns (output, exit_code, cwd)."""
stdout, stderr, status_raw, cwd = self._read_temp_files(
self._pshell_stdout, self._pshell_stderr,
self._pshell_status, self._pshell_cwd,
)
output = self._merge_output(stdout, stderr)
# Status format: "cmd_id:exit_code" — strip the ID prefix
status = status_raw.strip()
if ":" in status:
status = status.split(":", 1)[1]
try:
exit_code = int(status.strip())
except ValueError:
exit_code = 1
return output, exit_code, cwd.strip()
def _execute_persistent(self, command: str, cwd: str, *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
"""Execute a command in the persistent shell."""
if not self._shell_alive:
logger.info("Persistent shell died, restarting...")
self._start_persistent_shell()
exec_command, sudo_stdin = self._prepare_command(command)
effective_timeout = timeout or self.timeout
# Fall back to one-shot for commands needing piped stdin
if stdin_data or sudo_stdin:
return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data,
)
with self._shell_lock:
return self._execute_persistent_locked(
exec_command, cwd, effective_timeout,
)
def _execute_persistent_locked(self, command: str, cwd: str,
timeout: int) -> dict:
"""Inner persistent execution — caller must hold ``_shell_lock``."""
work_dir = cwd or self.cwd
# Each command gets a unique ID written into the status file so the
# poll loop can distinguish the *current* command's result from a
# stale value left over from the previous command. This eliminates
# the race where a fast local file read sees the old status before
# the shell has processed the truncation.
cmd_id = uuid.uuid4().hex[:8]
# Truncate temp files
truncate = (
f": > {self._pshell_stdout}\n"
f": > {self._pshell_stderr}\n"
f": > {self._pshell_status}\n"
)
self._send_to_shell(truncate)
# Escape command for eval
escaped = command.replace("'", "'\\''")
ipc_script = (
f"cd {shlex.quote(work_dir)}\n"
f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n"
f"__EC=$?\n"
f"pwd > {self._pshell_cwd}\n"
f"echo {cmd_id}:$__EC > {self._pshell_status}\n"
)
self._send_to_shell(ipc_script)
# Poll the status file for current command's ID
deadline = time.monotonic() + timeout
poll_interval = 0.15
while True:
if is_interrupted():
self._kill_shell_children()
output, _, _ = self._read_persistent_output()
return {
"output": output + "\n[Command interrupted]",
"returncode": 130,
}
if time.monotonic() > deadline:
self._kill_shell_children()
output, _, _ = self._read_persistent_output()
if output:
return {
"output": output + f"\n[Command timed out after {timeout}s]",
"returncode": 124,
}
return self._timeout_result(timeout)
if not self._shell_alive:
return {
"output": "Persistent shell died during execution",
"returncode": 1,
}
status_content = self._read_temp_files(self._pshell_status)[0].strip()
if status_content.startswith(cmd_id + ":"):
break
time.sleep(poll_interval)
output, exit_code, new_cwd = self._read_persistent_output()
if new_cwd:
self.cwd = new_cwd
return {"output": output, "returncode": exit_code}
@staticmethod
def _merge_output(stdout: str, stderr: str) -> str:
"""Combine stdout and stderr into a single output string."""
parts = []
if stdout.strip():
parts.append(stdout.rstrip("\n"))
if stderr.strip():
parts.append(stderr.rstrip("\n"))
return "\n".join(parts)
def _cleanup_persistent_shell(self):
"""Clean up persistent shell resources. Call from ``cleanup()``."""
if self._shell_proc is None:
return
if self._session_id:
self._cleanup_temp_files()
try:
self._shell_proc.stdin.close()
except Exception:
pass
try:
self._shell_proc.terminate()
self._shell_proc.wait(timeout=3)
except subprocess.TimeoutExpired:
self._shell_proc.kill()
self._shell_alive = False
self._shell_proc = None
if hasattr(self, "_drain_thread") and self._drain_thread.is_alive():
self._drain_thread.join(timeout=1.0)
def _cleanup_temp_files(self):
"""Remove local temp files. Override for remote backends (SSH, Docker)."""
for f in glob_mod.glob(f"{self._temp_prefix}-*"):
try:
os.remove(f)
except OSError:
pass

View File

@@ -1,21 +1,20 @@
"""SSH remote execution environment with ControlMaster connection persistence."""
import logging
import shlex
import subprocess
import tempfile
import threading
import time
import uuid
from pathlib import Path
from tools.environments.base import BaseEnvironment
from tools.environments.persistent_shell import PersistentShellMixin
from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
class SSHEnvironment(BaseEnvironment):
class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
"""Run commands on a remote machine over SSH.
Uses SSH ControlMaster for connection persistence so subsequent
@@ -47,22 +46,10 @@ class SSHEnvironment(BaseEnvironment):
self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock"
self._establish_connection()
# Persistent shell state
self._shell_proc: subprocess.Popen | None = None
self._shell_lock = threading.Lock()
self._shell_alive = False
self._session_id: str = ""
self._remote_stdout: str = ""
self._remote_stderr: str = ""
self._remote_status: str = ""
self._remote_cwd: str = ""
self._remote_pid: str = ""
self._remote_shell_pid: int | None = None
if self.persistent:
self._start_persistent_shell()
self._init_persistent_shell()
def _build_ssh_command(self, extra_args: list = None) -> list:
def _build_ssh_command(self, extra_args: list | None = None) -> list:
cmd = ["ssh"]
cmd.extend(["-o", f"ControlPath={self.control_socket}"])
cmd.extend(["-o", "ControlMaster=auto"])
@@ -91,230 +78,70 @@ class SSHEnvironment(BaseEnvironment):
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
# ------------------------------------------------------------------
# Persistent shell management
# PersistentShellMixin: backend-specific implementations
# ------------------------------------------------------------------
def _start_persistent_shell(self):
"""Spawn a long-lived bash shell over SSH."""
self._session_id = uuid.uuid4().hex[:12]
prefix = f"/tmp/hermes-ssh-{self._session_id}"
self._remote_stdout = f"{prefix}-stdout"
self._remote_stderr = f"{prefix}-stderr"
self._remote_status = f"{prefix}-status"
self._remote_cwd = f"{prefix}-cwd"
self._remote_pid = f"{prefix}-pid"
@property
def _temp_prefix(self) -> str:
return f"/tmp/hermes-ssh-{self._session_id}"
def _spawn_shell_process(self) -> subprocess.Popen:
cmd = self._build_ssh_command()
cmd.append("bash -l")
self._shell_proc = subprocess.Popen(
return subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
self._shell_alive = True
# Start daemon thread to drain stdout/stderr and detect shell death
self._drain_thread = threading.Thread(
target=self._drain_shell_output, daemon=True
def _read_temp_files(self, *paths: str) -> list[str]:
"""Read remote files via ControlMaster one-shot SSH calls."""
if len(paths) == 1:
cmd = self._build_ssh_command()
cmd.append(f"cat {paths[0]} 2>/dev/null")
try:
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=10,
)
return [result.stdout]
except (subprocess.TimeoutExpired, OSError):
return [""]
delim = f"__HERMES_SEP_{self._session_id}__"
script = "; ".join(
f"cat {p} 2>/dev/null; echo '{delim}'" for p in paths
)
self._drain_thread.start()
# Initialize remote temp files and capture shell PID
init_script = (
f"touch {self._remote_stdout} {self._remote_stderr} "
f"{self._remote_status} {self._remote_cwd} {self._remote_pid}\n"
f"echo $$ > {self._remote_pid}\n"
f"pwd > {self._remote_cwd}\n"
)
self._send_to_shell(init_script)
# Give shell time to initialize and write PID file
time.sleep(0.3)
# Read the remote shell PID
pid_str = self._read_remote_file(self._remote_pid).strip()
if pid_str.isdigit():
self._remote_shell_pid = int(pid_str)
logger.info("Persistent shell started (session=%s, pid=%d)",
self._session_id, self._remote_shell_pid)
else:
logger.warning("Could not read persistent shell PID (got %r)", pid_str)
self._remote_shell_pid = None
# Update cwd from what the shell reports
remote_cwd = self._read_remote_file(self._remote_cwd).strip()
if remote_cwd:
self.cwd = remote_cwd
def _drain_shell_output(self):
"""Drain the shell's stdout/stderr to prevent pipe deadlock.
Also detects when the shell process dies.
"""
try:
for _ in self._shell_proc.stdout:
pass # Discard — real output goes to temp files
except Exception:
pass
self._shell_alive = False
def _send_to_shell(self, text: str):
"""Write text to the persistent shell's stdin."""
if not self._shell_alive or self._shell_proc is None:
return
try:
self._shell_proc.stdin.write(text)
self._shell_proc.stdin.flush()
except (BrokenPipeError, OSError):
self._shell_alive = False
def _read_remote_file(self, path: str) -> str:
"""Read a file on the remote host via a one-shot SSH command.
Uses ControlMaster so this is very fast (~5ms on LAN).
"""
cmd = self._build_ssh_command()
cmd.append(f"cat {path} 2>/dev/null")
cmd.append(script)
try:
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=10
cmd, capture_output=True, text=True, timeout=10,
)
return result.stdout
parts = result.stdout.split(delim + "\n")
return [parts[i] if i < len(parts) else "" for i in range(len(paths))]
except (subprocess.TimeoutExpired, OSError):
return ""
return [""] * len(paths)
def _kill_shell_children(self):
"""Kill children of the persistent shell (the running command),
but not the shell itself."""
if self._remote_shell_pid is None:
if self._shell_pid is None:
return
cmd = self._build_ssh_command()
cmd.append(f"pkill -P {self._remote_shell_pid} 2>/dev/null; true")
cmd.append(f"pkill -P {self._shell_pid} 2>/dev/null; true")
try:
subprocess.run(cmd, capture_output=True, timeout=5)
except (subprocess.TimeoutExpired, OSError):
pass
def _execute_persistent(self, command: str, cwd: str, *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
"""Execute a command in the persistent shell."""
# If shell is dead, restart it
if not self._shell_alive:
logger.info("Persistent shell died, restarting...")
self._start_persistent_shell()
exec_command, sudo_stdin = self._prepare_command(command)
effective_timeout = timeout or self.timeout
# Fall back to one-shot for commands needing piped stdin
if stdin_data or sudo_stdin:
return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data
)
with self._shell_lock:
return self._execute_persistent_locked(
exec_command, cwd, effective_timeout
)
def _execute_persistent_locked(self, command: str, cwd: str,
timeout: int) -> dict:
"""Inner persistent execution — caller must hold _shell_lock."""
work_dir = cwd or self.cwd
# Truncate temp files
truncate = (
f": > {self._remote_stdout}\n"
f": > {self._remote_stderr}\n"
f": > {self._remote_status}\n"
)
self._send_to_shell(truncate)
# Escape command for eval — use single quotes with proper escaping
escaped = command.replace("'", "'\\''")
# Send the IPC script
ipc_script = (
f"cd {shlex.quote(work_dir)}\n"
f"eval '{escaped}' < /dev/null > {self._remote_stdout} 2> {self._remote_stderr}\n"
f"__EC=$?\n"
f"pwd > {self._remote_cwd}\n"
f"echo $__EC > {self._remote_status}\n"
)
self._send_to_shell(ipc_script)
# Poll the status file
deadline = time.monotonic() + timeout
poll_interval = 0.05 # 50ms
while True:
if is_interrupted():
self._kill_shell_children()
stdout = self._read_remote_file(self._remote_stdout)
stderr = self._read_remote_file(self._remote_stderr)
output = self._merge_output(stdout, stderr)
return {
"output": output + "\n[Command interrupted]",
"returncode": 130,
}
if time.monotonic() > deadline:
self._kill_shell_children()
stdout = self._read_remote_file(self._remote_stdout)
stderr = self._read_remote_file(self._remote_stderr)
output = self._merge_output(stdout, stderr)
if output:
return {
"output": output + f"\n[Command timed out after {timeout}s]",
"returncode": 124,
}
return self._timeout_result(timeout)
if not self._shell_alive:
return {
"output": "Persistent shell died during execution",
"returncode": 1,
}
# Check if status file has content (command is done)
status_content = self._read_remote_file(self._remote_status).strip()
if status_content:
break
time.sleep(poll_interval)
# Read results
stdout = self._read_remote_file(self._remote_stdout)
stderr = self._read_remote_file(self._remote_stderr)
exit_code_str = status_content
new_cwd = self._read_remote_file(self._remote_cwd).strip()
# Parse exit code
def _cleanup_temp_files(self):
"""Remove remote temp files via SSH."""
try:
exit_code = int(exit_code_str)
except ValueError:
exit_code = 1
# Update cwd
if new_cwd:
self.cwd = new_cwd
output = self._merge_output(stdout, stderr)
return {"output": output, "returncode": exit_code}
@staticmethod
def _merge_output(stdout: str, stderr: str) -> str:
"""Combine stdout and stderr into a single output string."""
parts = []
if stdout.strip():
parts.append(stdout.rstrip("\n"))
if stderr.strip():
parts.append(stderr.rstrip("\n"))
return "\n".join(parts)
cmd = self._build_ssh_command()
cmd.append(f"rm -f {self._temp_prefix}-*")
subprocess.run(cmd, capture_output=True, timeout=5)
except (OSError, subprocess.SubprocessError):
pass
# ------------------------------------------------------------------
# One-shot execution (original behavior)
@@ -413,34 +240,9 @@ class SSHEnvironment(BaseEnvironment):
)
def cleanup(self):
# Persistent shell teardown
if self.persistent and self._shell_proc is not None:
# Remove remote temp files
if self._session_id:
try:
cmd = self._build_ssh_command()
cmd.append(
f"rm -f /tmp/hermes-ssh-{self._session_id}-*"
)
subprocess.run(cmd, capture_output=True, timeout=5)
except (OSError, subprocess.SubprocessError):
pass
# Close the shell
try:
self._shell_proc.stdin.close()
except Exception:
pass
try:
self._shell_proc.terminate()
self._shell_proc.wait(timeout=3)
except Exception:
try:
self._shell_proc.kill()
except Exception:
pass
self._shell_alive = False
self._shell_proc = None
# Persistent shell teardown (via mixin)
if self.persistent:
self._cleanup_persistent_shell()
# ControlMaster cleanup
if self.control_socket.exists():

View File

@@ -504,6 +504,8 @@ def _get_env_config() -> Dict[str, Any]:
"ssh_port": _parse_env_var("TERMINAL_SSH_PORT", "22"),
"ssh_key": os.getenv("TERMINAL_SSH_KEY", ""),
"ssh_persistent": os.getenv("TERMINAL_SSH_PERSISTENT", "false").lower() in ("true", "1", "yes"),
# Local persistent shell (cwd/env vars survive across calls)
"local_persistent": os.getenv("TERMINAL_LOCAL_PERSISTENT", "false").lower() in ("true", "1", "yes"),
# Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh)
"container_cpu": _parse_env_var("TERMINAL_CONTAINER_CPU", "1", float, "number"),
"container_memory": _parse_env_var("TERMINAL_CONTAINER_MEMORY", "5120"), # MB (default 5GB)
@@ -515,6 +517,7 @@ def _get_env_config() -> Dict[str, Any]:
def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
ssh_config: dict = None, container_config: dict = None,
local_config: dict = None,
task_id: str = "default"):
"""
Create an execution environment from mini-swe-agent.
@@ -539,7 +542,9 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
volumes = cc.get("docker_volumes", [])
if env_type == "local":
return _LocalEnvironment(cwd=cwd, timeout=timeout)
lc = local_config or {}
return _LocalEnvironment(cwd=cwd, timeout=timeout,
persistent=lc.get("persistent", False))
elif env_type == "docker":
return _DockerEnvironment(
@@ -938,6 +943,12 @@ def terminal_tool(
"docker_volumes": config.get("docker_volumes", []),
}
local_config = None
if env_type == "local":
local_config = {
"persistent": config.get("local_persistent", False),
}
new_env = _create_environment(
env_type=env_type,
image=image,
@@ -945,6 +956,7 @@ def terminal_tool(
timeout=effective_timeout,
ssh_config=ssh_config,
container_config=container_config,
local_config=local_config,
task_id=effective_task_id,
)
except ImportError as e: