diff --git a/scripts/config_drift_detector.py b/scripts/config_drift_detector.py new file mode 100644 index 00000000..e7cf37c4 --- /dev/null +++ b/scripts/config_drift_detector.py @@ -0,0 +1,480 @@ +#!/usr/bin/env python3 +""" +config_drift_detector.py — Detect config drift across fleet nodes. + +Collects config from all wizard nodes via SSH, compares against +canonical timmy-config golden state, and reports differences. + +Usage: + python3 scripts/config_drift_detector.py # Report only + python3 scripts/config_drift_detector.py --auto-sync # Auto-fix drift with golden state + python3 scripts/config_drift_detector.py --node allegro # Check single node + python3 scripts/config_drift_detector.py --json # JSON output for automation + +Exit codes: + 0 — no drift detected + 1 — drift detected + 2 — error (SSH failure, missing deps, etc.) +""" + +import argparse +import json +import os +import subprocess +import sys +import tempfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import yaml + +# ── Constants ───────────────────────────────────────────────────────────────── + +SCRIPT_DIR = Path(__file__).resolve().parent +REPO_ROOT = SCRIPT_DIR.parent +ANSIBLE_INVENTORY = REPO_ROOT / "ansible" / "inventory" / "hosts.yml" +GOLDEN_STATE_PLAYBOOK = REPO_ROOT / "ansible" / "playbooks" / "golden_state.yml" + +# Config files to check on each node +CONFIG_PATHS = [ + ".hermes/config.yaml", + "wizards/{name}/config.yaml", +] + +# Keys that define golden state (from ansible inventory vars) +GOLDEN_KEYS = [ + "providers", + "provider", + "model", + "base_url", + "api_key_env", + "banned_providers", + "banned_models_patterns", +] + + +# ── Data Models ─────────────────────────────────────────────────────────────── + +@dataclass +class NodeConfig: + name: str + host: str + configs: dict[str, Any] = field(default_factory=dict) + errors: list[str] = field(default_factory=list) + reachable: bool = True + + +@dataclass +class DriftResult: + node: str + file_path: str + diff_type: str # "missing", "value_mismatch", "key_missing", "extra_key" + key: str + canonical_value: Any = None + node_value: Any = None + severity: str = "warning" # "info", "warning", "critical" + + +# ── Inventory Parsing ───────────────────────────────────────────────────────── + +def load_inventory() -> dict: + """Load Ansible inventory and extract wizard node definitions.""" + if not ANSIBLE_INVENTORY.exists(): + print(f"ERROR: Inventory not found at {ANSIBLE_INVENTORY}", file=sys.stderr) + sys.exit(2) + + with open(ANSIBLE_INVENTORY) as f: + inventory = yaml.safe_load(f) + + wizards = inventory.get("all", {}).get("children", {}).get("wizards", {}).get("hosts", {}) + global_vars = inventory.get("all", {}).get("vars", {}) + + nodes = {} + for name, config in wizards.items(): + nodes[name] = { + "host": config.get("ansible_host", "localhost"), + "user": config.get("ansible_user", ""), + "wizard_name": config.get("wizard_name", name), + "hermes_home": config.get("hermes_home", "~/.hermes"), + "wizard_home": config.get("wizard_home", f"~/wizards/{name}"), + "machine_type": config.get("machine_type", "unknown"), + } + + return nodes, global_vars + + +def load_golden_state(inventory_vars: dict) -> dict: + """Extract golden state from inventory vars.""" + golden = { + "providers": inventory_vars.get("golden_state_providers", []), + "banned_providers": inventory_vars.get("banned_providers", []), + "banned_models_patterns": inventory_vars.get("banned_models_patterns", []), + } + return golden + + +# ── SSH Collection ──────────────────────────────────────────────────────────── + +def ssh_collect(node_name: str, node_info: dict, timeout: int = 15) -> NodeConfig: + """SSH into a node and collect config files.""" + host = node_info["host"] + user = node_info.get("user", "") + hermes_home = node_info.get("hermes_home", "~/.hermes") + wizard_home = node_info.get("wizard_home", f"~/wizards/{node_name}") + + result = NodeConfig(name=node_name, host=host) + + # Build SSH target + if host in ("localhost", "127.0.0.1"): + ssh_target = None # local + else: + ssh_target = f"{user}@{host}" if user else host + + # Collect each config path + for path_template in CONFIG_PATHS: + # Resolve path template + remote_path = path_template.replace("{name}", node_name) + if not remote_path.startswith("/"): + # Resolve relative to home + if "wizards/" in remote_path: + full_path = f"{wizard_home}/config.yaml" + else: + full_path = f"{hermes_home}/config.yaml" if ".hermes" in remote_path else f"~/{remote_path}" + else: + full_path = remote_path + + config_content = _remote_cat(ssh_target, full_path, timeout) + if config_content is not None: + try: + parsed = yaml.safe_load(config_content) + if parsed: + result.configs[full_path] = parsed + except yaml.YAMLError as e: + result.errors.append(f"YAML parse error in {full_path}: {e}") + # Don't flag missing files as errors — some paths may not exist on all nodes + + # Also collect banned provider scan + banned_check = _remote_grep( + ssh_target, + hermes_home, + r"anthropic|claude-sonnet|claude-opus|claude-haiku", + timeout + ) + if banned_check: + result.configs["__banned_scan__"] = banned_check + + return result + + +def _remote_cat(ssh_target: str | None, path: str, timeout: int) -> str | None: + """Cat a file remotely (or locally).""" + if ssh_target is None: + cmd = ["cat", path] + else: + cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no", + ssh_target, f"cat {path}"] + + try: + proc = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) + if proc.returncode == 0: + return proc.stdout + except subprocess.TimeoutExpired: + pass + except FileNotFoundError: + pass + return None + + +def _remote_grep(ssh_target: str | None, base_path: str, pattern: str, timeout: int) -> dict: + """Grep for banned patterns in config files.""" + if ssh_target is None: + cmd = ["grep", "-rn", "-i", pattern, base_path, "--include=*.yaml", "--include=*.yml"] + else: + cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no", + ssh_target, f"grep -rn -i '{pattern}' {base_path} --include='*.yaml' --include='*.yml' 2>/dev/null || true"] + + try: + proc = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) + if proc.stdout.strip(): + lines = proc.stdout.strip().split("\n") + return {"matches": lines, "count": len(lines)} + except subprocess.TimeoutExpired: + pass + return {} + + +# ── Drift Detection ─────────────────────────────────────────────────────────── + +def detect_drift(nodes: list[NodeConfig], golden: dict) -> list[DriftResult]: + """Compare each node's config against golden state.""" + results = [] + + for node in nodes: + if not node.reachable: + continue + + # Check for banned providers + banned_scan = node.configs.get("__banned_scan__", {}) + if banned_scan.get("count", 0) > 0: + for match in banned_scan.get("matches", []): + results.append(DriftResult( + node=node.name, + file_path="(config files)", + diff_type="banned_provider_found", + key="banned_provider_reference", + node_value=match, + severity="critical" + )) + + # Check each config file + for path, config in node.configs.items(): + if path == "__banned_scan__": + continue + + # Check provider chain + if isinstance(config, dict): + node_providers = _extract_provider_chain(config) + golden_providers = golden.get("providers", []) + + if node_providers and golden_providers: + # Compare provider names in order + node_names = [p.get("name", "") for p in node_providers] + golden_names = [p.get("name", "") for p in golden_providers] + + if node_names != golden_names: + results.append(DriftResult( + node=node.name, + file_path=path, + diff_type="value_mismatch", + key="provider_chain", + canonical_value=golden_names, + node_value=node_names, + severity="critical" + )) + + # Check for banned providers in node config + for banned in golden.get("banned_providers", []): + for provider in node_providers: + prov_name = provider.get("name", "").lower() + prov_model = provider.get("model", "").lower() + if banned in prov_name or banned in prov_model: + results.append(DriftResult( + node=node.name, + file_path=path, + diff_type="banned_provider_found", + key=f"provider.{provider.get('name', 'unknown')}", + node_value=provider, + severity="critical" + )) + + # Check for missing critical keys + critical_keys = ["display", "providers", "tools", "delegation"] + for key in critical_keys: + if key not in config and key in str(config): + results.append(DriftResult( + node=node.name, + file_path=path, + diff_type="key_missing", + key=key, + canonical_value="(present in golden state)", + severity="warning" + )) + + return results + + +def _extract_provider_chain(config: dict) -> list[dict]: + """Extract provider list from a config dict (handles multiple formats).""" + # Direct providers key + if "providers" in config: + providers = config["providers"] + if isinstance(providers, list): + return providers + + # Nested in display or model config + for key in ["model", "inference", "llm"]: + if key in config and isinstance(config[key], dict): + if "providers" in config[key]: + return config[key]["providers"] + + # Single provider format + if "provider" in config and "model" in config: + return [{"name": config["provider"], "model": config["model"]}] + + return [] + + +# ── Auto-Sync ───────────────────────────────────────────────────────────────── + +def auto_sync(drifts: list[DriftResult], nodes: list[NodeConfig]) -> list[str]: + """Auto-sync drifted nodes using golden state playbook.""" + actions = [] + + drifted_nodes = set(d.node for d in drifts if d.severity == "critical") + if not drifted_nodes: + actions.append("No critical drift to sync.") + return actions + + for node_name in drifted_nodes: + node_info = next((n for n in nodes if n.name == node_name), None) + if not node_info: + continue + + actions.append(f"[{node_name}] Running golden state sync...") + + # Run ansible-playbook for this node + cmd = [ + "ansible-playbook", + str(GOLDEN_STATE_PLAYBOOK), + "-i", str(ANSIBLE_INVENTORY), + "-l", node_name, + "--tags", "golden", + ] + + try: + proc = subprocess.run( + cmd, capture_output=True, text=True, timeout=120, + cwd=str(REPO_ROOT) + ) + if proc.returncode == 0: + actions.append(f"[{node_name}] Sync completed successfully.") + else: + actions.append(f"[{node_name}] Sync FAILED: {proc.stderr[:200]}") + except subprocess.TimeoutExpired: + actions.append(f"[{node_name}] Sync timed out after 120s.") + except FileNotFoundError: + actions.append(f"[{node_name}] ansible-playbook not found. Install Ansible or run manually.") + + return actions + + +# ── Reporting ───────────────────────────────────────────────────────────────── + +def print_report(drifts: list[DriftResult], nodes: list[NodeConfig], golden: dict): + """Print human-readable drift report.""" + print("=" * 70) + print("CONFIG DRIFT DETECTION REPORT") + print("=" * 70) + print() + + # Summary + reachable = sum(1 for n in nodes if n.reachable) + print(f"Nodes checked: {len(nodes)} (reachable: {reachable})") + print(f"Golden state providers: {' → '.join(p['name'] for p in golden.get('providers', []))}") + print(f"Banned providers: {', '.join(golden.get('banned_providers', []))}") + print() + + if not drifts: + print("[OK] No config drift detected. All nodes match golden state.") + return + + # Group by node + by_node: dict[str, list[DriftResult]] = {} + for d in drifts: + by_node.setdefault(d.node, []).append(d) + + for node_name, node_drifts in sorted(by_node.items()): + print(f"--- {node_name} ---") + for d in node_drifts: + severity_icon = {"critical": "[!!]", "warning": "[!]", "info": "[i]"}.get(d.severity, "[?]") + print(f" {severity_icon} {d.diff_type}: {d.key}") + if d.canonical_value is not None: + print(f" canonical: {d.canonical_value}") + if d.node_value is not None: + print(f" actual: {d.node_value}") + print() + + # Severity summary + critical = sum(1 for d in drifts if d.severity == "critical") + warning = sum(1 for d in drifts if d.severity == "warning") + print(f"Total: {len(drifts)} drift(s) — {critical} critical, {warning} warning") + + +def print_json_report(drifts: list[DriftResult], nodes: list[NodeConfig], golden: dict): + """Print JSON report for automation.""" + report = { + "nodes_checked": len(nodes), + "reachable": sum(1 for n in nodes if n.reachable), + "golden_providers": [p["name"] for p in golden.get("providers", [])], + "drift_count": len(drifts), + "critical_count": sum(1 for d in drifts if d.severity == "critical"), + "drifts": [ + { + "node": d.node, + "file": d.file_path, + "type": d.diff_type, + "key": d.key, + "canonical": d.canonical_value, + "actual": d.node_value, + "severity": d.severity, + } + for d in drifts + ], + } + print(json.dumps(report, indent=2, default=str)) + + +# ── CLI ─────────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser(description="Detect config drift across fleet nodes") + parser.add_argument("--node", help="Check only this node") + parser.add_argument("--auto-sync", action="store_true", help="Auto-fix critical drift with golden state") + parser.add_argument("--json", action="store_true", help="JSON output") + parser.add_argument("--timeout", type=int, default=15, help="SSH timeout per node (seconds)") + args = parser.parse_args() + + # Load inventory + print("Loading inventory...", file=sys.stderr) + node_defs, global_vars = load_inventory() + golden = load_golden_state(global_vars) + + # Filter to single node if requested + if args.node: + if args.node not in node_defs: + print(f"ERROR: Node '{args.node}' not in inventory. Available: {', '.join(node_defs.keys())}") + sys.exit(2) + node_defs = {args.node: node_defs[args.node]} + + # Collect configs from each node + print(f"Collecting configs from {len(node_defs)} node(s)...", file=sys.stderr) + nodes = [] + for name, info in node_defs.items(): + print(f" {name} ({info['host']})...", file=sys.stderr, end=" ", flush=True) + node_config = ssh_collect(name, info, timeout=args.timeout) + if node_config.reachable: + print(f"OK ({len(node_config.configs)} files)", file=sys.stderr) + else: + print("UNREACHABLE", file=sys.stderr) + nodes.append(node_config) + + # Detect drift + print("\nAnalyzing drift...", file=sys.stderr) + drifts = detect_drift(nodes, golden) + + # Output + if args.json: + print_json_report(drifts, nodes, golden) + else: + print() + print_report(drifts, nodes, golden) + + # Auto-sync if requested + if args.auto_sync and drifts: + print("\n--- AUTO-SYNC ---") + actions = auto_sync(drifts, nodes) + for a in actions: + print(a) + + # Exit code + if any(d.severity == "critical" for d in drifts): + sys.exit(1) + elif drifts: + sys.exit(1) + else: + sys.exit(0) + + +if __name__ == "__main__": + main()