#!/usr/bin/env python3 """ config_drift_detector.py — Detect config drift across fleet nodes. Collects config from all wizard nodes via SSH, compares against canonical timmy-config golden state, and reports differences. Usage: python3 scripts/config_drift_detector.py # Report only python3 scripts/config_drift_detector.py --auto-sync # Auto-fix drift with golden state python3 scripts/config_drift_detector.py --node allegro # Check single node python3 scripts/config_drift_detector.py --json # JSON output for automation Exit codes: 0 — no drift detected 1 — drift detected 2 — error (SSH failure, missing deps, etc.) """ import argparse import json import os import subprocess import sys import tempfile from dataclasses import dataclass, field from pathlib import Path from typing import Any import yaml # ── Constants ───────────────────────────────────────────────────────────────── SCRIPT_DIR = Path(__file__).resolve().parent REPO_ROOT = SCRIPT_DIR.parent ANSIBLE_INVENTORY = REPO_ROOT / "ansible" / "inventory" / "hosts.yml" GOLDEN_STATE_PLAYBOOK = REPO_ROOT / "ansible" / "playbooks" / "golden_state.yml" # Config files to check on each node CONFIG_PATHS = [ ".hermes/config.yaml", "wizards/{name}/config.yaml", ] # Keys that define golden state (from ansible inventory vars) GOLDEN_KEYS = [ "providers", "provider", "model", "base_url", "api_key_env", "banned_providers", "banned_models_patterns", ] # ── Data Models ─────────────────────────────────────────────────────────────── @dataclass class NodeConfig: name: str host: str configs: dict[str, Any] = field(default_factory=dict) errors: list[str] = field(default_factory=list) reachable: bool = True @dataclass class DriftResult: node: str file_path: str diff_type: str # "missing", "value_mismatch", "key_missing", "extra_key" key: str canonical_value: Any = None node_value: Any = None severity: str = "warning" # "info", "warning", "critical" # ── Inventory Parsing ───────────────────────────────────────────────────────── def load_inventory() -> dict: """Load Ansible inventory and extract wizard node definitions.""" if not ANSIBLE_INVENTORY.exists(): print(f"ERROR: Inventory not found at {ANSIBLE_INVENTORY}", file=sys.stderr) sys.exit(2) with open(ANSIBLE_INVENTORY) as f: inventory = yaml.safe_load(f) wizards = inventory.get("all", {}).get("children", {}).get("wizards", {}).get("hosts", {}) global_vars = inventory.get("all", {}).get("vars", {}) nodes = {} for name, config in wizards.items(): nodes[name] = { "host": config.get("ansible_host", "localhost"), "user": config.get("ansible_user", ""), "wizard_name": config.get("wizard_name", name), "hermes_home": config.get("hermes_home", "~/.hermes"), "wizard_home": config.get("wizard_home", f"~/wizards/{name}"), "machine_type": config.get("machine_type", "unknown"), } return nodes, global_vars def load_golden_state(inventory_vars: dict) -> dict: """Extract golden state from inventory vars.""" golden = { "providers": inventory_vars.get("golden_state_providers", []), "banned_providers": inventory_vars.get("banned_providers", []), "banned_models_patterns": inventory_vars.get("banned_models_patterns", []), } return golden # ── SSH Collection ──────────────────────────────────────────────────────────── def ssh_collect(node_name: str, node_info: dict, timeout: int = 15) -> NodeConfig: """SSH into a node and collect config files.""" host = node_info["host"] user = node_info.get("user", "") hermes_home = node_info.get("hermes_home", "~/.hermes") wizard_home = node_info.get("wizard_home", f"~/wizards/{node_name}") result = NodeConfig(name=node_name, host=host) # Build SSH target if host in ("localhost", "127.0.0.1"): ssh_target = None # local else: ssh_target = f"{user}@{host}" if user else host # Collect each config path for path_template in CONFIG_PATHS: # Resolve path template remote_path = path_template.replace("{name}", node_name) if not remote_path.startswith("/"): # Resolve relative to home if "wizards/" in remote_path: full_path = f"{wizard_home}/config.yaml" else: full_path = f"{hermes_home}/config.yaml" if ".hermes" in remote_path else f"~/{remote_path}" else: full_path = remote_path config_content = _remote_cat(ssh_target, full_path, timeout) if config_content is not None: try: parsed = yaml.safe_load(config_content) if parsed: result.configs[full_path] = parsed except yaml.YAMLError as e: result.errors.append(f"YAML parse error in {full_path}: {e}") # Don't flag missing files as errors — some paths may not exist on all nodes # Also collect banned provider scan banned_check = _remote_grep( ssh_target, hermes_home, r"anthropic|claude-sonnet|claude-opus|claude-haiku", timeout ) if banned_check: result.configs["__banned_scan__"] = banned_check return result def _remote_cat(ssh_target: str | None, path: str, timeout: int) -> str | None: """Cat a file remotely (or locally).""" if ssh_target is None: cmd = ["cat", path] else: cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no", ssh_target, f"cat {path}"] try: proc = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) if proc.returncode == 0: return proc.stdout except subprocess.TimeoutExpired: pass except FileNotFoundError: pass return None def _remote_grep(ssh_target: str | None, base_path: str, pattern: str, timeout: int) -> dict: """Grep for banned patterns in config files.""" if ssh_target is None: cmd = ["grep", "-rn", "-i", pattern, base_path, "--include=*.yaml", "--include=*.yml"] else: cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no", ssh_target, f"grep -rn -i '{pattern}' {base_path} --include='*.yaml' --include='*.yml' 2>/dev/null || true"] try: proc = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) if proc.stdout.strip(): lines = proc.stdout.strip().split("\n") return {"matches": lines, "count": len(lines)} except subprocess.TimeoutExpired: pass return {} # ── Drift Detection ─────────────────────────────────────────────────────────── def detect_drift(nodes: list[NodeConfig], golden: dict) -> list[DriftResult]: """Compare each node's config against golden state.""" results = [] for node in nodes: if not node.reachable: continue # Check for banned providers banned_scan = node.configs.get("__banned_scan__", {}) if banned_scan.get("count", 0) > 0: for match in banned_scan.get("matches", []): results.append(DriftResult( node=node.name, file_path="(config files)", diff_type="banned_provider_found", key="banned_provider_reference", node_value=match, severity="critical" )) # Check each config file for path, config in node.configs.items(): if path == "__banned_scan__": continue # Check provider chain if isinstance(config, dict): node_providers = _extract_provider_chain(config) golden_providers = golden.get("providers", []) if node_providers and golden_providers: # Compare provider names in order node_names = [p.get("name", "") for p in node_providers] golden_names = [p.get("name", "") for p in golden_providers] if node_names != golden_names: results.append(DriftResult( node=node.name, file_path=path, diff_type="value_mismatch", key="provider_chain", canonical_value=golden_names, node_value=node_names, severity="critical" )) # Check for banned providers in node config for banned in golden.get("banned_providers", []): for provider in node_providers: prov_name = provider.get("name", "").lower() prov_model = provider.get("model", "").lower() if banned in prov_name or banned in prov_model: results.append(DriftResult( node=node.name, file_path=path, diff_type="banned_provider_found", key=f"provider.{provider.get('name', 'unknown')}", node_value=provider, severity="critical" )) # Check for missing critical keys critical_keys = ["display", "providers", "tools", "delegation"] for key in critical_keys: if key not in config and key in str(config): results.append(DriftResult( node=node.name, file_path=path, diff_type="key_missing", key=key, canonical_value="(present in golden state)", severity="warning" )) return results def _extract_provider_chain(config: dict) -> list[dict]: """Extract provider list from a config dict (handles multiple formats).""" # Direct providers key if "providers" in config: providers = config["providers"] if isinstance(providers, list): return providers # Nested in display or model config for key in ["model", "inference", "llm"]: if key in config and isinstance(config[key], dict): if "providers" in config[key]: return config[key]["providers"] # Single provider format if "provider" in config and "model" in config: return [{"name": config["provider"], "model": config["model"]}] return [] # ── Auto-Sync ───────────────────────────────────────────────────────────────── def auto_sync(drifts: list[DriftResult], nodes: list[NodeConfig]) -> list[str]: """Auto-sync drifted nodes using golden state playbook.""" actions = [] drifted_nodes = set(d.node for d in drifts if d.severity == "critical") if not drifted_nodes: actions.append("No critical drift to sync.") return actions for node_name in drifted_nodes: node_info = next((n for n in nodes if n.name == node_name), None) if not node_info: continue actions.append(f"[{node_name}] Running golden state sync...") # Run ansible-playbook for this node cmd = [ "ansible-playbook", str(GOLDEN_STATE_PLAYBOOK), "-i", str(ANSIBLE_INVENTORY), "-l", node_name, "--tags", "golden", ] try: proc = subprocess.run( cmd, capture_output=True, text=True, timeout=120, cwd=str(REPO_ROOT) ) if proc.returncode == 0: actions.append(f"[{node_name}] Sync completed successfully.") else: actions.append(f"[{node_name}] Sync FAILED: {proc.stderr[:200]}") except subprocess.TimeoutExpired: actions.append(f"[{node_name}] Sync timed out after 120s.") except FileNotFoundError: actions.append(f"[{node_name}] ansible-playbook not found. Install Ansible or run manually.") return actions # ── Reporting ───────────────────────────────────────────────────────────────── def print_report(drifts: list[DriftResult], nodes: list[NodeConfig], golden: dict): """Print human-readable drift report.""" print("=" * 70) print("CONFIG DRIFT DETECTION REPORT") print("=" * 70) print() # Summary reachable = sum(1 for n in nodes if n.reachable) print(f"Nodes checked: {len(nodes)} (reachable: {reachable})") print(f"Golden state providers: {' → '.join(p['name'] for p in golden.get('providers', []))}") print(f"Banned providers: {', '.join(golden.get('banned_providers', []))}") print() if not drifts: print("[OK] No config drift detected. All nodes match golden state.") return # Group by node by_node: dict[str, list[DriftResult]] = {} for d in drifts: by_node.setdefault(d.node, []).append(d) for node_name, node_drifts in sorted(by_node.items()): print(f"--- {node_name} ---") for d in node_drifts: severity_icon = {"critical": "[!!]", "warning": "[!]", "info": "[i]"}.get(d.severity, "[?]") print(f" {severity_icon} {d.diff_type}: {d.key}") if d.canonical_value is not None: print(f" canonical: {d.canonical_value}") if d.node_value is not None: print(f" actual: {d.node_value}") print() # Severity summary critical = sum(1 for d in drifts if d.severity == "critical") warning = sum(1 for d in drifts if d.severity == "warning") print(f"Total: {len(drifts)} drift(s) — {critical} critical, {warning} warning") def print_json_report(drifts: list[DriftResult], nodes: list[NodeConfig], golden: dict): """Print JSON report for automation.""" report = { "nodes_checked": len(nodes), "reachable": sum(1 for n in nodes if n.reachable), "golden_providers": [p["name"] for p in golden.get("providers", [])], "drift_count": len(drifts), "critical_count": sum(1 for d in drifts if d.severity == "critical"), "drifts": [ { "node": d.node, "file": d.file_path, "type": d.diff_type, "key": d.key, "canonical": d.canonical_value, "actual": d.node_value, "severity": d.severity, } for d in drifts ], } print(json.dumps(report, indent=2, default=str)) # ── CLI ─────────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="Detect config drift across fleet nodes") parser.add_argument("--node", help="Check only this node") parser.add_argument("--auto-sync", action="store_true", help="Auto-fix critical drift with golden state") parser.add_argument("--json", action="store_true", help="JSON output") parser.add_argument("--timeout", type=int, default=15, help="SSH timeout per node (seconds)") args = parser.parse_args() # Load inventory print("Loading inventory...", file=sys.stderr) node_defs, global_vars = load_inventory() golden = load_golden_state(global_vars) # Filter to single node if requested if args.node: if args.node not in node_defs: print(f"ERROR: Node '{args.node}' not in inventory. Available: {', '.join(node_defs.keys())}") sys.exit(2) node_defs = {args.node: node_defs[args.node]} # Collect configs from each node print(f"Collecting configs from {len(node_defs)} node(s)...", file=sys.stderr) nodes = [] for name, info in node_defs.items(): print(f" {name} ({info['host']})...", file=sys.stderr, end=" ", flush=True) node_config = ssh_collect(name, info, timeout=args.timeout) if node_config.reachable: print(f"OK ({len(node_config.configs)} files)", file=sys.stderr) else: print("UNREACHABLE", file=sys.stderr) nodes.append(node_config) # Detect drift print("\nAnalyzing drift...", file=sys.stderr) drifts = detect_drift(nodes, golden) # Output if args.json: print_json_report(drifts, nodes, golden) else: print() print_report(drifts, nodes, golden) # Auto-sync if requested if args.auto_sync and drifts: print("\n--- AUTO-SYNC ---") actions = auto_sync(drifts, nodes) for a in actions: print(a) # Exit code if any(d.severity == "critical" for d in drifts): sys.exit(1) elif drifts: sys.exit(1) else: sys.exit(0) if __name__ == "__main__": main()