diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index 83cc335b1..7a31006db 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -1,10 +1,12 @@ """SSH remote execution environment with ControlMaster connection persistence.""" import logging +import shlex import subprocess import tempfile import threading import time +import uuid from pathlib import Path from tools.environments.base import BaseEnvironment @@ -22,21 +24,44 @@ class SSHEnvironment(BaseEnvironment): Foreground commands are interruptible: the local ssh process is killed and a remote kill is attempted over the ControlMaster socket. + + When ``persistent=True``, a single long-lived bash shell is kept alive + over SSH and state (cwd, env vars, shell variables) persists across + ``execute()`` calls. Output capture uses file-based IPC on the remote + host (stdout/stderr/exit-code written to temp files, polled via fast + ControlMaster one-shot reads). """ def __init__(self, host: str, user: str, cwd: str = "~", - timeout: int = 60, port: int = 22, key_path: str = ""): + timeout: int = 60, port: int = 22, key_path: str = "", + persistent: bool = False): super().__init__(cwd=cwd, timeout=timeout) self.host = host self.user = user self.port = port self.key_path = key_path + self.persistent = persistent self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh" self.control_dir.mkdir(parents=True, exist_ok=True) self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock" self._establish_connection() + # Persistent shell state + self._shell_proc: subprocess.Popen | None = None + self._shell_lock = threading.Lock() + self._shell_alive = False + self._session_id: str = "" + self._remote_stdout: str = "" + self._remote_stderr: str = "" + self._remote_status: str = "" + self._remote_cwd: str = "" + self._remote_pid: str = "" + self._remote_shell_pid: int | None = None + + if self.persistent: + self._start_persistent_shell() + def _build_ssh_command(self, extra_args: list = None) -> list: cmd = ["ssh"] cmd.extend(["-o", f"ControlPath={self.control_socket}"]) @@ -65,9 +90,240 @@ class SSHEnvironment(BaseEnvironment): except subprocess.TimeoutExpired: raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out") - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: + # ------------------------------------------------------------------ + # Persistent shell management + # ------------------------------------------------------------------ + + def _start_persistent_shell(self): + """Spawn a long-lived bash shell over SSH.""" + self._session_id = uuid.uuid4().hex[:12] + prefix = f"/tmp/hermes-ssh-{self._session_id}" + self._remote_stdout = f"{prefix}-stdout" + self._remote_stderr = f"{prefix}-stderr" + self._remote_status = f"{prefix}-status" + self._remote_cwd = f"{prefix}-cwd" + self._remote_pid = f"{prefix}-pid" + + cmd = self._build_ssh_command() + cmd.append("bash -l") + + self._shell_proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + self._shell_alive = True + + # Start daemon thread to drain stdout/stderr and detect shell death + self._drain_thread = threading.Thread( + target=self._drain_shell_output, daemon=True + ) + self._drain_thread.start() + + # Initialize remote temp files and capture shell PID + init_script = ( + f"touch {self._remote_stdout} {self._remote_stderr} " + f"{self._remote_status} {self._remote_cwd} {self._remote_pid}\n" + f"echo $$ > {self._remote_pid}\n" + f"pwd > {self._remote_cwd}\n" + ) + self._send_to_shell(init_script) + + # Give shell time to initialize and write PID file + time.sleep(0.3) + + # Read the remote shell PID + pid_str = self._read_remote_file(self._remote_pid).strip() + if pid_str.isdigit(): + self._remote_shell_pid = int(pid_str) + logger.info("Persistent shell started (session=%s, pid=%d)", + self._session_id, self._remote_shell_pid) + else: + logger.warning("Could not read persistent shell PID (got %r)", pid_str) + self._remote_shell_pid = None + + # Update cwd from what the shell reports + remote_cwd = self._read_remote_file(self._remote_cwd).strip() + if remote_cwd: + self.cwd = remote_cwd + + def _drain_shell_output(self): + """Drain the shell's stdout/stderr to prevent pipe deadlock. + + Also detects when the shell process dies. + """ + try: + for _ in self._shell_proc.stdout: + pass # Discard — real output goes to temp files + except Exception: + pass + self._shell_alive = False + + def _send_to_shell(self, text: str): + """Write text to the persistent shell's stdin.""" + if not self._shell_alive or self._shell_proc is None: + return + try: + self._shell_proc.stdin.write(text) + self._shell_proc.stdin.flush() + except (BrokenPipeError, OSError): + self._shell_alive = False + + def _read_remote_file(self, path: str) -> str: + """Read a file on the remote host via a one-shot SSH command. + + Uses ControlMaster so this is very fast (~5ms on LAN). + """ + cmd = self._build_ssh_command() + cmd.append(f"cat {path} 2>/dev/null") + try: + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=10 + ) + return result.stdout + except (subprocess.TimeoutExpired, OSError): + return "" + + def _kill_shell_children(self): + """Kill children of the persistent shell (the running command), + but not the shell itself.""" + if self._remote_shell_pid is None: + return + cmd = self._build_ssh_command() + cmd.append(f"pkill -P {self._remote_shell_pid} 2>/dev/null; true") + try: + subprocess.run(cmd, capture_output=True, timeout=5) + except (subprocess.TimeoutExpired, OSError): + pass + + def _execute_persistent(self, command: str, cwd: str, *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + """Execute a command in the persistent shell.""" + # If shell is dead, restart it + if not self._shell_alive: + logger.info("Persistent shell died, restarting...") + self._start_persistent_shell() + + exec_command, sudo_stdin = self._prepare_command(command) + effective_timeout = timeout or self.timeout + + # Fall back to one-shot for commands needing piped stdin + if stdin_data or sudo_stdin: + return self._execute_oneshot( + command, cwd, timeout=timeout, stdin_data=stdin_data + ) + + with self._shell_lock: + return self._execute_persistent_locked( + exec_command, cwd, effective_timeout + ) + + def _execute_persistent_locked(self, command: str, cwd: str, + timeout: int) -> dict: + """Inner persistent execution — caller must hold _shell_lock.""" + work_dir = cwd or self.cwd + + # Truncate temp files + truncate = ( + f": > {self._remote_stdout}\n" + f": > {self._remote_stderr}\n" + f": > {self._remote_status}\n" + ) + self._send_to_shell(truncate) + + # Escape command for eval — use single quotes with proper escaping + escaped = command.replace("'", "'\\''") + + # Send the IPC script + ipc_script = ( + f"cd {shlex.quote(work_dir)}\n" + f"eval '{escaped}' < /dev/null > {self._remote_stdout} 2> {self._remote_stderr}\n" + f"__EC=$?\n" + f"pwd > {self._remote_cwd}\n" + f"echo $__EC > {self._remote_status}\n" + ) + self._send_to_shell(ipc_script) + + # Poll the status file + deadline = time.monotonic() + timeout + poll_interval = 0.05 # 50ms + + while True: + if is_interrupted(): + self._kill_shell_children() + stdout = self._read_remote_file(self._remote_stdout) + stderr = self._read_remote_file(self._remote_stderr) + output = self._merge_output(stdout, stderr) + return { + "output": output + "\n[Command interrupted]", + "returncode": 130, + } + + if time.monotonic() > deadline: + self._kill_shell_children() + stdout = self._read_remote_file(self._remote_stdout) + stderr = self._read_remote_file(self._remote_stderr) + output = self._merge_output(stdout, stderr) + if output: + return { + "output": output + f"\n[Command timed out after {timeout}s]", + "returncode": 124, + } + return self._timeout_result(timeout) + + if not self._shell_alive: + return { + "output": "Persistent shell died during execution", + "returncode": 1, + } + + # Check if status file has content (command is done) + status_content = self._read_remote_file(self._remote_status).strip() + if status_content: + break + + time.sleep(poll_interval) + + # Read results + stdout = self._read_remote_file(self._remote_stdout) + stderr = self._read_remote_file(self._remote_stderr) + exit_code_str = status_content + new_cwd = self._read_remote_file(self._remote_cwd).strip() + + # Parse exit code + try: + exit_code = int(exit_code_str) + except ValueError: + exit_code = 1 + + # Update cwd + if new_cwd: + self.cwd = new_cwd + + output = self._merge_output(stdout, stderr) + return {"output": output, "returncode": exit_code} + + @staticmethod + def _merge_output(stdout: str, stderr: str) -> str: + """Combine stdout and stderr into a single output string.""" + parts = [] + if stdout.strip(): + parts.append(stdout.rstrip("\n")) + if stderr.strip(): + parts.append(stderr.rstrip("\n")) + return "\n".join(parts) + + # ------------------------------------------------------------------ + # One-shot execution (original behavior) + # ------------------------------------------------------------------ + + def _execute_oneshot(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + """Execute a command via a fresh one-shot SSH invocation.""" work_dir = cwd or self.cwd exec_command, sudo_stdin = self._prepare_command(command) wrapped = f'cd {work_dir} && {exec_command}' @@ -141,7 +397,52 @@ class SSHEnvironment(BaseEnvironment): except Exception as e: return {"output": f"SSH execution error: {str(e)}", "returncode": 1} + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + def execute(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + if self.persistent: + return self._execute_persistent( + command, cwd, timeout=timeout, stdin_data=stdin_data + ) + return self._execute_oneshot( + command, cwd, timeout=timeout, stdin_data=stdin_data + ) + def cleanup(self): + # Persistent shell teardown + if self.persistent and self._shell_proc is not None: + # Remove remote temp files + if self._session_id: + try: + cmd = self._build_ssh_command() + cmd.append( + f"rm -f /tmp/hermes-ssh-{self._session_id}-*" + ) + subprocess.run(cmd, capture_output=True, timeout=5) + except (OSError, subprocess.SubprocessError): + pass + + # Close the shell + try: + self._shell_proc.stdin.close() + except Exception: + pass + try: + self._shell_proc.terminate() + self._shell_proc.wait(timeout=3) + except Exception: + try: + self._shell_proc.kill() + except Exception: + pass + self._shell_alive = False + self._shell_proc = None + + # ControlMaster cleanup if self.control_socket.exists(): try: cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index bf1d2b6b3..c7f72040a 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -503,6 +503,7 @@ def _get_env_config() -> Dict[str, Any]: "ssh_user": os.getenv("TERMINAL_SSH_USER", ""), "ssh_port": _parse_env_var("TERMINAL_SSH_PORT", "22"), "ssh_key": os.getenv("TERMINAL_SSH_KEY", ""), + "ssh_persistent": os.getenv("TERMINAL_SSH_PERSISTENT", "false").lower() in ("true", "1", "yes"), # Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh) "container_cpu": _parse_env_var("TERMINAL_CONTAINER_CPU", "1", float, "number"), "container_memory": _parse_env_var("TERMINAL_CONTAINER_MEMORY", "5120"), # MB (default 5GB) @@ -594,6 +595,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, key_path=ssh_config.get("key", ""), cwd=cwd, timeout=timeout, + persistent=ssh_config.get("persistent", False), ) else: @@ -923,6 +925,7 @@ def terminal_tool( "user": config.get("ssh_user", ""), "port": config.get("ssh_port", 22), "key": config.get("ssh_key", ""), + "persistent": config.get("ssh_persistent", False), } container_config = None