#!/usr/bin/env python3 """ SSH Trust Enforcement Utility for timmy-config. Validates SSH connections before executing remote commands. Enforces host key verification, key permission checks, and connection timeouts. Exit codes: 0 = success 1 = connection failed 2 = host key mismatch 3 = timeout Usage: python3 scripts/ssh_trust.py [options] python3 scripts/ssh_trust.py --audit # scan for StrictHostKeyChecking=no python3 scripts/ssh_trust.py --check-host python3 scripts/ssh_trust.py --dry-run """ import argparse import datetime import json import logging import os import re import stat import subprocess import sys from pathlib import Path from typing import Optional # --- Exit codes --- EXIT_SUCCESS = 0 EXIT_CONN_FAIL = 1 EXIT_HOST_KEY_MISMATCH = 2 EXIT_TIMEOUT = 3 # --- Defaults --- DEFAULT_SSH_KEY = os.path.expanduser("~/.ssh/id_rsa") DEFAULT_TIMEOUT = 30 DEFAULT_CONNECT_TIMEOUT = 10 LOG_DIR = os.path.expanduser("~/.ssh_trust_logs") logger = logging.getLogger("ssh_trust") def setup_logging(log_dir: str = LOG_DIR) -> str: """Set up logging with timestamped log file.""" os.makedirs(log_dir, exist_ok=True) ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") log_file = os.path.join(log_dir, f"ssh_trust_{ts}.log") logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[ logging.FileHandler(log_file), logging.StreamHandler(sys.stderr), ], ) return log_file def check_ssh_key(key_path: str = DEFAULT_SSH_KEY) -> tuple[bool, str]: """Check that the SSH key exists and has correct permissions (600).""" path = Path(key_path) if not path.exists(): return False, f"SSH key not found: {key_path}" mode = path.stat().st_mode perms = stat.S_IMODE(mode) if perms != 0o600: return False, f"SSH key {key_path} has insecure permissions {oct(perms)} (expected 0o600)" return True, f"SSH key {key_path} OK (permissions {oct(perms)})" def verify_host_key_fingerprint(host: str, key_path: Optional[str] = None) -> tuple[int, str]: """ Verify the host key against known_hosts. Returns: (EXIT_SUCCESS, message) if key found and matches (EXIT_HOST_KEY_MISMATCH, message) if key unknown or mismatch """ # Extract hostname/IP from user@host format target = host.split("@")[-1] # Try ssh-keygen to check known_hosts try: result = subprocess.run( ["ssh-keygen", "-F", target], capture_output=True, text=True, timeout=10, ) if result.returncode == 0 and result.stdout.strip(): return EXIT_SUCCESS, f"Host key for {target} found in known_hosts" except (subprocess.TimeoutExpired, FileNotFoundError): pass # Try scanning the host key and comparing try: scan = subprocess.run( ["ssh-keyscan", "-T", "5", target], capture_output=True, text=True, timeout=15, ) if scan.returncode != 0 or not scan.stdout.strip(): return EXIT_HOST_KEY_MISMATCH, f"Cannot retrieve host key for {target}" except subprocess.TimeoutExpired: return EXIT_TIMEOUT, f"ssh-keyscan timed out for {target}" except FileNotFoundError: return EXIT_HOST_KEY_MISMATCH, "ssh-keygen/ssh-keyscan not available" return EXIT_HOST_KEY_MISMATCH, ( f"Host key for {target} NOT found in known_hosts. " f"To add it safely, run: ssh-keyscan {target} >> ~/.ssh/known_hosts" ) def test_connection(host: str, timeout: int = DEFAULT_CONNECT_TIMEOUT) -> tuple[int, str]: """ Test SSH connection with timeout. Uses ssh -o BatchMode=yes to fail fast if auth or host key issues exist. """ target = host cmd = [ "ssh", "-o", "StrictHostKeyChecking=yes", "-o", "BatchMode=yes", "-o", f"ConnectTimeout={timeout}", "-o", "NumberOfPasswordPrompts=0", target, "echo ssh_trust_ok", ] logger.info("Testing connection: ssh %s 'echo ssh_trust_ok'", target) try: result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout + 5) if result.returncode == 0 and "ssh_trust_ok" in result.stdout: return EXIT_SUCCESS, f"Connection to {host} OK" stderr = result.stderr.strip() if "Host key verification failed" in stderr: return EXIT_HOST_KEY_MISMATCH, f"Host key verification failed for {host}" if "Connection timed out" in stderr or "timed out" in stderr.lower(): return EXIT_TIMEOUT, f"Connection timed out for {host}" return EXIT_CONN_FAIL, f"Connection to {host} failed: {stderr}" except subprocess.TimeoutExpired: return EXIT_TIMEOUT, f"Connection test timed out for {host}" def execute_remote( host: str, command: str, timeout: int = DEFAULT_TIMEOUT, dry_run: bool = False, key_path: str = DEFAULT_SSH_KEY, connect_timeout: int = DEFAULT_CONNECT_TIMEOUT, ) -> tuple[int, str]: """ Safely execute a remote command via SSH. Pre-flight checks: 1. SSH key exists with 600 permissions 2. Host key verified in known_hosts 3. Connection test passes Then executes with StrictHostKeyChecking=yes (never 'no'). """ ts = datetime.datetime.now().isoformat() logger.info("=" * 60) logger.info("SSH Trust Enforcement - %s", ts) logger.info("Host: %s | Command: %s", host, command) logger.info("=" * 60) # Step 1: Check SSH key key_ok, key_msg = check_ssh_key(key_path) logger.info("[KEY CHECK] %s", key_msg) if not key_ok: return EXIT_CONN_FAIL, key_msg # Step 2: Verify host key fingerprint hk_code, hk_msg = verify_host_key_fingerprint(host) logger.info("[HOST KEY] %s", hk_msg) if hk_code != EXIT_SUCCESS: return hk_code, hk_msg # Step 3: Test connection conn_code, conn_msg = test_connection(host, connect_timeout) logger.info("[CONN TEST] %s", conn_msg) if conn_code != EXIT_SUCCESS: return conn_code, conn_msg # Step 4: Dry run if dry_run: dry_msg = f"[DRY RUN] Would execute: ssh {host} {command}" logger.info(dry_msg) return EXIT_SUCCESS, dry_msg # Step 5: Execute ssh_cmd = [ "ssh", "-o", "StrictHostKeyChecking=yes", "-o", f"ConnectTimeout={connect_timeout}", "-i", key_path, host, command, ] logger.info("[EXEC] %s", " ".join(ssh_cmd)) try: result = subprocess.run( ssh_cmd, capture_output=True, text=True, timeout=timeout, ) logger.info("[RESULT] exit_code=%d", result.returncode) if result.stdout: logger.info("[STDOUT] %s", result.stdout.strip()) if result.stderr: logger.warning("[STDERR] %s", result.stderr.strip()) # Log the full record _log_execution_record(ts, host, command, result.returncode, result.stdout, result.stderr) if result.returncode == 0: return EXIT_SUCCESS, result.stdout.strip() else: return EXIT_CONN_FAIL, f"Remote command failed (exit {result.returncode}): {result.stderr.strip()}" except subprocess.TimeoutExpired: logger.error("[TIMEOUT] Command timed out after %ds", timeout) _log_execution_record(ts, host, command, -1, "", "TIMEOUT") return EXIT_TIMEOUT, f"Remote command timed out after {timeout}s" def _log_execution_record( timestamp: str, host: str, command: str, exit_code: int, stdout: str, stderr: str, ): """Append a JSON execution record to the daily log.""" os.makedirs(LOG_DIR, exist_ok=True) log_file = os.path.join(LOG_DIR, f"executions_{datetime.date.today().isoformat()}.jsonl") record = { "timestamp": timestamp, "host": host, "command": command, "exit_code": exit_code, "stdout": stdout[:1000], "stderr": stderr[:1000], } with open(log_file, "a") as f: f.write(json.dumps(record) + "\n") # ------------------------------------------------------------------- # Audit: scan for StrictHostKeyChecking=no # ------------------------------------------------------------------- AUDIT_PATTERN = re.compile(r"StrictHostKeyChecking\s*=\s*no", re.IGNORECASE) SKIP_DIRS = {".git", "__pycache__", "node_modules", ".venv", "venv"} def audit_repo(repo_path: str) -> list[dict]: """Scan a repository for files using StrictHostKeyChecking=no.""" findings = [] for root, dirs, files in os.walk(repo_path): dirs[:] = [d for d in dirs if d not in SKIP_DIRS] for fname in files: fpath = os.path.join(root, fname) try: with open(fpath, "r", errors="ignore") as f: for line_num, line in enumerate(f, 1): if AUDIT_PATTERN.search(line): findings.append({ "file": os.path.relpath(fpath, repo_path), "line": line_num, "content": line.strip(), }) except (IOError, OSError): continue return findings def format_audit_report(findings: list[dict]) -> str: """Format audit findings as a readable report.""" if not findings: return "No instances of StrictHostKeyChecking=no found." lines = [ f"SECURITY AUDIT: Found {len(findings)} instance(s) of StrictHostKeyChecking=no", "=" * 70, ] for f in findings: lines.append(f" {f['file']}:{f['line']}") lines.append(f" {f['content']}") lines.append("") lines.append("=" * 70) lines.append( "Recommendation: Replace with StrictHostKeyChecking=yes or 'accept-new', " "and use ssh_trust.py for safe remote execution." ) return "\n".join(lines) # ------------------------------------------------------------------- # CLI # ------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser( description="SSH Trust Enforcement Utility", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Exit codes: 0 Success 1 Connection failed 2 Host key mismatch 3 Timeout Examples: %(prog)s --dry-run user@host "uptime" %(prog)s user@host "df -h" %(prog)s --check-host user@host %(prog)s --audit %(prog)s --audit /path/to/repo """, ) parser.add_argument("host", nargs="?", help="SSH target (user@host)") parser.add_argument("command", nargs="?", default="echo ok", help="Remote command to execute") parser.add_argument("--dry-run", action="store_true", help="Show what would be executed without running") parser.add_argument("--check-host", metavar="HOST", help="Only verify host key (no command execution)") parser.add_argument("--audit", nargs="?", const=".", metavar="PATH", help="Audit repo for StrictHostKeyChecking=no") parser.add_argument("--key", default=DEFAULT_SSH_KEY, help="Path to SSH private key") parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, help="Command timeout in seconds") parser.add_argument("--connect-timeout", type=int, default=DEFAULT_CONNECT_TIMEOUT, help="Connection timeout in seconds") parser.add_argument("--log-dir", default=LOG_DIR, help="Log directory") parser.add_argument("--json", action="store_true", help="Output results as JSON") args = parser.parse_args() log_file = setup_logging(args.log_dir) # Audit mode if args.audit is not None: repo_path = os.path.abspath(args.audit) if not os.path.isdir(repo_path): print(f"Error: {repo_path} is not a directory", file=sys.stderr) sys.exit(1) findings = audit_repo(repo_path) if args.json: print(json.dumps({"findings": findings, "count": len(findings)}, indent=2)) else: print(format_audit_report(findings)) sys.exit(0) # Check-host mode if args.check_host: code, msg = verify_host_key_fingerprint(args.check_host) if args.json: print(json.dumps({"host": args.check_host, "exit_code": code, "message": msg})) else: print(msg) sys.exit(code) # Require host for execution modes if not args.host: parser.error("host argument is required (unless using --audit or --check-host)") code, msg = execute_remote( host=args.host, command=args.command, timeout=args.timeout, dry_run=args.dry_run, key_path=args.key, connect_timeout=args.connect_timeout, ) if args.json: print(json.dumps({ "host": args.host, "command": args.command, "exit_code": code, "message": msg, "log_file": log_file, })) else: print(msg) sys.exit(code) if __name__ == "__main__": main()