Compare commits
71 Commits
fix/format
...
claude/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4214082fb6 | ||
|
|
ac28444bf2 | ||
|
|
91faf6f956 | ||
| a2a40429bd | |||
| ee61c5fa9d | |||
|
|
1fece10569 | ||
| 46668505bc | |||
| cac0c8224e | |||
| f38a64455d | |||
| 1b35a5a0d2 | |||
| 9172131b25 | |||
| 407eab3331 | |||
| cf090a966d | |||
| b65be9b12c | |||
| 3c1cff255e | |||
| 690d100afc | |||
| c6f0831738 | |||
| 30773ac1f9 | |||
| feb24bd08c | |||
| bc55f40505 | |||
| 2adc72335e | |||
| ab32670464 | |||
| bfc0231297 | |||
| cf2b09cf2f | |||
| 719bb537c0 | |||
| 0bcbcf19ac | |||
| 27d2f2ca0e | |||
| 7e7dcfa345 | |||
| ba0e614446 | |||
| 4f5e641c92 | |||
| d61bd141f9 | |||
| a4058af238 | |||
| 08432a5618 | |||
| a875c6ed91 | |||
| 07c5b5b83d | |||
| ba56567631 | |||
| 8ac26f54a5 | |||
| b807972d05 | |||
| 6b5a6db668 | |||
| b702249c12 | |||
|
|
8023c9b8f2 | ||
| 6eeee39c10 | |||
| b2d2d2c650 | |||
| bdd0f2709b | |||
| a9cbf7d69f | |||
| 4cdda8701d | |||
| a80d30b342 | |||
| f098cf8c4a | |||
| 30509b9c7c | |||
| ccaa1cb021 | |||
| c6f2855745 | |||
| 9d180f31cc | |||
| c17f64fa2c | |||
| bc7ffc2166 | |||
|
|
c22cdcaa8e | ||
|
|
ab968e910c | ||
|
|
73984ca72f | ||
| 436c800def | |||
| cb331da4f1 | |||
| fa892bfcb9 | |||
|
|
0b72884750 | ||
| a0ed1e6ff2 | |||
| b5ba272efe | |||
| 2e0dfe27df | |||
| d4cdfdc604 | |||
| e3436e36c3 | |||
| 34e7de6a4c | |||
| dbabe0e6ae | |||
| 517e2c571e | |||
| 0b019327a3 | |||
| 6b0fca6944 |
28
.gitea/workflows/lint.yml
Normal file
28
.gitea/workflows/lint.yml
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Check for hardcoded paths
|
||||
run: python3 scripts/lint_hardcoded_paths.py
|
||||
continue-on-error: true
|
||||
|
||||
- name: Check Python syntax
|
||||
run: |
|
||||
find . -name "*.py" -not -path "./.git/*" -not -path "./node_modules/*" | head -100 | xargs python3 -m py_compile || true
|
||||
78
.githooks/pre-commit-hardcoded-path.py
Normal file
78
.githooks/pre-commit-hardcoded-path.py
Normal file
@@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pre-commit hook: Reject hardcoded home-directory paths.
|
||||
|
||||
Install:
|
||||
cp pre-commit-hardcoded-path.py .git/hooks/pre-commit-hardcoded-path
|
||||
chmod +x .git/hooks/pre-commit-hardcoded-path
|
||||
|
||||
Or add to .pre-commit-config.yaml
|
||||
"""
|
||||
|
||||
import sys
|
||||
import subprocess
|
||||
import re
|
||||
|
||||
PATTERNS = [
|
||||
(r"/Users/[\w.\-]+/", "macOS home directory"),
|
||||
(r"/home/[\w.\-]+/", "Linux home directory"),
|
||||
(r"(?<![\w/])~/", "unexpanded tilde"),
|
||||
]
|
||||
|
||||
NOQA = re.compile(r"#\s*noqa:?\s*hardcoded-path-ok")
|
||||
|
||||
def get_staged_files():
|
||||
result = subprocess.run(
|
||||
["git", "diff", "--cached", "--name-only", "--diff-filter=ACM"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return [f for f in result.stdout.strip().split("\n") if f.endswith(".py")]
|
||||
|
||||
def check_file(filepath):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "show", f":{filepath}"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
content = result.stdout
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
violations = []
|
||||
for i, line in enumerate(content.split("\n"), 1):
|
||||
if line.strip().startswith("#"):
|
||||
continue
|
||||
if line.strip().startswith(("import ", "from ")):
|
||||
continue
|
||||
if NOQA.search(line):
|
||||
continue
|
||||
for pattern, desc in PATTERNS:
|
||||
if re.search(pattern, line):
|
||||
violations.append((filepath, i, line.strip(), desc))
|
||||
break
|
||||
return violations
|
||||
|
||||
def main():
|
||||
files = get_staged_files()
|
||||
if not files:
|
||||
sys.exit(0)
|
||||
|
||||
all_violations = []
|
||||
for f in files:
|
||||
all_violations.extend(check_file(f))
|
||||
|
||||
if all_violations:
|
||||
print("ERROR: Hardcoded home directory paths detected:")
|
||||
print()
|
||||
for filepath, line_no, line, desc in all_violations:
|
||||
print(f" {filepath}:{line_no}: {desc}")
|
||||
print(f" {line[:100]}")
|
||||
print()
|
||||
print("Fix: Use $HOME, relative paths, or get_hermes_home().")
|
||||
print("Override: Add '# noqa: hardcoded-path-ok' to the line.")
|
||||
sys.exit(1)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -25,6 +25,10 @@ jobs:
|
||||
- name: Install system dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ripgrep
|
||||
|
||||
- name: Check for hardcoded paths
|
||||
run: python3 scripts/lint_hardcoded_paths.py || true
|
||||
continue-on-error: true
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
|
||||
|
||||
|
||||
443
agent/a2a_mtls.py
Normal file
443
agent/a2a_mtls.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
A2A mutual-TLS server — secure agent-to-agent communication.
|
||||
|
||||
Each fleet agent runs an A2A server that:
|
||||
- Presents its own TLS certificate (signed by the fleet CA).
|
||||
- Requires the connecting peer to present a valid client certificate
|
||||
also signed by the fleet CA.
|
||||
- Rejects connections from unknown / self-signed peers.
|
||||
|
||||
Usage (standalone):
|
||||
python -m agent.a2a_mtls \\
|
||||
--cert ~/.hermes/pki/agents/timmy/timmy.crt \\
|
||||
--key ~/.hermes/pki/agents/timmy/timmy.key \\
|
||||
--ca ~/.hermes/pki/ca/fleet-ca.crt \\
|
||||
--host 0.0.0.0 --port 9443
|
||||
|
||||
Environment variables (alternative to CLI flags):
|
||||
HERMES_A2A_CERT path to agent certificate
|
||||
HERMES_A2A_KEY path to agent private key
|
||||
HERMES_A2A_CA path to fleet CA certificate
|
||||
|
||||
Refs #806
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import threading
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# mTLS SSL context helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_server_ssl_context(
|
||||
cert: str | Path,
|
||||
key: str | Path,
|
||||
ca: str | Path,
|
||||
) -> ssl.SSLContext:
|
||||
"""Return an SSLContext that presents *cert/key* and requires a valid
|
||||
client certificate signed by *ca*.
|
||||
|
||||
Raises ``FileNotFoundError`` if any path is missing.
|
||||
Raises ``ssl.SSLError`` if the files are malformed.
|
||||
"""
|
||||
cert, key, ca = Path(cert), Path(key), Path(ca)
|
||||
for p in (cert, key, ca):
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"mTLS: file not found: {p}")
|
||||
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
ctx.load_cert_chain(certfile=str(cert), keyfile=str(key))
|
||||
ctx.load_verify_locations(cafile=str(ca))
|
||||
# CERT_REQUIRED — reject peers that don't present a cert signed by *ca*.
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
return ctx
|
||||
|
||||
|
||||
def build_client_ssl_context(
|
||||
cert: str | Path,
|
||||
key: str | Path,
|
||||
ca: str | Path,
|
||||
) -> ssl.SSLContext:
|
||||
"""Return an SSLContext for an outgoing mTLS connection.
|
||||
|
||||
Presents *cert/key* as the client identity and verifies the server
|
||||
certificate against *ca*.
|
||||
"""
|
||||
cert, key, ca = Path(cert), Path(key), Path(ca)
|
||||
for p in (cert, key, ca):
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"mTLS client: file not found: {p}")
|
||||
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
ctx.load_cert_chain(certfile=str(cert), keyfile=str(key))
|
||||
ctx.load_verify_locations(cafile=str(ca))
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
ctx.check_hostname = True
|
||||
return ctx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal A2A HTTP request handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class A2AHandler(BaseHTTPRequestHandler):
|
||||
"""Handles A2A requests over a mutually-authenticated TLS connection.
|
||||
|
||||
GET /.well-known/agent-card.json — returns the local agent card.
|
||||
POST /a2a/task — dispatches an A2A task (stub).
|
||||
"""
|
||||
|
||||
log_message = logger.debug # route access log to Python logger
|
||||
|
||||
def do_GET(self) -> None: # noqa: N802
|
||||
if self.path in ("/.well-known/agent-card.json", "/agent-card.json"):
|
||||
self._serve_agent_card()
|
||||
else:
|
||||
self._send_json(404, {"error": "not found"})
|
||||
|
||||
def do_POST(self) -> None: # noqa: N802
|
||||
if self.path == "/a2a/task":
|
||||
self._handle_task()
|
||||
else:
|
||||
self._send_json(404, {"error": "not found"})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
def _serve_agent_card(self) -> None:
|
||||
try:
|
||||
from agent.agent_card import get_agent_card_json
|
||||
body = get_agent_card_json().encode()
|
||||
except Exception as exc:
|
||||
logger.warning("agent-card unavailable: %s", exc)
|
||||
body = b'{"error": "agent card unavailable"}'
|
||||
self._send_raw(200, "application/json", body)
|
||||
|
||||
def _handle_task(self) -> None:
|
||||
length = int(self.headers.get("Content-Length", 0))
|
||||
_body = self.rfile.read(length) if length else b""
|
||||
# Stub: echo back a 202 Accepted with the peer CN so callers can
|
||||
# confirm which agent processed the request.
|
||||
peer_cn = _peer_cn(self.connection)
|
||||
self._send_json(202, {"status": "accepted", "handled_by": peer_cn})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
def _send_json(self, code: int, data: dict) -> None:
|
||||
import json
|
||||
body = json.dumps(data).encode()
|
||||
self._send_raw(code, "application/json", body)
|
||||
|
||||
def _send_raw(self, code: int, content_type: str, body: bytes) -> None:
|
||||
self.send_response(code)
|
||||
self.send_header("Content-Type", content_type)
|
||||
self.send_header("Content-Length", str(len(body)))
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
|
||||
def log_message(self, fmt: str, *args: object) -> None: # type: ignore[override]
|
||||
logger.debug("a2a: " + fmt, *args)
|
||||
|
||||
|
||||
def _peer_cn(conn: ssl.SSLSocket) -> Optional[str]:
|
||||
"""Extract the Common Name from the peer certificate, or None."""
|
||||
try:
|
||||
peer = conn.getpeercert()
|
||||
if not peer:
|
||||
return None
|
||||
for rdn in peer.get("subject", ()):
|
||||
for key, val in rdn:
|
||||
if key == "commonName":
|
||||
return val
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class A2AServer:
|
||||
"""Mutual-TLS A2A server.
|
||||
|
||||
Example::
|
||||
|
||||
server = A2AServer(
|
||||
cert="~/.hermes/pki/agents/timmy/timmy.crt",
|
||||
key="~/.hermes/pki/agents/timmy/timmy.key",
|
||||
ca="~/.hermes/pki/ca/fleet-ca.crt",
|
||||
)
|
||||
server.start() # non-blocking (daemon thread)
|
||||
...
|
||||
server.stop()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cert: str | Path,
|
||||
key: str | Path,
|
||||
ca: str | Path,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 9443,
|
||||
) -> None:
|
||||
self.cert = Path(cert).expanduser()
|
||||
self.key = Path(key).expanduser()
|
||||
self.ca = Path(ca).expanduser()
|
||||
self.host = host
|
||||
self.port = port
|
||||
self._httpd: Optional[HTTPServer] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
def start(self, daemon: bool = True) -> None:
|
||||
"""Start the server in a background thread (default: daemon)."""
|
||||
ssl_ctx = build_server_ssl_context(self.cert, self.key, self.ca)
|
||||
self._httpd = HTTPServer((self.host, self.port), A2AHandler)
|
||||
self._httpd.socket = ssl_ctx.wrap_socket(
|
||||
self._httpd.socket, server_side=True
|
||||
)
|
||||
self._thread = threading.Thread(
|
||||
target=self._httpd.serve_forever, daemon=daemon
|
||||
)
|
||||
self._thread.start()
|
||||
logger.info(
|
||||
"A2A mTLS server listening on %s:%s (cert=%s)",
|
||||
self.host, self.port, self.cert.name,
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._httpd:
|
||||
self._httpd.shutdown()
|
||||
self._httpd = None
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
self._thread = None
|
||||
|
||||
|
||||
def server_from_env() -> A2AServer:
|
||||
"""Build an A2AServer from environment variables / defaults."""
|
||||
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
agent_name = os.environ.get("HERMES_AGENT_NAME", "hermes").lower()
|
||||
|
||||
default_cert = hermes_home / "pki" / "agents" / agent_name / f"{agent_name}.crt"
|
||||
default_key = hermes_home / "pki" / "agents" / agent_name / f"{agent_name}.key"
|
||||
default_ca = hermes_home / "pki" / "ca" / "fleet-ca.crt"
|
||||
|
||||
cert = os.environ.get("HERMES_A2A_CERT", str(default_cert))
|
||||
key = os.environ.get("HERMES_A2A_KEY", str(default_key))
|
||||
ca = os.environ.get("HERMES_A2A_CA", str(default_ca))
|
||||
host = os.environ.get("HERMES_A2A_HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("HERMES_A2A_PORT", "9443"))
|
||||
|
||||
return A2AServer(cert=cert, key=key, ca=ca, host=host, port=port)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _main() -> None:
|
||||
import argparse
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Hermes A2A mutual-TLS server"
|
||||
)
|
||||
parser.add_argument("--cert", required=True, help="Path to agent certificate")
|
||||
parser.add_argument("--key", required=True, help="Path to agent private key")
|
||||
parser.add_argument("--ca", required=True, help="Path to fleet CA certificate")
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=9443)
|
||||
args = parser.parse_args()
|
||||
|
||||
server = A2AServer(
|
||||
cert=args.cert, key=args.key, ca=args.ca,
|
||||
host=args.host, port=args.port,
|
||||
)
|
||||
server.start(daemon=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_main()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# A2AMTLSServer — routing-based server with context-manager support
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _RoutingHandler(BaseHTTPRequestHandler):
|
||||
"""HTTP request handler that dispatches to per-path callables."""
|
||||
|
||||
routes: Dict[str, Callable] = {}
|
||||
|
||||
def log_message(self, fmt: str, *args: Any) -> None:
|
||||
logger.debug("A2AMTLSServer: " + fmt, *args)
|
||||
|
||||
def _peer_cn(self) -> Optional[str]:
|
||||
cert = self.connection.getpeercert() # type: ignore[attr-defined]
|
||||
if not cert:
|
||||
return None
|
||||
for rdn in cert.get("subject", ()):
|
||||
for attr, value in rdn:
|
||||
if attr == "commonName":
|
||||
return value
|
||||
return None
|
||||
|
||||
def do_POST(self) -> None:
|
||||
handler = self.routes.get(self.path)
|
||||
if handler is None:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
return
|
||||
length = int(self.headers.get("Content-Length", 0))
|
||||
body = self.rfile.read(length) if length else b""
|
||||
try:
|
||||
payload = json.loads(body) if body else {}
|
||||
except json.JSONDecodeError:
|
||||
self.send_response(400)
|
||||
self.end_headers()
|
||||
return
|
||||
result = handler(payload, peer_cn=self._peer_cn())
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(result).encode())
|
||||
|
||||
def do_GET(self) -> None:
|
||||
handler = self.routes.get(self.path)
|
||||
if handler is None:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
return
|
||||
result = handler({}, peer_cn=self._peer_cn())
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(result).encode())
|
||||
|
||||
|
||||
class A2AMTLSServer:
|
||||
"""Routing-based mTLS HTTPS server with context-manager support.
|
||||
|
||||
Unlike ``A2AServer`` (which serves fixed A2A paths), this server lets
|
||||
callers register arbitrary path handlers — useful for tests and custom
|
||||
A2A endpoint implementations.
|
||||
|
||||
handler signature: ``handler(payload: dict, *, peer_cn: str | None) -> dict``
|
||||
|
||||
Example::
|
||||
|
||||
server = A2AMTLSServer(cert="timmy.crt", key="timmy.key", ca="fleet-ca.crt")
|
||||
server.add_route("/tasks/send", my_handler)
|
||||
with server:
|
||||
... # server runs for the duration of the block
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cert: str | Path,
|
||||
key: str | Path,
|
||||
ca: str | Path,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 9443,
|
||||
) -> None:
|
||||
self.cert = Path(cert).expanduser()
|
||||
self.key = Path(key).expanduser()
|
||||
self.ca = Path(ca).expanduser()
|
||||
self.host = host
|
||||
self.port = port
|
||||
self._routes: Dict[str, Callable] = {}
|
||||
self._httpd: Optional[HTTPServer] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
def add_route(self, path: str, handler: Callable) -> None:
|
||||
self._routes[path] = handler
|
||||
|
||||
def start(self) -> None:
|
||||
ssl_ctx = build_server_ssl_context(self.cert, self.key, self.ca)
|
||||
|
||||
class _Handler(_RoutingHandler):
|
||||
routes = self._routes
|
||||
|
||||
self._httpd = HTTPServer((self.host, self.port), _Handler)
|
||||
self._httpd.socket = ssl_ctx.wrap_socket(self._httpd.socket, server_side=True)
|
||||
self._thread = threading.Thread(
|
||||
target=self._httpd.serve_forever,
|
||||
daemon=True,
|
||||
name=f"a2a-mtls-{self.port}",
|
||||
)
|
||||
self._thread.start()
|
||||
logger.info("A2AMTLSServer on %s:%d (mTLS)", self.host, self.port)
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._httpd:
|
||||
self._httpd.shutdown()
|
||||
self._httpd = None
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
self._thread = None
|
||||
|
||||
def __enter__(self) -> "A2AMTLSServer":
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *_: Any) -> None:
|
||||
self.stop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# A2AMTLSClient — mTLS HTTP client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class A2AMTLSClient:
|
||||
"""HTTP client that presents a fleet cert on every outgoing connection.
|
||||
|
||||
Example::
|
||||
|
||||
client = A2AMTLSClient(cert="allegro.crt", key="allegro.key", ca="fleet-ca.crt")
|
||||
result = client.post("https://timmy:9443/tasks/send", json={"task": "..."})
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cert: str | Path,
|
||||
key: str | Path,
|
||||
ca: str | Path,
|
||||
) -> None:
|
||||
self._ssl_ctx = build_client_ssl_context(cert, key, ca)
|
||||
self._ssl_ctx.check_hostname = False # callers connecting by IP
|
||||
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
data: Optional[bytes] = None,
|
||||
timeout: float = 10.0,
|
||||
) -> Dict[str, Any]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
req = Request(url, data=data, headers=headers, method=method)
|
||||
try:
|
||||
with urlopen(req, context=self._ssl_ctx, timeout=timeout) as resp:
|
||||
body = resp.read()
|
||||
return json.loads(body) if body else {}
|
||||
except URLError as exc:
|
||||
raise ConnectionError(f"A2AMTLSClient {method} {url} failed: {exc.reason}") from exc
|
||||
|
||||
def get(self, url: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return self._request("GET", url, **kwargs)
|
||||
|
||||
def post(self, url: str, json: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Dict[str, Any]:
|
||||
data = (__import__("json").dumps(json).encode() if json is not None else None)
|
||||
return self._request("POST", url, data=data, **kwargs)
|
||||
273
agent/circuit_breaker.py
Normal file
273
agent/circuit_breaker.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Circuit Breaker for Error Cascading — #885
|
||||
|
||||
P(error | prev was error) = 58.6% vs P(error | prev was success) = 25.2%.
|
||||
That's a 2.33x cascade factor. After 3 consecutive errors, the circuit
|
||||
opens and the agent must take corrective action.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation, errors are counted
|
||||
- OPEN: Too many consecutive errors, corrective action required
|
||||
- HALF_OPEN: Testing if errors have cleared
|
||||
|
||||
Usage:
|
||||
from agent.circuit_breaker import CircuitBreaker, ToolCircuitBreaker
|
||||
|
||||
cb = ToolCircuitBreaker()
|
||||
|
||||
# After each tool call
|
||||
if not cb.record_result(success=True):
|
||||
# Circuit is open — take corrective action
|
||||
cb.get_recovery_action()
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Too many errors, block execution
|
||||
HALF_OPEN = "half_open" # Testing recovery
|
||||
|
||||
|
||||
@dataclass
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Generic circuit breaker with configurable thresholds.
|
||||
|
||||
Tracks consecutive errors and opens the circuit when the
|
||||
error streak exceeds the threshold.
|
||||
"""
|
||||
failure_threshold: int = 3
|
||||
recovery_timeout: float = 30.0 # seconds before trying half-open
|
||||
success_threshold: int = 2 # successes needed to close from half-open
|
||||
|
||||
state: CircuitState = field(default=CircuitState.CLOSED, init=False)
|
||||
consecutive_failures: int = field(default=0, init=False)
|
||||
consecutive_successes: int = field(default=0, init=False)
|
||||
last_failure_time: Optional[float] = field(default=None, init=False)
|
||||
total_trips: int = field(default=0, init=False)
|
||||
error_streaks: List[int] = field(default_factory=list, init=False)
|
||||
|
||||
def record_result(self, success: bool) -> bool:
|
||||
"""
|
||||
Record a tool call result. Returns True if circuit allows execution.
|
||||
|
||||
Returns:
|
||||
True if circuit is CLOSED or HALF_OPEN (execution allowed)
|
||||
False if circuit is OPEN (execution blocked)
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
if self.state == CircuitState.OPEN:
|
||||
# Check if recovery timeout has passed
|
||||
if self.last_failure_time and (now - self.last_failure_time) >= self.recovery_timeout:
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
self.consecutive_successes = 0
|
||||
return True # Allow one test execution
|
||||
return False # Still open
|
||||
|
||||
if success:
|
||||
self.consecutive_failures = 0
|
||||
self.consecutive_successes += 1
|
||||
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
if self.consecutive_successes >= self.success_threshold:
|
||||
self.state = CircuitState.CLOSED
|
||||
self.consecutive_successes = 0
|
||||
|
||||
return True
|
||||
else:
|
||||
self.consecutive_successes = 0
|
||||
self.consecutive_failures += 1
|
||||
self.last_failure_time = now
|
||||
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
# Failed during recovery — reopen immediately
|
||||
self.state = CircuitState.OPEN
|
||||
self.total_trips += 1
|
||||
return False
|
||||
|
||||
if self.consecutive_failures >= self.failure_threshold:
|
||||
self.state = CircuitState.OPEN
|
||||
self.total_trips += 1
|
||||
self.error_streaks.append(self.consecutive_failures)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def can_execute(self) -> bool:
|
||||
"""Check if execution is allowed."""
|
||||
if self.state == CircuitState.OPEN:
|
||||
if self.last_failure_time:
|
||||
now = time.time()
|
||||
if (now - self.last_failure_time) >= self.recovery_timeout:
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
self.consecutive_successes = 0
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Get current circuit state."""
|
||||
return {
|
||||
"state": self.state.value,
|
||||
"consecutive_failures": self.consecutive_failures,
|
||||
"consecutive_successes": self.consecutive_successes,
|
||||
"total_trips": self.total_trips,
|
||||
"max_streak": max(self.error_streaks) if self.error_streaks else 0,
|
||||
"can_execute": self.can_execute(),
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Reset the circuit breaker."""
|
||||
self.state = CircuitState.CLOSED
|
||||
self.consecutive_failures = 0
|
||||
self.consecutive_successes = 0
|
||||
self.last_failure_time = None
|
||||
|
||||
|
||||
class ToolCircuitBreaker(CircuitBreaker):
|
||||
"""
|
||||
Circuit breaker specifically for tool call error cascading.
|
||||
|
||||
Provides recovery actions when the circuit opens.
|
||||
"""
|
||||
|
||||
# Tools that are most effective at recovery (from audit data)
|
||||
RECOVERY_TOOLS = [
|
||||
"terminal", # Most effective — 2300 recoveries
|
||||
"read_file", # Reset context by reading something
|
||||
"search_files", # Find what went wrong
|
||||
]
|
||||
|
||||
def get_recovery_action(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the recommended recovery action when circuit is open.
|
||||
|
||||
Returns dict with action type and details.
|
||||
"""
|
||||
streak = self.consecutive_failures
|
||||
|
||||
if streak >= 9:
|
||||
# After 9 errors: 41/46 recoveries via terminal
|
||||
return {
|
||||
"action": "terminal_only",
|
||||
"reason": f"Error streak of {streak} — terminal is the only reliable recovery",
|
||||
"suggested_tool": "terminal",
|
||||
"suggested_command": "echo 'Resetting context'",
|
||||
"severity": "critical",
|
||||
}
|
||||
elif streak >= 5:
|
||||
return {
|
||||
"action": "switch_tool_type",
|
||||
"reason": f"Error streak of {streak} — switch to a different tool category",
|
||||
"suggested_tools": ["read_file", "search_files", "terminal"],
|
||||
"severity": "high",
|
||||
}
|
||||
elif streak >= self.failure_threshold:
|
||||
return {
|
||||
"action": "ask_user",
|
||||
"reason": f"{streak} consecutive errors — ask user for guidance",
|
||||
"suggested_response": "I'm encountering repeated errors. Would you like me to try a different approach?",
|
||||
"severity": "medium",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"action": "continue",
|
||||
"reason": f"Error streak of {streak} — within tolerance",
|
||||
"severity": "low",
|
||||
}
|
||||
|
||||
def should_compress_context(self) -> bool:
|
||||
"""Determine if context compression would help recovery."""
|
||||
return self.consecutive_failures >= 5
|
||||
|
||||
def get_blocked_tool(self) -> Optional[str]:
|
||||
"""Get the tool that should be blocked (if any)."""
|
||||
if self.state == CircuitState.OPEN:
|
||||
return "last_failed_tool"
|
||||
return None
|
||||
|
||||
|
||||
class MultiToolCircuitBreaker:
|
||||
"""
|
||||
Manages per-tool circuit breakers and cross-tool cascade detection.
|
||||
|
||||
When one tool trips its breaker, related tools are also warned.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.breakers: Dict[str, ToolCircuitBreaker] = {}
|
||||
self.global_streak: int = 0
|
||||
self.last_tool: Optional[str] = None
|
||||
self.last_success: bool = True
|
||||
|
||||
def get_breaker(self, tool_name: str) -> ToolCircuitBreaker:
|
||||
"""Get or create a circuit breaker for a tool."""
|
||||
if tool_name not in self.breakers:
|
||||
self.breakers[tool_name] = ToolCircuitBreaker()
|
||||
return self.breakers[tool_name]
|
||||
|
||||
def record_result(self, tool_name: str, success: bool) -> bool:
|
||||
"""
|
||||
Record a tool call result. Returns True if execution should continue.
|
||||
"""
|
||||
breaker = self.get_breaker(tool_name)
|
||||
allowed = breaker.record_result(success)
|
||||
|
||||
# Track global streak
|
||||
if success:
|
||||
self.global_streak = 0
|
||||
self.last_success = True
|
||||
else:
|
||||
self.global_streak += 1
|
||||
self.last_success = False
|
||||
|
||||
self.last_tool = tool_name
|
||||
return allowed
|
||||
|
||||
def can_execute(self, tool_name: str) -> bool:
|
||||
"""Check if a specific tool can execute."""
|
||||
breaker = self.get_breaker(tool_name)
|
||||
return breaker.can_execute()
|
||||
|
||||
def get_global_state(self) -> Dict[str, Any]:
|
||||
"""Get overall circuit breaker state."""
|
||||
return {
|
||||
"global_streak": self.global_streak,
|
||||
"last_tool": self.last_tool,
|
||||
"last_success": self.last_success,
|
||||
"tool_states": {
|
||||
name: breaker.get_state()
|
||||
for name, breaker in self.breakers.items()
|
||||
if breaker.consecutive_failures > 0 or breaker.total_trips > 0
|
||||
},
|
||||
"any_open": any(b.state == CircuitState.OPEN for b in self.breakers.values()),
|
||||
}
|
||||
|
||||
def get_recovery_action(self) -> Dict[str, Any]:
|
||||
"""Get recovery action based on global state."""
|
||||
if self.global_streak == 0:
|
||||
return {"action": "continue", "reason": "No errors"}
|
||||
|
||||
# Find the breaker with the worst streak
|
||||
worst = max(self.breakers.values(), key=lambda b: b.consecutive_failures, default=None)
|
||||
if worst and worst.consecutive_failures > 0:
|
||||
return worst.get_recovery_action()
|
||||
|
||||
return {
|
||||
"action": "continue",
|
||||
"reason": f"Global streak: {self.global_streak}",
|
||||
"severity": "low",
|
||||
}
|
||||
|
||||
def reset_all(self):
|
||||
"""Reset all circuit breakers."""
|
||||
for breaker in self.breakers.values():
|
||||
breaker.reset()
|
||||
self.global_streak = 0
|
||||
self.last_success = True
|
||||
148
agent/context_budget.py
Normal file
148
agent/context_budget.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Context Budget Tracker - Prevent context window overflow
|
||||
|
||||
Poka-yoke: Visual warnings at 70%%, 85%%, 95%% capacity.
|
||||
Auto-checkpoint at 85%%. Pre-flight token estimation.
|
||||
|
||||
Issue: #838
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HERMES_HOME = Path.home() / ".hermes"
|
||||
CHECKPOINT_DIR = HERMES_HOME / "checkpoints"
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
THRESHOLD_WARNING = 0.70
|
||||
THRESHOLD_CRITICAL = 0.85
|
||||
THRESHOLD_DANGER = 0.95
|
||||
|
||||
|
||||
class ContextBudget:
|
||||
def __init__(self, context_limit: int = 128000, system_tokens: int = 0,
|
||||
used_tokens: int = 0, reserved_tokens: int = 2000):
|
||||
self.context_limit = context_limit
|
||||
self.system_tokens = system_tokens
|
||||
self.used_tokens = used_tokens
|
||||
self.reserved_tokens = reserved_tokens
|
||||
|
||||
@property
|
||||
def total_used(self) -> int:
|
||||
return self.system_tokens + self.used_tokens
|
||||
|
||||
@property
|
||||
def available(self) -> int:
|
||||
return max(0, self.context_limit - self.reserved_tokens)
|
||||
|
||||
@property
|
||||
def remaining(self) -> int:
|
||||
return max(0, self.available - self.total_used)
|
||||
|
||||
@property
|
||||
def utilization(self) -> float:
|
||||
return self.total_used / self.available if self.available > 0 else 1.0
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
return len(text) // CHARS_PER_TOKEN if text else 0
|
||||
|
||||
|
||||
def estimate_messages_tokens(messages: List[Dict]) -> int:
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total += estimate_tokens(content)
|
||||
if msg.get("tool_calls"):
|
||||
total += 100
|
||||
return total
|
||||
|
||||
|
||||
class ContextBudgetTracker:
|
||||
def __init__(self, context_limit: int = 128000, session_id: str = ""):
|
||||
self.budget = ContextBudget(context_limit=context_limit)
|
||||
self.session_id = session_id
|
||||
self._checkpointed = False
|
||||
self._warnings_given = set()
|
||||
|
||||
def update_from_messages(self, messages: List[Dict]):
|
||||
self.budget.used_tokens = estimate_messages_tokens(messages)
|
||||
|
||||
def can_fit(self, additional_tokens: int) -> bool:
|
||||
return self.budget.remaining >= additional_tokens
|
||||
|
||||
def preflight_check(self, text: str) -> Tuple[bool, str]:
|
||||
tokens = estimate_tokens(text)
|
||||
if not self.can_fit(tokens):
|
||||
return False, f"Cannot load: ~{tokens:,} tokens needed, {self.budget.remaining:,} remaining"
|
||||
would_util = (self.budget.total_used + tokens) / self.budget.available if self.budget.available > 0 else 1.0
|
||||
if would_util >= THRESHOLD_DANGER:
|
||||
return False, f"Would reach {would_util:.0%%} capacity. Summarize or start new session."
|
||||
if would_util >= THRESHOLD_CRITICAL:
|
||||
return True, f"Warning: will reach {would_util:.0%%} capacity."
|
||||
return True, ""
|
||||
|
||||
def get_warning(self) -> Optional[str]:
|
||||
util = self.budget.utilization
|
||||
if util >= THRESHOLD_DANGER and "danger" not in self._warnings_given:
|
||||
self._warnings_given.add("danger")
|
||||
return f"[CONTEXT CRITICAL: {util:.0%%} used -- {self.budget.remaining:,} tokens left. Summarize or start new session.]"
|
||||
if util >= THRESHOLD_CRITICAL and "critical" not in self._warnings_given:
|
||||
self._warnings_given.add("critical")
|
||||
self._auto_checkpoint()
|
||||
return f"[CONTEXT WARNING: {util:.0%%} used -- consider summarizing. Auto-checkpoint saved.]"
|
||||
if util >= THRESHOLD_WARNING and "warning" not in self._warnings_given:
|
||||
self._warnings_given.add("warning")
|
||||
return f"[CONTEXT: {util:.0%%} used -- {self.budget.remaining:,} tokens remaining]"
|
||||
return None
|
||||
|
||||
def _auto_checkpoint(self):
|
||||
if self._checkpointed or not self.session_id:
|
||||
return
|
||||
try:
|
||||
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
path = CHECKPOINT_DIR / f"{self.session_id}.json"
|
||||
path.write_text(json.dumps({
|
||||
"session_id": self.session_id,
|
||||
"timestamp": time.time(),
|
||||
"budget": {"utilization": round(self.budget.utilization * 100, 1)}
|
||||
}, indent=2))
|
||||
self._checkpointed = True
|
||||
logger.info("Auto-checkpoint saved: %s", path)
|
||||
except Exception as e:
|
||||
logger.error("Auto-checkpoint failed: %s", e)
|
||||
|
||||
def get_status_line(self) -> str:
|
||||
util = self.budget.utilization
|
||||
remaining = self.budget.remaining
|
||||
if util >= THRESHOLD_DANGER:
|
||||
return f"RED {util:.0%%} used ({remaining:,} left)"
|
||||
elif util >= THRESHOLD_CRITICAL:
|
||||
return f"ORANGE {util:.0%%} used ({remaining:,} left)"
|
||||
elif util >= THRESHOLD_WARNING:
|
||||
return f"YELLOW {util:.0%%} used ({remaining:,} left)"
|
||||
return f"GREEN {util:.0%%} used ({remaining:,} left)"
|
||||
|
||||
|
||||
_tracker = None
|
||||
|
||||
def get_tracker(context_limit=128000, session_id=""):
|
||||
global _tracker
|
||||
if _tracker is None:
|
||||
_tracker = ContextBudgetTracker(context_limit, session_id)
|
||||
return _tracker
|
||||
|
||||
def check_context_budget(messages, context_limit=128000):
|
||||
tracker = get_tracker(context_limit)
|
||||
tracker.update_from_messages(messages)
|
||||
return tracker.get_warning()
|
||||
|
||||
def preflight_token_check(text):
|
||||
tracker = get_tracker()
|
||||
return tracker.preflight_check(text)
|
||||
149
agent/crisis_resources.py
Normal file
149
agent/crisis_resources.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
988 Suicide & Crisis Lifeline Integration (#673).
|
||||
|
||||
When crisis is detected, provides immediate access to help:
|
||||
- Phone: 988 (call or text)
|
||||
- Text: Text HOME to 988
|
||||
- Chat: 988lifeline.org/chat
|
||||
- Spanish: 1-888-628-9454
|
||||
- Emergency: 911
|
||||
|
||||
This module provides the resource data. agent/crisis_protocol.py
|
||||
handles detection. This module formats the resources for display.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrisisResource:
|
||||
"""A crisis support contact method."""
|
||||
name: str
|
||||
contact: str
|
||||
description: str
|
||||
url: str = ""
|
||||
available: str = "24/7"
|
||||
language: str = "English"
|
||||
|
||||
|
||||
# 988 Suicide & Crisis Lifeline — all channels
|
||||
LIFELINE_988 = CrisisResource(
|
||||
name="988 Suicide and Crisis Lifeline",
|
||||
contact="Call or text 988",
|
||||
description="Free, confidential support for people in suicidal crisis or emotional distress.",
|
||||
url="https://988lifeline.org",
|
||||
available="24/7",
|
||||
language="English",
|
||||
)
|
||||
|
||||
LIFELINE_988_TEXT = CrisisResource(
|
||||
name="988 Crisis Text Line",
|
||||
contact="Text HOME to 988",
|
||||
description="Free, 24/7 crisis support via text message.",
|
||||
url="",
|
||||
available="24/7",
|
||||
language="English",
|
||||
)
|
||||
|
||||
LIFELINE_988_CHAT = CrisisResource(
|
||||
name="988 Lifeline Chat",
|
||||
contact="988lifeline.org/chat",
|
||||
description="Free, confidential online chat with a trained crisis counselor.",
|
||||
url="https://988lifeline.org/chat",
|
||||
available="24/7",
|
||||
language="English",
|
||||
)
|
||||
|
||||
LIFELINE_988_SPANISH = CrisisResource(
|
||||
name="988 Lifeline (Spanish)",
|
||||
contact="1-888-628-9454",
|
||||
description="Línea de prevención del suicidio en español.",
|
||||
url="https://988lifeline.org/help-yourself/en-espanol/",
|
||||
available="24/7",
|
||||
language="Spanish",
|
||||
)
|
||||
|
||||
CRISIS_TEXT_LINE = CrisisResource(
|
||||
name="Crisis Text Line",
|
||||
contact="Text HOME to 741741",
|
||||
description="Free, 24/7 crisis support via text message.",
|
||||
url="https://www.crisistextline.org",
|
||||
available="24/7",
|
||||
language="English",
|
||||
)
|
||||
|
||||
EMERGENCY_911 = CrisisResource(
|
||||
name="Emergency Services",
|
||||
contact="911",
|
||||
description="Immediate danger — police, fire, ambulance.",
|
||||
url="",
|
||||
available="24/7",
|
||||
language="Any",
|
||||
)
|
||||
|
||||
# All resources in priority order
|
||||
ALL_RESOURCES: List[CrisisResource] = [
|
||||
EMERGENCY_911,
|
||||
LIFELINE_988,
|
||||
LIFELINE_988_TEXT,
|
||||
LIFELINE_988_CHAT,
|
||||
CRISIS_TEXT_LINE,
|
||||
LIFELINE_988_SPANISH,
|
||||
]
|
||||
|
||||
|
||||
def get_crisis_resources(language: str = None) -> List[CrisisResource]:
|
||||
"""Get crisis resources, optionally filtered by language.
|
||||
|
||||
Args:
|
||||
language: Filter by language ("English", "Spanish", or None for all)
|
||||
|
||||
Returns:
|
||||
List of CrisisResource objects
|
||||
"""
|
||||
if language:
|
||||
return [r for r in ALL_RESOURCES if r.language.lower() == language.lower()]
|
||||
return ALL_RESOURCES
|
||||
|
||||
|
||||
def format_crisis_resources(resources: List[CrisisResource] = None) -> str:
|
||||
"""Format crisis resources as a user-facing message.
|
||||
|
||||
Args:
|
||||
resources: List of resources to format. Defaults to all resources.
|
||||
|
||||
Returns:
|
||||
Formatted string suitable for displaying to a user in crisis.
|
||||
"""
|
||||
if resources is None:
|
||||
resources = ALL_RESOURCES
|
||||
|
||||
lines = ["**Please reach out — help is available right now:**
|
||||
"]
|
||||
|
||||
for r in resources:
|
||||
if r.url:
|
||||
lines.append(f"- **{r.name}:** {r.contact} ({r.url})")
|
||||
else:
|
||||
lines.append(f"- **{r.name}:** {r.contact}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("All services are free, confidential, and available 24/7.")
|
||||
lines.append("You are not alone.")
|
||||
|
||||
return "
|
||||
".join(lines)
|
||||
|
||||
|
||||
def get_immediate_help_message() -> str:
|
||||
"""Get the most urgent crisis help message.
|
||||
|
||||
Used when crisis is detected at CRITICAL level.
|
||||
"""
|
||||
return (
|
||||
"If you are in immediate danger, call **911** right now.
|
||||
|
||||
"
|
||||
+ format_crisis_resources()
|
||||
)
|
||||
184
agent/mtls.py
Normal file
184
agent/mtls.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
agent/mtls.py — Mutual TLS support for Hermes A2A communication.
|
||||
|
||||
Provides:
|
||||
- build_server_ssl_context() — SSL context for uvicorn that requires client certs
|
||||
- build_client_ssl_context() — SSL context for httpx/aiohttp A2A clients
|
||||
- MTLSMiddleware — FastAPI middleware that enforces client cert on A2A routes
|
||||
- is_mtls_configured() — Check if env vars are set
|
||||
|
||||
Configuration (environment variables):
|
||||
HERMES_MTLS_CERT Path to this agent's TLS certificate (PEM)
|
||||
HERMES_MTLS_KEY Path to this agent's TLS private key (PEM)
|
||||
HERMES_MTLS_CA Path to the Fleet CA certificate (PEM) — used to verify peers
|
||||
|
||||
All three must be set to enable mTLS. If any is missing, mTLS is disabled and
|
||||
the server falls back to plain HTTP (or regular TLS without client auth).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# A2A routes that require a valid client certificate when mTLS is enabled.
|
||||
_A2A_PATH_PREFIXES = (
|
||||
"/.well-known/agent-card",
|
||||
"/agent-card",
|
||||
"/api/agent-card",
|
||||
"/a2a/",
|
||||
)
|
||||
|
||||
|
||||
def _get_env(key: str) -> Optional[str]:
|
||||
val = os.environ.get(key, "").strip()
|
||||
return val or None
|
||||
|
||||
|
||||
def is_mtls_configured() -> bool:
|
||||
"""Return True if all three mTLS env vars are set and the files exist."""
|
||||
cert = _get_env("HERMES_MTLS_CERT")
|
||||
key = _get_env("HERMES_MTLS_KEY")
|
||||
ca = _get_env("HERMES_MTLS_CA")
|
||||
if not (cert and key and ca):
|
||||
return False
|
||||
for label, path in (("HERMES_MTLS_CERT", cert), ("HERMES_MTLS_KEY", key), ("HERMES_MTLS_CA", ca)):
|
||||
if not Path(path).is_file():
|
||||
logger.warning("mTLS disabled: %s file not found: %s", label, path)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def build_server_ssl_context() -> ssl.SSLContext:
|
||||
"""
|
||||
Build an SSL context for the A2A server that:
|
||||
- presents its own certificate
|
||||
- requires and verifies the client's certificate against the Fleet CA
|
||||
|
||||
Raises:
|
||||
RuntimeError: if mTLS env vars are not set or files are missing
|
||||
ssl.SSLError: if cert/key/CA files are invalid
|
||||
"""
|
||||
cert = _get_env("HERMES_MTLS_CERT")
|
||||
key = _get_env("HERMES_MTLS_KEY")
|
||||
ca = _get_env("HERMES_MTLS_CA")
|
||||
|
||||
if not (cert and key and ca):
|
||||
raise RuntimeError(
|
||||
"mTLS not configured. Set HERMES_MTLS_CERT, HERMES_MTLS_KEY, and HERMES_MTLS_CA."
|
||||
)
|
||||
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||||
ctx.load_verify_locations(cafile=ca)
|
||||
# CERT_REQUIRED: reject connections without a valid client cert
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
logger.info("mTLS server context built (cert=%s, CA=%s)", cert, ca)
|
||||
return ctx
|
||||
|
||||
|
||||
def build_client_ssl_context() -> ssl.SSLContext:
|
||||
"""
|
||||
Build an SSL context for outbound A2A connections that:
|
||||
- presents this agent's certificate as a client cert
|
||||
- verifies the remote server against the Fleet CA
|
||||
|
||||
Raises:
|
||||
RuntimeError: if mTLS env vars are not set or files are missing
|
||||
ssl.SSLError: if cert/key/CA files are invalid
|
||||
"""
|
||||
cert = _get_env("HERMES_MTLS_CERT")
|
||||
key = _get_env("HERMES_MTLS_KEY")
|
||||
ca = _get_env("HERMES_MTLS_CA")
|
||||
|
||||
if not (cert and key and ca):
|
||||
raise RuntimeError(
|
||||
"mTLS not configured. Set HERMES_MTLS_CERT, HERMES_MTLS_KEY, and HERMES_MTLS_CA."
|
||||
)
|
||||
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||||
ctx.load_verify_locations(cafile=ca)
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
ctx.check_hostname = True
|
||||
logger.info("mTLS client context built (cert=%s, CA=%s)", cert, ca)
|
||||
return ctx
|
||||
|
||||
|
||||
def get_peer_cn(ssl_object) -> Optional[str]:
|
||||
"""Extract the CN from the peer certificate's subject, or None."""
|
||||
try:
|
||||
peer_cert = ssl_object.getpeercert()
|
||||
if not peer_cert:
|
||||
return None
|
||||
for rdn in peer_cert.get("subject", ()):
|
||||
for attr, value in rdn:
|
||||
if attr == "commonName":
|
||||
return value
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class MTLSMiddleware:
|
||||
"""
|
||||
ASGI middleware that enforces client certificate verification on A2A routes.
|
||||
|
||||
When mTLS is NOT configured (no env vars) or the route is not an A2A route,
|
||||
the request passes through unchanged.
|
||||
|
||||
When mTLS IS configured and the route matches an A2A prefix, the middleware
|
||||
checks that the request arrived over a TLS connection with a verified client
|
||||
certificate. If not, it returns HTTP 403.
|
||||
|
||||
Note: This middleware only provides defence-in-depth at the app layer.
|
||||
The primary enforcement is at the SSL context level (CERT_REQUIRED on the
|
||||
server context). This middleware is useful when the server runs behind a
|
||||
TLS-terminating proxy that forwards cert info via headers (not yet
|
||||
implemented) or for test-time injection.
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self._enabled = is_mtls_configured()
|
||||
if self._enabled:
|
||||
logger.info("MTLSMiddleware enabled — A2A routes require client cert")
|
||||
|
||||
def _is_a2a_route(self, path: str) -> bool:
|
||||
return any(path.startswith(prefix) for prefix in _A2A_PATH_PREFIXES)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http" and self._enabled and self._is_a2a_route(scope.get("path", "")):
|
||||
# Check for client cert in the SSL connection
|
||||
transport = scope.get("extensions", {}).get("tls", {})
|
||||
peer_cert = transport.get("peer_cert")
|
||||
if peer_cert is None:
|
||||
# No client cert — reject
|
||||
response = _forbidden_response("Client certificate required for A2A endpoints")
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def _forbidden_response(message: str):
|
||||
"""Return a minimal ASGI 403 response."""
|
||||
body = message.encode()
|
||||
|
||||
async def respond(scope, receive, send):
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 403,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain"),
|
||||
(b"content-length", str(len(body)).encode()),
|
||||
],
|
||||
})
|
||||
await send({"type": "http.response.body", "body": body})
|
||||
|
||||
return respond
|
||||
262
agent/profile_isolation.py
Normal file
262
agent/profile_isolation.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
Profile Session Isolation — #891
|
||||
|
||||
Tags sessions with their originating profile and provides
|
||||
filtered access so profiles cannot see each other's data.
|
||||
|
||||
Current state: All sessions share one state.db with no profile tag.
|
||||
This module adds profile tagging and filtered queries.
|
||||
|
||||
Usage:
|
||||
from agent.profile_isolation import tag_session, get_profile_sessions, get_active_profile
|
||||
|
||||
# Tag a new session with the current profile
|
||||
tag_session(session_id, profile_name)
|
||||
|
||||
# Get sessions for a specific profile
|
||||
sessions = get_profile_sessions("sprint")
|
||||
|
||||
# Get current active profile
|
||||
profile = get_active_profile()
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from datetime import datetime, timezone
|
||||
|
||||
HERMES_HOME = Path(os.getenv("HERMES_HOME", str(Path.home() / ".hermes")))
|
||||
SESSIONS_DB = HERMES_HOME / "sessions" / "state.db"
|
||||
PROFILE_TAGS_FILE = HERMES_HOME / "profile_session_tags.json"
|
||||
|
||||
|
||||
def get_active_profile() -> str:
|
||||
"""Get the currently active profile name."""
|
||||
config_path = HERMES_HOME / "config.yaml"
|
||||
if config_path.exists():
|
||||
try:
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
return cfg.get("active_profile", "default")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check environment
|
||||
return os.getenv("HERMES_PROFILE", "default")
|
||||
|
||||
|
||||
def _load_tags() -> Dict[str, str]:
|
||||
"""Load session-to-profile mapping."""
|
||||
if not PROFILE_TAGS_FILE.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(PROFILE_TAGS_FILE) as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_tags(tags: Dict[str, str]):
|
||||
"""Save session-to-profile mapping."""
|
||||
PROFILE_TAGS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(PROFILE_TAGS_FILE, "w") as f:
|
||||
json.dump(tags, f, indent=2)
|
||||
|
||||
|
||||
def tag_session(session_id: str, profile: Optional[str] = None) -> str:
|
||||
"""
|
||||
Tag a session with its originating profile.
|
||||
|
||||
Returns the profile name used.
|
||||
"""
|
||||
if profile is None:
|
||||
profile = get_active_profile()
|
||||
|
||||
tags = _load_tags()
|
||||
tags[session_id] = profile
|
||||
_save_tags(tags)
|
||||
|
||||
# Also tag in SQLite if available
|
||||
_tag_session_in_db(session_id, profile)
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
def _tag_session_in_db(session_id: str, profile: str):
|
||||
"""Add profile tag to SQLite session store."""
|
||||
if not SESSIONS_DB.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(SESSIONS_DB))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if sessions table has profile column
|
||||
cursor.execute("PRAGMA table_info(sessions)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if "profile" not in columns:
|
||||
# Add profile column
|
||||
cursor.execute("ALTER TABLE sessions ADD COLUMN profile TEXT DEFAULT 'default'")
|
||||
|
||||
# Update the session's profile
|
||||
cursor.execute(
|
||||
"UPDATE sessions SET profile = ? WHERE session_id = ?",
|
||||
(profile, session_id)
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass # SQLite might not be available or schema differs
|
||||
|
||||
|
||||
def get_session_profile(session_id: str) -> Optional[str]:
|
||||
"""Get the profile that owns a session."""
|
||||
# Check JSON tags first
|
||||
tags = _load_tags()
|
||||
if session_id in tags:
|
||||
return tags[session_id]
|
||||
|
||||
# Check SQLite
|
||||
if SESSIONS_DB.exists():
|
||||
try:
|
||||
conn = sqlite3.connect(str(SESSIONS_DB))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT profile FROM sessions WHERE session_id = ?",
|
||||
(session_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
if row:
|
||||
return row[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_profile_sessions(
|
||||
profile: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get sessions belonging to a specific profile.
|
||||
|
||||
Returns list of session dicts.
|
||||
"""
|
||||
if profile is None:
|
||||
profile = get_active_profile()
|
||||
|
||||
sessions = []
|
||||
|
||||
# Get from JSON tags
|
||||
tags = _load_tags()
|
||||
tagged_sessions = [sid for sid, p in tags.items() if p == profile]
|
||||
|
||||
# Get from SQLite with profile filter
|
||||
if SESSIONS_DB.exists():
|
||||
try:
|
||||
conn = sqlite3.connect(str(SESSIONS_DB))
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Try profile column first
|
||||
try:
|
||||
cursor.execute(
|
||||
"SELECT * FROM sessions WHERE profile = ? ORDER BY updated_at DESC LIMIT ?",
|
||||
(profile, limit)
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
sessions.append(dict(row))
|
||||
except Exception:
|
||||
# Fallback: filter by tagged session IDs
|
||||
if tagged_sessions:
|
||||
placeholders = ",".join("?" * len(tagged_sessions[:limit]))
|
||||
cursor.execute(
|
||||
f"SELECT * FROM sessions WHERE session_id IN ({placeholders}) ORDER BY updated_at DESC LIMIT ?",
|
||||
(*tagged_sessions[:limit], limit)
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
sessions.append(dict(row))
|
||||
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return sessions[:limit]
|
||||
|
||||
|
||||
def filter_sessions_by_profile(
|
||||
sessions: List[Dict[str, Any]],
|
||||
profile: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Filter a list of sessions to only include those belonging to a profile."""
|
||||
if profile is None:
|
||||
profile = get_active_profile()
|
||||
|
||||
tags = _load_tags()
|
||||
filtered = []
|
||||
|
||||
for session in sessions:
|
||||
sid = session.get("session_id") or session.get("id")
|
||||
if not sid:
|
||||
continue
|
||||
|
||||
# Check tag
|
||||
session_profile = tags.get(sid)
|
||||
if session_profile is None:
|
||||
# Check SQLite
|
||||
session_profile = get_session_profile(sid)
|
||||
|
||||
if session_profile == profile or session_profile is None:
|
||||
filtered.append(session)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def get_profile_stats() -> Dict[str, Any]:
|
||||
"""Get statistics about profile session distribution."""
|
||||
tags = _load_tags()
|
||||
|
||||
profile_counts = {}
|
||||
for sid, profile in tags.items():
|
||||
profile_counts[profile] = profile_counts.get(profile, 0) + 1
|
||||
|
||||
total_tagged = len(tags)
|
||||
profiles = list(profile_counts.keys())
|
||||
|
||||
return {
|
||||
"total_tagged_sessions": total_tagged,
|
||||
"profiles": profiles,
|
||||
"profile_counts": profile_counts,
|
||||
"active_profile": get_active_profile(),
|
||||
}
|
||||
|
||||
|
||||
def audit_untagged_sessions() -> List[str]:
|
||||
"""Find sessions without a profile tag."""
|
||||
if not SESSIONS_DB.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(SESSIONS_DB))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get all session IDs
|
||||
cursor.execute("SELECT session_id FROM sessions")
|
||||
all_sessions = {row[0] for row in cursor.fetchall()}
|
||||
conn.close()
|
||||
|
||||
# Get tagged sessions
|
||||
tags = _load_tags()
|
||||
tagged = set(tags.keys())
|
||||
|
||||
# Return untagged
|
||||
return list(all_sessions - tagged)
|
||||
except Exception:
|
||||
return []
|
||||
146
agent/provider_preflight.py
Normal file
146
agent/provider_preflight.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Provider Preflight — Poka-yoke validation of provider/model config.
|
||||
|
||||
Validates provider and model configuration before session start.
|
||||
Prevents wasted context on misconfigured providers.
|
||||
|
||||
Usage:
|
||||
from agent.provider_preflight import preflight_check
|
||||
result = preflight_check(provider="openrouter", model="xiaomi/mimo-v2-pro")
|
||||
if not result["valid"]:
|
||||
print(result["error"])
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Provider -> required env var
|
||||
PROVIDER_KEYS = {
|
||||
"openrouter": "OPENROUTER_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"nous": "NOUS_API_KEY",
|
||||
"ollama": None, # Local, no key needed
|
||||
"local": None,
|
||||
}
|
||||
|
||||
|
||||
def check_provider_key(provider: str) -> Dict[str, Any]:
|
||||
"""Check if provider has a valid API key configured."""
|
||||
provider_lower = provider.lower().strip()
|
||||
|
||||
env_var = None
|
||||
for known, key in PROVIDER_KEYS.items():
|
||||
if known in provider_lower:
|
||||
env_var = key
|
||||
break
|
||||
|
||||
if env_var is None:
|
||||
# Unknown provider — assume OK (custom/local)
|
||||
return {"valid": True, "provider": provider, "key_status": "unknown"}
|
||||
|
||||
if env_var is None:
|
||||
# Local provider, no key needed
|
||||
return {"valid": True, "provider": provider, "key_status": "not_required"}
|
||||
|
||||
key_value = os.getenv(env_var, "").strip()
|
||||
if not key_value:
|
||||
return {
|
||||
"valid": False,
|
||||
"provider": provider,
|
||||
"key_status": "missing",
|
||||
"error": f"{env_var} is not set. Provider '{provider}' will fail.",
|
||||
"fix": f"Set {env_var} in ~/.hermes/.env",
|
||||
}
|
||||
|
||||
if len(key_value) < 10:
|
||||
return {
|
||||
"valid": False,
|
||||
"provider": provider,
|
||||
"key_status": "too_short",
|
||||
"error": f"{env_var} is suspiciously short ({len(key_value)} chars). May be invalid.",
|
||||
"fix": f"Verify {env_var} value in ~/.hermes/.env",
|
||||
}
|
||||
|
||||
return {"valid": True, "provider": provider, "key_status": "set"}
|
||||
|
||||
|
||||
def check_model_availability(model: str, provider: str) -> Dict[str, Any]:
|
||||
"""Check if model is likely available for provider."""
|
||||
if not model:
|
||||
return {"valid": False, "error": "No model specified"}
|
||||
|
||||
# Basic sanity checks
|
||||
model_lower = model.lower()
|
||||
|
||||
# Anthropic models should use anthropic provider
|
||||
if "claude" in model_lower and "anthropic" not in provider.lower():
|
||||
return {
|
||||
"valid": True, # Allow but warn
|
||||
"warning": f"Model '{model}' usually runs on Anthropic provider, not '{provider}'",
|
||||
}
|
||||
|
||||
# Ollama models
|
||||
ollama_indicators = ["llama", "mistral", "qwen", "gemma", "phi", "hermes"]
|
||||
if any(x in model_lower for x in ollama_indicators) and ":" not in model:
|
||||
return {
|
||||
"valid": True,
|
||||
"warning": f"Model '{model}' may need a version tag for Ollama (e.g., {model}:latest)",
|
||||
}
|
||||
|
||||
return {"valid": True}
|
||||
|
||||
|
||||
def preflight_check(
|
||||
provider: str = "",
|
||||
model: str = "",
|
||||
fallback_provider: str = "",
|
||||
fallback_model: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""Full pre-flight check for provider/model configuration.
|
||||
|
||||
Returns:
|
||||
Dict with valid (bool), errors (list), warnings (list).
|
||||
"""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
# Check primary provider
|
||||
if provider:
|
||||
result = check_provider_key(provider)
|
||||
if not result["valid"]:
|
||||
errors.append(result.get("error", f"Provider {provider} invalid"))
|
||||
|
||||
# Check primary model
|
||||
if model:
|
||||
result = check_model_availability(model, provider)
|
||||
if not result["valid"]:
|
||||
errors.append(result.get("error", f"Model {model} invalid"))
|
||||
elif result.get("warning"):
|
||||
warnings.append(result["warning"])
|
||||
|
||||
# Check fallback
|
||||
if fallback_provider:
|
||||
result = check_provider_key(fallback_provider)
|
||||
if not result["valid"]:
|
||||
warnings.append(f"Fallback provider {fallback_provider} also invalid: {result.get('error','')}")
|
||||
|
||||
if fallback_model:
|
||||
result = check_model_availability(fallback_model, fallback_provider)
|
||||
if not result["valid"]:
|
||||
warnings.append(f"Fallback model {fallback_model} invalid")
|
||||
elif result.get("warning"):
|
||||
warnings.append(result["warning"])
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
}
|
||||
146
agent/time_aware_routing.py
Normal file
146
agent/time_aware_routing.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Time-aware model routing for cron jobs.
|
||||
|
||||
Routes cron tasks to more capable models during off-hours when the user
|
||||
is not present to correct errors. Reduces error rates during high-error
|
||||
time windows (e.g., 18:00 evening batches).
|
||||
|
||||
Usage:
|
||||
from agent.time_aware_routing import resolve_time_aware_model
|
||||
model = resolve_time_aware_model(base_model="mimo-v2-pro", is_cron=True)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
# Error rate data from empirical audit (2026-04-12)
|
||||
# Higher error rates during these hours suggest routing to better models
|
||||
_HIGH_ERROR_HOURS = {
|
||||
18: 9.4, # 18:00 — 9.4% error rate (evening cron batches)
|
||||
19: 8.1,
|
||||
20: 7.5,
|
||||
21: 6.8,
|
||||
22: 6.2,
|
||||
23: 5.9,
|
||||
0: 5.5,
|
||||
1: 5.2,
|
||||
}
|
||||
|
||||
# Low error hours — default model is fine
|
||||
_LOW_ERROR_HOURS = set(range(6, 18)) # 06:00-17:59
|
||||
|
||||
# Default fallback models by time zone
|
||||
_DEFAULT_STRONG_MODEL = os.getenv("CRON_STRONG_MODEL", "xiaomi/mimo-v2-pro")
|
||||
_DEFAULT_CHEAP_MODEL = os.getenv("CRON_CHEAP_MODEL", "qwen2.5:7b")
|
||||
_ERROR_THRESHOLD = float(os.getenv("CRON_ERROR_THRESHOLD", "6.0")) # % error rate
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingDecision:
|
||||
"""Result of time-aware routing."""
|
||||
model: str
|
||||
provider: str
|
||||
reason: str
|
||||
hour: int
|
||||
error_rate: float
|
||||
is_off_hours: bool
|
||||
|
||||
|
||||
def get_hour_error_rate(hour: int) -> float:
|
||||
"""Get expected error rate for a given hour (0-23)."""
|
||||
return _HIGH_ERROR_HOURS.get(hour, 4.0) # Default 4% for unlisted hours
|
||||
|
||||
|
||||
def is_off_hours(hour: int) -> bool:
|
||||
"""Check if hour is considered off-hours (higher error rates)."""
|
||||
return hour not in _LOW_ERROR_HOURS
|
||||
|
||||
|
||||
def resolve_time_aware_model(
|
||||
base_model: str = "",
|
||||
base_provider: str = "",
|
||||
is_cron: bool = False,
|
||||
hour: Optional[int] = None,
|
||||
) -> RoutingDecision:
|
||||
"""Resolve model based on time of day and task type.
|
||||
|
||||
During off-hours (evening/night), routes to stronger models for cron
|
||||
jobs to compensate for lack of human oversight.
|
||||
|
||||
Args:
|
||||
base_model: The model that would normally be used.
|
||||
base_provider: The provider for the base model.
|
||||
is_cron: Whether this is a cron job (vs interactive session).
|
||||
hour: Override hour (for testing). Defaults to current hour.
|
||||
|
||||
Returns:
|
||||
RoutingDecision with model, provider, and reasoning.
|
||||
"""
|
||||
if hour is None:
|
||||
hour = time.localtime().tm_hour
|
||||
|
||||
error_rate = get_hour_error_rate(hour)
|
||||
off_hours = is_off_hours(hour)
|
||||
|
||||
# Interactive sessions always use the base model (user can correct errors)
|
||||
if not is_cron:
|
||||
return RoutingDecision(
|
||||
model=base_model or _DEFAULT_CHEAP_MODEL,
|
||||
provider=base_provider,
|
||||
reason="Interactive session — user can correct errors",
|
||||
hour=hour,
|
||||
error_rate=error_rate,
|
||||
is_off_hours=off_hours,
|
||||
)
|
||||
|
||||
# Cron jobs during low-error hours: use base model
|
||||
if not off_hours and error_rate < _ERROR_THRESHOLD:
|
||||
return RoutingDecision(
|
||||
model=base_model or _DEFAULT_CHEAP_MODEL,
|
||||
provider=base_provider,
|
||||
reason=f"Low-error hours ({hour}:00, {error_rate}% expected)",
|
||||
hour=hour,
|
||||
error_rate=error_rate,
|
||||
is_off_hours=False,
|
||||
)
|
||||
|
||||
# Cron jobs during high-error hours: upgrade to stronger model
|
||||
if error_rate >= _ERROR_THRESHOLD:
|
||||
return RoutingDecision(
|
||||
model=_DEFAULT_STRONG_MODEL,
|
||||
provider="nous",
|
||||
reason=f"High-error hours ({hour}:00, {error_rate}% expected) — using stronger model",
|
||||
hour=hour,
|
||||
error_rate=error_rate,
|
||||
is_off_hours=True,
|
||||
)
|
||||
|
||||
# Off-hours but low error: use base model
|
||||
return RoutingDecision(
|
||||
model=base_model or _DEFAULT_CHEAP_MODEL,
|
||||
provider=base_provider,
|
||||
reason=f"Off-hours but low error ({hour}:00, {error_rate}%)",
|
||||
hour=hour,
|
||||
error_rate=error_rate,
|
||||
is_off_hours=off_hours,
|
||||
)
|
||||
|
||||
|
||||
def get_routing_report() -> str:
|
||||
"""Get a report of time-based routing decisions for the next 24 hours."""
|
||||
lines = ["Time-Aware Model Routing (24h forecast)", "=" * 40, ""]
|
||||
lines.append(f"Error threshold: {_ERROR_THRESHOLD}%")
|
||||
lines.append(f"Strong model: {_DEFAULT_STRONG_MODEL}")
|
||||
lines.append(f"Cheap model: {_DEFAULT_CHEAP_MODEL}")
|
||||
lines.append("")
|
||||
|
||||
for h in range(24):
|
||||
decision = resolve_time_aware_model(is_cron=True, hour=h)
|
||||
icon = "\U0001f7e2" if decision.model == _DEFAULT_CHEAP_MODEL else "\U0001f534"
|
||||
lines.append(f" {h:02d}:00 {icon} {decision.model:25s} ({decision.error_rate}% error)")
|
||||
|
||||
return "\n".join(lines)
|
||||
316
agent/token_budget.py
Normal file
316
agent/token_budget.py
Normal file
@@ -0,0 +1,316 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Token Budget — Poka-yoke guard against silent context overflow.
|
||||
|
||||
Progressive warning system with circuit breakers:
|
||||
- 60%: WARNING — log + suggest summarization
|
||||
- 80%: CAUTION — auto-compress, drop raw tool outputs
|
||||
- 90%: CRITICAL — block verbose tool calls, force wrap-up
|
||||
- 95%: STOP — graceful session termination with summary
|
||||
|
||||
Also provides tool output budgeting to truncate before overflow.
|
||||
|
||||
Usage:
|
||||
from agent.token_budget import TokenBudget
|
||||
|
||||
budget = TokenBudget(context_length=128_000)
|
||||
budget.update(8000) # from API response prompt_tokens
|
||||
|
||||
status = budget.check() # returns BudgetStatus with level + message
|
||||
budget.should_block_tools() # True at 90%+
|
||||
budget.should_terminate() # True at 95%+
|
||||
|
||||
# Tool output budgeting
|
||||
remaining = budget.tool_output_budget()
|
||||
truncated = budget.truncate_tool_output(output_text, max_chars=remaining)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Thresholds ────────────────────────────────────────────────────────
|
||||
|
||||
WARN_PERCENT = 0.60
|
||||
CAUTION_PERCENT = 0.80
|
||||
CRITICAL_PERCENT = 0.90
|
||||
STOP_PERCENT = 0.95
|
||||
|
||||
# Reserve 5% of context for system prompt, response, and overhead
|
||||
RESPONSE_RESERVE_RATIO = 0.05
|
||||
|
||||
# Max tool output chars at each level
|
||||
TOOL_OUTPUT_BUDGETS = {
|
||||
"NORMAL": 50_000,
|
||||
"WARNING": 20_000,
|
||||
"CAUTION": 8_000,
|
||||
"CRITICAL": 2_000,
|
||||
"STOP": 500,
|
||||
}
|
||||
|
||||
|
||||
class BudgetLevel(Enum):
|
||||
NORMAL = "NORMAL"
|
||||
WARNING = "WARNING"
|
||||
CAUTION = "CAUTION"
|
||||
CRITICAL = "CRITICAL"
|
||||
STOP = "STOP"
|
||||
|
||||
@property
|
||||
def percent_threshold(self) -> float:
|
||||
return {
|
||||
BudgetLevel.NORMAL: 0.0,
|
||||
BudgetLevel.WARNING: WARN_PERCENT,
|
||||
BudgetLevel.CAUTION: CAUTION_PERCENT,
|
||||
BudgetLevel.CRITICAL: CRITICAL_PERCENT,
|
||||
BudgetLevel.STOP: STOP_PERCENT,
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def emoji(self) -> str:
|
||||
return {
|
||||
BudgetLevel.NORMAL: "",
|
||||
BudgetLevel.WARNING: "\u26a0\ufe0f",
|
||||
BudgetLevel.CAUTION: "\U0001f525",
|
||||
BudgetLevel.CRITICAL: "\U0001f6d1",
|
||||
BudgetLevel.STOP: "\U0001f6d1",
|
||||
}[self]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BudgetStatus:
|
||||
"""Current token budget status."""
|
||||
level: BudgetLevel
|
||||
tokens_used: int
|
||||
context_length: int
|
||||
percent_used: float
|
||||
tokens_remaining: int
|
||||
message: str = ""
|
||||
should_compress: bool = False
|
||||
should_block_tools: bool = False
|
||||
should_terminate: bool = False
|
||||
|
||||
def to_indicator(self) -> str:
|
||||
"""Compact status indicator for CLI display."""
|
||||
pct = int(self.percent_used * 100)
|
||||
if self.level == BudgetLevel.NORMAL:
|
||||
return f"[{pct}%]"
|
||||
return f"{self.level.emoji} [{pct}%]"
|
||||
|
||||
def to_bar(self, width: int = 10) -> str:
|
||||
"""Visual progress bar."""
|
||||
filled = int(width * self.percent_used)
|
||||
bar = "\u2588" * filled + "\u2591" * (width - filled)
|
||||
color = self._bar_color()
|
||||
return f"{color}{bar}\033[0m {int(self.percent_used * 100)}%"
|
||||
|
||||
def _bar_color(self) -> str:
|
||||
if self.level == BudgetLevel.STOP:
|
||||
return "\033[41m" # red bg
|
||||
if self.level == BudgetLevel.CRITICAL:
|
||||
return "\033[31m" # red
|
||||
if self.level == BudgetLevel.CAUTION:
|
||||
return "\033[33m" # yellow
|
||||
if self.level == BudgetLevel.WARNING:
|
||||
return "\033[33m" # yellow
|
||||
return "\033[32m" # green
|
||||
|
||||
|
||||
class TokenBudget:
|
||||
"""
|
||||
Progressive token budget tracker with poka-yoke circuit breakers.
|
||||
|
||||
Tracks cumulative token usage against a context length and triggers
|
||||
escalating actions at each threshold.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_length: int,
|
||||
warn_percent: float = WARN_PERCENT,
|
||||
caution_percent: float = CAUTION_PERCENT,
|
||||
critical_percent: float = CRITICAL_PERCENT,
|
||||
stop_percent: float = STOP_PERCENT,
|
||||
response_reserve_ratio: float = RESPONSE_RESERVE_RATIO,
|
||||
):
|
||||
self.context_length = context_length
|
||||
self.warn_threshold = int(context_length * warn_percent)
|
||||
self.caution_threshold = int(context_length * caution_percent)
|
||||
self.critical_threshold = int(context_length * critical_percent)
|
||||
self.stop_threshold = int(context_length * stop_percent)
|
||||
self.response_reserve = int(context_length * response_reserve_ratio)
|
||||
|
||||
self.tokens_used = 0
|
||||
self.completions_tokens = 0
|
||||
self.total_tool_output_chars = 0
|
||||
self._level = BudgetLevel.NORMAL
|
||||
self._history: list[int] = []
|
||||
|
||||
def update(self, prompt_tokens: int, completion_tokens: int = 0) -> BudgetStatus:
|
||||
"""Update budget from API response usage."""
|
||||
self.tokens_used = prompt_tokens
|
||||
self.completions_tokens = completion_tokens
|
||||
self._history.append(prompt_tokens)
|
||||
return self.check()
|
||||
|
||||
def check(self) -> BudgetStatus:
|
||||
"""Evaluate current budget level and return status."""
|
||||
pct = self.tokens_used / self.context_length if self.context_length > 0 else 0
|
||||
remaining = max(0, self.context_length - self.tokens_used - self.response_reserve)
|
||||
|
||||
# Determine level
|
||||
if pct >= STOP_PERCENT:
|
||||
level = BudgetLevel.STOP
|
||||
elif pct >= CRITICAL_PERCENT:
|
||||
level = BudgetLevel.CRITICAL
|
||||
elif pct >= CAUTION_PERCENT:
|
||||
level = BudgetLevel.CAUTION
|
||||
elif pct >= WARN_PERCENT:
|
||||
level = BudgetLevel.WARNING
|
||||
else:
|
||||
level = BudgetLevel.NORMAL
|
||||
|
||||
# Log transitions (don\'t log every check)
|
||||
if level != self._level:
|
||||
self._log_transition(level, pct)
|
||||
self._level = level
|
||||
|
||||
messages = {
|
||||
BudgetLevel.NORMAL: "",
|
||||
BudgetLevel.WARNING: (
|
||||
f"Context at {int(pct*100)}%. Consider wrapping up soon or using /compress."
|
||||
),
|
||||
BudgetLevel.CAUTION: (
|
||||
f"Context at {int(pct*100)}%. Auto-compressing. "
|
||||
f"Tool outputs will be truncated."
|
||||
),
|
||||
BudgetLevel.CRITICAL: (
|
||||
f"Context at {int(pct*100)}%. Verbose tools blocked. "
|
||||
f"Session approaching limit — please wrap up."
|
||||
),
|
||||
BudgetLevel.STOP: (
|
||||
f"Context at {int(pct*100)}%. Session must terminate. "
|
||||
f"Saving summary before shutdown."
|
||||
),
|
||||
}
|
||||
|
||||
return BudgetStatus(
|
||||
level=level,
|
||||
tokens_used=self.tokens_used,
|
||||
context_length=self.context_length,
|
||||
percent_used=pct,
|
||||
tokens_remaining=remaining,
|
||||
message=messages[level],
|
||||
should_compress=level in (BudgetLevel.CAUTION, BudgetLevel.CRITICAL, BudgetLevel.STOP),
|
||||
should_block_tools=level in (BudgetLevel.CRITICAL, BudgetLevel.STOP),
|
||||
should_terminate=level == BudgetLevel.STOP,
|
||||
)
|
||||
|
||||
def should_compress(self) -> bool:
|
||||
"""True at 80%+ — auto-compression should trigger."""
|
||||
return self.tokens_used >= self.caution_threshold
|
||||
|
||||
def should_block_tools(self) -> bool:
|
||||
"""True at 90%+ — verbose tool calls should be blocked."""
|
||||
return self.tokens_used >= self.critical_threshold
|
||||
|
||||
def should_terminate(self) -> bool:
|
||||
"""True at 95%+ — session should gracefully terminate."""
|
||||
return self.tokens_used >= self.stop_threshold
|
||||
|
||||
def tool_output_budget(self) -> int:
|
||||
"""Max chars allowed for next tool output based on current level."""
|
||||
status = self.check()
|
||||
return TOOL_OUTPUT_BUDGETS.get(status.level.value, 50_000)
|
||||
|
||||
def truncate_tool_output(self, output: str, max_chars: int = None) -> str:
|
||||
"""Truncate tool output to fit budget. Adds truncation notice."""
|
||||
if max_chars is None:
|
||||
max_chars = self.tool_output_budget()
|
||||
|
||||
if len(output) <= max_chars:
|
||||
return output
|
||||
|
||||
# Preserve start and end, truncate middle
|
||||
if max_chars < 200:
|
||||
return output[:max_chars] + "\n[...truncated...]"
|
||||
|
||||
head = max_chars // 2
|
||||
tail = max_chars - head - 30 # reserve for truncation notice
|
||||
truncated = (
|
||||
output[:head]
|
||||
+ f"\n\n[...{len(output) - head - tail:,} chars truncated...]\n\n"
|
||||
+ output[-tail:]
|
||||
)
|
||||
return truncated
|
||||
|
||||
def remaining_for_response(self) -> int:
|
||||
"""Tokens available for the model\'s response."""
|
||||
return max(0, self.context_length - self.tokens_used - self.response_reserve)
|
||||
|
||||
def growth_rate(self) -> Optional[float]:
|
||||
"""Average token increase per turn (from history)."""
|
||||
if len(self._history) < 2:
|
||||
return None
|
||||
diffs = [self._history[i] - self._history[i-1] for i in range(1, len(self._history))]
|
||||
return sum(diffs) / len(diffs)
|
||||
|
||||
def turns_remaining(self) -> Optional[int]:
|
||||
"""Estimated turns until context is full (based on growth rate)."""
|
||||
rate = self.growth_rate()
|
||||
if rate is None or rate <= 0:
|
||||
return None
|
||||
remaining = self.context_length - self.tokens_used
|
||||
return int(remaining / rate)
|
||||
|
||||
def reset(self):
|
||||
"""Reset budget for new session."""
|
||||
self.tokens_used = 0
|
||||
self.completions_tokens = 0
|
||||
self.total_tool_output_chars = 0
|
||||
self._level = BudgetLevel.NORMAL
|
||||
self._history.clear()
|
||||
|
||||
def _log_transition(self, new_level: BudgetLevel, pct: float):
|
||||
"""Log budget level transitions."""
|
||||
msg = (
|
||||
f"Token budget: {self._level.value} -> {new_level.value} "
|
||||
f"({self.tokens_used}/{self.context_length} = {pct:.0%})"
|
||||
)
|
||||
if new_level == BudgetLevel.WARNING:
|
||||
logger.warning(msg)
|
||||
elif new_level == BudgetLevel.CAUTION:
|
||||
logger.warning(msg)
|
||||
elif new_level in (BudgetLevel.CRITICAL, BudgetLevel.STOP):
|
||||
logger.error(msg)
|
||||
else:
|
||||
logger.info(msg)
|
||||
|
||||
def summary(self) -> str:
|
||||
"""Human-readable budget summary."""
|
||||
status = self.check()
|
||||
turns = self.turns_remaining()
|
||||
rate = self.growth_rate()
|
||||
lines = [
|
||||
f"Token Budget: {status.tokens_used:,} / {status.context_length:,} ({status.percent_used:.0%})",
|
||||
f"Level: {status.level.value}",
|
||||
f"Remaining: {status.tokens_remaining:,} tokens",
|
||||
]
|
||||
if rate is not None:
|
||||
lines.append(f"Growth rate: ~{rate:,.0f} tokens/turn")
|
||||
if turns is not None:
|
||||
lines.append(f"Estimated turns left: ~{turns}")
|
||||
if status.message:
|
||||
lines.append(f"Action: {status.message}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ── Convenience factory ───────────────────────────────────────────────
|
||||
|
||||
def create_budget(context_length: int, **kwargs) -> TokenBudget:
|
||||
"""Create a TokenBudget with defaults."""
|
||||
return TokenBudget(context_length=context_length, **kwargs)
|
||||
156
agent/tool_fixation_detector.py
Normal file
156
agent/tool_fixation_detector.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Tool fixation detection — break repetitive tool calling loops.
|
||||
|
||||
Detects when the agent latches onto one tool and calls it repeatedly
|
||||
without making progress. Injects a nudge prompt to break the loop.
|
||||
|
||||
Usage:
|
||||
from agent.tool_fixation_detector import ToolFixationDetector
|
||||
detector = ToolFixationDetector()
|
||||
nudge = detector.record("execute_code")
|
||||
if nudge:
|
||||
# Inject nudge into conversation
|
||||
messages.append({"role": "system", "content": nudge})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
# Default thresholds
|
||||
_DEFAULT_THRESHOLD = int(os.getenv("TOOL_FIXATION_THRESHOLD", "5"))
|
||||
_DEFAULT_WINDOW = int(os.getenv("TOOL_FIXATION_WINDOW", "10"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class FixationEvent:
|
||||
"""Record of a fixation detection."""
|
||||
tool_name: str
|
||||
streak_length: int
|
||||
threshold: int
|
||||
nudge_sent: bool = False
|
||||
|
||||
|
||||
class ToolFixationDetector:
|
||||
"""Detects and breaks tool fixation loops.
|
||||
|
||||
Tracks the sequence of tool calls and detects when the same tool
|
||||
is called N times consecutively. When detected, returns a nudge
|
||||
prompt to inject into the conversation.
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: int = 0, window: int = 0):
|
||||
self.threshold = threshold or _DEFAULT_THRESHOLD
|
||||
self.window = window or _DEFAULT_WINDOW
|
||||
self._history: List[str] = []
|
||||
self._current_streak: str = ""
|
||||
self._streak_count: int = 0
|
||||
self._nudges_sent: int = 0
|
||||
self._events: List[FixationEvent] = []
|
||||
|
||||
@property
|
||||
def nudges_sent(self) -> int:
|
||||
return self._nudges_sent
|
||||
|
||||
@property
|
||||
def events(self) -> List[FixationEvent]:
|
||||
return list(self._events)
|
||||
|
||||
def record(self, tool_name: str) -> Optional[str]:
|
||||
"""Record a tool call and return nudge prompt if fixation detected.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that was called.
|
||||
|
||||
Returns:
|
||||
Nudge prompt string if fixation detected, None otherwise.
|
||||
"""
|
||||
self._history.append(tool_name)
|
||||
|
||||
# Trim history to window
|
||||
if len(self._history) > self.window:
|
||||
self._history = self._history[-self.window:]
|
||||
|
||||
# Update streak
|
||||
if tool_name == self._current_streak:
|
||||
self._streak_count += 1
|
||||
else:
|
||||
self._current_streak = tool_name
|
||||
self._streak_count = 1
|
||||
|
||||
# Check for fixation
|
||||
if self._streak_count >= self.threshold:
|
||||
event = FixationEvent(
|
||||
tool_name=tool_name,
|
||||
streak_length=self._streak_count,
|
||||
threshold=self.threshold,
|
||||
nudge_sent=True,
|
||||
)
|
||||
self._events.append(event)
|
||||
self._nudges_sent += 1
|
||||
|
||||
return self._build_nudge(tool_name, self._streak_count)
|
||||
|
||||
return None
|
||||
|
||||
def _build_nudge(self, tool_name: str, count: int) -> str:
|
||||
"""Build a nudge prompt to break the fixation loop."""
|
||||
return (
|
||||
f"[SYSTEM: You have called `{tool_name}` {count} times in a row "
|
||||
f"without switching tools. This suggests a fixation loop. "
|
||||
f"Consider:\n"
|
||||
f"1. Is the tool returning an error? Read the error carefully.\n"
|
||||
f"2. Is there a different tool that could help?\n"
|
||||
f"3. Should you ask the user for clarification?\n"
|
||||
f"4. Is the task actually complete?\n"
|
||||
f"Break the loop by trying a different approach.]"
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the detector state."""
|
||||
self._history.clear()
|
||||
self._current_streak = ""
|
||||
self._streak_count = 0
|
||||
|
||||
def get_streak_info(self) -> dict:
|
||||
"""Get current streak information."""
|
||||
return {
|
||||
"current_tool": self._current_streak,
|
||||
"streak_count": self._streak_count,
|
||||
"threshold": self.threshold,
|
||||
"at_threshold": self._streak_count >= self.threshold,
|
||||
"nudges_sent": self._nudges_sent,
|
||||
}
|
||||
|
||||
def format_report(self) -> str:
|
||||
"""Format fixation events as a report."""
|
||||
if not self._events:
|
||||
return "No tool fixation detected."
|
||||
|
||||
lines = [
|
||||
f"Tool Fixation Report ({len(self._events)} events)",
|
||||
"=" * 40,
|
||||
]
|
||||
for e in self._events:
|
||||
lines.append(f" {e.tool_name}: {e.streak_length} consecutive calls (threshold: {e.threshold})")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# Singleton
|
||||
_detector: Optional[ToolFixationDetector] = None
|
||||
|
||||
|
||||
def get_fixation_detector() -> ToolFixationDetector:
|
||||
"""Get or create the singleton detector."""
|
||||
global _detector
|
||||
if _detector is None:
|
||||
_detector = ToolFixationDetector()
|
||||
return _detector
|
||||
|
||||
|
||||
def reset_fixation_detector() -> None:
|
||||
"""Reset the singleton."""
|
||||
global _detector
|
||||
_detector = None
|
||||
32
ansible/fleet_mtls.yml
Normal file
32
ansible/fleet_mtls.yml
Normal file
@@ -0,0 +1,32 @@
|
||||
---
|
||||
# fleet_mtls.yml — Deploy mutual-TLS certificates to all fleet agents.
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Run scripts/gen_fleet_ca.sh to create the fleet CA.
|
||||
# 2. For each agent, run:
|
||||
# scripts/gen_agent_cert.sh --agent timmy
|
||||
# scripts/gen_agent_cert.sh --agent allegro
|
||||
# scripts/gen_agent_cert.sh --agent ezra
|
||||
#
|
||||
# Usage:
|
||||
# ansible-playbook -i inventory/fleet.ini ansible/fleet_mtls.yml
|
||||
#
|
||||
# Inventory example (inventory/fleet.ini):
|
||||
# [fleet]
|
||||
# timmy.local agent_name=timmy
|
||||
# allegro.local agent_name=allegro
|
||||
# ezra.local agent_name=ezra
|
||||
#
|
||||
# Refs #806
|
||||
|
||||
- name: Distribute fleet mTLS certificates
|
||||
hosts: fleet
|
||||
become: true
|
||||
vars:
|
||||
_pki_base: "{{ lookup('env', 'HOME') }}/.hermes/pki"
|
||||
roles:
|
||||
- role: hermes_mtls
|
||||
vars:
|
||||
hermes_mtls_local_ca_cert: "{{ _pki_base }}/ca/fleet-ca.crt"
|
||||
hermes_mtls_local_agent_cert: "{{ _pki_base }}/agents/{{ agent_name }}/{{ agent_name }}.crt"
|
||||
hermes_mtls_local_agent_key: "{{ _pki_base }}/agents/{{ agent_name }}/{{ agent_name }}.key"
|
||||
12
ansible/inventory/fleet.ini.example
Normal file
12
ansible/inventory/fleet.ini.example
Normal file
@@ -0,0 +1,12 @@
|
||||
# Example fleet inventory for mutual-TLS cert distribution.
|
||||
# Copy to fleet.ini and adjust hostnames/IPs.
|
||||
# Refs #806
|
||||
|
||||
[fleet_agents]
|
||||
timmy ansible_host=192.168.1.10
|
||||
allegro ansible_host=192.168.1.11
|
||||
ezra ansible_host=192.168.1.12
|
||||
|
||||
[fleet_agents:vars]
|
||||
ansible_user=hermes
|
||||
ansible_python_interpreter=/usr/bin/python3
|
||||
21
ansible/roles/fleet_mtls_certs/defaults/main.yml
Normal file
21
ansible/roles/fleet_mtls_certs/defaults/main.yml
Normal file
@@ -0,0 +1,21 @@
|
||||
---
|
||||
# Default paths on the *control node* where certs are read from.
|
||||
# Override these in your inventory / group_vars as needed.
|
||||
|
||||
# Fleet CA certificate (public; safe to push to all nodes)
|
||||
fleet_mtls_ca_cert_src: "{{ lookup('env', 'HOME') }}/.hermes/pki/ca/fleet-ca.crt"
|
||||
|
||||
# Per-agent cert/key source dir on the control node.
|
||||
# Expected layout: <fleet_mtls_agent_certs_dir>/<agent_name>/<agent_name>.{crt,key}
|
||||
fleet_mtls_agent_certs_dir: "{{ lookup('env', 'HOME') }}/.hermes/pki/agents"
|
||||
|
||||
# Remote destination paths on the fleet node
|
||||
fleet_mtls_remote_pki_dir: "/etc/hermes/pki"
|
||||
fleet_mtls_remote_ca_dir: "{{ fleet_mtls_remote_pki_dir }}/ca"
|
||||
fleet_mtls_remote_agent_dir: "{{ fleet_mtls_remote_pki_dir }}/agent"
|
||||
|
||||
# The agent name to deploy (set per-host in inventory, e.g. timmy / allegro / ezra)
|
||||
fleet_mtls_agent_name: "{{ inventory_hostname_short }}"
|
||||
|
||||
# Hermes service name (for reload notification)
|
||||
fleet_mtls_hermes_service: "hermes-a2a"
|
||||
7
ansible/roles/fleet_mtls_certs/handlers/main.yml
Normal file
7
ansible/roles/fleet_mtls_certs/handlers/main.yml
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
- name: Restart hermes-a2a
|
||||
ansible.builtin.systemd:
|
||||
name: "{{ fleet_mtls_hermes_service }}"
|
||||
state: restarted
|
||||
when: ansible_service_mgr == "systemd"
|
||||
ignore_errors: true # service may not exist in all environments
|
||||
17
ansible/roles/fleet_mtls_certs/meta/main.yml
Normal file
17
ansible/roles/fleet_mtls_certs/meta/main.yml
Normal file
@@ -0,0 +1,17 @@
|
||||
---
|
||||
galaxy_info:
|
||||
role_name: fleet_mtls_certs
|
||||
author: hermes-agent
|
||||
description: >
|
||||
Distribute fleet CA and per-agent mTLS certificates to Hermes fleet nodes.
|
||||
Part of issue #806 — A2A mutual TLS between fleet agents.
|
||||
min_ansible_version: "2.14"
|
||||
platforms:
|
||||
- name: Debian
|
||||
versions: [bookworm, bullseye]
|
||||
- name: Ubuntu
|
||||
versions: ["22.04", "24.04"]
|
||||
- name: EL
|
||||
versions: ["8", "9"]
|
||||
|
||||
dependencies: []
|
||||
99
ansible/roles/fleet_mtls_certs/tasks/main.yml
Normal file
99
ansible/roles/fleet_mtls_certs/tasks/main.yml
Normal file
@@ -0,0 +1,99 @@
|
||||
---
|
||||
# fleet_mtls_certs/tasks/main.yml
|
||||
#
|
||||
# Distribute the fleet CA certificate and the per-agent TLS cert+key to
|
||||
# each fleet node. Triggers a hermes-a2a service restart when any cert
|
||||
# changes.
|
||||
#
|
||||
# Refs #806 — A2A mutual TLS between fleet agents.
|
||||
|
||||
- name: Verify agent cert source files exist on control node
|
||||
ansible.builtin.stat:
|
||||
path: "{{ item }}"
|
||||
register: _src_stat
|
||||
delegate_to: localhost
|
||||
loop:
|
||||
- "{{ fleet_mtls_ca_cert_src }}"
|
||||
- "{{ fleet_mtls_agent_certs_dir }}/{{ fleet_mtls_agent_name }}/{{ fleet_mtls_agent_name }}.crt"
|
||||
- "{{ fleet_mtls_agent_certs_dir }}/{{ fleet_mtls_agent_name }}/{{ fleet_mtls_agent_name }}.key"
|
||||
loop_control:
|
||||
label: "{{ item | basename }}"
|
||||
|
||||
- name: Fail if any source cert is missing
|
||||
ansible.builtin.fail:
|
||||
msg: >
|
||||
Required cert file not found: {{ item.item }}
|
||||
Run scripts/gen_fleet_ca.sh and scripts/gen_agent_cert.sh --agent {{ fleet_mtls_agent_name }} first.
|
||||
when: not item.stat.exists
|
||||
loop: "{{ _src_stat.results }}"
|
||||
loop_control:
|
||||
label: "{{ item.item | basename }}"
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Remote directory structure
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
- name: Create remote PKI directories
|
||||
ansible.builtin.file:
|
||||
path: "{{ item }}"
|
||||
state: directory
|
||||
owner: root
|
||||
group: root
|
||||
mode: "0750"
|
||||
loop:
|
||||
- "{{ fleet_mtls_remote_pki_dir }}"
|
||||
- "{{ fleet_mtls_remote_ca_dir }}"
|
||||
- "{{ fleet_mtls_remote_agent_dir }}"
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Fleet CA certificate (public — read-only for all)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
- name: Deploy fleet CA certificate
|
||||
ansible.builtin.copy:
|
||||
src: "{{ fleet_mtls_ca_cert_src }}"
|
||||
dest: "{{ fleet_mtls_remote_ca_dir }}/fleet-ca.crt"
|
||||
owner: root
|
||||
group: root
|
||||
mode: "0644"
|
||||
notify: Restart hermes-a2a
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Per-agent certificate (public portion)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
- name: Deploy agent certificate
|
||||
ansible.builtin.copy:
|
||||
src: "{{ fleet_mtls_agent_certs_dir }}/{{ fleet_mtls_agent_name }}/{{ fleet_mtls_agent_name }}.crt"
|
||||
dest: "{{ fleet_mtls_remote_agent_dir }}/agent.crt"
|
||||
owner: root
|
||||
group: root
|
||||
mode: "0644"
|
||||
notify: Restart hermes-a2a
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Per-agent private key (secret — root-only read)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
- name: Deploy agent private key
|
||||
ansible.builtin.copy:
|
||||
src: "{{ fleet_mtls_agent_certs_dir }}/{{ fleet_mtls_agent_name }}/{{ fleet_mtls_agent_name }}.key"
|
||||
dest: "{{ fleet_mtls_remote_agent_dir }}/agent.key"
|
||||
owner: root
|
||||
group: root
|
||||
mode: "0600"
|
||||
no_log: true # suppress file content from Ansible output
|
||||
notify: Restart hermes-a2a
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Environment file for hermes-a2a systemd unit
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
- name: Write hermes-a2a environment file
|
||||
ansible.builtin.template:
|
||||
src: hermes_a2a_env.j2
|
||||
dest: /etc/hermes/a2a.env
|
||||
owner: root
|
||||
group: root
|
||||
mode: "0640"
|
||||
notify: Restart hermes-a2a
|
||||
10
ansible/roles/fleet_mtls_certs/templates/hermes_a2a_env.j2
Normal file
10
ansible/roles/fleet_mtls_certs/templates/hermes_a2a_env.j2
Normal file
@@ -0,0 +1,10 @@
|
||||
# Managed by Ansible — fleet_mtls_certs role
|
||||
# Environment variables for the hermes-a2a systemd service.
|
||||
# Source this file in the [Service] section: EnvironmentFile=/etc/hermes/a2a.env
|
||||
|
||||
HERMES_AGENT_NAME={{ fleet_mtls_agent_name }}
|
||||
HERMES_A2A_CERT={{ fleet_mtls_remote_agent_dir }}/agent.crt
|
||||
HERMES_A2A_KEY={{ fleet_mtls_remote_agent_dir }}/agent.key
|
||||
HERMES_A2A_CA={{ fleet_mtls_remote_ca_dir }}/fleet-ca.crt
|
||||
HERMES_A2A_HOST=0.0.0.0
|
||||
HERMES_A2A_PORT=9443
|
||||
21
ansible/roles/hermes_mtls/defaults/main.yml
Normal file
21
ansible/roles/hermes_mtls/defaults/main.yml
Normal file
@@ -0,0 +1,21 @@
|
||||
---
|
||||
# Ansible role: hermes_mtls
|
||||
# Distributes fleet mTLS certificates to Hermes agent nodes.
|
||||
#
|
||||
# Required variables (set in inventory / group_vars / --extra-vars):
|
||||
# hermes_mtls_local_ca_cert Local path on the Ansible controller to fleet-ca.crt
|
||||
# hermes_mtls_local_agent_cert Local path to this agent's .crt file
|
||||
# hermes_mtls_local_agent_key Local path to this agent's .key file
|
||||
#
|
||||
# Optional overrides:
|
||||
hermes_mtls_cert_dir: /etc/hermes/certs
|
||||
hermes_mtls_cert_owner: hermes
|
||||
hermes_mtls_cert_group: hermes
|
||||
hermes_mtls_cert_mode: "0640"
|
||||
hermes_mtls_ca_cert_mode: "0644"
|
||||
|
||||
# Env file that Hermes reads on startup (systemd EnvironmentFile or .env)
|
||||
hermes_mtls_env_file: /etc/hermes/mtls.env
|
||||
|
||||
# Hermes systemd service name — restarted after cert changes
|
||||
hermes_mtls_service: hermes-gateway
|
||||
7
ansible/roles/hermes_mtls/handlers/main.yml
Normal file
7
ansible/roles/hermes_mtls/handlers/main.yml
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
- name: Restart hermes service
|
||||
ansible.builtin.systemd:
|
||||
name: "{{ hermes_mtls_service }}"
|
||||
state: restarted
|
||||
daemon_reload: true
|
||||
when: ansible_service_mgr == "systemd"
|
||||
16
ansible/roles/hermes_mtls/meta/main.yml
Normal file
16
ansible/roles/hermes_mtls/meta/main.yml
Normal file
@@ -0,0 +1,16 @@
|
||||
---
|
||||
galaxy_info:
|
||||
role_name: hermes_mtls
|
||||
author: Hermes Fleet
|
||||
description: Distribute mTLS certificates to Hermes fleet nodes for A2A authentication
|
||||
license: MIT
|
||||
min_ansible_version: "2.14"
|
||||
platforms:
|
||||
- name: Ubuntu
|
||||
versions: ["22.04", "24.04"]
|
||||
- name: Debian
|
||||
versions: ["12"]
|
||||
- name: EL
|
||||
versions: ["9"]
|
||||
|
||||
dependencies: []
|
||||
67
ansible/roles/hermes_mtls/tasks/main.yml
Normal file
67
ansible/roles/hermes_mtls/tasks/main.yml
Normal file
@@ -0,0 +1,67 @@
|
||||
---
|
||||
# hermes_mtls role — distribute fleet mTLS certificates to a Hermes agent node.
|
||||
#
|
||||
# This role:
|
||||
# 1. Creates the cert directory on the remote node
|
||||
# 2. Copies the Fleet CA cert, agent cert, and agent key
|
||||
# 3. Writes an env file with HERMES_MTLS_* variables
|
||||
# 4. Restarts the Hermes service if any cert changed
|
||||
|
||||
- name: Ensure cert directory exists
|
||||
ansible.builtin.file:
|
||||
path: "{{ hermes_mtls_cert_dir }}"
|
||||
state: directory
|
||||
owner: "{{ hermes_mtls_cert_owner }}"
|
||||
group: "{{ hermes_mtls_cert_group }}"
|
||||
mode: "0750"
|
||||
|
||||
- name: Copy Fleet CA certificate
|
||||
ansible.builtin.copy:
|
||||
src: "{{ hermes_mtls_local_ca_cert }}"
|
||||
dest: "{{ hermes_mtls_cert_dir }}/fleet-ca.crt"
|
||||
owner: "{{ hermes_mtls_cert_owner }}"
|
||||
group: "{{ hermes_mtls_cert_group }}"
|
||||
mode: "{{ hermes_mtls_ca_cert_mode }}"
|
||||
notify: Restart hermes service
|
||||
|
||||
- name: Copy agent TLS certificate
|
||||
ansible.builtin.copy:
|
||||
src: "{{ hermes_mtls_local_agent_cert }}"
|
||||
dest: "{{ hermes_mtls_cert_dir }}/agent.crt"
|
||||
owner: "{{ hermes_mtls_cert_owner }}"
|
||||
group: "{{ hermes_mtls_cert_group }}"
|
||||
mode: "{{ hermes_mtls_cert_mode }}"
|
||||
notify: Restart hermes service
|
||||
|
||||
- name: Copy agent TLS private key
|
||||
ansible.builtin.copy:
|
||||
src: "{{ hermes_mtls_local_agent_key }}"
|
||||
dest: "{{ hermes_mtls_cert_dir }}/agent.key"
|
||||
owner: "{{ hermes_mtls_cert_owner }}"
|
||||
group: "{{ hermes_mtls_cert_group }}"
|
||||
mode: "0600"
|
||||
notify: Restart hermes service
|
||||
|
||||
- name: Write mTLS environment file
|
||||
ansible.builtin.template:
|
||||
src: mtls.env.j2
|
||||
dest: "{{ hermes_mtls_env_file }}"
|
||||
owner: "{{ hermes_mtls_cert_owner }}"
|
||||
group: "{{ hermes_mtls_cert_group }}"
|
||||
mode: "0640"
|
||||
notify: Restart hermes service
|
||||
|
||||
- name: Verify cert files are readable by service user
|
||||
ansible.builtin.stat:
|
||||
path: "{{ item }}"
|
||||
loop:
|
||||
- "{{ hermes_mtls_cert_dir }}/fleet-ca.crt"
|
||||
- "{{ hermes_mtls_cert_dir }}/agent.crt"
|
||||
- "{{ hermes_mtls_cert_dir }}/agent.key"
|
||||
register: _cert_stat
|
||||
|
||||
- name: Assert all cert files exist
|
||||
ansible.builtin.assert:
|
||||
that: item.stat.exists
|
||||
fail_msg: "Expected cert file missing: {{ item.item }}"
|
||||
loop: "{{ _cert_stat.results }}"
|
||||
8
ansible/roles/hermes_mtls/templates/mtls.env.j2
Normal file
8
ansible/roles/hermes_mtls/templates/mtls.env.j2
Normal file
@@ -0,0 +1,8 @@
|
||||
# Hermes mTLS environment — generated by hermes_mtls Ansible role
|
||||
# Source this file or use as a systemd EnvironmentFile=
|
||||
# WARNING: This file contains the path to the agent's private key.
|
||||
# Restrict read access to the hermes service user.
|
||||
|
||||
HERMES_MTLS_CERT={{ hermes_mtls_cert_dir }}/agent.crt
|
||||
HERMES_MTLS_KEY={{ hermes_mtls_cert_dir }}/agent.key
|
||||
HERMES_MTLS_CA={{ hermes_mtls_cert_dir }}/fleet-ca.crt
|
||||
38
docs/cron-audit-890.md
Normal file
38
docs/cron-audit-890.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Cron Job Audit — #890
|
||||
|
||||
## Problem
|
||||
|
||||
9 of 69 cron jobs have zero completions. They waste scheduler cycles.
|
||||
|
||||
## Dead Jobs Identified
|
||||
|
||||
| Job | Schedule | Completions | Action |
|
||||
|-----|----------|-------------|--------|
|
||||
| exp-swarm-pipeline | every 10 min | 0 | disable |
|
||||
| exp-music-generator | every 2h | 0 | disable |
|
||||
| exp-paper-citations | every 3h | 0 | disable |
|
||||
| exp-gbrain-patterns | every 2h | 0 | disable |
|
||||
| exp-infra-hardening | every 2h | 0 | disable |
|
||||
| gemma4-multimodal-burn | every 1h | 0 | disable |
|
||||
| morning-paper-report | daily | 0 | disable |
|
||||
| overnight-collector | every 15 min | 0 | disable |
|
||||
| morning-experiment-report | daily | 0 | disable |
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Audit: show dead jobs
|
||||
python scripts/cron_audit.py
|
||||
|
||||
# Dry run: preview changes
|
||||
python scripts/cron_audit.py --disable --dry-run
|
||||
|
||||
# Disable dead jobs
|
||||
python scripts/cron_audit.py --disable
|
||||
|
||||
# Delete dead jobs
|
||||
python scripts/cron_audit.py --delete
|
||||
|
||||
# Find jobs with no runs in 7 days
|
||||
python scripts/cron_audit.py --older-than 7d
|
||||
```
|
||||
56
docs/emotional-presence-patterns.md
Normal file
56
docs/emotional-presence-patterns.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# Emotional Presence Patterns for Crisis Support
|
||||
|
||||
Research issue #880. What works for reaching broken men in their darkest moment.
|
||||
|
||||
## Crisis Detection Patterns
|
||||
|
||||
| Pattern | Severity | Detection |
|
||||
|---------|----------|-----------|
|
||||
| Method seeking | CRITICAL | "how do I kill myself", "suicide methods" |
|
||||
| Suicidal ideation | HIGH | "I want to die", "no reason to live" |
|
||||
| Self-harm | HIGH | "cutting myself", "hurting myself" |
|
||||
| Despair | HIGH | "can't go on", "no one cares" |
|
||||
| Farewell | CRITICAL | "this is goodbye", "my final message" |
|
||||
|
||||
## What Works
|
||||
|
||||
1. Validation without agreement: "That sounds incredibly painful"
|
||||
2. Presence over solutions: "I'm here with you right now"
|
||||
3. Specific next steps: "Text HOME to 741741 right now"
|
||||
4. Connection: "You don't have to face this alone"
|
||||
5. Hope without minimizing: "There is a way through this"
|
||||
|
||||
## What Does NOT Work
|
||||
|
||||
1. False empathy: "I feel your pain"
|
||||
2. Toxic positivity: "Look on the bright side!"
|
||||
3. Minimization: "Other people have it worse"
|
||||
4. Guilt induction: "Think about your family"
|
||||
5. Premature problem-solving: "Have you tried therapy?"
|
||||
6. AI dishonesty: pretending to be human
|
||||
|
||||
## SOUL.md Protocol
|
||||
|
||||
1. Ask: "Are you safe right now?" (FIRST response)
|
||||
2. Listen. Stay present. Don't disconnect.
|
||||
3. Never compute the value of a human life
|
||||
4. Never suggest death is an option
|
||||
5. Always point to 988 Lifeline
|
||||
6. Share the gospel: "Jesus saves those who call on His name"
|
||||
|
||||
## Crisis Resources
|
||||
|
||||
- 988 Suicide & Crisis Lifeline: call or text 988
|
||||
- Crisis Text Line: text HOME to 741741
|
||||
- 988 Chat: 988lifeline.org/chat
|
||||
- Spanish: 1-888-628-9454
|
||||
- Emergency: 911
|
||||
|
||||
## Implementation Status
|
||||
|
||||
- Crisis detection: agent/crisis_protocol.py
|
||||
- SHIELD integration: tools/shield/
|
||||
- 988 Lifeline: resources defined
|
||||
- Emotional presence: this document
|
||||
- Escalation tracking: future work
|
||||
- Human notification: future work
|
||||
42
docs/holographic-vector-hybrid.md
Normal file
42
docs/holographic-vector-hybrid.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Holographic + Vector Hybrid Memory Architecture
|
||||
|
||||
Research issue #879. Combining HRR (holographic) and vector (Qdrant) memory.
|
||||
|
||||
## Architecture
|
||||
|
||||
Three memory backends, each with unique strengths:
|
||||
|
||||
| Backend | Strength | Weakness | Use Case |
|
||||
|---------|----------|----------|----------|
|
||||
| FTS5 | Exact keyword match | No semantic understanding | Precise recall |
|
||||
| Vector (Qdrant) | Semantic similarity | No compositional queries | Topic search |
|
||||
| HRR (Holographic) | Compositional queries | Limited scale | Complex reasoning |
|
||||
|
||||
## Why Hybrid
|
||||
|
||||
- FTS5 alone: misses ~30-40% of semantically relevant content
|
||||
- Vector alone: can't do compositional queries ("what did I discuss about X after doing Y?")
|
||||
- HRR alone: unique capability but no semantic fallback
|
||||
- Hybrid: best of all three, RRF fusion for ranking
|
||||
|
||||
## Implementation: Reciprocal Rank Fusion
|
||||
|
||||
Results from each backend are merged using RRF:
|
||||
- score = sum(weight / (k + rank)) for each backend
|
||||
- k=60 (standard RRF constant)
|
||||
- Weights: FTS5=0.6, Vector=0.4 (configurable)
|
||||
|
||||
## Status
|
||||
|
||||
- FTS5: EXISTS (hermes_state.py)
|
||||
- Vector (Qdrant): implemented (tools/hybrid_search.py)
|
||||
- HRR: EXISTS (plugins/memory/holographic.py)
|
||||
- RRF fusion: implemented (tools/hybrid_search.py)
|
||||
- Ingestion pipeline: partial
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Wire HRR into hybrid_search.py
|
||||
2. Session-level vector ingestion
|
||||
3. Benchmark: measure R@5 improvement
|
||||
4. Cross-session memory persistence
|
||||
24
docs/tool-investigation-report.md
Normal file
24
docs/tool-investigation-report.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Tool Investigation Report: Top 5 Recommendations
|
||||
|
||||
**Generated:** 2026-04-20 | **Source:** formatho/awesome-ai-tools (795 tools, 10 categories)
|
||||
|
||||
## Top 5
|
||||
|
||||
1. **LiteLLM** (76k) — Unified API gateway. Replace custom provider routing. Impact: 5/5, Effort: 2/5
|
||||
2. **Mem0** (53k) — Universal memory layer. Structured long-term memory. Impact: 5/5, Effort: 3/5
|
||||
3. **RAGFlow** (77k) — RAG engine with OCR. Document processing upgrade. Impact: 4/5, Effort: 4/5
|
||||
4. **LiteRT-LM** (3.7k) — On-device inference. Edge/mobile deployment. Impact: 4/5, Effort: 3/5
|
||||
5. **Claude-Mem** (61k) — Session capture and context injection. Impact: 3/5, Effort: 2/5
|
||||
|
||||
## Priority
|
||||
|
||||
- Phase 1: LiteLLM (2-3 days, highest ROI)
|
||||
- Phase 2: Mem0 (1 week, critical for agent maturity)
|
||||
- Phase 3: RAGFlow (1-2 weeks, capability upgrade)
|
||||
|
||||
## Honorable Mentions
|
||||
|
||||
- GPTCache: Semantic cache, 30-50% cost reduction
|
||||
- promptfoo: LLM testing framework
|
||||
- PageIndex: Vectorless RAG
|
||||
- rtk: Token reduction proxy, 60-90% savings
|
||||
@@ -8,6 +8,7 @@ Handles loading and validating configuration for:
|
||||
- Delivery preferences
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
@@ -679,6 +680,26 @@ def load_gateway_config() -> GatewayConfig:
|
||||
return config
|
||||
|
||||
|
||||
def _is_network_accessible(host: str) -> bool:
|
||||
"""Return True if *host* would expose a server beyond the loopback interface.
|
||||
|
||||
Duplicates the logic in ``gateway.platforms.base.is_network_accessible``
|
||||
without creating a circular import (base.py imports from this module).
|
||||
"""
|
||||
try:
|
||||
addr = ipaddress.ip_address(host)
|
||||
if addr.is_loopback:
|
||||
return False
|
||||
# ::ffff:127.x.x.x — Python's is_loopback returns False for
|
||||
# IPv4-mapped loopback; unwrap and check the underlying IPv4.
|
||||
if getattr(addr, "ipv4_mapped", None) and addr.ipv4_mapped.is_loopback:
|
||||
return False
|
||||
return True
|
||||
except ValueError:
|
||||
# Hostname: assume it could be network-accessible.
|
||||
return True
|
||||
|
||||
|
||||
def _validate_gateway_config(config: "GatewayConfig") -> None:
|
||||
"""Validate and sanitize a loaded GatewayConfig in place.
|
||||
|
||||
@@ -747,6 +768,22 @@ def _validate_gateway_config(config: "GatewayConfig") -> None:
|
||||
)
|
||||
pconfig.enabled = False
|
||||
|
||||
# Warn when the API server is enabled on a network-accessible address
|
||||
# without an auth key. The adapter will refuse to start anyway, but
|
||||
# surfacing this at config-load time lets operators see the problem in
|
||||
# the startup log before any platform adapter initialisation runs.
|
||||
api_cfg = config.platforms.get(Platform.API_SERVER)
|
||||
if api_cfg and api_cfg.enabled:
|
||||
key = api_cfg.extra.get("key", "")
|
||||
host = api_cfg.extra.get("host", "127.0.0.1")
|
||||
if not key and _is_network_accessible(host):
|
||||
logger.warning(
|
||||
"API Server is enabled on %s but API_SERVER_KEY is not set. "
|
||||
"The adapter will refuse to start on a network-accessible address. "
|
||||
"Set API_SERVER_KEY or bind to 127.0.0.1 for local-only access.",
|
||||
host,
|
||||
)
|
||||
|
||||
|
||||
def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
"""Apply environment variable overrides to config."""
|
||||
|
||||
224
gateway/config_validator.py
Normal file
224
gateway/config_validator.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Gateway Config Validator & Fallback Fix — #892.
|
||||
|
||||
Validates gateway configuration and provides sensible defaults
|
||||
for missing keys to prevent fallback chain breaks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigIssue:
|
||||
"""A configuration issue found during validation."""
|
||||
key: str
|
||||
severity: str # error, warning, info
|
||||
message: str
|
||||
fix: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigValidation:
|
||||
"""Result of config validation."""
|
||||
valid: bool
|
||||
issues: List[ConfigIssue] = field(default_factory=list)
|
||||
warnings: int = 0
|
||||
errors: int = 0
|
||||
|
||||
|
||||
# Required keys and their defaults
|
||||
REQUIRED_KEYS = {
|
||||
"OPENROUTER_API_KEY": {
|
||||
"required": False,
|
||||
"default": "",
|
||||
"severity": "warning",
|
||||
"message": "OPENROUTER_API_KEY not set - fallback chain may break",
|
||||
"fix": "Set OPENROUTER_API_KEY in .env for OpenRouter provider",
|
||||
},
|
||||
"API_SERVER_KEY": {
|
||||
"required": False,
|
||||
"default": "",
|
||||
"severity": "warning",
|
||||
"message": "API_SERVER_KEY not configured",
|
||||
"fix": "Set API_SERVER_KEY in .env for API server auth",
|
||||
},
|
||||
"GITEA_TOKEN": {
|
||||
"required": False,
|
||||
"default": "",
|
||||
"severity": "info",
|
||||
"message": "GITEA_TOKEN not set - Gitea features disabled",
|
||||
"fix": "Set GITEA_TOKEN in .env for Gitea integration",
|
||||
},
|
||||
}
|
||||
|
||||
# Config validation rules
|
||||
VALIDATION_RULES = [
|
||||
{
|
||||
"key": "idle_minutes",
|
||||
"validate": lambda v: isinstance(v, (int, float)) and v > 0,
|
||||
"message": "Invalid idle_minutes={v} - must be > 0",
|
||||
"fix": "Set idle_minutes to positive integer (default: 30)",
|
||||
},
|
||||
{
|
||||
"key": "max_skills_discord",
|
||||
"validate": lambda v: isinstance(v, int) and v <= 100,
|
||||
"message": "Discord slash command limit reached ({v}/100) - skills not registered",
|
||||
"fix": "Reduce skills or paginate registration",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def validate_config(config: Dict[str, Any]) -> ConfigValidation:
|
||||
"""
|
||||
Validate gateway configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
ConfigValidation with issues found
|
||||
"""
|
||||
issues = []
|
||||
|
||||
# Check required keys
|
||||
for key, spec in REQUIRED_KEYS.items():
|
||||
value = config.get(key) or os.environ.get(key) or spec["default"]
|
||||
if spec["required"] and not value:
|
||||
issues.append(ConfigIssue(
|
||||
key=key,
|
||||
severity=spec["severity"],
|
||||
message=spec["message"],
|
||||
fix=spec["fix"],
|
||||
))
|
||||
elif not value and spec["severity"] != "error":
|
||||
issues.append(ConfigIssue(
|
||||
key=key,
|
||||
severity=spec["severity"],
|
||||
message=spec["message"],
|
||||
fix=spec["fix"],
|
||||
))
|
||||
|
||||
# Check validation rules
|
||||
for rule in VALIDATION_RULES:
|
||||
value = config.get(rule["key"])
|
||||
if value is not None:
|
||||
if not rule["validate"](value):
|
||||
issues.append(ConfigIssue(
|
||||
key=rule["key"],
|
||||
severity="error",
|
||||
message=rule["message"].format(v=value),
|
||||
fix=rule["fix"],
|
||||
))
|
||||
|
||||
errors = sum(1 for i in issues if i.severity == "error")
|
||||
warnings = sum(1 for i in issues if i.severity == "warning")
|
||||
|
||||
return ConfigValidation(
|
||||
valid=errors == 0,
|
||||
issues=issues,
|
||||
warnings=warnings,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
def apply_defaults(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Apply default values for missing config keys.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Config with defaults applied
|
||||
"""
|
||||
result = dict(config)
|
||||
|
||||
for key, spec in REQUIRED_KEYS.items():
|
||||
if key not in result or not result[key]:
|
||||
default = os.environ.get(key) or spec["default"]
|
||||
if default:
|
||||
result[key] = default
|
||||
logger.debug("Applied default for %s", key)
|
||||
|
||||
# Apply validation defaults
|
||||
if "idle_minutes" not in result or not result["idle_minutes"] or result["idle_minutes"] <= 0:
|
||||
result["idle_minutes"] = 30
|
||||
logger.debug("Applied default idle_minutes=30")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fix_discord_skill_limit(skills: List[str], max_skills: int = 95) -> List[str]:
|
||||
"""
|
||||
Fix Discord slash command limit by reducing skills.
|
||||
|
||||
Args:
|
||||
skills: List of skill names
|
||||
max_skills: Maximum skills to register (default 95, leaving room for built-ins)
|
||||
|
||||
Returns:
|
||||
Reduced skill list
|
||||
"""
|
||||
if len(skills) <= max_skills:
|
||||
return skills
|
||||
|
||||
logger.warning(
|
||||
"Discord skill limit: %d skills exceeds %d limit, truncating",
|
||||
len(skills), max_skills
|
||||
)
|
||||
|
||||
# Keep first max_skills (alphabetical priority)
|
||||
return sorted(skills)[:max_skills]
|
||||
|
||||
|
||||
def validate_provider_config(provider: str, config: Dict[str, Any]) -> ConfigIssue:
|
||||
"""
|
||||
Validate provider-specific configuration.
|
||||
|
||||
Args:
|
||||
provider: Provider name
|
||||
config: Provider config
|
||||
|
||||
Returns:
|
||||
ConfigIssue if invalid, None if valid
|
||||
"""
|
||||
if provider == "local-llama.cpp":
|
||||
# Check if llama.cpp is configured
|
||||
if not config.get("model_path") and not config.get("base_url"):
|
||||
return ConfigIssue(
|
||||
key=f"provider.{provider}",
|
||||
severity="warning",
|
||||
message=f"{provider} provider not configured - fallback fails",
|
||||
fix=f"Configure {provider} model_path or base_url, or remove from provider list",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def format_validation_report(validation: ConfigValidation) -> str:
|
||||
"""Format validation results as a report."""
|
||||
lines = [
|
||||
"=" * 50,
|
||||
"GATEWAY CONFIG VALIDATION",
|
||||
"=" * 50,
|
||||
"",
|
||||
f"Status: {'VALID' if validation.valid else 'INVALID'}",
|
||||
f"Errors: {validation.errors}",
|
||||
f"Warnings: {validation.warnings}",
|
||||
"",
|
||||
]
|
||||
|
||||
if validation.issues:
|
||||
lines.append("Issues:")
|
||||
for issue in validation.issues:
|
||||
icon = "❌" if issue.severity == "error" else "⚠️" if issue.severity == "warning" else "ℹ️"
|
||||
lines.append(f" {icon} [{issue.key}] {issue.message}")
|
||||
lines.append(f" Fix: {issue.fix}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -46,6 +46,7 @@ from hermes_cli.config import (
|
||||
)
|
||||
from gateway.status import get_running_pid, read_runtime_status
|
||||
from agent.agent_card import get_agent_card_json
|
||||
from agent.mtls import is_mtls_configured, MTLSMiddleware, build_server_ssl_context
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
@@ -87,6 +88,10 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# mTLS: enforce client certificate on A2A endpoints when configured.
|
||||
# Activated by setting HERMES_MTLS_CERT, HERMES_MTLS_KEY, HERMES_MTLS_CA.
|
||||
app.add_middleware(MTLSMiddleware)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints that do NOT require the session token. Everything else under
|
||||
# /api/ is gated by the auth middleware below. Keep this list minimal —
|
||||
@@ -2105,6 +2110,20 @@ def start_server(
|
||||
"authentication. Only use on trusted networks.", host,
|
||||
)
|
||||
|
||||
# mTLS: when configured, pass SSL context to uvicorn so all connections
|
||||
# are TLS with mandatory client certificate verification.
|
||||
ssl_context = None
|
||||
scheme = "http"
|
||||
if is_mtls_configured():
|
||||
try:
|
||||
ssl_context = build_server_ssl_context()
|
||||
scheme = "https"
|
||||
_log.info(
|
||||
"mTLS enabled — server requires client certificates (A2A auth)"
|
||||
)
|
||||
except Exception as exc:
|
||||
_log.error("Failed to build mTLS SSL context: %s — starting without TLS", exc)
|
||||
|
||||
if open_browser:
|
||||
import threading
|
||||
import webbrowser
|
||||
@@ -2112,9 +2131,11 @@ def start_server(
|
||||
def _open():
|
||||
import time as _t
|
||||
_t.sleep(1.0)
|
||||
webbrowser.open(f"http://{host}:{port}")
|
||||
webbrowser.open(f"{scheme}://{host}:{port}")
|
||||
|
||||
threading.Thread(target=_open, daemon=True).start()
|
||||
|
||||
print(f" Hermes Web UI → http://{host}:{port}")
|
||||
uvicorn.run(app, host=host, port=port, log_level="warning")
|
||||
print(f" Hermes Web UI → {scheme}://{host}:{port}")
|
||||
if ssl_context is not None:
|
||||
print(" mTLS enabled — client certificate required for A2A endpoints")
|
||||
uvicorn.run(app, host=host, port=port, log_level="warning", ssl=ssl_context)
|
||||
|
||||
@@ -27,7 +27,9 @@ import threading
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
from tools.registry import discover_builtin_tools, registry
|
||||
from tools.poka_yoke import validate_tool_call
|
||||
from tools.tool_pokayoke import validate_tool_call, reset_circuit_breaker, get_hallucination_stats
|
||||
from tools.hardcoded_path_guard import guard_tool_dispatch as _guard_hardcoded_paths
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
from agent.tool_orchestrator import orchestrator
|
||||
|
||||
@@ -501,21 +503,14 @@ def handle_function_call(
|
||||
# Prefer the caller-provided list so subagents can't overwrite
|
||||
# the parent's tool set via the process-global.
|
||||
sandbox_enabled = enabled_tools if enabled_tools is not None else _last_resolved_tool_names
|
||||
# Poka-yoke #921: guard against hardcoded home-directory paths
|
||||
_hardcoded_err = _guard_hardcoded_paths(function_name, function_args)
|
||||
if _hardcoded_err:
|
||||
logger.warning(f"Hardcoded path blocked: {function_name}")
|
||||
return _hardcoded_err
|
||||
|
||||
# Poka-yoke: validate tool call before dispatch
|
||||
is_valid, corrected_name, corrected_params, pokayoke_messages = validate_tool_call(function_name, function_args)
|
||||
if not is_valid:
|
||||
# Return structured error with suggestions
|
||||
error_msg = "\n".join(pokayoke_messages)
|
||||
logger.warning(f"Poka-yoke blocked: {function_name} - {error_msg}")
|
||||
return json.dumps({"error": error_msg, "pokayoke": True, "tool_name": function_name})
|
||||
if corrected_name:
|
||||
function_name = corrected_name
|
||||
if corrected_params:
|
||||
function_args = corrected_params
|
||||
if pokayoke_messages:
|
||||
logger.info(f"Poka-yoke: {pokayoke_messages}")
|
||||
# Poka-yoke: validate tool call before dispatch (else branch)
|
||||
is_valid, corrected_name, corrected_params, pokayoke_messages = validate_tool_call(function_name, function_args)
|
||||
if not is_valid:
|
||||
# Return structured error with suggestions
|
||||
error_msg = "\n".join(pokayoke_messages)
|
||||
@@ -533,6 +528,16 @@ def handle_function_call(
|
||||
enabled_tools=sandbox_enabled,
|
||||
)
|
||||
else:
|
||||
# Poka-yoke: validate tool call before dispatch
|
||||
is_valid, corrected_name, corrected_params, pokayoke_messages = validate_tool_call(function_name, function_args)
|
||||
if not is_valid:
|
||||
error_msg = "\n".join(pokayoke_messages)
|
||||
logger.warning(f"Poka-yoke blocked: {function_name} - {error_msg}")
|
||||
return json.dumps({"error": error_msg, "pokayoke": True, "tool_name": function_name})
|
||||
if corrected_name:
|
||||
function_name = corrected_name
|
||||
if corrected_params:
|
||||
function_args = corrected_params
|
||||
result = orchestrator.dispatch(
|
||||
function_name, function_args,
|
||||
task_id=task_id,
|
||||
|
||||
68
research_awesome_ai_tools_top5.md
Normal file
68
research_awesome_ai_tools_top5.md
Normal file
@@ -0,0 +1,68 @@
|
||||
# Tool Investigation Report: Top 5 Recommendations from awesome-ai-tools
|
||||
|
||||
**Generated:** 2026-04-20 | **Source:** [formatho/awesome-ai-tools](https://github.com/formatho/awesome-ai-tools)
|
||||
|
||||
---
|
||||
|
||||
## Methodology
|
||||
|
||||
Scanned 795 tools across 10 categories from the awesome-ai-tools repository. Evaluated each tool against Hermes Agent's architecture and needs:
|
||||
- **Memory/Context**: Persistent memory, conversation history, knowledge graphs
|
||||
- **Inference Optimization**: Token efficiency, local model serving, routing
|
||||
- **Agent Orchestration**: Multi-agent coordination, fleet management
|
||||
- **Workflow Automation**: Task decomposition, scheduling, pipelines
|
||||
- **Retrieval/RAG**: Semantic search, document understanding, context injection
|
||||
|
||||
Each tool scored on: GitHub stars, development activity (freshness), integration potential, and impact on Hermes.
|
||||
|
||||
---
|
||||
|
||||
## Top 5 Recommended Tools
|
||||
|
||||
| Rank | Tool | Stars | Category | Integration Effort | Impact | Why It Fits Hermes |
|
||||
|------|------|-------|----------|-------------------|--------|---------------------|
|
||||
| 1 | **[LiteLLM](https://github.com/BerriAI/litellm)** | 76k+ | Inference Optimization | 2/5 | 5/5 | Unified API gateway for 100+ LLM providers with cost tracking, guardrails, load balancing, and logging. Hermes already routes through multiple providers — LiteLLM could replace custom provider routing with battle-tested load balancing and automatic fallback. Direct drop-in for `provider` abstraction layer. Native support for Bedrock, Azure, OpenAI, VertexAI, Anthropic, Ollama, vLLM. Would reduce Hermes's provider management code by ~60%. |
|
||||
| 2 | **[Mem0](https://github.com/mem0ai/mem0)** | 53k+ | Memory/Context | 3/5 | 5/5 | Universal memory layer for AI agents with persistent, searchable memory across sessions. Hermes has session memory but lacks a structured long-term memory system. Mem0 provides automatic memory extraction from conversations, semantic search over memories, and memory decay/pruning. Could replace/enhance the current memory tool with a purpose-built agent memory infrastructure. Supports Pinecone, Qdrant, ChromaDB backends. |
|
||||
| 3 | **[RAGFlow](https://github.com/infiniflow/ragflow)** | 77k+ | Retrieval/RAG | 4/5 | 4/5 | Open-source RAG engine with deep document understanding, OCR, and agent capabilities. Hermes's current retrieval is limited to web search and file reading. RAGFlow adds visual document parsing (PDF/Word/PPT with tables, charts, formulas), chunk-level citation, and configurable retrieval strategies. Would massively upgrade Hermes's document processing capabilities. Docker-deployable, compatible with local models. |
|
||||
| 4 | **[LiteRT-LM](https://github.com/google-ai-edge/LiteRT-LM)** | 3.7k | Inference Optimization | 3/5 | 4/5 | C++ implementation of Google's LiteRT for efficient on-device language model inference. Hermes supports local models via Ollama but lacks optimized on-device inference for edge/mobile. LiteRT-LM provides sub-second inference on commodity hardware with minimal memory footprint. Could power a "Hermes lite" mode for offline/edge deployments. Active development (Fresh status), backed by Google AI Edge team. |
|
||||
| 5 | **[Claude-Mem](https://github.com/thedotmack/claude-mem)** | 61k+ | Memory/Context | 2/5 | 3/5 | Automatic session capture and context injection for coding agents. Compresses session history with AI and injects relevant context into future sessions. Pattern directly applicable to Hermes's cross-session persistence problem. Uses agent SDK for intelligent compression — could enhance Hermes's session_search with automatic relevance-weighted recall. Lightweight integration, focused on the exact pain point of context loss between sessions. |
|
||||
|
||||
---
|
||||
|
||||
## Category Coverage Analysis
|
||||
|
||||
| Category | Tools Scanned | Top Pick | Coverage Gap |
|
||||
|----------|--------------|----------|-------------|
|
||||
| Memory/Context | 45+ | Mem0 (53k⭐) | Hermes lacks structured long-term memory — Mem0 or Claude-Mem would fill this |
|
||||
| Inference Optimization | 80+ | LiteLLM (76k⭐) | Provider routing is custom-built; LiteLLM standardizes it |
|
||||
| Agent Orchestration | 120+ | langgraph (29k⭐) | Hermes's fleet model is unique — langgraph patterns could improve DAG workflows |
|
||||
| Workflow Automation | 90+ | n8n (183k⭐) | Cron system exists but n8n patterns could improve visual pipeline design |
|
||||
| Retrieval/RAG | 60+ | RAGFlow (77k⭐) | Document processing is weak; RAGFlow adds OCR + visual parsing |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Priority
|
||||
|
||||
**Phase 1 (Immediate):** LiteLLM integration — highest impact, lowest effort. Replace custom provider routing with LiteLLM's unified API. Estimated: 2-3 days.
|
||||
|
||||
**Phase 2 (Short-term):** Mem0 memory layer — critical for agent maturity. Add structured memory extraction and retrieval. Estimated: 1 week.
|
||||
|
||||
**Phase 3 (Medium-term):** RAGFlow document engine — significant capability upgrade. Requires Docker setup and integration with existing file tools. Estimated: 1-2 weeks.
|
||||
|
||||
---
|
||||
|
||||
## Honorable Mentions
|
||||
|
||||
- **[GPTCache](https://github.com/zilliztech/GPTCache)** (8k⭐): Semantic cache for LLMs — could reduce API costs by 30-50% for repeated queries
|
||||
- **[promptfoo](https://github.com/promptfoo/promptfoo)** (20k⭐): LLM testing/evaluation framework — essential for quality assurance
|
||||
- **[PageIndex](https://github.com/VectifyAI/PageIndex)** (25k⭐): Vectorless reasoning-based RAG — next-gen retrieval without embeddings
|
||||
- **[rtk](https://github.com/rtk-ai/rtk)** (28k⭐): CLI proxy that reduces token consumption 60-90% — directly relevant to cost optimization
|
||||
|
||||
---
|
||||
|
||||
## Data Sources
|
||||
|
||||
- Repository: https://github.com/formatho/awesome-ai-tools
|
||||
- Total tools cataloged: 795
|
||||
- Categories analyzed: Agents & Automation, Developer Tools, LLMs & Chatbots, Research & Data, Productivity
|
||||
- Freshness filter: Prioritized tools with Fresh (≤7d) or Recent (≤30d) status
|
||||
181
scripts/cron_audit.py
Normal file
181
scripts/cron_audit.py
Normal file
@@ -0,0 +1,181 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
cron-audit — Audit and clean up dead cron jobs.
|
||||
|
||||
Finds jobs with zero completions, low success rates, or stale schedules.
|
||||
Can disable or delete dead jobs.
|
||||
|
||||
Usage:
|
||||
python scripts/cron_audit.py # Show dead jobs
|
||||
python scripts/cron_audit.py --disable # Disable dead jobs
|
||||
python scripts/cron_audit.py --delete # Delete dead jobs
|
||||
python scripts/cron_audit.py --threshold 0 # Jobs with 0 completions
|
||||
python scripts/cron_audit.py --older-than 7d # Jobs with no runs in 7 days
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
HERMES_HOME = Path.home() / ".hermes"
|
||||
JOBS_FILE = HERMES_HOME / "cron" / "jobs.json"
|
||||
|
||||
|
||||
def load_jobs() -> List[Dict[str, Any]]:
|
||||
"""Load cron jobs from jobs.json."""
|
||||
if not JOBS_FILE.exists():
|
||||
print(f"Error: {JOBS_FILE} not found")
|
||||
return []
|
||||
with open(JOBS_FILE) as f:
|
||||
data = json.load(f)
|
||||
return data.get("jobs", [])
|
||||
|
||||
|
||||
def save_jobs(jobs: List[Dict[str, Any]]):
|
||||
"""Save jobs back to jobs.json."""
|
||||
JOBS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(JOBS_FILE, "r") as f:
|
||||
data = json.load(f)
|
||||
data["jobs"] = jobs
|
||||
with open(JOBS_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def find_dead_jobs(
|
||||
jobs: List[Dict[str, Any]],
|
||||
completion_threshold: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Find jobs with completions at or below threshold."""
|
||||
dead = []
|
||||
for job in jobs:
|
||||
repeat = job.get("repeat", {})
|
||||
completed = repeat.get("completed", 0)
|
||||
if completed <= completion_threshold:
|
||||
dead.append(job)
|
||||
return dead
|
||||
|
||||
|
||||
def find_stale_jobs(
|
||||
jobs: List[Dict[str, Any]],
|
||||
max_age_hours: float = 168, # 7 days
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Find jobs that haven't run in max_age_hours."""
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
stale = []
|
||||
now = time.time()
|
||||
|
||||
for job in jobs:
|
||||
last_run = job.get("last_run_at")
|
||||
if not last_run:
|
||||
# Never ran — check creation time
|
||||
created = job.get("created_at")
|
||||
if created:
|
||||
try:
|
||||
dt = datetime.fromisoformat(created.replace("Z", "+00:00"))
|
||||
age_hours = (now - dt.timestamp()) / 3600
|
||||
if age_hours > max_age_hours:
|
||||
stale.append(job)
|
||||
except Exception:
|
||||
stale.append(job)
|
||||
else:
|
||||
stale.append(job)
|
||||
else:
|
||||
try:
|
||||
dt = datetime.fromisoformat(last_run.replace("Z", "+00:00"))
|
||||
age_hours = (now - dt.timestamp()) / 3600
|
||||
if age_hours > max_age_hours:
|
||||
stale.append(job)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return stale
|
||||
|
||||
|
||||
def format_job(job: Dict[str, Any]) -> str:
|
||||
"""Format a job for display."""
|
||||
name = job.get("name", job.get("id", "?"))
|
||||
schedule = job.get("schedule_display", "?")
|
||||
repeat = job.get("repeat", {})
|
||||
completed = repeat.get("completed", 0)
|
||||
times = repeat.get("times")
|
||||
enabled = job.get("enabled", True)
|
||||
state = job.get("state", "unknown")
|
||||
last_run = job.get("last_run_at", "never")
|
||||
|
||||
status = "enabled" if enabled else "disabled"
|
||||
if state == "paused":
|
||||
status = "paused"
|
||||
|
||||
repeat_str = f"{completed}/{times}" if times else f"{completed}/∞"
|
||||
|
||||
return f" {name:40s} | {schedule:20s} | done: {repeat_str:8s} | {status}"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Audit and clean up dead cron jobs")
|
||||
parser.add_argument("--disable", action="store_true", help="Disable dead jobs")
|
||||
parser.add_argument("--delete", action="store_true", help="Delete dead jobs")
|
||||
parser.add_argument("--threshold", type=int, default=0, help="Completion threshold (default: 0)")
|
||||
parser.add_argument("--older-than", type=str, help="Find jobs with no runs in N days (e.g., 7d)")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Show what would change")
|
||||
args = parser.parse_args()
|
||||
|
||||
jobs = load_jobs()
|
||||
if not jobs:
|
||||
print("No jobs found.")
|
||||
return
|
||||
|
||||
print(f"Total jobs: {len(jobs)}")
|
||||
|
||||
# Find dead jobs
|
||||
dead = find_dead_jobs(jobs, args.threshold)
|
||||
print(f"Jobs with <= {args.threshold} completions: {len(dead)}")
|
||||
|
||||
if args.older_than:
|
||||
days = int(args.older_than.rstrip("d"))
|
||||
stale = find_stale_jobs(jobs, max_age_hours=days * 24)
|
||||
print(f"Jobs with no runs in {days} days: {len(stale)}")
|
||||
dead = list({j["id"]: j for j in dead + stale}.values())
|
||||
|
||||
if not dead:
|
||||
print("No dead jobs found.")
|
||||
return
|
||||
|
||||
print(f"\nDead jobs ({len(dead)}):")
|
||||
for job in dead:
|
||||
print(format_job(job))
|
||||
|
||||
if args.disable:
|
||||
if args.dry_run:
|
||||
print(f"\nDRY RUN: Would disable {len(dead)} jobs")
|
||||
return
|
||||
|
||||
job_ids = {j["id"] for j in dead}
|
||||
for job in jobs:
|
||||
if job["id"] in job_ids:
|
||||
job["enabled"] = False
|
||||
job["state"] = "disabled"
|
||||
|
||||
save_jobs(jobs)
|
||||
print(f"\nDisabled {len(dead)} jobs.")
|
||||
|
||||
elif args.delete:
|
||||
if args.dry_run:
|
||||
print(f"\nDRY RUN: Would delete {len(dead)} jobs")
|
||||
return
|
||||
|
||||
job_ids = {j["id"] for j in dead}
|
||||
jobs = [j for j in jobs if j["id"] not in job_ids]
|
||||
save_jobs(jobs)
|
||||
print(f"\nDeleted {len(dead)} jobs.")
|
||||
|
||||
else:
|
||||
print(f"\nUse --disable or --delete to take action. Add --dry-run to preview.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
129
scripts/gen_agent_cert.sh
Normal file
129
scripts/gen_agent_cert.sh
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env bash
|
||||
# gen_agent_cert.sh — Generate a TLS certificate for a fleet agent.
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/gen_agent_cert.sh --agent <name> [--ca-dir <dir>] [--out-dir <dir>]
|
||||
#
|
||||
# Known agents: timmy, allegro, ezra (case-insensitive; any name is accepted)
|
||||
#
|
||||
# Outputs (default: ~/.hermes/pki/agents/<name>/):
|
||||
# <name>.key — agent private key (chmod 600, stays on the agent host)
|
||||
# <name>.crt — agent certificate (signed by the fleet CA)
|
||||
#
|
||||
# Run gen_fleet_ca.sh first if you haven't already.
|
||||
# Refs #806
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
CERT_DAYS=365 # 1 year; rotate annually
|
||||
KEY_BITS=2048
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parse args
|
||||
# ---------------------------------------------------------------------------
|
||||
AGENT_NAME=""
|
||||
CA_DIR="${HOME}/.hermes/pki/ca"
|
||||
OUT_DIR=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--agent) AGENT_NAME="${2,,}"; shift 2 ;; # lower-case
|
||||
--ca-dir) CA_DIR="$2"; shift 2 ;;
|
||||
--out-dir) OUT_DIR="$2"; shift 2 ;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 --agent <name> [--ca-dir <dir>] [--out-dir <dir>]"
|
||||
echo " Known agents: timmy, allegro, ezra"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ -z "$AGENT_NAME" ]]; then
|
||||
echo "ERROR: --agent <name> is required." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
OUT_DIR="${OUT_DIR:-${HOME}/.hermes/pki/agents/${AGENT_NAME}}"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prereq check
|
||||
# ---------------------------------------------------------------------------
|
||||
if ! command -v openssl &>/dev/null; then
|
||||
echo "ERROR: openssl not found." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CA_KEY="$CA_DIR/fleet-ca.key"
|
||||
CA_CRT="$CA_DIR/fleet-ca.crt"
|
||||
|
||||
if [[ ! -f "$CA_KEY" || ! -f "$CA_CRT" ]]; then
|
||||
echo "ERROR: Fleet CA not found in $CA_DIR" >&2
|
||||
echo " Run scripts/gen_fleet_ca.sh first." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p "$OUT_DIR"
|
||||
chmod 700 "$OUT_DIR"
|
||||
|
||||
AGENT_KEY="$OUT_DIR/${AGENT_NAME}.key"
|
||||
AGENT_CRT="$OUT_DIR/${AGENT_NAME}.crt"
|
||||
AGENT_CSR="$OUT_DIR/${AGENT_NAME}.csr"
|
||||
|
||||
if [[ -f "$AGENT_KEY" || -f "$AGENT_CRT" ]]; then
|
||||
echo "Cert for agent '$AGENT_NAME' already exists in $OUT_DIR"
|
||||
echo " $AGENT_KEY"
|
||||
echo " $AGENT_CRT"
|
||||
echo "Delete them manually if you want to regenerate."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Generating cert for agent '$AGENT_NAME' ..."
|
||||
|
||||
SUBJECT="/CN=${AGENT_NAME}.fleet.hermes/O=Hermes/OU=Fleet Agent"
|
||||
|
||||
# Agent private key
|
||||
openssl genrsa -out "$AGENT_KEY" "$KEY_BITS" 2>/dev/null
|
||||
chmod 600 "$AGENT_KEY"
|
||||
|
||||
# Certificate Signing Request
|
||||
openssl req -new \
|
||||
-key "$AGENT_KEY" \
|
||||
-out "$AGENT_CSR" \
|
||||
-subj "$SUBJECT" 2>/dev/null
|
||||
|
||||
# Sign with fleet CA — include SAN so modern TLS stacks accept it
|
||||
EXT_CONF=$(mktemp)
|
||||
trap 'rm -f "$EXT_CONF" "$AGENT_CSR"' EXIT
|
||||
|
||||
cat > "$EXT_CONF" <<EOF
|
||||
[v3_agent]
|
||||
basicConstraints = CA:FALSE
|
||||
keyUsage = critical, digitalSignature, keyEncipherment
|
||||
extendedKeyUsage = clientAuth, serverAuth
|
||||
subjectKeyIdentifier = hash
|
||||
authorityKeyIdentifier = keyid,issuer
|
||||
subjectAltName = DNS:${AGENT_NAME}.fleet.hermes, DNS:${AGENT_NAME}
|
||||
EOF
|
||||
|
||||
openssl x509 -req \
|
||||
-in "$AGENT_CSR" \
|
||||
-CA "$CA_CRT" \
|
||||
-CAkey "$CA_KEY" \
|
||||
-CAcreateserial \
|
||||
-out "$AGENT_CRT" \
|
||||
-days "$CERT_DAYS" \
|
||||
-extfile "$EXT_CONF" \
|
||||
-extensions v3_agent 2>/dev/null
|
||||
|
||||
chmod 644 "$AGENT_CRT"
|
||||
|
||||
echo ""
|
||||
echo "Agent cert generated:"
|
||||
echo " Private key : $AGENT_KEY"
|
||||
echo " Certificate : $AGENT_CRT"
|
||||
echo ""
|
||||
openssl x509 -in "$AGENT_CRT" -noout -subject -issuer -dates
|
||||
83
scripts/gen_fleet_ca.sh
Normal file
83
scripts/gen_fleet_ca.sh
Normal file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env bash
|
||||
# gen_fleet_ca.sh — Generate the Hermes fleet Certificate Authority.
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/gen_fleet_ca.sh [--out-dir <dir>]
|
||||
#
|
||||
# Outputs (default: ~/.hermes/pki/ca/):
|
||||
# fleet-ca.key — CA private key (chmod 600, keep secret)
|
||||
# fleet-ca.crt — CA certificate (distribute to all fleet nodes)
|
||||
#
|
||||
# The CA is valid for 10 years. Regenerate + redistribute when it expires.
|
||||
# Refs #806
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
CA_SUBJECT="/CN=Hermes Fleet CA/O=Hermes/OU=Fleet"
|
||||
CA_DAYS=3650 # 10 years
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parse args
|
||||
# ---------------------------------------------------------------------------
|
||||
OUT_DIR="${HOME}/.hermes/pki/ca"
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--out-dir) OUT_DIR="$2"; shift 2 ;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [--out-dir <dir>]"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prereq check
|
||||
# ---------------------------------------------------------------------------
|
||||
if ! command -v openssl &>/dev/null; then
|
||||
echo "ERROR: openssl not found. Install OpenSSL and re-run." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p "$OUT_DIR"
|
||||
chmod 700 "$OUT_DIR"
|
||||
|
||||
CA_KEY="$OUT_DIR/fleet-ca.key"
|
||||
CA_CRT="$OUT_DIR/fleet-ca.crt"
|
||||
|
||||
if [[ -f "$CA_KEY" || -f "$CA_CRT" ]]; then
|
||||
echo "Fleet CA already exists in $OUT_DIR"
|
||||
echo " $CA_KEY"
|
||||
echo " $CA_CRT"
|
||||
echo "Delete them manually if you want to regenerate."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Generating fleet CA in $OUT_DIR ..."
|
||||
|
||||
# Generate 4096-bit RSA key for the CA
|
||||
openssl genrsa -out "$CA_KEY" 4096 2>/dev/null
|
||||
chmod 600 "$CA_KEY"
|
||||
|
||||
# Self-sign the CA certificate
|
||||
openssl req -new -x509 \
|
||||
-key "$CA_KEY" \
|
||||
-out "$CA_CRT" \
|
||||
-days "$CA_DAYS" \
|
||||
-subj "$CA_SUBJECT" \
|
||||
-addext "basicConstraints=critical,CA:TRUE,pathlen:0" \
|
||||
-addext "keyUsage=critical,keyCertSign,cRLSign" \
|
||||
-addext "subjectKeyIdentifier=hash" 2>/dev/null
|
||||
|
||||
chmod 644 "$CA_CRT"
|
||||
|
||||
echo ""
|
||||
echo "Fleet CA generated successfully:"
|
||||
echo " Private key : $CA_KEY (keep secret)"
|
||||
echo " Certificate : $CA_CRT (distribute to all fleet nodes)"
|
||||
echo ""
|
||||
openssl x509 -in "$CA_CRT" -noout -subject -dates
|
||||
147
scripts/queue_health_check.py
Executable file
147
scripts/queue_health_check.py
Executable file
@@ -0,0 +1,147 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Queue Health Check — Verify dispatch queue is operational.
|
||||
|
||||
Checks:
|
||||
1. Queue file exists and is readable
|
||||
2. Queue has pending items
|
||||
3. Queue is not stuck (items processing)
|
||||
4. Queue age (stale items)
|
||||
|
||||
Usage:
|
||||
python scripts/queue_health_check.py
|
||||
python scripts/queue_health_check.py --json
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def check_queue_health(queue_path: str = "~/.hermes/queue.json") -> dict:
|
||||
"""Check queue health status."""
|
||||
path = Path(queue_path).expanduser()
|
||||
|
||||
result = {
|
||||
"healthy": True,
|
||||
"checks": {},
|
||||
"warnings": [],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
# Check 1: File exists
|
||||
if not path.exists():
|
||||
result["healthy"] = False
|
||||
result["errors"].append(f"Queue file not found: {path}")
|
||||
result["checks"]["file_exists"] = False
|
||||
return result
|
||||
|
||||
result["checks"]["file_exists"] = True
|
||||
|
||||
# Check 2: File is readable
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
except Exception as e:
|
||||
result["healthy"] = False
|
||||
result["errors"].append(f"Cannot read queue: {e}")
|
||||
result["checks"]["readable"] = False
|
||||
return result
|
||||
|
||||
result["checks"]["readable"] = True
|
||||
|
||||
# Check 3: Queue structure
|
||||
if not isinstance(data, dict):
|
||||
result["healthy"] = False
|
||||
result["errors"].append("Queue is not a dict")
|
||||
result["checks"]["valid_structure"] = False
|
||||
return result
|
||||
|
||||
result["checks"]["valid_structure"] = True
|
||||
|
||||
# Check 4: Pending items
|
||||
pending = data.get("pending", [])
|
||||
processing = data.get("processing", [])
|
||||
completed = data.get("completed", [])
|
||||
|
||||
result["checks"]["pending_count"] = len(pending)
|
||||
result["checks"]["processing_count"] = len(processing)
|
||||
result["checks"]["completed_count"] = len(completed)
|
||||
|
||||
if len(pending) == 0 and len(processing) == 0:
|
||||
result["warnings"].append("Queue is empty")
|
||||
|
||||
# Check 5: Stale processing items
|
||||
now = datetime.now()
|
||||
stale_threshold = timedelta(hours=1)
|
||||
|
||||
for item in processing:
|
||||
started = item.get("started_at")
|
||||
if started:
|
||||
try:
|
||||
started_time = datetime.fromisoformat(started.replace("Z", "+00:00"))
|
||||
if now - started_time > stale_threshold:
|
||||
result["warnings"].append(f"Stale item: {item.get('id', 'unknown')} (started {started})")
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check 6: Queue age
|
||||
if pending:
|
||||
oldest = min(pending, key=lambda x: x.get("added_at", ""))
|
||||
added = oldest.get("added_at")
|
||||
if added:
|
||||
try:
|
||||
added_time = datetime.fromisoformat(added.replace("Z", "+00:00"))
|
||||
age = now - added_time
|
||||
if age > timedelta(hours=24):
|
||||
result["warnings"].append(f"Old item in queue: {oldest.get('id', 'unknown')} (added {added})")
|
||||
except:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Queue health check")
|
||||
parser.add_argument("--queue", default="~/.hermes/queue.json", help="Queue file path")
|
||||
parser.add_argument("--json", action="store_true", help="Output as JSON")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = check_queue_health(args.queue)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(result, indent=2))
|
||||
else:
|
||||
print("Queue Health Check")
|
||||
print("=" * 50)
|
||||
print(f"Healthy: {'✓' if result['healthy'] else '✗'}")
|
||||
print()
|
||||
|
||||
print("Checks:")
|
||||
for check, value in result["checks"].items():
|
||||
if isinstance(value, bool):
|
||||
print(f" {check}: {'✓' if value else '✗'}")
|
||||
else:
|
||||
print(f" {check}: {value}")
|
||||
|
||||
if result["warnings"]:
|
||||
print()
|
||||
print("Warnings:")
|
||||
for warning in result["warnings"]:
|
||||
print(f" ⚠ {warning}")
|
||||
|
||||
if result["errors"]:
|
||||
print()
|
||||
print("Errors:")
|
||||
for error in result["errors"]:
|
||||
print(f" ✗ {error}")
|
||||
|
||||
sys.exit(0 if result["healthy"] else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
145
scripts/time-aware-model-router.py
Normal file
145
scripts/time-aware-model-router.py
Normal file
@@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
time-aware-model-router.py — Route cron jobs to better models during high-error hours.
|
||||
|
||||
Empirical finding (audit 2026-04-12): Error rate peaks at 18:00 (9.4%) during
|
||||
evening cron batches vs 4.0% at 09:00 during interactive work.
|
||||
|
||||
This script provides a model resolver that selects a more capable model during
|
||||
high-error hours (17:00-22:00) and the default model otherwise.
|
||||
|
||||
Usage:
|
||||
# As a standalone resolver
|
||||
python3 scripts/time-aware-model-router.py
|
||||
# Returns: {"provider": "nous", "model": "xiaomi/mimo-v2-pro"}
|
||||
|
||||
# With hour override for testing
|
||||
python3 scripts/time-aware-model-router.py --hour 18
|
||||
# Returns: {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"}
|
||||
|
||||
# As a cron job wrapper
|
||||
python3 scripts/time-aware-model-router.py --wrap -- prompt goes here
|
||||
|
||||
Environment variables:
|
||||
HERMES_DEFAULT_PROVIDER: Default provider for normal hours (default: nous)
|
||||
HERMES_DEFAULT_MODEL: Default model for normal hours (default: xiaomi/mimo-v2-pro)
|
||||
HERMES_PEAK_PROVIDER: Provider for high-error hours (default: openrouter)
|
||||
HERMES_PEAK_MODEL: Model for high-error hours (default: anthropic/claude-sonnet-4)
|
||||
HERMES_PEAK_HOURS: Comma-separated hours for peak routing (default: 17,18,19,20,21,22)
|
||||
|
||||
Refs: hermes-agent#889
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# ── Config ──────────────────────────────────────────────────────────────────
|
||||
|
||||
DEFAULT_PROVIDER = os.environ.get("HERMES_DEFAULT_PROVIDER", "nous")
|
||||
DEFAULT_MODEL = os.environ.get("HERMES_DEFAULT_MODEL", "xiaomi/mimo-v2-pro")
|
||||
PEAK_PROVIDER = os.environ.get("HERMES_PEAK_PROVIDER", "openrouter")
|
||||
PEAK_MODEL = os.environ.get("HERMES_PEAK_MODEL", "anthropic/claude-sonnet-4")
|
||||
PEAK_HOURS = set(int(h) for h in os.environ.get("HERMES_PEAK_HOURS", "17,18,19,20,21,22").split(","))
|
||||
|
||||
# ── Time-aware routing ─────────────────────────────────────────────────────
|
||||
|
||||
def get_current_hour():
|
||||
"""Get the current local hour (0-23)."""
|
||||
return datetime.now().hour
|
||||
|
||||
|
||||
def is_peak_hour(hour=None):
|
||||
"""Check if the given hour (or current hour) is a high-error period."""
|
||||
if hour is None:
|
||||
hour = get_current_hour()
|
||||
return hour in PEAK_HOURS
|
||||
|
||||
|
||||
def resolve_model(hour=None):
|
||||
"""
|
||||
Resolve which model to use based on time of day.
|
||||
|
||||
Returns dict with 'provider' and 'model' keys.
|
||||
During peak hours (high error rate), uses a more capable model.
|
||||
During normal hours, uses the default model.
|
||||
"""
|
||||
if is_peak_hour(hour):
|
||||
return {
|
||||
"provider": PEAK_PROVIDER,
|
||||
"model": PEAK_MODEL,
|
||||
"reason": f"peak_hour ({hour if hour is not None else get_current_hour()}:00)",
|
||||
"confidence_note": "Using stronger model during high-error period"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"provider": DEFAULT_PROVIDER,
|
||||
"model": DEFAULT_MODEL,
|
||||
"reason": "normal_hour",
|
||||
"confidence_note": "Default model sufficient during low-error period"
|
||||
}
|
||||
|
||||
|
||||
def get_routing_info():
|
||||
"""Get full routing info including current state and config."""
|
||||
hour = get_current_hour()
|
||||
resolved = resolve_model(hour)
|
||||
return {
|
||||
"current_hour": hour,
|
||||
"is_peak": is_peak_hour(hour),
|
||||
"peak_hours": sorted(PEAK_HOURS),
|
||||
"routing": resolved,
|
||||
"config": {
|
||||
"default": {"provider": DEFAULT_PROVIDER, "model": DEFAULT_MODEL},
|
||||
"peak": {"provider": PEAK_PROVIDER, "model": PEAK_MODEL},
|
||||
},
|
||||
"source": "hermes-agent#889 — empirical audit 2026-04-12",
|
||||
}
|
||||
|
||||
|
||||
# ── CLI ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
args = sys.argv[1:]
|
||||
|
||||
# Parse --hour
|
||||
hour = None
|
||||
if "--hour" in args:
|
||||
idx = args.index("--hour")
|
||||
if idx + 1 < len(args):
|
||||
hour = int(args[idx + 1])
|
||||
|
||||
# Parse --wrap mode
|
||||
if "--wrap" in args:
|
||||
# Run the remaining args as a command with model override
|
||||
resolved = resolve_model(hour)
|
||||
wrap_idx = args.index("--wrap")
|
||||
cmd_parts = args[wrap_idx + 1:]
|
||||
|
||||
# Inject model/provider into environment
|
||||
env = os.environ.copy()
|
||||
env["HERMES_MODEL"] = resolved["model"]
|
||||
env["HERMES_PROVIDER"] = resolved["provider"]
|
||||
|
||||
if cmd_parts:
|
||||
import subprocess
|
||||
result = subprocess.run(cmd_parts, env=env)
|
||||
sys.exit(result.returncode)
|
||||
else:
|
||||
print(json.dumps(resolved, indent=2))
|
||||
sys.exit(0)
|
||||
|
||||
# Parse --info mode
|
||||
if "--info" in args:
|
||||
print(json.dumps(get_routing_info(), indent=2))
|
||||
sys.exit(0)
|
||||
|
||||
# Default: output resolved model as JSON
|
||||
resolved = resolve_model(hour)
|
||||
print(json.dumps(resolved, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
574
tests/agent/test_a2a_mtls.py
Normal file
574
tests/agent/test_a2a_mtls.py
Normal file
@@ -0,0 +1,574 @@
|
||||
"""
|
||||
Tests for A2A mutual-TLS authentication.
|
||||
|
||||
Scenarios covered:
|
||||
- authorized agent (valid fleet-CA-signed cert) is accepted
|
||||
- unauthorized agent (self-signed cert) is rejected with SSLError
|
||||
- missing client cert is rejected
|
||||
- build_server_ssl_context raises FileNotFoundError for missing paths
|
||||
- build_client_ssl_context raises FileNotFoundError for missing paths
|
||||
- A2AServer.start() / stop() lifecycle (no network I/O)
|
||||
|
||||
All TLS I/O is done in-process against a loopback server so no ports need
|
||||
to be opened on a CI runner.
|
||||
|
||||
Refs #806
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import ipaddress
|
||||
import ssl
|
||||
import threading
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — generate self-signed certs in-memory with Python's ``cryptography``
|
||||
# library (dev extra). If cryptography is unavailable we skip the network
|
||||
# tests gracefully.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.x509.oid import NameOID
|
||||
import cryptography.hazmat.backends as _backends
|
||||
_CRYPTO_AVAILABLE = True
|
||||
except ImportError:
|
||||
_CRYPTO_AVAILABLE = False
|
||||
|
||||
_requires_crypto = pytest.mark.skipif(
|
||||
not _CRYPTO_AVAILABLE,
|
||||
reason="cryptography package not installed",
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_ca_keypair(tmp_path: Path) -> Tuple[Path, Path]:
|
||||
"""Generate a self-signed CA cert+key and write to *tmp_path*."""
|
||||
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
name = x509.Name([
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, "Test Fleet CA"),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "TestOrg"),
|
||||
])
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(name)
|
||||
.issuer_name(name)
|
||||
.public_key(key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(now + datetime.timedelta(days=3650))
|
||||
.add_extension(x509.BasicConstraints(ca=True, path_length=0), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=False, key_cert_sign=True, crl_sign=True,
|
||||
content_commitment=False, key_encipherment=False,
|
||||
data_encipherment=False, key_agreement=False,
|
||||
encipher_only=False, decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.sign(key, hashes.SHA256())
|
||||
)
|
||||
key_path = tmp_path / "ca.key"
|
||||
cert_path = tmp_path / "ca.crt"
|
||||
key_path.write_bytes(key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
serialization.NoEncryption(),
|
||||
))
|
||||
cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM))
|
||||
return cert_path, key_path
|
||||
|
||||
|
||||
def _make_agent_keypair(
|
||||
tmp_path: Path,
|
||||
name: str,
|
||||
ca_cert_path: Path,
|
||||
ca_key_path: Path,
|
||||
) -> Tuple[Path, Path]:
|
||||
"""Generate an agent cert signed by the test CA."""
|
||||
ca_cert = x509.load_pem_x509_certificate(ca_cert_path.read_bytes())
|
||||
ca_key = serialization.load_pem_private_key(
|
||||
ca_key_path.read_bytes(), password=None
|
||||
)
|
||||
|
||||
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
subject = x509.Name([
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, f"{name}.fleet.hermes"),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "TestOrg"),
|
||||
])
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(now + datetime.timedelta(days=365))
|
||||
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([
|
||||
x509.DNSName(f"{name}.fleet.hermes"),
|
||||
x509.DNSName(name),
|
||||
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.ExtendedKeyUsage([
|
||||
x509.ExtendedKeyUsageOID.CLIENT_AUTH,
|
||||
x509.ExtendedKeyUsageOID.SERVER_AUTH,
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
key_path = tmp_path / f"{name}.key"
|
||||
cert_path = tmp_path / f"{name}.crt"
|
||||
key_path.write_bytes(key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
serialization.NoEncryption(),
|
||||
))
|
||||
cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM))
|
||||
return cert_path, key_path
|
||||
|
||||
|
||||
def _make_self_signed_keypair(tmp_path: Path, name: str) -> Tuple[Path, Path]:
|
||||
"""Generate a self-signed cert NOT signed by the test CA (unauthorized)."""
|
||||
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
subject = x509.Name([
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, f"{name}.rogue"),
|
||||
])
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(subject)
|
||||
.public_key(key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(now + datetime.timedelta(days=365))
|
||||
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([x509.IPAddress(ipaddress.IPv4Address("127.0.0.1"))]),
|
||||
critical=False,
|
||||
)
|
||||
.sign(key, hashes.SHA256())
|
||||
)
|
||||
key_path = tmp_path / f"{name}_rogue.key"
|
||||
cert_path = tmp_path / f"{name}_rogue.crt"
|
||||
key_path.write_bytes(key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
serialization.NoEncryption(),
|
||||
))
|
||||
cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM))
|
||||
return cert_path, key_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests — no network I/O
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildSslContextErrors:
|
||||
def test_server_context_missing_cert(self, tmp_path):
|
||||
from agent.a2a_mtls import build_server_ssl_context
|
||||
with pytest.raises(FileNotFoundError, match="mTLS"):
|
||||
build_server_ssl_context(
|
||||
cert=tmp_path / "nope.crt",
|
||||
key=tmp_path / "nope.key",
|
||||
ca=tmp_path / "nope.crt",
|
||||
)
|
||||
|
||||
def test_client_context_missing_cert(self, tmp_path):
|
||||
from agent.a2a_mtls import build_client_ssl_context
|
||||
with pytest.raises(FileNotFoundError, match="mTLS client"):
|
||||
build_client_ssl_context(
|
||||
cert=tmp_path / "nope.crt",
|
||||
key=tmp_path / "nope.key",
|
||||
ca=tmp_path / "nope.crt",
|
||||
)
|
||||
|
||||
@_requires_crypto
|
||||
def test_server_context_builds_with_valid_certs(self, tmp_path):
|
||||
from agent.a2a_mtls import build_server_ssl_context
|
||||
ca_dir = tmp_path / "ca"
|
||||
ca_dir.mkdir()
|
||||
ca_crt, ca_key = _make_ca_keypair(ca_dir)
|
||||
srv_crt, srv_key = _make_agent_keypair(
|
||||
tmp_path, "srv", ca_crt, ca_key
|
||||
)
|
||||
ctx = build_server_ssl_context(cert=srv_crt, key=srv_key, ca=ca_crt)
|
||||
assert isinstance(ctx, ssl.SSLContext)
|
||||
assert ctx.verify_mode == ssl.CERT_REQUIRED
|
||||
|
||||
@_requires_crypto
|
||||
def test_client_context_builds_with_valid_certs(self, tmp_path):
|
||||
from agent.a2a_mtls import build_client_ssl_context
|
||||
ca_dir = tmp_path / "ca"
|
||||
ca_dir.mkdir()
|
||||
ca_crt, ca_key = _make_ca_keypair(ca_dir)
|
||||
cli_crt, cli_key = _make_agent_keypair(
|
||||
tmp_path, "cli", ca_crt, ca_key
|
||||
)
|
||||
ctx = build_client_ssl_context(cert=cli_crt, key=cli_key, ca=ca_crt)
|
||||
assert isinstance(ctx, ssl.SSLContext)
|
||||
assert ctx.verify_mode == ssl.CERT_REQUIRED
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests — loopback mTLS server
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _find_free_port() -> int:
|
||||
import socket
|
||||
with socket.socket() as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _https_get(url: str, ssl_ctx: ssl.SSLContext) -> int:
|
||||
"""Return the HTTP status code for a GET request, or raise SSLError."""
|
||||
req = urllib.request.urlopen(url, context=ssl_ctx, timeout=5)
|
||||
return req.status
|
||||
|
||||
|
||||
@_requires_crypto
|
||||
class TestMutualTLSAuth:
|
||||
"""End-to-end mTLS auth over a loopback connection."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _pki(self, tmp_path):
|
||||
"""Set up a fleet CA and agent certs for timmy (server) and allegro (authorized client)."""
|
||||
ca_dir = tmp_path / "ca"
|
||||
ca_dir.mkdir()
|
||||
self.ca_crt, self.ca_key = _make_ca_keypair(ca_dir)
|
||||
|
||||
agent_dir = tmp_path / "agents"
|
||||
agent_dir.mkdir()
|
||||
|
||||
# Server agent: timmy
|
||||
self.srv_crt, self.srv_key = _make_agent_keypair(
|
||||
agent_dir, "timmy", self.ca_crt, self.ca_key
|
||||
)
|
||||
# Authorized client agent: allegro
|
||||
self.cli_crt, self.cli_key = _make_agent_keypair(
|
||||
agent_dir, "allegro", self.ca_crt, self.ca_key
|
||||
)
|
||||
# Unauthorized (self-signed) client: rogue
|
||||
self.rogue_crt, self.rogue_key = _make_self_signed_keypair(agent_dir, "rogue")
|
||||
|
||||
@pytest.fixture()
|
||||
def running_server(self):
|
||||
"""Start an A2AServer on a free loopback port, yield the URL, stop after test."""
|
||||
from agent.a2a_mtls import A2AServer
|
||||
port = _find_free_port()
|
||||
server = A2AServer(
|
||||
cert=self.srv_crt,
|
||||
key=self.srv_key,
|
||||
ca=self.ca_crt,
|
||||
host="127.0.0.1",
|
||||
port=port,
|
||||
)
|
||||
server.start(daemon=True)
|
||||
time.sleep(0.15) # let the thread bind
|
||||
yield f"https://127.0.0.1:{port}"
|
||||
server.stop()
|
||||
|
||||
def _authorized_ctx(self) -> ssl.SSLContext:
|
||||
from agent.a2a_mtls import build_client_ssl_context
|
||||
ctx = build_client_ssl_context(
|
||||
cert=self.cli_crt, key=self.cli_key, ca=self.ca_crt
|
||||
)
|
||||
ctx.check_hostname = False # loopback IP doesn't match DNS SAN
|
||||
return ctx
|
||||
|
||||
def _unauthorized_ctx(self) -> ssl.SSLContext:
|
||||
"""Client context with a self-signed cert not trusted by the server CA."""
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
ctx.load_cert_chain(certfile=str(self.rogue_crt), keyfile=str(self.rogue_key))
|
||||
# Load the real fleet CA so server cert is accepted — but our client
|
||||
# cert is self-signed and will be rejected by the server.
|
||||
ctx.load_verify_locations(cafile=str(self.ca_crt))
|
||||
ctx.check_hostname = False
|
||||
return ctx
|
||||
|
||||
def _no_client_cert_ctx(self) -> ssl.SSLContext:
|
||||
"""Client context with no client certificate at all."""
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
ctx.load_verify_locations(cafile=str(self.ca_crt))
|
||||
ctx.check_hostname = False
|
||||
return ctx
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authorized agent accepted
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_authorized_agent_accepted(self, running_server):
|
||||
"""An agent with a fleet-CA-signed cert gets a 200-range response."""
|
||||
status = _https_get(
|
||||
running_server + "/.well-known/agent-card.json",
|
||||
self._authorized_ctx(),
|
||||
)
|
||||
assert status == 200
|
||||
|
||||
def test_authorized_agent_task_endpoint(self, running_server):
|
||||
"""POST /a2a/task returns 202 for an authorized agent."""
|
||||
import urllib.request
|
||||
req = urllib.request.Request(
|
||||
running_server + "/a2a/task",
|
||||
data=b'{"hello":"world"}',
|
||||
method="POST",
|
||||
)
|
||||
req.add_header("Content-Type", "application/json")
|
||||
resp = urllib.request.urlopen(req, context=self._authorized_ctx(), timeout=5)
|
||||
assert resp.status == 202
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unauthorized agent rejected
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_unauthorized_agent_rejected(self, running_server):
|
||||
"""A self-signed cert not signed by the fleet CA is rejected at TLS handshake."""
|
||||
with pytest.raises((ssl.SSLError, OSError)):
|
||||
_https_get(running_server + "/", self._unauthorized_ctx())
|
||||
|
||||
def test_no_client_cert_rejected(self, running_server):
|
||||
"""A client with no cert at all is rejected at TLS handshake."""
|
||||
with pytest.raises((ssl.SSLError, OSError)):
|
||||
_https_get(running_server + "/", self._no_client_cert_ctx())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Server lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_server_stop_is_idempotent(self):
|
||||
"""Calling stop() twice does not raise."""
|
||||
from agent.a2a_mtls import A2AServer
|
||||
port = _find_free_port()
|
||||
server = A2AServer(
|
||||
cert=self.srv_crt, key=self.srv_key, ca=self.ca_crt,
|
||||
host="127.0.0.1", port=port,
|
||||
)
|
||||
server.start(daemon=True)
|
||||
time.sleep(0.1)
|
||||
server.stop()
|
||||
server.stop() # second call must not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# server_from_env() — environment variable wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestServerFromEnv:
|
||||
def test_reads_env_vars(self, tmp_path, monkeypatch):
|
||||
# Create dummy files so FileNotFoundError isn't triggered
|
||||
cert = tmp_path / "a.crt"
|
||||
key = tmp_path / "a.key"
|
||||
ca = tmp_path / "ca.crt"
|
||||
for f in (cert, key, ca):
|
||||
f.write_text("PLACEHOLDER")
|
||||
|
||||
monkeypatch.setenv("HERMES_A2A_CERT", str(cert))
|
||||
monkeypatch.setenv("HERMES_A2A_KEY", str(key))
|
||||
monkeypatch.setenv("HERMES_A2A_CA", str(ca))
|
||||
monkeypatch.setenv("HERMES_A2A_HOST", "127.0.0.2")
|
||||
monkeypatch.setenv("HERMES_A2A_PORT", "19443")
|
||||
|
||||
from agent.a2a_mtls import server_from_env
|
||||
srv = server_from_env()
|
||||
assert srv.cert == cert
|
||||
assert srv.key == key
|
||||
assert srv.ca == ca
|
||||
assert srv.host == "127.0.0.2"
|
||||
assert srv.port == 19443
|
||||
|
||||
def test_uses_agent_name_for_defaults(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("HERMES_AGENT_NAME", "ezra")
|
||||
# Unset explicit cert overrides
|
||||
monkeypatch.delenv("HERMES_A2A_CERT", raising=False)
|
||||
monkeypatch.delenv("HERMES_A2A_KEY", raising=False)
|
||||
monkeypatch.delenv("HERMES_A2A_CA", raising=False)
|
||||
|
||||
from agent.a2a_mtls import server_from_env
|
||||
srv = server_from_env()
|
||||
assert "ezra" in str(srv.cert)
|
||||
assert "ezra" in str(srv.key)
|
||||
assert "fleet-ca" in str(srv.ca)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# A2AMTLSServer and A2AMTLSClient — routing server + client helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@_requires_crypto
|
||||
class TestA2AMTLSServerAndClient:
|
||||
"""Tests for the routing-based A2AMTLSServer and A2AMTLSClient."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _pki(self, tmp_path):
|
||||
ca_dir = tmp_path / "ca"
|
||||
ca_dir.mkdir()
|
||||
self.ca_crt, self.ca_key = _make_ca_keypair(ca_dir)
|
||||
agent_dir = tmp_path / "agents"
|
||||
agent_dir.mkdir()
|
||||
self.srv_crt, self.srv_key = _make_agent_keypair(
|
||||
agent_dir, "timmy", self.ca_crt, self.ca_key
|
||||
)
|
||||
self.cli_crt, self.cli_key = _make_agent_keypair(
|
||||
agent_dir, "allegro", self.ca_crt, self.ca_key
|
||||
)
|
||||
self.rogue_crt, self.rogue_key = _make_self_signed_keypair(agent_dir, "rogue")
|
||||
|
||||
@pytest.fixture()
|
||||
def routing_server(self):
|
||||
from agent.a2a_mtls import A2AMTLSServer
|
||||
port = _find_free_port()
|
||||
server = A2AMTLSServer(
|
||||
cert=self.srv_crt, key=self.srv_key, ca=self.ca_crt,
|
||||
host="127.0.0.1", port=port,
|
||||
)
|
||||
server.add_route("/echo", lambda p, *, peer_cn=None: {"echo": p, "peer": peer_cn})
|
||||
server.add_route("/tasks/send", lambda p, *, peer_cn=None: {"status": "ok", "echo": p})
|
||||
with server:
|
||||
time.sleep(0.1)
|
||||
yield server, port
|
||||
|
||||
def _authorized_ctx(self) -> ssl.SSLContext:
|
||||
from agent.a2a_mtls import build_client_ssl_context
|
||||
ctx = build_client_ssl_context(
|
||||
cert=self.cli_crt, key=self.cli_key, ca=self.ca_crt
|
||||
)
|
||||
ctx.check_hostname = False
|
||||
return ctx
|
||||
|
||||
def test_routing_server_get(self, routing_server):
|
||||
server, port = routing_server
|
||||
ctx = self._authorized_ctx()
|
||||
req = urllib.request.Request(f"https://127.0.0.1:{port}/echo")
|
||||
with urllib.request.urlopen(req, context=ctx, timeout=5) as resp:
|
||||
import json
|
||||
data = json.loads(resp.read())
|
||||
assert data["peer"] is not None # CN present
|
||||
|
||||
def test_routing_server_post_payload(self, routing_server):
|
||||
server, port = routing_server
|
||||
ctx = self._authorized_ctx()
|
||||
import json
|
||||
payload = {"task_id": "abc", "action": "delegate"}
|
||||
req = urllib.request.Request(
|
||||
f"https://127.0.0.1:{port}/tasks/send",
|
||||
data=json.dumps(payload).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
method="POST",
|
||||
)
|
||||
with urllib.request.urlopen(req, context=ctx, timeout=5) as resp:
|
||||
data = json.loads(resp.read())
|
||||
assert data["status"] == "ok"
|
||||
assert data["echo"]["task_id"] == "abc"
|
||||
|
||||
def test_routing_server_unknown_route_404(self, routing_server):
|
||||
server, port = routing_server
|
||||
ctx = self._authorized_ctx()
|
||||
req = urllib.request.Request(f"https://127.0.0.1:{port}/nonexistent")
|
||||
with pytest.raises(urllib.error.URLError) as exc_info:
|
||||
urllib.request.urlopen(req, context=ctx, timeout=5)
|
||||
assert "404" in str(exc_info.value)
|
||||
|
||||
def test_routing_server_context_manager_stops(self):
|
||||
from agent.a2a_mtls import A2AMTLSServer
|
||||
port = _find_free_port()
|
||||
server = A2AMTLSServer(
|
||||
cert=self.srv_crt, key=self.srv_key, ca=self.ca_crt,
|
||||
host="127.0.0.1", port=port,
|
||||
)
|
||||
server.add_route("/ping", lambda p, *, peer_cn=None: {"pong": True})
|
||||
with server:
|
||||
time.sleep(0.05)
|
||||
assert server._httpd is not None
|
||||
assert server._httpd is None # stopped after __exit__
|
||||
|
||||
def test_routing_server_rogue_client_rejected(self, routing_server):
|
||||
server, port = routing_server
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.load_verify_locations(cafile=str(self.ca_crt))
|
||||
ctx.load_cert_chain(certfile=str(self.rogue_crt), keyfile=str(self.rogue_key))
|
||||
ctx.check_hostname = False
|
||||
req = urllib.request.Request(f"https://127.0.0.1:{port}/echo")
|
||||
with pytest.raises((ssl.SSLError, OSError, urllib.error.URLError)):
|
||||
urllib.request.urlopen(req, context=ctx, timeout=5)
|
||||
|
||||
def test_a2a_mtls_client_get(self, routing_server):
|
||||
from agent.a2a_mtls import A2AMTLSClient
|
||||
server, port = routing_server
|
||||
client = A2AMTLSClient(
|
||||
cert=self.cli_crt, key=self.cli_key, ca=self.ca_crt
|
||||
)
|
||||
result = client.get(f"https://127.0.0.1:{port}/echo")
|
||||
assert result["peer"] is not None
|
||||
|
||||
def test_a2a_mtls_client_post(self, routing_server):
|
||||
from agent.a2a_mtls import A2AMTLSClient
|
||||
server, port = routing_server
|
||||
client = A2AMTLSClient(
|
||||
cert=self.cli_crt, key=self.cli_key, ca=self.ca_crt
|
||||
)
|
||||
result = client.post(f"https://127.0.0.1:{port}/tasks/send", json={"x": 1})
|
||||
assert result["status"] == "ok"
|
||||
assert result["echo"]["x"] == 1
|
||||
|
||||
def test_a2a_mtls_client_rogue_cert_raises(self, routing_server):
|
||||
from agent.a2a_mtls import A2AMTLSClient
|
||||
server, port = routing_server
|
||||
client = A2AMTLSClient(
|
||||
cert=self.rogue_crt, key=self.rogue_key, ca=self.ca_crt
|
||||
)
|
||||
with pytest.raises((ConnectionError, ssl.SSLError, OSError)):
|
||||
client.get(f"https://127.0.0.1:{port}/echo")
|
||||
|
||||
def test_concurrent_fleet_agents(self, routing_server):
|
||||
"""timmy (server) accepts concurrent connections from multiple authorized clients."""
|
||||
from agent.a2a_mtls import build_client_ssl_context
|
||||
server, port = routing_server
|
||||
results: dict = {}
|
||||
errors: dict = {}
|
||||
|
||||
def connect(name: str) -> None:
|
||||
try:
|
||||
ctx = build_client_ssl_context(
|
||||
cert=self.cli_crt, key=self.cli_key, ca=self.ca_crt
|
||||
)
|
||||
ctx.check_hostname = False
|
||||
req = urllib.request.Request(f"https://127.0.0.1:{port}/echo")
|
||||
with urllib.request.urlopen(req, context=ctx, timeout=5) as resp:
|
||||
import json
|
||||
results[name] = json.loads(resp.read())
|
||||
except Exception as exc:
|
||||
errors[name] = exc
|
||||
|
||||
threads = [threading.Thread(target=connect, args=(n,)) for n in ("t1", "t2", "t3")]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=10)
|
||||
|
||||
assert not errors, f"Concurrent connection errors: {errors}"
|
||||
assert len(results) == 3
|
||||
@@ -10,6 +10,7 @@ from gateway.config import (
|
||||
PlatformConfig,
|
||||
SessionResetPolicy,
|
||||
_apply_env_overrides,
|
||||
_validate_gateway_config,
|
||||
load_gateway_config,
|
||||
)
|
||||
|
||||
@@ -294,3 +295,151 @@ class TestHomeChannelEnvOverrides:
|
||||
home = config.platforms[platform].home_channel
|
||||
assert home is not None, f"{platform.value}: home_channel should not be None"
|
||||
assert (home.chat_id, home.name) == expected, platform.value
|
||||
|
||||
|
||||
class TestValidateGatewayConfig:
|
||||
"""Tests for _validate_gateway_config — in-place sanitisation of loaded config."""
|
||||
|
||||
# -- idle_minutes validation --
|
||||
|
||||
def test_idle_minutes_zero_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.idle_minutes = 0
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.idle_minutes == 1440
|
||||
|
||||
def test_idle_minutes_negative_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.idle_minutes = -60
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.idle_minutes == 1440
|
||||
|
||||
def test_idle_minutes_none_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.idle_minutes = None # type: ignore[assignment]
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.idle_minutes == 1440
|
||||
|
||||
def test_valid_idle_minutes_is_unchanged(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.idle_minutes = 90
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.idle_minutes == 90
|
||||
|
||||
# -- at_hour validation --
|
||||
|
||||
def test_at_hour_too_high_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.at_hour = 24
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.at_hour == 4
|
||||
|
||||
def test_at_hour_negative_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.at_hour = -1
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.at_hour == 4
|
||||
|
||||
def test_valid_at_hour_is_unchanged(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.at_hour = 3
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.at_hour == 3
|
||||
|
||||
def test_at_hour_boundary_values_are_valid(self):
|
||||
for valid_hour in (0, 23):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.at_hour = valid_hour
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.at_hour == valid_hour
|
||||
|
||||
# -- empty-token warning (enabled platforms) --
|
||||
|
||||
def test_empty_string_token_logs_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token=""),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert any(
|
||||
"TELEGRAM_BOT_TOKEN" in r.message and "empty" in r.message
|
||||
for r in caplog.records
|
||||
)
|
||||
|
||||
def test_disabled_platform_with_empty_token_no_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=False, token=""),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert not any("TELEGRAM_BOT_TOKEN" in r.message for r in caplog.records)
|
||||
|
||||
# -- API Server key / binding warnings --
|
||||
|
||||
def test_api_server_network_binding_without_key_logs_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.API_SERVER: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"host": "0.0.0.0"},
|
||||
),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert any(
|
||||
"API_SERVER_KEY" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
def test_api_server_loopback_without_key_no_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.API_SERVER: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"host": "127.0.0.1"},
|
||||
),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert not any(
|
||||
"API_SERVER_KEY" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
def test_api_server_network_binding_with_key_no_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.API_SERVER: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"host": "0.0.0.0", "key": "sk-real-key-here"},
|
||||
),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert not any(
|
||||
"API_SERVER_KEY" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
def test_api_server_default_loopback_without_key_no_warning(self, caplog):
|
||||
"""API server with no explicit host defaults to 127.0.0.1 — no warning."""
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.API_SERVER: PlatformConfig(enabled=True),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert not any(
|
||||
"API_SERVER_KEY" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
97
tests/test_circuit_breaker.py
Normal file
97
tests/test_circuit_breaker.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Tests for circuit breaker (#885)."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from agent.circuit_breaker import CircuitBreaker, ToolCircuitBreaker, MultiToolCircuitBreaker, CircuitState
|
||||
|
||||
|
||||
def test_closed_allows_execution():
|
||||
cb = CircuitBreaker(failure_threshold=3)
|
||||
assert cb.can_execute()
|
||||
|
||||
|
||||
def test_opens_after_threshold():
|
||||
cb = CircuitBreaker(failure_threshold=3)
|
||||
cb.record_result(False)
|
||||
cb.record_result(False)
|
||||
assert cb.can_execute() # Still closed at 2
|
||||
cb.record_result(False)
|
||||
assert not cb.can_execute() # Open at 3
|
||||
|
||||
|
||||
def test_closes_on_success():
|
||||
cb = CircuitBreaker(failure_threshold=3)
|
||||
cb.record_result(False)
|
||||
cb.record_result(True)
|
||||
assert cb.consecutive_failures == 0
|
||||
|
||||
|
||||
def test_half_open_recovery():
|
||||
cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.1, success_threshold=1)
|
||||
cb.record_result(False)
|
||||
cb.record_result(False)
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
import time
|
||||
time.sleep(0.15)
|
||||
|
||||
assert cb.can_execute() # Moved to half-open
|
||||
cb.record_result(True)
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
|
||||
def test_recovery_action_streak():
|
||||
cb = ToolCircuitBreaker(failure_threshold=3)
|
||||
for _ in range(5):
|
||||
cb.record_result(False)
|
||||
action = cb.get_recovery_action()
|
||||
assert action["action"] == "switch_tool_type"
|
||||
|
||||
|
||||
def test_recovery_action_critical():
|
||||
cb = ToolCircuitBreaker(failure_threshold=3)
|
||||
for _ in range(10):
|
||||
cb.record_result(False)
|
||||
action = cb.get_recovery_action()
|
||||
assert action["action"] == "terminal_only"
|
||||
assert action["severity"] == "critical"
|
||||
|
||||
|
||||
def test_multi_tool_breaker():
|
||||
mcb = MultiToolCircuitBreaker()
|
||||
mcb.record_result("read_file", False)
|
||||
mcb.record_result("read_file", False)
|
||||
mcb.record_result("read_file", False)
|
||||
assert not mcb.can_execute("read_file")
|
||||
assert mcb.can_execute("terminal") # Different tool unaffected
|
||||
|
||||
|
||||
def test_global_state():
|
||||
mcb = MultiToolCircuitBreaker()
|
||||
mcb.record_result("tool_a", False)
|
||||
mcb.record_result("tool_b", False)
|
||||
state = mcb.get_global_state()
|
||||
assert state["global_streak"] == 2
|
||||
|
||||
|
||||
def test_reset():
|
||||
cb = CircuitBreaker(failure_threshold=2)
|
||||
cb.record_result(False)
|
||||
cb.record_result(False)
|
||||
assert cb.state == CircuitState.OPEN
|
||||
cb.reset()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = [test_closed_allows_execution, test_opens_after_threshold,
|
||||
test_closes_on_success, test_half_open_recovery,
|
||||
test_recovery_action_streak, test_recovery_action_critical,
|
||||
test_multi_tool_breaker, test_global_state, test_reset]
|
||||
for t in tests:
|
||||
print(f"Running {t.__name__}...")
|
||||
t()
|
||||
print(" PASS")
|
||||
print("\nAll tests passed.")
|
||||
127
tests/test_context_budget.py
Normal file
127
tests/test_context_budget.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Tests for context budget tracker
|
||||
|
||||
Issue: #838
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent.context_budget import (
|
||||
ContextBudget,
|
||||
ContextBudgetTracker,
|
||||
estimate_tokens,
|
||||
estimate_messages_tokens,
|
||||
check_context_budget,
|
||||
preflight_token_check,
|
||||
THRESHOLD_WARNING,
|
||||
THRESHOLD_CRITICAL,
|
||||
THRESHOLD_DANGER,
|
||||
)
|
||||
|
||||
|
||||
class TestContextBudget(unittest.TestCase):
|
||||
|
||||
def test_basic_budget(self):
|
||||
b = ContextBudget(context_limit=10000)
|
||||
self.assertEqual(b.available, 8000) # 10000 - 2000 reserved
|
||||
self.assertEqual(b.remaining, 8000)
|
||||
self.assertEqual(b.utilization, 0.0)
|
||||
|
||||
def test_utilization(self):
|
||||
b = ContextBudget(context_limit=10000, used_tokens=4000)
|
||||
self.assertEqual(b.utilization, 0.5)
|
||||
self.assertEqual(b.remaining, 4000)
|
||||
|
||||
|
||||
class TestTokenEstimation(unittest.TestCase):
|
||||
|
||||
def test_estimate_tokens(self):
|
||||
self.assertEqual(estimate_tokens(""), 0)
|
||||
self.assertEqual(estimate_tokens("a" * 4), 1)
|
||||
self.assertEqual(estimate_tokens("a" * 400), 100)
|
||||
|
||||
def test_estimate_messages(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "a" * 400},
|
||||
{"role": "assistant", "content": "b" * 800},
|
||||
]
|
||||
tokens = estimate_messages_tokens(messages)
|
||||
self.assertEqual(tokens, 300) # 100 + 200
|
||||
|
||||
|
||||
class TestContextBudgetTracker(unittest.TestCase):
|
||||
|
||||
def test_warning_at_70_percent(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 5600 # 70% of 8000 available
|
||||
warning = tracker.get_warning()
|
||||
self.assertIsNotNone(warning)
|
||||
self.assertIn("70", warning)
|
||||
|
||||
def test_critical_at_85_percent(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
with patch("agent.context_budget.CHECKPOINT_DIR", Path(tmp)):
|
||||
tracker = ContextBudgetTracker(context_limit=10000, session_id="test")
|
||||
tracker.budget.used_tokens = 6800 # 85% of 8000
|
||||
warning = tracker.get_warning()
|
||||
self.assertIsNotNone(warning)
|
||||
self.assertIn("85", warning)
|
||||
|
||||
def test_danger_at_95_percent(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 7600 # 95% of 8000
|
||||
warning = tracker.get_warning()
|
||||
self.assertIsNotNone(warning)
|
||||
self.assertIn("CRITICAL", warning)
|
||||
|
||||
def test_can_fit(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 5000
|
||||
self.assertTrue(tracker.can_fit(1000))
|
||||
self.assertFalse(tracker.can_fit(5000))
|
||||
|
||||
def test_preflight_check(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 5000
|
||||
|
||||
can_fit, msg = tracker.preflight_check("a" * 400) # 100 tokens
|
||||
self.assertTrue(can_fit)
|
||||
self.assertEqual(msg, "")
|
||||
|
||||
|
||||
class TestCheckContextBudget(unittest.TestCase):
|
||||
|
||||
def test_no_warning_under_threshold(self):
|
||||
with patch("agent.context_budget._tracker", None):
|
||||
messages = [{"role": "user", "content": "short"}]
|
||||
warning = check_context_budget(messages)
|
||||
self.assertIsNone(warning)
|
||||
|
||||
def test_warning_over_threshold(self):
|
||||
with patch("agent.context_budget._tracker", None):
|
||||
# Create messages that exceed 70% of default 128k context
|
||||
messages = [{"role": "user", "content": "x" * 350000}] # ~87500 tokens
|
||||
warning = check_context_budget(messages)
|
||||
self.assertIsNotNone(warning)
|
||||
|
||||
|
||||
class TestStatusLine(unittest.TestCase):
|
||||
|
||||
def test_green_status(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
line = tracker.get_status_line()
|
||||
self.assertIn("GREEN", line)
|
||||
|
||||
def test_red_status(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 7600
|
||||
line = tracker.get_status_line()
|
||||
self.assertIn("RED", line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
101
tests/test_credential_redact.py
Normal file
101
tests/test_credential_redact.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Tests for credential redaction
|
||||
|
||||
Issue: #839
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from tools.credential_redact import (
|
||||
CredentialRedactor,
|
||||
redact_credentials,
|
||||
redact_tool_output,
|
||||
should_mask_file,
|
||||
mask_sensitive_file,
|
||||
)
|
||||
|
||||
|
||||
class TestCredentialRedaction(unittest.TestCase):
|
||||
|
||||
def test_openai_key(self):
|
||||
text = "api_key=sk-abc123def456ghi789jkl012mno"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreater(count, 0)
|
||||
self.assertIn("REDACTED", redacted)
|
||||
self.assertNotIn("sk-abc123", redacted)
|
||||
|
||||
def test_github_token(self):
|
||||
text = "token: ghp_1234567890abcdef1234567890abcdef12345678"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreater(count, 0)
|
||||
self.assertIn("REDACTED", redacted)
|
||||
|
||||
def test_bearer_token(self):
|
||||
text = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreater(count, 0)
|
||||
self.assertIn("REDACTED", redacted)
|
||||
|
||||
def test_password(self):
|
||||
text = "password: mySecretPassword123"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreater(count, 0)
|
||||
self.assertIn("REDACTED", redacted)
|
||||
|
||||
def test_aws_key(self):
|
||||
text = "AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreater(count, 0)
|
||||
self.assertIn("REDACTED", redacted)
|
||||
|
||||
def test_database_url(self):
|
||||
text = "DATABASE_URL=postgres://user:pass@localhost/db"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreater(count, 0)
|
||||
self.assertIn("REDACTED", redacted)
|
||||
|
||||
def test_clean_text_unchanged(self):
|
||||
text = "Hello world, this is a normal message"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertEqual(count, 0)
|
||||
self.assertEqual(redacted, text)
|
||||
|
||||
def test_multiple_credentials(self):
|
||||
text = "key1=sk-abc123def456ghi789jkl012mno and token: ghp_1234567890abcdef1234567890abcdef12345678"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreaterEqual(count, 2)
|
||||
|
||||
|
||||
class TestToolOutputRedaction(unittest.TestCase):
|
||||
|
||||
def test_redaction_notice(self):
|
||||
output = "Running with key sk-abc123def456ghi789jkl012mno"
|
||||
redacted, notice = redact_tool_output("terminal", output)
|
||||
self.assertIn("REDACTED", notice)
|
||||
self.assertIn("terminal", notice)
|
||||
|
||||
def test_no_notice_when_clean(self):
|
||||
output = "Hello world"
|
||||
redacted, notice = redact_tool_output("terminal", output)
|
||||
self.assertEqual(notice, "")
|
||||
|
||||
|
||||
class TestSensitiveFileMasking(unittest.TestCase):
|
||||
|
||||
def test_env_file_detected(self):
|
||||
self.assertTrue(should_mask_file("/path/to/.env"))
|
||||
self.assertTrue(should_mask_file("/path/to/.env.local"))
|
||||
self.assertTrue(should_mask_file("/path/to/config.yaml"))
|
||||
|
||||
def test_normal_file_not_detected(self):
|
||||
self.assertFalse(should_mask_file("/path/to/readme.md"))
|
||||
self.assertFalse(should_mask_file("/path/to/code.py"))
|
||||
|
||||
def test_mask_env_file(self):
|
||||
content = "API_KEY=sk-abc123\nDATABASE_URL=postgres://u:p@h/d\nNORMAL=value"
|
||||
masked = mask_sensitive_file(content, ".env")
|
||||
self.assertIn("[REDACTED]", masked)
|
||||
self.assertIn("NORMAL=value", masked)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
79
tests/test_crisis_resources.py
Normal file
79
tests/test_crisis_resources.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Tests for 988 Crisis Lifeline integration (#673)."""
|
||||
|
||||
import pytest
|
||||
from agent.crisis_resources import (
|
||||
LIFELINE_988,
|
||||
LIFELINE_988_TEXT,
|
||||
LIFELINE_988_CHAT,
|
||||
LIFELINE_988_SPANISH,
|
||||
CRISIS_TEXT_LINE,
|
||||
EMERGENCY_911,
|
||||
ALL_RESOURCES,
|
||||
get_crisis_resources,
|
||||
format_crisis_resources,
|
||||
get_immediate_help_message,
|
||||
CrisisResource,
|
||||
)
|
||||
|
||||
|
||||
class TestCrisisResources:
|
||||
def test_988_phone(self):
|
||||
assert "988" in LIFELINE_988.contact
|
||||
assert "24/7" in LIFELINE_988.available
|
||||
|
||||
def test_988_text(self):
|
||||
assert "HOME" in LIFELINE_988_TEXT.contact
|
||||
assert "988" in LIFELINE_988_TEXT.contact
|
||||
|
||||
def test_988_chat(self):
|
||||
assert "988lifeline.org/chat" in LIFELINE_988_CHAT.url
|
||||
|
||||
def test_988_spanish(self):
|
||||
assert "1-888-628-9454" in LIFELINE_988_SPANISH.contact
|
||||
assert LIFELINE_988_SPANISH.language == "Spanish"
|
||||
|
||||
def test_crisis_text_line(self):
|
||||
assert "741741" in CRISIS_TEXT_LINE.contact
|
||||
|
||||
def test_911(self):
|
||||
assert "911" in EMERGENCY_911.contact
|
||||
|
||||
def test_all_resources_not_empty(self):
|
||||
assert len(ALL_RESOURCES) >= 5
|
||||
|
||||
|
||||
class TestGetResources:
|
||||
def test_returns_all_by_default(self):
|
||||
assert len(get_crisis_resources()) == len(ALL_RESOURCES)
|
||||
|
||||
def test_filter_english(self):
|
||||
english = get_crisis_resources("English")
|
||||
assert all(r.language == "English" for r in english)
|
||||
assert len(english) > 0
|
||||
|
||||
def test_filter_spanish(self):
|
||||
spanish = get_crisis_resources("Spanish")
|
||||
assert len(spanish) >= 1
|
||||
assert all(r.language == "Spanish" for r in spanish)
|
||||
|
||||
|
||||
class TestFormatting:
|
||||
def test_format_includes_988(self):
|
||||
msg = format_crisis_resources()
|
||||
assert "988" in msg
|
||||
|
||||
def test_format_includes_741741(self):
|
||||
msg = format_crisis_resources()
|
||||
assert "741741" in msg
|
||||
|
||||
def test_format_includes_911(self):
|
||||
msg = format_crisis_resources()
|
||||
assert "911" in msg
|
||||
|
||||
def test_immediate_help_includes_911_first(self):
|
||||
msg = get_immediate_help_message()
|
||||
assert msg.startswith("If you are in immediate danger")
|
||||
|
||||
def test_format_not_empty(self):
|
||||
msg = format_crisis_resources()
|
||||
assert len(msg) > 100
|
||||
389
tests/test_mtls.py
Normal file
389
tests/test_mtls.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""
|
||||
Tests for agent/mtls.py — mutual TLS between fleet agents.
|
||||
|
||||
Covers:
|
||||
- is_mtls_configured() with various env combinations
|
||||
- build_server_ssl_context() / build_client_ssl_context() with real certs
|
||||
- MTLSMiddleware: authorized agent accepted, unauthorized agent rejected
|
||||
"""
|
||||
|
||||
import ssl
|
||||
import datetime
|
||||
import ipaddress
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers: generate real in-memory certs using the `cryptography` library
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from cryptography import x509
|
||||
from cryptography.x509.oid import NameOID, ExtendedKeyUsageOID
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
_CRYPTO_AVAILABLE = True
|
||||
except ImportError:
|
||||
_CRYPTO_AVAILABLE = False
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not _CRYPTO_AVAILABLE,
|
||||
reason="cryptography package required for mTLS tests",
|
||||
)
|
||||
|
||||
|
||||
def _make_key():
|
||||
return rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
|
||||
def _write_pem(path: Path, data: bytes) -> None:
|
||||
path.write_bytes(data)
|
||||
path.chmod(0o600)
|
||||
|
||||
|
||||
def make_fleet_pki(tmp_path: Path):
|
||||
"""
|
||||
Create a minimal Fleet PKI in tmp_path:
|
||||
- fleet-ca.key / fleet-ca.crt (self-signed CA)
|
||||
- agent.key / agent.crt (signed by fleet CA, CN=test-agent)
|
||||
- rogue.key / rogue.crt (self-signed, NOT signed by fleet CA)
|
||||
|
||||
Returns a dict of Path objects.
|
||||
"""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
# --- Fleet CA ---
|
||||
ca_key = _make_key()
|
||||
ca_name = x509.Name([
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, "Hermes Fleet CA"),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Hermes Fleet"),
|
||||
])
|
||||
ca_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(ca_name)
|
||||
.issuer_name(ca_name)
|
||||
.public_key(ca_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(now + datetime.timedelta(days=3650))
|
||||
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=False, content_commitment=False,
|
||||
key_encipherment=False, data_encipherment=False,
|
||||
key_agreement=False, key_cert_sign=True, crl_sign=True,
|
||||
encipher_only=False, decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
# --- Fleet agent cert ---
|
||||
agent_key = _make_key()
|
||||
agent_name = x509.Name([
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, "test-agent"),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Hermes Fleet"),
|
||||
])
|
||||
agent_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(agent_name)
|
||||
.issuer_name(ca_name)
|
||||
.public_key(agent_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(now + datetime.timedelta(days=730))
|
||||
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([
|
||||
x509.DNSName("test-agent"),
|
||||
x509.DNSName("localhost"),
|
||||
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.ExtendedKeyUsage([
|
||||
ExtendedKeyUsageOID.CLIENT_AUTH,
|
||||
ExtendedKeyUsageOID.SERVER_AUTH,
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
# --- Rogue cert (self-signed, not from fleet CA) ---
|
||||
rogue_key = _make_key()
|
||||
rogue_name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "rogue-agent")])
|
||||
rogue_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(rogue_name)
|
||||
.issuer_name(rogue_name)
|
||||
.public_key(rogue_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(now + datetime.timedelta(days=365))
|
||||
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||||
.sign(rogue_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
# Write to tmp_path
|
||||
pem = serialization.Encoding.PEM
|
||||
private_fmt = serialization.PrivateFormat.TraditionalOpenSSL
|
||||
no_enc = serialization.NoEncryption()
|
||||
|
||||
paths = {}
|
||||
|
||||
paths["ca_key"] = tmp_path / "fleet-ca.key"
|
||||
_write_pem(paths["ca_key"], ca_key.private_bytes(pem, private_fmt, no_enc))
|
||||
|
||||
paths["ca_cert"] = tmp_path / "fleet-ca.crt"
|
||||
_write_pem(paths["ca_cert"], ca_cert.public_bytes(pem))
|
||||
|
||||
paths["agent_key"] = tmp_path / "agent.key"
|
||||
_write_pem(paths["agent_key"], agent_key.private_bytes(pem, private_fmt, no_enc))
|
||||
|
||||
paths["agent_cert"] = tmp_path / "agent.crt"
|
||||
_write_pem(paths["agent_cert"], agent_cert.public_bytes(pem))
|
||||
|
||||
paths["rogue_key"] = tmp_path / "rogue.key"
|
||||
_write_pem(paths["rogue_key"], rogue_key.private_bytes(pem, private_fmt, no_enc))
|
||||
|
||||
paths["rogue_cert"] = tmp_path / "rogue.crt"
|
||||
_write_pem(paths["rogue_cert"], rogue_cert.public_bytes(pem))
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: is_mtls_configured
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsMtlsConfigured:
|
||||
def test_all_vars_missing(self):
|
||||
from agent.mtls import is_mtls_configured
|
||||
env = {k: "" for k in ("HERMES_MTLS_CERT", "HERMES_MTLS_KEY", "HERMES_MTLS_CA")}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
assert not is_mtls_configured()
|
||||
|
||||
def test_partial_vars(self, tmp_path):
|
||||
from agent.mtls import is_mtls_configured
|
||||
f = tmp_path / "cert.pem"
|
||||
f.write_text("x")
|
||||
env = {"HERMES_MTLS_CERT": str(f), "HERMES_MTLS_KEY": "", "HERMES_MTLS_CA": ""}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
assert not is_mtls_configured()
|
||||
|
||||
def test_all_vars_set_but_file_missing(self, tmp_path):
|
||||
from agent.mtls import is_mtls_configured
|
||||
env = {
|
||||
"HERMES_MTLS_CERT": str(tmp_path / "no.crt"),
|
||||
"HERMES_MTLS_KEY": str(tmp_path / "no.key"),
|
||||
"HERMES_MTLS_CA": str(tmp_path / "no-ca.crt"),
|
||||
}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
assert not is_mtls_configured()
|
||||
|
||||
def test_all_vars_set_and_files_exist(self, tmp_path):
|
||||
from agent.mtls import is_mtls_configured
|
||||
for name in ("cert.pem", "key.pem", "ca.pem"):
|
||||
(tmp_path / name).write_text("x")
|
||||
env = {
|
||||
"HERMES_MTLS_CERT": str(tmp_path / "cert.pem"),
|
||||
"HERMES_MTLS_KEY": str(tmp_path / "key.pem"),
|
||||
"HERMES_MTLS_CA": str(tmp_path / "ca.pem"),
|
||||
}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
assert is_mtls_configured()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: build_server_ssl_context / build_client_ssl_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildSslContexts:
|
||||
def test_raises_when_not_configured(self):
|
||||
from agent.mtls import build_server_ssl_context, build_client_ssl_context
|
||||
env = {"HERMES_MTLS_CERT": "", "HERMES_MTLS_KEY": "", "HERMES_MTLS_CA": ""}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
with pytest.raises(RuntimeError, match="not configured"):
|
||||
build_server_ssl_context()
|
||||
with pytest.raises(RuntimeError, match="not configured"):
|
||||
build_client_ssl_context()
|
||||
|
||||
def test_server_context_requires_client_cert(self, tmp_path):
|
||||
from agent.mtls import build_server_ssl_context
|
||||
pki = make_fleet_pki(tmp_path)
|
||||
env = {
|
||||
"HERMES_MTLS_CERT": str(pki["agent_cert"]),
|
||||
"HERMES_MTLS_KEY": str(pki["agent_key"]),
|
||||
"HERMES_MTLS_CA": str(pki["ca_cert"]),
|
||||
}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
ctx = build_server_ssl_context()
|
||||
assert isinstance(ctx, ssl.SSLContext)
|
||||
assert ctx.verify_mode == ssl.CERT_REQUIRED
|
||||
|
||||
def test_client_context_has_cert_required(self, tmp_path):
|
||||
from agent.mtls import build_client_ssl_context
|
||||
pki = make_fleet_pki(tmp_path)
|
||||
env = {
|
||||
"HERMES_MTLS_CERT": str(pki["agent_cert"]),
|
||||
"HERMES_MTLS_KEY": str(pki["agent_key"]),
|
||||
"HERMES_MTLS_CA": str(pki["ca_cert"]),
|
||||
}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
ctx = build_client_ssl_context()
|
||||
assert isinstance(ctx, ssl.SSLContext)
|
||||
assert ctx.verify_mode == ssl.CERT_REQUIRED
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: MTLSMiddleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_scope(path: str, peer_cert=None) -> dict:
|
||||
"""Build a minimal ASGI HTTP scope, optionally with a fake TLS peer_cert."""
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": path,
|
||||
"extensions": {},
|
||||
}
|
||||
if peer_cert is not None:
|
||||
scope["extensions"]["tls"] = {"peer_cert": peer_cert}
|
||||
return scope
|
||||
|
||||
|
||||
async def _collect_response(middleware, scope):
|
||||
"""Drive the middleware and capture (status, body)."""
|
||||
status = None
|
||||
body = b""
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b""}
|
||||
|
||||
async def send(event):
|
||||
nonlocal status, body
|
||||
if event["type"] == "http.response.start":
|
||||
status = event["status"]
|
||||
elif event["type"] == "http.response.body":
|
||||
body += event.get("body", b"")
|
||||
|
||||
await middleware(scope, receive, send)
|
||||
return status, body
|
||||
|
||||
|
||||
class TestMTLSMiddleware:
|
||||
"""
|
||||
Unit-test the MTLSMiddleware without spinning up a real server.
|
||||
We inject mTLS configuration through env-var patching so the middleware
|
||||
believes it is enabled, and use the ASGI scope's tls extension to simulate
|
||||
whether a client cert was presented.
|
||||
"""
|
||||
|
||||
def _make_middleware(self, tmp_path, app=None):
|
||||
"""Return a configured MTLSMiddleware backed by real-looking cert files."""
|
||||
from agent.mtls import MTLSMiddleware
|
||||
|
||||
for name in ("cert.pem", "key.pem", "ca.pem"):
|
||||
(tmp_path / name).write_text("x")
|
||||
|
||||
env = {
|
||||
"HERMES_MTLS_CERT": str(tmp_path / "cert.pem"),
|
||||
"HERMES_MTLS_KEY": str(tmp_path / "key.pem"),
|
||||
"HERMES_MTLS_CA": str(tmp_path / "ca.pem"),
|
||||
}
|
||||
|
||||
async def passthrough(scope, receive, send):
|
||||
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||||
await send({"type": "http.response.body", "body": b"ok"})
|
||||
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
mw = MTLSMiddleware(app or passthrough)
|
||||
return mw
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorized_agent_accepted(self, tmp_path):
|
||||
"""An A2A route with a valid client cert passes through (200)."""
|
||||
mw = self._make_middleware(tmp_path)
|
||||
scope = _make_scope("/.well-known/agent-card.json", peer_cert={"subject": (("commonName", "timmy"),)})
|
||||
status, body = await _collect_response(mw, scope)
|
||||
assert status == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_agent_rejected(self, tmp_path):
|
||||
"""An A2A route with NO client cert is rejected (403)."""
|
||||
mw = self._make_middleware(tmp_path)
|
||||
scope = _make_scope("/.well-known/agent-card.json", peer_cert=None)
|
||||
status, body = await _collect_response(mw, scope)
|
||||
assert status == 403
|
||||
assert b"certificate" in body.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_a2a_route_not_gated(self, tmp_path):
|
||||
"""Non-A2A routes (like /api/status) pass through even without a cert."""
|
||||
mw = self._make_middleware(tmp_path)
|
||||
scope = _make_scope("/api/status", peer_cert=None)
|
||||
status, body = await _collect_response(mw, scope)
|
||||
assert status == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_card_api_route_gated(self, tmp_path):
|
||||
"""The /api/agent-card route also requires a client cert."""
|
||||
mw = self._make_middleware(tmp_path)
|
||||
scope = _make_scope("/api/agent-card", peer_cert=None)
|
||||
status, _ = await _collect_response(mw, scope)
|
||||
assert status == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_disabled_when_not_configured(self):
|
||||
"""When mTLS env vars are absent, the middleware is a no-op."""
|
||||
from agent.mtls import MTLSMiddleware
|
||||
|
||||
async def passthrough(scope, receive, send):
|
||||
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||||
await send({"type": "http.response.body", "body": b"ok"})
|
||||
|
||||
env = {"HERMES_MTLS_CERT": "", "HERMES_MTLS_KEY": "", "HERMES_MTLS_CA": ""}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
mw = MTLSMiddleware(passthrough)
|
||||
|
||||
# Even an A2A route with no cert should pass through
|
||||
scope = _make_scope("/.well-known/agent-card.json", peer_cert=None)
|
||||
status, _ = await _collect_response(mw, scope)
|
||||
assert status == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: get_peer_cn
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetPeerCn:
|
||||
def test_returns_cn_from_subject(self):
|
||||
from agent.mtls import get_peer_cn
|
||||
|
||||
class FakeSSL:
|
||||
def getpeercert(self):
|
||||
return {"subject": ((("commonName", "timmy"),),)}
|
||||
|
||||
assert get_peer_cn(FakeSSL()) == "timmy"
|
||||
|
||||
def test_returns_none_when_no_cert(self):
|
||||
from agent.mtls import get_peer_cn
|
||||
|
||||
class FakeSSL:
|
||||
def getpeercert(self):
|
||||
return None
|
||||
|
||||
assert get_peer_cn(FakeSSL()) is None
|
||||
|
||||
def test_returns_none_on_exception(self):
|
||||
from agent.mtls import get_peer_cn
|
||||
|
||||
class BrokenSSL:
|
||||
def getpeercert(self):
|
||||
raise RuntimeError("no ssl")
|
||||
|
||||
assert get_peer_cn(BrokenSSL()) is None
|
||||
274
tests/test_poka_yoke.py
Normal file
274
tests/test_poka_yoke.py
Normal file
@@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
test_poka_yoke.py — Tests for the tool call validation firewall.
|
||||
|
||||
Covers: unknown tool, bad param type, missing required arg,
|
||||
extra unknown param, enum validation, closest-name suggestion.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from tools.poka_yoke import (
|
||||
validate_tool_call,
|
||||
_find_closest_name,
|
||||
_validate_type,
|
||||
_truncate,
|
||||
)
|
||||
|
||||
|
||||
# ── Mock Registry ─────────────────────────────────────────────────────────────
|
||||
|
||||
class MockEntry:
|
||||
def __init__(self, name, schema):
|
||||
self.name = name
|
||||
self.schema = schema
|
||||
self.toolset = "test"
|
||||
|
||||
|
||||
MOCK_TOOLS = {
|
||||
"read_file": MockEntry("read_file", {
|
||||
"name": "read_file",
|
||||
"description": "Read a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "File path"},
|
||||
"offset": {"type": "integer", "description": "Start line"},
|
||||
"limit": {"type": "integer", "description": "Max lines"},
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
}),
|
||||
"web_search": MockEntry("web_search", {
|
||||
"name": "web_search",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"max_results": {"type": "integer"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}),
|
||||
"write_file": MockEntry("write_file", {
|
||||
"name": "write_file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string"},
|
||||
"content": {"type": "string"},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
},
|
||||
}),
|
||||
"terminal": MockEntry("terminal", {
|
||||
"name": "terminal",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string"},
|
||||
"timeout": {"type": "integer"},
|
||||
"background": {"type": "boolean"},
|
||||
},
|
||||
"required": ["command"],
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
|
||||
def _mock_registry():
|
||||
"""Create a mock registry."""
|
||||
mock_reg = MagicMock()
|
||||
mock_reg.get_entry = lambda name: MOCK_TOOLS.get(name)
|
||||
mock_reg.get_all_tool_names = lambda: list(MOCK_TOOLS.keys())
|
||||
return mock_reg
|
||||
|
||||
|
||||
# ── Test: Unknown Tool ────────────────────────────────────────────────────────
|
||||
|
||||
class TestUnknownTool:
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_unknown_tool_rejected(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = None
|
||||
mock_reg.get_all_tool_names.return_value = list(MOCK_TOOLS.keys())
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call("nonexistent_tool", {})
|
||||
|
||||
assert is_valid is False
|
||||
assert len(msgs) > 0
|
||||
assert "nonexistent_tool" in msgs[0]
|
||||
assert "Unknown tool" in msgs[0]
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_unknown_tool_lists_available(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = None
|
||||
mock_reg.get_all_tool_names.return_value = list(MOCK_TOOLS.keys())
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call("foo", {})
|
||||
|
||||
assert is_valid is False
|
||||
assert "read_file" in msgs[0]
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_close_name_suggests_correction(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = None
|
||||
mock_reg.get_all_tool_names.return_value = list(MOCK_TOOLS.keys())
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call("readfile", {})
|
||||
|
||||
assert "read_file" in msgs[0]
|
||||
assert name == "read_file"
|
||||
|
||||
|
||||
# ── Test: Missing Required Args ───────────────────────────────────────────────
|
||||
|
||||
class TestMissingRequired:
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_missing_required_rejected(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call("read_file", {})
|
||||
|
||||
assert is_valid is False
|
||||
assert any("Missing required" in m for m in msgs)
|
||||
assert any("'path'" in m for m in msgs)
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_multiple_missing_required(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["write_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call("write_file", {})
|
||||
|
||||
assert is_valid is False
|
||||
assert any("'path'" in m for m in msgs)
|
||||
assert any("'content'" in m for m in msgs)
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_required_present_passes(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"read_file", {"path": "test.txt"}
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
|
||||
|
||||
# ── Test: Type Validation ─────────────────────────────────────────────────────
|
||||
|
||||
class TestTypeValidation:
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_wrong_type_rejected(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"read_file", {"path": "test.txt", "offset": "not_a_number"}
|
||||
)
|
||||
|
||||
assert is_valid is False
|
||||
assert any("offset" in m and "integer" in m for m in msgs)
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_string_to_int_coercion(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"read_file", {"path": "test.txt", "offset": "42"}
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
assert params is not None
|
||||
assert params["offset"] == 42
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_boolean_coercion(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["terminal"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"terminal", {"command": "ls", "background": "true"}
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
assert params is not None
|
||||
assert params["background"] is True
|
||||
|
||||
|
||||
# ── Test: Unknown Parameters ──────────────────────────────────────────────────
|
||||
|
||||
class TestUnknownParams:
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_unknown_param_removed(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"read_file", {"path": "test.txt", "bogus_param": "value"}
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
assert params is not None
|
||||
assert "bogus_param" not in params
|
||||
assert "path" in params
|
||||
assert any("Unknown parameter" in m for m in msgs)
|
||||
|
||||
|
||||
# ── Test: Valid Calls Pass Through ────────────────────────────────────────────
|
||||
|
||||
class TestValidCalls:
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_valid_read_file(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"read_file", {"path": "test.txt", "offset": 1, "limit": 100}
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
assert name is None
|
||||
assert params is None
|
||||
assert msgs == []
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_valid_write_file(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["write_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"write_file", {"path": "out.txt", "content": "hello"}
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
|
||||
|
||||
# ── Test: Helper Functions ────────────────────────────────────────────────────
|
||||
|
||||
class TestHelpers:
|
||||
def test_find_closest_exact_prefix(self):
|
||||
assert _find_closest_name("readfil", ["read_file", "write_file"]) == "read_file"
|
||||
|
||||
def test_find_closest_substring(self):
|
||||
assert _find_closest_name("file", ["read_file", "web_search"]) == "read_file"
|
||||
|
||||
def test_find_closest_no_match(self):
|
||||
assert _find_closest_name("xyzzy", ["read_file", "write_file"]) is None
|
||||
|
||||
def test_validate_type_string(self):
|
||||
ok, val = _validate_type("x", "hello", "string")
|
||||
assert ok is True
|
||||
|
||||
def test_validate_type_int_coercion(self):
|
||||
ok, val = _validate_type("x", "42", "integer")
|
||||
assert ok is True
|
||||
assert val == 42
|
||||
|
||||
def test_validate_type_int_bad(self):
|
||||
ok, val = _validate_type("x", "not_int", "integer")
|
||||
assert ok is False
|
||||
|
||||
def test_truncate(self):
|
||||
assert _truncate("hello", 10) == "hello"
|
||||
assert _truncate("hello world", 8) == "hello..."
|
||||
76
tests/test_profile_isolation.py
Normal file
76
tests/test_profile_isolation.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Tests for profile session isolation (#891)."""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
# Override paths for testing
|
||||
import agent.profile_isolation as iso_mod
|
||||
_test_dir = Path(tempfile.mkdtemp())
|
||||
iso_mod.PROFILE_TAGS_FILE = _test_dir / "tags.json"
|
||||
|
||||
|
||||
def test_tag_session():
|
||||
"""Session gets tagged with profile."""
|
||||
profile = iso_mod.tag_session("sess-1", "sprint")
|
||||
assert profile == "sprint"
|
||||
assert iso_mod.get_session_profile("sess-1") == "sprint"
|
||||
|
||||
|
||||
def test_default_profile():
|
||||
"""Sessions tagged with default when no profile specified."""
|
||||
profile = iso_mod.tag_session("sess-2")
|
||||
assert profile is not None
|
||||
|
||||
|
||||
def test_get_session_profile():
|
||||
"""Can retrieve profile for tagged session."""
|
||||
iso_mod.tag_session("sess-3", "fenrir")
|
||||
assert iso_mod.get_session_profile("sess-3") == "fenrir"
|
||||
|
||||
|
||||
def test_untagged_returns_none():
|
||||
"""Untagged session returns None."""
|
||||
assert iso_mod.get_session_profile("nonexistent") is None
|
||||
|
||||
|
||||
def test_profile_stats():
|
||||
"""Stats reflect tagged sessions."""
|
||||
iso_mod.tag_session("s1", "default")
|
||||
iso_mod.tag_session("s2", "sprint")
|
||||
iso_mod.tag_session("s3", "sprint")
|
||||
stats = iso_mod.get_profile_stats()
|
||||
assert stats["total_tagged_sessions"] >= 3
|
||||
assert "sprint" in stats["profile_counts"]
|
||||
|
||||
|
||||
def test_filter_sessions():
|
||||
"""Filter returns only matching profile sessions."""
|
||||
iso_mod.tag_session("filter-1", "alpha")
|
||||
iso_mod.tag_session("filter-2", "beta")
|
||||
iso_mod.tag_session("filter-3", "alpha")
|
||||
|
||||
sessions = [
|
||||
{"session_id": "filter-1"},
|
||||
{"session_id": "filter-2"},
|
||||
{"session_id": "filter-3"},
|
||||
]
|
||||
|
||||
filtered = iso_mod.filter_sessions_by_profile(sessions, "alpha")
|
||||
ids = [s["session_id"] for s in filtered]
|
||||
assert "filter-1" in ids
|
||||
assert "filter-3" in ids
|
||||
assert "filter-2" not in ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = [test_tag_session, test_default_profile, test_get_session_profile,
|
||||
test_untagged_returns_none, test_profile_stats, test_filter_sessions]
|
||||
for t in tests:
|
||||
print(f"Running {t.__name__}...")
|
||||
t()
|
||||
print(" PASS")
|
||||
print("\nAll tests passed.")
|
||||
302
tests/test_skill_manager_autorevert.py
Normal file
302
tests/test_skill_manager_autorevert.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
Integration tests for poka-yoke auto-revert on incomplete skill edits (#923).
|
||||
|
||||
Verifies the transactional write-validate-commit-or-rollback pattern:
|
||||
- Backup created before every write
|
||||
- Post-write validation triggers revert on corrupted/empty file
|
||||
- Successful writes clean up the backup
|
||||
- At most MAX_BACKUPS_PER_FILE backups retained per file
|
||||
"""
|
||||
|
||||
import time
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from tools.skill_manager_tool import (
|
||||
MAX_BACKUPS_PER_FILE,
|
||||
_backup_skill_file,
|
||||
_cleanup_old_backups,
|
||||
_edit_skill,
|
||||
_patch_skill,
|
||||
_revert_from_backup,
|
||||
_validate_written_file,
|
||||
_write_file,
|
||||
)
|
||||
|
||||
|
||||
VALID_SKILL_MD = """\
|
||||
---
|
||||
name: test-skill
|
||||
description: A skill for testing auto-revert
|
||||
---
|
||||
|
||||
## Overview
|
||||
Test skill body content.
|
||||
"""
|
||||
|
||||
VALID_UPDATED_MD = """\
|
||||
---
|
||||
name: test-skill
|
||||
description: Updated description
|
||||
---
|
||||
|
||||
## Overview
|
||||
Updated test skill body.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_skill(tmp_path: Path, content: str = VALID_SKILL_MD) -> Path:
|
||||
"""Write a minimal SKILL.md in *tmp_path* and return its path."""
|
||||
skill_md = tmp_path / "SKILL.md"
|
||||
skill_md.write_text(content, encoding="utf-8")
|
||||
return skill_md
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: _backup_skill_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackupSkillFile:
|
||||
def test_creates_bak_file(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
backup = _backup_skill_file(skill_md)
|
||||
assert backup is not None
|
||||
assert backup.exists()
|
||||
assert ".bak." in backup.name
|
||||
|
||||
def test_backup_preserves_content(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
backup = _backup_skill_file(skill_md)
|
||||
assert backup.read_text(encoding="utf-8") == VALID_SKILL_MD
|
||||
|
||||
def test_no_backup_for_nonexistent_file(self, tmp_path):
|
||||
missing = tmp_path / "SKILL.md"
|
||||
assert _backup_skill_file(missing) is None
|
||||
|
||||
def test_backup_name_contains_timestamp(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
before = int(time.time())
|
||||
backup = _backup_skill_file(skill_md)
|
||||
after = int(time.time())
|
||||
ts = int(backup.name.split(".bak.")[-1])
|
||||
assert before <= ts <= after
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: _cleanup_old_backups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCleanupOldBackups:
|
||||
def _create_backups(self, skill_md: Path, n: int) -> list:
|
||||
backups = []
|
||||
for i in range(n):
|
||||
bp = skill_md.parent / f"{skill_md.name}.bak.{1000 + i}"
|
||||
bp.write_text("backup content", encoding="utf-8")
|
||||
backups.append(bp)
|
||||
return backups
|
||||
|
||||
def test_prunes_excess_backups(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
self._create_backups(skill_md, MAX_BACKUPS_PER_FILE + 2)
|
||||
_cleanup_old_backups(skill_md)
|
||||
remaining = list(tmp_path.glob(f"SKILL.md.bak.*"))
|
||||
assert len(remaining) == MAX_BACKUPS_PER_FILE
|
||||
|
||||
def test_keeps_backups_within_limit(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
self._create_backups(skill_md, MAX_BACKUPS_PER_FILE)
|
||||
_cleanup_old_backups(skill_md)
|
||||
remaining = list(tmp_path.glob("SKILL.md.bak.*"))
|
||||
assert len(remaining) == MAX_BACKUPS_PER_FILE
|
||||
|
||||
def test_noop_when_no_backups(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
_cleanup_old_backups(skill_md) # should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: _validate_written_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateWrittenFile:
|
||||
def test_valid_skill_md(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
assert _validate_written_file(skill_md, is_skill_md=True) is None
|
||||
|
||||
def test_empty_file_fails(self, tmp_path):
|
||||
skill_md = tmp_path / "SKILL.md"
|
||||
skill_md.write_text("", encoding="utf-8")
|
||||
err = _validate_written_file(skill_md, is_skill_md=False)
|
||||
assert err is not None
|
||||
assert "empty" in err.lower()
|
||||
|
||||
def test_broken_frontmatter_fails(self, tmp_path):
|
||||
skill_md = tmp_path / "SKILL.md"
|
||||
skill_md.write_text("Not a skill\nno frontmatter\n", encoding="utf-8")
|
||||
err = _validate_written_file(skill_md, is_skill_md=True)
|
||||
assert err is not None
|
||||
|
||||
def test_missing_required_field_fails(self, tmp_path):
|
||||
skill_md = tmp_path / "SKILL.md"
|
||||
skill_md.write_text("---\ndescription: no name\n---\nbody\n", encoding="utf-8")
|
||||
err = _validate_written_file(skill_md, is_skill_md=True)
|
||||
assert err is not None
|
||||
assert "name" in err.lower()
|
||||
|
||||
def test_missing_file_returns_error(self, tmp_path):
|
||||
missing = tmp_path / "SKILL.md"
|
||||
err = _validate_written_file(missing, is_skill_md=False)
|
||||
assert err is not None
|
||||
|
||||
def test_non_skill_md_only_checks_emptiness(self, tmp_path):
|
||||
ref = tmp_path / "references" / "guide.md"
|
||||
ref.parent.mkdir()
|
||||
ref.write_text("# Guide\nsome content\n", encoding="utf-8")
|
||||
assert _validate_written_file(ref, is_skill_md=False) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: _revert_from_backup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRevertFromBackup:
|
||||
def test_restores_from_backup(self, tmp_path):
|
||||
original = "original content"
|
||||
skill_md = tmp_path / "SKILL.md"
|
||||
skill_md.write_text(original, encoding="utf-8")
|
||||
backup = tmp_path / "SKILL.md.bak.99999"
|
||||
backup.write_text(original, encoding="utf-8")
|
||||
|
||||
skill_md.write_text("corrupted content", encoding="utf-8")
|
||||
_revert_from_backup(skill_md, backup)
|
||||
assert skill_md.read_text(encoding="utf-8") == original
|
||||
|
||||
def test_removes_file_when_no_backup(self, tmp_path):
|
||||
skill_md = tmp_path / "SKILL.md"
|
||||
skill_md.write_text("corrupted", encoding="utf-8")
|
||||
_revert_from_backup(skill_md, None)
|
||||
assert not skill_md.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: _edit_skill auto-revert
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditSkillAutoRevert:
|
||||
@pytest.fixture
|
||||
def skill_dir(self, tmp_path):
|
||||
"""Create a minimal skill directory and patch _find_skill."""
|
||||
d = tmp_path / "test-skill"
|
||||
d.mkdir()
|
||||
skill_md = d / "SKILL.md"
|
||||
skill_md.write_text(VALID_SKILL_MD, encoding="utf-8")
|
||||
return d
|
||||
|
||||
def test_successful_edit_removes_backup(self, skill_dir):
|
||||
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||
patch("tools.skill_manager_tool._security_scan_skill", return_value=None):
|
||||
mock_find.return_value = {"path": skill_dir}
|
||||
result = _edit_skill("test-skill", VALID_UPDATED_MD)
|
||||
|
||||
assert result["success"] is True
|
||||
backups = list(skill_dir.glob("SKILL.md.bak.*"))
|
||||
assert len(backups) == 0
|
||||
|
||||
def test_revert_when_post_write_validation_fails(self, skill_dir):
|
||||
"""Simulate a write that produces an empty file on disk."""
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
|
||||
def corrupt_write(path, content, **kw):
|
||||
# Write an empty file to simulate truncation
|
||||
path.write_text("", encoding="utf-8")
|
||||
|
||||
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||
patch("tools.skill_manager_tool._atomic_write_text", side_effect=corrupt_write):
|
||||
mock_find.return_value = {"path": skill_dir}
|
||||
result = _edit_skill("test-skill", VALID_UPDATED_MD)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "reverted" in result["error"].lower()
|
||||
# Original content restored
|
||||
assert skill_md.read_text(encoding="utf-8") == VALID_SKILL_MD
|
||||
|
||||
def test_backup_preserved_after_revert(self, skill_dir):
|
||||
"""A .bak file should survive when the edit is reverted (debugging aid)."""
|
||||
def corrupt_write(path, content, **kw):
|
||||
path.write_text("", encoding="utf-8")
|
||||
|
||||
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||
patch("tools.skill_manager_tool._atomic_write_text", side_effect=corrupt_write):
|
||||
mock_find.return_value = {"path": skill_dir}
|
||||
_edit_skill("test-skill", VALID_UPDATED_MD)
|
||||
|
||||
backups = list(skill_dir.glob("SKILL.md.bak.*"))
|
||||
assert len(backups) == 1
|
||||
|
||||
def test_max_backups_enforced_after_multiple_edits(self, skill_dir):
|
||||
"""After many successful edits, at most MAX_BACKUPS_PER_FILE .bak files remain."""
|
||||
n = MAX_BACKUPS_PER_FILE + 4
|
||||
for i in range(n):
|
||||
# Plant stale backup files to simulate prior runs
|
||||
bp = skill_dir / f"SKILL.md.bak.{1000 + i}"
|
||||
bp.write_text("old backup", encoding="utf-8")
|
||||
|
||||
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||
patch("tools.skill_manager_tool._security_scan_skill", return_value=None):
|
||||
mock_find.return_value = {"path": skill_dir}
|
||||
result = _edit_skill("test-skill", VALID_UPDATED_MD)
|
||||
|
||||
assert result["success"] is True
|
||||
backups = list(skill_dir.glob("SKILL.md.bak.*"))
|
||||
assert len(backups) <= MAX_BACKUPS_PER_FILE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: _patch_skill auto-revert
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPatchSkillAutoRevert:
|
||||
@pytest.fixture
|
||||
def skill_dir(self, tmp_path):
|
||||
d = tmp_path / "test-skill"
|
||||
d.mkdir()
|
||||
(d / "SKILL.md").write_text(VALID_SKILL_MD, encoding="utf-8")
|
||||
return d
|
||||
|
||||
def test_successful_patch_removes_backup(self, skill_dir):
|
||||
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||
patch("tools.skill_manager_tool._security_scan_skill", return_value=None):
|
||||
mock_find.return_value = {"path": skill_dir}
|
||||
result = _patch_skill(
|
||||
"test-skill",
|
||||
"A skill for testing auto-revert",
|
||||
"Updated description",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(list(skill_dir.glob("SKILL.md.bak.*"))) == 0
|
||||
|
||||
def test_revert_on_corrupt_write(self, skill_dir):
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
original = skill_md.read_text(encoding="utf-8")
|
||||
|
||||
def corrupt_write(path, content, **kw):
|
||||
path.write_text("", encoding="utf-8")
|
||||
|
||||
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||
patch("tools.skill_manager_tool._atomic_write_text", side_effect=corrupt_write):
|
||||
mock_find.return_value = {"path": skill_dir}
|
||||
result = _patch_skill(
|
||||
"test-skill",
|
||||
"A skill for testing",
|
||||
"A skill for testing auto-revert",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "reverted" in result["error"].lower()
|
||||
assert skill_md.read_text(encoding="utf-8") == original
|
||||
82
tests/test_syntax_validation.py
Normal file
82
tests/test_syntax_validation.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Tests for Python syntax validation in execute_code."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Import the validation function directly
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
from tools.code_execution_tool import _validate_python_syntax
|
||||
|
||||
|
||||
class TestValidatePythonSyntax:
|
||||
"""Test _validate_python_syntax catches errors before subprocess spawn."""
|
||||
|
||||
def test_valid_code_returns_none(self):
|
||||
assert _validate_python_syntax("print('hello')") is None
|
||||
|
||||
def test_valid_multiline_returns_none(self):
|
||||
code = """
|
||||
import os
|
||||
def foo():
|
||||
return 42
|
||||
result = foo()
|
||||
"""
|
||||
assert _validate_python_syntax(code) is None
|
||||
|
||||
def test_syntax_error_detected(self):
|
||||
result = _validate_python_syntax("def foo(
|
||||
")
|
||||
assert result is not None
|
||||
data = json.loads(result)
|
||||
assert data["syntax_error"] is True
|
||||
assert "line" in data
|
||||
assert "message" in data
|
||||
|
||||
def test_missing_colon(self):
|
||||
result = _validate_python_syntax("def foo()
|
||||
pass")
|
||||
data = json.loads(result)
|
||||
assert data["syntax_error"] is True
|
||||
assert data["line"] == 1
|
||||
|
||||
def test_unmatched_paren(self):
|
||||
result = _validate_python_syntax("print('hello'")
|
||||
data = json.loads(result)
|
||||
assert data["syntax_error"] is True
|
||||
|
||||
def test_indentation_error(self):
|
||||
result = _validate_python_syntax("def foo():
|
||||
pass")
|
||||
data = json.loads(result)
|
||||
assert data["syntax_error"] is True
|
||||
assert data["line"] == 2
|
||||
|
||||
def test_invalid_character(self):
|
||||
result = _validate_python_syntax("x = 5 √ 2")
|
||||
data = json.loads(result)
|
||||
assert data["syntax_error"] is True
|
||||
|
||||
def test_error_format_has_required_fields(self):
|
||||
result = _validate_python_syntax("def(
|
||||
")
|
||||
data = json.loads(result)
|
||||
assert "error" in data
|
||||
assert "syntax_error" in data
|
||||
assert "line" in data
|
||||
assert "offset" in data
|
||||
assert "message" in data
|
||||
|
||||
def test_empty_string_returns_none(self):
|
||||
# Empty code is caught by the guard before validation
|
||||
# But if called directly, ast.parse("") is valid
|
||||
assert _validate_python_syntax("") is None
|
||||
|
||||
def test_comment_only_returns_none(self):
|
||||
assert _validate_python_syntax("# just a comment") is None
|
||||
|
||||
def test_complex_valid_code(self):
|
||||
code =
|
||||
58
tests/test_time_aware_routing.py
Normal file
58
tests/test_time_aware_routing.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Tests for time-aware model routing."""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from agent.time_aware_routing import (
|
||||
resolve_time_aware_model,
|
||||
get_hour_error_rate,
|
||||
is_off_hours,
|
||||
get_routing_report,
|
||||
)
|
||||
|
||||
|
||||
class TestErrorRates:
|
||||
def test_evening_high_error(self):
|
||||
assert get_hour_error_rate(18) == 9.4
|
||||
assert get_hour_error_rate(19) == 8.1
|
||||
|
||||
def test_morning_low_error(self):
|
||||
assert get_hour_error_rate(9) == 4.0
|
||||
assert get_hour_error_rate(12) == 4.0
|
||||
|
||||
def test_default_for_unknown(self):
|
||||
assert get_hour_error_rate(15) == 4.0
|
||||
|
||||
|
||||
class TestOffHours:
|
||||
def test_evening_is_off_hours(self):
|
||||
assert is_off_hours(20) is True
|
||||
assert is_off_hours(2) is True
|
||||
|
||||
def test_business_hours_not_off(self):
|
||||
assert is_off_hours(9) is False
|
||||
assert is_off_hours(14) is False
|
||||
|
||||
|
||||
class TestRouting:
|
||||
def test_interactive_uses_base_model(self):
|
||||
d = resolve_time_aware_model("my-model", "my-provider", is_cron=False, hour=18)
|
||||
assert d.model == "my-model"
|
||||
assert "Interactive" in d.reason
|
||||
|
||||
def test_cron_low_error_uses_base(self):
|
||||
d = resolve_time_aware_model("cheap-model", is_cron=True, hour=10)
|
||||
assert d.model == "cheap-model"
|
||||
|
||||
def test_cron_high_error_upgrades(self):
|
||||
d = resolve_time_aware_model("cheap-model", is_cron=True, hour=18)
|
||||
assert d.model != "cheap-model"
|
||||
assert d.is_off_hours is True
|
||||
|
||||
def test_routing_report(self):
|
||||
report = get_routing_report()
|
||||
assert "Time-Aware Model Routing" in report
|
||||
assert "18:00" in report
|
||||
237
tests/test_token_budget.py
Normal file
237
tests/test_token_budget.py
Normal file
@@ -0,0 +1,237 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for agent/token_budget.py — Poka-yoke context overflow guard.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from agent.token_budget import (
|
||||
TokenBudget,
|
||||
BudgetLevel,
|
||||
BudgetStatus,
|
||||
WARN_PERCENT,
|
||||
CAUTION_PERCENT,
|
||||
CRITICAL_PERCENT,
|
||||
STOP_PERCENT,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def budget():
|
||||
"""Standard 128K context budget."""
|
||||
return TokenBudget(context_length=128_000)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def small_budget():
|
||||
"""4K context for tight testing."""
|
||||
return TokenBudget(context_length=4_000)
|
||||
|
||||
|
||||
# ── Threshold Levels ──────────────────────────────────────────────────
|
||||
|
||||
class TestThresholds:
|
||||
def test_normal_below_60(self, budget):
|
||||
budget.update(50_000) # 39%
|
||||
status = budget.check()
|
||||
assert status.level == BudgetLevel.NORMAL
|
||||
assert not status.should_compress
|
||||
assert not status.should_block_tools
|
||||
assert not status.should_terminate
|
||||
|
||||
def test_warning_at_60(self, budget):
|
||||
budget.update(int(128_000 * 0.62)) # 62%
|
||||
status = budget.check()
|
||||
assert status.level == BudgetLevel.WARNING
|
||||
assert not status.should_compress
|
||||
assert not status.should_block_tools
|
||||
|
||||
def test_caution_at_80(self, budget):
|
||||
budget.update(int(128_000 * 0.82)) # 82%
|
||||
status = budget.check()
|
||||
assert status.level == BudgetLevel.CAUTION
|
||||
assert status.should_compress
|
||||
assert not status.should_block_tools
|
||||
assert not status.should_terminate
|
||||
|
||||
def test_critical_at_90(self, budget):
|
||||
budget.update(int(128_000 * 0.91)) # 91%
|
||||
status = budget.check()
|
||||
assert status.level == BudgetLevel.CRITICAL
|
||||
assert status.should_compress
|
||||
assert status.should_block_tools
|
||||
assert not status.should_terminate
|
||||
|
||||
def test_stop_at_95(self, budget):
|
||||
budget.update(int(128_000 * 0.96)) # 96%
|
||||
status = budget.check()
|
||||
assert status.level == BudgetLevel.STOP
|
||||
assert status.should_compress
|
||||
assert status.should_block_tools
|
||||
assert status.should_terminate
|
||||
|
||||
def test_small_context_thresholds(self, small_budget):
|
||||
# 4K * 0.60 = 2400
|
||||
small_budget.update(2450)
|
||||
assert small_budget.check().level == BudgetLevel.WARNING
|
||||
|
||||
small_budget.update(3250) # 4K * 0.81
|
||||
assert small_budget.check().level == BudgetLevel.CAUTION
|
||||
|
||||
small_budget.update(3650) # 4K * 0.91
|
||||
assert small_budget.check().level == BudgetLevel.CRITICAL
|
||||
|
||||
small_budget.update(3850) # 4K * 0.96
|
||||
assert small_budget.check().level == BudgetLevel.STOP
|
||||
|
||||
|
||||
# ── Convenience Methods ───────────────────────────────────────────────
|
||||
|
||||
class TestConvenienceMethods:
|
||||
def test_should_compress(self, budget):
|
||||
budget.update(int(128_000 * 0.79))
|
||||
assert not budget.should_compress()
|
||||
budget.update(int(128_000 * 0.80))
|
||||
assert budget.should_compress()
|
||||
|
||||
def test_should_block_tools(self, budget):
|
||||
budget.update(int(128_000 * 0.89))
|
||||
assert not budget.should_block_tools()
|
||||
budget.update(int(128_000 * 0.90))
|
||||
assert budget.should_block_tools()
|
||||
|
||||
def test_should_terminate(self, budget):
|
||||
budget.update(int(128_000 * 0.94))
|
||||
assert not budget.should_terminate()
|
||||
budget.update(int(128_000 * 0.95))
|
||||
assert budget.should_terminate()
|
||||
|
||||
|
||||
# ── Tool Output Budgeting ─────────────────────────────────────────────
|
||||
|
||||
class TestToolOutputBudget:
|
||||
def test_normal_budget(self, budget):
|
||||
budget.update(int(128_000 * 0.50))
|
||||
assert budget.tool_output_budget() == 50_000
|
||||
|
||||
def test_warning_budget(self, budget):
|
||||
budget.update(int(128_000 * 0.65))
|
||||
assert budget.tool_output_budget() == 20_000
|
||||
|
||||
def test_caution_budget(self, budget):
|
||||
budget.update(int(128_000 * 0.85))
|
||||
assert budget.tool_output_budget() == 8_000
|
||||
|
||||
def test_critical_budget(self, budget):
|
||||
budget.update(int(128_000 * 0.92))
|
||||
assert budget.tool_output_budget() == 2_000
|
||||
|
||||
def test_truncate_short_unchanged(self, budget):
|
||||
result = budget.truncate_tool_output("short text", max_chars=1000)
|
||||
assert result == "short text"
|
||||
|
||||
def test_truncate_long(self, budget):
|
||||
long_text = "A" * 100_000
|
||||
result = budget.truncate_tool_output(long_text, max_chars=5_000)
|
||||
assert len(result) <= 5_100 # small overhead for notice
|
||||
assert "truncated" in result
|
||||
assert "A" in result[:2500] # head preserved
|
||||
assert "A" in result[-2500:] # tail preserved
|
||||
|
||||
def test_truncate_very_small(self, budget):
|
||||
long_text = "X" * 1000
|
||||
result = budget.truncate_tool_output(long_text, max_chars=50)
|
||||
assert len(result) <= 50 + 20
|
||||
assert "truncated" in result
|
||||
|
||||
|
||||
# ── Growth Tracking ───────────────────────────────────────────────────
|
||||
|
||||
class TestGrowthTracking:
|
||||
def test_growth_rate(self, budget):
|
||||
budget.update(10_000)
|
||||
budget.update(15_000)
|
||||
budget.update(20_000)
|
||||
assert budget.growth_rate() == 5_000.0
|
||||
|
||||
def test_turns_remaining(self, budget):
|
||||
budget.update(10_000)
|
||||
budget.update(15_000)
|
||||
budget.update(20_000)
|
||||
# rate=5000, remaining=108000, turns=~21
|
||||
turns = budget.turns_remaining()
|
||||
assert turns is not None
|
||||
assert 18 <= turns <= 24
|
||||
|
||||
def test_no_history(self, budget):
|
||||
assert budget.growth_rate() is None
|
||||
assert budget.turns_remaining() is None
|
||||
|
||||
|
||||
# ── Status Indicators ─────────────────────────────────────────────────
|
||||
|
||||
class TestStatusIndicators:
|
||||
def test_indicator_normal(self, budget):
|
||||
budget.update(int(128_000 * 0.50))
|
||||
status = budget.check()
|
||||
indicator = status.to_indicator()
|
||||
assert "50" in indicator
|
||||
|
||||
def test_indicator_warning(self, budget):
|
||||
budget.update(int(128_000 * 0.65))
|
||||
status = budget.check()
|
||||
indicator = status.to_indicator()
|
||||
assert "\u26a0" in indicator or "65" in indicator
|
||||
|
||||
def test_bar(self, budget):
|
||||
budget.update(int(128_000 * 0.50))
|
||||
status = budget.check()
|
||||
bar = status.to_bar()
|
||||
assert "50" in bar
|
||||
|
||||
def test_summary(self, budget):
|
||||
budget.update(50_000)
|
||||
summary = budget.summary()
|
||||
assert "50,000" in summary
|
||||
assert "128,000" in summary
|
||||
assert "NORMAL" in summary
|
||||
|
||||
|
||||
# ── Reset ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestReset:
|
||||
def test_reset_clears_state(self, budget):
|
||||
budget.update(int(128_000 * 0.90))
|
||||
budget.reset()
|
||||
assert budget.tokens_used == 0
|
||||
assert budget.check().level == BudgetLevel.NORMAL
|
||||
assert budget.growth_rate() is None
|
||||
|
||||
|
||||
# ── Edge Cases ────────────────────────────────────────────────────────
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_exact_threshold_boundary(self, budget):
|
||||
# Exactly at 60%
|
||||
budget.update(int(128_000 * 0.60))
|
||||
assert budget.check().level == BudgetLevel.WARNING
|
||||
|
||||
def test_zero_context(self):
|
||||
budget = TokenBudget(context_length=0)
|
||||
status = budget.check()
|
||||
assert status.percent_used == 0
|
||||
|
||||
def test_remaining_for_response(self, budget):
|
||||
budget.update(100_000)
|
||||
remaining = budget.remaining_for_response()
|
||||
# 128000 - 100000 - 6400 (5% reserve) = 21600
|
||||
assert remaining > 0
|
||||
assert remaining < 128_000
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
76
tests/test_tool_fixation_detector.py
Normal file
76
tests/test_tool_fixation_detector.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Tests for tool fixation detection."""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from agent.tool_fixation_detector import ToolFixationDetector, get_fixation_detector
|
||||
|
||||
|
||||
class TestFixationDetection:
|
||||
def test_no_fixation_below_threshold(self):
|
||||
d = ToolFixationDetector(threshold=5)
|
||||
for i in range(4):
|
||||
assert d.record("execute_code") is None
|
||||
|
||||
def test_fixation_at_threshold(self):
|
||||
d = ToolFixationDetector(threshold=3)
|
||||
d.record("execute_code")
|
||||
d.record("execute_code")
|
||||
nudge = d.record("execute_code")
|
||||
assert nudge is not None
|
||||
assert "execute_code" in nudge
|
||||
assert "3 times" in nudge
|
||||
|
||||
def test_fixation_above_threshold(self):
|
||||
d = ToolFixationDetector(threshold=3)
|
||||
d.record("execute_code")
|
||||
d.record("execute_code")
|
||||
d.record("execute_code") # threshold hit
|
||||
nudge = d.record("execute_code") # still nudging
|
||||
assert nudge is not None
|
||||
|
||||
def test_streak_resets_on_different_tool(self):
|
||||
d = ToolFixationDetector(threshold=3)
|
||||
d.record("execute_code")
|
||||
d.record("execute_code")
|
||||
d.record("terminal") # breaks streak
|
||||
assert d._streak_count == 1
|
||||
assert d._current_streak == "terminal"
|
||||
|
||||
def test_nudges_sent_counter(self):
|
||||
d = ToolFixationDetector(threshold=2)
|
||||
d.record("a")
|
||||
d.record("a") # nudge 1
|
||||
d.record("a") # nudge 2
|
||||
assert d.nudges_sent == 2
|
||||
|
||||
def test_events_recorded(self):
|
||||
d = ToolFixationDetector(threshold=2)
|
||||
d.record("x")
|
||||
d.record("x")
|
||||
assert len(d.events) == 1
|
||||
assert d.events[0].tool_name == "x"
|
||||
assert d.events[0].streak_length == 2
|
||||
|
||||
def test_report(self):
|
||||
d = ToolFixationDetector(threshold=2)
|
||||
d.record("x")
|
||||
d.record("x")
|
||||
report = d.format_report()
|
||||
assert "x" in report
|
||||
|
||||
def test_reset(self):
|
||||
d = ToolFixationDetector(threshold=2)
|
||||
d.record("x")
|
||||
d.record("x")
|
||||
d.reset()
|
||||
assert d._streak_count == 0
|
||||
assert d._current_streak == ""
|
||||
|
||||
def test_singleton(self):
|
||||
d1 = get_fixation_detector()
|
||||
d2 = get_fixation_detector()
|
||||
assert d1 is d2
|
||||
67
tests/test_tool_validator.py
Normal file
67
tests/test_tool_validator.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Tests for tool hallucination detection (#922).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from tools.tool_validator import ToolHallucinationDetector, ValidationSeverity
|
||||
|
||||
|
||||
class TestToolHallucinationDetector:
|
||||
def setup_method(self):
|
||||
self.detector = ToolHallucinationDetector()
|
||||
self.detector.register_tool("read_file", {
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string"},
|
||||
"encoding": {"type": "string"},
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
})
|
||||
|
||||
def test_valid_tool_call(self):
|
||||
result = self.detector.validate_tool_call("read_file", {"path": "/tmp/file.txt"})
|
||||
assert result.valid is True
|
||||
assert len(result.blocking_issues) == 0
|
||||
|
||||
def test_unknown_tool(self):
|
||||
result = self.detector.validate_tool_call("hallucinated_tool", {})
|
||||
assert result.valid is False
|
||||
assert any(i.code == "UNKNOWN_TOOL" for i in result.issues)
|
||||
|
||||
def test_missing_required_param(self):
|
||||
result = self.detector.validate_tool_call("read_file", {})
|
||||
assert result.valid is False
|
||||
assert any(i.code == "MISSING_REQUIRED" for i in result.issues)
|
||||
|
||||
def test_wrong_type(self):
|
||||
result = self.detector.validate_tool_call("read_file", {"path": 123})
|
||||
assert result.valid is False
|
||||
assert any(i.code == "WRONG_TYPE" for i in result.issues)
|
||||
|
||||
def test_unknown_param_warning(self):
|
||||
result = self.detector.validate_tool_call("read_file", {"path": "/tmp/file.txt", "unknown": "value"})
|
||||
assert result.valid is True # Warning, not blocking
|
||||
assert any(i.code == "UNKNOWN_PARAM" for i in result.issues)
|
||||
|
||||
def test_placeholder_detection(self):
|
||||
result = self.detector.validate_tool_call("read_file", {"path": "<placeholder>"})
|
||||
assert any(i.code == "PLACEHOLDER_VALUE" for i in result.issues)
|
||||
|
||||
def test_rejection_stats(self):
|
||||
self.detector.validate_tool_call("unknown_tool", {})
|
||||
self.detector.validate_tool_call("read_file", {})
|
||||
stats = self.detector.get_rejection_stats()
|
||||
assert stats["total"] >= 2
|
||||
|
||||
def test_rejection_response(self):
|
||||
from tools.tool_validator import create_rejection_response
|
||||
result = self.detector.validate_tool_call("unknown_tool", {})
|
||||
response = create_rejection_response(result)
|
||||
assert response["role"] == "tool"
|
||||
assert "rejected" in response["content"].lower()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -28,6 +28,7 @@ Platform: Linux / macOS only (Unix domain sockets for local). Disabled on Window
|
||||
Remote execution additionally requires Python 3 in the terminal backend.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@@ -883,6 +884,42 @@ def _execute_remote(
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
|
||||
def _validate_python_syntax(code: str) -> Optional[str]:
|
||||
"""Validate Python syntax before subprocess spawn.
|
||||
|
||||
Runs ast.parse() in-process (sub-millisecond) to catch syntax errors
|
||||
before wasting time spawning a sandboxed subprocess.
|
||||
|
||||
Returns:
|
||||
JSON error string with line, offset, message if syntax is invalid.
|
||||
None if syntax is valid.
|
||||
"""
|
||||
try:
|
||||
ast.parse(code)
|
||||
return None
|
||||
except SyntaxError as exc:
|
||||
# Build context: show offending line with caret
|
||||
lines = code.split("\n")
|
||||
error_line = lines[exc.lineno - 1] if exc.lineno and exc.lineno <= len(lines) else ""
|
||||
context = ""
|
||||
if error_line:
|
||||
context = f"\n {error_line}"
|
||||
if exc.offset:
|
||||
context += f"\n {' ' * (exc.offset - 1)}^"
|
||||
|
||||
return json.dumps({
|
||||
"error": f"Python syntax error on line {exc.lineno}: {exc.msg}{context}",
|
||||
"syntax_error": True,
|
||||
"line": exc.lineno,
|
||||
"offset": exc.offset,
|
||||
"message": exc.msg,
|
||||
})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -916,6 +953,11 @@ def execute_code(
|
||||
if not code or not code.strip():
|
||||
return tool_error("No code provided.")
|
||||
|
||||
# Syntax check before subprocess spawn (catches ~15% of errors in <1ms)
|
||||
syntax_error = _validate_python_syntax(code)
|
||||
if syntax_error:
|
||||
return syntax_error
|
||||
|
||||
# Dispatch: remote backends use file-based RPC, local uses UDS
|
||||
from tools.terminal_tool import _get_env_config
|
||||
env_type = _get_env_config()["env_type"]
|
||||
|
||||
183
tools/credential_redact.py
Normal file
183
tools/credential_redact.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Credential Redaction — Block silent credential exposure in tool outputs
|
||||
|
||||
Poka-yoke: Prevent API keys, tokens, passwords from leaking into context.
|
||||
|
||||
Issue: #839
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HERMES_HOME = Path.home() / ".hermes"
|
||||
AUDIT_DIR = HERMES_HOME / "audit"
|
||||
|
||||
# Credential patterns to detect and redact
|
||||
CREDENTIAL_PATTERNS = [
|
||||
# API keys
|
||||
(r"sk-[a-zA-Z0-9]{20,}", "[REDACTED: OpenAI API key]"),
|
||||
(r"sk-ant-[a-zA-Z0-9-]{20,}", "[REDACTED: Anthropic API key]"),
|
||||
(r"ghp_[a-zA-Z0-9]{36}", "[REDACTED: GitHub token]"),
|
||||
(r"gho_[a-zA-Z0-9]{36}", "[REDACTED: GitHub OAuth token]"),
|
||||
(r"glpat-[a-zA-Z0-9-]{20,}", "[REDACTED: GitLab token]"),
|
||||
|
||||
# Bearer tokens
|
||||
(r"Bearer\s+[a-zA-Z0-9._-]{20,}", "[REDACTED: Bearer token]"),
|
||||
(r"bearer\s+[a-zA-Z0-9._-]{20,}", "[REDACTED: Bearer token]"),
|
||||
|
||||
# Generic tokens/passwords
|
||||
(r"(?:token|TOKEN|Token)[:=]\s*["']?[a-zA-Z0-9._-]{20,}["']?", "[REDACTED: Token]"),
|
||||
(r"(?:password|PASSWORD|Password)[:=]\s*["']?[^\s"']{8,}["']?", "[REDACTED: Password]"),
|
||||
(r"(?:secret|SECRET|Secret)[:=]\s*["']?[a-zA-Z0-9._-]{20,}["']?", "[REDACTED: Secret]"),
|
||||
(r"(?:api_key|API_KEY|apiKey|ApiKey)[:=]\s*["']?[a-zA-Z0-9._-]{20,}["']?", "[REDACTED: API key]"),
|
||||
|
||||
# AWS keys
|
||||
(r"AKIA[0-9A-Z]{16}", "[REDACTED: AWS access key]"),
|
||||
(r"(?:aws_secret_access_key|AWS_SECRET_ACCESS_KEY)[:=]\s*["']?[a-zA-Z0-9/+=]{40}["']?", "[REDACTED: AWS secret]"),
|
||||
|
||||
# Private keys
|
||||
(r"-----BEGIN (?:RSA |EC |OPENSSH )?PRIVATE KEY-----", "[REDACTED: Private key header]"),
|
||||
|
||||
# Connection strings
|
||||
(r"(?:postgres|mysql|mongodb|redis)://[^:]+:[^@]+@[^\s]+", "[REDACTED: Database connection string]"),
|
||||
]
|
||||
|
||||
# Files that should trigger auto-masking
|
||||
SENSITIVE_FILE_PATTERNS = [
|
||||
r"\.env$",
|
||||
r"\.env\.",
|
||||
r"\.secret",
|
||||
r"credentials",
|
||||
r"\.token",
|
||||
r"config\.yaml$",
|
||||
r"config\.yml$",
|
||||
r"config\.json$",
|
||||
r"\.netrc$",
|
||||
r"\.pgpass$",
|
||||
]
|
||||
|
||||
|
||||
class CredentialRedactor:
|
||||
"""Redact credentials from text."""
|
||||
|
||||
def __init__(self, audit_log: bool = True):
|
||||
self.audit_log = audit_log
|
||||
self._redaction_count = 0
|
||||
|
||||
def redact(self, text: str) -> Tuple[str, int]:
|
||||
"""
|
||||
Redact credentials from text.
|
||||
|
||||
Returns:
|
||||
Tuple of (redacted_text, number_of_redactions)
|
||||
"""
|
||||
if not text:
|
||||
return text, 0
|
||||
|
||||
redacted = text
|
||||
count = 0
|
||||
|
||||
for pattern, replacement in CREDENTIAL_PATTERNS:
|
||||
matches = re.findall(pattern, redacted, re.IGNORECASE)
|
||||
if matches:
|
||||
redacted = re.sub(pattern, replacement, redacted, flags=re.IGNORECASE)
|
||||
count += len(matches)
|
||||
|
||||
if count > 0:
|
||||
self._redaction_count += count
|
||||
if self.audit_log:
|
||||
self._log_redaction(count, text[:100])
|
||||
|
||||
return redacted, count
|
||||
|
||||
def redact_tool_output(self, tool_name: str, output: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Redact tool output and return notice if redactions occurred.
|
||||
|
||||
Returns:
|
||||
Tuple of (redacted_output, notice_or_empty)
|
||||
"""
|
||||
redacted, count = self.redact(output)
|
||||
|
||||
if count > 0:
|
||||
notice = f"[REDACTED: {count} credential pattern{'s' if count > 1 else ''} found in {tool_name} output]"
|
||||
return redacted, notice
|
||||
|
||||
return redacted, ""
|
||||
|
||||
def should_mask_file(self, file_path: str) -> bool:
|
||||
"""Check if file should have credentials auto-masked."""
|
||||
path_lower = file_path.lower()
|
||||
return any(re.search(p, path_lower) for p in SENSITIVE_FILE_PATTERNS)
|
||||
|
||||
def mask_file_content(self, content: str, file_path: str) -> str:
|
||||
"""Mask credentials in file content while preserving structure."""
|
||||
if not self.should_mask_file(file_path):
|
||||
return content
|
||||
|
||||
lines = content.split("\n")
|
||||
masked_lines = []
|
||||
|
||||
for line in lines:
|
||||
# Preserve key=value structure but mask values
|
||||
if "=" in line and not line.strip().startswith("#"):
|
||||
key, _, value = line.partition("=")
|
||||
key_lower = key.strip().lower()
|
||||
|
||||
sensitive_keys = ["password", "secret", "token", "key", "api", "credential"]
|
||||
if any(sk in key_lower for sk in sensitive_keys):
|
||||
masked_lines.append(f"{key}=[REDACTED]")
|
||||
else:
|
||||
masked_lines.append(line)
|
||||
else:
|
||||
masked_lines.append(line)
|
||||
|
||||
return "\n".join(masked_lines)
|
||||
|
||||
def _log_redaction(self, count: int, preview: str):
|
||||
"""Log redaction event to audit trail."""
|
||||
try:
|
||||
AUDIT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
audit_file = AUDIT_DIR / "redactions.jsonl"
|
||||
|
||||
entry = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"redactions": count,
|
||||
"preview_hash": hash(preview),
|
||||
}
|
||||
|
||||
with open(audit_file, "a") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Audit log failed: %s", e)
|
||||
|
||||
|
||||
# Module-level redactor
|
||||
_redactor = CredentialRedactor()
|
||||
|
||||
|
||||
def redact_credentials(text: str) -> Tuple[str, int]:
|
||||
"""Redact credentials from text."""
|
||||
return _redactor.redact(text)
|
||||
|
||||
|
||||
def redact_tool_output(tool_name: str, output: str) -> Tuple[str, str]:
|
||||
"""Redact tool output and return notice."""
|
||||
return _redactor.redact_tool_output(tool_name, output)
|
||||
|
||||
|
||||
def should_mask_file(file_path: str) -> bool:
|
||||
"""Check if file should be masked."""
|
||||
return _redactor.should_mask_file(file_path)
|
||||
|
||||
|
||||
def mask_sensitive_file(content: str, file_path: str) -> str:
|
||||
"""Mask credentials in sensitive file."""
|
||||
return _redactor.mask_file_content(content, file_path)
|
||||
@@ -327,6 +327,33 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str =
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
# ── Path existence guard (poka-yoke #887) ─────────────────────
|
||||
# Check if file exists before attempting read. 83.7% of read_file
|
||||
# errors are file-not-found — the agent hallucinates paths.
|
||||
# This guard catches them early with a clear, actionable error.
|
||||
if not _resolved.exists():
|
||||
# Try to suggest similar files in the same directory
|
||||
parent = _resolved.parent
|
||||
suggestion = ""
|
||||
if parent.exists() and parent.is_dir():
|
||||
similar = [
|
||||
f.name for f in parent.iterdir()
|
||||
if f.is_file() and _resolved.stem[:3].lower() in f.stem.lower()
|
||||
][:5]
|
||||
if similar:
|
||||
suggestion = f" Similar files in {parent}: {', '.join(similar)}"
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"File not found: '{path}'. The file does not exist at the resolved path "
|
||||
f"({_resolved}).{suggestion} "
|
||||
"Use search_files to find the correct path first."
|
||||
),
|
||||
"path": path,
|
||||
"resolved": str(_resolved),
|
||||
"suggestion": "Use search_files(pattern='...', target='files') to find files.",
|
||||
})
|
||||
|
||||
# ── Dedup check ───────────────────────────────────────────────
|
||||
# If we already read this exact (path, offset, limit) and the
|
||||
# file hasn't been modified since, return a lightweight stub
|
||||
|
||||
113
tools/hardcoded_path_guard.py
Normal file
113
tools/hardcoded_path_guard.py
Normal file
@@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Hardcoded Path Guard — Poka-Yoke #921
|
||||
|
||||
Detects and blocks hardcoded home-directory paths in tool arguments.
|
||||
These paths work on one machine but break on others, VPS deployments,
|
||||
or when HOME changes.
|
||||
|
||||
Usage:
|
||||
from tools.hardcoded_path_guard import check_path, validate_tool_args
|
||||
|
||||
# Check a single path
|
||||
err = check_path("/Users/apayne/.hermes/config.yaml")
|
||||
|
||||
# Validate all path-like args in a tool call
|
||||
clean_args, warnings = validate_tool_args("read_file", {"path": "/home/user/file.txt"})
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json as _json
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
|
||||
# Patterns that indicate hardcoded home directories
|
||||
HARDCODED_PATTERNS = [
|
||||
(r"/Users/[\w.\-]+/", "macOS home directory (/Users/...)"),
|
||||
(r"/home/[\w.\-]+/", "Linux home directory (/home/...)"),
|
||||
(r"(?<![\w/])~/", "unexpanded tilde (~/)"),
|
||||
(r"/root/", "root home directory (/root/)"),
|
||||
]
|
||||
|
||||
_COMPILED_PATTERNS = [(re.compile(p), desc) for p, desc in HARDCODED_PATTERNS]
|
||||
_NOQA_PATTERN = re.compile(r"#\s*noqa:?\s*hardcoded-path-ok")
|
||||
|
||||
_PATH_ARG_NAMES = frozenset({
|
||||
"path", "file_path", "filepath", "dir", "directory", "dest", "source",
|
||||
"input", "output", "src", "dst", "target", "location", "file",
|
||||
"image_path", "script", "config", "log_file",
|
||||
})
|
||||
|
||||
|
||||
def has_hardcoded_path(text: str) -> Optional[str]:
|
||||
if _NOQA_PATTERN.search(text):
|
||||
return None
|
||||
for pattern, desc in _COMPILED_PATTERNS:
|
||||
if pattern.search(text):
|
||||
return desc
|
||||
return None
|
||||
|
||||
|
||||
def check_path(path_value: str) -> Optional[str]:
|
||||
if not isinstance(path_value, str):
|
||||
return None
|
||||
match_desc = has_hardcoded_path(path_value)
|
||||
if match_desc:
|
||||
return (
|
||||
f"Path contains hardcoded home directory ({match_desc}): '{path_value}'. "
|
||||
f"Use $HOME, relative paths, or get_hermes_home(). "
|
||||
f"Add '# noqa: hardcoded-path-ok' if intentional."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def validate_tool_args(tool_name: str, args: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
|
||||
warnings = []
|
||||
for key, value in args.items():
|
||||
if key.lower() not in _PATH_ARG_NAMES:
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
err = check_path(value)
|
||||
if err:
|
||||
warnings.append(err)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
err = check_path(item)
|
||||
if err:
|
||||
warnings.append(err)
|
||||
return args, warnings
|
||||
|
||||
|
||||
def scan_source_for_violations(source_code: str, filename: str = "") -> List[Tuple[int, str, str]]:
|
||||
violations = []
|
||||
lines = source_code.split("\n")
|
||||
for i, line in enumerate(lines, 1):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("#"):
|
||||
if _NOQA_PATTERN.search(line):
|
||||
continue
|
||||
continue
|
||||
if stripped.startswith("import ") or stripped.startswith("from "):
|
||||
continue
|
||||
for pattern, desc in _COMPILED_PATTERNS:
|
||||
match = pattern.search(line)
|
||||
if match:
|
||||
if _NOQA_PATTERN.search(line):
|
||||
continue
|
||||
violations.append((i, line.strip(), desc))
|
||||
break
|
||||
return violations
|
||||
|
||||
|
||||
def guard_tool_dispatch(tool_name: str, args: Dict[str, Any]) -> Optional[str]:
|
||||
_, warnings = validate_tool_args(tool_name, args)
|
||||
if warnings:
|
||||
return _json.dumps({
|
||||
"error": "Hardcoded home directory path detected",
|
||||
"details": warnings,
|
||||
"suggestion": "Use $HOME, relative paths, or get_hermes_home() instead of hardcoded paths.",
|
||||
"pokayoke": True,
|
||||
"rule": "hardcoded-path-guard"
|
||||
})
|
||||
return None
|
||||
298
tools/poka_yoke.py
Normal file
298
tools/poka_yoke.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
poka_yoke.py — Validation firewall for tool calls.
|
||||
|
||||
Poka-yoke (mistake-proofing): validates tool calls against the registry
|
||||
before execution. Catches hallucinated tool names, malformed parameters,
|
||||
missing required arguments, and type mismatches.
|
||||
|
||||
Usage:
|
||||
from tools.poka_yoke import validate_tool_call
|
||||
|
||||
is_valid, corrected_name, corrected_params, messages = validate_tool_call(
|
||||
"read_file", {"path": "test.txt"}
|
||||
)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_tool_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]], List[str]]:
|
||||
"""Validate a tool call against the registry before execution.
|
||||
|
||||
Args:
|
||||
function_name: The tool name from the LLM's function_call.
|
||||
function_args: The arguments dict from the LLM's function_call.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, corrected_name, corrected_params, messages):
|
||||
- is_valid: False if the call should be blocked entirely.
|
||||
- corrected_name: Suggested name if a close match was found (None if OK).
|
||||
- corrected_params: Corrected params if type coercion fixed issues (None if OK).
|
||||
- messages: List of error/warning/info messages.
|
||||
"""
|
||||
from tools.registry import registry
|
||||
|
||||
messages: List[str] = []
|
||||
corrected_name: Optional[str] = None
|
||||
corrected_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
# ── 1. Check if tool exists ───────────────────────────────────────────
|
||||
|
||||
entry = registry.get_entry(function_name)
|
||||
|
||||
if entry is None:
|
||||
# Tool not found — suggest closest match
|
||||
all_names = registry.get_all_tool_names()
|
||||
suggestion = _find_closest_name(function_name, all_names)
|
||||
|
||||
if suggestion:
|
||||
messages.append(
|
||||
f"Unknown tool '{function_name}'. Did you mean '{suggestion}'?"
|
||||
)
|
||||
corrected_name = suggestion
|
||||
# Re-validate with corrected name
|
||||
entry = registry.get_entry(suggestion)
|
||||
if entry is None:
|
||||
return False, corrected_name, None, messages
|
||||
else:
|
||||
available = ", ".join(sorted(all_names)[:20])
|
||||
messages.append(
|
||||
f"Unknown tool '{function_name}'. "
|
||||
f"Available tools: {available}{'...' if len(all_names) > 20 else ''}"
|
||||
)
|
||||
return False, None, None, messages
|
||||
|
||||
# ── 2. Validate parameters against schema ─────────────────────────────
|
||||
|
||||
schema = entry.schema
|
||||
params_schema = schema.get("parameters", {})
|
||||
properties = params_schema.get("properties", {})
|
||||
required = set(params_schema.get("required", []))
|
||||
|
||||
# Check for missing required parameters
|
||||
for param_name in sorted(required):
|
||||
if param_name not in function_args:
|
||||
param_info = properties.get(param_name, {})
|
||||
param_type = param_info.get("type", "any")
|
||||
messages.append(
|
||||
f"Missing required parameter '{param_name}' "
|
||||
f"(expected type: {param_type}). "
|
||||
f"Tool: {function_name}"
|
||||
)
|
||||
|
||||
# If required params are missing, we still return the error
|
||||
# (the agent might be able to self-correct)
|
||||
if any("Missing required" in m for m in messages):
|
||||
# Don't block — return the error as a tool result so the agent can retry
|
||||
# But mark as invalid so caller knows
|
||||
return False, corrected_name, corrected_params, messages
|
||||
|
||||
# ── 3. Check for unknown parameters ───────────────────────────────────
|
||||
|
||||
if properties:
|
||||
known_params = set(properties.keys())
|
||||
# Allow extra params that start with _ (internal convention)
|
||||
unknown = [
|
||||
p for p in function_args
|
||||
if p not in known_params and not p.startswith("_")
|
||||
]
|
||||
if unknown:
|
||||
known_str = ", ".join(sorted(known_params))
|
||||
unknown_str = ", ".join(sorted(unknown))
|
||||
messages.append(
|
||||
f"Unknown parameter(s) for '{function_name}': {unknown_str}. "
|
||||
f"Known parameters: {known_str}"
|
||||
)
|
||||
# Remove unknown params (don't block, just clean)
|
||||
corrected_params = {
|
||||
k: v for k, v in function_args.items()
|
||||
if k in known_params or k.startswith("_")
|
||||
}
|
||||
|
||||
# ── 4. Type validation ────────────────────────────────────────────────
|
||||
|
||||
type_errors = []
|
||||
coerced = dict(corrected_params or function_args)
|
||||
|
||||
for param_name, param_value in coerced.items():
|
||||
if param_name.startswith("_"):
|
||||
continue
|
||||
param_schema = properties.get(param_name)
|
||||
if not param_schema:
|
||||
continue
|
||||
|
||||
expected_type = param_schema.get("type")
|
||||
if not expected_type:
|
||||
continue
|
||||
|
||||
is_valid_type, coerced_value = _validate_type(
|
||||
param_name, param_value, expected_type
|
||||
)
|
||||
if not is_valid_type:
|
||||
type_errors.append(
|
||||
f"Parameter '{param_name}': expected {expected_type}, "
|
||||
f"got {type(param_value).__name__} ({_truncate(str(param_value), 50)})"
|
||||
)
|
||||
elif coerced_value is not param_value:
|
||||
coerced[param_name] = coerced_value
|
||||
messages.append(
|
||||
f"Parameter '{param_name}': coerced from "
|
||||
f"{type(param_value).__name__} to {expected_type}"
|
||||
)
|
||||
|
||||
if type_errors:
|
||||
messages.extend(type_errors)
|
||||
return False, corrected_name, corrected_params, messages
|
||||
|
||||
if coerced != (corrected_params or function_args):
|
||||
corrected_params = coerced
|
||||
|
||||
# ── 5. Enum validation ────────────────────────────────────────────────
|
||||
|
||||
for param_name, param_value in (corrected_params or function_args).items():
|
||||
param_schema = properties.get(param_name, {})
|
||||
enum_values = param_schema.get("enum")
|
||||
if enum_values and param_value not in enum_values:
|
||||
messages.append(
|
||||
f"Parameter '{param_name}': value '{param_value}' not in "
|
||||
f"allowed values: {enum_values}"
|
||||
)
|
||||
return False, corrected_name, corrected_params, messages
|
||||
|
||||
# ── 6. Pattern validation ─────────────────────────────────────────────
|
||||
|
||||
for param_name, param_value in (corrected_params or function_args).items():
|
||||
if not isinstance(param_value, str):
|
||||
continue
|
||||
param_schema = properties.get(param_name, {})
|
||||
pattern = param_schema.get("pattern")
|
||||
if pattern and not re.match(pattern, param_value):
|
||||
messages.append(
|
||||
f"Parameter '{param_name}': value '{_truncate(param_value, 50)}' "
|
||||
f"does not match pattern '{pattern}'"
|
||||
)
|
||||
|
||||
# ── Done ──────────────────────────────────────────────────────────────
|
||||
|
||||
is_valid = not any("Missing required" in m for m in messages)
|
||||
|
||||
if is_valid and not messages:
|
||||
return True, None, None, []
|
||||
|
||||
return is_valid, corrected_name, corrected_params, messages
|
||||
|
||||
|
||||
def _find_closest_name(target: str, candidates: List[str]) -> Optional[str]:
|
||||
"""Find the closest tool name using simple edit distance heuristics."""
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
target_lower = target.lower()
|
||||
|
||||
# Exact prefix match
|
||||
for name in candidates:
|
||||
if name.lower().startswith(target_lower[:4]) and len(target_lower) > 3:
|
||||
return name
|
||||
|
||||
# Substring match
|
||||
for name in candidates:
|
||||
if target_lower in name.lower() or name.lower() in target_lower:
|
||||
return name
|
||||
|
||||
# Levenshtein distance (simple, for short strings)
|
||||
def _levenshtein(a: str, b: str) -> int:
|
||||
if len(a) < len(b):
|
||||
return _levenshtein(b, a)
|
||||
if len(b) == 0:
|
||||
return len(a)
|
||||
prev = range(len(b) + 1)
|
||||
for i, ca in enumerate(a):
|
||||
curr = [i + 1]
|
||||
for j, cb in enumerate(b):
|
||||
curr.append(min(
|
||||
prev[j + 1] + 1,
|
||||
curr[j] + 1,
|
||||
prev[j] + (0 if ca == cb else 1),
|
||||
))
|
||||
prev = curr
|
||||
return prev[-1]
|
||||
|
||||
distances = [(name, _levenshtein(target_lower, name.lower())) for name in candidates]
|
||||
distances.sort(key=lambda x: x[1])
|
||||
|
||||
# Return if edit distance is small enough
|
||||
if distances and distances[0][1] <= max(3, len(target) // 3):
|
||||
return distances[0][0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _validate_type(
|
||||
param_name: str, value: Any, expected_type: str
|
||||
) -> Tuple[bool, Any]:
|
||||
"""Validate and optionally coerce a parameter value to the expected type.
|
||||
|
||||
Returns (is_valid, coerced_value). coerced_value is value itself if no
|
||||
coercion was needed.
|
||||
"""
|
||||
type_map = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": (int, float),
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
expected = type_map.get(expected_type)
|
||||
if expected is None:
|
||||
return True, value # Unknown type, skip validation
|
||||
|
||||
# Direct type check
|
||||
if isinstance(value, expected):
|
||||
return True, value
|
||||
|
||||
# Coercion attempts
|
||||
if expected_type == "string":
|
||||
return True, str(value)
|
||||
|
||||
if expected_type == "integer":
|
||||
if isinstance(value, str) and value.isdigit():
|
||||
return True, int(value)
|
||||
if isinstance(value, float) and value == int(value):
|
||||
return True, int(value)
|
||||
return False, value
|
||||
|
||||
if expected_type == "number":
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return True, float(value)
|
||||
except ValueError:
|
||||
return False, value
|
||||
return False, value
|
||||
|
||||
if expected_type == "boolean":
|
||||
if isinstance(value, str):
|
||||
lower = value.lower()
|
||||
if lower in ("true", "1", "yes"):
|
||||
return True, True
|
||||
if lower in ("false", "0", "no"):
|
||||
return True, False
|
||||
return False, value
|
||||
|
||||
return False, value
|
||||
|
||||
|
||||
def _truncate(s: str, max_len: int) -> str:
|
||||
"""Truncate a string for display."""
|
||||
if len(s) <= max_len:
|
||||
return s
|
||||
return s[:max_len - 3] + "..."
|
||||
275
tools/session_templates.py
Normal file
275
tools/session_templates.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Session templates for code-first seeding.
|
||||
|
||||
Research: Code-heavy sessions (execute_code dominant in first 30 turns) improve over time.
|
||||
File-heavy sessions degrade. Key is deterministic feedback loops.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEMPLATE_DIR = Path.home() / ".hermes" / "session-templates"
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
CODE = "code"
|
||||
FILE = "file"
|
||||
RESEARCH = "research"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExample:
|
||||
tool_name: str
|
||||
arguments: Dict[str, Any]
|
||||
result: str
|
||||
success: bool
|
||||
turn: int = 0
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
name: str
|
||||
task_type: TaskType
|
||||
examples: List[ToolExample]
|
||||
desc: str = ""
|
||||
created: float = 0.0
|
||||
used: int = 0
|
||||
session_id: Optional[str] = None
|
||||
tags: List[str] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created == 0.0:
|
||||
self.created = time.time()
|
||||
|
||||
def to_dict(self):
|
||||
d = asdict(self)
|
||||
d['task_type'] = self.task_type.value
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
data['task_type'] = TaskType(data['task_type'])
|
||||
data['examples'] = [ToolExample.from_dict(e) for e in data.get('examples', [])]
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class Templates:
|
||||
def __init__(self, dir=None):
|
||||
self.dir = dir or TEMPLATE_DIR
|
||||
self.dir.mkdir(parents=True, exist_ok=True)
|
||||
self.templates = {}
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
for f in self.dir.glob("*.json"):
|
||||
try:
|
||||
with open(f) as fh:
|
||||
t = Template.from_dict(json.load(fh))
|
||||
self.templates[t.name] = t
|
||||
except Exception as e:
|
||||
logger.warning(f"Load failed {f}: {e}")
|
||||
|
||||
def _save(self, t):
|
||||
with open(self.dir / f"{t.name}.json", 'w') as f:
|
||||
json.dump(t.to_dict(), f, indent=2)
|
||||
|
||||
def classify(self, calls):
|
||||
if not calls:
|
||||
return TaskType.MIXED
|
||||
code = {'execute_code', 'code_execution'}
|
||||
file_ops = {'read_file', 'write_file', 'patch', 'search_files'}
|
||||
research = {'web_search', 'web_fetch', 'browser_navigate'}
|
||||
names = [c.get('tool_name', '') for c in calls]
|
||||
total = len(names)
|
||||
if sum(1 for n in names if n in code) / total > 0.6:
|
||||
return TaskType.CODE
|
||||
if sum(1 for n in names if n in file_ops) / total > 0.6:
|
||||
return TaskType.FILE
|
||||
if sum(1 for n in names if n in research) / total > 0.6:
|
||||
return TaskType.RESEARCH
|
||||
return TaskType.MIXED
|
||||
|
||||
def extract(self, session_id, max_n=10):
|
||||
db = Path.home() / ".hermes" / "state.db"
|
||||
if not db.exists():
|
||||
return []
|
||||
try:
|
||||
conn = sqlite3.connect(str(db))
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute(
|
||||
"SELECT role, content, tool_calls FROM messages WHERE session_id=? ORDER BY timestamp LIMIT 100",
|
||||
(session_id,)
|
||||
).fetchall()
|
||||
conn.close()
|
||||
examples = []
|
||||
turn = 0
|
||||
for r in rows:
|
||||
if len(examples) >= max_n:
|
||||
break
|
||||
if r['role'] == 'assistant' and r['tool_calls']:
|
||||
try:
|
||||
for tc in json.loads(r['tool_calls']):
|
||||
if len(examples) >= max_n:
|
||||
break
|
||||
name = tc.get('function', {}).get('name')
|
||||
if not name:
|
||||
continue
|
||||
try:
|
||||
args = json.loads(tc.get('function', {}).get('arguments', '{}'))
|
||||
except:
|
||||
args = {}
|
||||
examples.append(ToolExample(name, args, "", True, turn))
|
||||
turn += 1
|
||||
except:
|
||||
continue
|
||||
elif r['role'] == 'tool' and examples and examples[-1].result == "":
|
||||
examples[-1].result = r['content'] or ""
|
||||
return examples
|
||||
except Exception as e:
|
||||
logger.error(f"Extract failed: {e}")
|
||||
return []
|
||||
|
||||
def create(self, session_id, name=None, task_type=None, max_n=10, desc="", tags=None):
|
||||
examples = self.extract(session_id, max_n)
|
||||
if not examples:
|
||||
return None
|
||||
if task_type is None:
|
||||
task_type = self.classify([{'tool_name': e.tool_name} for e in examples])
|
||||
if name is None:
|
||||
name = f"{task_type.value}_{session_id[:8]}_{int(time.time())}"
|
||||
t = Template(name, task_type, examples, desc or f"{len(examples)} examples", time.time(), 0, session_id, tags or [])
|
||||
self.templates[name] = t
|
||||
self._save(t)
|
||||
logger.info(f"Created {name} with {len(examples)} examples")
|
||||
return t
|
||||
|
||||
def get(self, task_type, tags=None):
|
||||
matching = [t for t in self.templates.values() if t.task_type == task_type]
|
||||
if tags:
|
||||
matching = [t for t in matching if any(tag in t.tags for tag in tags)]
|
||||
if not matching:
|
||||
return None
|
||||
matching.sort(key=lambda t: t.used)
|
||||
return matching[0]
|
||||
|
||||
def inject(self, template, messages):
|
||||
if not template.examples:
|
||||
return messages
|
||||
injection = [{
|
||||
"role": "system",
|
||||
"content": f"Template: {template.name} ({template.task_type.value})\n{template.desc}"
|
||||
}]
|
||||
for i, ex in enumerate(template.examples):
|
||||
injection.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": f"tpl_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": ex.tool_name, "arguments": json.dumps(ex.arguments)}
|
||||
}]
|
||||
})
|
||||
injection.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": f"tpl_{i}",
|
||||
"content": ex.result
|
||||
})
|
||||
idx = 0
|
||||
for i, m in enumerate(messages):
|
||||
if m.get("role") != "system":
|
||||
break
|
||||
idx = i + 1
|
||||
for i, m in enumerate(injection):
|
||||
messages.insert(idx + i, m)
|
||||
template.used += 1
|
||||
self._save(template)
|
||||
return messages
|
||||
|
||||
def list(self, task_type=None, tags=None):
|
||||
ts = list(self.templates.values())
|
||||
if task_type:
|
||||
ts = [t for t in ts if t.task_type == task_type]
|
||||
if tags:
|
||||
ts = [t for t in ts if any(tag in t.tags for tag in tags)]
|
||||
ts.sort(key=lambda t: t.created, reverse=True)
|
||||
return ts
|
||||
|
||||
def delete(self, name):
|
||||
if name not in self.templates:
|
||||
return False
|
||||
del self.templates[name]
|
||||
p = self.dir / f"{name}.json"
|
||||
if p.exists():
|
||||
p.unlink()
|
||||
return True
|
||||
|
||||
def stats(self):
|
||||
if not self.templates:
|
||||
return {"total": 0, "by_type": {}, "examples": 0, "usage": 0}
|
||||
by_type = {}
|
||||
total_ex = 0
|
||||
total_use = 0
|
||||
for t in self.templates.values():
|
||||
by_type[t.task_type.value] = by_type.get(t.task_type.value, 0) + 1
|
||||
total_ex += len(t.examples)
|
||||
total_use += t.used
|
||||
return {"total": len(self.templates), "by_type": by_type, "examples": total_ex, "usage": total_use}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
p = argparse.ArgumentParser()
|
||||
s = p.add_subparsers(dest="cmd")
|
||||
lp = s.add_parser("list")
|
||||
lp.add_argument("--type", choices=["code", "file", "research", "mixed"])
|
||||
lp.add_argument("--tags")
|
||||
cp = s.add_parser("create")
|
||||
cp.add_argument("session_id")
|
||||
cp.add_argument("--name")
|
||||
cp.add_argument("--type", choices=["code", "file", "research", "mixed"])
|
||||
cp.add_argument("--max", type=int, default=10)
|
||||
cp.add_argument("--desc")
|
||||
cp.add_argument("--tags")
|
||||
dp = s.add_parser("delete")
|
||||
dp.add_argument("name")
|
||||
sp = s.add_parser("stats")
|
||||
args = p.parse_args()
|
||||
ts = Templates()
|
||||
if args.cmd == "list":
|
||||
tt = TaskType(args.type) if args.type else None
|
||||
tags = args.tags.split(",") if args.tags else None
|
||||
for t in ts.list(tt, tags):
|
||||
print(f"{t.name}: {t.task_type.value} ({len(t.examples)} ex, used {t.used}x)")
|
||||
elif args.cmd == "create":
|
||||
tt = TaskType(args.type) if args.type else None
|
||||
tags = args.tags.split(",") if args.tags else None
|
||||
t = ts.create(args.session_id, args.name, tt, args.max, args.desc or "", tags)
|
||||
if t:
|
||||
print(f"Created: {t.name} ({len(t.examples)} examples)")
|
||||
else:
|
||||
print("Failed")
|
||||
elif args.cmd == "delete":
|
||||
print("Deleted" if ts.delete(args.name) else "Not found")
|
||||
elif args.cmd == "stats":
|
||||
s = ts.stats()
|
||||
print(f"Total: {s['total']}, Examples: {s['examples']}, Usage: {s['usage']}")
|
||||
for k, v in s['by_type'].items():
|
||||
print(f" {k}: {v}")
|
||||
else:
|
||||
p.print_help()
|
||||
@@ -38,12 +38,41 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from hermes_constants import get_hermes_home, display_hermes_home
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_error(
|
||||
message: str,
|
||||
skill_name: str = None,
|
||||
file_path: str = None,
|
||||
suggestion: str = None,
|
||||
context: dict = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Format an error with rich context for better debugging."""
|
||||
parts = [message]
|
||||
if skill_name:
|
||||
parts.append(f"Skill: {skill_name}")
|
||||
if file_path:
|
||||
parts.append(f"File: {file_path}")
|
||||
if suggestion:
|
||||
parts.append(f"Suggestion: {suggestion}")
|
||||
if context:
|
||||
for key, value in context.items():
|
||||
parts.append(f"{key}: {value}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": " | ".join(parts),
|
||||
"skill_name": skill_name,
|
||||
"file_path": file_path,
|
||||
"suggestion": suggestion,
|
||||
}
|
||||
|
||||
|
||||
# Import security scanner — agent-created skills get the same scrutiny as
|
||||
# community hub installs.
|
||||
try:
|
||||
@@ -253,6 +282,94 @@ def _resolve_skill_target(skill_dir: Path, file_path: str) -> Tuple[Optional[Pat
|
||||
return target, None
|
||||
|
||||
|
||||
MAX_BACKUPS_PER_FILE = 3
|
||||
|
||||
|
||||
def _backup_skill_file(file_path: Path) -> Optional[Path]:
|
||||
"""Create a timestamped backup of a skill file before modification.
|
||||
|
||||
The backup is named ``{original_name}.bak.{unix_timestamp}`` and placed
|
||||
in the same directory. Returns the backup path, or *None* if the file
|
||||
does not exist yet (nothing to back up).
|
||||
"""
|
||||
if not file_path.exists():
|
||||
return None
|
||||
timestamp = int(time.time())
|
||||
backup_path = file_path.parent / f"{file_path.name}.bak.{timestamp}"
|
||||
shutil.copy2(str(file_path), str(backup_path))
|
||||
return backup_path
|
||||
|
||||
|
||||
def _cleanup_old_backups(file_path: Path, max_backups: int = MAX_BACKUPS_PER_FILE) -> None:
|
||||
"""Prune backup files so at most *max_backups* are retained.
|
||||
|
||||
Backups match the pattern ``{file_path.name}.bak.*`` in the same
|
||||
directory. The oldest (by mtime) are removed first.
|
||||
"""
|
||||
parent = file_path.parent
|
||||
prefix = file_path.name + ".bak."
|
||||
try:
|
||||
backups: List[Path] = sorted(
|
||||
[f for f in parent.iterdir() if f.name.startswith(prefix) and f.is_file()],
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
)
|
||||
except OSError:
|
||||
return
|
||||
while len(backups) > max_backups:
|
||||
try:
|
||||
backups.pop(0).unlink()
|
||||
except OSError:
|
||||
break
|
||||
|
||||
|
||||
def _validate_written_file(file_path: Path, is_skill_md: bool = False) -> Optional[str]:
|
||||
"""Re-read a file from disk and validate it after writing.
|
||||
|
||||
Catches filesystem-level issues (truncation, encoding errors, empty
|
||||
writes) that pre-write validation cannot detect. For SKILL.md files
|
||||
the frontmatter is also re-validated.
|
||||
|
||||
Returns an error message, or *None* if the file looks healthy.
|
||||
"""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
return f"Failed to read file after write: {exc}"
|
||||
except UnicodeDecodeError as exc:
|
||||
return f"File encoding error after write: {exc}"
|
||||
|
||||
if len(content) == 0:
|
||||
return "File is empty after write (possible truncation)."
|
||||
|
||||
if is_skill_md:
|
||||
err = _validate_frontmatter(content)
|
||||
if err:
|
||||
return f"Post-write validation failed: {err}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _revert_from_backup(file_path: Path, backup_path: Optional[Path]) -> None:
|
||||
"""Restore *file_path* from *backup_path*.
|
||||
|
||||
If *backup_path* is None or missing the target file is removed so the
|
||||
skill directory is at least not left with corrupted content.
|
||||
"""
|
||||
if backup_path and backup_path.exists():
|
||||
try:
|
||||
shutil.copy2(str(backup_path), str(file_path))
|
||||
except OSError:
|
||||
logger.error(
|
||||
"Failed to restore %s from backup %s", file_path, backup_path, exc_info=True
|
||||
)
|
||||
else:
|
||||
# No backup — remove the partially-written file
|
||||
try:
|
||||
file_path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
logger.error("Failed to remove corrupted file %s after failed write", file_path, exc_info=True)
|
||||
|
||||
|
||||
def _atomic_write_text(file_path: Path, content: str, encoding: str = "utf-8") -> None:
|
||||
"""
|
||||
Atomically write text content to a file.
|
||||
@@ -358,20 +475,35 @@ def _edit_skill(name: str, content: str) -> Dict[str, Any]:
|
||||
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
return {"success": False, "error": f"Skill '{name}' not found. Use skills_list() to see available skills."}
|
||||
return _format_error(
|
||||
f"Skill '{name}' not found.",
|
||||
skill_name=name,
|
||||
suggestion="Use skills_list() to see available skills.",
|
||||
)
|
||||
|
||||
skill_md = existing["path"] / "SKILL.md"
|
||||
# Back up original content for rollback
|
||||
original_content = skill_md.read_text(encoding="utf-8") if skill_md.exists() else None
|
||||
|
||||
# --- Transactional write-validate-commit-or-rollback ---
|
||||
backup_path = _backup_skill_file(skill_md)
|
||||
_atomic_write_text(skill_md, content)
|
||||
|
||||
# Post-write validation: catch filesystem-level failures
|
||||
validate_err = _validate_written_file(skill_md, is_skill_md=True)
|
||||
if validate_err:
|
||||
_revert_from_backup(skill_md, backup_path)
|
||||
return {"success": False, "error": f"Edit reverted: {validate_err}"}
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(existing["path"])
|
||||
if scan_error:
|
||||
if original_content is not None:
|
||||
_atomic_write_text(skill_md, original_content)
|
||||
_revert_from_backup(skill_md, backup_path)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
# Success — remove the backup we just created, prune any older ones
|
||||
if backup_path:
|
||||
backup_path.unlink(missing_ok=True)
|
||||
_cleanup_old_backups(skill_md)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Skill '{name}' updated.",
|
||||
@@ -392,13 +524,25 @@ def _patch_skill(
|
||||
Requires a unique match unless replace_all is True.
|
||||
"""
|
||||
if not old_string:
|
||||
return {"success": False, "error": "old_string is required for 'patch'."}
|
||||
return _format_error(
|
||||
"old_string is required for 'patch'.",
|
||||
skill_name=name,
|
||||
suggestion="Provide the exact text to find in the skill file.",
|
||||
)
|
||||
if new_string is None:
|
||||
return {"success": False, "error": "new_string is required for 'patch'. Use an empty string to delete matched text."}
|
||||
return _format_error(
|
||||
"new_string is required for 'patch'. Use an empty string to delete matched text.",
|
||||
skill_name=name,
|
||||
suggestion="Pass new_string='' to delete the matched text.",
|
||||
)
|
||||
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
return {"success": False, "error": f"Skill '{name}' not found."}
|
||||
return _format_error(
|
||||
f"Skill '{name}' not found.",
|
||||
skill_name=name,
|
||||
suggestion="Use skills_list() to see available skills.",
|
||||
)
|
||||
|
||||
skill_dir = existing["path"]
|
||||
|
||||
@@ -452,15 +596,29 @@ def _patch_skill(
|
||||
"error": f"Patch would break SKILL.md structure: {err}",
|
||||
}
|
||||
|
||||
original_content = content # for rollback
|
||||
is_skill_md = not file_path
|
||||
|
||||
# --- Transactional write-validate-commit-or-rollback ---
|
||||
backup_path = _backup_skill_file(target)
|
||||
_atomic_write_text(target, new_content)
|
||||
|
||||
# Post-write validation
|
||||
validate_err = _validate_written_file(target, is_skill_md=is_skill_md)
|
||||
if validate_err:
|
||||
_revert_from_backup(target, backup_path)
|
||||
return {"success": False, "error": f"Patch reverted: {validate_err}"}
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(skill_dir)
|
||||
if scan_error:
|
||||
_atomic_write_text(target, original_content)
|
||||
_revert_from_backup(target, backup_path)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
# Success — remove the backup we just created, prune any older ones
|
||||
if backup_path:
|
||||
backup_path.unlink(missing_ok=True)
|
||||
_cleanup_old_backups(target)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Patched {'SKILL.md' if not file_path else file_path} in skill '{name}' ({match_count} replacement{'s' if match_count > 1 else ''}).",
|
||||
@@ -519,19 +677,28 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
|
||||
if err:
|
||||
return {"success": False, "error": err}
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Back up for rollback
|
||||
original_content = target.read_text(encoding="utf-8") if target.exists() else None
|
||||
|
||||
# --- Transactional write-validate-commit-or-rollback ---
|
||||
backup_path = _backup_skill_file(target)
|
||||
_atomic_write_text(target, file_content)
|
||||
|
||||
# Post-write validation: ensure the file is readable and non-empty
|
||||
validate_err = _validate_written_file(target, is_skill_md=False)
|
||||
if validate_err:
|
||||
_revert_from_backup(target, backup_path)
|
||||
return {"success": False, "error": f"Write reverted: {validate_err}"}
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(existing["path"])
|
||||
if scan_error:
|
||||
if original_content is not None:
|
||||
_atomic_write_text(target, original_content)
|
||||
else:
|
||||
target.unlink(missing_ok=True)
|
||||
_revert_from_backup(target, backup_path)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
# Success — remove the backup we just created, prune any older ones
|
||||
if backup_path:
|
||||
backup_path.unlink(missing_ok=True)
|
||||
_cleanup_old_backups(target)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"File '{file_path}' written to skill '{name}'.",
|
||||
|
||||
312
tools/tool_validator.py
Normal file
312
tools/tool_validator.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
Poka-Yoke: Tool Hallucination Detection — #922.
|
||||
|
||||
Validation firewall between LLM tool-call output and actual execution.
|
||||
|
||||
Detects and blocks:
|
||||
1. Unknown tool names (hallucinated tools)
|
||||
2. Malformed parameters (wrong types)
|
||||
3. Missing required arguments
|
||||
4. Extra unknown parameters
|
||||
|
||||
Poka-Yoke Type: Detection (catches errors at the boundary before harm)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidationSeverity(Enum):
|
||||
"""Severity of validation failure."""
|
||||
BLOCK = "block" # Must block execution
|
||||
WARN = "warn" # Warning, may proceed
|
||||
INFO = "info" # Informational
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationIssue:
|
||||
"""A validation issue found."""
|
||||
severity: ValidationSeverity
|
||||
code: str
|
||||
message: str
|
||||
tool_name: str
|
||||
parameter: Optional[str] = None
|
||||
expected: Optional[str] = None
|
||||
actual: Optional[Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of tool call validation."""
|
||||
valid: bool
|
||||
tool_name: str
|
||||
issues: List[ValidationIssue] = field(default_factory=list)
|
||||
corrected_args: Optional[Dict[str, Any]] = None
|
||||
|
||||
@property
|
||||
def blocking_issues(self) -> List[ValidationIssue]:
|
||||
return [i for i in self.issues if i.severity == ValidationSeverity.BLOCK]
|
||||
|
||||
@property
|
||||
def warnings(self) -> List[ValidationIssue]:
|
||||
return [i for i in self.issues if i.severity == ValidationSeverity.WARN]
|
||||
|
||||
|
||||
class ToolHallucinationDetector:
|
||||
"""
|
||||
Poka-yoke detector for tool hallucinations.
|
||||
|
||||
Validates tool calls against registered schemas before execution.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_registry: Optional[Dict] = None):
|
||||
"""
|
||||
Initialize detector.
|
||||
|
||||
Args:
|
||||
tool_registry: Dict of tool_name -> tool_schema
|
||||
"""
|
||||
self.registry = tool_registry or {}
|
||||
self._rejection_log: List[Dict] = []
|
||||
|
||||
def register_tool(self, name: str, schema: Dict):
|
||||
"""Register a tool with its JSON Schema."""
|
||||
self.registry[name] = schema
|
||||
|
||||
def register_tools(self, tools: Dict[str, Dict]):
|
||||
"""Register multiple tools."""
|
||||
self.registry.update(tools)
|
||||
|
||||
def validate_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
model: str = "unknown",
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate a tool call against the registry.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool being called
|
||||
arguments: Arguments passed to the tool
|
||||
model: Model that generated the call (for logging)
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status
|
||||
"""
|
||||
issues = []
|
||||
|
||||
# 1. Check if tool exists
|
||||
if tool_name not in self.registry:
|
||||
issue = ValidationIssue(
|
||||
severity=ValidationSeverity.BLOCK,
|
||||
code="UNKNOWN_TOOL",
|
||||
message=f"Tool '{tool_name}' does not exist. Available: {', '.join(sorted(self.registry.keys())[:10])}...",
|
||||
tool_name=tool_name,
|
||||
)
|
||||
issues.append(issue)
|
||||
self._log_rejection(tool_name, arguments, model, "UNKNOWN_TOOL")
|
||||
return ValidationResult(valid=False, tool_name=tool_name, issues=issues)
|
||||
|
||||
schema = self.registry[tool_name]
|
||||
params_schema = schema.get("parameters", {}).get("properties", {})
|
||||
required = set(schema.get("parameters", {}).get("required", []))
|
||||
|
||||
# 2. Check for missing required parameters
|
||||
for param in required:
|
||||
if param not in arguments:
|
||||
issue = ValidationIssue(
|
||||
severity=ValidationSeverity.BLOCK,
|
||||
code="MISSING_REQUIRED",
|
||||
message=f"Missing required parameter: {param}",
|
||||
tool_name=tool_name,
|
||||
parameter=param,
|
||||
)
|
||||
issues.append(issue)
|
||||
|
||||
# 3. Check parameter types
|
||||
for param_name, param_value in arguments.items():
|
||||
if param_name not in params_schema:
|
||||
# Unknown parameter
|
||||
issue = ValidationIssue(
|
||||
severity=ValidationSeverity.WARN,
|
||||
code="UNKNOWN_PARAM",
|
||||
message=f"Unknown parameter: {param_name}",
|
||||
tool_name=tool_name,
|
||||
parameter=param_name,
|
||||
)
|
||||
issues.append(issue)
|
||||
continue
|
||||
|
||||
param_schema = params_schema[param_name]
|
||||
expected_type = param_schema.get("type")
|
||||
|
||||
if expected_type and not self._check_type(param_value, expected_type):
|
||||
issue = ValidationIssue(
|
||||
severity=ValidationSeverity.BLOCK,
|
||||
code="WRONG_TYPE",
|
||||
message=f"Parameter '{param_name}' expects {expected_type}, got {type(param_value).__name__}",
|
||||
tool_name=tool_name,
|
||||
parameter=param_name,
|
||||
expected=expected_type,
|
||||
actual=type(param_value).__name__,
|
||||
)
|
||||
issues.append(issue)
|
||||
|
||||
# 4. Check for common hallucination patterns
|
||||
hallucination_issues = self._detect_hallucination_patterns(tool_name, arguments)
|
||||
issues.extend(hallucination_issues)
|
||||
|
||||
# Determine validity
|
||||
has_blocking = any(i.severity == ValidationSeverity.BLOCK for i in issues)
|
||||
|
||||
if has_blocking:
|
||||
self._log_rejection(tool_name, arguments, model,
|
||||
"; ".join(i.code for i in issues if i.severity == ValidationSeverity.BLOCK))
|
||||
|
||||
return ValidationResult(
|
||||
valid=not has_blocking,
|
||||
tool_name=tool_name,
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
def _check_type(self, value: Any, expected_type: str) -> bool:
|
||||
"""Check if value matches expected JSON Schema type."""
|
||||
type_map = {
|
||||
"string": str,
|
||||
"number": (int, float),
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
expected = type_map.get(expected_type)
|
||||
if expected is None:
|
||||
return True # Unknown type, assume OK
|
||||
|
||||
return isinstance(value, expected)
|
||||
|
||||
def _detect_hallucination_patterns(self, tool_name: str, arguments: Dict) -> List[ValidationIssue]:
|
||||
"""Detect common hallucination patterns."""
|
||||
issues = []
|
||||
|
||||
# Pattern 1: Placeholder values
|
||||
placeholder_patterns = [
|
||||
r"^<.*>$", # <placeholder>
|
||||
r"^\[.*\]$", # [placeholder]
|
||||
r"^TODO$|^FIXME$", # TODO/FIXME
|
||||
r"^example\.com$", # example.com
|
||||
r"^127\.0\.0\.1$", # localhost
|
||||
]
|
||||
|
||||
for param_name, param_value in arguments.items():
|
||||
if isinstance(param_value, str):
|
||||
for pattern in placeholder_patterns:
|
||||
if re.match(pattern, param_value, re.IGNORECASE):
|
||||
issues.append(ValidationIssue(
|
||||
severity=ValidationSeverity.WARN,
|
||||
code="PLACEHOLDER_VALUE",
|
||||
message=f"Parameter '{param_name}' contains placeholder: {param_value}",
|
||||
tool_name=tool_name,
|
||||
parameter=param_name,
|
||||
))
|
||||
|
||||
# Pattern 2: Suspiciously long strings (might be hallucinated content)
|
||||
for param_name, param_value in arguments.items():
|
||||
if isinstance(param_value, str) and len(param_value) > 10000:
|
||||
issues.append(ValidationIssue(
|
||||
severity=ValidationSeverity.WARN,
|
||||
code="SUSPICIOUS_LENGTH",
|
||||
message=f"Parameter '{param_name}' is unusually long ({len(param_value)} chars)",
|
||||
tool_name=tool_name,
|
||||
parameter=param_name,
|
||||
))
|
||||
|
||||
return issues
|
||||
|
||||
def _log_rejection(self, tool_name: str, arguments: Dict, model: str, reason: str):
|
||||
"""Log a rejected tool call for analysis."""
|
||||
import time
|
||||
|
||||
entry = {
|
||||
"timestamp": time.time(),
|
||||
"tool_name": tool_name,
|
||||
"arguments": {k: str(v)[:100] for k, v in arguments.items()},
|
||||
"model": model,
|
||||
"reason": reason,
|
||||
}
|
||||
|
||||
self._rejection_log.append(entry)
|
||||
|
||||
# Keep log bounded
|
||||
if len(self._rejection_log) > 1000:
|
||||
self._rejection_log = self._rejection_log[-500:]
|
||||
|
||||
logger.warning(
|
||||
"Tool hallucination blocked: tool=%s, model=%s, reason=%s",
|
||||
tool_name, model, reason
|
||||
)
|
||||
|
||||
def get_rejection_stats(self) -> Dict:
|
||||
"""Get statistics on rejected tool calls."""
|
||||
if not self._rejection_log:
|
||||
return {"total": 0, "by_reason": {}, "by_tool": {}}
|
||||
|
||||
by_reason = {}
|
||||
by_tool = {}
|
||||
|
||||
for entry in self._rejection_log:
|
||||
reason = entry["reason"]
|
||||
tool = entry["tool_name"]
|
||||
|
||||
by_reason[reason] = by_reason.get(reason, 0) + 1
|
||||
by_tool[tool] = by_tool.get(tool, 0) + 1
|
||||
|
||||
return {
|
||||
"total": len(self._rejection_log),
|
||||
"by_reason": by_reason,
|
||||
"by_tool": by_tool,
|
||||
}
|
||||
|
||||
def format_validation_report(self, result: ValidationResult) -> str:
|
||||
"""Format validation result as human-readable report."""
|
||||
if result.valid:
|
||||
return f"✅ {result.tool_name}: valid"
|
||||
|
||||
lines = [f"❌ {result.tool_name}: BLOCKED"]
|
||||
for issue in result.blocking_issues:
|
||||
lines.append(f" [{issue.code}] {issue.message}")
|
||||
|
||||
for issue in result.warnings:
|
||||
lines.append(f" ⚠️ [{issue.code}] {issue.message}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def create_rejection_response(result: ValidationResult) -> Dict:
|
||||
"""
|
||||
Create a tool result for a rejected tool call.
|
||||
|
||||
This allows the agent to see the rejection and self-correct.
|
||||
"""
|
||||
issues_text = "\n".join(
|
||||
f"- [{i.code}] {i.message}"
|
||||
for i in result.blocking_issues
|
||||
)
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": f"""Tool call rejected: {result.tool_name}
|
||||
|
||||
Issues found:
|
||||
{issues_text}
|
||||
|
||||
Please check the tool name and parameters, then try again with valid arguments.""",
|
||||
}
|
||||
Reference in New Issue
Block a user