diff --git a/scripts/agent_dispatch.py b/scripts/agent_dispatch.py index de5ff966..d5d05d21 100644 --- a/scripts/agent_dispatch.py +++ b/scripts/agent_dispatch.py @@ -9,7 +9,12 @@ Replaces ad-hoc dispatch scripts with a unified framework for tasking agents. import os import sys import argparse -import subprocess + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +if SCRIPT_DIR not in sys.path: + sys.path.insert(0, SCRIPT_DIR) + +from ssh_trust import VerifiedSSHExecutor # --- CONFIGURATION --- FLEET = { @@ -18,6 +23,9 @@ FLEET = { } class Dispatcher: + def __init__(self, executor=None): + self.executor = executor or VerifiedSSHExecutor() + def log(self, message: str): print(f"[*] {message}") @@ -25,14 +33,14 @@ class Dispatcher: self.log(f"Dispatching task to {agent_name} on {host}...") ip = FLEET[host] - # Command to run the agent on the remote machine - # Assumes hermes-agent is installed in /opt/hermes - remote_cmd = f"cd /opt/hermes && python3 run_agent.py --agent {agent_name} --task '{task}'" - - ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", f"root@{ip}", remote_cmd] - + try: - res = subprocess.run(ssh_cmd, capture_output=True, text=True) + res = self.executor.run( + ip, + ['python3', 'run_agent.py', '--agent', agent_name, '--task', task], + cwd='/opt/hermes', + timeout=30, + ) if res.returncode == 0: self.log(f"[SUCCESS] {agent_name} completed task.") print(res.stdout) diff --git a/scripts/fleet_llama.py b/scripts/fleet_llama.py index 5c73243e..5a914d86 100644 --- a/scripts/fleet_llama.py +++ b/scripts/fleet_llama.py @@ -11,10 +11,15 @@ import os import sys import json import argparse -import subprocess import requests from typing import Dict, List, Any +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +if SCRIPT_DIR not in sys.path: + sys.path.insert(0, SCRIPT_DIR) + +from ssh_trust import VerifiedSSHExecutor + # --- FLEET DEFINITION --- FLEET = { "mac": {"ip": "10.1.10.77", "port": 8080, "role": "hub"}, @@ -24,8 +29,9 @@ FLEET = { } class FleetManager: - def __init__(self): + def __init__(self, executor=None): self.results = {} + self.executor = executor or VerifiedSSHExecutor() def run_remote(self, host: str, command: str): ip = FLEET[host]["ip"] diff --git a/scripts/provision_wizard.py b/scripts/provision_wizard.py index 93b17139..2a0902f3 100644 --- a/scripts/provision_wizard.py +++ b/scripts/provision_wizard.py @@ -15,10 +15,15 @@ import sys import time import argparse import requests -import subprocess import json from typing import Optional, Dict, Any +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +if SCRIPT_DIR not in sys.path: + sys.path.insert(0, SCRIPT_DIR) + +from ssh_trust import VerifiedSSHExecutor + # --- CONFIGURATION --- DO_API_URL = "https://api.digitalocean.com/v2" # We expect DIGITALOCEAN_TOKEN to be set in the environment. @@ -30,13 +35,14 @@ DEFAULT_IMAGE = "ubuntu-22-04-x64" LLAMA_CPP_REPO = "https://github.com/ggerganov/llama.cpp" class Provisioner: - def __init__(self, name: str, size: str, model: str, region: str = DEFAULT_REGION): + def __init__(self, name: str, size: str, model: str, region: str = DEFAULT_REGION, executor=None): self.name = name self.size = size self.model = model self.region = region self.droplet_id = None self.ip_address = None + self.executor = executor or VerifiedSSHExecutor(auto_enroll=True) def log(self, message: str): print(f"[*] {message}") @@ -104,13 +110,8 @@ class Provisioner: self.log(f"Droplet IP: {self.ip_address}") def run_remote(self, command: str): - # Using subprocess to call ssh. Assumes local machine has the right private key. - ssh_cmd = [ - "ssh", "-o", "StrictHostKeyChecking=no", - f"root@{self.ip_address}", command - ] - result = subprocess.run(ssh_cmd, capture_output=True, text=True) - return result + # Uses verified host trust. Brand-new nodes explicitly enroll on first contact. + return self.executor.run_script(self.ip_address, command, timeout=60) def setup_wizard(self): self.log("Starting remote setup...") diff --git a/scripts/self_healing.py b/scripts/self_healing.py index b33cd9bd..d5900582 100644 --- a/scripts/self_healing.py +++ b/scripts/self_healing.py @@ -10,12 +10,17 @@ Safe-by-default: runs in dry-run mode unless --execute is given. import os import sys -import subprocess import argparse import requests import datetime from typing import Sequence +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +if SCRIPT_DIR not in sys.path: + sys.path.insert(0, SCRIPT_DIR) + +from ssh_trust import VerifiedSSHExecutor + # --- CONFIGURATION --- FLEET = { "mac": {"ip": "10.1.10.77", "port": 8080}, @@ -25,54 +30,24 @@ FLEET = { } class SelfHealer: - def __init__(self, dry_run=True, confirm_kill=False, yes=False): + def __init__(self, dry_run=True, confirm_kill=False, yes=False, executor=None): self.dry_run = dry_run self.confirm_kill = confirm_kill self.yes = yes + self.executor = executor or VerifiedSSHExecutor() def log(self, message: str): timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(f"[{timestamp}] {message}") def run_remote(self, host: str, command: str): - ip = FLEET[host]["ip"] - ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5", f"root@{ip}", command] - if host == "mac": - ssh_cmd = ["bash", "-c", command] + ip = FLEET[host]['ip'] try: - return subprocess.run(ssh_cmd, capture_output=True, text=True, timeout=15) + return self.executor.run_script(ip, command, local=(host == 'mac'), timeout=15) except Exception as e: self.log(f" [ERROR] Failed to run remote command on {host}: {e}") return None - def confirm(self, prompt: str) -> bool: - """Ask for confirmation unless --yes flag is set.""" - if self.yes: - return True - while True: - response = input(f"{prompt} [y/N] ").strip().lower() - if response in ("y", "yes"): - return True - elif response in ("n", "no", ""): - return False - print("Please answer 'y' or 'n'.") - - def check_llama_server(self, host: str): - ip = FLEET[host]["ip"] - port = FLEET[host]["port"] - try: - requests.get(f"http://{ip}:{port}/health", timeout=2) - except: - self.log(f" [!] llama-server down on {host}.") - if self.dry_run: - self.log(f" [DRY-RUN] Would restart llama-server on {host}") - else: - if self.confirm(f" Restart llama-server on {host}?"): - self.log(f" Restarting llama-server on {host}...") - self.run_remote(host, "systemctl restart llama-server") - else: - self.log(f" Skipped restart on {host}.") - def check_disk_space(self, host: str): res = self.run_remote(host, "df -h / | tail -1 | awk '{print $5}' | sed 's/%//'") if res and res.returncode == 0: diff --git a/scripts/ssh_trust.py b/scripts/ssh_trust.py new file mode 100644 index 00000000..3d37ad62 --- /dev/null +++ b/scripts/ssh_trust.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +"""Verified SSH trust helpers for Gemini infrastructure scripts.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Callable, Sequence +import shlex +import subprocess + +DEFAULT_KNOWN_HOSTS = Path(__file__).resolve().parent.parent / ".ssh" / "known_hosts" +Runner = Callable[..., subprocess.CompletedProcess] + + +class SSHTrustError(RuntimeError): + pass + + +class HostKeyEnrollmentError(SSHTrustError): + pass + + +class HostKeyVerificationError(SSHTrustError): + pass + + +class CommandPlan: + def __init__(self, argv: list[str], local: bool, remote_command: str | None = None): + self.argv = argv + self.local = local + self.remote_command = remote_command + + +def _ensure_parent(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + + +def enroll_host_key( + host: str, + *, + port: int = 22, + known_hosts_path: str | Path | None = None, + runner: Runner = subprocess.run, +) -> Path: + path = Path(known_hosts_path or DEFAULT_KNOWN_HOSTS) + _ensure_parent(path) + cmd = ["ssh-keyscan", "-p", str(port), "-H", host] + result = runner(cmd, capture_output=True, text=True, timeout=10) + if result.returncode != 0 or not (result.stdout or "").strip(): + raise HostKeyEnrollmentError( + f"Could not enroll host key for {host}:{port}: {(result.stderr or '').strip() or 'empty ssh-keyscan output'}" + ) + + existing = [] + if path.exists(): + existing = [line for line in path.read_text().splitlines() if line.strip()] + for line in result.stdout.splitlines(): + line = line.strip() + if line and line not in existing: + existing.append(line) + path.write_text(("\n".join(existing) + "\n") if existing else "") + return path + + +class VerifiedSSHExecutor: + def __init__( + self, + *, + user: str = "root", + known_hosts_path: str | Path | None = None, + connect_timeout: int = 5, + auto_enroll: bool = False, + runner: Runner = subprocess.run, + ): + self.user = user + self.known_hosts_path = Path(known_hosts_path or DEFAULT_KNOWN_HOSTS) + self.connect_timeout = connect_timeout + self.auto_enroll = auto_enroll + self.runner = runner + + def _ensure_known_hosts(self, host: str, port: int) -> Path: + if self.known_hosts_path.exists(): + return self.known_hosts_path + if not self.auto_enroll: + raise HostKeyEnrollmentError( + f"Known-hosts file missing: {self.known_hosts_path}. Enroll {host}:{port} before connecting." + ) + return enroll_host_key(host, port=port, known_hosts_path=self.known_hosts_path, runner=self.runner) + + def _ssh_prefix(self, host: str, port: int) -> list[str]: + known_hosts = self._ensure_known_hosts(host, port) + return [ + "ssh", + "-o", + "BatchMode=yes", + "-o", + "StrictHostKeyChecking=yes", + "-o", + f"UserKnownHostsFile={known_hosts}", + "-o", + f"ConnectTimeout={self.connect_timeout}", + "-p", + str(port), + f"{self.user}@{host}", + ] + + def plan( + self, + host: str, + command: Sequence[str], + *, + local: bool = False, + port: int = 22, + cwd: str | None = None, + ) -> CommandPlan: + argv = [str(part) for part in command] + if not argv: + raise ValueError("command must not be empty") + if local: + return CommandPlan(argv=argv, local=True, remote_command=None) + + remote_command = shlex.join(argv) + if cwd: + remote_command = f"cd {shlex.quote(cwd)} && exec {remote_command}" + return CommandPlan(self._ssh_prefix(host, port) + [remote_command], False, remote_command) + + def plan_script( + self, + host: str, + script_text: str, + *, + local: bool = False, + port: int = 22, + cwd: str | None = None, + ) -> CommandPlan: + remote_command = script_text.strip() + if cwd: + remote_command = f"cd {shlex.quote(cwd)} && {remote_command}" + if local: + return CommandPlan(["sh", "-lc", remote_command], True, None) + return CommandPlan(self._ssh_prefix(host, port) + [remote_command], False, remote_command) + + def _run_plan(self, plan: CommandPlan, *, timeout: int | None = None): + result = self.runner(plan.argv, capture_output=True, text=True, timeout=timeout) + if result.returncode != 0 and "host key verification failed" in ((result.stderr or "").lower()): + raise HostKeyVerificationError((result.stderr or "").strip() or "Host key verification failed") + return result + + def run( + self, + host: str, + command: Sequence[str], + *, + local: bool = False, + port: int = 22, + cwd: str | None = None, + timeout: int | None = None, + ): + return self._run_plan(self.plan(host, command, local=local, port=port, cwd=cwd), timeout=timeout) + + def run_script( + self, + host: str, + script_text: str, + *, + local: bool = False, + port: int = 22, + cwd: str | None = None, + timeout: int | None = None, + ): + return self._run_plan(self.plan_script(host, script_text, local=local, port=port, cwd=cwd), timeout=timeout) diff --git a/scripts/telemetry.py b/scripts/telemetry.py index 3bab9fa3..c97bd9f4 100644 --- a/scripts/telemetry.py +++ b/scripts/telemetry.py @@ -10,9 +10,14 @@ import os import sys import json import time -import subprocess import argparse +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +if SCRIPT_DIR not in sys.path: + sys.path.insert(0, SCRIPT_DIR) + +from ssh_trust import VerifiedSSHExecutor + # --- CONFIGURATION --- FLEET = { "mac": "10.1.10.77", @@ -23,7 +28,8 @@ FLEET = { TELEMETRY_FILE = "logs/telemetry.json" class Telemetry: - def __init__(self): + def __init__(self, executor=None): + self.executor = executor or VerifiedSSHExecutor() # Find logs relative to repo root script_dir = os.path.dirname(os.path.abspath(__file__)) repo_root = os.path.dirname(script_dir) @@ -41,14 +47,12 @@ class Telemetry: # Command to get disk usage, memory usage (%), and load avg cmd = "df -h / | tail -1 | awk '{print $5}' && free -m | grep Mem | awk '{print $3/$2 * 100}' && uptime | awk '{print $10}'" - ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", f"root@{ip}", cmd] - if host == "mac": + if host == 'mac': # Mac specific commands cmd = "df -h / | tail -1 | awk '{print $5}' && sysctl -n vm.page_pageable_internal_count && uptime | awk '{print $10}'" - ssh_cmd = ["bash", "-c", cmd] - + try: - res = subprocess.run(ssh_cmd, capture_output=True, text=True, timeout=10) + res = self.executor.run_script(ip, cmd, local=(host == 'mac'), timeout=10) if res.returncode == 0: lines = res.stdout.strip().split("\n") return { diff --git a/tests/test_ssh_trust.py b/tests/test_ssh_trust.py new file mode 100644 index 00000000..3a80460f --- /dev/null +++ b/tests/test_ssh_trust.py @@ -0,0 +1,93 @@ +"""Tests for scripts/ssh_trust.py verified SSH trust helpers.""" + +from __future__ import annotations + +import importlib.util +import shlex +import subprocess +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).parent.parent +spec = importlib.util.spec_from_file_location("ssh_trust", REPO_ROOT / "scripts" / "ssh_trust.py") +ssh_trust = importlib.util.module_from_spec(spec) +spec.loader.exec_module(ssh_trust) + + +def test_enroll_host_key_writes_scanned_key(tmp_path): + calls = [] + known_hosts = tmp_path / "known_hosts" + + def fake_run(argv, capture_output, text, timeout): + calls.append(argv) + return subprocess.CompletedProcess( + argv, + 0, + stdout="example.com ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITestKey\n", + stderr="", + ) + + written_path = ssh_trust.enroll_host_key( + "example.com", + port=2222, + known_hosts_path=known_hosts, + runner=fake_run, + ) + + assert written_path == known_hosts + assert known_hosts.read_text() == "example.com ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITestKey\n" + assert calls == [["ssh-keyscan", "-p", "2222", "-H", "example.com"]] + + +def test_executor_requires_known_hosts_or_auto_enroll(tmp_path): + executor = ssh_trust.VerifiedSSHExecutor( + known_hosts_path=tmp_path / "known_hosts", + auto_enroll=False, + ) + + with pytest.raises(ssh_trust.HostKeyEnrollmentError): + executor.plan("203.0.113.10", ["echo", "ok"]) + + +def test_remote_command_is_quoted_and_local_execution_stays_shell_free(tmp_path): + known_hosts = tmp_path / "known_hosts" + known_hosts.write_text("203.0.113.10 ssh-ed25519 AAAAC3NzaTest\n") + executor = ssh_trust.VerifiedSSHExecutor(known_hosts_path=known_hosts) + + command = ["python3", "run_agent.py", "--task", "hello 'quoted' world"] + plan = executor.plan("203.0.113.10", command, port=2222) + + expected_remote_command = shlex.join(command) + assert plan.local is False + assert plan.remote_command == expected_remote_command + assert plan.argv[-1] == expected_remote_command + assert "StrictHostKeyChecking=yes" in plan.argv + assert f"UserKnownHostsFile={known_hosts}" in plan.argv + assert plan.argv[-2] == "root@203.0.113.10" + + local_plan = executor.plan("127.0.0.1", ["python3", "-V"], local=True) + assert local_plan.local is True + assert local_plan.argv == ["python3", "-V"] + assert local_plan.remote_command is None + + +def test_run_raises_host_key_verification_error(tmp_path): + known_hosts = tmp_path / "known_hosts" + known_hosts.write_text("203.0.113.10 ssh-ed25519 AAAAC3NzaTest\n") + + def fake_run(argv, capture_output, text, timeout): + return subprocess.CompletedProcess( + argv, + 255, + stdout="", + stderr="Host key verification failed.\n", + ) + + executor = ssh_trust.VerifiedSSHExecutor( + known_hosts_path=known_hosts, + runner=fake_run, + ) + + with pytest.raises(ssh_trust.HostKeyVerificationError): + executor.run("203.0.113.10", ["true"])