#!/usr/bin/env python3 """Verified SSH trust helpers for Gemini infrastructure scripts.""" from __future__ import annotations from pathlib import Path from typing import Callable, Sequence import shlex import subprocess DEFAULT_KNOWN_HOSTS = Path(__file__).resolve().parent.parent / ".ssh" / "known_hosts" Runner = Callable[..., subprocess.CompletedProcess] class SSHTrustError(RuntimeError): pass class HostKeyEnrollmentError(SSHTrustError): pass class HostKeyVerificationError(SSHTrustError): pass class CommandPlan: def __init__(self, argv: list[str], local: bool, remote_command: str | None = None): self.argv = argv self.local = local self.remote_command = remote_command def _ensure_parent(path: Path) -> None: path.parent.mkdir(parents=True, exist_ok=True) def enroll_host_key( host: str, *, port: int = 22, known_hosts_path: str | Path | None = None, runner: Runner = subprocess.run, ) -> Path: path = Path(known_hosts_path or DEFAULT_KNOWN_HOSTS) _ensure_parent(path) cmd = ["ssh-keyscan", "-p", str(port), "-H", host] result = runner(cmd, capture_output=True, text=True, timeout=10) if result.returncode != 0 or not (result.stdout or "").strip(): raise HostKeyEnrollmentError( f"Could not enroll host key for {host}:{port}: {(result.stderr or '').strip() or 'empty ssh-keyscan output'}" ) existing = [] if path.exists(): existing = [line for line in path.read_text().splitlines() if line.strip()] for line in result.stdout.splitlines(): line = line.strip() if line and line not in existing: existing.append(line) path.write_text(("\n".join(existing) + "\n") if existing else "") return path class VerifiedSSHExecutor: def __init__( self, *, user: str = "root", known_hosts_path: str | Path | None = None, connect_timeout: int = 5, auto_enroll: bool = False, runner: Runner = subprocess.run, ): self.user = user self.known_hosts_path = Path(known_hosts_path or DEFAULT_KNOWN_HOSTS) self.connect_timeout = connect_timeout self.auto_enroll = auto_enroll self.runner = runner def _ensure_known_hosts(self, host: str, port: int) -> Path: if self.known_hosts_path.exists(): return self.known_hosts_path if not self.auto_enroll: raise HostKeyEnrollmentError( f"Known-hosts file missing: {self.known_hosts_path}. Enroll {host}:{port} before connecting." ) return enroll_host_key(host, port=port, known_hosts_path=self.known_hosts_path, runner=self.runner) def _ssh_prefix(self, host: str, port: int) -> list[str]: known_hosts = self._ensure_known_hosts(host, port) return [ "ssh", "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=yes", "-o", f"UserKnownHostsFile={known_hosts}", "-o", f"ConnectTimeout={self.connect_timeout}", "-p", str(port), f"{self.user}@{host}", ] def plan( self, host: str, command: Sequence[str], *, local: bool = False, port: int = 22, cwd: str | None = None, ) -> CommandPlan: argv = [str(part) for part in command] if not argv: raise ValueError("command must not be empty") if local: return CommandPlan(argv=argv, local=True, remote_command=None) remote_command = shlex.join(argv) if cwd: remote_command = f"cd {shlex.quote(cwd)} && exec {remote_command}" return CommandPlan(self._ssh_prefix(host, port) + [remote_command], False, remote_command) def plan_script( self, host: str, script_text: str, *, local: bool = False, port: int = 22, cwd: str | None = None, ) -> CommandPlan: remote_command = script_text.strip() if cwd: remote_command = f"cd {shlex.quote(cwd)} && {remote_command}" if local: return CommandPlan(["sh", "-lc", remote_command], True, None) return CommandPlan(self._ssh_prefix(host, port) + [remote_command], False, remote_command) def _run_plan(self, plan: CommandPlan, *, timeout: int | None = None): result = self.runner(plan.argv, capture_output=True, text=True, timeout=timeout) if result.returncode != 0 and "host key verification failed" in ((result.stderr or "").lower()): raise HostKeyVerificationError((result.stderr or "").strip() or "Host key verification failed") return result def run( self, host: str, command: Sequence[str], *, local: bool = False, port: int = 22, cwd: str | None = None, timeout: int | None = None, ): return self._run_plan(self.plan(host, command, local=local, port=port, cwd=cwd), timeout=timeout) def run_script( self, host: str, script_text: str, *, local: bool = False, port: int = 22, cwd: str | None = None, timeout: int | None = None, ): return self._run_plan(self.plan_script(host, script_text, local=local, port=port, cwd=cwd), timeout=timeout)