Compare commits
10 Commits
feat/gofai
...
timmy/issu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
19aa0830f4 | ||
| f2edb6a9b3 | |||
| fc817c6a84 | |||
| a620bd19b3 | |||
| 0c98bce77f | |||
| c01e7f7d7f | |||
| 20bc0aa41a | |||
| b6c0620c83 | |||
| 17de7f5df1 | |||
|
|
06031d923f |
@@ -9,7 +9,12 @@ Replaces ad-hoc dispatch scripts with a unified framework for tasking agents.
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
|
||||||
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
if SCRIPT_DIR not in sys.path:
|
||||||
|
sys.path.insert(0, SCRIPT_DIR)
|
||||||
|
|
||||||
|
from ssh_trust import VerifiedSSHExecutor
|
||||||
|
|
||||||
# --- CONFIGURATION ---
|
# --- CONFIGURATION ---
|
||||||
FLEET = {
|
FLEET = {
|
||||||
@@ -18,6 +23,9 @@ FLEET = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class Dispatcher:
|
class Dispatcher:
|
||||||
|
def __init__(self, executor=None):
|
||||||
|
self.executor = executor or VerifiedSSHExecutor()
|
||||||
|
|
||||||
def log(self, message: str):
|
def log(self, message: str):
|
||||||
print(f"[*] {message}")
|
print(f"[*] {message}")
|
||||||
|
|
||||||
@@ -25,14 +33,14 @@ class Dispatcher:
|
|||||||
self.log(f"Dispatching task to {agent_name} on {host}...")
|
self.log(f"Dispatching task to {agent_name} on {host}...")
|
||||||
|
|
||||||
ip = FLEET[host]
|
ip = FLEET[host]
|
||||||
# Command to run the agent on the remote machine
|
|
||||||
# Assumes hermes-agent is installed in /opt/hermes
|
|
||||||
remote_cmd = f"cd /opt/hermes && python3 run_agent.py --agent {agent_name} --task '{task}'"
|
|
||||||
|
|
||||||
ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", f"root@{ip}", remote_cmd]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = subprocess.run(ssh_cmd, capture_output=True, text=True)
|
res = self.executor.run(
|
||||||
|
ip,
|
||||||
|
['python3', 'run_agent.py', '--agent', agent_name, '--task', task],
|
||||||
|
cwd='/opt/hermes',
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
if res.returncode == 0:
|
if res.returncode == 0:
|
||||||
self.log(f"[SUCCESS] {agent_name} completed task.")
|
self.log(f"[SUCCESS] {agent_name} completed task.")
|
||||||
print(res.stdout)
|
print(res.stdout)
|
||||||
|
|||||||
@@ -11,10 +11,15 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
|
||||||
import requests
|
import requests
|
||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
|
|
||||||
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
if SCRIPT_DIR not in sys.path:
|
||||||
|
sys.path.insert(0, SCRIPT_DIR)
|
||||||
|
|
||||||
|
from ssh_trust import VerifiedSSHExecutor
|
||||||
|
|
||||||
# --- FLEET DEFINITION ---
|
# --- FLEET DEFINITION ---
|
||||||
FLEET = {
|
FLEET = {
|
||||||
"mac": {"ip": "10.1.10.77", "port": 8080, "role": "hub"},
|
"mac": {"ip": "10.1.10.77", "port": 8080, "role": "hub"},
|
||||||
@@ -24,8 +29,9 @@ FLEET = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class FleetManager:
|
class FleetManager:
|
||||||
def __init__(self):
|
def __init__(self, executor=None):
|
||||||
self.results = {}
|
self.results = {}
|
||||||
|
self.executor = executor or VerifiedSSHExecutor()
|
||||||
|
|
||||||
def run_remote(self, host: str, command: str):
|
def run_remote(self, host: str, command: str):
|
||||||
ip = FLEET[host]["ip"]
|
ip = FLEET[host]["ip"]
|
||||||
|
|||||||
@@ -15,10 +15,15 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
import requests
|
import requests
|
||||||
import subprocess
|
|
||||||
import json
|
import json
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
if SCRIPT_DIR not in sys.path:
|
||||||
|
sys.path.insert(0, SCRIPT_DIR)
|
||||||
|
|
||||||
|
from ssh_trust import VerifiedSSHExecutor
|
||||||
|
|
||||||
# --- CONFIGURATION ---
|
# --- CONFIGURATION ---
|
||||||
DO_API_URL = "https://api.digitalocean.com/v2"
|
DO_API_URL = "https://api.digitalocean.com/v2"
|
||||||
# We expect DIGITALOCEAN_TOKEN to be set in the environment.
|
# We expect DIGITALOCEAN_TOKEN to be set in the environment.
|
||||||
@@ -30,13 +35,14 @@ DEFAULT_IMAGE = "ubuntu-22-04-x64"
|
|||||||
LLAMA_CPP_REPO = "https://github.com/ggerganov/llama.cpp"
|
LLAMA_CPP_REPO = "https://github.com/ggerganov/llama.cpp"
|
||||||
|
|
||||||
class Provisioner:
|
class Provisioner:
|
||||||
def __init__(self, name: str, size: str, model: str, region: str = DEFAULT_REGION):
|
def __init__(self, name: str, size: str, model: str, region: str = DEFAULT_REGION, executor=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model = model
|
||||||
self.region = region
|
self.region = region
|
||||||
self.droplet_id = None
|
self.droplet_id = None
|
||||||
self.ip_address = None
|
self.ip_address = None
|
||||||
|
self.executor = executor or VerifiedSSHExecutor(auto_enroll=True)
|
||||||
|
|
||||||
def log(self, message: str):
|
def log(self, message: str):
|
||||||
print(f"[*] {message}")
|
print(f"[*] {message}")
|
||||||
@@ -104,13 +110,8 @@ class Provisioner:
|
|||||||
self.log(f"Droplet IP: {self.ip_address}")
|
self.log(f"Droplet IP: {self.ip_address}")
|
||||||
|
|
||||||
def run_remote(self, command: str):
|
def run_remote(self, command: str):
|
||||||
# Using subprocess to call ssh. Assumes local machine has the right private key.
|
# Uses verified host trust. Brand-new nodes explicitly enroll on first contact.
|
||||||
ssh_cmd = [
|
return self.executor.run_script(self.ip_address, command, timeout=60)
|
||||||
"ssh", "-o", "StrictHostKeyChecking=no",
|
|
||||||
f"root@{self.ip_address}", command
|
|
||||||
]
|
|
||||||
result = subprocess.run(ssh_cmd, capture_output=True, text=True)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def setup_wizard(self):
|
def setup_wizard(self):
|
||||||
self.log("Starting remote setup...")
|
self.log("Starting remote setup...")
|
||||||
|
|||||||
@@ -10,10 +10,16 @@ Safe-by-default: runs in dry-run mode unless --execute is given.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import subprocess
|
|
||||||
import argparse
|
import argparse
|
||||||
import requests
|
import requests
|
||||||
import datetime
|
import datetime
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
if SCRIPT_DIR not in sys.path:
|
||||||
|
sys.path.insert(0, SCRIPT_DIR)
|
||||||
|
|
||||||
|
from ssh_trust import VerifiedSSHExecutor
|
||||||
|
|
||||||
# --- CONFIGURATION ---
|
# --- CONFIGURATION ---
|
||||||
FLEET = {
|
FLEET = {
|
||||||
@@ -24,54 +30,24 @@ FLEET = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class SelfHealer:
|
class SelfHealer:
|
||||||
def __init__(self, dry_run=True, confirm_kill=False, yes=False):
|
def __init__(self, dry_run=True, confirm_kill=False, yes=False, executor=None):
|
||||||
self.dry_run = dry_run
|
self.dry_run = dry_run
|
||||||
self.confirm_kill = confirm_kill
|
self.confirm_kill = confirm_kill
|
||||||
self.yes = yes
|
self.yes = yes
|
||||||
|
self.executor = executor or VerifiedSSHExecutor()
|
||||||
|
|
||||||
def log(self, message: str):
|
def log(self, message: str):
|
||||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
print(f"[{timestamp}] {message}")
|
print(f"[{timestamp}] {message}")
|
||||||
|
|
||||||
def run_remote(self, host: str, command: str):
|
def run_remote(self, host: str, command: str):
|
||||||
ip = FLEET[host]["ip"]
|
ip = FLEET[host]['ip']
|
||||||
ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5", f"root@{ip}", command]
|
|
||||||
if host == "mac":
|
|
||||||
ssh_cmd = ["bash", "-c", command]
|
|
||||||
try:
|
try:
|
||||||
return subprocess.run(ssh_cmd, capture_output=True, text=True, timeout=15)
|
return self.executor.run_script(ip, command, local=(host == 'mac'), timeout=15)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f" [ERROR] Failed to run remote command on {host}: {e}")
|
self.log(f" [ERROR] Failed to run remote command on {host}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def confirm(self, prompt: str) -> bool:
|
|
||||||
"""Ask for confirmation unless --yes flag is set."""
|
|
||||||
if self.yes:
|
|
||||||
return True
|
|
||||||
while True:
|
|
||||||
response = input(f"{prompt} [y/N] ").strip().lower()
|
|
||||||
if response in ("y", "yes"):
|
|
||||||
return True
|
|
||||||
elif response in ("n", "no", ""):
|
|
||||||
return False
|
|
||||||
print("Please answer 'y' or 'n'.")
|
|
||||||
|
|
||||||
def check_llama_server(self, host: str):
|
|
||||||
ip = FLEET[host]["ip"]
|
|
||||||
port = FLEET[host]["port"]
|
|
||||||
try:
|
|
||||||
requests.get(f"http://{ip}:{port}/health", timeout=2)
|
|
||||||
except:
|
|
||||||
self.log(f" [!] llama-server down on {host}.")
|
|
||||||
if self.dry_run:
|
|
||||||
self.log(f" [DRY-RUN] Would restart llama-server on {host}")
|
|
||||||
else:
|
|
||||||
if self.confirm(f" Restart llama-server on {host}?"):
|
|
||||||
self.log(f" Restarting llama-server on {host}...")
|
|
||||||
self.run_remote(host, "systemctl restart llama-server")
|
|
||||||
else:
|
|
||||||
self.log(f" Skipped restart on {host}.")
|
|
||||||
|
|
||||||
def check_disk_space(self, host: str):
|
def check_disk_space(self, host: str):
|
||||||
res = self.run_remote(host, "df -h / | tail -1 | awk '{print $5}' | sed 's/%//'")
|
res = self.run_remote(host, "df -h / | tail -1 | awk '{print $5}' | sed 's/%//'")
|
||||||
if res and res.returncode == 0:
|
if res and res.returncode == 0:
|
||||||
@@ -192,10 +168,10 @@ EXAMPLES:
|
|||||||
"""
|
"""
|
||||||
print(help_text)
|
print(help_text)
|
||||||
|
|
||||||
def main():
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Self-healing infrastructure script (safe-by-default).",
|
description="Self-healing infrastructure script (safe-by-default).",
|
||||||
add_help=False # We'll handle --help ourselves
|
add_help=False,
|
||||||
)
|
)
|
||||||
parser.add_argument("--dry-run", action="store_true", default=False,
|
parser.add_argument("--dry-run", action="store_true", default=False,
|
||||||
help="Run in dry-run mode (default behavior).")
|
help="Run in dry-run mode (default behavior).")
|
||||||
@@ -209,25 +185,28 @@ def main():
|
|||||||
help="Show detailed help about safety features.")
|
help="Show detailed help about safety features.")
|
||||||
parser.add_argument("--help", "-h", action="store_true", default=False,
|
parser.add_argument("--help", "-h", action="store_true", default=False,
|
||||||
help="Show standard help.")
|
help="Show standard help.")
|
||||||
|
return parser
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
def main(argv: Sequence[str] | None = None):
|
||||||
|
parser = build_parser()
|
||||||
|
args = parser.parse_args(list(argv) if argv is not None else None)
|
||||||
|
|
||||||
if args.help_safe:
|
if args.help_safe:
|
||||||
print_help_safe()
|
print_help_safe()
|
||||||
sys.exit(0)
|
raise SystemExit(0)
|
||||||
|
|
||||||
if args.help:
|
if args.help:
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(0)
|
raise SystemExit(0)
|
||||||
|
|
||||||
# Determine mode: if --execute is given, disable dry-run
|
|
||||||
dry_run = not args.execute
|
dry_run = not args.execute
|
||||||
# If --dry-run is explicitly given, ensure dry-run (redundant but clear)
|
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
dry_run = True
|
dry_run = True
|
||||||
|
|
||||||
healer = SelfHealer(dry_run=dry_run, confirm_kill=args.confirm_kill, yes=args.yes)
|
healer = SelfHealer(dry_run=dry_run, confirm_kill=args.confirm_kill, yes=args.yes)
|
||||||
healer.run()
|
healer.run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
171
scripts/ssh_trust.py
Normal file
171
scripts/ssh_trust.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Verified SSH trust helpers for Gemini infrastructure scripts."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Sequence
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
DEFAULT_KNOWN_HOSTS = Path(__file__).resolve().parent.parent / ".ssh" / "known_hosts"
|
||||||
|
Runner = Callable[..., subprocess.CompletedProcess]
|
||||||
|
|
||||||
|
|
||||||
|
class SSHTrustError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HostKeyEnrollmentError(SSHTrustError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HostKeyVerificationError(SSHTrustError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CommandPlan:
|
||||||
|
def __init__(self, argv: list[str], local: bool, remote_command: str | None = None):
|
||||||
|
self.argv = argv
|
||||||
|
self.local = local
|
||||||
|
self.remote_command = remote_command
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_parent(path: Path) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def enroll_host_key(
|
||||||
|
host: str,
|
||||||
|
*,
|
||||||
|
port: int = 22,
|
||||||
|
known_hosts_path: str | Path | None = None,
|
||||||
|
runner: Runner = subprocess.run,
|
||||||
|
) -> Path:
|
||||||
|
path = Path(known_hosts_path or DEFAULT_KNOWN_HOSTS)
|
||||||
|
_ensure_parent(path)
|
||||||
|
cmd = ["ssh-keyscan", "-p", str(port), "-H", host]
|
||||||
|
result = runner(cmd, capture_output=True, text=True, timeout=10)
|
||||||
|
if result.returncode != 0 or not (result.stdout or "").strip():
|
||||||
|
raise HostKeyEnrollmentError(
|
||||||
|
f"Could not enroll host key for {host}:{port}: {(result.stderr or '').strip() or 'empty ssh-keyscan output'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
existing = []
|
||||||
|
if path.exists():
|
||||||
|
existing = [line for line in path.read_text().splitlines() if line.strip()]
|
||||||
|
for line in result.stdout.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if line and line not in existing:
|
||||||
|
existing.append(line)
|
||||||
|
path.write_text(("\n".join(existing) + "\n") if existing else "")
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
class VerifiedSSHExecutor:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user: str = "root",
|
||||||
|
known_hosts_path: str | Path | None = None,
|
||||||
|
connect_timeout: int = 5,
|
||||||
|
auto_enroll: bool = False,
|
||||||
|
runner: Runner = subprocess.run,
|
||||||
|
):
|
||||||
|
self.user = user
|
||||||
|
self.known_hosts_path = Path(known_hosts_path or DEFAULT_KNOWN_HOSTS)
|
||||||
|
self.connect_timeout = connect_timeout
|
||||||
|
self.auto_enroll = auto_enroll
|
||||||
|
self.runner = runner
|
||||||
|
|
||||||
|
def _ensure_known_hosts(self, host: str, port: int) -> Path:
|
||||||
|
if self.known_hosts_path.exists():
|
||||||
|
return self.known_hosts_path
|
||||||
|
if not self.auto_enroll:
|
||||||
|
raise HostKeyEnrollmentError(
|
||||||
|
f"Known-hosts file missing: {self.known_hosts_path}. Enroll {host}:{port} before connecting."
|
||||||
|
)
|
||||||
|
return enroll_host_key(host, port=port, known_hosts_path=self.known_hosts_path, runner=self.runner)
|
||||||
|
|
||||||
|
def _ssh_prefix(self, host: str, port: int) -> list[str]:
|
||||||
|
known_hosts = self._ensure_known_hosts(host, port)
|
||||||
|
return [
|
||||||
|
"ssh",
|
||||||
|
"-o",
|
||||||
|
"BatchMode=yes",
|
||||||
|
"-o",
|
||||||
|
"StrictHostKeyChecking=yes",
|
||||||
|
"-o",
|
||||||
|
f"UserKnownHostsFile={known_hosts}",
|
||||||
|
"-o",
|
||||||
|
f"ConnectTimeout={self.connect_timeout}",
|
||||||
|
"-p",
|
||||||
|
str(port),
|
||||||
|
f"{self.user}@{host}",
|
||||||
|
]
|
||||||
|
|
||||||
|
def plan(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
command: Sequence[str],
|
||||||
|
*,
|
||||||
|
local: bool = False,
|
||||||
|
port: int = 22,
|
||||||
|
cwd: str | None = None,
|
||||||
|
) -> CommandPlan:
|
||||||
|
argv = [str(part) for part in command]
|
||||||
|
if not argv:
|
||||||
|
raise ValueError("command must not be empty")
|
||||||
|
if local:
|
||||||
|
return CommandPlan(argv=argv, local=True, remote_command=None)
|
||||||
|
|
||||||
|
remote_command = shlex.join(argv)
|
||||||
|
if cwd:
|
||||||
|
remote_command = f"cd {shlex.quote(cwd)} && exec {remote_command}"
|
||||||
|
return CommandPlan(self._ssh_prefix(host, port) + [remote_command], False, remote_command)
|
||||||
|
|
||||||
|
def plan_script(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
script_text: str,
|
||||||
|
*,
|
||||||
|
local: bool = False,
|
||||||
|
port: int = 22,
|
||||||
|
cwd: str | None = None,
|
||||||
|
) -> CommandPlan:
|
||||||
|
remote_command = script_text.strip()
|
||||||
|
if cwd:
|
||||||
|
remote_command = f"cd {shlex.quote(cwd)} && {remote_command}"
|
||||||
|
if local:
|
||||||
|
return CommandPlan(["sh", "-lc", remote_command], True, None)
|
||||||
|
return CommandPlan(self._ssh_prefix(host, port) + [remote_command], False, remote_command)
|
||||||
|
|
||||||
|
def _run_plan(self, plan: CommandPlan, *, timeout: int | None = None):
|
||||||
|
result = self.runner(plan.argv, capture_output=True, text=True, timeout=timeout)
|
||||||
|
if result.returncode != 0 and "host key verification failed" in ((result.stderr or "").lower()):
|
||||||
|
raise HostKeyVerificationError((result.stderr or "").strip() or "Host key verification failed")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
command: Sequence[str],
|
||||||
|
*,
|
||||||
|
local: bool = False,
|
||||||
|
port: int = 22,
|
||||||
|
cwd: str | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
):
|
||||||
|
return self._run_plan(self.plan(host, command, local=local, port=port, cwd=cwd), timeout=timeout)
|
||||||
|
|
||||||
|
def run_script(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
script_text: str,
|
||||||
|
*,
|
||||||
|
local: bool = False,
|
||||||
|
port: int = 22,
|
||||||
|
cwd: str | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
):
|
||||||
|
return self._run_plan(self.plan_script(host, script_text, local=local, port=port, cwd=cwd), timeout=timeout)
|
||||||
304
scripts/strips_planner.py
Normal file
304
scripts/strips_planner.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""strips_planner.py - GOFAI STRIPS-style goal-directed planner for the Timmy Foundation fleet.
|
||||||
|
|
||||||
|
Implements a classical means-ends analysis (MEA) planner over a STRIPS action
|
||||||
|
representation. Each action has preconditions, an add-list, and a delete-list.
|
||||||
|
The planner uses regression (backward chaining) from the goal state to find a
|
||||||
|
linear action sequence that achieves all goal conditions from the initial state.
|
||||||
|
No ML, no embeddings - just symbolic state-space search.
|
||||||
|
|
||||||
|
Representation:
|
||||||
|
State: frozenset of ground literals, e.g. {'agent_idle', 'task_queued'}
|
||||||
|
Action: (name, preconditions, add_effects, delete_effects)
|
||||||
|
Goal: set of literals that must hold in the final state
|
||||||
|
|
||||||
|
Algorithm:
|
||||||
|
Iterative-deepening DFS (IDDFS) over the regression search space.
|
||||||
|
Cycle detection via visited-state set per path.
|
||||||
|
|
||||||
|
Usage (Python API):
|
||||||
|
from strips_planner import Action, STRIPSPlanner
|
||||||
|
actions = [
|
||||||
|
Action('assign_task',
|
||||||
|
pre={'agent_idle', 'task_queued'},
|
||||||
|
add={'task_running'},
|
||||||
|
delete={'agent_idle', 'task_queued'}),
|
||||||
|
Action('complete_task',
|
||||||
|
pre={'task_running'},
|
||||||
|
add={'agent_idle', 'task_done'},
|
||||||
|
delete={'task_running'}),
|
||||||
|
]
|
||||||
|
planner = STRIPSPlanner(actions)
|
||||||
|
plan = planner.solve(
|
||||||
|
initial={'agent_idle', 'task_queued'},
|
||||||
|
goal={'task_done'},
|
||||||
|
)
|
||||||
|
# plan -> ['assign_task', 'complete_task']
|
||||||
|
|
||||||
|
CLI:
|
||||||
|
python strips_planner.py --demo
|
||||||
|
python strips_planner.py --max-depth 15
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import FrozenSet, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Data model
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Literal = str
|
||||||
|
State = FrozenSet[Literal]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Action:
|
||||||
|
"""A STRIPS operator with preconditions and add/delete effects."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
pre: FrozenSet[Literal]
|
||||||
|
add: FrozenSet[Literal]
|
||||||
|
delete: FrozenSet[Literal]
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
# Coerce mutable sets to frozensets for hashability
|
||||||
|
object.__setattr__(self, 'pre', frozenset(self.pre))
|
||||||
|
object.__setattr__(self, 'add', frozenset(self.add))
|
||||||
|
object.__setattr__(self, 'delete', frozenset(self.delete))
|
||||||
|
|
||||||
|
def applicable(self, state: State) -> bool:
|
||||||
|
"""True if all preconditions hold in *state*."""
|
||||||
|
return self.pre <= state
|
||||||
|
|
||||||
|
def apply(self, state: State) -> State:
|
||||||
|
"""Return the successor state after executing this action."""
|
||||||
|
return (state - self.delete) | self.add
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Planner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class STRIPSPlanner:
|
||||||
|
"""Goal-directed STRIPS planner using iterative-deepening DFS.
|
||||||
|
|
||||||
|
Searches forward from the initial state, pruning branches where the
|
||||||
|
goal cannot be satisfied within the remaining depth budget.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, actions: List[Action]) -> None:
|
||||||
|
self.actions = list(actions)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def solve(
|
||||||
|
self,
|
||||||
|
initial: Set[Literal] | FrozenSet[Literal],
|
||||||
|
goal: Set[Literal] | FrozenSet[Literal],
|
||||||
|
max_depth: int = 20,
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
"""Find a plan that achieves *goal* from *initial*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial: Initial world state (set of ground literals).
|
||||||
|
goal: Goal conditions (set of ground literals to achieve).
|
||||||
|
max_depth: Maximum plan length to consider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Ordered list of action names, or None if no plan found.
|
||||||
|
"""
|
||||||
|
s0 = frozenset(initial)
|
||||||
|
g = frozenset(goal)
|
||||||
|
|
||||||
|
if g <= s0:
|
||||||
|
return [] # goal already satisfied
|
||||||
|
|
||||||
|
for depth in range(1, max_depth + 1):
|
||||||
|
result = self._dfs(s0, g, depth, [], {s0})
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal search
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _dfs(
|
||||||
|
self,
|
||||||
|
state: State,
|
||||||
|
goal: State,
|
||||||
|
remaining: int,
|
||||||
|
path: List[str],
|
||||||
|
visited: Set[State],
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
"""Depth-limited forward DFS."""
|
||||||
|
if remaining == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for action in self.actions:
|
||||||
|
if not action.applicable(state):
|
||||||
|
continue
|
||||||
|
next_state = action.apply(state)
|
||||||
|
if next_state in visited:
|
||||||
|
continue
|
||||||
|
new_path = path + [action.name]
|
||||||
|
if goal <= next_state:
|
||||||
|
return new_path
|
||||||
|
visited.add(next_state)
|
||||||
|
result = self._dfs(next_state, goal, remaining - 1, new_path, visited)
|
||||||
|
visited.discard(next_state)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def explain_plan(
|
||||||
|
self, initial: Set[Literal], plan: List[str]
|
||||||
|
) -> List[Tuple[str, State]]:
|
||||||
|
"""Trace each action with the resulting state for debugging.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (action_name, resulting_state) tuples.
|
||||||
|
"""
|
||||||
|
state: State = frozenset(initial)
|
||||||
|
trace: List[Tuple[str, State]] = []
|
||||||
|
action_map = {a.name: a for a in self.actions}
|
||||||
|
for name in plan:
|
||||||
|
action = action_map[name]
|
||||||
|
state = action.apply(state)
|
||||||
|
trace.append((name, state))
|
||||||
|
return trace
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Built-in demo domain: Timmy fleet task lifecycle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _fleet_demo_actions() -> List[Action]:
|
||||||
|
"""Return a small STRIPS domain modelling the Timmy fleet task lifecycle."""
|
||||||
|
return [
|
||||||
|
Action(
|
||||||
|
name='receive_task',
|
||||||
|
pre={'fleet_idle'},
|
||||||
|
add={'task_queued', 'fleet_busy'},
|
||||||
|
delete={'fleet_idle'},
|
||||||
|
),
|
||||||
|
Action(
|
||||||
|
name='validate_task',
|
||||||
|
pre={'task_queued'},
|
||||||
|
add={'task_validated'},
|
||||||
|
delete={'task_queued'},
|
||||||
|
),
|
||||||
|
Action(
|
||||||
|
name='assign_agent',
|
||||||
|
pre={'task_validated', 'agent_available'},
|
||||||
|
add={'task_assigned'},
|
||||||
|
delete={'task_validated', 'agent_available'},
|
||||||
|
),
|
||||||
|
Action(
|
||||||
|
name='execute_task',
|
||||||
|
pre={'task_assigned'},
|
||||||
|
add={'task_running'},
|
||||||
|
delete={'task_assigned'},
|
||||||
|
),
|
||||||
|
Action(
|
||||||
|
name='complete_task',
|
||||||
|
pre={'task_running'},
|
||||||
|
add={'task_done', 'agent_available', 'fleet_idle'},
|
||||||
|
delete={'task_running', 'fleet_busy'},
|
||||||
|
),
|
||||||
|
Action(
|
||||||
|
name='report_result',
|
||||||
|
pre={'task_done'},
|
||||||
|
add={'result_reported'},
|
||||||
|
delete={'task_done'},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def run_demo(max_depth: int = 20) -> None:
|
||||||
|
"""Run the built-in Timmy fleet planning demo."""
|
||||||
|
actions = _fleet_demo_actions()
|
||||||
|
planner = STRIPSPlanner(actions)
|
||||||
|
|
||||||
|
initial: Set[Literal] = {'fleet_idle', 'agent_available'}
|
||||||
|
goal: Set[Literal] = {'result_reported', 'fleet_idle'}
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("STRIPS Planner Demo - Timmy Fleet Task Lifecycle")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Initial state : {sorted(initial)}")
|
||||||
|
print(f"Goal : {sorted(goal)}")
|
||||||
|
print(f"Max depth : {max_depth}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
plan = planner.solve(initial, goal, max_depth=max_depth)
|
||||||
|
|
||||||
|
if plan is None:
|
||||||
|
print("No plan found within depth limit.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Plan ({len(plan)} steps):")
|
||||||
|
for i, step in enumerate(plan, 1):
|
||||||
|
print(f" {i:2d}. {step}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Execution trace:")
|
||||||
|
state: Set[Literal] = set(initial)
|
||||||
|
for name, resulting_state in planner.explain_plan(initial, plan):
|
||||||
|
print(f" -> {name}")
|
||||||
|
print(f" state: {sorted(resulting_state)}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
achieved = frozenset(goal) <= frozenset(state) or True
|
||||||
|
goal_met = all(g in [s for _, rs in planner.explain_plan(initial, plan) for s in rs]
|
||||||
|
or g in initial for g in goal)
|
||||||
|
final_state = planner.explain_plan(initial, plan)[-1][1] if plan else frozenset(initial)
|
||||||
|
print(f"Goal satisfied: {frozenset(goal) <= final_state}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CLI
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="GOFAI STRIPS-style goal-directed planner"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--demo",
|
||||||
|
action="store_true",
|
||||||
|
help="Run the built-in Timmy fleet demo",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-depth",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
metavar="N",
|
||||||
|
help="Maximum plan depth for IDDFS search (default: 20)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.demo or not any(vars(args).values()):
|
||||||
|
run_demo(max_depth=args.max_depth)
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
276
scripts/symbolic_reasoner.py
Normal file
276
scripts/symbolic_reasoner.py
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""symbolic_reasoner.py — Forward-chaining rule engine for the Timmy Foundation fleet.
|
||||||
|
|
||||||
|
A classical GOFAI approach: declarative IF-THEN rules evaluated over a
|
||||||
|
working-memory of facts. Rules fire until quiescence (no new facts) or
|
||||||
|
a configurable cycle limit. Designed to sit *beside* the LLM layer so
|
||||||
|
that hard policy constraints never depend on probabilistic inference.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python symbolic_reasoner.py --rules rules.yaml --facts facts.yaml
|
||||||
|
python symbolic_reasoner.py --self-test
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
except ImportError:
|
||||||
|
yaml = None # graceful fallback — JSON-only mode
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Domain types
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
Fact = Tuple[str, ...] # e.g. ("agent", "timmy", "role", "infrastructure")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Rule:
|
||||||
|
"""A single IF-THEN production rule."""
|
||||||
|
name: str
|
||||||
|
conditions: FrozenSet[Fact] # all must be present
|
||||||
|
negations: FrozenSet[Fact] # none may be present
|
||||||
|
conclusions: FrozenSet[Fact] # added when rule fires
|
||||||
|
priority: int = 0 # higher fires first
|
||||||
|
|
||||||
|
def satisfied(self, wm: Set[Fact]) -> bool:
|
||||||
|
return self.conditions.issubset(wm) and self.negations.isdisjoint(wm)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Engine
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class SymbolicReasoner:
|
||||||
|
"""Forward-chaining production system."""
|
||||||
|
|
||||||
|
def __init__(self, rules: List[Rule], *, cycle_limit: int = 200):
|
||||||
|
self._rules = sorted(rules, key=lambda r: -r.priority)
|
||||||
|
self._cycle_limit = cycle_limit
|
||||||
|
self._trace: List[str] = []
|
||||||
|
|
||||||
|
# -- public API ---------------------------------------------------------
|
||||||
|
|
||||||
|
def infer(self, initial_facts: Set[Fact]) -> Set[Fact]:
|
||||||
|
"""Run to quiescence and return the final working-memory."""
|
||||||
|
wm = set(initial_facts)
|
||||||
|
fired: Set[str] = set()
|
||||||
|
for cycle in range(self._cycle_limit):
|
||||||
|
progress = False
|
||||||
|
for rule in self._rules:
|
||||||
|
if rule.name in fired:
|
||||||
|
continue
|
||||||
|
if rule.satisfied(wm):
|
||||||
|
new = rule.conclusions - wm
|
||||||
|
if new:
|
||||||
|
wm |= new
|
||||||
|
fired.add(rule.name)
|
||||||
|
self._trace.append(
|
||||||
|
f"cycle {cycle}: {rule.name} => {_fmt_facts(new)}"
|
||||||
|
)
|
||||||
|
progress = True
|
||||||
|
break # restart from highest-priority rule
|
||||||
|
if not progress:
|
||||||
|
break
|
||||||
|
return wm
|
||||||
|
|
||||||
|
def query(self, wm: Set[Fact], pattern: Tuple[Optional[str], ...]) -> List[Fact]:
|
||||||
|
"""Return facts matching *pattern* (None = wildcard)."""
|
||||||
|
return [
|
||||||
|
f for f in wm
|
||||||
|
if len(f) == len(pattern)
|
||||||
|
and all(p is None or p == v for p, v in zip(pattern, f))
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trace(self) -> List[str]:
|
||||||
|
return list(self._trace)
|
||||||
|
|
||||||
|
# -- serialisation helpers -----------------------------------------------
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dicts(cls, raw_rules: List[Dict], **kw) -> "SymbolicReasoner":
|
||||||
|
rules = [_parse_rule(r) for r in raw_rules]
|
||||||
|
return cls(rules, **kw)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, path: Path, **kw) -> "SymbolicReasoner":
|
||||||
|
text = path.read_text()
|
||||||
|
if path.suffix in (".yaml", ".yml"):
|
||||||
|
if yaml is None:
|
||||||
|
raise RuntimeError("PyYAML required for .yaml rules")
|
||||||
|
data = yaml.safe_load(text)
|
||||||
|
else:
|
||||||
|
data = json.loads(text)
|
||||||
|
return cls.from_dicts(data["rules"], **kw)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Parsing helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def _parse_fact(raw: list | str) -> Fact:
|
||||||
|
if isinstance(raw, str):
|
||||||
|
return tuple(raw.split())
|
||||||
|
return tuple(str(x) for x in raw)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_rule(d: Dict) -> Rule:
|
||||||
|
return Rule(
|
||||||
|
name=d["name"],
|
||||||
|
conditions=frozenset(_parse_fact(c) for c in d.get("if", [])),
|
||||||
|
negations=frozenset(_parse_fact(c) for c in d.get("unless", [])),
|
||||||
|
conclusions=frozenset(_parse_fact(c) for c in d.get("then", [])),
|
||||||
|
priority=d.get("priority", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_facts(facts: Set[Fact]) -> str:
|
||||||
|
return ", ".join(" ".join(f) for f in sorted(facts))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Built-in fleet rules (loaded when no --rules file is given)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
DEFAULT_FLEET_RULES: List[Dict] = [
|
||||||
|
{
|
||||||
|
"name": "route-ci-to-timmy",
|
||||||
|
"if": [["task", "category", "ci"]],
|
||||||
|
"then": [["assign", "timmy"], ["reason", "timmy", "best-ci-merge-rate"]],
|
||||||
|
"priority": 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "route-security-to-timmy",
|
||||||
|
"if": [["task", "category", "security"]],
|
||||||
|
"then": [["assign", "timmy"], ["reason", "timmy", "security-specialist"]],
|
||||||
|
"priority": 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "route-architecture-to-gemini",
|
||||||
|
"if": [["task", "category", "architecture"]],
|
||||||
|
"unless": [["assign", "timmy"]],
|
||||||
|
"then": [["assign", "gemini"], ["reason", "gemini", "architecture-strength"]],
|
||||||
|
"priority": 8,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "route-review-to-allegro",
|
||||||
|
"if": [["task", "category", "review"]],
|
||||||
|
"then": [["assign", "allegro"], ["reason", "allegro", "highest-quality-per-pr"]],
|
||||||
|
"priority": 9,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "route-frontend-to-claude",
|
||||||
|
"if": [["task", "category", "frontend"]],
|
||||||
|
"unless": [["task", "repo", "fleet-ops"]],
|
||||||
|
"then": [["assign", "claude"], ["reason", "claude", "high-volume-frontend"]],
|
||||||
|
"priority": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "block-merge-without-review",
|
||||||
|
"if": [["pr", "status", "open"], ["pr", "reviews", "0"]],
|
||||||
|
"then": [["pr", "action", "block-merge"], ["reason", "policy", "no-unreviewed-merges"]],
|
||||||
|
"priority": 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "block-merge-ci-failing",
|
||||||
|
"if": [["pr", "status", "open"], ["pr", "ci", "failing"]],
|
||||||
|
"then": [["pr", "action", "block-merge"], ["reason", "policy", "ci-must-pass"]],
|
||||||
|
"priority": 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "auto-label-hotfix",
|
||||||
|
"if": [["pr", "title-prefix", "hotfix"]],
|
||||||
|
"then": [["pr", "label", "hotfix"], ["pr", "priority", "urgent"]],
|
||||||
|
"priority": 15,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Self-test
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def _self_test() -> bool:
|
||||||
|
"""Verify core behaviour; returns True on success."""
|
||||||
|
engine = SymbolicReasoner.from_dicts(DEFAULT_FLEET_RULES)
|
||||||
|
|
||||||
|
# Scenario 1: CI task should route to Timmy
|
||||||
|
wm = engine.infer({("task", "category", "ci")})
|
||||||
|
assert ("assign", "timmy") in wm, f"expected timmy assignment, got {wm}"
|
||||||
|
|
||||||
|
# Scenario 2: architecture task routes to gemini (not timmy)
|
||||||
|
engine2 = SymbolicReasoner.from_dicts(DEFAULT_FLEET_RULES)
|
||||||
|
wm2 = engine2.infer({("task", "category", "architecture")})
|
||||||
|
assert ("assign", "gemini") in wm2, f"expected gemini assignment, got {wm2}"
|
||||||
|
|
||||||
|
# Scenario 3: open PR with no reviews should block merge
|
||||||
|
engine3 = SymbolicReasoner.from_dicts(DEFAULT_FLEET_RULES)
|
||||||
|
wm3 = engine3.infer({("pr", "status", "open"), ("pr", "reviews", "0")})
|
||||||
|
assert ("pr", "action", "block-merge") in wm3
|
||||||
|
|
||||||
|
# Scenario 4: negation — frontend + fleet-ops should NOT assign claude
|
||||||
|
engine4 = SymbolicReasoner.from_dicts(DEFAULT_FLEET_RULES)
|
||||||
|
wm4 = engine4.infer({("task", "category", "frontend"), ("task", "repo", "fleet-ops")})
|
||||||
|
assert ("assign", "claude") not in wm4
|
||||||
|
|
||||||
|
# Scenario 5: query with wildcards
|
||||||
|
results = engine.query(wm, ("reason", None, None))
|
||||||
|
assert len(results) > 0
|
||||||
|
|
||||||
|
print("All 5 self-test scenarios passed.")
|
||||||
|
for line in engine.trace:
|
||||||
|
print(f" {line}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CLI
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def main():
|
||||||
|
ap = argparse.ArgumentParser(description=__doc__)
|
||||||
|
ap.add_argument("--rules", type=Path, help="YAML/JSON rule file")
|
||||||
|
ap.add_argument("--facts", type=Path, help="YAML/JSON initial facts")
|
||||||
|
ap.add_argument("--self-test", action="store_true")
|
||||||
|
ap.add_argument("--json", action="store_true", help="output as JSON")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
if args.self_test:
|
||||||
|
ok = _self_test()
|
||||||
|
sys.exit(0 if ok else 1)
|
||||||
|
|
||||||
|
if args.rules:
|
||||||
|
engine = SymbolicReasoner.from_file(args.rules)
|
||||||
|
else:
|
||||||
|
engine = SymbolicReasoner.from_dicts(DEFAULT_FLEET_RULES)
|
||||||
|
|
||||||
|
if args.facts:
|
||||||
|
text = args.facts.read_text()
|
||||||
|
if args.facts.suffix in (".yaml", ".yml"):
|
||||||
|
raw = yaml.safe_load(text)
|
||||||
|
else:
|
||||||
|
raw = json.loads(text)
|
||||||
|
initial = {_parse_fact(f) for f in raw.get("facts", [])}
|
||||||
|
else:
|
||||||
|
initial = set()
|
||||||
|
print("No --facts provided; running with empty working memory.")
|
||||||
|
|
||||||
|
wm = engine.infer(initial)
|
||||||
|
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps({"facts": [list(f) for f in sorted(wm)], "trace": engine.trace}, indent=2))
|
||||||
|
else:
|
||||||
|
print(f"Final working memory ({len(wm)} facts):")
|
||||||
|
for f in sorted(wm):
|
||||||
|
print(f" {' '.join(f)}")
|
||||||
|
if engine.trace:
|
||||||
|
print(f"\nInference trace ({len(engine.trace)} firings):")
|
||||||
|
for line in engine.trace:
|
||||||
|
print(f" {line}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -10,9 +10,14 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import subprocess
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
if SCRIPT_DIR not in sys.path:
|
||||||
|
sys.path.insert(0, SCRIPT_DIR)
|
||||||
|
|
||||||
|
from ssh_trust import VerifiedSSHExecutor
|
||||||
|
|
||||||
# --- CONFIGURATION ---
|
# --- CONFIGURATION ---
|
||||||
FLEET = {
|
FLEET = {
|
||||||
"mac": "10.1.10.77",
|
"mac": "10.1.10.77",
|
||||||
@@ -23,7 +28,8 @@ FLEET = {
|
|||||||
TELEMETRY_FILE = "logs/telemetry.json"
|
TELEMETRY_FILE = "logs/telemetry.json"
|
||||||
|
|
||||||
class Telemetry:
|
class Telemetry:
|
||||||
def __init__(self):
|
def __init__(self, executor=None):
|
||||||
|
self.executor = executor or VerifiedSSHExecutor()
|
||||||
# Find logs relative to repo root
|
# Find logs relative to repo root
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
repo_root = os.path.dirname(script_dir)
|
repo_root = os.path.dirname(script_dir)
|
||||||
@@ -41,14 +47,12 @@ class Telemetry:
|
|||||||
# Command to get disk usage, memory usage (%), and load avg
|
# Command to get disk usage, memory usage (%), and load avg
|
||||||
cmd = "df -h / | tail -1 | awk '{print $5}' && free -m | grep Mem | awk '{print $3/$2 * 100}' && uptime | awk '{print $10}'"
|
cmd = "df -h / | tail -1 | awk '{print $5}' && free -m | grep Mem | awk '{print $3/$2 * 100}' && uptime | awk '{print $10}'"
|
||||||
|
|
||||||
ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", f"root@{ip}", cmd]
|
if host == 'mac':
|
||||||
if host == "mac":
|
|
||||||
# Mac specific commands
|
# Mac specific commands
|
||||||
cmd = "df -h / | tail -1 | awk '{print $5}' && sysctl -n vm.page_pageable_internal_count && uptime | awk '{print $10}'"
|
cmd = "df -h / | tail -1 | awk '{print $5}' && sysctl -n vm.page_pageable_internal_count && uptime | awk '{print $10}'"
|
||||||
ssh_cmd = ["bash", "-c", cmd]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = subprocess.run(ssh_cmd, capture_output=True, text=True, timeout=10)
|
res = self.executor.run_script(ip, cmd, local=(host == 'mac'), timeout=10)
|
||||||
if res.returncode == 0:
|
if res.returncode == 0:
|
||||||
lines = res.stdout.strip().split("\n")
|
lines = res.stdout.strip().split("\n")
|
||||||
return {
|
return {
|
||||||
|
|||||||
307
scripts/temporal_reasoner.py
Normal file
307
scripts/temporal_reasoner.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""temporal_reasoner.py - GOFAI temporal reasoning engine for the Timmy Foundation fleet.
|
||||||
|
|
||||||
|
A symbolic temporal constraint network (TCN) for scheduling and ordering events.
|
||||||
|
Models Allen's interval algebra relations (before, after, meets, overlaps, etc.)
|
||||||
|
and propagates temporal constraints via path-consistency to detect conflicts.
|
||||||
|
No ML, no embeddings - just constraint propagation over a temporal graph.
|
||||||
|
|
||||||
|
Core concepts:
|
||||||
|
TimePoint: A named instant on a symbolic timeline.
|
||||||
|
Interval: A pair of time-points (start, end) with start < end.
|
||||||
|
Constraint: A relation between two time-points or intervals
|
||||||
|
(e.g. A.before(B), A.meets(B)).
|
||||||
|
|
||||||
|
Usage (Python API):
|
||||||
|
from temporal_reasoner import TemporalNetwork, Interval
|
||||||
|
tn = TemporalNetwork()
|
||||||
|
deploy = tn.add_interval('deploy', duration=(10, 30))
|
||||||
|
test = tn.add_interval('test', duration=(5, 15))
|
||||||
|
tn.add_constraint(deploy, 'before', test)
|
||||||
|
consistent = tn.propagate()
|
||||||
|
|
||||||
|
CLI:
|
||||||
|
python temporal_reasoner.py --demo
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
INF = float('inf')
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Data model
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TimePoint:
|
||||||
|
"""A named instant on the timeline."""
|
||||||
|
name: str
|
||||||
|
id: int = field(default=0)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Interval:
|
||||||
|
"""A named interval bounded by two time-points."""
|
||||||
|
name: str
|
||||||
|
start: int # index into the distance matrix
|
||||||
|
end: int # index into the distance matrix
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
class Relation(Enum):
|
||||||
|
"""Allen's interval algebra relations (simplified subset)."""
|
||||||
|
BEFORE = 'before'
|
||||||
|
AFTER = 'after'
|
||||||
|
MEETS = 'meets'
|
||||||
|
MET_BY = 'met_by'
|
||||||
|
OVERLAPS = 'overlaps'
|
||||||
|
DURING = 'during'
|
||||||
|
EQUALS = 'equals'
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Simple Temporal Network (STN) via distance matrix
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalNetwork:
|
||||||
|
"""Simple Temporal Network with Floyd-Warshall propagation.
|
||||||
|
|
||||||
|
Internally maintains a distance matrix D where D[i][j] is the
|
||||||
|
maximum allowed distance from time-point i to time-point j.
|
||||||
|
Negative cycles indicate inconsistency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._n = 0
|
||||||
|
self._names: List[str] = []
|
||||||
|
self._dist: List[List[float]] = []
|
||||||
|
self._intervals: Dict[str, Interval] = {}
|
||||||
|
self._origin_idx: int = -1
|
||||||
|
self._add_point('origin')
|
||||||
|
self._origin_idx = 0
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Point management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _add_point(self, name: str) -> int:
|
||||||
|
"""Add a time-point and return its index."""
|
||||||
|
idx = self._n
|
||||||
|
self._n += 1
|
||||||
|
self._names.append(name)
|
||||||
|
# Extend distance matrix
|
||||||
|
for row in self._dist:
|
||||||
|
row.append(INF)
|
||||||
|
self._dist.append([INF] * self._n)
|
||||||
|
self._dist[idx][idx] = 0.0
|
||||||
|
return idx
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Interval management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def add_interval(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
duration: Optional[Tuple[float, float]] = None,
|
||||||
|
) -> Interval:
|
||||||
|
"""Add a named interval with optional duration bounds [lo, hi].
|
||||||
|
|
||||||
|
Returns the Interval object with start/end indices.
|
||||||
|
"""
|
||||||
|
s = self._add_point(f"{name}.start")
|
||||||
|
e = self._add_point(f"{name}.end")
|
||||||
|
# start < end (at least 1 time unit)
|
||||||
|
self._dist[s][e] = min(self._dist[s][e], duration[1] if duration else INF)
|
||||||
|
self._dist[e][s] = min(self._dist[e][s], -(duration[0] if duration else 1))
|
||||||
|
interval = Interval(name=name, start=s, end=e)
|
||||||
|
self._intervals[name] = interval
|
||||||
|
return interval
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Constraint management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def add_distance_constraint(
|
||||||
|
self, i: int, j: int, lo: float, hi: float
|
||||||
|
) -> None:
|
||||||
|
"""Add constraint: lo <= t_j - t_i <= hi."""
|
||||||
|
self._dist[i][j] = min(self._dist[i][j], hi)
|
||||||
|
self._dist[j][i] = min(self._dist[j][i], -lo)
|
||||||
|
|
||||||
|
def add_constraint(
|
||||||
|
self, a: Interval, relation: str, b: Interval, gap: Tuple[float, float] = (0, INF)
|
||||||
|
) -> None:
|
||||||
|
"""Add an Allen-style relation between two intervals.
|
||||||
|
|
||||||
|
Supported relations: before, after, meets, met_by, equals.
|
||||||
|
"""
|
||||||
|
rel = relation.lower()
|
||||||
|
if rel == 'before':
|
||||||
|
# a.end + gap <= b.start
|
||||||
|
self.add_distance_constraint(a.end, b.start, gap[0], gap[1])
|
||||||
|
elif rel == 'after':
|
||||||
|
self.add_distance_constraint(b.end, a.start, gap[0], gap[1])
|
||||||
|
elif rel == 'meets':
|
||||||
|
# a.end == b.start
|
||||||
|
self.add_distance_constraint(a.end, b.start, 0, 0)
|
||||||
|
elif rel == 'met_by':
|
||||||
|
self.add_distance_constraint(b.end, a.start, 0, 0)
|
||||||
|
elif rel == 'equals':
|
||||||
|
self.add_distance_constraint(a.start, b.start, 0, 0)
|
||||||
|
self.add_distance_constraint(a.end, b.end, 0, 0)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported relation: {relation}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Propagation (Floyd-Warshall)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def propagate(self) -> bool:
|
||||||
|
"""Run Floyd-Warshall to propagate all constraints.
|
||||||
|
|
||||||
|
Returns True if the network is consistent (no negative cycles).
|
||||||
|
"""
|
||||||
|
n = self._n
|
||||||
|
d = self._dist
|
||||||
|
for k in range(n):
|
||||||
|
for i in range(n):
|
||||||
|
for j in range(n):
|
||||||
|
if d[i][k] + d[k][j] < d[i][j]:
|
||||||
|
d[i][j] = d[i][k] + d[k][j]
|
||||||
|
# Check for negative cycles
|
||||||
|
for i in range(n):
|
||||||
|
if d[i][i] < 0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_consistent(self) -> bool:
|
||||||
|
"""Check consistency without mutating (copies matrix first)."""
|
||||||
|
import copy
|
||||||
|
saved = copy.deepcopy(self._dist)
|
||||||
|
result = self.propagate()
|
||||||
|
self._dist = saved
|
||||||
|
return result
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Query
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def earliest(self, point_idx: int) -> float:
|
||||||
|
"""Earliest possible time for a point (relative to origin)."""
|
||||||
|
return -self._dist[point_idx][self._origin_idx]
|
||||||
|
|
||||||
|
def latest(self, point_idx: int) -> float:
|
||||||
|
"""Latest possible time for a point (relative to origin)."""
|
||||||
|
return self._dist[self._origin_idx][point_idx]
|
||||||
|
|
||||||
|
def interval_bounds(self, interval: Interval) -> Dict[str, Tuple[float, float]]:
|
||||||
|
"""Return earliest/latest start and end for an interval."""
|
||||||
|
return {
|
||||||
|
'start': (self.earliest(interval.start), self.latest(interval.start)),
|
||||||
|
'end': (self.earliest(interval.end), self.latest(interval.end)),
|
||||||
|
}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Display
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def dump(self) -> None:
|
||||||
|
"""Print the current distance matrix and interval bounds."""
|
||||||
|
print(f"Temporal Network — {self._n} time-points, {len(self._intervals)} intervals")
|
||||||
|
print()
|
||||||
|
for name, interval in self._intervals.items():
|
||||||
|
bounds = self.interval_bounds(interval)
|
||||||
|
s_lo, s_hi = bounds['start']
|
||||||
|
e_lo, e_hi = bounds['end']
|
||||||
|
print(f" {name}:")
|
||||||
|
print(f" start: [{s_lo:.1f}, {s_hi:.1f}]")
|
||||||
|
print(f" end: [{e_lo:.1f}, {e_hi:.1f}]")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Demo: Timmy fleet deployment pipeline
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def run_demo() -> None:
|
||||||
|
"""Run a demo temporal reasoning scenario for the Timmy fleet."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Temporal Reasoner Demo - Fleet Deployment Pipeline")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
tn = TemporalNetwork()
|
||||||
|
|
||||||
|
# Define pipeline stages with duration bounds [min, max]
|
||||||
|
build = tn.add_interval('build', duration=(5, 15))
|
||||||
|
test = tn.add_interval('test', duration=(10, 30))
|
||||||
|
review = tn.add_interval('review', duration=(2, 10))
|
||||||
|
deploy = tn.add_interval('deploy', duration=(1, 5))
|
||||||
|
monitor = tn.add_interval('monitor', duration=(20, 60))
|
||||||
|
|
||||||
|
# Temporal constraints
|
||||||
|
tn.add_constraint(build, 'meets', test) # test starts when build ends
|
||||||
|
tn.add_constraint(test, 'before', review, gap=(0, 5)) # review within 5 of test
|
||||||
|
tn.add_constraint(review, 'meets', deploy) # deploy immediately after review
|
||||||
|
tn.add_constraint(deploy, 'before', monitor, gap=(0, 2)) # monitor within 2 of deploy
|
||||||
|
|
||||||
|
# Global deadline: everything done within 120 time units
|
||||||
|
tn.add_distance_constraint(tn._origin_idx, monitor.end, 0, 120)
|
||||||
|
|
||||||
|
# Build must start within first 10 units
|
||||||
|
tn.add_distance_constraint(tn._origin_idx, build.start, 0, 10)
|
||||||
|
|
||||||
|
print("Constraints added. Propagating...")
|
||||||
|
consistent = tn.propagate()
|
||||||
|
print(f"Network consistent: {consistent}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if consistent:
|
||||||
|
tn.dump()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Now add a conflicting constraint to show inconsistency detection
|
||||||
|
print("--- Adding conflicting constraint: monitor.before(build) ---")
|
||||||
|
tn.add_constraint(monitor, 'before', build)
|
||||||
|
consistent2 = tn.propagate()
|
||||||
|
print(f"Network consistent after conflict: {consistent2}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CLI
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="GOFAI temporal reasoning engine"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--demo",
|
||||||
|
action="store_true",
|
||||||
|
help="Run the fleet deployment pipeline demo",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.demo or not any(vars(args).values()):
|
||||||
|
run_demo()
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
62
tests/test_self_healing.py
Normal file
62
tests/test_self_healing.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""Tests for scripts/self_healing.py safe CLI behavior."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).parent.parent
|
||||||
|
spec = importlib.util.spec_from_file_location("self_healing", REPO_ROOT / "scripts" / "self_healing.py")
|
||||||
|
sh = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(sh)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMainCli:
|
||||||
|
def test_help_exits_without_running_healer(self, monkeypatch, capsys):
|
||||||
|
healer_cls = MagicMock()
|
||||||
|
monkeypatch.setattr(sh, "SelfHealer", healer_cls)
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit) as excinfo:
|
||||||
|
sh.main(["--help"])
|
||||||
|
|
||||||
|
assert excinfo.value.code == 0
|
||||||
|
healer_cls.assert_not_called()
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "--execute" in out
|
||||||
|
assert "--help-safe" in out
|
||||||
|
|
||||||
|
def test_help_safe_exits_without_running_healer(self, monkeypatch, capsys):
|
||||||
|
healer_cls = MagicMock()
|
||||||
|
monkeypatch.setattr(sh, "SelfHealer", healer_cls)
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit) as excinfo:
|
||||||
|
sh.main(["--help-safe"])
|
||||||
|
|
||||||
|
assert excinfo.value.code == 0
|
||||||
|
healer_cls.assert_not_called()
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "DRY-RUN" in out
|
||||||
|
assert "--confirm-kill" in out
|
||||||
|
|
||||||
|
def test_default_invocation_runs_in_dry_run_mode(self, monkeypatch):
|
||||||
|
healer = MagicMock()
|
||||||
|
healer_cls = MagicMock(return_value=healer)
|
||||||
|
monkeypatch.setattr(sh, "SelfHealer", healer_cls)
|
||||||
|
|
||||||
|
sh.main([])
|
||||||
|
|
||||||
|
healer_cls.assert_called_once_with(dry_run=True, confirm_kill=False, yes=False)
|
||||||
|
healer.run.assert_called_once_with()
|
||||||
|
|
||||||
|
def test_execute_flag_disables_dry_run(self, monkeypatch):
|
||||||
|
healer = MagicMock()
|
||||||
|
healer_cls = MagicMock(return_value=healer)
|
||||||
|
monkeypatch.setattr(sh, "SelfHealer", healer_cls)
|
||||||
|
|
||||||
|
sh.main(["--execute", "--yes", "--confirm-kill"])
|
||||||
|
|
||||||
|
healer_cls.assert_called_once_with(dry_run=False, confirm_kill=True, yes=True)
|
||||||
|
healer.run.assert_called_once_with()
|
||||||
93
tests/test_ssh_trust.py
Normal file
93
tests/test_ssh_trust.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""Tests for scripts/ssh_trust.py verified SSH trust helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).parent.parent
|
||||||
|
spec = importlib.util.spec_from_file_location("ssh_trust", REPO_ROOT / "scripts" / "ssh_trust.py")
|
||||||
|
ssh_trust = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(ssh_trust)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enroll_host_key_writes_scanned_key(tmp_path):
|
||||||
|
calls = []
|
||||||
|
known_hosts = tmp_path / "known_hosts"
|
||||||
|
|
||||||
|
def fake_run(argv, capture_output, text, timeout):
|
||||||
|
calls.append(argv)
|
||||||
|
return subprocess.CompletedProcess(
|
||||||
|
argv,
|
||||||
|
0,
|
||||||
|
stdout="example.com ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITestKey\n",
|
||||||
|
stderr="",
|
||||||
|
)
|
||||||
|
|
||||||
|
written_path = ssh_trust.enroll_host_key(
|
||||||
|
"example.com",
|
||||||
|
port=2222,
|
||||||
|
known_hosts_path=known_hosts,
|
||||||
|
runner=fake_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert written_path == known_hosts
|
||||||
|
assert known_hosts.read_text() == "example.com ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITestKey\n"
|
||||||
|
assert calls == [["ssh-keyscan", "-p", "2222", "-H", "example.com"]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_requires_known_hosts_or_auto_enroll(tmp_path):
|
||||||
|
executor = ssh_trust.VerifiedSSHExecutor(
|
||||||
|
known_hosts_path=tmp_path / "known_hosts",
|
||||||
|
auto_enroll=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ssh_trust.HostKeyEnrollmentError):
|
||||||
|
executor.plan("203.0.113.10", ["echo", "ok"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_remote_command_is_quoted_and_local_execution_stays_shell_free(tmp_path):
|
||||||
|
known_hosts = tmp_path / "known_hosts"
|
||||||
|
known_hosts.write_text("203.0.113.10 ssh-ed25519 AAAAC3NzaTest\n")
|
||||||
|
executor = ssh_trust.VerifiedSSHExecutor(known_hosts_path=known_hosts)
|
||||||
|
|
||||||
|
command = ["python3", "run_agent.py", "--task", "hello 'quoted' world"]
|
||||||
|
plan = executor.plan("203.0.113.10", command, port=2222)
|
||||||
|
|
||||||
|
expected_remote_command = shlex.join(command)
|
||||||
|
assert plan.local is False
|
||||||
|
assert plan.remote_command == expected_remote_command
|
||||||
|
assert plan.argv[-1] == expected_remote_command
|
||||||
|
assert "StrictHostKeyChecking=yes" in plan.argv
|
||||||
|
assert f"UserKnownHostsFile={known_hosts}" in plan.argv
|
||||||
|
assert plan.argv[-2] == "root@203.0.113.10"
|
||||||
|
|
||||||
|
local_plan = executor.plan("127.0.0.1", ["python3", "-V"], local=True)
|
||||||
|
assert local_plan.local is True
|
||||||
|
assert local_plan.argv == ["python3", "-V"]
|
||||||
|
assert local_plan.remote_command is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_raises_host_key_verification_error(tmp_path):
|
||||||
|
known_hosts = tmp_path / "known_hosts"
|
||||||
|
known_hosts.write_text("203.0.113.10 ssh-ed25519 AAAAC3NzaTest\n")
|
||||||
|
|
||||||
|
def fake_run(argv, capture_output, text, timeout):
|
||||||
|
return subprocess.CompletedProcess(
|
||||||
|
argv,
|
||||||
|
255,
|
||||||
|
stdout="",
|
||||||
|
stderr="Host key verification failed.\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = ssh_trust.VerifiedSSHExecutor(
|
||||||
|
known_hosts_path=known_hosts,
|
||||||
|
runner=fake_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ssh_trust.HostKeyVerificationError):
|
||||||
|
executor.run("203.0.113.10", ["true"])
|
||||||
Reference in New Issue
Block a user