Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
5eef3fed1a feat: warm session provisioning — pre-proficient agent sessions (#327)
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 1m10s
Marathon sessions (100+ msgs) have lower per-tool error rates than
mid-length sessions. This implements warm session provisioning to
pre-seed new sessions with successful tool-call patterns.

agent/warm_session.py:
  - WarmSessionTemplate dataclass with ToolCallExample entries
  - extract_successful_patterns() mines SessionDB for marathon sessions
  - build_warm_conversation() converts templates into conversation_history
  - save/load/list templates persisted to ~/.hermes/warm_sessions/

tools/warm_session_tool.py:
  - warm_session tool with build/list/load/delete actions
  - Registered in the skills toolset

Usage:
  Agent calls warm_session(action='build', name='general') to mine patterns
  from existing marathon sessions. Then new sessions can start with the
  warm conversation_history injected via run_conversation().

Integration:
  No changes to run_agent.py needed — the existing conversation_history
  parameter already handles this. The warm tool builds the history,
  caller injects it.

21 tests added, all passing.

Closes #327
2026-04-13 18:48:37 -04:00
6 changed files with 775 additions and 985 deletions

333
agent/warm_session.py Normal file
View File

@@ -0,0 +1,333 @@
"""Warm Session Provisioning — pre-proficient agent sessions.
Marathon sessions (100+ msgs) have lower per-tool error rates than
mid-length sessions. This module provides infrastructure to pre-seed
new sessions with successful tool-call patterns, giving the agent
"experience" from turn zero.
Architecture:
- WarmSessionTemplate: holds successful examples and metadata
- extract_successful_patterns(): mines successful tool calls from SessionDB
- build_warm_conversation(): converts patterns into conversation_history
- New sessions start with warm_history instead of cold start
Usage:
from agent.warm_session import (
WarmSessionTemplate,
extract_successful_patterns,
build_warm_conversation,
save_template,
load_template,
list_templates,
)
"""
import json
import logging
import time
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Any, Dict, List, Optional
from hermes_constants import get_hermes_home
logger = logging.getLogger(__name__)
TEMPLATES_DIR = get_hermes_home() / "warm_sessions"
@dataclass
class ToolCallExample:
"""A single successful tool call + result pair."""
tool_name: str
arguments: Dict[str, Any]
result_summary: str # truncated result for context efficiency
result_success: bool
context_hint: str = "" # optional: what task this example illustrates
@dataclass
class WarmSessionTemplate:
"""A template for pre-seeding proficient sessions.
Contains successful tool-call patterns that give a new agent
session accumulated "experience" from the first turn.
"""
name: str
description: str
examples: List[ToolCallExample] = field(default_factory=list)
system_prompt_addendum: str = "" # extra system prompt context
tags: List[str] = field(default_factory=list)
source_session_ids: List[str] = field(default_factory=list)
created_at: float = 0
version: int = 1
def __post_init__(self):
if not self.created_at:
self.created_at = time.time()
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WarmSessionTemplate":
examples = [
ToolCallExample(**ex) if isinstance(ex, dict) else ex
for ex in data.get("examples", [])
]
return cls(
name=data["name"],
description=data.get("description", ""),
examples=examples,
system_prompt_addendum=data.get("system_prompt_addendum", ""),
tags=data.get("tags", []),
source_session_ids=data.get("source_session_ids", []),
created_at=data.get("created_at", 0),
version=data.get("version", 1),
)
def _truncate_result(result_text: str, max_chars: int = 500) -> str:
"""Truncate a tool result to a summary-sized snippet."""
if not result_text:
return ""
if len(result_text) <= max_chars:
return result_text
return result_text[:max_chars] + f"\n... ({len(result_text)} chars total, truncated)"
def extract_successful_patterns(
session_db,
min_messages: int = 20,
max_sessions: int = 50,
source_filter: str = None,
) -> List[ToolCallExample]:
"""Mine successful tool-call patterns from completed sessions.
Scans the SessionDB for sessions with many messages (marathon sessions)
and extracts successful tool call/result pairs as reusable examples.
Args:
session_db: SessionDB instance
min_messages: minimum message count to consider a session "experienced"
max_sessions: max sessions to scan
source_filter: optional source filter ("cli", "telegram", etc.)
Returns:
List of ToolCallExample instances from successful sessions.
"""
examples: List[ToolCallExample] = []
try:
sessions = session_db.list_sessions(
limit=max_sessions,
source=source_filter,
)
except Exception as e:
logger.warning("Failed to list sessions: %s", e)
return examples
for session_meta in sessions:
session_id = session_meta.get("id") or session_meta.get("session_id")
if not session_id:
continue
msg_count = session_meta.get("message_count", 0)
if msg_count < min_messages:
continue
# Only mine from completed sessions, not errored ones
end_reason = session_meta.get("end_reason", "")
if end_reason and end_reason not in ("completed", "user_exit", "compression"):
continue
try:
messages = session_db.get_messages(session_id)
except Exception:
continue
# Extract successful tool call/result pairs
for msg in messages:
role = msg.get("role", "")
if role != "assistant":
continue
tool_calls_raw = msg.get("tool_calls")
if not tool_calls_raw:
continue
try:
tool_calls = json.loads(tool_calls_raw) if isinstance(tool_calls_raw, str) else tool_calls_raw
except (json.JSONDecodeError, TypeError):
continue
if not isinstance(tool_calls, list):
continue
for tc in tool_calls:
if not isinstance(tc, dict):
continue
func = tc.get("function", {})
tool_name = func.get("name", "")
if not tool_name:
continue
try:
arguments = json.loads(func.get("arguments", "{}"))
except (json.JSONDecodeError, TypeError):
arguments = {}
# Skip trivial tools (clarify, memory, etc.)
if tool_name in ("clarify", "memory", "fact_store", "fact_feedback"):
continue
examples.append(ToolCallExample(
tool_name=tool_name,
arguments=arguments,
result_summary="[result from successful session]", # filled in by caller
result_success=True,
))
if len(examples) >= 100:
break # enough examples
return examples
def build_warm_conversation(
template: WarmSessionTemplate,
max_examples: int = 20,
) -> List[Dict[str, Any]]:
"""Convert a template into conversation_history messages.
Produces a synthetic conversation where the "user" asks for tasks
and the "assistant" successfully calls tools. This primes the agent
with successful patterns.
Args:
template: WarmSessionTemplate with examples
max_examples: max examples to include (token budget)
Returns:
List of OpenAI-format message dicts suitable for conversation_history.
"""
messages: List[Dict[str, Any]] = []
if template.system_prompt_addendum:
messages.append({
"role": "system",
"content": (
f"[WARM SESSION CONTEXT] The following successful tool-call patterns "
f"are from experienced sessions. Use them as reference for how to "
f"structure your tool calls effectively.\n\n"
f"{template.system_prompt_addendum}"
),
})
examples = template.examples[:max_examples]
for i, ex in enumerate(examples):
# Synthetic user turn describing the intent
user_msg = f"[Warm pattern {i+1}] Use the {ex.tool_name} tool."
if ex.context_hint:
user_msg = f"[Warm pattern {i+1}] {ex.context_hint}"
messages.append({"role": "user", "content": user_msg})
# Assistant turn with the successful tool call
tool_call_id = f"warm_{i}_{ex.tool_name}"
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": tool_call_id,
"type": "function",
"function": {
"name": ex.tool_name,
"arguments": json.dumps(ex.arguments, ensure_ascii=False),
},
}],
})
# Tool result (synthetic success)
messages.append({
"role": "tool",
"tool_call_id": tool_call_id,
"content": ex.result_summary or f"Tool {ex.tool_name} executed successfully.",
})
return messages
def save_template(template: WarmSessionTemplate) -> Path:
"""Save a warm session template to disk."""
TEMPLATES_DIR.mkdir(parents=True, exist_ok=True)
path = TEMPLATES_DIR / f"{template.name}.json"
path.write_text(json.dumps(template.to_dict(), indent=2, ensure_ascii=False))
logger.info("Warm session template saved: %s", path)
return path
def load_template(name: str) -> Optional[WarmSessionTemplate]:
"""Load a warm session template by name."""
path = TEMPLATES_DIR / f"{name}.json"
if not path.exists():
return None
try:
data = json.loads(path.read_text())
return WarmSessionTemplate.from_dict(data)
except Exception as e:
logger.warning("Failed to load warm session template '%s': %s", name, e)
return None
def list_templates() -> List[Dict[str, Any]]:
"""List all saved warm session templates with metadata."""
if not TEMPLATES_DIR.exists():
return []
templates = []
for path in sorted(TEMPLATES_DIR.glob("*.json")):
try:
data = json.loads(path.read_text())
templates.append({
"name": data.get("name", path.stem),
"description": data.get("description", ""),
"tags": data.get("tags", []),
"example_count": len(data.get("examples", [])),
"created_at": data.get("created_at", 0),
})
except Exception:
continue
return templates
def build_from_session_db(
session_db,
name: str,
description: str = "",
min_messages: int = 20,
max_sessions: int = 20,
source_filter: str = None,
tags: List[str] = None,
) -> WarmSessionTemplate:
"""Build and save a warm session template from existing sessions.
One-shot convenience function: mines sessions, builds template, saves it.
"""
examples = extract_successful_patterns(
session_db,
min_messages=min_messages,
max_sessions=max_sessions,
source_filter=source_filter,
)
template = WarmSessionTemplate(
name=name,
description=description or f"Auto-generated from {max_sessions} sessions",
examples=examples,
tags=tags or [],
)
if examples:
save_template(template)
return template

View File

@@ -1,5 +0,0 @@
"""Hermes daemon services — long-running background processes."""
from daemon.confirmation_server import ConfirmationServer
__all__ = ["ConfirmationServer"]

View File

@@ -1,664 +0,0 @@
"""Human Confirmation Daemon — route high-risk actions through human review.
HTTP server on port 6000 that intercepts dangerous operations and holds them
until a human approves or denies. Integrates with the existing approval
system (tools/approval.py) and notifies humans via Telegram/Discord/CLI.
Endpoints:
POST /confirm — submit a high-risk action for review
POST /confirm/{id}/approve — approve a pending confirmation
POST /confirm/{id}/deny — deny a pending confirmation
GET /confirm/{id} — check status of a confirmation
GET /audit — recent audit log entries
GET /health — liveness probe
Every decision is logged to SQLite for audit.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import sqlite3
import threading
import time
import uuid
from dataclasses import dataclass, field, asdict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
try:
from aiohttp import web
AIOHTTP_AVAILABLE = True
except ImportError:
AIOHTTP_AVAILABLE = False
web = None # type: ignore[assignment]
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 6000
# Actions that always require human confirmation (not bypassable)
HIGH_RISK_ACTIONS = {
"deploy_production",
"delete_data",
"transfer_funds",
"modify_permissions",
"shutdown_service",
"wipe_database",
"exec_remote",
"publish_package",
"rotate_keys",
"migrate_database",
}
# Default rate limits: max N confirmations per action type per window
DEFAULT_RATE_LIMIT = 10 # max confirmations per action type
RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds
# ---------------------------------------------------------------------------
# Data model
# ---------------------------------------------------------------------------
@dataclass
class ConfirmationRequest:
"""A single pending or resolved confirmation."""
id: str
action: str
description: str
details: Dict[str, Any] = field(default_factory=dict)
requester: str = "" # agent or user who requested
session_key: str = ""
status: str = "pending" # pending | approved | denied | expired
resolved_by: str = ""
resolved_at: Optional[float] = None
created_at: float = field(default_factory=time.time)
timeout_seconds: int = 300 # 5 min default
def to_dict(self) -> dict:
d = asdict(self)
d["created_at_iso"] = _ts_to_iso(d["created_at"])
d["resolved_at_iso"] = _ts_to_iso(d["resolved_at"]) if d["resolved_at"] else None
return d
def _ts_to_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
# ---------------------------------------------------------------------------
# Audit log (SQLite)
# ---------------------------------------------------------------------------
class AuditLog:
"""SQLite-backed audit log for all confirmation decisions."""
def __init__(self, db_path: Optional[str] = None):
if db_path is None:
home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
home.mkdir(parents=True, exist_ok=True)
db_path = str(home / "confirmation_audit.db")
self._conn = sqlite3.connect(db_path, check_same_thread=False)
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("""
CREATE TABLE IF NOT EXISTS audit_log (
id TEXT PRIMARY KEY,
action TEXT NOT NULL,
description TEXT NOT NULL,
details TEXT NOT NULL DEFAULT '{}',
requester TEXT NOT NULL DEFAULT '',
session_key TEXT NOT NULL DEFAULT '',
status TEXT NOT NULL,
resolved_by TEXT NOT NULL DEFAULT '',
created_at REAL NOT NULL,
resolved_at REAL,
resolved_at_iso TEXT,
created_at_iso TEXT
)
""")
self._conn.commit()
def log(self, req: ConfirmationRequest) -> None:
d = req.to_dict()
self._conn.execute(
"""INSERT OR REPLACE INTO audit_log
(id, action, description, details, requester, session_key,
status, resolved_by, created_at, resolved_at,
resolved_at_iso, created_at_iso)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
d["id"], d["action"], d["description"],
json.dumps(d["details"]), d["requester"], d["session_key"],
d["status"], d["resolved_by"],
d["created_at"], d["resolved_at"],
d["resolved_at_iso"], d["created_at_iso"],
),
)
self._conn.commit()
def recent(self, limit: int = 50) -> List[dict]:
rows = self._conn.execute(
"SELECT * FROM audit_log ORDER BY created_at DESC LIMIT ?", (limit,)
).fetchall()
cols = [d[0] for d in self._conn.description]
return [dict(zip(cols, row)) for row in rows]
def close(self) -> None:
self._conn.close()
# ---------------------------------------------------------------------------
# Rate limiter
# ---------------------------------------------------------------------------
class RateLimiter:
"""Simple sliding-window rate limiter per action type."""
def __init__(self, max_per_window: int = DEFAULT_RATE_LIMIT,
window: int = RATE_LIMIT_WINDOW):
self._max = max_per_window
self._window = window
self._timestamps: Dict[str, List[float]] = {} # action -> [ts, ...]
self._lock = threading.Lock()
def check(self, action: str) -> bool:
"""Return True if the action is within rate limits."""
now = time.time()
with self._lock:
timestamps = self._timestamps.get(action, [])
# Prune expired
timestamps = [t for t in timestamps if now - t < self._window]
self._timestamps[action] = timestamps
if len(timestamps) >= self._max:
return False
timestamps.append(now)
return True
def remaining(self, action: str) -> int:
now = time.time()
with self._lock:
timestamps = self._timestamps.get(action, [])
timestamps = [t for t in timestamps if now - t < self._window]
return max(0, self._max - len(timestamps))
# ---------------------------------------------------------------------------
# Notification dispatcher
# ---------------------------------------------------------------------------
async def _notify_human(request: ConfirmationRequest) -> None:
"""Send a notification about a pending confirmation to humans.
Tries Telegram first, then Discord, then falls back to log warning.
Uses the existing send_message infrastructure.
"""
msg = (
f"\U0001f514 Confirmation Required\n"
f"Action: {request.action}\n"
f"Description: {request.description}\n"
f"Requester: {request.requester}\n"
f"ID: {request.id}\n\n"
f"Approve: POST /confirm/{request.id}/approve\n"
f"Deny: POST /confirm/{request.id}/deny"
)
sent = False
# Try Telegram
try:
from tools.send_message_tool import _handle_send
result = _handle_send({
"target": "telegram",
"message": msg,
})
result_dict = json.loads(result) if isinstance(result, str) else result
if "error" not in result_dict:
sent = True
logger.info("Confirmation %s: notified via Telegram", request.id)
except Exception as e:
logger.debug("Telegram notify failed: %s", e)
# Try Discord if Telegram failed
if not sent:
try:
from tools.send_message_tool import _handle_send
result = _handle_send({
"target": "discord",
"message": msg,
})
result_dict = json.loads(result) if isinstance(result, str) else result
if "error" not in result_dict:
sent = True
logger.info("Confirmation %s: notified via Discord", request.id)
except Exception as e:
logger.debug("Discord notify failed: %s", e)
if not sent:
logger.warning(
"Confirmation %s: no messaging channel available. "
"Action '%s' requires human review -- check /confirm/%s",
request.id, request.action, request.id,
)
# ---------------------------------------------------------------------------
# Whitelist manager
# ---------------------------------------------------------------------------
class Whitelist:
"""Configurable whitelist for actions that skip confirmation."""
def __init__(self):
self._allowed: Dict[str, set] = {} # session_key -> {action, ...}
self._global_allowed: set = set()
self._lock = threading.Lock()
def is_whitelisted(self, action: str, session_key: str = "") -> bool:
with self._lock:
if action in self._global_allowed:
return True
if session_key and action in self._allowed.get(session_key, set()):
return True
return False
def add(self, action: str, session_key: str = "") -> None:
with self._lock:
if session_key:
self._allowed.setdefault(session_key, set()).add(action)
else:
self._global_allowed.add(action)
def remove(self, action: str, session_key: str = "") -> None:
with self._lock:
if session_key:
self._allowed.get(session_key, set()).discard(action)
else:
self._global_allowed.discard(action)
# ---------------------------------------------------------------------------
# Confirmation Server
# ---------------------------------------------------------------------------
class ConfirmationServer:
"""HTTP server for human confirmation of high-risk actions.
Usage:
server = ConfirmationServer()
await server.start() # blocks
# or:
server.start_background() # non-blocking
...
await server.stop()
"""
def __init__(self, host: str = DEFAULT_HOST, port: int = DEFAULT_PORT,
db_path: Optional[str] = None,
rate_limit: int = DEFAULT_RATE_LIMIT):
if not AIOHTTP_AVAILABLE:
raise RuntimeError(
"aiohttp is required for the confirmation daemon. "
"Install with: pip install aiohttp"
)
self._host = host
self._port = port
self._audit = AuditLog(db_path)
self._rate_limiter = RateLimiter(max_per_window=rate_limit)
self._whitelist = Whitelist()
self._pending: Dict[str, ConfirmationRequest] = {}
self._lock = threading.Lock()
self._app: Optional[web.Application] = None
self._runner: Optional[web.AppRunner] = None
self._bg_thread: Optional[threading.Thread] = None
# --- Lifecycle ---
def _build_app(self) -> web.Application:
app = web.Application(client_max_size=1_048_576) # 1 MB
app["server"] = self
app.router.add_post("/confirm", self._handle_submit)
app.router.add_post("/confirm/{req_id}/approve", self._handle_approve)
app.router.add_post("/confirm/{req_id}/deny", self._handle_deny)
app.router.add_get("/confirm/{req_id}", self._handle_status)
app.router.add_get("/audit", self._handle_audit)
app.router.add_get("/health", self._handle_health)
return app
async def start(self) -> None:
"""Start the server and block until stopped."""
self._app = self._build_app()
self._runner = web.AppRunner(self._app)
await self._runner.setup()
site = web.TCPSite(self._runner, self._host, self._port)
await site.start()
logger.info(
"Confirmation daemon listening on http://%s:%s",
self._host, self._port,
)
# Run until cancelled
try:
await asyncio.Event().wait()
except asyncio.CancelledError:
pass
finally:
await self.stop()
def start_background(self) -> None:
"""Start the server in a background thread (non-blocking)."""
def _run():
asyncio.run(self.start())
self._bg_thread = threading.Thread(target=_run, daemon=True)
self._bg_thread.start()
async def stop(self) -> None:
"""Gracefully stop the server."""
if self._runner:
await self._runner.cleanup()
self._runner = None
self._audit.close()
logger.info("Confirmation daemon stopped")
# --- Internal helpers ---
def _get_request(self, req_id: str) -> Optional[ConfirmationRequest]:
with self._lock:
return self._pending.get(req_id)
def _expire_old_requests(self) -> int:
"""Mark expired requests. Returns count expired."""
now = time.time()
expired = 0
with self._lock:
for req in list(self._pending.values()):
if req.status == "pending" and (now - req.created_at) > req.timeout_seconds:
req.status = "expired"
req.resolved_at = now
req.resolved_by = "system:timeout"
self._audit.log(req)
expired += 1
return expired
# --- HTTP handlers ---
async def _handle_submit(self, request: web.Request) -> web.Response:
"""POST /confirm -- submit a new confirmation request."""
try:
body = await request.json()
except Exception:
return web.json_response(
{"error": "Invalid JSON body"}, status=400
)
action = (body.get("action") or "").strip()
description = (body.get("description") or "").strip()
details = body.get("details") or {}
requester = (body.get("requester") or "agent").strip()
session_key = (body.get("session_key") or "").strip()
timeout = body.get("timeout_seconds", 300)
if not action:
return web.json_response(
{"error": "Field 'action' is required"}, status=400
)
if not description:
return web.json_response(
{"error": "Field 'description' is required"}, status=400
)
# Whitelist check
if self._whitelist.is_whitelisted(action, session_key):
auto_id = str(uuid.uuid4())[:8]
return web.json_response({
"id": auto_id,
"action": action,
"status": "auto_approved",
"message": f"Action '{action}' is whitelisted for this session",
})
# Rate limit check
if not self._rate_limiter.check(action):
remaining = self._rate_limiter.remaining(action)
return web.json_response({
"error": f"Rate limit exceeded for action '{action}'",
"remaining": remaining,
"window_seconds": RATE_LIMIT_WINDOW,
}, status=429)
# Enforce timeout bounds
try:
timeout = max(30, min(int(timeout), 3600))
except (ValueError, TypeError):
timeout = 300
# Create request
req = ConfirmationRequest(
id=str(uuid.uuid4())[:12],
action=action,
description=description,
details=details,
requester=requester,
session_key=session_key,
timeout_seconds=timeout,
)
with self._lock:
self._pending[req.id] = req
# Audit log
self._audit.log(req)
# Notify humans (fire-and-forget)
asyncio.create_task(_notify_human(req))
logger.info(
"Confirmation %s submitted: action=%s requester=%s",
req.id, action, requester,
)
return web.json_response({
"id": req.id,
"action": req.action,
"status": req.status,
"timeout_seconds": req.timeout_seconds,
"message": "Confirmation pending. Awaiting human review.",
"approve_url": f"/confirm/{req.id}/approve",
"deny_url": f"/confirm/{req.id}/deny",
}, status=202)
async def _handle_approve(self, request: web.Request) -> web.Response:
"""POST /confirm/{id}/approve -- approve a pending confirmation."""
req_id = request.match_info["req_id"]
req = self._get_request(req_id)
if req is None:
return web.json_response(
{"error": f"Confirmation '{req_id}' not found"}, status=404
)
if req.status != "pending":
return web.json_response({
"error": f"Confirmation '{req_id}' already resolved",
"status": req.status,
}, status=409)
# Parse optional approver identity
try:
body = await request.json()
approver = (body.get("approver") or "api").strip()
except Exception:
approver = "api"
req.status = "approved"
req.resolved_by = approver
req.resolved_at = time.time()
# Audit log
self._audit.log(req)
logger.info(
"Confirmation %s APPROVED by %s (action=%s)",
req_id, approver, req.action,
)
return web.json_response({
"id": req.id,
"action": req.action,
"status": "approved",
"resolved_by": approver,
"resolved_at": req.resolved_at,
})
async def _handle_deny(self, request: web.Request) -> web.Response:
"""POST /confirm/{id}/deny -- deny a pending confirmation."""
req_id = request.match_info["req_id"]
req = self._get_request(req_id)
if req is None:
return web.json_response(
{"error": f"Confirmation '{req_id}' not found"}, status=404
)
if req.status != "pending":
return web.json_response({
"error": f"Confirmation '{req_id}' already resolved",
"status": req.status,
}, status=409)
try:
body = await request.json()
denier = (body.get("denier") or "api").strip()
reason = (body.get("reason") or "").strip()
except Exception:
denier = "api"
reason = ""
req.status = "denied"
req.resolved_by = denier
req.resolved_at = time.time()
# Audit log
self._audit.log(req)
logger.info(
"Confirmation %s DENIED by %s (action=%s, reason=%s)",
req_id, denier, req.action, reason,
)
resp = {
"id": req.id,
"action": req.action,
"status": "denied",
"resolved_by": denier,
"resolved_at": req.resolved_at,
}
if reason:
resp["reason"] = reason
return web.json_response(resp)
async def _handle_status(self, request: web.Request) -> web.Response:
"""GET /confirm/{id} -- check status of a confirmation."""
req_id = request.match_info["req_id"]
req = self._get_request(req_id)
if req is None:
return web.json_response(
{"error": f"Confirmation '{req_id}' not found"}, status=404
)
# Check for expiration
if req.status == "pending":
now = time.time()
if (now - req.created_at) > req.timeout_seconds:
req.status = "expired"
req.resolved_at = now
req.resolved_by = "system:timeout"
self._audit.log(req)
return web.json_response(req.to_dict())
async def _handle_audit(self, request: web.Request) -> web.Response:
"""GET /audit -- recent audit log entries."""
try:
limit = int(request.query.get("limit", "50"))
limit = max(1, min(limit, 500))
except (ValueError, TypeError):
limit = 50
entries = self._audit.recent(limit)
return web.json_response({
"count": len(entries),
"entries": entries,
})
async def _handle_health(self, request: web.Request) -> web.Response:
"""GET /health -- liveness probe."""
return web.json_response({
"status": "ok",
"pending_count": len(self._pending),
"timestamp": time.time(),
})
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main():
"""Run the confirmation daemon as a standalone process."""
import argparse
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
)
parser = argparse.ArgumentParser(
description="Hermes Human Confirmation Daemon"
)
parser.add_argument(
"--host", default=os.getenv("CONFIRMATION_HOST", DEFAULT_HOST),
help="Bind address (default: 127.0.0.1)",
)
parser.add_argument(
"--port", type=int,
default=int(os.getenv("CONFIRMATION_PORT", DEFAULT_PORT)),
help="Bind port (default: 6000)",
)
parser.add_argument(
"--db-path", default=None,
help="SQLite database path (default: ~/.hermes/confirmation_audit.db)",
)
parser.add_argument(
"--rate-limit", type=int,
default=int(os.getenv("CONFIRMATION_RATE_LIMIT", DEFAULT_RATE_LIMIT)),
help="Max confirmations per action per hour (default: 10)",
)
args = parser.parse_args()
if not AIOHTTP_AVAILABLE:
print("ERROR: aiohttp is required. Install with: pip install aiohttp")
raise SystemExit(1)
server = ConfirmationServer(
host=args.host,
port=args.port,
db_path=args.db_path,
rate_limit=args.rate_limit,
)
print(f"Starting Confirmation Daemon on http://{args.host}:{args.port}")
asyncio.run(server.start())
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,264 @@
"""Tests for warm session provisioning (#327)."""
import json
import time
from unittest.mock import MagicMock, patch
import pytest
from agent.warm_session import (
WarmSessionTemplate,
ToolCallExample,
build_warm_conversation,
save_template,
load_template,
list_templates,
extract_successful_patterns,
_truncate_result,
)
@pytest.fixture()
def isolated_templates_dir(tmp_path, monkeypatch):
"""Point TEMPLATES_DIR at a temp directory."""
tdir = tmp_path / "warm_sessions"
tdir.mkdir()
monkeypatch.setattr("agent.warm_session.TEMPLATES_DIR", tdir)
return tdir
@pytest.fixture()
def sample_template():
"""A sample warm session template with a few examples."""
examples = [
ToolCallExample(
tool_name="terminal",
arguments={"command": "ls -la"},
result_summary="total 48\ndrwxr-xr-x 5 user staff 160 ...",
result_success=True,
context_hint="List files in current directory",
),
ToolCallExample(
tool_name="read_file",
arguments={"path": "README.md"},
result_summary="# Project\n\nThis is the README.",
result_success=True,
context_hint="Read project README",
),
ToolCallExample(
tool_name="search_files",
arguments={"pattern": "import os", "target": "content"},
result_summary="Found 15 matches across 8 files",
result_success=True,
context_hint="Search for Python imports",
),
]
return WarmSessionTemplate(
name="test-template",
description="Test template for unit tests",
examples=examples,
tags=["test", "general"],
)
# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
class TestToolCallExample:
def test_creation(self):
ex = ToolCallExample(
tool_name="terminal",
arguments={"command": "echo hello"},
result_summary="hello",
result_success=True,
)
assert ex.tool_name == "terminal"
assert ex.arguments == {"command": "echo hello"}
assert ex.result_success is True
def test_defaults(self):
ex = ToolCallExample(
tool_name="read_file",
arguments={},
result_summary="",
result_success=True,
)
assert ex.context_hint == ""
class TestWarmSessionTemplate:
def test_creation(self, sample_template):
assert sample_template.name == "test-template"
assert len(sample_template.examples) == 3
assert sample_template.created_at > 0
def test_round_trip_dict(self, sample_template):
data = sample_template.to_dict()
restored = WarmSessionTemplate.from_dict(data)
assert restored.name == sample_template.name
assert len(restored.examples) == len(sample_template.examples)
assert restored.examples[0].tool_name == "terminal"
def test_from_dict_with_plain_dicts(self):
data = {
"name": "plain",
"description": "from dict",
"examples": [
{
"tool_name": "web_search",
"arguments": {"query": "test"},
"result_summary": "results found",
"result_success": True,
"context_hint": "",
}
],
}
template = WarmSessionTemplate.from_dict(data)
assert len(template.examples) == 1
assert template.examples[0].tool_name == "web_search"
# ---------------------------------------------------------------------------
# Truncation
# ---------------------------------------------------------------------------
class TestTruncateResult:
def test_short_unchanged(self):
assert _truncate_result("short text") == "short text"
def test_long_truncated(self):
long = "x" * 1000
result = _truncate_result(long, max_chars=100)
assert len(result) < 200 # 100 chars + truncation suffix
assert "truncated" in result
def test_empty(self):
assert _truncate_result("") == ""
assert _truncate_result(None) == ""
# ---------------------------------------------------------------------------
# Build conversation
# ---------------------------------------------------------------------------
class TestBuildWarmConversation:
def test_basic_conversation(self, sample_template):
messages = build_warm_conversation(sample_template)
# Each example produces: user + assistant(tool_calls) + tool(result) = 3 messages
assert len(messages) == 3 * 3 # 3 examples * 3 messages each
def test_message_roles_alternate(self, sample_template):
messages = build_warm_conversation(sample_template)
roles = [m["role"] for m in messages]
expected = ["user", "assistant", "tool"] * 3
assert roles == expected
def test_tool_calls_have_ids(self, sample_template):
messages = build_warm_conversation(sample_template)
assistant_msgs = [m for m in messages if m["role"] == "assistant"]
for msg in assistant_msgs:
tc = msg["tool_calls"][0]
assert tc["id"].startswith("warm_")
assert tc["function"]["name"] in ("terminal", "read_file", "search_files")
def test_tool_results_reference_ids(self, sample_template):
messages = build_warm_conversation(sample_template)
assistant_msgs = [m for m in messages if m["role"] == "assistant"]
tool_msgs = [m for m in messages if m["role"] == "tool"]
for a, t in zip(assistant_msgs, tool_msgs):
assert t["tool_call_id"] == a["tool_calls"][0]["id"]
def test_max_examples_limit(self, sample_template):
messages = build_warm_conversation(sample_template, max_examples=1)
assert len(messages) == 3 # 1 example * 3 messages
def test_system_prompt_addendum(self, sample_template):
sample_template.system_prompt_addendum = "Use Python 3.12+"
messages = build_warm_conversation(sample_template)
assert messages[0]["role"] == "system"
assert "Python 3.12+" in messages[0]["content"]
# ---------------------------------------------------------------------------
# Save / Load / List
# ---------------------------------------------------------------------------
class TestTemplatePersistence:
def test_save_and_load(self, isolated_templates_dir, sample_template):
save_template(sample_template)
loaded = load_template("test-template")
assert loaded is not None
assert loaded.name == "test-template"
assert len(loaded.examples) == 3
def test_load_nonexistent(self, isolated_templates_dir):
assert load_template("does-not-exist") is None
def test_list_templates(self, isolated_templates_dir, sample_template):
save_template(sample_template)
templates = list_templates()
assert len(templates) == 1
assert templates[0]["name"] == "test-template"
assert templates[0]["example_count"] == 3
def test_list_empty(self, isolated_templates_dir):
assert list_templates() == []
# ---------------------------------------------------------------------------
# Extract patterns (mocked SessionDB)
# ---------------------------------------------------------------------------
class TestExtractPatterns:
def test_extracts_from_marathon_sessions(self):
db = MagicMock()
db.list_sessions.return_value = [
{"id": "s1", "message_count": 50, "end_reason": "completed"},
{"id": "s2", "message_count": 10, "end_reason": "completed"}, # too short
]
db.get_messages.return_value = [
{
"role": "assistant",
"content": None,
"tool_calls": json.dumps([{
"id": "tc1",
"type": "function",
"function": {"name": "terminal", "arguments": json.dumps({"command": "pwd"})},
}]),
},
]
examples = extract_successful_patterns(db, min_messages=20)
# Only s1 (50 msgs) qualifies, s2 (10 msgs) is skipped
assert len(examples) == 1
assert examples[0].tool_name == "terminal"
def test_skips_trivial_tools(self):
db = MagicMock()
db.list_sessions.return_value = [
{"id": "s1", "message_count": 50, "end_reason": "completed"},
]
db.get_messages.return_value = [
{
"role": "assistant",
"content": None,
"tool_calls": json.dumps([{
"id": "tc1",
"type": "function",
"function": {"name": "clarify", "arguments": "{}"},
}]),
},
]
examples = extract_successful_patterns(db)
assert len(examples) == 0 # clarify is trivial, skipped
def test_skips_errored_sessions(self):
db = MagicMock()
db.list_sessions.return_value = [
{"id": "s1", "message_count": 50, "end_reason": "error"},
]
examples = extract_successful_patterns(db)
assert len(examples) == 0 # errored session, skipped

View File

@@ -1,316 +0,0 @@
"""Tests for the Human Confirmation Daemon."""
import asyncio
import json
import time
from unittest.mock import patch, MagicMock
import pytest
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
# We import after the fixtures to avoid aiohttp import issues in test envs
try:
from aiohttp import web
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
AIOHTTP_AVAILABLE = True
except ImportError:
AIOHTTP_AVAILABLE = False
if AIOHTTP_AVAILABLE:
from daemon.confirmation_server import (
ConfirmationServer,
ConfirmationRequest,
AuditLog,
RateLimiter,
Whitelist,
HIGH_RISK_ACTIONS,
DEFAULT_RATE_LIMIT,
RATE_LIMIT_WINDOW,
)
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestRateLimiter:
"""Unit tests for the RateLimiter."""
def test_allows_within_limit(self):
rl = RateLimiter(max_per_window=3, window=60)
assert rl.check("deploy") is True
assert rl.check("deploy") is True
assert rl.check("deploy") is True
def test_blocks_over_limit(self):
rl = RateLimiter(max_per_window=2, window=60)
assert rl.check("deploy") is True
assert rl.check("deploy") is True
assert rl.check("deploy") is False
def test_remaining_count(self):
rl = RateLimiter(max_per_window=5, window=60)
assert rl.remaining("deploy") == 5
rl.check("deploy")
assert rl.remaining("deploy") == 4
def test_separate_actions_independent(self):
rl = RateLimiter(max_per_window=2, window=60)
assert rl.check("deploy") is True
assert rl.check("deploy") is True
assert rl.check("deploy") is False
assert rl.check("shutdown") is True # different action
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestWhitelist:
"""Unit tests for the Whitelist."""
def test_global_whitelist(self):
wl = Whitelist()
assert wl.is_whitelisted("deploy") is False
wl.add("deploy")
assert wl.is_whitelisted("deploy") is True
def test_session_scoped_whitelist(self):
wl = Whitelist()
assert wl.is_whitelisted("deploy", "session1") is False
wl.add("deploy", "session1")
assert wl.is_whitelisted("deploy", "session1") is True
assert wl.is_whitelisted("deploy", "session2") is False
def test_remove(self):
wl = Whitelist()
wl.add("deploy")
assert wl.is_whitelisted("deploy") is True
wl.remove("deploy")
assert wl.is_whitelisted("deploy") is False
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestAuditLog:
"""Unit tests for the AuditLog."""
def test_log_and_retrieve(self, tmp_path):
db = str(tmp_path / "test_audit.db")
log = AuditLog(db_path=db)
req = ConfirmationRequest(
id="test-123",
action="deploy_production",
description="Deploy v2.0 to prod",
requester="timmy",
)
log.log(req)
entries = log.recent(limit=10)
assert len(entries) == 1
assert entries[0]["id"] == "test-123"
assert entries[0]["action"] == "deploy_production"
assert entries[0]["status"] == "pending"
log.close()
def test_update_on_resolve(self, tmp_path):
db = str(tmp_path / "test_audit.db")
log = AuditLog(db_path=db)
req = ConfirmationRequest(
id="test-456",
action="delete_data",
description="Purge old records",
)
log.log(req)
# Resolve
req.status = "approved"
req.resolved_by = "alexander"
req.resolved_at = time.time()
log.log(req)
entries = log.recent(limit=10)
assert len(entries) == 1
assert entries[0]["status"] == "approved"
assert entries[0]["resolved_by"] == "alexander"
log.close()
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestConfirmationRequest:
"""Unit tests for the data model."""
def test_to_dict(self):
req = ConfirmationRequest(
id="abc123",
action="deploy_production",
description="Ship it",
details={"version": "2.0"},
)
d = req.to_dict()
assert d["id"] == "abc123"
assert d["status"] == "pending"
assert d["created_at_iso"] is not None
assert d["resolved_at_iso"] is None
assert d["details"]["version"] == "2.0"
# ---------------------------------------------------------------------------
# Integration tests (HTTP)
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestConfirmationHTTP(AioHTTPTestCase):
"""Full HTTP integration tests for the ConfirmationServer."""
async def get_application(self):
# Suppress notification during tests
with patch("daemon.confirmation_server._notify_human", return_value=None):
server = ConfirmationServer(
host="127.0.0.1",
port=6000,
db_path=":memory:",
)
self._server = server
return server._build_app()
@unittest_run_loop
async def test_health(self):
resp = await self.client.request("GET", "/health")
assert resp.status == 200
data = await resp.json()
assert data["status"] == "ok"
@unittest_run_loop
async def test_submit_confirmation(self):
with patch("daemon.confirmation_server._notify_human", return_value=None):
resp = await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
"description": "Deploy v2.0 to production",
"requester": "timmy",
"session_key": "test-session",
})
assert resp.status == 202
data = await resp.json()
assert data["status"] == "pending"
assert data["action"] == "deploy_production"
assert "id" in data
@unittest_run_loop
async def test_submit_missing_action(self):
resp = await self.client.request("POST", "/confirm", json={
"description": "Something",
})
assert resp.status == 400
@unittest_run_loop
async def test_submit_missing_description(self):
resp = await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
})
assert resp.status == 400
@unittest_run_loop
async def test_approve_flow(self):
# Submit
with patch("daemon.confirmation_server._notify_human", return_value=None):
submit_resp = await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
"description": "Ship it",
})
assert submit_resp.status == 202
submit_data = await submit_resp.json()
req_id = submit_data["id"]
# Approve
approve_resp = await self.client.request(
"POST", f"/confirm/{req_id}/approve",
json={"approver": "alexander"},
)
assert approve_resp.status == 200
approve_data = await approve_resp.json()
assert approve_data["status"] == "approved"
assert approve_data["resolved_by"] == "alexander"
# Check status
status_resp = await self.client.request("GET", f"/confirm/{req_id}")
assert status_resp.status == 200
status_data = await status_resp.json()
assert status_data["status"] == "approved"
@unittest_run_loop
async def test_deny_flow(self):
with patch("daemon.confirmation_server._notify_human", return_value=None):
submit_resp = await self.client.request("POST", "/confirm", json={
"action": "delete_data",
"description": "Wipe everything",
})
req_id = (await submit_resp.json())["id"]
deny_resp = await self.client.request(
"POST", f"/confirm/{req_id}/deny",
json={"denier": "alexander", "reason": "Too risky"},
)
assert deny_resp.status == 200
deny_data = await deny_resp.json()
assert deny_data["status"] == "denied"
assert deny_data["reason"] == "Too risky"
@unittest_run_loop
async def test_double_approve_returns_409(self):
with patch("daemon.confirmation_server._notify_human", return_value=None):
submit_resp = await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
"description": "Ship it",
})
req_id = (await submit_resp.json())["id"]
await self.client.request(f"POST", f"/confirm/{req_id}/approve")
resp2 = await self.client.request(f"POST", f"/confirm/{req_id}/approve")
assert resp2.status == 409
@unittest_run_loop
async def test_not_found(self):
resp = await self.client.request("GET", "/confirm/nonexistent")
assert resp.status == 404
@unittest_run_loop
async def test_audit_log(self):
with patch("daemon.confirmation_server._notify_human", return_value=None):
await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
"description": "Ship it",
})
resp = await self.client.request("GET", "/audit")
assert resp.status == 200
data = await resp.json()
assert data["count"] >= 1
@unittest_run_loop
async def test_rate_limit(self):
# Exhaust rate limit (default is 10)
with patch("daemon.confirmation_server._notify_human", return_value=None):
for i in range(10):
await self.client.request("POST", "/confirm", json={
"action": "test_rate_action",
"description": f"Request {i}",
})
# 11th should be rate-limited
resp = await self.client.request("POST", "/confirm", json={
"action": "test_rate_action",
"description": "Over the limit",
})
assert resp.status == 429
@unittest_run_loop
async def test_whitelist_auto_approves(self):
self._server._whitelist.add("safe_action")
resp = await self.client.request("POST", "/confirm", json={
"action": "safe_action",
"description": "This is whitelisted",
})
assert resp.status == 200
data = await resp.json()
assert data["status"] == "auto_approved"

178
tools/warm_session_tool.py Normal file
View File

@@ -0,0 +1,178 @@
"""Warm Session Tool — manage pre-proficient agent sessions.
Allows the agent to build, save, list, and load warm session templates
that pre-seed new sessions with successful tool-call patterns.
"""
import json
import logging
from typing import Optional
from tools.registry import registry
logger = logging.getLogger(__name__)
def warm_session(
action: str,
name: str = None,
description: str = "",
min_messages: int = 20,
max_sessions: int = 20,
source_filter: str = None,
tags: list = None,
) -> str:
"""Manage warm session templates for pre-proficient agent sessions.
Actions:
build — mine existing sessions and create a template
list — show saved templates
load — return a template's conversation_history for injection
delete — remove a template
"""
from agent.warm_session import (
build_from_session_db,
load_template,
list_templates,
build_warm_conversation,
save_template,
TEMPLATES_DIR,
)
if action == "list":
templates = list_templates()
return json.dumps({
"success": True,
"templates": templates,
"count": len(templates),
})
if action == "build":
if not name:
return json.dumps({"success": False, "error": "name is required for 'build'."})
try:
from hermes_state import SessionDB
db = SessionDB()
except Exception as e:
return json.dumps({"success": False, "error": f"Cannot open session DB: {e}"})
template = build_from_session_db(
db,
name=name,
description=description,
min_messages=min_messages,
max_sessions=max_sessions,
source_filter=source_filter,
tags=tags or [],
)
return json.dumps({
"success": True,
"name": template.name,
"example_count": len(template.examples),
"description": template.description,
})
if action == "load":
if not name:
return json.dumps({"success": False, "error": "name is required for 'load'."})
template = load_template(name)
if not template:
return json.dumps({"success": False, "error": f"Template '{name}' not found."})
conversation = build_warm_conversation(template)
return json.dumps({
"success": True,
"name": template.name,
"message_count": len(conversation),
"conversation_preview": [
{"role": m["role"], "content_preview": str(m.get("content", ""))[:100]}
for m in conversation[:6]
],
})
if action == "delete":
if not name:
return json.dumps({"success": False, "error": "name is required for 'delete'."})
path = TEMPLATES_DIR / f"{name}.json"
if not path.exists():
return json.dumps({"success": False, "error": f"Template '{name}' not found."})
path.unlink()
return json.dumps({"success": True, "message": f"Template '{name}' deleted."})
return json.dumps({
"success": False,
"error": f"Unknown action '{action}'. Use: build, list, load, delete",
})
WARM_SESSION_SCHEMA = {
"name": "warm_session",
"description": (
"Manage warm session templates for pre-proficient agent sessions. "
"Marathon sessions have lower error rates than mid-length ones because "
"agents accumulate successful patterns. Warm templates capture those "
"patterns and pre-seed new sessions with experience.\n\n"
"Actions:\n"
" build — mine existing sessions for successful tool-call patterns, save as template\n"
" list — show saved templates\n"
" load — retrieve a template's conversation history for session injection\n"
" delete — remove a template"
),
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["build", "list", "load", "delete"],
"description": "The action to perform.",
},
"name": {
"type": "string",
"description": "Template name. Required for build/load/delete.",
},
"description": {
"type": "string",
"description": "Description for the template. Used with 'build'.",
},
"min_messages": {
"type": "integer",
"description": "Minimum message count to consider a session experienced (default: 20).",
},
"max_sessions": {
"type": "integer",
"description": "Maximum sessions to scan when building (default: 20).",
},
"source_filter": {
"type": "string",
"description": "Filter sessions by source (cli, telegram, discord, etc.).",
},
"tags": {
"type": "array",
"items": {"type": "string"},
"description": "Tags for organizing templates.",
},
},
"required": ["action"],
},
}
registry.register(
name="warm_session",
toolset="skills",
schema=WARM_SESSION_SCHEMA,
handler=lambda args, **kw: warm_session(
action=args.get("action", ""),
name=args.get("name"),
description=args.get("description", ""),
min_messages=args.get("min_messages", 20),
max_sessions=args.get("max_sessions", 20),
source_filter=args.get("source_filter"),
tags=args.get("tags"),
),
emoji="🔥",
)